267 lines
11 KiB
Python
267 lines
11 KiB
Python
"""
|
|
컨텍스트 기반 검색 시스템
|
|
- 컨텍스트 임베딩: 질문과 문서를 함께 임베딩하여 더 정확한 검색
|
|
- 컨텍스트 BM25: 질문과 문서의 컨텍스트를 고려한 키워드 검색
|
|
- Reranker: 검색 결과를 재순위화하여 정확도 향상
|
|
"""
|
|
|
|
import logging
|
|
import numpy as np
|
|
from typing import List, Dict, Any, Tuple
|
|
from rank_bm25 import BM25Okapi
|
|
from sentence_transformers import SentenceTransformer
|
|
from sklearn.metrics.pairwise import cosine_similarity
|
|
import re
|
|
from collections import Counter
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class ContextEmbedding:
|
|
"""컨텍스트 임베딩을 통한 검색"""
|
|
|
|
def __init__(self, model_name: str = "jhgan/ko-sroberta-multitask"):
|
|
self.model = SentenceTransformer(model_name)
|
|
logger.info(f"✅ 컨텍스트 임베딩 모델 로드 완료: {model_name}")
|
|
|
|
def create_context_embedding(self, question: str, document: str) -> np.ndarray:
|
|
"""질문과 문서를 함께 임베딩하여 컨텍스트 임베딩 생성"""
|
|
# 질문과 문서를 결합하여 컨텍스트 생성
|
|
context = f"질문: {question}\n문서: {document}"
|
|
embedding = self.model.encode(context)
|
|
return embedding
|
|
|
|
def search_with_context(self, question: str, documents: List[Dict[str, Any]], top_k: int = 10) -> List[Dict[str, Any]]:
|
|
"""컨텍스트 임베딩을 사용한 검색"""
|
|
logger.info(f"🔍 컨텍스트 임베딩 검색 시작: {len(documents)}개 문서")
|
|
|
|
# 질문 임베딩 생성
|
|
question_embedding = self.model.encode(question)
|
|
|
|
# 각 문서에 대해 컨텍스트 임베딩 생성 및 유사도 계산
|
|
scored_documents = []
|
|
for doc in documents:
|
|
doc_content = doc.get('content', '')
|
|
if not doc_content:
|
|
continue
|
|
|
|
# 컨텍스트 임베딩 생성
|
|
context_embedding = self.create_context_embedding(question, doc_content)
|
|
|
|
# 코사인 유사도 계산
|
|
similarity = cosine_similarity([question_embedding], [context_embedding])[0][0]
|
|
|
|
scored_documents.append({
|
|
'document': doc,
|
|
'context_score': similarity,
|
|
'content': doc_content
|
|
})
|
|
|
|
# 유사도 기준으로 정렬
|
|
scored_documents.sort(key=lambda x: x['context_score'], reverse=True)
|
|
|
|
logger.info(f"📊 컨텍스트 임베딩 검색 완료: {len(scored_documents)}개 결과")
|
|
return scored_documents[:top_k]
|
|
|
|
class ContextBM25:
|
|
"""컨텍스트 BM25를 통한 검색"""
|
|
|
|
def __init__(self):
|
|
self.bm25 = None
|
|
self.documents = []
|
|
logger.info("✅ 컨텍스트 BM25 초기화 완료")
|
|
|
|
def preprocess_text(self, text: str) -> List[str]:
|
|
"""텍스트 전처리 및 토큰화"""
|
|
# 한글, 영문, 숫자만 추출
|
|
text = re.sub(r'[^\w\s가-힣]', ' ', text)
|
|
# 공백으로 분리
|
|
tokens = text.split()
|
|
# 빈 토큰 제거
|
|
tokens = [token.strip() for token in tokens if token.strip()]
|
|
return tokens
|
|
|
|
def build_index(self, documents: List[Dict[str, Any]]):
|
|
"""BM25 인덱스 구축"""
|
|
self.documents = documents
|
|
corpus = []
|
|
|
|
for doc in documents:
|
|
content = doc.get('content', '')
|
|
tokens = self.preprocess_text(content)
|
|
corpus.append(tokens)
|
|
|
|
self.bm25 = BM25Okapi(corpus)
|
|
logger.info(f"📚 BM25 인덱스 구축 완료: {len(corpus)}개 문서")
|
|
|
|
def search_with_context(self, question: str, top_k: int = 10) -> List[Dict[str, Any]]:
|
|
"""컨텍스트 BM25를 사용한 검색"""
|
|
if not self.bm25:
|
|
logger.warning("⚠️ BM25 인덱스가 구축되지 않았습니다.")
|
|
return []
|
|
|
|
logger.info(f"🔍 컨텍스트 BM25 검색 시작: {question}")
|
|
|
|
# 질문 토큰화
|
|
question_tokens = self.preprocess_text(question)
|
|
|
|
# BM25 점수 계산
|
|
scores = self.bm25.get_scores(question_tokens)
|
|
|
|
# 점수와 문서를 매핑
|
|
scored_documents = []
|
|
for i, (doc, score) in enumerate(zip(self.documents, scores)):
|
|
scored_documents.append({
|
|
'document': doc,
|
|
'bm25_score': score,
|
|
'content': doc.get('content', '')
|
|
})
|
|
|
|
# 점수 기준으로 정렬
|
|
scored_documents.sort(key=lambda x: x['bm25_score'], reverse=True)
|
|
|
|
logger.info(f"📊 컨텍스트 BM25 검색 완료: {len(scored_documents)}개 결과")
|
|
return scored_documents[:top_k]
|
|
|
|
class Reranker:
|
|
"""검색 결과 재순위화"""
|
|
|
|
def __init__(self, model_name: str = "jhgan/ko-sroberta-multitask"):
|
|
self.model = SentenceTransformer(model_name)
|
|
logger.info(f"✅ Reranker 모델 로드 완료: {model_name}")
|
|
|
|
def calculate_relevance_score(self, question: str, document_content: str) -> float:
|
|
"""질문과 문서의 관련성 점수 계산"""
|
|
# 질문과 문서의 임베딩 생성
|
|
question_embedding = self.model.encode(question)
|
|
doc_embedding = self.model.encode(document_content)
|
|
|
|
# 코사인 유사도 계산
|
|
similarity = cosine_similarity([question_embedding], [doc_embedding])[0][0]
|
|
|
|
# 키워드 매칭 점수 추가
|
|
keyword_score = self._calculate_keyword_score(question, document_content)
|
|
|
|
# 최종 점수 (임베딩 유사도 70% + 키워드 매칭 30%)
|
|
final_score = 0.7 * similarity + 0.3 * keyword_score
|
|
|
|
return final_score
|
|
|
|
def _calculate_keyword_score(self, question: str, document: str) -> float:
|
|
"""키워드 매칭 점수 계산"""
|
|
# 질문에서 키워드 추출
|
|
question_tokens = re.findall(r'\b\w+\b', question.lower())
|
|
doc_tokens = re.findall(r'\b\w+\b', document.lower())
|
|
|
|
# 토큰 빈도 계산
|
|
question_counter = Counter(question_tokens)
|
|
doc_counter = Counter(doc_tokens)
|
|
|
|
# 공통 토큰의 가중치 계산
|
|
common_tokens = set(question_tokens) & set(doc_tokens)
|
|
if not common_tokens:
|
|
return 0.0
|
|
|
|
# TF-IDF 스타일 점수 계산
|
|
total_score = 0.0
|
|
for token in common_tokens:
|
|
question_freq = question_counter[token]
|
|
doc_freq = doc_counter[token]
|
|
# 간단한 TF-IDF 스타일 점수
|
|
score = (question_freq * doc_freq) / (len(question_tokens) * len(doc_tokens))
|
|
total_score += score
|
|
|
|
return min(total_score, 1.0) # 최대 1.0으로 제한
|
|
|
|
def rerank_documents(self, question: str, documents: List[Dict[str, Any]], top_k: int = 10) -> List[Dict[str, Any]]:
|
|
"""문서 재순위화"""
|
|
logger.info(f"🔄 Reranker 재순위화 시작: {len(documents)}개 문서")
|
|
|
|
reranked_documents = []
|
|
for doc_info in documents:
|
|
content = doc_info.get('content', '')
|
|
if not content:
|
|
continue
|
|
|
|
# 관련성 점수 계산
|
|
relevance_score = self.calculate_relevance_score(question, content)
|
|
|
|
# 기존 점수와 재순위화 점수 결합
|
|
original_score = doc_info.get('context_score', 0) + doc_info.get('bm25_score', 0)
|
|
final_score = 0.6 * relevance_score + 0.4 * original_score
|
|
|
|
reranked_documents.append({
|
|
**doc_info,
|
|
'rerank_score': relevance_score,
|
|
'final_score': final_score
|
|
})
|
|
|
|
# 최종 점수 기준으로 정렬
|
|
reranked_documents.sort(key=lambda x: x['final_score'], reverse=True)
|
|
|
|
logger.info(f"📊 Reranker 재순위화 완료: {len(reranked_documents)}개 결과")
|
|
return reranked_documents[:top_k]
|
|
|
|
class ContextRetrieval:
|
|
"""컨텍스트 기반 검색 시스템 통합 클래스"""
|
|
|
|
def __init__(self, model_name: str = "jhgan/ko-sroberta-multitask"):
|
|
self.context_embedding = ContextEmbedding(model_name)
|
|
self.context_bm25 = ContextBM25()
|
|
self.reranker = Reranker(model_name)
|
|
logger.info("✅ 컨텍스트 검색 시스템 초기화 완료")
|
|
|
|
def build_index(self, documents: List[Dict[str, Any]]):
|
|
"""검색 인덱스 구축"""
|
|
logger.info(f"📚 검색 인덱스 구축 시작: {len(documents)}개 문서")
|
|
self.context_bm25.build_index(documents)
|
|
logger.info("✅ 검색 인덱스 구축 완료")
|
|
|
|
def search(self, question: str, top_k: int = 10) -> List[Dict[str, Any]]:
|
|
"""컨텍스트 기반 통합 검색"""
|
|
logger.info(f"🔍 컨텍스트 기반 통합 검색 시작: {question}")
|
|
|
|
# 1. 컨텍스트 임베딩 검색
|
|
embedding_results = self.context_embedding.search_with_context(question, self.context_bm25.documents, top_k * 2)
|
|
|
|
# 2. 컨텍스트 BM25 검색
|
|
bm25_results = self.context_bm25.search_with_context(question, top_k * 2)
|
|
|
|
# 3. 두 결과를 결합하여 중복 제거
|
|
combined_results = self._combine_results(embedding_results, bm25_results)
|
|
|
|
# 4. Reranker로 재순위화
|
|
final_results = self.reranker.rerank_documents(question, combined_results, top_k)
|
|
|
|
logger.info(f"📊 컨텍스트 기반 통합 검색 완료: {len(final_results)}개 결과")
|
|
return final_results
|
|
|
|
def _combine_results(self, embedding_results: List[Dict], bm25_results: List[Dict]) -> List[Dict]:
|
|
"""임베딩과 BM25 결과를 결합하여 중복 제거"""
|
|
combined = {}
|
|
|
|
# 임베딩 결과 추가
|
|
for result in embedding_results:
|
|
doc_id = id(result['document'])
|
|
if doc_id not in combined:
|
|
combined[doc_id] = result
|
|
else:
|
|
# 기존 점수와 새 점수 결합
|
|
combined[doc_id]['context_score'] = max(
|
|
combined[doc_id].get('context_score', 0),
|
|
result['context_score']
|
|
)
|
|
|
|
# BM25 결과 추가
|
|
for result in bm25_results:
|
|
doc_id = id(result['document'])
|
|
if doc_id not in combined:
|
|
combined[doc_id] = result
|
|
else:
|
|
# 기존 점수와 새 점수 결합
|
|
combined[doc_id]['bm25_score'] = max(
|
|
combined[doc_id].get('bm25_score', 0),
|
|
result['bm25_score']
|
|
)
|
|
|
|
return list(combined.values())
|