This commit is contained in:
YannAhlgrim
2025-10-08 15:23:23 +02:00
parent b59f52cf86
commit 5e6eae61cc
8 changed files with 288 additions and 70 deletions
+3
View File
@@ -64,3 +64,6 @@ Thumbs.db
.pytest_cache/
.coverage
coverage.xml
# Ollama specific
ollama-data/
+2
View File
@@ -8,6 +8,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
git \
wget \
espeak-ng \
espeak-ng-data \
&& rm -rf /var/lib/apt/lists/*
COPY requirements.txt ./
+1 -1
View File
@@ -1,5 +1,5 @@
fastapi==0.100.0
uvicorn[standard]==0.22.0
TTS>=0.12.0
coqui-tts==0.27.2
soundfile==0.12.1
numpy==1.26.0
+36 -9
View File
@@ -1,18 +1,35 @@
from fastapi import FastAPI, HTTPException
from fastapi import Body
from fastapi import FastAPI, HTTPException, Body, BackgroundTasks
from fastapi.responses import FileResponse, JSONResponse
from TTS.api import TTS
import tempfile
import os
import logging
app = FastAPI()
# Load a German-capable model. Model may be downloaded on first run.
tts = TTS(model_name="tts_models/de/thorsten_hsmm")
# Load TTS model at startup with a valid 4-part model name
model_name = os.environ.get("TTS_MODEL", "tts_models/de/thorsten/tacotron2-DDC")
try:
tts = TTS(model_name=model_name)
logging.info(f"Successfully loaded TTS model: {model_name}")
except Exception as exc:
logging.exception("Failed to load TTS model '%s'", model_name)
raise RuntimeError(
f"Failed to load TTS model '{model_name}': {exc}.\n"
"Set the environment variable TTS_MODEL to a valid model id in the format 'tts_models/<lang>/<dataset>/<model>' "
"or install the model manually."
) from exc
def _cleanup_file(path: str) -> None:
try:
os.remove(path)
except Exception:
logging.exception("Failed to remove temporary file %s", path)
@app.post("/speak")
def speak(payload: dict = Body(...)):
def speak(payload: dict = Body(...), background_tasks: BackgroundTasks = None):
text = payload.get("text")
language = payload.get("language", "de")
if not text:
@@ -21,8 +38,18 @@ def speak(payload: dict = Body(...)):
fd, path = tempfile.mkstemp(suffix=".wav")
os.close(fd)
try:
tts.tts_to_file(text=text, speaker=None, language=language, file_path=path)
# Check if the model is multilingual before passing language parameter
if hasattr(tts, 'is_multi_lingual') and tts.is_multi_lingual:
tts.tts_to_file(text=text, speaker=None, language=language, file_path=path)
else:
# For non-multilingual models, don't pass the language parameter
tts.tts_to_file(text=text, speaker=None, file_path=path)
# schedule cleanup after response has been sent
if background_tasks is not None:
background_tasks.add_task(_cleanup_file, path)
return FileResponse(path, media_type="audio/wav", filename="response.wav")
finally:
# FileResponse will stream the file; don't remove immediately. Consumer can manage cleanup.
pass
except Exception:
logging.exception("TTS generation failed")
# try to remove file immediately on failure
_cleanup_file(path)
raise HTTPException(status_code=500, detail="TTS generation failed")
+28 -16
View File
@@ -31,23 +31,35 @@ export default function Home() {
}
async function sendAudio(blob) {
const form = new FormData()
// Convert webm to wav on the client is complex; many servers accept webm/ogg.
form.append('file', blob, 'recording.webm')
try {
const form = new FormData()
// Convert webm to wav on the client is complex; many servers accept webm/ogg.
form.append('file', blob, 'recording.webm')
const res = await fetch('http://localhost:8000/chat', { method: 'POST', body: form })
if (!res.ok) {
const text = await res.text()
alert('Error: ' + res.status + ' ' + text)
return
}
const audioBlob = await res.blob()
const url = URL.createObjectURL(audioBlob)
if (audioRef.current) {
audioRef.current.src = url
audioRef.current.play()
setPlaying(true)
audioRef.current.onended = () => setPlaying(false)
console.log('Sending request to backend...')
const res = await fetch('http://localhost:8000/chat', { method: 'POST', body: form })
console.log('Response received:', res.status, res.statusText)
if (!res.ok) {
const text = await res.text()
alert('Error: ' + res.status + ' ' + text)
return
}
console.log('Converting response to blob...')
const audioBlob = await res.blob()
console.log('Audio blob created, size:', audioBlob.size)
const url = URL.createObjectURL(audioBlob)
if (audioRef.current) {
audioRef.current.src = url
audioRef.current.play()
setPlaying(true)
audioRef.current.onended = () => setPlaying(false)
}
} catch (error) {
console.error('Error in sendAudio:', error)
alert('Failed to process audio: ' + error.message)
}
}
+144 -31
View File
@@ -1,15 +1,34 @@
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import httpx
import tempfile
import shutil
import asyncio
import os
import json
import logging
import traceback
app = FastAPI()
# Add CORS middleware to allow frontend requests
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, replace with specific origins like ["http://localhost:3000"]
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("middleware")
WHISPER_URL = "http://whisper:8001/transcribe"
COQUITTS_URL = "http://coquitts:8002/speak"
OLLAMA_URL = "http://ollama:11434/v1/complete"
OLLAMA_URL = "http://ollama:11434/api/generate"
LLM_MODEL = os.getenv("LLM_MODEL", "gemma3:270m")
logger.info("Using LLM model: %s", LLM_MODEL)
@app.post("/chat")
@@ -17,40 +36,134 @@ async def chat(file: UploadFile = File(...)):
if not file.content_type.startswith("audio"):
raise HTTPException(status_code=400, detail="File must be audio")
# save file to temp
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
contents = await file.read()
tmp.write(contents)
tmp.flush()
tmp_path = tmp.name
tmp_path = None
try:
# save file to temp
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
contents = await file.read()
tmp.write(contents)
tmp.flush()
tmp_path = tmp.name
logger.info("Saved uploaded audio to %s", tmp_path)
async with httpx.AsyncClient() as client:
# Send audio to whisper
with open(tmp_path, "rb") as f:
files = {"file": ("audio.wav", f, "audio/wav")}
r = await client.post(WHISPER_URL, files=files, timeout=120.0)
async with httpx.AsyncClient() as client:
# Send audio to whisper
try:
with open(tmp_path, "rb") as f:
files = {"file": ("audio.wav", f, "audio/wav")}
r = await client.post(WHISPER_URL, files=files, timeout=120.0)
except httpx.RequestError as e:
logger.exception("Request to Whisper failed")
raise HTTPException(status_code=502, detail=f"Whisper request failed: {e}")
if r.status_code != 200:
raise HTTPException(status_code=502, detail=f"Whisper error: {r.status_code} {r.text}")
if r.status_code != 200:
logger.error("Whisper returned non-200: %s body=%s", r.status_code, r.text)
raise HTTPException(status_code=502, detail=f"Whisper error: {r.status_code} {r.text}")
text = r.json().get("text", "")
try:
text = r.json().get("text", "")
except Exception:
logger.exception("Failed to decode Whisper JSON: %s", r.text)
raise HTTPException(status_code=502, detail="Invalid JSON from Whisper")
# Send text to ollama for reasoning
# We assume Ollama HTTP API accepts JSON {"model":"<model>", "prompt":"..."}
ollama_payload = {"model": "llama2", "prompt": text}
ro = await client.post(OLLAMA_URL, json=ollama_payload, timeout=120.0)
if ro.status_code != 200:
raise HTTPException(status_code=502, detail=f"Ollama error: {ro.status_code} {ro.text}")
logger.info("Whisper transcribed text: %s", text)
answer_json = ro.json()
# Depending on API shape, try to extract text
answer_text = answer_json.get("response") or answer_json.get("text") or answer_json.get("output") or str(answer_json)
# Send text to ollama for reasoning. Ollama often streams incremental
# JSON objects (one per line) instead of returning a single JSON body.
# Use a streaming request and parse JSON lines as they arrive.
ollama_payload = {"model": LLM_MODEL, "prompt": text}
try:
async with client.stream("POST", OLLAMA_URL, json=ollama_payload, timeout=120.0) as ro:
if ro.status_code != 200:
# read body for error message
body = await ro.aread()
body_text = body.decode(errors="ignore")
logger.error("Ollama returned non-200: %s body=%s", ro.status_code, body_text)
raise HTTPException(status_code=502, detail=f"Ollama error: {ro.status_code} {body_text}")
# Send answer to coquitts to generate German audio
coquitts_payload = {"text": answer_text, "language": "de"}
co = await client.post(COQUITTS_URL, json=coquitts_payload, timeout=120.0)
if co.status_code != 200:
raise HTTPException(status_code=502, detail=f"CoquiTTS error: {co.status_code} {co.text}")
parts = []
# Ollama can send multiple JSON objects line-by-line; collect 'response' parts
async for raw_line in ro.aiter_lines():
if not raw_line:
continue
line = raw_line.strip()
# handle Server-Sent Events style 'data: ...' prefixes
if line.startswith("data:"):
line = line[len("data:"):].strip()
try:
j = json.loads(line)
except Exception:
logger.debug("Skipping non-JSON line from Ollama: %s", line)
continue
# stream the audio back
return StreamingResponse(co.aiter_bytes(), media_type="audio/wav")
# prefer 'response', fallback to other candidate fields
piece = j.get("response") or j.get("text") or j.get("output")
if piece:
parts.append(str(piece))
# stop when stream indicates completion
if j.get("done") or j.get("done_reason"):
break
answer_text = "".join(parts).strip()
except httpx.RequestError as e:
logger.exception("Request to Ollama failed")
raise HTTPException(status_code=502, detail=f"Ollama request failed: {e}")
if not answer_text:
# as a last resort, try to fetch full text non-streamed
try:
fallback = await client.post(OLLAMA_URL, json=ollama_payload, timeout=30.0)
if fallback.status_code == 200:
try:
fallback_json = fallback.json()
answer_text = (
fallback_json.get("response")
or fallback_json.get("text")
or fallback_json.get("output")
or str(fallback_json)
)
except Exception:
answer_text = (await fallback.aread()).decode(errors="ignore")
else:
logger.debug("Ollama fallback non-200: %s %s", fallback.status_code, await fallback.aread())
except Exception:
# nothing more to do; keep answer_text empty and let downstream handle it
logger.debug("Ollama fallback failed", exc_info=True)
logger.info("Ollama returned: %s", answer_text)
# Send answer to coquitts to generate German audio
coquitts_payload = {"text": answer_text, "language": "de"}
try:
co = await client.post(COQUITTS_URL, json=coquitts_payload, timeout=120.0)
except httpx.RequestError as e:
logger.exception("Request to CoquiTTS failed")
raise HTTPException(status_code=502, detail=f"CoquiTTS request failed: {e}")
if co.status_code != 200:
logger.error("CoquiTTS returned non-200: %s body=%s", co.status_code, co.text)
raise HTTPException(status_code=502, detail=f"CoquiTTS error: {co.status_code} {co.text}")
logger.info("CoquiTTS returned audio, streaming back to client")
# Get the audio content as bytes instead of streaming
audio_content = await co.aread()
# Return the audio as a regular response with proper headers
return StreamingResponse(
iter([audio_content]),
media_type="audio/wav",
headers={
"Content-Length": str(len(audio_content)),
"Accept-Ranges": "bytes"
}
)
finally:
# cleanup temp file
try:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
logger.info("Removed temp file %s", tmp_path)
except Exception:
logger.warning("Failed to remove temp file %s: %s", tmp_path, traceback.format_exc())
+1 -1
View File
@@ -1,6 +1,6 @@
fastapi==0.100.0
uvicorn[standard]==0.22.0
whisper==1.1.10
openai-whisper
pydub==0.25.1
aiofiles==23.1.0
python-multipart==0.0.6
+73 -12
View File
@@ -3,30 +3,91 @@ from fastapi.responses import JSONResponse
import whisper
import tempfile
import shutil
import os
import logging
from pydub import AudioSegment
logging.basicConfig(level=logging.INFO)
app = FastAPI()
model = whisper.load_model("small")
# Load model at startup
try:
model = whisper.load_model("small")
except Exception:
logging.exception("Failed to load Whisper model")
# re-raise so container fails fast if model can't be loaded
raise
def convert_to_wav(src_path: str) -> str:
"""Convert an audio file (webm/ogg/mp3/...) to a 16 kHz mono WAV file using pydub/ffmpeg.
Returns path to the new WAV file (caller is responsible for cleanup).
"""
audio = AudioSegment.from_file(src_path)
audio = audio.set_frame_rate(16000).set_channels(1)
wav_tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
wav_tmp.close()
audio.export(wav_tmp.name, format="wav")
return wav_tmp.name
@app.post("/transcribe")
async def transcribe(file: UploadFile = File(...)):
if not file.content_type.startswith("audio"):
if not file.content_type or not file.content_type.startswith("audio"):
raise HTTPException(status_code=400, detail="File must be audio")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
# preserve original extension if possible
filename = file.filename or "upload"
ext = os.path.splitext(filename)[1] or ""
if not ext:
# try to infer common extension from content-type
if "webm" in file.content_type:
ext = ".webm"
elif "ogg" in file.content_type or "opus" in file.content_type:
ext = ".ogg"
elif "mpeg" in file.content_type or "mp3" in file.content_type:
ext = ".mp3"
else:
ext = ".wav"
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp:
contents = await file.read()
tmp.write(contents)
tmp.flush()
tmp_path = tmp.name
try:
result = model.transcribe(tmp_path, language=None)
text = result.get("text", "")
finally:
try:
shutil.os.remove(tmp_path)
except Exception:
pass
logging.info("Received upload %s (%d bytes, content-type=%s)", filename, os.path.getsize(tmp_path), file.content_type)
return JSONResponse({"text": text})
# If the uploaded file is not a WAV, convert it to WAV first to ensure ffmpeg/pydub compatibility.
wav_path = tmp_path
converted = False
try:
if not tmp_path.lower().endswith('.wav'):
try:
wav_path = convert_to_wav(tmp_path)
converted = True
logging.info("Converted to wav: %s (size=%d)", wav_path, os.path.getsize(wav_path))
except Exception as e:
# conversion failed; return a helpful error including ffmpeg/pydub message
logging.exception("Failed to convert uploaded audio to wav")
# try to surface the underlying error text
raise HTTPException(status_code=400, detail=f"Failed to convert audio: {e}")
try:
result = model.transcribe(wav_path, language=None)
text = result.get("text", "")
except RuntimeError as e:
# likely ffmpeg failed while loading audio; include error message for debugging
logging.exception("Whisper failed to transcribe audio")
raise HTTPException(status_code=500, detail=str(e))
return JSONResponse({"text": text})
finally:
# cleanup temp files
for path in {tmp_path, wav_path}:
try:
if path and os.path.exists(path):
os.remove(path)
except Exception:
logging.exception("Failed to remove temp file %s", path)