214 lines
5.9 KiB
Python
214 lines
5.9 KiB
Python
from pathlib import Path
|
|
from typing import List
|
|
|
|
import logging
|
|
|
|
from dotenv import load_dotenv
|
|
from fastapi import FastAPI, HTTPException, Request
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import FileResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi.templating import Jinja2Templates
|
|
from pydantic import BaseModel
|
|
|
|
from .db import (
|
|
init_db,
|
|
create_item,
|
|
update_filename,
|
|
update_size_bytes,
|
|
list_items,
|
|
get_item,
|
|
delete_items,
|
|
delete_item_by_id,
|
|
)
|
|
from .tts_service import text_to_mp3
|
|
|
|
BASE_DIR = Path(__file__).resolve().parent
|
|
ROOT_DIR = BASE_DIR.parent
|
|
CLIENT_DIR = ROOT_DIR / "client"
|
|
RESOURCES_DIR = ROOT_DIR / "resources"
|
|
|
|
# 프로젝트 루트의 .env를 명시적으로 로드
|
|
load_dotenv(dotenv_path=ROOT_DIR / ".env")
|
|
|
|
app = FastAPI()
|
|
logger = logging.getLogger("tts")
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
app.mount("/static", StaticFiles(directory=str(CLIENT_DIR / "static")), name="static")
|
|
templates = Jinja2Templates(directory=str(CLIENT_DIR / "templates"))
|
|
|
|
|
|
class TtsCreateRequest(BaseModel):
|
|
text: str
|
|
voice: str | None = None
|
|
|
|
|
|
class TtsDeleteRequest(BaseModel):
|
|
ids: List[int]
|
|
|
|
|
|
def format_display_time(dt):
|
|
# 한국 표기 형식으로 변환
|
|
local_dt = dt.astimezone()
|
|
return local_dt.strftime("%Y년 %m월 %d일 %H:%M:%S")
|
|
|
|
|
|
def ensure_resources_dir():
|
|
# mp3 저장 디렉토리 보장
|
|
RESOURCES_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
def format_size(bytes_size: int) -> str:
|
|
if bytes_size < 1024:
|
|
return f"{bytes_size}B"
|
|
if bytes_size < 1024 * 1024:
|
|
return f"{bytes_size / 1024:.1f}KB"
|
|
return f"{bytes_size / (1024 * 1024):.1f}MB"
|
|
|
|
|
|
def get_file_size_display(size_bytes: int | None) -> str | None:
|
|
if size_bytes is None:
|
|
return None
|
|
return format_size(size_bytes)
|
|
|
|
|
|
def get_file_size_bytes(filename: str | None) -> int | None:
|
|
if not filename:
|
|
return None
|
|
file_path = RESOURCES_DIR / filename
|
|
if not file_path.exists():
|
|
return None
|
|
return file_path.stat().st_size
|
|
|
|
|
|
@app.on_event("startup")
|
|
def on_startup():
|
|
ensure_resources_dir()
|
|
init_db()
|
|
|
|
|
|
@app.get("/")
|
|
def index(request: Request):
|
|
return templates.TemplateResponse("index.html", {"request": request})
|
|
|
|
|
|
@app.get("/api/tts")
|
|
def api_list_tts():
|
|
rows = list_items()
|
|
payload = []
|
|
for row in rows:
|
|
size_bytes = row.get("size_bytes")
|
|
if size_bytes is None and row.get("filename"):
|
|
computed = get_file_size_bytes(row["filename"])
|
|
if computed is not None:
|
|
update_size_bytes(row["id"], computed)
|
|
size_bytes = computed
|
|
payload.append(
|
|
{
|
|
"id": row["id"],
|
|
"created_at": row["created_at"].isoformat(),
|
|
"display_time": format_display_time(row["created_at"]),
|
|
"filename": row["filename"],
|
|
"size_display": get_file_size_display(size_bytes),
|
|
}
|
|
)
|
|
return payload
|
|
|
|
|
|
@app.post("/api/tts")
|
|
def api_create_tts(payload: TtsCreateRequest):
|
|
text = (payload.text or "").strip()
|
|
voice = (payload.voice or "").strip().lower()
|
|
if len(text) < 11:
|
|
raise HTTPException(status_code=400, detail="텍스트는 11글자 이상이어야 합니다.")
|
|
|
|
created = create_item(text)
|
|
tts_id = created["id"]
|
|
created_at = created["created_at"]
|
|
|
|
timestamp = created_at.astimezone().strftime("%Y%m%d_%H%M%S")
|
|
filename = f"tts_{tts_id}_{timestamp}.mp3"
|
|
mp3_path = RESOURCES_DIR / filename
|
|
|
|
try:
|
|
text_to_mp3(text=text, mp3_path=str(mp3_path), voice=voice)
|
|
except Exception as exc:
|
|
logger.exception("TTS 생성 실패")
|
|
delete_item_by_id(tts_id)
|
|
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
|
|
|
size_bytes = get_file_size_bytes(filename)
|
|
update_filename(tts_id, filename)
|
|
if size_bytes is not None:
|
|
update_size_bytes(tts_id, size_bytes)
|
|
|
|
return {
|
|
"id": tts_id,
|
|
"created_at": created_at.isoformat(),
|
|
"display_time": format_display_time(created_at),
|
|
"filename": filename,
|
|
"size_display": get_file_size_display(size_bytes),
|
|
}
|
|
|
|
|
|
@app.get("/api/tts/{tts_id}")
|
|
def api_get_tts(tts_id: int):
|
|
row = get_item(tts_id)
|
|
if not row:
|
|
raise HTTPException(status_code=404, detail="해당 항목이 없습니다.")
|
|
|
|
return {
|
|
"id": row["id"],
|
|
"text": row["text"],
|
|
"created_at": row["created_at"].isoformat(),
|
|
"display_time": format_display_time(row["created_at"]),
|
|
"filename": row["filename"],
|
|
"download_url": f"/api/tts/{row['id']}/download",
|
|
}
|
|
|
|
|
|
@app.get("/api/tts/{tts_id}/download")
|
|
def api_download_tts(tts_id: int):
|
|
row = get_item(tts_id)
|
|
if not row or not row["filename"]:
|
|
raise HTTPException(status_code=404, detail="파일이 없습니다.")
|
|
|
|
file_path = RESOURCES_DIR / row["filename"]
|
|
if not file_path.exists():
|
|
raise HTTPException(status_code=404, detail="파일이 없습니다.")
|
|
|
|
return FileResponse(
|
|
path=str(file_path),
|
|
media_type="audio/mpeg",
|
|
filename=row["filename"],
|
|
)
|
|
|
|
|
|
@app.delete("/api/tts")
|
|
def api_delete_tts(payload: TtsDeleteRequest):
|
|
ids = [int(i) for i in payload.ids if isinstance(i, int) or str(i).isdigit()]
|
|
if not ids:
|
|
raise HTTPException(status_code=400, detail="삭제할 항목이 없습니다.")
|
|
|
|
deleted_rows = delete_items(ids)
|
|
deleted_ids = []
|
|
for row in deleted_rows:
|
|
deleted_ids.append(row["id"])
|
|
if row.get("filename"):
|
|
file_path = RESOURCES_DIR / row["filename"]
|
|
if file_path.exists():
|
|
try:
|
|
file_path.unlink()
|
|
except OSError:
|
|
pass
|
|
|
|
return {"deleted": deleted_ids}
|