feat(web): speaker diarization via pyannote (parity with whisper_stt)
- Add app/diarize.py: local snapshot, A/B labels, disclaimer text - transcribe_file and async jobs support diarize flag; Form diarize on API - UI checkbox (default on); requirements: pyannote.audio, huggingface_hub - README: env vars and model notes Made-with: Cursor
This commit is contained in:
172
app/diarize.py
Normal file
172
app/diarize.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""웹 STT용 화자 분리 — whisper_stt.py와 동일한 pyannote 로컬 스냅샷 + 타임라인 정렬."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .pyannote_auth import load_pyannote_pipeline
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||
DEFAULT_DIARIZE_MODEL_DIR = PROJECT_ROOT / "models" / "pyannote-diarization-3.1"
|
||||
|
||||
DIARIZE_DISCLAIMER_KO = (
|
||||
"※ 화자 A, B, C… 는 실제 이름이 아니라, 이 녹음에서 말이 처음 잡힌 순서로 붙인 구분자입니다.\n"
|
||||
"※ 같은 사람이 여러 구간으로 나뉘면 라벨이 바뀌거나 섞일 수 있으니, 중요한 회의는 검수가 필요합니다.\n\n"
|
||||
)
|
||||
|
||||
|
||||
def diarization_annotation(diarization: Any) -> Any:
|
||||
"""pyannote.audio 4.x는 DiarizeOutput; 구간은 .speaker_diarization에 있다."""
|
||||
ann = getattr(diarization, "speaker_diarization", None)
|
||||
return diarization if ann is None else ann
|
||||
|
||||
|
||||
def validate_pyannote_snapshot(model_dir: Path | str) -> None:
|
||||
cfg = Path(model_dir) / "config.yaml"
|
||||
if cfg.is_file():
|
||||
return
|
||||
p = Path(model_dir).resolve()
|
||||
raise ValueError(
|
||||
f"pyannote 모델 폴더가 불완전합니다 (config.yaml 없음): {p}. "
|
||||
"hf download pyannote/speaker-diarization-3.1 --local-dir ./models/pyannote-diarization-3.1"
|
||||
)
|
||||
|
||||
|
||||
def resolve_local_diarize_dir(override: str | None) -> Path:
|
||||
if override:
|
||||
path = Path(override).expanduser().resolve()
|
||||
if path.is_dir():
|
||||
return path
|
||||
raise ValueError(f"화자 분리 모델 폴더가 없습니다: {path}")
|
||||
|
||||
for cand in (os.environ.get("WHISPER_DIARIZE_MODEL_DIR"), os.environ.get("PYANNOTE_MODEL_DIR")):
|
||||
if cand:
|
||||
path = Path(cand).expanduser().resolve()
|
||||
if path.is_dir():
|
||||
return path
|
||||
|
||||
path = DEFAULT_DIARIZE_MODEL_DIR.resolve()
|
||||
if path.is_dir():
|
||||
return path
|
||||
raise ValueError(
|
||||
f"화자 분리 모델 폴더가 없습니다: {path}. "
|
||||
"프로젝트 루트에서: hf download pyannote/speaker-diarization-3.1 "
|
||||
"--local-dir ./models/pyannote-diarization-3.1 (약관 동의·HF 토큰 필요)"
|
||||
)
|
||||
|
||||
|
||||
def speaker_turns(audio_path: str, *, model_dir: str | None = None) -> list[tuple[float, float, str]]:
|
||||
import torch
|
||||
|
||||
resolved = resolve_local_diarize_dir(model_dir)
|
||||
validate_pyannote_snapshot(resolved)
|
||||
pipeline = load_pyannote_pipeline(resolved)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
pipeline.to(device)
|
||||
diarization = pipeline(audio_path)
|
||||
ann = diarization_annotation(diarization)
|
||||
turns: list[tuple[float, float, str]] = []
|
||||
for segment, _, label in ann.itertracks(yield_label=True):
|
||||
turns.append((float(segment.start), float(segment.end), str(label)))
|
||||
turns.sort(key=lambda x: x[0])
|
||||
return turns
|
||||
|
||||
|
||||
def _overlap_sec(a0: float, a1: float, b0: float, b1: float) -> float:
|
||||
return max(0.0, min(a1, b1) - max(a0, b0))
|
||||
|
||||
|
||||
def _assign_speaker(
|
||||
seg_start: float, seg_end: float, turns: list[tuple[float, float, str]]
|
||||
) -> str | None:
|
||||
best: str | None = None
|
||||
best_ov = 0.0
|
||||
for t0, t1, sp in turns:
|
||||
ov = _overlap_sec(seg_start, seg_end, t0, t1)
|
||||
if ov > best_ov:
|
||||
best_ov = ov
|
||||
best = sp
|
||||
if best is None or best_ov < 0.05:
|
||||
return None
|
||||
return best
|
||||
|
||||
|
||||
def speaker_label_order(turns: list[tuple[float, float, str]]) -> dict[str, str]:
|
||||
order: list[str] = []
|
||||
for t0, _, sp in sorted(turns, key=lambda x: x[0]):
|
||||
if sp not in order:
|
||||
order.append(sp)
|
||||
|
||||
def letter(i: int) -> str:
|
||||
if i < 26:
|
||||
return chr(ord("A") + i)
|
||||
return f"SP{i + 1}"
|
||||
|
||||
return {sp: letter(i) for i, sp in enumerate(order)}
|
||||
|
||||
|
||||
def format_diarized_text(
|
||||
whisper_segments: list[dict[str, Any]],
|
||||
turns: list[tuple[float, float, str]],
|
||||
) -> str:
|
||||
labels = speaker_label_order(turns)
|
||||
lines: list[str] = []
|
||||
current_letter: str | None = None
|
||||
current_parts: list[str] = []
|
||||
|
||||
def flush() -> None:
|
||||
nonlocal current_letter, current_parts
|
||||
if current_letter is not None and current_parts:
|
||||
lines.append(f"{current_letter}: {' '.join(current_parts).strip()}")
|
||||
current_letter = None
|
||||
current_parts = []
|
||||
|
||||
for seg in whisper_segments:
|
||||
text = (seg.get("text") or "").strip()
|
||||
if not text:
|
||||
continue
|
||||
start = float(seg["start"])
|
||||
end = float(seg["end"])
|
||||
sp = _assign_speaker(start, end, turns)
|
||||
letter = labels.get(sp, "?") if sp is not None else "?"
|
||||
|
||||
if letter == current_letter:
|
||||
current_parts.append(text)
|
||||
else:
|
||||
flush()
|
||||
current_letter = letter
|
||||
current_parts = [text]
|
||||
|
||||
flush()
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def segments_with_speakers(
|
||||
whisper_segments: list[dict[str, Any]],
|
||||
turns: list[tuple[float, float, str]],
|
||||
) -> list[dict[str, Any]]:
|
||||
labels = speaker_label_order(turns)
|
||||
out: list[dict[str, Any]] = []
|
||||
for seg in whisper_segments:
|
||||
text = (seg.get("text") or "").strip()
|
||||
if not text:
|
||||
continue
|
||||
sp = _assign_speaker(float(seg["start"]), float(seg["end"]), turns)
|
||||
letter = labels.get(sp, "?") if sp is not None else "?"
|
||||
out.append({**seg, "text": text, "speaker": letter})
|
||||
return out
|
||||
|
||||
|
||||
def build_diarized_output(
|
||||
whisper_segments: list[dict[str, Any]],
|
||||
audio_path: str,
|
||||
*,
|
||||
model_dir: str | None = None,
|
||||
with_disclaimer: bool = True,
|
||||
) -> tuple[str, list[dict[str, Any]]]:
|
||||
turns = speaker_turns(audio_path, model_dir=model_dir)
|
||||
body = format_diarized_text(whisper_segments, turns)
|
||||
text = (DIARIZE_DISCLAIMER_KO + body) if with_disclaimer else body
|
||||
segs = segments_with_speakers(whisper_segments, turns)
|
||||
return text, segs
|
||||
Reference in New Issue
Block a user