391 lines
12 KiB
Python
391 lines
12 KiB
Python
"""
|
|
LangChain v0.3 기반 연구QA 챗봇 API
|
|
향후 고도화를 위한 확장 가능한 아키텍처
|
|
"""
|
|
|
|
from fastapi import FastAPI, HTTPException, Depends, UploadFile, File, Form
|
|
from fastapi.responses import FileResponse
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
from pydantic import BaseModel
|
|
from typing import List, Optional
|
|
from contextlib import asynccontextmanager
|
|
import os
|
|
import uuid
|
|
import shutil
|
|
from datetime import datetime
|
|
import json
|
|
import logging
|
|
import psycopg2
|
|
from psycopg2.extras import RealDictCursor
|
|
|
|
# LangChain 서비스 임포트
|
|
from services.langchain_service import langchain_service
|
|
from parser.pdf.MainParser import PDFParser
|
|
|
|
# 로깅 설정
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Pydantic 모델들
|
|
class ChatRequest(BaseModel):
|
|
message: str
|
|
user_id: Optional[str] = None
|
|
|
|
class ChatResponse(BaseModel):
|
|
response: str
|
|
sources: List[str]
|
|
timestamp: str
|
|
|
|
class FileUploadResponse(BaseModel):
|
|
message: str
|
|
file_id: str
|
|
filename: str
|
|
status: str
|
|
|
|
class FileListResponse(BaseModel):
|
|
files: List[dict]
|
|
total: int
|
|
|
|
# FastAPI 앱 생성
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""앱 시작/종료 시 실행"""
|
|
# 시작 시
|
|
logger.info("🚀 LangChain 기반 연구QA 챗봇 시작")
|
|
try:
|
|
langchain_service.initialize()
|
|
logger.info("✅ LangChain 서비스 초기화 완료")
|
|
except Exception as e:
|
|
logger.error(f"❌ LangChain 서비스 초기화 실패: {e}")
|
|
raise
|
|
|
|
yield
|
|
|
|
# 종료 시
|
|
logger.info("🛑 LangChain 기반 연구QA 챗봇 종료")
|
|
|
|
app = FastAPI(
|
|
title="연구QA Chatbot API",
|
|
description="LangChain v0.3 기반 고성능 PDF 파싱과 벡터 검색을 활용한 연구 질의응답 시스템",
|
|
version="2.0.0",
|
|
lifespan=lifespan
|
|
)
|
|
|
|
# CORS 설정
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["http://localhost:3000", "http://127.0.0.1:3000"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# 보안 설정
|
|
security = HTTPBearer(auto_error=False)
|
|
|
|
def get_db_connection():
|
|
"""PostgreSQL 데이터베이스 연결"""
|
|
try:
|
|
connection = psycopg2.connect(
|
|
host="localhost",
|
|
port=5432,
|
|
database="researchqa",
|
|
user="woonglab",
|
|
password="!@#woonglab"
|
|
)
|
|
connection.autocommit = True
|
|
return connection
|
|
except Exception as e:
|
|
logger.error(f"PostgreSQL 연결 실패: {e}")
|
|
raise HTTPException(status_code=500, detail="데이터베이스 연결 실패")
|
|
|
|
# API 엔드포인트들
|
|
@app.get("/")
|
|
async def root():
|
|
"""루트 엔드포인트"""
|
|
return {
|
|
"message": "LangChain 기반 연구QA 챗봇 API",
|
|
"version": "2.0.0",
|
|
"status": "running"
|
|
}
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""헬스 체크"""
|
|
try:
|
|
# LangChain 서비스 상태 확인
|
|
collection_info = langchain_service.get_collection_info()
|
|
|
|
return {
|
|
"status": "healthy",
|
|
"langchain_service": "active",
|
|
"collection_info": collection_info,
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"헬스 체크 실패: {e}")
|
|
raise HTTPException(status_code=500, detail=f"서비스 상태 불량: {e}")
|
|
|
|
@app.post("/chat", response_model=ChatResponse)
|
|
async def chat(request: ChatRequest):
|
|
"""LangChain RAG 기반 채팅"""
|
|
try:
|
|
logger.info(f"💬 채팅 요청: {request.message}")
|
|
|
|
# LangChain RAG를 통한 답변 생성
|
|
result = langchain_service.generate_answer(request.message)
|
|
|
|
response = ChatResponse(
|
|
response=result["answer"],
|
|
sources=result["references"],
|
|
timestamp=datetime.now().isoformat()
|
|
)
|
|
|
|
logger.info(f"✅ 답변 생성 완료: {len(result['references'])}개 참조")
|
|
return response
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ 채팅 처리 실패: {e}")
|
|
raise HTTPException(status_code=500, detail=f"채팅 처리 실패: {e}")
|
|
|
|
@app.post("/upload", response_model=FileUploadResponse)
|
|
async def upload_file(file: UploadFile = File(...)):
|
|
"""PDF 파일 업로드 및 LangChain 처리"""
|
|
try:
|
|
# 파일 유효성 검사
|
|
if not file.filename.lower().endswith('.pdf'):
|
|
raise HTTPException(status_code=400, detail="PDF 파일만 업로드 가능합니다")
|
|
|
|
# 파일 ID 생성 (UUID)
|
|
file_id = str(uuid.uuid4())
|
|
filename = file.filename
|
|
|
|
logger.info(f"📄 파일 업로드 시작: {filename}")
|
|
|
|
# 파일 저장
|
|
upload_dir = "uploads"
|
|
os.makedirs(upload_dir, exist_ok=True)
|
|
file_path = os.path.join(upload_dir, f"{file_id}_{filename}")
|
|
|
|
with open(file_path, "wb") as buffer:
|
|
shutil.copyfileobj(file.file, buffer)
|
|
|
|
# PDF 파싱
|
|
parser = PDFParser()
|
|
result = parser.process_pdf(file_path)
|
|
|
|
if not result["success"]:
|
|
raise HTTPException(status_code=400, detail=f"PDF 파싱 실패: {result.get('error', 'Unknown error')}")
|
|
|
|
# LangChain 문서로 변환
|
|
from langchain_core.documents import Document
|
|
langchain_docs = []
|
|
|
|
# 청크별로 문서 생성
|
|
for i, chunk in enumerate(result["chunks"]):
|
|
langchain_doc = Document(
|
|
page_content=chunk,
|
|
metadata={
|
|
"filename": filename,
|
|
"chunk_index": i,
|
|
"file_id": file_id,
|
|
"upload_time": datetime.now().isoformat(),
|
|
"total_chunks": len(result["chunks"])
|
|
}
|
|
)
|
|
langchain_docs.append(langchain_doc)
|
|
|
|
# LangChain 벡터스토어에 추가
|
|
langchain_service.add_documents(langchain_docs)
|
|
|
|
# 데이터베이스에 메타데이터 저장
|
|
db_conn = get_db_connection()
|
|
cursor = db_conn.cursor()
|
|
|
|
cursor.execute("""
|
|
INSERT INTO uploaded_file (filename, file_path, status, upload_dt)
|
|
VALUES (%s, %s, %s, %s)
|
|
""", (filename, file_path, "processed", datetime.now()))
|
|
|
|
cursor.close()
|
|
|
|
logger.info(f"✅ 파일 업로드 완료: {filename} ({len(langchain_docs)}개 문서)")
|
|
|
|
return FileUploadResponse(
|
|
message=f"파일 업로드 및 처리 완료: {len(langchain_docs)}개 문서",
|
|
file_id=file_id,
|
|
filename=filename,
|
|
status="success"
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ 파일 업로드 실패: {e}")
|
|
raise HTTPException(status_code=500, detail=f"파일 업로드 실패: {e}")
|
|
|
|
@app.get("/files", response_model=FileListResponse)
|
|
async def get_files():
|
|
"""업로드된 파일 목록 조회"""
|
|
try:
|
|
db_conn = get_db_connection()
|
|
cursor = db_conn.cursor(cursor_factory=RealDictCursor)
|
|
|
|
cursor.execute("""
|
|
SELECT id, filename, upload_dt as upload_time, status
|
|
FROM uploaded_file
|
|
ORDER BY upload_dt DESC
|
|
""")
|
|
|
|
files = cursor.fetchall()
|
|
cursor.close()
|
|
|
|
return FileListResponse(
|
|
files=[dict(file) for file in files],
|
|
total=len(files)
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ 파일 목록 조회 실패: {e}")
|
|
raise HTTPException(status_code=500, detail=f"파일 목록 조회 실패: {e}")
|
|
|
|
@app.delete("/files/{file_id}")
|
|
async def delete_file(file_id: str):
|
|
"""파일 삭제"""
|
|
try:
|
|
db_conn = get_db_connection()
|
|
cursor = db_conn.cursor()
|
|
|
|
# 파일 정보 조회
|
|
cursor.execute("SELECT filename FROM uploaded_file WHERE id = %s", (file_id,))
|
|
result = cursor.fetchone()
|
|
|
|
if not result:
|
|
raise HTTPException(status_code=404, detail="파일을 찾을 수 없습니다")
|
|
|
|
filename = result[0]
|
|
|
|
# LangChain 벡터스토어에서 삭제
|
|
langchain_service.delete_documents_by_filename(filename)
|
|
|
|
# 데이터베이스에서 삭제
|
|
cursor.execute("DELETE FROM uploaded_file WHERE id = %s", (file_id,))
|
|
|
|
# 실제 파일 삭제
|
|
try:
|
|
os.remove(f"uploads/{file_id}_{filename}")
|
|
except FileNotFoundError:
|
|
pass
|
|
|
|
cursor.close()
|
|
|
|
logger.info(f"✅ 파일 삭제 완료: {filename}")
|
|
|
|
return {"message": f"파일 삭제 완료: {filename}"}
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ 파일 삭제 실패: {e}")
|
|
raise HTTPException(status_code=500, detail=f"파일 삭제 실패: {e}")
|
|
|
|
@app.get("/pdf/{file_id}/view")
|
|
async def view_pdf(file_id: str):
|
|
"""PDF 파일 뷰어"""
|
|
try:
|
|
db_conn = get_db_connection()
|
|
cursor = db_conn.cursor()
|
|
|
|
# UUID가 전달된 경우 정수 ID로 변환
|
|
try:
|
|
# 먼저 정수 ID로 시도
|
|
cursor.execute("SELECT filename, file_path FROM uploaded_file WHERE id = %s", (int(file_id),))
|
|
result = cursor.fetchone()
|
|
except ValueError:
|
|
# UUID가 전달된 경우 file_path에서 UUID를 찾아서 매칭
|
|
cursor.execute("SELECT id, filename, file_path FROM uploaded_file")
|
|
all_files = cursor.fetchall()
|
|
result = None
|
|
for file_row in all_files:
|
|
if file_id in file_row[2]: # file_path에 UUID가 포함되어 있는지 확인
|
|
result = (file_row[1], file_row[2]) # filename, file_path
|
|
break
|
|
|
|
if not result:
|
|
raise HTTPException(status_code=404, detail="파일을 찾을 수 없습니다")
|
|
|
|
filename = result[0]
|
|
file_path = result[1]
|
|
|
|
# 절대 경로로 변환
|
|
if not os.path.isabs(file_path):
|
|
file_path = os.path.abspath(file_path)
|
|
|
|
if not os.path.exists(file_path):
|
|
raise HTTPException(status_code=404, detail="파일이 존재하지 않습니다")
|
|
|
|
cursor.close()
|
|
|
|
return FileResponse(
|
|
path=file_path,
|
|
media_type="application/pdf",
|
|
filename=filename
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ PDF 뷰어 실패: {e}")
|
|
raise HTTPException(status_code=500, detail=f"PDF 뷰어 실패: {e}")
|
|
|
|
@app.get("/search")
|
|
async def search_documents(query: str, limit: int = 5):
|
|
"""문서 검색"""
|
|
try:
|
|
# LangChain 유사 문서 검색
|
|
documents = langchain_service.search_similar_documents(query, k=limit)
|
|
|
|
results = []
|
|
for doc in documents:
|
|
results.append({
|
|
"content": doc.page_content[:200] + "...",
|
|
"metadata": doc.metadata,
|
|
"score": getattr(doc, 'score', 0.0)
|
|
})
|
|
|
|
return {
|
|
"query": query,
|
|
"results": results,
|
|
"total": len(results)
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ 문서 검색 실패: {e}")
|
|
raise HTTPException(status_code=500, detail=f"문서 검색 실패: {e}")
|
|
|
|
@app.get("/stats")
|
|
async def get_stats():
|
|
"""시스템 통계"""
|
|
try:
|
|
# LangChain 컬렉션 정보
|
|
collection_info = langchain_service.get_collection_info()
|
|
|
|
# 데이터베이스 통계
|
|
db_conn = get_db_connection()
|
|
cursor = db_conn.cursor()
|
|
|
|
cursor.execute("SELECT COUNT(*) FROM uploaded_file")
|
|
file_count = cursor.fetchone()[0]
|
|
|
|
cursor.close()
|
|
|
|
return {
|
|
"langchain_stats": collection_info,
|
|
"database_stats": {
|
|
"total_files": file_count
|
|
},
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ 통계 조회 실패: {e}")
|
|
raise HTTPException(status_code=500, detail=f"통계 조회 실패: {e}")
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |