Persist file size at creation and backfill missing sizes on list responses so the UI can display sizes reliably.
212 lines
5.8 KiB
Python
212 lines
5.8 KiB
Python
from pathlib import Path
|
|
from typing import List
|
|
|
|
import logging
|
|
|
|
from dotenv import load_dotenv
|
|
from fastapi import FastAPI, HTTPException, Request
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import FileResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi.templating import Jinja2Templates
|
|
from pydantic import BaseModel
|
|
|
|
from .db import (
|
|
init_db,
|
|
create_item,
|
|
update_filename,
|
|
update_size_bytes,
|
|
list_items,
|
|
get_item,
|
|
delete_items,
|
|
delete_item_by_id,
|
|
)
|
|
from .tts_service import text_to_mp3
|
|
|
|
BASE_DIR = Path(__file__).resolve().parent
|
|
ROOT_DIR = BASE_DIR.parent
|
|
CLIENT_DIR = ROOT_DIR / "client"
|
|
RESOURCES_DIR = ROOT_DIR / "resources"
|
|
|
|
# 프로젝트 루트의 .env를 명시적으로 로드
|
|
load_dotenv(dotenv_path=ROOT_DIR / ".env")
|
|
|
|
app = FastAPI()
|
|
logger = logging.getLogger("tts")
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
app.mount("/static", StaticFiles(directory=str(CLIENT_DIR / "static")), name="static")
|
|
templates = Jinja2Templates(directory=str(CLIENT_DIR / "templates"))
|
|
|
|
|
|
class TtsCreateRequest(BaseModel):
|
|
text: str
|
|
|
|
|
|
class TtsDeleteRequest(BaseModel):
|
|
ids: List[int]
|
|
|
|
|
|
def format_display_time(dt):
|
|
# 한국 표기 형식으로 변환
|
|
local_dt = dt.astimezone()
|
|
return local_dt.strftime("%Y년 %m월 %d일 %H:%M:%S")
|
|
|
|
|
|
def ensure_resources_dir():
|
|
# mp3 저장 디렉토리 보장
|
|
RESOURCES_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
def format_size(bytes_size: int) -> str:
|
|
if bytes_size < 1024:
|
|
return f"{bytes_size}B"
|
|
if bytes_size < 1024 * 1024:
|
|
return f"{bytes_size / 1024:.1f}KB"
|
|
return f"{bytes_size / (1024 * 1024):.1f}MB"
|
|
|
|
|
|
def get_file_size_display(size_bytes: int | None) -> str | None:
|
|
if size_bytes is None:
|
|
return None
|
|
return format_size(size_bytes)
|
|
|
|
|
|
def get_file_size_bytes(filename: str | None) -> int | None:
|
|
if not filename:
|
|
return None
|
|
file_path = RESOURCES_DIR / filename
|
|
if not file_path.exists():
|
|
return None
|
|
return file_path.stat().st_size
|
|
|
|
|
|
@app.on_event("startup")
|
|
def on_startup():
|
|
ensure_resources_dir()
|
|
init_db()
|
|
|
|
|
|
@app.get("/")
|
|
def index(request: Request):
|
|
return templates.TemplateResponse("index.html", {"request": request})
|
|
|
|
|
|
@app.get("/api/tts")
|
|
def api_list_tts():
|
|
rows = list_items()
|
|
payload = []
|
|
for row in rows:
|
|
size_bytes = row.get("size_bytes")
|
|
if size_bytes is None and row.get("filename"):
|
|
computed = get_file_size_bytes(row["filename"])
|
|
if computed is not None:
|
|
update_size_bytes(row["id"], computed)
|
|
size_bytes = computed
|
|
payload.append(
|
|
{
|
|
"id": row["id"],
|
|
"created_at": row["created_at"].isoformat(),
|
|
"display_time": format_display_time(row["created_at"]),
|
|
"filename": row["filename"],
|
|
"size_display": get_file_size_display(size_bytes),
|
|
}
|
|
)
|
|
return payload
|
|
|
|
|
|
@app.post("/api/tts")
|
|
def api_create_tts(payload: TtsCreateRequest):
|
|
text = (payload.text or "").strip()
|
|
if len(text) < 11:
|
|
raise HTTPException(status_code=400, detail="텍스트는 11글자 이상이어야 합니다.")
|
|
|
|
created = create_item(text)
|
|
tts_id = created["id"]
|
|
created_at = created["created_at"]
|
|
|
|
timestamp = created_at.astimezone().strftime("%Y%m%d_%H%M%S")
|
|
filename = f"tts_{tts_id}_{timestamp}.mp3"
|
|
mp3_path = RESOURCES_DIR / filename
|
|
|
|
try:
|
|
text_to_mp3(text=text, mp3_path=str(mp3_path))
|
|
except Exception as exc:
|
|
logger.exception("TTS 생성 실패")
|
|
delete_item_by_id(tts_id)
|
|
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
|
|
|
size_bytes = get_file_size_bytes(filename)
|
|
update_filename(tts_id, filename)
|
|
if size_bytes is not None:
|
|
update_size_bytes(tts_id, size_bytes)
|
|
|
|
return {
|
|
"id": tts_id,
|
|
"created_at": created_at.isoformat(),
|
|
"display_time": format_display_time(created_at),
|
|
"filename": filename,
|
|
"size_display": get_file_size_display(size_bytes),
|
|
}
|
|
|
|
|
|
@app.get("/api/tts/{tts_id}")
|
|
def api_get_tts(tts_id: int):
|
|
row = get_item(tts_id)
|
|
if not row:
|
|
raise HTTPException(status_code=404, detail="해당 항목이 없습니다.")
|
|
|
|
return {
|
|
"id": row["id"],
|
|
"text": row["text"],
|
|
"created_at": row["created_at"].isoformat(),
|
|
"display_time": format_display_time(row["created_at"]),
|
|
"filename": row["filename"],
|
|
"download_url": f"/api/tts/{row['id']}/download",
|
|
}
|
|
|
|
|
|
@app.get("/api/tts/{tts_id}/download")
|
|
def api_download_tts(tts_id: int):
|
|
row = get_item(tts_id)
|
|
if not row or not row["filename"]:
|
|
raise HTTPException(status_code=404, detail="파일이 없습니다.")
|
|
|
|
file_path = RESOURCES_DIR / row["filename"]
|
|
if not file_path.exists():
|
|
raise HTTPException(status_code=404, detail="파일이 없습니다.")
|
|
|
|
return FileResponse(
|
|
path=str(file_path),
|
|
media_type="audio/mpeg",
|
|
filename=row["filename"],
|
|
)
|
|
|
|
|
|
@app.delete("/api/tts")
|
|
def api_delete_tts(payload: TtsDeleteRequest):
|
|
ids = [int(i) for i in payload.ids if isinstance(i, int) or str(i).isdigit()]
|
|
if not ids:
|
|
raise HTTPException(status_code=400, detail="삭제할 항목이 없습니다.")
|
|
|
|
deleted_rows = delete_items(ids)
|
|
deleted_ids = []
|
|
for row in deleted_rows:
|
|
deleted_ids.append(row["id"])
|
|
if row.get("filename"):
|
|
file_path = RESOURCES_DIR / row["filename"]
|
|
if file_path.exists():
|
|
try:
|
|
file_path.unlink()
|
|
except OSError:
|
|
pass
|
|
|
|
return {"deleted": deleted_ids}
|