2025-10-05 23:22:54 +09:00

391 lines
12 KiB
Python

"""
LangChain v0.3 기반 연구QA 챗봇 API
향후 고도화를 위한 확장 가능한 아키텍처
"""
from fastapi import FastAPI, HTTPException, Depends, UploadFile, File, Form
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel
from typing import List, Optional
from contextlib import asynccontextmanager
import os
import uuid
import shutil
from datetime import datetime
import json
import logging
import psycopg2
from psycopg2.extras import RealDictCursor
# LangChain 서비스 임포트
from services.langchain_service import langchain_service
from parser.pdf.MainParser import PDFParser
# 로깅 설정
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Pydantic 모델들
class ChatRequest(BaseModel):
message: str
user_id: Optional[str] = None
class ChatResponse(BaseModel):
response: str
sources: List[str]
timestamp: str
class FileUploadResponse(BaseModel):
message: str
file_id: str
filename: str
status: str
class FileListResponse(BaseModel):
files: List[dict]
total: int
# FastAPI 앱 생성
@asynccontextmanager
async def lifespan(app: FastAPI):
"""앱 시작/종료 시 실행"""
# 시작 시
logger.info("🚀 LangChain 기반 연구QA 챗봇 시작")
try:
langchain_service.initialize()
logger.info("✅ LangChain 서비스 초기화 완료")
except Exception as e:
logger.error(f"❌ LangChain 서비스 초기화 실패: {e}")
raise
yield
# 종료 시
logger.info("🛑 LangChain 기반 연구QA 챗봇 종료")
app = FastAPI(
title="연구QA Chatbot API",
description="LangChain v0.3 기반 고성능 PDF 파싱과 벡터 검색을 활용한 연구 질의응답 시스템",
version="2.0.0",
lifespan=lifespan
)
# CORS 설정
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000", "http://127.0.0.1:3000"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 보안 설정
security = HTTPBearer(auto_error=False)
def get_db_connection():
"""PostgreSQL 데이터베이스 연결"""
try:
connection = psycopg2.connect(
host="localhost",
port=5432,
database="researchqa",
user="woonglab",
password="!@#woonglab"
)
connection.autocommit = True
return connection
except Exception as e:
logger.error(f"PostgreSQL 연결 실패: {e}")
raise HTTPException(status_code=500, detail="데이터베이스 연결 실패")
# API 엔드포인트들
@app.get("/")
async def root():
"""루트 엔드포인트"""
return {
"message": "LangChain 기반 연구QA 챗봇 API",
"version": "2.0.0",
"status": "running"
}
@app.get("/health")
async def health_check():
"""헬스 체크"""
try:
# LangChain 서비스 상태 확인
collection_info = langchain_service.get_collection_info()
return {
"status": "healthy",
"langchain_service": "active",
"collection_info": collection_info,
"timestamp": datetime.now().isoformat()
}
except Exception as e:
logger.error(f"헬스 체크 실패: {e}")
raise HTTPException(status_code=500, detail=f"서비스 상태 불량: {e}")
@app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest):
"""LangChain RAG 기반 채팅"""
try:
logger.info(f"💬 채팅 요청: {request.message}")
# LangChain RAG를 통한 답변 생성
result = langchain_service.generate_answer(request.message)
response = ChatResponse(
response=result["answer"],
sources=result["references"],
timestamp=datetime.now().isoformat()
)
logger.info(f"✅ 답변 생성 완료: {len(result['references'])}개 참조")
return response
except Exception as e:
logger.error(f"❌ 채팅 처리 실패: {e}")
raise HTTPException(status_code=500, detail=f"채팅 처리 실패: {e}")
@app.post("/upload", response_model=FileUploadResponse)
async def upload_file(file: UploadFile = File(...)):
"""PDF 파일 업로드 및 LangChain 처리"""
try:
# 파일 유효성 검사
if not file.filename.lower().endswith('.pdf'):
raise HTTPException(status_code=400, detail="PDF 파일만 업로드 가능합니다")
# 파일 ID 생성 (UUID)
file_id = str(uuid.uuid4())
filename = file.filename
logger.info(f"📄 파일 업로드 시작: {filename}")
# 파일 저장
upload_dir = "uploads"
os.makedirs(upload_dir, exist_ok=True)
file_path = os.path.join(upload_dir, f"{file_id}_{filename}")
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
# PDF 파싱
parser = PDFParser()
result = parser.process_pdf(file_path)
if not result["success"]:
raise HTTPException(status_code=400, detail=f"PDF 파싱 실패: {result.get('error', 'Unknown error')}")
# LangChain 문서로 변환
from langchain_core.documents import Document
langchain_docs = []
# 청크별로 문서 생성
for i, chunk in enumerate(result["chunks"]):
langchain_doc = Document(
page_content=chunk,
metadata={
"filename": filename,
"chunk_index": i,
"file_id": file_id,
"upload_time": datetime.now().isoformat(),
"total_chunks": len(result["chunks"])
}
)
langchain_docs.append(langchain_doc)
# LangChain 벡터스토어에 추가
langchain_service.add_documents(langchain_docs)
# 데이터베이스에 메타데이터 저장
db_conn = get_db_connection()
cursor = db_conn.cursor()
cursor.execute("""
INSERT INTO uploaded_file (filename, file_path, status, upload_dt)
VALUES (%s, %s, %s, %s)
""", (filename, file_path, "processed", datetime.now()))
cursor.close()
logger.info(f"✅ 파일 업로드 완료: {filename} ({len(langchain_docs)}개 문서)")
return FileUploadResponse(
message=f"파일 업로드 및 처리 완료: {len(langchain_docs)}개 문서",
file_id=file_id,
filename=filename,
status="success"
)
except Exception as e:
logger.error(f"❌ 파일 업로드 실패: {e}")
raise HTTPException(status_code=500, detail=f"파일 업로드 실패: {e}")
@app.get("/files", response_model=FileListResponse)
async def get_files():
"""업로드된 파일 목록 조회"""
try:
db_conn = get_db_connection()
cursor = db_conn.cursor(cursor_factory=RealDictCursor)
cursor.execute("""
SELECT id, filename, upload_dt as upload_time, status
FROM uploaded_file
ORDER BY upload_dt DESC
""")
files = cursor.fetchall()
cursor.close()
return FileListResponse(
files=[dict(file) for file in files],
total=len(files)
)
except Exception as e:
logger.error(f"❌ 파일 목록 조회 실패: {e}")
raise HTTPException(status_code=500, detail=f"파일 목록 조회 실패: {e}")
@app.delete("/files/{file_id}")
async def delete_file(file_id: str):
"""파일 삭제"""
try:
db_conn = get_db_connection()
cursor = db_conn.cursor()
# 파일 정보 조회
cursor.execute("SELECT filename FROM uploaded_file WHERE id = %s", (file_id,))
result = cursor.fetchone()
if not result:
raise HTTPException(status_code=404, detail="파일을 찾을 수 없습니다")
filename = result[0]
# LangChain 벡터스토어에서 삭제
langchain_service.delete_documents_by_filename(filename)
# 데이터베이스에서 삭제
cursor.execute("DELETE FROM uploaded_file WHERE id = %s", (file_id,))
# 실제 파일 삭제
try:
os.remove(f"uploads/{file_id}_{filename}")
except FileNotFoundError:
pass
cursor.close()
logger.info(f"✅ 파일 삭제 완료: {filename}")
return {"message": f"파일 삭제 완료: {filename}"}
except Exception as e:
logger.error(f"❌ 파일 삭제 실패: {e}")
raise HTTPException(status_code=500, detail=f"파일 삭제 실패: {e}")
@app.get("/pdf/{file_id}/view")
async def view_pdf(file_id: str):
"""PDF 파일 뷰어"""
try:
db_conn = get_db_connection()
cursor = db_conn.cursor()
# UUID가 전달된 경우 정수 ID로 변환
try:
# 먼저 정수 ID로 시도
cursor.execute("SELECT filename, file_path FROM uploaded_file WHERE id = %s", (int(file_id),))
result = cursor.fetchone()
except ValueError:
# UUID가 전달된 경우 file_path에서 UUID를 찾아서 매칭
cursor.execute("SELECT id, filename, file_path FROM uploaded_file")
all_files = cursor.fetchall()
result = None
for file_row in all_files:
if file_id in file_row[2]: # file_path에 UUID가 포함되어 있는지 확인
result = (file_row[1], file_row[2]) # filename, file_path
break
if not result:
raise HTTPException(status_code=404, detail="파일을 찾을 수 없습니다")
filename = result[0]
file_path = result[1]
# 절대 경로로 변환
if not os.path.isabs(file_path):
file_path = os.path.abspath(file_path)
if not os.path.exists(file_path):
raise HTTPException(status_code=404, detail="파일이 존재하지 않습니다")
cursor.close()
return FileResponse(
path=file_path,
media_type="application/pdf",
filename=filename
)
except Exception as e:
logger.error(f"❌ PDF 뷰어 실패: {e}")
raise HTTPException(status_code=500, detail=f"PDF 뷰어 실패: {e}")
@app.get("/search")
async def search_documents(query: str, limit: int = 5):
"""문서 검색"""
try:
# LangChain 유사 문서 검색
documents = langchain_service.search_similar_documents(query, k=limit)
results = []
for doc in documents:
results.append({
"content": doc.page_content[:200] + "...",
"metadata": doc.metadata,
"score": getattr(doc, 'score', 0.0)
})
return {
"query": query,
"results": results,
"total": len(results)
}
except Exception as e:
logger.error(f"❌ 문서 검색 실패: {e}")
raise HTTPException(status_code=500, detail=f"문서 검색 실패: {e}")
@app.get("/stats")
async def get_stats():
"""시스템 통계"""
try:
# LangChain 컬렉션 정보
collection_info = langchain_service.get_collection_info()
# 데이터베이스 통계
db_conn = get_db_connection()
cursor = db_conn.cursor()
cursor.execute("SELECT COUNT(*) FROM uploaded_file")
file_count = cursor.fetchone()[0]
cursor.close()
return {
"langchain_stats": collection_info,
"database_stats": {
"total_files": file_count
},
"timestamp": datetime.now().isoformat()
}
except Exception as e:
logger.error(f"❌ 통계 조회 실패: {e}")
raise HTTPException(status_code=500, detail=f"통계 조회 실패: {e}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)