Files
stt/app/db.py

230 lines
6.8 KiB
Python

from __future__ import annotations
import os
import re
from typing import Any, Iterable
import psycopg
from psycopg import sql
from psycopg.rows import dict_row
from psycopg.types.json import Json
_IDENT_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
def _table_name() -> str:
name = os.getenv("TABLE", "ncue_stt").strip()
if not _IDENT_RE.match(name):
raise RuntimeError("TABLE 환경변수가 올바르지 않습니다.")
return name
def _conninfo() -> str:
host = os.getenv("DB_HOST", "").strip()
port = os.getenv("DB_PORT", "5432").strip()
dbname = os.getenv("DB_NAME", "").strip()
user = os.getenv("DB_USER", "").strip()
password = os.getenv("DB_PASSWORD", "").strip()
sslmode = os.getenv("DB_SSLMODE", "").strip() # optional
missing = [k for k, v in (("DB_HOST", host), ("DB_NAME", dbname), ("DB_USER", user), ("DB_PASSWORD", password)) if not v]
if missing:
raise RuntimeError(f"DB 환경변수 누락: {', '.join(missing)}")
parts = [
f"host={host}",
f"port={port}",
f"dbname={dbname}",
f"user={user}",
f"password={password}",
]
if sslmode:
parts.append(f"sslmode={sslmode}")
return " ".join(parts)
def connect() -> psycopg.Connection[Any]:
return psycopg.connect(_conninfo(), row_factory=dict_row, connect_timeout=5)
def init_db() -> None:
table = _table_name()
create_sql = sql.SQL(
"""
CREATE TABLE IF NOT EXISTS {table} (
id BIGSERIAL PRIMARY KEY,
author_id TEXT NOT NULL,
filename TEXT,
language_requested TEXT,
detected_language TEXT,
language_probability DOUBLE PRECISION,
duration_sec DOUBLE PRECISION,
status TEXT NOT NULL DEFAULT 'completed',
text TEXT NOT NULL DEFAULT '',
segments JSONB NOT NULL DEFAULT '[]'::jsonb,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
"""
).format(table=sql.Identifier(table))
idx1 = sql.SQL("CREATE INDEX IF NOT EXISTS {idx} ON {table}(author_id);").format(
idx=sql.Identifier(f"{table}_author_id_idx"), table=sql.Identifier(table)
)
idx2 = sql.SQL("CREATE INDEX IF NOT EXISTS {idx} ON {table}(created_at DESC);").format(
idx=sql.Identifier(f"{table}_created_at_idx"), table=sql.Identifier(table)
)
with connect() as conn:
with conn.cursor() as cur:
cur.execute(create_sql)
cur.execute(idx1)
cur.execute(idx2)
conn.commit()
def insert_record(
*,
author_id: str,
filename: str | None,
language_requested: str | None,
detected_language: str | None,
language_probability: float | None,
duration_sec: float | None,
status: str,
text: str,
segments: list[dict[str, Any]],
) -> int:
table = _table_name()
q = sql.SQL(
"""
INSERT INTO {table}
(author_id, filename, language_requested, detected_language, language_probability, duration_sec, status, text, segments)
VALUES
(%s, %s, %s, %s, %s, %s, %s, %s, %s)
RETURNING id;
"""
).format(table=sql.Identifier(table))
with connect() as conn:
with conn.cursor() as cur:
cur.execute(
q,
(
author_id,
filename,
language_requested,
detected_language,
language_probability,
duration_sec,
status,
text or "",
Json(segments or []),
),
)
row = cur.fetchone()
conn.commit()
return int(row["id"])
def list_records(*, limit: int = 50, offset: int = 0, author_id: str | None = None, q: str | None = None) -> dict[str, Any]:
table = _table_name()
limit = max(1, min(int(limit), 200))
offset = max(0, int(offset))
where_parts: list[sql.SQL] = []
params: list[Any] = []
if author_id:
where_parts.append(sql.SQL("author_id = %s"))
params.append(author_id)
if q:
where_parts.append(sql.SQL("(filename ILIKE %s OR text ILIKE %s)"))
params.extend([f"%{q}%", f"%{q}%"])
where_sql = sql.SQL("")
if where_parts:
where_sql = sql.SQL("WHERE ") + sql.SQL(" AND ").join(where_parts)
count_q = sql.SQL("SELECT count(*)::bigint AS cnt FROM {table} {where};").format(
table=sql.Identifier(table), where=where_sql
)
list_q = sql.SQL(
"""
SELECT id, author_id, filename, language_requested, detected_language, duration_sec, status, created_at, updated_at
FROM {table}
{where}
ORDER BY created_at DESC
LIMIT %s OFFSET %s;
"""
).format(table=sql.Identifier(table), where=where_sql)
with connect() as conn:
with conn.cursor() as cur:
cur.execute(count_q, params)
total = int(cur.fetchone()["cnt"])
cur.execute(list_q, params + [limit, offset])
rows = cur.fetchall()
return {"total": total, "items": rows}
def get_record(record_id: int) -> dict[str, Any] | None:
table = _table_name()
q = sql.SQL("SELECT * FROM {table} WHERE id = %s;").format(table=sql.Identifier(table))
with connect() as conn:
with conn.cursor() as cur:
cur.execute(q, (int(record_id),))
row = cur.fetchone()
return row
def update_record(
record_id: int,
*,
author_id: str | None = None,
text: str | None = None,
status: str | None = None,
) -> dict[str, Any] | None:
table = _table_name()
sets: list[sql.SQL] = []
params: list[Any] = []
if author_id is not None:
sets.append(sql.SQL("author_id = %s"))
params.append(author_id)
if text is not None:
sets.append(sql.SQL("text = %s"))
params.append(text)
if status is not None:
sets.append(sql.SQL("status = %s"))
params.append(status)
if not sets:
return get_record(int(record_id))
sets.append(sql.SQL("updated_at = now()"))
q = sql.SQL("UPDATE {table} SET {sets} WHERE id = %s RETURNING *;").format(
table=sql.Identifier(table), sets=sql.SQL(", ").join(sets)
)
params.append(int(record_id))
with connect() as conn:
with conn.cursor() as cur:
cur.execute(q, params)
row = cur.fetchone()
conn.commit()
return row
def delete_record(record_id: int) -> bool:
table = _table_name()
q = sql.SQL("DELETE FROM {table} WHERE id = %s;").format(table=sql.Identifier(table))
with connect() as conn:
with conn.cursor() as cur:
cur.execute(q, (int(record_id),))
deleted = cur.rowcount > 0
conn.commit()
return deleted