Files
tts/server/tts_service.py
dsyoon 9b3a743c52 Reduce MMS audio distortion
Write MMS wav output as PCM16, simplify filters, and normalize punctuation to avoid garbled speech.
2026-01-30 20:24:44 +09:00

189 lines
5.4 KiB
Python

import os
import re
import subprocess
import tempfile
from pathlib import Path
from typing import Optional, Tuple
import pyttsx3
_MMS_CACHE: Optional[Tuple[object, object]] = None
_LETTER_KO = {
"A": "에이",
"B": "",
"C": "",
"D": "",
"E": "",
"F": "에프",
"G": "",
"H": "에이치",
"I": "아이",
"J": "제이",
"K": "케이",
"L": "",
"M": "",
"N": "",
"O": "",
"P": "",
"Q": "",
"R": "",
"S": "에스",
"T": "",
"U": "",
"V": "브이",
"W": "더블유",
"X": "엑스",
"Y": "와이",
"Z": "",
}
_PHRASE_MAP = [
("Automatic Document Feeder", "오토매틱 도큐먼트 피더"),
("Naver Blog", "네이버 블로그"),
("Brother Korea", "브라더 코리아"),
]
def _get_mms():
global _MMS_CACHE
if _MMS_CACHE is not None:
return _MMS_CACHE
try:
from transformers import VitsModel, AutoTokenizer
import torch
except Exception as exc:
raise RuntimeError("MMS TTS 사용을 위해 transformers/torch 설치가 필요합니다.") from exc
model_name = os.getenv("MMS_MODEL", "facebook/mms-tts-kor")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = VitsModel.from_pretrained(model_name)
model.eval()
_MMS_CACHE = (model, tokenizer)
return _MMS_CACHE
def _text_to_wav_mms(text: str, wav_path: str) -> None:
try:
import torch
except Exception as exc:
raise RuntimeError("MMS TTS 사용을 위해 torch/numpy가 정상 설치되어야 합니다.") from exc
try:
import soundfile as sf
except Exception as exc:
raise RuntimeError("MMS TTS 사용을 위해 soundfile 설치가 필요합니다.") from exc
model, tokenizer = _get_mms()
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
audio = model(**inputs).waveform.squeeze().cpu().numpy()
sample_rate = getattr(model.config, "sampling_rate", 22050)
# MMS 출력은 float이므로 PCM16으로 저장해 왜곡을 줄입니다.
sf.write(wav_path, audio, sample_rate, subtype="PCM_16")
def _select_korean_voice(engine: pyttsx3.Engine) -> None:
try:
voices = engine.getProperty("voices") or []
except Exception:
return
for voice in voices:
lang_values = []
if getattr(voice, "languages", None):
lang_values.extend(voice.languages)
if getattr(voice, "id", None):
lang_values.append(voice.id)
if getattr(voice, "name", None):
lang_values.append(voice.name)
joined = " ".join(str(v) for v in lang_values).lower()
if "ko" in joined or "korean" in joined:
try:
engine.setProperty("voice", voice.id)
return
except Exception:
continue
def _spell_abbrev(match: re.Match) -> str:
return " ".join(_LETTER_KO.get(ch, ch) for ch in match.group(0))
def _preprocess_text(text: str) -> str:
# 영어 약어/브랜드 발음 보정
for src, dst in _PHRASE_MAP:
text = re.sub(rf"\b{re.escape(src)}\b", dst, text, flags=re.IGNORECASE)
text = re.sub(r"\b[A-Z]{2,6}\b", _spell_abbrev, text)
# 괄호/구두점으로 인한 끊김을 완화
text = text.replace("(", " ").replace(")", " ")
return text
def text_to_mp3(text: str, mp3_path: str) -> None:
if not text:
raise RuntimeError("텍스트가 비어 있습니다.")
text = _preprocess_text(text)
mp3_target = Path(mp3_path)
mp3_target.parent.mkdir(parents=True, exist_ok=True)
tts_engine = os.getenv("TTS_ENGINE", "pyttsx3").strip().lower()
wav_fd, wav_path = tempfile.mkstemp(suffix=".wav")
os.close(wav_fd)
try:
if tts_engine == "mms":
_text_to_wav_mms(text, wav_path)
audio_filter = "highpass=f=80,lowpass=f=12000"
else:
engine = pyttsx3.init()
# 음질 개선: 속도/볼륨 조정 및 한국어 음성 우선 선택
try:
# 서버 음성이 늘어지는 현상 완화
engine.setProperty("rate", 210)
engine.setProperty("volume", 1.0)
except Exception:
pass
_select_korean_voice(engine)
# pyttsx3로 wav 생성 후 ffmpeg로 mp3 변환
engine.save_to_file(text, wav_path)
engine.runAndWait()
audio_filter = "loudnorm=I=-16:LRA=11:TP=-1.5,atempo=1.15"
subprocess.run(
[
"ffmpeg",
"-y",
"-i",
wav_path,
"-ac",
"2",
"-ar",
"44100",
"-b:a",
"192k",
"-af",
audio_filter,
str(mp3_target),
],
check=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
if not mp3_target.exists():
raise RuntimeError("mp3 파일 생성에 실패했습니다.")
except subprocess.CalledProcessError as exc:
raise RuntimeError("ffmpeg 변환에 실패했습니다.") from exc
except OSError as exc:
raise RuntimeError("파일 생성 권한 또는 경로 오류입니다.") from exc
finally:
try:
os.remove(wav_path)
except OSError:
pass