diff --git a/README.md b/README.md index 3a9e24d..2fb335a 100644 --- a/README.md +++ b/README.md @@ -86,13 +86,32 @@ pip install -r requirements-whisper-stt.txt ```bash conda activate stt # 또는 사용 중인 env (예: ncue) -pip uninstall -y torch torchvision torchaudio -pip uninstall -y torch torchvision torchaudio # Skipping만 나올 때까지 반복 +pip uninstall -y torch torchvision torchaudio functorch +pip uninstall -y torch torchvision torchaudio functorch # Skipping만 나올 때까지 반복 pip cache purge pip install torch torchaudio --index-url https://download.pytorch.org/whl/cpu pip install -r requirements-whisper-stt.txt ``` +**같은 오류(`ATen.h` 없음 등)가 `torch` 재설치 시에도 반복되면** +`pip uninstall`만으로는 깨진 `site-packages/torch` 폴더가 남는 경우가 있습니다. 아래로 **잔여 디렉터리를 직접 삭제**한 뒤 다시 설치하세요. (`python3.11`은 `python -c "import sys; print(sys.version_info[:2])"`로 맞춤.) + +```bash +conda activate ncue # 문제 나는 env +pip uninstall -y torch torchvision torchaudio functorch 2>/dev/null || true +rm -rf "$CONDA_PREFIX/lib/python3.11/site-packages/torch" \ + "$CONDA_PREFIX/lib/python3.11/site-packages/torch-"*.dist-info \ + "$CONDA_PREFIX/lib/python3.11/site-packages/torchaudio" \ + "$CONDA_PREFIX/lib/python3.11/site-packages/torchaudio-"*.dist-info \ + "$CONDA_PREFIX/lib/python3.11/site-packages/torchgen" \ + "$CONDA_PREFIX/lib/python3.11/site-packages/functorch" +pip cache purge +pip install torch torchaudio --index-url https://download.pytorch.org/whl/cpu +pip install -r requirements-whisper-stt.txt +``` + +그래도 실패하면 **새 conda 환경**(`conda create -n stt-whisper python=3.11 -y`)을 만들고, 위 README의 **conda로 PyTorch 먼저** 절차만 그 env에서 진행하는 것이 가장 확실합니다. + 애초에 꼬이지 않게 하려면 **PyTorch를 conda로 먼저** 깐 뒤 위 requirements만 pip로 설치하는 것을 권장합니다. ```bash @@ -102,6 +121,8 @@ pip install -r requirements-whisper-stt.txt ``` - **Hugging Face `hf` CLI**: `pip install huggingface_hub` 후 `hf auth login`, `hf download …` (화자 구분용 pyannote 모델 등). + - $ hf auth login + - $ hf download pyannote/speaker-diarization-3.1 --local-dir ./models/pyannote-diarization-3.1 - **화자 구분(기본 켜짐)**: `./models/pyannote-diarization-3.1` 에 pyannote 스냅샷이 있어야 합니다. 없으면 스크립트가 `hf download` 안내 후 종료합니다. 모델 받기: [pyannote/speaker-diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1) 약관 동의 후 `hf auth login`, `hf download … --local-dir ./models/pyannote-diarization-3.1`. 다른 경로는 `--diarize-model-dir` 또는 `WHISPER_DIARIZE_MODEL_DIR` 로 지정. - **화자 구분 끄기**: `python whisper_stt.py 입력.m4a 출력.txt --no-diarize` (Whisper 통문만 저장) @@ -131,12 +152,15 @@ uvicorn app.main:app --reload --host 127.0.0.1 --port 8025 브라우저에서 `http://127.0.0.1:8025` 접속. +업로드 전사가 끝나면 **`app/diarize.py`** 가 `whisper_stt.py`와 같은 방식으로 pyannote 화자 구분을 시도합니다. 저장소 루트의 **`models/pyannote-diarization-3.1`** (`config.yaml` 포함)이 있어야 하며, `requirements.txt`에 `pyannote.audio`가 포함되어 있습니다. 스냅샷이 없거나 오류면 전사만 반환하고, 응답에 `speaker_diarization: false` 와 `diarize_skip_reason` 이 붙을 수 있습니다. + --- ## 옵션·환경 변수 - **모델**: 기본 `small` (정확도/속도 균형). `APP_WHISPER_MODEL=base|small|medium|large-v3` 등으로 변경 가능. - **디바이스**: 기본 CPU. Apple Silicon에서 Metal은 `faster-whisper` 단독으로는 제한이 있어 CPU 기본값을 권장. +- **웹 화자 구분**: `APP_DIARIZE=1`(기본) — `0`/`false`/`off` 이면 pyannote 단계 생략. `APP_PYANNOTE_MODEL_DIR` 로 스냅샷 경로 지정(없으면 프로젝트 `models/pyannote-diarization-3.1`). - **기타**: `APP_WHISPER_DEVICE`, `APP_WHISPER_COMPUTE_TYPE`, 업로드 크기 등은 `app/main.py` 및 `.env` 예시를 참고. --- diff --git a/app/diarize.py b/app/diarize.py new file mode 100644 index 0000000..cafbe61 --- /dev/null +++ b/app/diarize.py @@ -0,0 +1,186 @@ +""" +업로드 STT 결과에 pyannote 화자 구분을 합칩니다 (whisper_stt.py 와 동일한 규칙). +환경변수 APP_DIARIZE=0 이면 비활성화. 모델: APP_PYANNOTE_MODEL_DIR 또는 프로젝트 models/pyannote-diarization-3.1 +""" +from __future__ import annotations + +import logging +import os +from pathlib import Path +from typing import Any + +log = logging.getLogger(__name__) + +_APP_DIR = Path(__file__).resolve().parent +_PROJECT_ROOT = _APP_DIR.parent +_DEFAULT_SNAPSHOT = _PROJECT_ROOT / "models" / "pyannote-diarization-3.1" + +_DISCLAIMER = ( + "※ 화자 A, B, C… 는 실제 이름이 아니라, 이 녹음에서 말이 처음 잡힌 순서로 붙인 구분자입니다.\n" + "※ 같은 사람이 여러 구간으로 나뉘면 라벨이 바뀌거나 섞일 수 있으니, 중요한 회의는 검수가 필요합니다.\n\n" +) + + +def _env_disabled() -> bool: + v = os.getenv("APP_DIARIZE", "1").strip().lower() + return v in ("0", "false", "no", "off") + + +def resolve_snapshot_dir() -> Path | None: + raw = os.getenv("APP_PYANNOTE_MODEL_DIR", "").strip() + if raw: + p = Path(raw).expanduser() + if not p.is_absolute(): + p = (_PROJECT_ROOT / p).resolve() + else: + p = _DEFAULT_SNAPSHOT.resolve() + if (p / "config.yaml").is_file(): + return p + return None + + +def _overlap_sec(a0: float, a1: float, b0: float, b1: float) -> float: + return max(0.0, min(a1, b1) - max(a0, b0)) + + +def _assign_speaker( + seg_start: float, seg_end: float, turns: list[tuple[float, float, str]] +) -> str | None: + best: str | None = None + best_ov = 0.0 + for t0, t1, sp in turns: + ov = _overlap_sec(seg_start, seg_end, t0, t1) + if ov > best_ov: + best_ov = ov + best = sp + if best is None or best_ov < 0.05: + return None + return best + + +def _speaker_label_order(turns: list[tuple[float, float, str]]) -> dict[str, str]: + order: list[str] = [] + for t0, _, sp in sorted(turns, key=lambda x: x[0]): + if sp not in order: + order.append(sp) + + def letter(i: int) -> str: + if i < 26: + return chr(ord("A") + i) + return f"SP{i + 1}" + + return {sp: letter(i) for i, sp in enumerate(order)} + + +def _merge_segments( + whisper_segments: list[dict[str, Any]], + turns: list[tuple[float, float, str]], +) -> tuple[str, list[dict[str, Any]]]: + labels = _speaker_label_order(turns) + merged_lines: list[str] = [] + out_segments: list[dict[str, Any]] = [] + + current_letter: str | None = None + current_parts: list[str] = [] + current_start: float | None = None + current_end: float | None = None + + def flush() -> None: + nonlocal current_letter, current_parts, current_start, current_end + if current_letter is not None and current_parts and current_start is not None and current_end is not None: + line = " ".join(current_parts).strip() + merged_lines.append(f"{current_letter}: {line}") + out_segments.append( + { + "start": current_start, + "end": current_end, + "speaker": current_letter, + "text": line, + } + ) + current_letter = None + current_parts = [] + current_start = None + current_end = None + + for seg in whisper_segments: + text = (seg.get("text") or "").strip() + if not text: + continue + start = float(seg["start"]) + end = float(seg["end"]) + sp = _assign_speaker(start, end, turns) + letter = labels.get(sp, "?") if sp is not None else "?" + + if letter == current_letter: + current_parts.append(text) + current_end = end + else: + flush() + current_letter = letter + current_parts = [text] + current_start = start + current_end = end + + flush() + body = "\n".join(merged_lines).strip() + return body, out_segments + + +def _run_pyannote(audio_path: str, model_dir: Path) -> list[tuple[float, float, str]]: + import torch + from pyannote.audio import Pipeline + + pipeline = Pipeline.from_pretrained(str(model_dir)) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + pipeline.to(device) + diarization = pipeline(audio_path) + turns: list[tuple[float, float, str]] = [] + for segment, _, label in diarization.itertracks(yield_label=True): + turns.append((float(segment.start), float(segment.end), str(label))) + turns.sort(key=lambda x: x[0]) + return turns + + +def apply_speaker_diarization(result: dict[str, Any], audio_path: str) -> dict[str, Any]: + """ + transcribe_file 결과에 speaker 필드·A:/B: 본문을 반영. + 실패·비활성 시 원본 유지 및 메타만 추가. + """ + out = dict(result) + out.setdefault("speaker_diarization", False) + out.pop("diarize_skip_reason", None) + + if _env_disabled(): + out["diarize_skip_reason"] = "APP_DIARIZE=0" + return out + + snap = resolve_snapshot_dir() + if snap is None: + out["diarize_skip_reason"] = f"pyannote 스냅샷 없음(config.yaml): {_DEFAULT_SNAPSHOT}" + log.warning("Speaker diarization skipped: %s", out["diarize_skip_reason"]) + return out + + try: + import pyannote.audio # noqa: F401 + except ImportError: + out["diarize_skip_reason"] = "pyannote.audio 미설치" + log.warning("Speaker diarization skipped: pyannote not installed") + return out + + segs = list(out.get("segments") or []) + if not segs: + out["diarize_skip_reason"] = "세그먼트 없음" + return out + + try: + turns = _run_pyannote(audio_path, snap) + body, new_segs = _merge_segments(segs, turns) + out["text"] = _DISCLAIMER + body if body else out.get("text", "") + out["segments"] = new_segs + out["speaker_diarization"] = True + out.pop("diarize_skip_reason", None) + except Exception as e: + out["diarize_skip_reason"] = str(e) + log.exception("Speaker diarization failed") + return out diff --git a/app/main.py b/app/main.py index 9abe7aa..02035bb 100644 --- a/app/main.py +++ b/app/main.py @@ -18,6 +18,7 @@ from fastapi.staticfiles import StaticFiles from pydantic import BaseModel from . import db +from .diarize import apply_speaker_diarization from .stt import transcribe_file, transcribe_iter @@ -74,6 +75,8 @@ class _Job: created_at: float = dataclasses.field(default_factory=time.time) updated_at: float = dataclasses.field(default_factory=time.time) cancel_event: threading.Event = dataclasses.field(default_factory=threading.Event, repr=False) + speaker_diarization: bool = False + diarize_skip_reason: str | None = None _JOBS: dict[str, _Job] = {} @@ -111,6 +114,8 @@ def _job_public(job: _Job) -> dict[str, Any]: "language_probability": job.language_probability, "duration_sec": job.duration_sec, "error": job.error, + "speaker_diarization": job.speaker_diarization, + "diarize_skip_reason": job.diarize_skip_reason, "created_at": job.created_at, "updated_at": job.updated_at, } @@ -204,6 +209,7 @@ async def api_transcribe( vad_filter=bool(vad_filter), beam_size=int(beam_size), ) + result = apply_speaker_diarization(result, saved_path) # 단발성 API도 DB 저장 try: db.insert_record( @@ -428,6 +434,20 @@ def _run_job(job_id: str) -> None: if cancelled or job.cancel_event.is_set(): job.status = "cancelled" else: + merged = apply_speaker_diarization( + { + "text": job.text, + "segments": list(job.segments), + "detected_language": job.detected_language, + "language_probability": job.language_probability, + "duration_sec": job.duration_sec, + }, + tmp_path, + ) + job.text = merged.get("text", job.text) + job.segments = merged.get("segments", job.segments) + job.speaker_diarization = bool(merged.get("speaker_diarization")) + job.diarize_skip_reason = merged.get("diarize_skip_reason") job.status = "completed" job.progress = 1.0 job.updated_at = time.time() diff --git a/app/static/index.html b/app/static/index.html index 7c725da..8615298 100644 --- a/app/static/index.html +++ b/app/static/index.html @@ -314,7 +314,8 @@
models/pyannote-diarization-3.1 필요).