whisper_stt: support pyannote 4.x DiarizeOutput for itertracks
Pipeline returns DiarizeOutput with speaker_diarization Annotation; fall back to legacy Annotation when attribute absent. Made-with: Cursor
This commit is contained in:
@@ -20,6 +20,12 @@ if _STT_ROOT not in sys.path:
|
||||
DEFAULT_DIARIZE_MODEL_DIR = "./models/pyannote-diarization-3.1"
|
||||
|
||||
|
||||
def _diarization_annotation(diarization: Any) -> Any:
|
||||
"""pyannote.audio 4.x는 DiarizeOutput을 반환하고, 구간은 .speaker_diarization(Annotation)에 있다."""
|
||||
ann = getattr(diarization, "speaker_diarization", None)
|
||||
return diarization if ann is None else ann
|
||||
|
||||
|
||||
def _validate_pyannote_snapshot(model_dir: str) -> None:
|
||||
"""README만 있거나 중간에 끊긴 다운로드면 config.yaml 이 없다."""
|
||||
cfg = os.path.join(model_dir, "config.yaml")
|
||||
@@ -242,8 +248,9 @@ def _run_diarization(audio_path: str, *, diarize_model_dir: str | None) -> list[
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
pipeline.to(device)
|
||||
diarization = pipeline(audio_path)
|
||||
ann = _diarization_annotation(diarization)
|
||||
turns: list[tuple[float, float, str]] = []
|
||||
for segment, _, label in diarization.itertracks(yield_label=True):
|
||||
for segment, _, label in ann.itertracks(yield_label=True):
|
||||
turns.append((float(segment.start), float(segment.end), str(label)))
|
||||
turns.sort(key=lambda x: x[0])
|
||||
print(f"[4/4] 화자 분리 완료 ({time.perf_counter() - t0:.1f}초, 구간 {len(turns)}개)", flush=True)
|
||||
|
||||
Reference in New Issue
Block a user