87 lines
2.1 KiB
Python
87 lines
2.1 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
from dataclasses import dataclass
|
|
from typing import Any, Iterable, Tuple
|
|
|
|
from faster_whisper import WhisperModel
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class SegmentOut:
|
|
start: float
|
|
end: float
|
|
text: str
|
|
|
|
|
|
_MODEL: WhisperModel | None = None
|
|
|
|
|
|
def _get_model() -> WhisperModel:
|
|
global _MODEL
|
|
if _MODEL is not None:
|
|
return _MODEL
|
|
|
|
model_name = os.getenv("APP_WHISPER_MODEL", "small")
|
|
device = os.getenv("APP_WHISPER_DEVICE", "cpu")
|
|
compute_type = os.getenv("APP_WHISPER_COMPUTE_TYPE", "int8")
|
|
|
|
# WhisperModel download/cache handled by faster-whisper internally.
|
|
_MODEL = WhisperModel(model_name, device=device, compute_type=compute_type)
|
|
return _MODEL
|
|
|
|
|
|
def transcribe_iter(
|
|
audio_path: str,
|
|
*,
|
|
language: str | None = None,
|
|
vad_filter: bool = True,
|
|
beam_size: int = 5,
|
|
) -> Tuple[Iterable[Any], Any]:
|
|
model = _get_model()
|
|
segments_iter, info = model.transcribe(
|
|
audio_path,
|
|
language=language,
|
|
vad_filter=vad_filter,
|
|
beam_size=beam_size,
|
|
)
|
|
return segments_iter, info
|
|
|
|
|
|
def transcribe_file(
|
|
audio_path: str,
|
|
*,
|
|
language: str | None = None,
|
|
vad_filter: bool = True,
|
|
beam_size: int = 5,
|
|
) -> dict[str, Any]:
|
|
segments_iter, info = transcribe_iter(
|
|
audio_path,
|
|
language=language,
|
|
vad_filter=vad_filter,
|
|
beam_size=beam_size,
|
|
)
|
|
|
|
segments: list[SegmentOut] = []
|
|
texts: list[str] = []
|
|
for s in _iter_segments(segments_iter):
|
|
seg = SegmentOut(start=float(s.start), end=float(s.end), text=(s.text or "").strip())
|
|
if seg.text:
|
|
segments.append(seg)
|
|
texts.append(seg.text)
|
|
|
|
full_text = "\n".join(texts).strip()
|
|
return {
|
|
"text": full_text,
|
|
"segments": [seg.__dict__ for seg in segments],
|
|
"detected_language": getattr(info, "language", None),
|
|
"language_probability": getattr(info, "language_probability", None),
|
|
"duration_sec": getattr(info, "duration", None),
|
|
}
|
|
|
|
|
|
def _iter_segments(segments_iter: Iterable[Any]) -> Iterable[Any]:
|
|
for s in segments_iter:
|
|
yield s
|
|
|