TART reranker 원리와 코드 알아보기
TART 리랭커의 작동 원리는 어떻게 될까요?
Quick context
First, this page captures one concrete build-log step, research note, or project lesson from Jeffrey Kim.
Next, use the tags, related reading, and home archive to move from this note to deeper material in the same topic cluster.
Finally, follow the RSS feed if you want the next experiment, retrospective, or paper review as soon as it ships.
메타에서 나온 Information Retrieval 기법. Task-aware Retrieval with Instructions이라는 논문에 잘 소개가 되어 있고, 깃허브도 있다.
핵심
TART의 핵심은 “retrieval에 검색할 때 Instruction을 포함시킨다” 에 있다. 단순한 유저의 질문이 아니라, 질문 문장에는 담기지 않은 의도까지 포함하여 검색할 수 있는 것이다. 예시를 들어보겠다.
예시
유저가 다음과 같은 질문을 한다고 하자.
질문 : Python에서 1부터 10까지 더하는 방법을 알려줘.
이러한 경우 유저에게는 두 가지 의도가 있을 수 있다.
- Python에서 1부터 10까지 더하는 코드를 작성해줘.
- Python에서 1부터 10까지 더하는 방법을 상세하게 줄글로 설명해줘.
즉, 이 경우 retrieve 해야하는 문서의 내용이 코드와 코드에 대한 풀이글, 두 가지로 나눠지는 것이다. 통상적인 경우에는 이런 경우 코드에 대한 retriever, 풀이글에 대한 retriever을 따로 사용했다.
하지만 TART는 유저의 의도를 함께 retrieve 할 때 넣어주고, 그러면 TART가 의도에 알맞게 문서를 찾아준다. 위의 예시에서 유저가 코드를 원했다면 코드 문서를, 풀이글을 원했다면 풀이글 문서를 retrieve 해 주는 것이다.
만든 방법
연구진들은 BERRI라는 instruction-query-정답 문서 쌍이 있는 데이터셋을 만들고, 그것으로 retriever를 학습시켰다. 더 자세한 부분은 논문을 참고해주시길.
TART-dual vs TART-full
TART는 두가지 모델이 있다. TART-dual은 보통의 임베딩 모델과 비슷하다. 유저의 쿼리와 각 문서들간의 유사도만을 비교한다. TART-full은 유저 쿼리와 각 문서들간의 관계 뿐 아니라, 문서들 간의 관계 역시 고려한다. instruction에 대해 어떤 문서가 더 가까운지 서로 비교해가며 연산하기 때문에 더욱 정확하다. 하지만 당연히 계산량이 훨씬 크겠지?
코드
RAGchain에서는 TART-full을 리랭커 형태로 준비했다. TART-dual의 모델은 허깅페이스에 올라와 있지 않고, TART-full의 연산량이 크기 때문이다. 실제로 논문에서도 TRAT-full은 다른 방법으로 retrieve 한 문서들에 사용했다고 하니, 딱 리랭커인 것이다. TART 리랭커가 추가된 PR은 여기서 확인할 수 있다.
간단한 TART 코드 예시는 아래와 같다.
from typing import List
import torch
import torch.nn.functional as F
from .modeling_enc_t5 import EncT5ForSequenceClassification
from .tokenization_enc_t5 import EncT5Tokenizer
def tart_full(query: str, contents: List[str], instruction: str) -> List[str]:
model_name = "facebook/tart-full-flan-t5-xl"
model = EncT5ForSequenceClassification.from_pretrained(model_name)
tokenizer = EncT5Tokenizer.from_pretrained(model_name)
instruction_queries: List[str] = ['{0} [SEP] {1}'.format(instruction, query) for _ in range(len(contents))]
features = tokenizer(instruction_queries, contents, padding=True, truncation=True, return_tensors="pt")
with torch.no_grad():
scores = model(**features).logits
normalized_scores = [float(score[1]) for score in F.softmax(scores, dim=1)]
sorted_pairs = sorted(zip(contents, normalized_scores), key=lambda x: x[1], reverse=True)
sorted_contents = [content for content, score in sorted_pairs]
return sorted_contents
modeling_ent_t5와 tokenization_enc_t5는 원저자의 깃허브에도 있고 RAGchain에도 있다. TART 참 쉽죠?