Day4-5 (10.28-29, 목-금)

  • Baseline 기준으로 predict 하는 흐름 이해하기

1. 코드 흐름 이해하기

Retriever 데이터 가져오기

def run_sparse_retrieval(
    tokenize_fn: Callable[[str], List[str]],
    datasets: DatasetDict,
    training_args: TrainingArguments,
    data_args: DataTrainingArguments,
    data_path: str = "../data",
    context_path: str = "wikipedia_documents.json",
) -> DatasetDict:
df = retriever.retrieve(datasets["validation"], topk=data_args.top_k_retrieval)
doc_scores, doc_indices = self.get_relevant_doc_bulk(
                    query_or_dataset["question"], k=topk
                )
  • get_relevant_doc_bulk 로 부터 query와 유사도가 높은 document를 추출한다.

for i in range(result.shape[0]):
            sorted_result = np.argsort(result[i, :])[::-1]
            doc_scores.append(result[i, :][sorted_result].tolist()[:k])
            doc_indices.append(sorted_result.tolist()[:k])
  • 보다시피 각 query마다 k(5)개의 index와 각각의 score를 리턴하는 것을 확인할 수 있다.

  • 그 후 5개의 문장을 띄어쓰기 기준으로 리턴하고 있다.

tmp = {
    # Query와 해당 id를 반환합니다.
    "question": example["question"],
    "id": example["id"],
    # Retrieve한 Passage의 id, context를 반환합니다.
    "context_id": doc_indices[idx],
    "context": " ".join(
        [self.contexts[pid] for pid in doc_indices[idx]]
    ),
}

음... 띄어쓰기로 그냥 이어서 Reader에게 전달하고, 리더는 다시 max_length 기준으로 잘라서 추론하게된다. 음... 결국 document를 전체를 한 context를 볼게 아니라 잘라서 하나의 context라고 봐도 무관하겠구나 라는 생각이 드는 부분이었다.

Last updated