Initial commit after re-install
This commit is contained in:
484
app/main.py
Normal file
484
app/main.py
Normal file
@@ -0,0 +1,484 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import mimetypes
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from pydantic import BaseModel
|
||||
|
||||
from . import db
|
||||
from .stt import transcribe_file, transcribe_iter
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
APP_ROOT = Path(__file__).resolve().parent
|
||||
PROJECT_ROOT = APP_ROOT.parent
|
||||
STATIC_DIR = APP_ROOT / "static"
|
||||
UPLOAD_DIR = PROJECT_ROOT / "resources" / "uploads"
|
||||
|
||||
ALLOWED_EXTS = {".mp3", ".m4a", ".wav", ".mp4", ".aac", ".ogg", ".flac", ".webm"}
|
||||
ALLOWED_MIME_PREFIXES = ("audio/",)
|
||||
ALLOWED_MIMES = {
|
||||
"video/mp4", # m4a가 video/mp4로 인식되는 경우가 흔함
|
||||
"application/octet-stream", # 일부 브라우저/OS 조합
|
||||
}
|
||||
|
||||
MAX_UPLOAD_MB = int(os.getenv("APP_MAX_UPLOAD_MB", "200"))
|
||||
MAX_UPLOAD_BYTES = MAX_UPLOAD_MB * 1024 * 1024
|
||||
|
||||
JOB_TTL_SEC = int(os.getenv("APP_JOB_TTL_SEC", "3600"))
|
||||
DEFAULT_AUTHOR_ID = os.getenv("APP_DEFAULT_AUTHOR_ID", "dosangyoon@gmail.com").strip() or "dosangyoon@gmail.com"
|
||||
|
||||
|
||||
app = FastAPI(title="Web STT")
|
||||
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
def _startup() -> None:
|
||||
# .env 기반으로 DB 테이블 자동 생성
|
||||
db.init_db()
|
||||
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _Job:
|
||||
job_id: str
|
||||
filename: str
|
||||
tmp_path: str
|
||||
language: str | None
|
||||
vad_filter: bool
|
||||
beam_size: int
|
||||
author_id: str
|
||||
language_requested: str | None
|
||||
status: str = "queued" # queued|running|completed|failed|cancelled
|
||||
progress: float | None = 0.0
|
||||
text: str = ""
|
||||
segments: list[dict[str, Any]] = dataclasses.field(default_factory=list)
|
||||
detected_language: str | None = None
|
||||
language_probability: float | None = None
|
||||
duration_sec: float | None = None
|
||||
error: str | None = None
|
||||
created_at: float = dataclasses.field(default_factory=time.time)
|
||||
updated_at: float = dataclasses.field(default_factory=time.time)
|
||||
cancel_event: threading.Event = dataclasses.field(default_factory=threading.Event, repr=False)
|
||||
|
||||
|
||||
_JOBS: dict[str, _Job] = {}
|
||||
_JOBS_LOCK = threading.Lock()
|
||||
|
||||
|
||||
def _cleanup_jobs(now: float | None = None) -> None:
|
||||
now = time.time() if now is None else now
|
||||
to_delete: list[str] = []
|
||||
with _JOBS_LOCK:
|
||||
for job_id, job in _JOBS.items():
|
||||
if job.status in ("running", "queued"):
|
||||
continue
|
||||
if now - job.updated_at > JOB_TTL_SEC:
|
||||
to_delete.append(job_id)
|
||||
for job_id in to_delete:
|
||||
job = _JOBS.pop(job_id, None)
|
||||
if job is None:
|
||||
continue
|
||||
try:
|
||||
os.remove(job.tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def _job_public(job: _Job) -> dict[str, Any]:
|
||||
return {
|
||||
"job_id": job.job_id,
|
||||
"filename": job.filename,
|
||||
"status": job.status,
|
||||
"progress": job.progress,
|
||||
"text": job.text,
|
||||
"segments": job.segments,
|
||||
"detected_language": job.detected_language,
|
||||
"language_probability": job.language_probability,
|
||||
"duration_sec": job.duration_sec,
|
||||
"error": job.error,
|
||||
"created_at": job.created_at,
|
||||
"updated_at": job.updated_at,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
def index() -> HTMLResponse:
|
||||
index_path = STATIC_DIR / "index.html"
|
||||
return HTMLResponse(index_path.read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
@app.post("/api/jobs")
|
||||
async def api_create_job(
|
||||
file: UploadFile = File(...),
|
||||
language: str = Form(default="ko"),
|
||||
vad_filter: bool = Form(default=True),
|
||||
beam_size: int = Form(default=5),
|
||||
author_id: str = Form(default=DEFAULT_AUTHOR_ID),
|
||||
) -> dict[str, Any]:
|
||||
_cleanup_jobs()
|
||||
_validate_upload(file)
|
||||
job_id = str(uuid4())
|
||||
saved_path = await _save_upload(file, file_id=job_id)
|
||||
|
||||
lang = language.strip().lower()
|
||||
if lang in ("", "auto"):
|
||||
lang = ""
|
||||
|
||||
job = _Job(
|
||||
job_id=job_id,
|
||||
filename=file.filename,
|
||||
tmp_path=saved_path,
|
||||
language=(lang or None),
|
||||
vad_filter=bool(vad_filter),
|
||||
beam_size=int(beam_size),
|
||||
author_id=(author_id.strip() or DEFAULT_AUTHOR_ID),
|
||||
language_requested=(language.strip() or None),
|
||||
status="queued",
|
||||
)
|
||||
|
||||
with _JOBS_LOCK:
|
||||
_JOBS[job_id] = job
|
||||
|
||||
threading.Thread(target=_run_job, args=(job_id,), daemon=True).start()
|
||||
return {"job_id": job_id}
|
||||
|
||||
|
||||
@app.get("/api/jobs/{job_id}")
|
||||
def api_get_job(job_id: str) -> dict[str, Any]:
|
||||
_cleanup_jobs()
|
||||
with _JOBS_LOCK:
|
||||
job = _JOBS.get(job_id)
|
||||
if job is None:
|
||||
raise HTTPException(status_code=404, detail="job not found")
|
||||
job.updated_at = time.time()
|
||||
return _job_public(job)
|
||||
|
||||
|
||||
@app.post("/api/jobs/{job_id}/cancel")
|
||||
def api_cancel_job(job_id: str) -> dict[str, Any]:
|
||||
with _JOBS_LOCK:
|
||||
job = _JOBS.get(job_id)
|
||||
if job is None:
|
||||
raise HTTPException(status_code=404, detail="job not found")
|
||||
if job.status in ("completed", "failed", "cancelled"):
|
||||
return _job_public(job)
|
||||
job.cancel_event.set()
|
||||
job.updated_at = time.time()
|
||||
return _job_public(job)
|
||||
|
||||
|
||||
@app.post("/api/transcribe")
|
||||
async def api_transcribe(
|
||||
file: UploadFile = File(...),
|
||||
language: str = Form(default="ko"),
|
||||
vad_filter: bool = Form(default=True),
|
||||
beam_size: int = Form(default=5),
|
||||
author_id: str = Form(default=DEFAULT_AUTHOR_ID),
|
||||
) -> dict[str, Any]:
|
||||
_validate_upload(file)
|
||||
|
||||
try:
|
||||
file_id = str(uuid4())
|
||||
saved_path = await _save_upload(file, file_id=file_id)
|
||||
lang = language.strip().lower()
|
||||
if lang in ("", "auto"):
|
||||
lang = ""
|
||||
result = transcribe_file(
|
||||
saved_path,
|
||||
language=(lang or None),
|
||||
vad_filter=bool(vad_filter),
|
||||
beam_size=int(beam_size),
|
||||
)
|
||||
# 단발성 API도 DB 저장
|
||||
try:
|
||||
db.insert_record(
|
||||
author_id=(author_id.strip() or DEFAULT_AUTHOR_ID),
|
||||
filename=file.filename,
|
||||
language_requested=(language.strip() or None),
|
||||
detected_language=result.get("detected_language"),
|
||||
language_probability=result.get("language_probability"),
|
||||
duration_sec=result.get("duration_sec"),
|
||||
status="completed",
|
||||
text=result.get("text") or "",
|
||||
segments=result.get("segments") or [],
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
finally:
|
||||
# 업로드 파일은 resources/uploads 아래에 보관 (삭제하지 않음)
|
||||
pass
|
||||
|
||||
|
||||
@app.get("/healthz")
|
||||
def healthz() -> dict[str, str]:
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.get("/api/records")
|
||||
def api_list_records(limit: int = 50, offset: int = 0, author_id: str | None = None, q: str | None = None) -> dict[str, Any]:
|
||||
return db.list_records(limit=limit, offset=offset, author_id=author_id, q=q)
|
||||
|
||||
|
||||
@app.get("/api/records/{record_id}")
|
||||
def api_get_record(record_id: int) -> dict[str, Any]:
|
||||
row = db.get_record(int(record_id))
|
||||
if row is None:
|
||||
raise HTTPException(status_code=404, detail="record not found")
|
||||
return row
|
||||
|
||||
|
||||
class _UpdateRecordIn(BaseModel):
|
||||
author_id: str | None = None
|
||||
text: str | None = None
|
||||
status: str | None = None
|
||||
|
||||
|
||||
@app.put("/api/records/{record_id}")
|
||||
def api_update_record(record_id: int, payload: _UpdateRecordIn) -> dict[str, Any]:
|
||||
author_id = payload.author_id
|
||||
if author_id is not None:
|
||||
author_id = author_id.strip()
|
||||
if not author_id:
|
||||
raise HTTPException(status_code=400, detail="author_id는 비울 수 없습니다.")
|
||||
|
||||
row = db.update_record(int(record_id), author_id=author_id, text=payload.text, status=payload.status)
|
||||
if row is None:
|
||||
raise HTTPException(status_code=404, detail="record not found")
|
||||
return row
|
||||
|
||||
|
||||
@app.delete("/api/records/{record_id}")
|
||||
def api_delete_record(record_id: int) -> dict[str, Any]:
|
||||
ok = db.delete_record(int(record_id))
|
||||
if not ok:
|
||||
raise HTTPException(status_code=404, detail="record not found")
|
||||
return {"deleted": True}
|
||||
|
||||
|
||||
def _validate_upload(file: UploadFile) -> None:
|
||||
if not file or not file.filename:
|
||||
raise HTTPException(status_code=400, detail="파일이 필요합니다.")
|
||||
|
||||
ext = Path(file.filename).suffix.lower()
|
||||
if ext not in ALLOWED_EXTS:
|
||||
raise HTTPException(
|
||||
status_code=415,
|
||||
detail=f"허용되지 않는 확장자입니다: {ext}. 허용: {sorted(ALLOWED_EXTS)}",
|
||||
)
|
||||
|
||||
content_type = (file.content_type or "").lower().strip()
|
||||
guessed, _ = mimetypes.guess_type(file.filename)
|
||||
guessed = (guessed or "").lower()
|
||||
|
||||
def ok_mime(m: str) -> bool:
|
||||
return (m.startswith(ALLOWED_MIME_PREFIXES)) or (m in ALLOWED_MIMES)
|
||||
|
||||
if content_type and not ok_mime(content_type) and guessed and not ok_mime(guessed):
|
||||
raise HTTPException(
|
||||
status_code=415,
|
||||
detail=f"오디오 파일만 업로드 가능합니다. content-type={content_type}, guessed={guessed}",
|
||||
)
|
||||
|
||||
|
||||
_FILENAME_SAFE_RE = re.compile(r"[^A-Za-z0-9._-]+")
|
||||
|
||||
|
||||
def _safe_filename(name: str) -> str:
|
||||
base = Path(name).name # path traversal 방지
|
||||
base = base.strip().replace(" ", "_")
|
||||
base = _FILENAME_SAFE_RE.sub("_", base)
|
||||
if not base:
|
||||
return "upload.bin"
|
||||
if len(base) > 120:
|
||||
stem = Path(base).stem[:100]
|
||||
suf = Path(base).suffix[:20]
|
||||
base = f"{stem}{suf}"
|
||||
return base
|
||||
|
||||
|
||||
async def _save_upload(file: UploadFile, *, file_id: str) -> str:
|
||||
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||
safe = _safe_filename(file.filename or "upload.bin")
|
||||
out_path = UPLOAD_DIR / f"{file_id}_{safe}"
|
||||
tmp_path = str(out_path)
|
||||
with open(tmp_path, "wb") as tmp:
|
||||
total = 0
|
||||
while True:
|
||||
chunk = await file.read(1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
total += len(chunk)
|
||||
if total > MAX_UPLOAD_BYTES:
|
||||
try:
|
||||
os.remove(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"파일이 너무 큽니다. 최대 {MAX_UPLOAD_MB}MB 까지 업로드 가능합니다.",
|
||||
)
|
||||
tmp.write(chunk)
|
||||
return tmp_path
|
||||
|
||||
|
||||
def _run_job(job_id: str) -> None:
|
||||
with _JOBS_LOCK:
|
||||
job = _JOBS.get(job_id)
|
||||
if job is None:
|
||||
return
|
||||
job.status = "running"
|
||||
job.progress = 0.0
|
||||
job.updated_at = time.time()
|
||||
|
||||
tmp_path: str | None = None
|
||||
cancelled = False
|
||||
try:
|
||||
with _JOBS_LOCK:
|
||||
job = _JOBS.get(job_id)
|
||||
if job is None:
|
||||
return
|
||||
tmp_path = job.tmp_path
|
||||
language = job.language
|
||||
vad_filter = job.vad_filter
|
||||
beam_size = job.beam_size
|
||||
author_id = job.author_id
|
||||
language_requested = job.language_requested
|
||||
filename = job.filename
|
||||
|
||||
segments_iter, info = transcribe_iter(
|
||||
tmp_path,
|
||||
language=language,
|
||||
vad_filter=vad_filter,
|
||||
beam_size=beam_size,
|
||||
)
|
||||
|
||||
duration = getattr(info, "duration", None)
|
||||
detected_language = getattr(info, "language", None)
|
||||
language_probability = getattr(info, "language_probability", None)
|
||||
|
||||
with _JOBS_LOCK:
|
||||
job = _JOBS.get(job_id)
|
||||
if job is None:
|
||||
return
|
||||
job.duration_sec = duration
|
||||
job.detected_language = detected_language
|
||||
job.language_probability = language_probability
|
||||
job.updated_at = time.time()
|
||||
|
||||
texts: list[str] = []
|
||||
for s in segments_iter:
|
||||
with _JOBS_LOCK:
|
||||
job = _JOBS.get(job_id)
|
||||
if job is None:
|
||||
return
|
||||
if job.cancel_event.is_set():
|
||||
job.status = "cancelled"
|
||||
job.updated_at = time.time()
|
||||
cancelled = True
|
||||
break
|
||||
|
||||
seg_text = (getattr(s, "text", "") or "").strip()
|
||||
if not seg_text:
|
||||
continue
|
||||
|
||||
seg = {
|
||||
"start": float(getattr(s, "start", 0.0)),
|
||||
"end": float(getattr(s, "end", 0.0)),
|
||||
"text": seg_text,
|
||||
}
|
||||
texts.append(seg_text)
|
||||
|
||||
with _JOBS_LOCK:
|
||||
job = _JOBS.get(job_id)
|
||||
if job is None:
|
||||
return
|
||||
if job.cancel_event.is_set():
|
||||
job.status = "cancelled"
|
||||
job.updated_at = time.time()
|
||||
cancelled = True
|
||||
break
|
||||
job.segments.append(seg)
|
||||
job.text = "\n".join(texts).strip()
|
||||
if job.duration_sec and job.duration_sec > 0:
|
||||
job.progress = max(0.0, min(0.999, float(seg["end"]) / float(job.duration_sec)))
|
||||
else:
|
||||
job.progress = None
|
||||
job.updated_at = time.time()
|
||||
|
||||
with _JOBS_LOCK:
|
||||
job = _JOBS.get(job_id)
|
||||
if job is None:
|
||||
return
|
||||
if cancelled or job.cancel_event.is_set():
|
||||
job.status = "cancelled"
|
||||
else:
|
||||
job.status = "completed"
|
||||
job.progress = 1.0
|
||||
job.updated_at = time.time()
|
||||
|
||||
# DB 저장 (완료/취소 모두 저장)
|
||||
try:
|
||||
with _JOBS_LOCK:
|
||||
job = _JOBS.get(job_id)
|
||||
if job is None:
|
||||
return
|
||||
db.insert_record(
|
||||
author_id=author_id,
|
||||
filename=filename,
|
||||
language_requested=language_requested,
|
||||
detected_language=job.detected_language,
|
||||
language_probability=job.language_probability,
|
||||
duration_sec=job.duration_sec,
|
||||
status=job.status,
|
||||
text=job.text,
|
||||
segments=job.segments,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
with _JOBS_LOCK:
|
||||
job = _JOBS.get(job_id)
|
||||
if job is None:
|
||||
return
|
||||
job.status = "failed"
|
||||
job.error = str(e)
|
||||
job.updated_at = time.time()
|
||||
try:
|
||||
with _JOBS_LOCK:
|
||||
job = _JOBS.get(job_id)
|
||||
if job is None:
|
||||
return
|
||||
db.insert_record(
|
||||
author_id=getattr(job, "author_id", DEFAULT_AUTHOR_ID),
|
||||
filename=getattr(job, "filename", None),
|
||||
language_requested=getattr(job, "language_requested", None),
|
||||
detected_language=job.detected_language,
|
||||
language_probability=job.language_probability,
|
||||
duration_sec=job.duration_sec,
|
||||
status="failed",
|
||||
text=job.text,
|
||||
segments=job.segments,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
# 업로드 파일은 resources/uploads 아래에 보관 (삭제하지 않음)
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user