Retrieval-augmented generation for knowledge-intensive nlp tasks
1. Abstract
1.1. LM의 특징과 한계
특징
지식을 매개변수에 잘 저장한다.
Downstream NLP task에 대해 fine-tuning을 수행했을 때 많은 부분에서 SOTA를 달성하였다.
한계
확장이나, 수정이 어렵다.
지식에 접근하고, 정확한 조작이 불가능하다.
잘못된 지식을 생성할 수 있다.
1.2. RAG Basic Idea
위 1.1. 에서 나온 한계점들에 대해서 parametic memory와 non-parametric memory를 결합한 hybrid모델이 이런 문제를 일부 해소할 수 있다. 지식을 추가하거나 수정할 수 있고, 접근된 지식에 대해 검증할 수 있다. 이런 아이디어가 적용된 모델로는 REALM과 ORQA가 있는데 이는 모두 extraction 기반으로 구현되어있다.
이 논문에서 소개하고자 하는 RAG모델은 Context(문맥)으로 언어모델의 성능을 향상시키고자 하였다. 그래서 기존 방식인 input x를 넣고 output y를 생성하는 것이 아니라 input x 앞에 context z를 추가하여 input x + context z 로 부터 output y를 생성한다는 아이디어를 가지고 있다.
이렇게 생성된 context y는 자연어처리에 있어 classification이나 text-to-text task에 속하는 semantic similarity등과 같이 적용할 수 있고, fact verification이나 ODQA와 같은 intensive task에서 좋은 성능을 냈다.
2. Method
RAG 모델을 좀 더 자세히 살펴보게 되면, 크게 두개의 component로 나누어 볼 수 있다.
: Query x에 대해 top-k 개의 z를 리턴해주는 component
: x, z를 초기 입력값으로 가지고, i-1 까지의 값으로 부터 토큰을 생성하는 모델
2.1. RAG Sequence & RAG tokens
우선 생성된 텍스트에 대한 분포를 생성하기 위해 latent document에 대해 marginalize하는 두가지 모델을 제안한다.
RAG Sequence
top k개의 z를 선정하고 해당 document 로부터 output y 즉, token을 생성해내는 구조
RAG Token
각 토큰에 대해서 다른 document를 선택할 수 있도록 하였다. 수식에서처럼 토큰마다 z를 선정하는 것을 볼 수 있다.
2.2. Retrieval Component: DPR (Dense Passage Retriever)
각 document에 대한 embedding값을 미리 계산하고 Query에 대한 embedding값을 구해 둘 사이의 유사도를 가지고 추출하게 된다. 여기서 embedding 하게 되는 document와 query는 각각 다른 weight를 가지는 bert 모델을 사용한다.
FAISS 라이브러리를 활용하여 빠르게 retrieve 할 수 있도록 하였다.
2.3. Generator: BART
input x와 retrieve 된 document를 concate 하여 BART의 input으로 사용하게 된다. 이 논문에서는 BART-Large를 사용하였다.
2.4. Training
어떤 document가 검색되어야 할지 알려주지 않고, Retriever와 Generator를 동시에 학습한다.
입력 쌍 가 주어지면 Adam을 통해 negative log likelyhood, 를 최소화한다.
학습 중 documents encoder를 업데이트 하면 document index를 정기적으로 업데이트 해야하므로 비용소모가 크다 -> document encoder는 고정하고, query encoder와 generator를 fine-tuning 한다.
2.5. Decoding
RAG-Token
각 토큰에 대해 계산하게 되면 transition probability를 가진 auto-regressive seq2seq generator로 볼 수 있다.
Use Standard Beam decoder
RAG-Sequence
Run beam search for each document z
Scoring each hypothesis using
모든 beam에 대한 hypothesis y의 확률을 추정하기 위해, y가 존재하지 않는 각 document z에 대해 추가적인 학습을 진행하고, generator score에 를 곱해 marginal에 대한 beam 사이의 확률을 합한다.
3. Experiments & Result
3.1. Resource and Setting
Single Wikipedia dump for non-parametric knowledge source (Dec 2018)
FAISS를 활용하여 속도 개선
, for top-k
3.2. Open-Domain Question Answering
RAG Model is SOTA
Extraction QA에서도 실제로 정답이 없음에도 11.8%로 정답을 찾아냈다.
3.3. Abstractive Question Answering
MS MARCO (Microsoft Machine Reading Comprehension)
2016년 세계 AI 컨퍼런스 NIPS에서 기계의 독해 및 질의응답에 대한 기존 Dataset의 약점 극복을 목적으로 만들어졌다.
일관성 부족 / 질문과 관련이 없는 응답 / 구체적이지 않은 응답
320만 문서 풀(pool)에서 검색해 관련 높은 응답 100개를 뽑고, 질문에 대한 답변 상위 100개를 MRR(Mean Reciprocal Rank metric)을 통해 성능을 측정한다고 한다.
질문에 대해 답이 있는 gold passage를 10개씩 넣어 놓는다. 이 10개의 passage가 없다면 답을 찾기 매우 어려운 질문으로 구성되어있도록 한다.
BART 보다 Rouge-L / BLEU-Score가 각각 2.6 / 2.6 높은 걸 확인할 수 있었다.
또한 BART보다 hallucinate 하지않고 사실에 가까운 문장을 만들어 냈다.
3.4. Jeopardy Question Generation
정답을 가지고 질문을 생성하는 Task
사람이 직접 평가하였는데 RAG-Token이 더 좋은 결과를 냈다고 평가하였다.
3.5. FEVER - Fact Verification
wikipedia로 부터 가져온 185,000개의 데이터로 주장을 받고, 그에 대해 다른사람의 주장을 지지, 반박, 알수 없음 3가지로 분류하는 Task이다. 여기서도 BART보다 좋은 성능을 확인할 수 있었다.
3.6. Effective of retrieving more documents
이 논문에서는 5, 10 으로 구현하였는데 두개의 성능이 크게 차이나지 않으며, 오히려 10을 넘어선 순간부터는 떨어지는 현상을 확인할 수 있었다.
4. Discussion
Parametric과 non-parametric Hybrid Model을 사용
Retrieval로 BERT + Generator로 BART모델을 사용
document z + query x => y
RAG model이 ODQA에서 SOTA를 달성 + 뿐만아니라 다른 Task에서도 BART보다 좋은 성능을
Last updated