This commit is contained in:
YannAhlgrim
2025-10-08 15:23:23 +02:00
parent b59f52cf86
commit 5e6eae61cc
8 changed files with 288 additions and 70 deletions
+3
View File
@@ -64,3 +64,6 @@ Thumbs.db
.pytest_cache/ .pytest_cache/
.coverage .coverage
coverage.xml coverage.xml
# Ollama specific
ollama-data/
+2
View File
@@ -8,6 +8,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \ build-essential \
git \ git \
wget \ wget \
espeak-ng \
espeak-ng-data \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
COPY requirements.txt ./ COPY requirements.txt ./
+1 -1
View File
@@ -1,5 +1,5 @@
fastapi==0.100.0 fastapi==0.100.0
uvicorn[standard]==0.22.0 uvicorn[standard]==0.22.0
TTS>=0.12.0 coqui-tts==0.27.2
soundfile==0.12.1 soundfile==0.12.1
numpy==1.26.0 numpy==1.26.0
+35 -8
View File
@@ -1,18 +1,35 @@
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException, Body, BackgroundTasks
from fastapi import Body
from fastapi.responses import FileResponse, JSONResponse from fastapi.responses import FileResponse, JSONResponse
from TTS.api import TTS from TTS.api import TTS
import tempfile import tempfile
import os import os
import logging
app = FastAPI() app = FastAPI()
# Load a German-capable model. Model may be downloaded on first run. # Load TTS model at startup with a valid 4-part model name
tts = TTS(model_name="tts_models/de/thorsten_hsmm") model_name = os.environ.get("TTS_MODEL", "tts_models/de/thorsten/tacotron2-DDC")
try:
tts = TTS(model_name=model_name)
logging.info(f"Successfully loaded TTS model: {model_name}")
except Exception as exc:
logging.exception("Failed to load TTS model '%s'", model_name)
raise RuntimeError(
f"Failed to load TTS model '{model_name}': {exc}.\n"
"Set the environment variable TTS_MODEL to a valid model id in the format 'tts_models/<lang>/<dataset>/<model>' "
"or install the model manually."
) from exc
def _cleanup_file(path: str) -> None:
try:
os.remove(path)
except Exception:
logging.exception("Failed to remove temporary file %s", path)
@app.post("/speak") @app.post("/speak")
def speak(payload: dict = Body(...)): def speak(payload: dict = Body(...), background_tasks: BackgroundTasks = None):
text = payload.get("text") text = payload.get("text")
language = payload.get("language", "de") language = payload.get("language", "de")
if not text: if not text:
@@ -21,8 +38,18 @@ def speak(payload: dict = Body(...)):
fd, path = tempfile.mkstemp(suffix=".wav") fd, path = tempfile.mkstemp(suffix=".wav")
os.close(fd) os.close(fd)
try: try:
# Check if the model is multilingual before passing language parameter
if hasattr(tts, 'is_multi_lingual') and tts.is_multi_lingual:
tts.tts_to_file(text=text, speaker=None, language=language, file_path=path) tts.tts_to_file(text=text, speaker=None, language=language, file_path=path)
else:
# For non-multilingual models, don't pass the language parameter
tts.tts_to_file(text=text, speaker=None, file_path=path)
# schedule cleanup after response has been sent
if background_tasks is not None:
background_tasks.add_task(_cleanup_file, path)
return FileResponse(path, media_type="audio/wav", filename="response.wav") return FileResponse(path, media_type="audio/wav", filename="response.wav")
finally: except Exception:
# FileResponse will stream the file; don't remove immediately. Consumer can manage cleanup. logging.exception("TTS generation failed")
pass # try to remove file immediately on failure
_cleanup_file(path)
raise HTTPException(status_code=500, detail="TTS generation failed")
+12
View File
@@ -31,17 +31,25 @@ export default function Home() {
} }
async function sendAudio(blob) { async function sendAudio(blob) {
try {
const form = new FormData() const form = new FormData()
// Convert webm to wav on the client is complex; many servers accept webm/ogg. // Convert webm to wav on the client is complex; many servers accept webm/ogg.
form.append('file', blob, 'recording.webm') form.append('file', blob, 'recording.webm')
console.log('Sending request to backend...')
const res = await fetch('http://localhost:8000/chat', { method: 'POST', body: form }) const res = await fetch('http://localhost:8000/chat', { method: 'POST', body: form })
console.log('Response received:', res.status, res.statusText)
if (!res.ok) { if (!res.ok) {
const text = await res.text() const text = await res.text()
alert('Error: ' + res.status + ' ' + text) alert('Error: ' + res.status + ' ' + text)
return return
} }
console.log('Converting response to blob...')
const audioBlob = await res.blob() const audioBlob = await res.blob()
console.log('Audio blob created, size:', audioBlob.size)
const url = URL.createObjectURL(audioBlob) const url = URL.createObjectURL(audioBlob)
if (audioRef.current) { if (audioRef.current) {
audioRef.current.src = url audioRef.current.src = url
@@ -49,6 +57,10 @@ export default function Home() {
setPlaying(true) setPlaying(true)
audioRef.current.onended = () => setPlaying(false) audioRef.current.onended = () => setPlaying(false)
} }
} catch (error) {
console.error('Error in sendAudio:', error)
alert('Failed to process audio: ' + error.message)
}
} }
return ( return (
+124 -11
View File
@@ -1,15 +1,34 @@
from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import httpx import httpx
import tempfile import tempfile
import shutil import shutil
import asyncio import asyncio
import os
import json
import logging
import traceback
app = FastAPI() app = FastAPI()
# Add CORS middleware to allow frontend requests
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, replace with specific origins like ["http://localhost:3000"]
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("middleware")
WHISPER_URL = "http://whisper:8001/transcribe" WHISPER_URL = "http://whisper:8001/transcribe"
COQUITTS_URL = "http://coquitts:8002/speak" COQUITTS_URL = "http://coquitts:8002/speak"
OLLAMA_URL = "http://ollama:11434/v1/complete" OLLAMA_URL = "http://ollama:11434/api/generate"
LLM_MODEL = os.getenv("LLM_MODEL", "gemma3:270m")
logger.info("Using LLM model: %s", LLM_MODEL)
@app.post("/chat") @app.post("/chat")
@@ -17,40 +36,134 @@ async def chat(file: UploadFile = File(...)):
if not file.content_type.startswith("audio"): if not file.content_type.startswith("audio"):
raise HTTPException(status_code=400, detail="File must be audio") raise HTTPException(status_code=400, detail="File must be audio")
tmp_path = None
try:
# save file to temp # save file to temp
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
contents = await file.read() contents = await file.read()
tmp.write(contents) tmp.write(contents)
tmp.flush() tmp.flush()
tmp_path = tmp.name tmp_path = tmp.name
logger.info("Saved uploaded audio to %s", tmp_path)
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
# Send audio to whisper # Send audio to whisper
try:
with open(tmp_path, "rb") as f: with open(tmp_path, "rb") as f:
files = {"file": ("audio.wav", f, "audio/wav")} files = {"file": ("audio.wav", f, "audio/wav")}
r = await client.post(WHISPER_URL, files=files, timeout=120.0) r = await client.post(WHISPER_URL, files=files, timeout=120.0)
except httpx.RequestError as e:
logger.exception("Request to Whisper failed")
raise HTTPException(status_code=502, detail=f"Whisper request failed: {e}")
if r.status_code != 200: if r.status_code != 200:
logger.error("Whisper returned non-200: %s body=%s", r.status_code, r.text)
raise HTTPException(status_code=502, detail=f"Whisper error: {r.status_code} {r.text}") raise HTTPException(status_code=502, detail=f"Whisper error: {r.status_code} {r.text}")
try:
text = r.json().get("text", "") text = r.json().get("text", "")
except Exception:
logger.exception("Failed to decode Whisper JSON: %s", r.text)
raise HTTPException(status_code=502, detail="Invalid JSON from Whisper")
# Send text to ollama for reasoning logger.info("Whisper transcribed text: %s", text)
# We assume Ollama HTTP API accepts JSON {"model":"<model>", "prompt":"..."}
ollama_payload = {"model": "llama2", "prompt": text} # Send text to ollama for reasoning. Ollama often streams incremental
ro = await client.post(OLLAMA_URL, json=ollama_payload, timeout=120.0) # JSON objects (one per line) instead of returning a single JSON body.
# Use a streaming request and parse JSON lines as they arrive.
ollama_payload = {"model": LLM_MODEL, "prompt": text}
try:
async with client.stream("POST", OLLAMA_URL, json=ollama_payload, timeout=120.0) as ro:
if ro.status_code != 200: if ro.status_code != 200:
raise HTTPException(status_code=502, detail=f"Ollama error: {ro.status_code} {ro.text}") # read body for error message
body = await ro.aread()
body_text = body.decode(errors="ignore")
logger.error("Ollama returned non-200: %s body=%s", ro.status_code, body_text)
raise HTTPException(status_code=502, detail=f"Ollama error: {ro.status_code} {body_text}")
answer_json = ro.json() parts = []
# Depending on API shape, try to extract text # Ollama can send multiple JSON objects line-by-line; collect 'response' parts
answer_text = answer_json.get("response") or answer_json.get("text") or answer_json.get("output") or str(answer_json) async for raw_line in ro.aiter_lines():
if not raw_line:
continue
line = raw_line.strip()
# handle Server-Sent Events style 'data: ...' prefixes
if line.startswith("data:"):
line = line[len("data:"):].strip()
try:
j = json.loads(line)
except Exception:
logger.debug("Skipping non-JSON line from Ollama: %s", line)
continue
# prefer 'response', fallback to other candidate fields
piece = j.get("response") or j.get("text") or j.get("output")
if piece:
parts.append(str(piece))
# stop when stream indicates completion
if j.get("done") or j.get("done_reason"):
break
answer_text = "".join(parts).strip()
except httpx.RequestError as e:
logger.exception("Request to Ollama failed")
raise HTTPException(status_code=502, detail=f"Ollama request failed: {e}")
if not answer_text:
# as a last resort, try to fetch full text non-streamed
try:
fallback = await client.post(OLLAMA_URL, json=ollama_payload, timeout=30.0)
if fallback.status_code == 200:
try:
fallback_json = fallback.json()
answer_text = (
fallback_json.get("response")
or fallback_json.get("text")
or fallback_json.get("output")
or str(fallback_json)
)
except Exception:
answer_text = (await fallback.aread()).decode(errors="ignore")
else:
logger.debug("Ollama fallback non-200: %s %s", fallback.status_code, await fallback.aread())
except Exception:
# nothing more to do; keep answer_text empty and let downstream handle it
logger.debug("Ollama fallback failed", exc_info=True)
logger.info("Ollama returned: %s", answer_text)
# Send answer to coquitts to generate German audio # Send answer to coquitts to generate German audio
coquitts_payload = {"text": answer_text, "language": "de"} coquitts_payload = {"text": answer_text, "language": "de"}
try:
co = await client.post(COQUITTS_URL, json=coquitts_payload, timeout=120.0) co = await client.post(COQUITTS_URL, json=coquitts_payload, timeout=120.0)
except httpx.RequestError as e:
logger.exception("Request to CoquiTTS failed")
raise HTTPException(status_code=502, detail=f"CoquiTTS request failed: {e}")
if co.status_code != 200: if co.status_code != 200:
logger.error("CoquiTTS returned non-200: %s body=%s", co.status_code, co.text)
raise HTTPException(status_code=502, detail=f"CoquiTTS error: {co.status_code} {co.text}") raise HTTPException(status_code=502, detail=f"CoquiTTS error: {co.status_code} {co.text}")
# stream the audio back logger.info("CoquiTTS returned audio, streaming back to client")
return StreamingResponse(co.aiter_bytes(), media_type="audio/wav")
# Get the audio content as bytes instead of streaming
audio_content = await co.aread()
# Return the audio as a regular response with proper headers
return StreamingResponse(
iter([audio_content]),
media_type="audio/wav",
headers={
"Content-Length": str(len(audio_content)),
"Accept-Ranges": "bytes"
}
)
finally:
# cleanup temp file
try:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
logger.info("Removed temp file %s", tmp_path)
except Exception:
logger.warning("Failed to remove temp file %s: %s", tmp_path, traceback.format_exc())
+1 -1
View File
@@ -1,6 +1,6 @@
fastapi==0.100.0 fastapi==0.100.0
uvicorn[standard]==0.22.0 uvicorn[standard]==0.22.0
whisper==1.1.10 openai-whisper
pydub==0.25.1 pydub==0.25.1
aiofiles==23.1.0 aiofiles==23.1.0
python-multipart==0.0.6 python-multipart==0.0.6
+70 -9
View File
@@ -3,30 +3,91 @@ from fastapi.responses import JSONResponse
import whisper import whisper
import tempfile import tempfile
import shutil import shutil
import os
import logging
from pydub import AudioSegment
logging.basicConfig(level=logging.INFO)
app = FastAPI() app = FastAPI()
model = whisper.load_model("small") # Load model at startup
try:
model = whisper.load_model("small")
except Exception:
logging.exception("Failed to load Whisper model")
# re-raise so container fails fast if model can't be loaded
raise
def convert_to_wav(src_path: str) -> str:
"""Convert an audio file (webm/ogg/mp3/...) to a 16 kHz mono WAV file using pydub/ffmpeg.
Returns path to the new WAV file (caller is responsible for cleanup).
"""
audio = AudioSegment.from_file(src_path)
audio = audio.set_frame_rate(16000).set_channels(1)
wav_tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
wav_tmp.close()
audio.export(wav_tmp.name, format="wav")
return wav_tmp.name
@app.post("/transcribe") @app.post("/transcribe")
async def transcribe(file: UploadFile = File(...)): async def transcribe(file: UploadFile = File(...)):
if not file.content_type.startswith("audio"): if not file.content_type or not file.content_type.startswith("audio"):
raise HTTPException(status_code=400, detail="File must be audio") raise HTTPException(status_code=400, detail="File must be audio")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: # preserve original extension if possible
filename = file.filename or "upload"
ext = os.path.splitext(filename)[1] or ""
if not ext:
# try to infer common extension from content-type
if "webm" in file.content_type:
ext = ".webm"
elif "ogg" in file.content_type or "opus" in file.content_type:
ext = ".ogg"
elif "mpeg" in file.content_type or "mp3" in file.content_type:
ext = ".mp3"
else:
ext = ".wav"
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp:
contents = await file.read() contents = await file.read()
tmp.write(contents) tmp.write(contents)
tmp.flush() tmp.flush()
tmp_path = tmp.name tmp_path = tmp.name
logging.info("Received upload %s (%d bytes, content-type=%s)", filename, os.path.getsize(tmp_path), file.content_type)
# If the uploaded file is not a WAV, convert it to WAV first to ensure ffmpeg/pydub compatibility.
wav_path = tmp_path
converted = False
try: try:
result = model.transcribe(tmp_path, language=None) if not tmp_path.lower().endswith('.wav'):
try:
wav_path = convert_to_wav(tmp_path)
converted = True
logging.info("Converted to wav: %s (size=%d)", wav_path, os.path.getsize(wav_path))
except Exception as e:
# conversion failed; return a helpful error including ffmpeg/pydub message
logging.exception("Failed to convert uploaded audio to wav")
# try to surface the underlying error text
raise HTTPException(status_code=400, detail=f"Failed to convert audio: {e}")
try:
result = model.transcribe(wav_path, language=None)
text = result.get("text", "") text = result.get("text", "")
finally: except RuntimeError as e:
try: # likely ffmpeg failed while loading audio; include error message for debugging
shutil.os.remove(tmp_path) logging.exception("Whisper failed to transcribe audio")
except Exception: raise HTTPException(status_code=500, detail=str(e))
pass
return JSONResponse({"text": text}) return JSONResponse({"text": text})
finally:
# cleanup temp files
for path in {tmp_path, wav_path}:
try:
if path and os.path.exists(path):
os.remove(path)
except Exception:
logging.exception("Failed to remove temp file %s", path)