Day4-5 (10.28-29, 목-금)
Baseline 기준으로 predict 하는 흐름 이해하기
1. 코드 흐름 이해하기
Retriever 데이터 가져오기
def run_sparse_retrieval(
tokenize_fn: Callable[[str], List[str]],
datasets: DatasetDict,
training_args: TrainingArguments,
data_args: DataTrainingArguments,
data_path: str = "../data",
context_path: str = "wikipedia_documents.json",
) -> DatasetDict:
df = retriever.retrieve(datasets["validation"], topk=data_args.top_k_retrieval)
doc_scores, doc_indices = self.get_relevant_doc_bulk(
query_or_dataset["question"], k=topk
)
get_relevant_doc_bulk
로 부터 query와 유사도가 높은 document를 추출한다.
for i in range(result.shape[0]):
sorted_result = np.argsort(result[i, :])[::-1]
doc_scores.append(result[i, :][sorted_result].tolist()[:k])
doc_indices.append(sorted_result.tolist()[:k])

보다시피 각 query마다 k(5)개의 index와 각각의 score를 리턴하는 것을 확인할 수 있다.
그 후 5개의 문장을 띄어쓰기 기준으로 리턴하고 있다.
tmp = {
# Query와 해당 id를 반환합니다.
"question": example["question"],
"id": example["id"],
# Retrieve한 Passage의 id, context를 반환합니다.
"context_id": doc_indices[idx],
"context": " ".join(
[self.contexts[pid] for pid in doc_indices[idx]]
),
}
음... 띄어쓰기로 그냥 이어서 Reader에게 전달하고, 리더는 다시 max_length 기준으로 잘라서 추론하게된다. 음... 결국 document를 전체를 한 context를 볼게 아니라 잘라서 하나의 context라고 봐도 무관하겠구나 라는 생각이 드는 부분이었다.
Last updated
Was this helpful?