Day2 (9.29, 수)
from tqdm import tqdm
from transformers import PreTrainedTokenizerFast, BartForConditionalGeneration
import torch
import pandas as pd
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Load Model and Tokenize
tokenizer = PreTrainedTokenizerFast.from_pretrained("ainize/kobart-news")
model = BartForConditionalGeneration.from_pretrained("ainize/kobart-news")
model.to(device)
test = pd.read_csv('/content/drive/MyDrive/boostcamp/dacon/aihub-2021/test_data.csv')
text_list = list(test.text)
submission_csv = pd.read_csv('/content/drive/MyDrive/boostcamp/dacon/aihub-2021/sample_submission.csv')
error_cnt = 0
for index, input_text in enumerate(tqdm(text_list)):
if index in [4083, 4913, 5525, 8788]:
continue
try:
#input_text = input_text[:2300]
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
summary_text_ids = model.generate(
input_ids=input_ids,
bos_token_id=model.config.bos_token_id,
eos_token_id=model.config.eos_token_id,
length_penalty=2.0,
max_length=142,
min_length=56,
num_beams=4,
)
submission_csv.summary[index] = tokenizer.decode(summary_text_ids[0], skip_special_tokens=True)
except:
print(f'index error {index}')
pass
submission_csv.to_csv('/content/drive/MyDrive/boostcamp/dacon/aihub-2021/submission_1.csv')
print('Job done')
제출
1차 제출

2차 제출

Last updated