feat(web): speaker diarization via pyannote (parity with whisper_stt)

- Add app/diarize.py: local snapshot, A/B labels, disclaimer text
- transcribe_file and async jobs support diarize flag; Form diarize on API
- UI checkbox (default on); requirements: pyannote.audio, huggingface_hub
- README: env vars and model notes

Made-with: Cursor
This commit is contained in:
dosangyoon
2026-03-23 15:23:49 +09:00
parent 2caa74ac05
commit 26ff9b59c2
6 changed files with 280 additions and 3 deletions

172
app/diarize.py Normal file
View File

@@ -0,0 +1,172 @@
"""웹 STT용 화자 분리 — whisper_stt.py와 동일한 pyannote 로컬 스냅샷 + 타임라인 정렬."""
from __future__ import annotations
import os
from pathlib import Path
from typing import Any
from .pyannote_auth import load_pyannote_pipeline
PROJECT_ROOT = Path(__file__).resolve().parent.parent
DEFAULT_DIARIZE_MODEL_DIR = PROJECT_ROOT / "models" / "pyannote-diarization-3.1"
DIARIZE_DISCLAIMER_KO = (
"※ 화자 A, B, C… 는 실제 이름이 아니라, 이 녹음에서 말이 처음 잡힌 순서로 붙인 구분자입니다.\n"
"※ 같은 사람이 여러 구간으로 나뉘면 라벨이 바뀌거나 섞일 수 있으니, 중요한 회의는 검수가 필요합니다.\n\n"
)
def diarization_annotation(diarization: Any) -> Any:
"""pyannote.audio 4.x는 DiarizeOutput; 구간은 .speaker_diarization에 있다."""
ann = getattr(diarization, "speaker_diarization", None)
return diarization if ann is None else ann
def validate_pyannote_snapshot(model_dir: Path | str) -> None:
cfg = Path(model_dir) / "config.yaml"
if cfg.is_file():
return
p = Path(model_dir).resolve()
raise ValueError(
f"pyannote 모델 폴더가 불완전합니다 (config.yaml 없음): {p}. "
"hf download pyannote/speaker-diarization-3.1 --local-dir ./models/pyannote-diarization-3.1"
)
def resolve_local_diarize_dir(override: str | None) -> Path:
if override:
path = Path(override).expanduser().resolve()
if path.is_dir():
return path
raise ValueError(f"화자 분리 모델 폴더가 없습니다: {path}")
for cand in (os.environ.get("WHISPER_DIARIZE_MODEL_DIR"), os.environ.get("PYANNOTE_MODEL_DIR")):
if cand:
path = Path(cand).expanduser().resolve()
if path.is_dir():
return path
path = DEFAULT_DIARIZE_MODEL_DIR.resolve()
if path.is_dir():
return path
raise ValueError(
f"화자 분리 모델 폴더가 없습니다: {path}. "
"프로젝트 루트에서: hf download pyannote/speaker-diarization-3.1 "
"--local-dir ./models/pyannote-diarization-3.1 (약관 동의·HF 토큰 필요)"
)
def speaker_turns(audio_path: str, *, model_dir: str | None = None) -> list[tuple[float, float, str]]:
import torch
resolved = resolve_local_diarize_dir(model_dir)
validate_pyannote_snapshot(resolved)
pipeline = load_pyannote_pipeline(resolved)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipeline.to(device)
diarization = pipeline(audio_path)
ann = diarization_annotation(diarization)
turns: list[tuple[float, float, str]] = []
for segment, _, label in ann.itertracks(yield_label=True):
turns.append((float(segment.start), float(segment.end), str(label)))
turns.sort(key=lambda x: x[0])
return turns
def _overlap_sec(a0: float, a1: float, b0: float, b1: float) -> float:
return max(0.0, min(a1, b1) - max(a0, b0))
def _assign_speaker(
seg_start: float, seg_end: float, turns: list[tuple[float, float, str]]
) -> str | None:
best: str | None = None
best_ov = 0.0
for t0, t1, sp in turns:
ov = _overlap_sec(seg_start, seg_end, t0, t1)
if ov > best_ov:
best_ov = ov
best = sp
if best is None or best_ov < 0.05:
return None
return best
def speaker_label_order(turns: list[tuple[float, float, str]]) -> dict[str, str]:
order: list[str] = []
for t0, _, sp in sorted(turns, key=lambda x: x[0]):
if sp not in order:
order.append(sp)
def letter(i: int) -> str:
if i < 26:
return chr(ord("A") + i)
return f"SP{i + 1}"
return {sp: letter(i) for i, sp in enumerate(order)}
def format_diarized_text(
whisper_segments: list[dict[str, Any]],
turns: list[tuple[float, float, str]],
) -> str:
labels = speaker_label_order(turns)
lines: list[str] = []
current_letter: str | None = None
current_parts: list[str] = []
def flush() -> None:
nonlocal current_letter, current_parts
if current_letter is not None and current_parts:
lines.append(f"{current_letter}: {' '.join(current_parts).strip()}")
current_letter = None
current_parts = []
for seg in whisper_segments:
text = (seg.get("text") or "").strip()
if not text:
continue
start = float(seg["start"])
end = float(seg["end"])
sp = _assign_speaker(start, end, turns)
letter = labels.get(sp, "?") if sp is not None else "?"
if letter == current_letter:
current_parts.append(text)
else:
flush()
current_letter = letter
current_parts = [text]
flush()
return "\n".join(lines)
def segments_with_speakers(
whisper_segments: list[dict[str, Any]],
turns: list[tuple[float, float, str]],
) -> list[dict[str, Any]]:
labels = speaker_label_order(turns)
out: list[dict[str, Any]] = []
for seg in whisper_segments:
text = (seg.get("text") or "").strip()
if not text:
continue
sp = _assign_speaker(float(seg["start"]), float(seg["end"]), turns)
letter = labels.get(sp, "?") if sp is not None else "?"
out.append({**seg, "text": text, "speaker": letter})
return out
def build_diarized_output(
whisper_segments: list[dict[str, Any]],
audio_path: str,
*,
model_dir: str | None = None,
with_disclaimer: bool = True,
) -> tuple[str, list[dict[str, Any]]]:
turns = speaker_turns(audio_path, model_dir=model_dir)
body = format_diarized_text(whisper_segments, turns)
text = (DIARIZE_DISCLAIMER_KO + body) if with_disclaimer else body
segs = segments_with_speakers(whisper_segments, turns)
return text, segs

View File

@@ -61,6 +61,7 @@ class _Job:
language: str | None
vad_filter: bool
beam_size: int
diarize: bool
author_id: str
language_requested: str | None
status: str = "queued" # queued|running|completed|failed|cancelled
@@ -128,6 +129,7 @@ async def api_create_job(
language: str = Form(default="ko"),
vad_filter: bool = Form(default=True),
beam_size: int = Form(default=5),
diarize: bool = Form(default=True),
author_id: str = Form(default=DEFAULT_AUTHOR_ID),
) -> dict[str, Any]:
_cleanup_jobs()
@@ -146,6 +148,7 @@ async def api_create_job(
language=(lang or None),
vad_filter=bool(vad_filter),
beam_size=int(beam_size),
diarize=bool(diarize),
author_id=(author_id.strip() or DEFAULT_AUTHOR_ID),
language_requested=(language.strip() or None),
status="queued",
@@ -188,6 +191,7 @@ async def api_transcribe(
language: str = Form(default="ko"),
vad_filter: bool = Form(default=True),
beam_size: int = Form(default=5),
diarize: bool = Form(default=True),
author_id: str = Form(default=DEFAULT_AUTHOR_ID),
) -> dict[str, Any]:
_validate_upload(file)
@@ -203,6 +207,7 @@ async def api_transcribe(
language=(lang or None),
vad_filter=bool(vad_filter),
beam_size=int(beam_size),
diarize=bool(diarize),
)
# 단발성 API도 DB 저장
try:
@@ -357,6 +362,7 @@ def _run_job(job_id: str) -> None:
language = job.language
vad_filter = job.vad_filter
beam_size = job.beam_size
do_diarize = job.diarize
author_id = job.author_id
language_requested = job.language_requested
filename = job.filename
@@ -421,6 +427,34 @@ def _run_job(job_id: str) -> None:
job.progress = None
job.updated_at = time.time()
if not cancelled and do_diarize:
with _JOBS_LOCK:
job = _JOBS.get(job_id)
if job is None:
return
if job.cancel_event.is_set():
cancelled = True
else:
segs_snapshot = [dict(s) for s in job.segments]
path_for_diar = job.tmp_path
if not cancelled and segs_snapshot:
from . import diarize as dz
mdir = os.getenv("APP_PYANNOTE_MODEL_DIR") or None
with _JOBS_LOCK:
job = _JOBS.get(job_id)
if job is not None:
job.progress = 0.97
job.updated_at = time.time()
text_d, segs_d = dz.build_diarized_output(segs_snapshot, path_for_diar, model_dir=mdir)
with _JOBS_LOCK:
job = _JOBS.get(job_id)
if job is not None:
job.text = text_d
job.segments = segs_d
job.updated_at = time.time()
with _JOBS_LOCK:
job = _JOBS.get(job_id)
if job is None:

View File

@@ -305,6 +305,11 @@
VAD 필터 (무음 구간 감소)
</label>
<label>
<input id="diarize" type="checkbox" checked />
화자 분리 (pyannote, whisper_stt.py와 동일 방식 — 서버에 로컬 모델·HF 토큰 필요)
</label>
<div class="row" style="margin-top: 12px">
<button class="btn primary" id="go" disabled>전사(STT) 실행</button>
<button class="btn" id="cancel" disabled>취소</button>
@@ -314,7 +319,8 @@
<div class="hint">
- 허용: mp3, m4a, wav, mp4, aac, ogg, flac, webm<br />
- 첫 실행 시 Whisper 모델 다운로드로 시간이 걸릴 수 있습니다.
- 첫 실행 시 Whisper 모델 다운로드로 시간이 걸릴 수 있습니다.<br />
- 화자 분리 켜짐: <span class="mono">./models/pyannote-diarization-3.1</span> 및 gated HF 모델 동의(README 참고).
</div>
<div class="progress">
@@ -420,6 +426,7 @@
const progTextEl = $("progText");
const downloadEl = $("download");
const clearEl = $("clear");
const diarizeEl = $("diarize");
const healthEl = $("health");
const metaEl = $("meta");
const timingEl = $("timing");
@@ -677,6 +684,7 @@
const author = (authorEl?.value || "").trim();
if (author) fd.append("author_id", author);
fd.append("vad_filter", $("vad").checked ? "true" : "false");
fd.append("diarize", !diarizeEl || diarizeEl.checked ? "true" : "false");
fd.append("beam_size", $("beam").value);
uploadController = new AbortController();

View File

@@ -54,6 +54,8 @@ def transcribe_file(
language: str | None = None,
vad_filter: bool = True,
beam_size: int = 5,
diarize: bool = True,
diarize_model_dir: str | None = None,
) -> dict[str, Any]:
segments_iter, info = transcribe_iter(
audio_path,
@@ -70,10 +72,23 @@ def transcribe_file(
segments.append(seg)
texts.append(seg.text)
seg_dicts = [seg.__dict__ for seg in segments]
full_text = "\n".join(texts).strip()
if diarize:
from . import diarize as dz
mdir = diarize_model_dir or os.getenv("APP_PYANNOTE_MODEL_DIR") or None
full_text, seg_dicts = dz.build_diarized_output(
seg_dicts,
audio_path,
model_dir=mdir,
with_disclaimer=True,
)
return {
"text": full_text,
"segments": [seg.__dict__ for seg in segments],
"segments": seg_dicts,
"detected_language": getattr(info, "language", None),
"language_probability": getattr(info, "language_probability", None),
"duration_sec": getattr(info, "duration", None),