485 lines
15 KiB
Python
485 lines
15 KiB
Python
from __future__ import annotations
|
|
|
|
import dataclasses
|
|
import mimetypes
|
|
import os
|
|
import tempfile
|
|
import threading
|
|
import time
|
|
import re
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from uuid import uuid4
|
|
|
|
from dotenv import load_dotenv
|
|
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
|
from fastapi.responses import HTMLResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from pydantic import BaseModel
|
|
|
|
from . import db
|
|
from .stt import transcribe_file, transcribe_iter
|
|
|
|
|
|
load_dotenv()
|
|
|
|
APP_ROOT = Path(__file__).resolve().parent
|
|
PROJECT_ROOT = APP_ROOT.parent
|
|
STATIC_DIR = APP_ROOT / "static"
|
|
UPLOAD_DIR = PROJECT_ROOT / "resources" / "uploads"
|
|
|
|
ALLOWED_EXTS = {".mp3", ".m4a", ".wav", ".mp4", ".aac", ".ogg", ".flac", ".webm"}
|
|
ALLOWED_MIME_PREFIXES = ("audio/",)
|
|
ALLOWED_MIMES = {
|
|
"video/mp4", # m4a가 video/mp4로 인식되는 경우가 흔함
|
|
"application/octet-stream", # 일부 브라우저/OS 조합
|
|
}
|
|
|
|
MAX_UPLOAD_MB = int(os.getenv("APP_MAX_UPLOAD_MB", "200"))
|
|
MAX_UPLOAD_BYTES = MAX_UPLOAD_MB * 1024 * 1024
|
|
|
|
JOB_TTL_SEC = int(os.getenv("APP_JOB_TTL_SEC", "3600"))
|
|
DEFAULT_AUTHOR_ID = os.getenv("APP_DEFAULT_AUTHOR_ID", "dosangyoon@gmail.com").strip() or "dosangyoon@gmail.com"
|
|
|
|
|
|
app = FastAPI(title="Web STT")
|
|
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
|
|
|
|
|
|
@app.on_event("startup")
|
|
def _startup() -> None:
|
|
# .env 기반으로 DB 테이블 자동 생성
|
|
db.init_db()
|
|
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class _Job:
|
|
job_id: str
|
|
filename: str
|
|
tmp_path: str
|
|
language: str | None
|
|
vad_filter: bool
|
|
beam_size: int
|
|
author_id: str
|
|
language_requested: str | None
|
|
status: str = "queued" # queued|running|completed|failed|cancelled
|
|
progress: float | None = 0.0
|
|
text: str = ""
|
|
segments: list[dict[str, Any]] = dataclasses.field(default_factory=list)
|
|
detected_language: str | None = None
|
|
language_probability: float | None = None
|
|
duration_sec: float | None = None
|
|
error: str | None = None
|
|
created_at: float = dataclasses.field(default_factory=time.time)
|
|
updated_at: float = dataclasses.field(default_factory=time.time)
|
|
cancel_event: threading.Event = dataclasses.field(default_factory=threading.Event, repr=False)
|
|
|
|
|
|
_JOBS: dict[str, _Job] = {}
|
|
_JOBS_LOCK = threading.Lock()
|
|
|
|
|
|
def _cleanup_jobs(now: float | None = None) -> None:
|
|
now = time.time() if now is None else now
|
|
to_delete: list[str] = []
|
|
with _JOBS_LOCK:
|
|
for job_id, job in _JOBS.items():
|
|
if job.status in ("running", "queued"):
|
|
continue
|
|
if now - job.updated_at > JOB_TTL_SEC:
|
|
to_delete.append(job_id)
|
|
for job_id in to_delete:
|
|
job = _JOBS.pop(job_id, None)
|
|
if job is None:
|
|
continue
|
|
try:
|
|
os.remove(job.tmp_path)
|
|
except OSError:
|
|
pass
|
|
|
|
|
|
def _job_public(job: _Job) -> dict[str, Any]:
|
|
return {
|
|
"job_id": job.job_id,
|
|
"filename": job.filename,
|
|
"status": job.status,
|
|
"progress": job.progress,
|
|
"text": job.text,
|
|
"segments": job.segments,
|
|
"detected_language": job.detected_language,
|
|
"language_probability": job.language_probability,
|
|
"duration_sec": job.duration_sec,
|
|
"error": job.error,
|
|
"created_at": job.created_at,
|
|
"updated_at": job.updated_at,
|
|
}
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse)
|
|
def index() -> HTMLResponse:
|
|
index_path = STATIC_DIR / "index.html"
|
|
return HTMLResponse(index_path.read_text(encoding="utf-8"))
|
|
|
|
|
|
@app.post("/api/jobs")
|
|
async def api_create_job(
|
|
file: UploadFile = File(...),
|
|
language: str = Form(default="ko"),
|
|
vad_filter: bool = Form(default=True),
|
|
beam_size: int = Form(default=5),
|
|
author_id: str = Form(default=DEFAULT_AUTHOR_ID),
|
|
) -> dict[str, Any]:
|
|
_cleanup_jobs()
|
|
_validate_upload(file)
|
|
job_id = str(uuid4())
|
|
saved_path = await _save_upload(file, file_id=job_id)
|
|
|
|
lang = language.strip().lower()
|
|
if lang in ("", "auto"):
|
|
lang = ""
|
|
|
|
job = _Job(
|
|
job_id=job_id,
|
|
filename=file.filename,
|
|
tmp_path=saved_path,
|
|
language=(lang or None),
|
|
vad_filter=bool(vad_filter),
|
|
beam_size=int(beam_size),
|
|
author_id=(author_id.strip() or DEFAULT_AUTHOR_ID),
|
|
language_requested=(language.strip() or None),
|
|
status="queued",
|
|
)
|
|
|
|
with _JOBS_LOCK:
|
|
_JOBS[job_id] = job
|
|
|
|
threading.Thread(target=_run_job, args=(job_id,), daemon=True).start()
|
|
return {"job_id": job_id}
|
|
|
|
|
|
@app.get("/api/jobs/{job_id}")
|
|
def api_get_job(job_id: str) -> dict[str, Any]:
|
|
_cleanup_jobs()
|
|
with _JOBS_LOCK:
|
|
job = _JOBS.get(job_id)
|
|
if job is None:
|
|
raise HTTPException(status_code=404, detail="job not found")
|
|
job.updated_at = time.time()
|
|
return _job_public(job)
|
|
|
|
|
|
@app.post("/api/jobs/{job_id}/cancel")
|
|
def api_cancel_job(job_id: str) -> dict[str, Any]:
|
|
with _JOBS_LOCK:
|
|
job = _JOBS.get(job_id)
|
|
if job is None:
|
|
raise HTTPException(status_code=404, detail="job not found")
|
|
if job.status in ("completed", "failed", "cancelled"):
|
|
return _job_public(job)
|
|
job.cancel_event.set()
|
|
job.updated_at = time.time()
|
|
return _job_public(job)
|
|
|
|
|
|
@app.post("/api/transcribe")
|
|
async def api_transcribe(
|
|
file: UploadFile = File(...),
|
|
language: str = Form(default="ko"),
|
|
vad_filter: bool = Form(default=True),
|
|
beam_size: int = Form(default=5),
|
|
author_id: str = Form(default=DEFAULT_AUTHOR_ID),
|
|
) -> dict[str, Any]:
|
|
_validate_upload(file)
|
|
|
|
try:
|
|
file_id = str(uuid4())
|
|
saved_path = await _save_upload(file, file_id=file_id)
|
|
lang = language.strip().lower()
|
|
if lang in ("", "auto"):
|
|
lang = ""
|
|
result = transcribe_file(
|
|
saved_path,
|
|
language=(lang or None),
|
|
vad_filter=bool(vad_filter),
|
|
beam_size=int(beam_size),
|
|
)
|
|
# 단발성 API도 DB 저장
|
|
try:
|
|
db.insert_record(
|
|
author_id=(author_id.strip() or DEFAULT_AUTHOR_ID),
|
|
filename=file.filename,
|
|
language_requested=(language.strip() or None),
|
|
detected_language=result.get("detected_language"),
|
|
language_probability=result.get("language_probability"),
|
|
duration_sec=result.get("duration_sec"),
|
|
status="completed",
|
|
text=result.get("text") or "",
|
|
segments=result.get("segments") or [],
|
|
)
|
|
except Exception:
|
|
pass
|
|
return result
|
|
finally:
|
|
# 업로드 파일은 resources/uploads 아래에 보관 (삭제하지 않음)
|
|
pass
|
|
|
|
|
|
@app.get("/healthz")
|
|
def healthz() -> dict[str, str]:
|
|
return {"status": "ok"}
|
|
|
|
|
|
@app.get("/api/records")
|
|
def api_list_records(limit: int = 50, offset: int = 0, author_id: str | None = None, q: str | None = None) -> dict[str, Any]:
|
|
return db.list_records(limit=limit, offset=offset, author_id=author_id, q=q)
|
|
|
|
|
|
@app.get("/api/records/{record_id}")
|
|
def api_get_record(record_id: int) -> dict[str, Any]:
|
|
row = db.get_record(int(record_id))
|
|
if row is None:
|
|
raise HTTPException(status_code=404, detail="record not found")
|
|
return row
|
|
|
|
|
|
class _UpdateRecordIn(BaseModel):
|
|
author_id: str | None = None
|
|
text: str | None = None
|
|
status: str | None = None
|
|
|
|
|
|
@app.put("/api/records/{record_id}")
|
|
def api_update_record(record_id: int, payload: _UpdateRecordIn) -> dict[str, Any]:
|
|
author_id = payload.author_id
|
|
if author_id is not None:
|
|
author_id = author_id.strip()
|
|
if not author_id:
|
|
raise HTTPException(status_code=400, detail="author_id는 비울 수 없습니다.")
|
|
|
|
row = db.update_record(int(record_id), author_id=author_id, text=payload.text, status=payload.status)
|
|
if row is None:
|
|
raise HTTPException(status_code=404, detail="record not found")
|
|
return row
|
|
|
|
|
|
@app.delete("/api/records/{record_id}")
|
|
def api_delete_record(record_id: int) -> dict[str, Any]:
|
|
ok = db.delete_record(int(record_id))
|
|
if not ok:
|
|
raise HTTPException(status_code=404, detail="record not found")
|
|
return {"deleted": True}
|
|
|
|
|
|
def _validate_upload(file: UploadFile) -> None:
|
|
if not file or not file.filename:
|
|
raise HTTPException(status_code=400, detail="파일이 필요합니다.")
|
|
|
|
ext = Path(file.filename).suffix.lower()
|
|
if ext not in ALLOWED_EXTS:
|
|
raise HTTPException(
|
|
status_code=415,
|
|
detail=f"허용되지 않는 확장자입니다: {ext}. 허용: {sorted(ALLOWED_EXTS)}",
|
|
)
|
|
|
|
content_type = (file.content_type or "").lower().strip()
|
|
guessed, _ = mimetypes.guess_type(file.filename)
|
|
guessed = (guessed or "").lower()
|
|
|
|
def ok_mime(m: str) -> bool:
|
|
return (m.startswith(ALLOWED_MIME_PREFIXES)) or (m in ALLOWED_MIMES)
|
|
|
|
if content_type and not ok_mime(content_type) and guessed and not ok_mime(guessed):
|
|
raise HTTPException(
|
|
status_code=415,
|
|
detail=f"오디오 파일만 업로드 가능합니다. content-type={content_type}, guessed={guessed}",
|
|
)
|
|
|
|
|
|
_FILENAME_SAFE_RE = re.compile(r"[^A-Za-z0-9._-]+")
|
|
|
|
|
|
def _safe_filename(name: str) -> str:
|
|
base = Path(name).name # path traversal 방지
|
|
base = base.strip().replace(" ", "_")
|
|
base = _FILENAME_SAFE_RE.sub("_", base)
|
|
if not base:
|
|
return "upload.bin"
|
|
if len(base) > 120:
|
|
stem = Path(base).stem[:100]
|
|
suf = Path(base).suffix[:20]
|
|
base = f"{stem}{suf}"
|
|
return base
|
|
|
|
|
|
async def _save_upload(file: UploadFile, *, file_id: str) -> str:
|
|
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
|
safe = _safe_filename(file.filename or "upload.bin")
|
|
out_path = UPLOAD_DIR / f"{file_id}_{safe}"
|
|
tmp_path = str(out_path)
|
|
with open(tmp_path, "wb") as tmp:
|
|
total = 0
|
|
while True:
|
|
chunk = await file.read(1024 * 1024)
|
|
if not chunk:
|
|
break
|
|
total += len(chunk)
|
|
if total > MAX_UPLOAD_BYTES:
|
|
try:
|
|
os.remove(tmp_path)
|
|
except OSError:
|
|
pass
|
|
raise HTTPException(
|
|
status_code=413,
|
|
detail=f"파일이 너무 큽니다. 최대 {MAX_UPLOAD_MB}MB 까지 업로드 가능합니다.",
|
|
)
|
|
tmp.write(chunk)
|
|
return tmp_path
|
|
|
|
|
|
def _run_job(job_id: str) -> None:
|
|
with _JOBS_LOCK:
|
|
job = _JOBS.get(job_id)
|
|
if job is None:
|
|
return
|
|
job.status = "running"
|
|
job.progress = 0.0
|
|
job.updated_at = time.time()
|
|
|
|
tmp_path: str | None = None
|
|
cancelled = False
|
|
try:
|
|
with _JOBS_LOCK:
|
|
job = _JOBS.get(job_id)
|
|
if job is None:
|
|
return
|
|
tmp_path = job.tmp_path
|
|
language = job.language
|
|
vad_filter = job.vad_filter
|
|
beam_size = job.beam_size
|
|
author_id = job.author_id
|
|
language_requested = job.language_requested
|
|
filename = job.filename
|
|
|
|
segments_iter, info = transcribe_iter(
|
|
tmp_path,
|
|
language=language,
|
|
vad_filter=vad_filter,
|
|
beam_size=beam_size,
|
|
)
|
|
|
|
duration = getattr(info, "duration", None)
|
|
detected_language = getattr(info, "language", None)
|
|
language_probability = getattr(info, "language_probability", None)
|
|
|
|
with _JOBS_LOCK:
|
|
job = _JOBS.get(job_id)
|
|
if job is None:
|
|
return
|
|
job.duration_sec = duration
|
|
job.detected_language = detected_language
|
|
job.language_probability = language_probability
|
|
job.updated_at = time.time()
|
|
|
|
texts: list[str] = []
|
|
for s in segments_iter:
|
|
with _JOBS_LOCK:
|
|
job = _JOBS.get(job_id)
|
|
if job is None:
|
|
return
|
|
if job.cancel_event.is_set():
|
|
job.status = "cancelled"
|
|
job.updated_at = time.time()
|
|
cancelled = True
|
|
break
|
|
|
|
seg_text = (getattr(s, "text", "") or "").strip()
|
|
if not seg_text:
|
|
continue
|
|
|
|
seg = {
|
|
"start": float(getattr(s, "start", 0.0)),
|
|
"end": float(getattr(s, "end", 0.0)),
|
|
"text": seg_text,
|
|
}
|
|
texts.append(seg_text)
|
|
|
|
with _JOBS_LOCK:
|
|
job = _JOBS.get(job_id)
|
|
if job is None:
|
|
return
|
|
if job.cancel_event.is_set():
|
|
job.status = "cancelled"
|
|
job.updated_at = time.time()
|
|
cancelled = True
|
|
break
|
|
job.segments.append(seg)
|
|
job.text = "\n".join(texts).strip()
|
|
if job.duration_sec and job.duration_sec > 0:
|
|
job.progress = max(0.0, min(0.999, float(seg["end"]) / float(job.duration_sec)))
|
|
else:
|
|
job.progress = None
|
|
job.updated_at = time.time()
|
|
|
|
with _JOBS_LOCK:
|
|
job = _JOBS.get(job_id)
|
|
if job is None:
|
|
return
|
|
if cancelled or job.cancel_event.is_set():
|
|
job.status = "cancelled"
|
|
else:
|
|
job.status = "completed"
|
|
job.progress = 1.0
|
|
job.updated_at = time.time()
|
|
|
|
# DB 저장 (완료/취소 모두 저장)
|
|
try:
|
|
with _JOBS_LOCK:
|
|
job = _JOBS.get(job_id)
|
|
if job is None:
|
|
return
|
|
db.insert_record(
|
|
author_id=author_id,
|
|
filename=filename,
|
|
language_requested=language_requested,
|
|
detected_language=job.detected_language,
|
|
language_probability=job.language_probability,
|
|
duration_sec=job.duration_sec,
|
|
status=job.status,
|
|
text=job.text,
|
|
segments=job.segments,
|
|
)
|
|
except Exception:
|
|
pass
|
|
|
|
except Exception as e:
|
|
with _JOBS_LOCK:
|
|
job = _JOBS.get(job_id)
|
|
if job is None:
|
|
return
|
|
job.status = "failed"
|
|
job.error = str(e)
|
|
job.updated_at = time.time()
|
|
try:
|
|
with _JOBS_LOCK:
|
|
job = _JOBS.get(job_id)
|
|
if job is None:
|
|
return
|
|
db.insert_record(
|
|
author_id=getattr(job, "author_id", DEFAULT_AUTHOR_ID),
|
|
filename=getattr(job, "filename", None),
|
|
language_requested=getattr(job, "language_requested", None),
|
|
detected_language=job.detected_language,
|
|
language_probability=job.language_probability,
|
|
duration_sec=job.duration_sec,
|
|
status="failed",
|
|
text=job.text,
|
|
segments=job.segments,
|
|
)
|
|
except Exception:
|
|
pass
|
|
finally:
|
|
# 업로드 파일은 resources/uploads 아래에 보관 (삭제하지 않음)
|
|
pass
|
|
|