mvp
This commit is contained in:
@@ -8,6 +8,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
git \
|
||||
wget \
|
||||
espeak-ng \
|
||||
espeak-ng-data \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY requirements.txt ./
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
fastapi==0.100.0
|
||||
uvicorn[standard]==0.22.0
|
||||
TTS>=0.12.0
|
||||
coqui-tts==0.27.2
|
||||
soundfile==0.12.1
|
||||
numpy==1.26.0
|
||||
|
||||
+36
-9
@@ -1,18 +1,35 @@
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi import Body
|
||||
from fastapi import FastAPI, HTTPException, Body, BackgroundTasks
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
from TTS.api import TTS
|
||||
import tempfile
|
||||
import os
|
||||
import logging
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Load a German-capable model. Model may be downloaded on first run.
|
||||
tts = TTS(model_name="tts_models/de/thorsten_hsmm")
|
||||
# Load TTS model at startup with a valid 4-part model name
|
||||
model_name = os.environ.get("TTS_MODEL", "tts_models/de/thorsten/tacotron2-DDC")
|
||||
try:
|
||||
tts = TTS(model_name=model_name)
|
||||
logging.info(f"Successfully loaded TTS model: {model_name}")
|
||||
except Exception as exc:
|
||||
logging.exception("Failed to load TTS model '%s'", model_name)
|
||||
raise RuntimeError(
|
||||
f"Failed to load TTS model '{model_name}': {exc}.\n"
|
||||
"Set the environment variable TTS_MODEL to a valid model id in the format 'tts_models/<lang>/<dataset>/<model>' "
|
||||
"or install the model manually."
|
||||
) from exc
|
||||
|
||||
|
||||
def _cleanup_file(path: str) -> None:
|
||||
try:
|
||||
os.remove(path)
|
||||
except Exception:
|
||||
logging.exception("Failed to remove temporary file %s", path)
|
||||
|
||||
|
||||
@app.post("/speak")
|
||||
def speak(payload: dict = Body(...)):
|
||||
def speak(payload: dict = Body(...), background_tasks: BackgroundTasks = None):
|
||||
text = payload.get("text")
|
||||
language = payload.get("language", "de")
|
||||
if not text:
|
||||
@@ -21,8 +38,18 @@ def speak(payload: dict = Body(...)):
|
||||
fd, path = tempfile.mkstemp(suffix=".wav")
|
||||
os.close(fd)
|
||||
try:
|
||||
tts.tts_to_file(text=text, speaker=None, language=language, file_path=path)
|
||||
# Check if the model is multilingual before passing language parameter
|
||||
if hasattr(tts, 'is_multi_lingual') and tts.is_multi_lingual:
|
||||
tts.tts_to_file(text=text, speaker=None, language=language, file_path=path)
|
||||
else:
|
||||
# For non-multilingual models, don't pass the language parameter
|
||||
tts.tts_to_file(text=text, speaker=None, file_path=path)
|
||||
# schedule cleanup after response has been sent
|
||||
if background_tasks is not None:
|
||||
background_tasks.add_task(_cleanup_file, path)
|
||||
return FileResponse(path, media_type="audio/wav", filename="response.wav")
|
||||
finally:
|
||||
# FileResponse will stream the file; don't remove immediately. Consumer can manage cleanup.
|
||||
pass
|
||||
except Exception:
|
||||
logging.exception("TTS generation failed")
|
||||
# try to remove file immediately on failure
|
||||
_cleanup_file(path)
|
||||
raise HTTPException(status_code=500, detail="TTS generation failed")
|
||||
|
||||
Reference in New Issue
Block a user