diff --git a/.gitignore b/.gitignore index c107cee..57f6125 100644 --- a/.gitignore +++ b/.gitignore @@ -64,3 +64,6 @@ Thumbs.db .pytest_cache/ .coverage coverage.xml + +# Ollama specific +ollama-data/ \ No newline at end of file diff --git a/coquitts/Dockerfile b/coquitts/Dockerfile index 3c0ef51..6cb6072 100644 --- a/coquitts/Dockerfile +++ b/coquitts/Dockerfile @@ -8,6 +8,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ build-essential \ git \ wget \ + espeak-ng \ + espeak-ng-data \ && rm -rf /var/lib/apt/lists/* COPY requirements.txt ./ diff --git a/coquitts/requirements.txt b/coquitts/requirements.txt index a6b080e..dd1fe77 100644 --- a/coquitts/requirements.txt +++ b/coquitts/requirements.txt @@ -1,5 +1,5 @@ fastapi==0.100.0 uvicorn[standard]==0.22.0 -TTS>=0.12.0 +coqui-tts==0.27.2 soundfile==0.12.1 numpy==1.26.0 diff --git a/coquitts/server.py b/coquitts/server.py index 478a5cf..d9f0fa3 100644 --- a/coquitts/server.py +++ b/coquitts/server.py @@ -1,18 +1,35 @@ -from fastapi import FastAPI, HTTPException -from fastapi import Body +from fastapi import FastAPI, HTTPException, Body, BackgroundTasks from fastapi.responses import FileResponse, JSONResponse from TTS.api import TTS import tempfile import os +import logging app = FastAPI() -# Load a German-capable model. Model may be downloaded on first run. -tts = TTS(model_name="tts_models/de/thorsten_hsmm") +# Load TTS model at startup with a valid 4-part model name +model_name = os.environ.get("TTS_MODEL", "tts_models/de/thorsten/tacotron2-DDC") +try: + tts = TTS(model_name=model_name) + logging.info(f"Successfully loaded TTS model: {model_name}") +except Exception as exc: + logging.exception("Failed to load TTS model '%s'", model_name) + raise RuntimeError( + f"Failed to load TTS model '{model_name}': {exc}.\n" + "Set the environment variable TTS_MODEL to a valid model id in the format 'tts_models///' " + "or install the model manually." + ) from exc + + +def _cleanup_file(path: str) -> None: + try: + os.remove(path) + except Exception: + logging.exception("Failed to remove temporary file %s", path) @app.post("/speak") -def speak(payload: dict = Body(...)): +def speak(payload: dict = Body(...), background_tasks: BackgroundTasks = None): text = payload.get("text") language = payload.get("language", "de") if not text: @@ -21,8 +38,18 @@ def speak(payload: dict = Body(...)): fd, path = tempfile.mkstemp(suffix=".wav") os.close(fd) try: - tts.tts_to_file(text=text, speaker=None, language=language, file_path=path) + # Check if the model is multilingual before passing language parameter + if hasattr(tts, 'is_multi_lingual') and tts.is_multi_lingual: + tts.tts_to_file(text=text, speaker=None, language=language, file_path=path) + else: + # For non-multilingual models, don't pass the language parameter + tts.tts_to_file(text=text, speaker=None, file_path=path) + # schedule cleanup after response has been sent + if background_tasks is not None: + background_tasks.add_task(_cleanup_file, path) return FileResponse(path, media_type="audio/wav", filename="response.wav") - finally: - # FileResponse will stream the file; don't remove immediately. Consumer can manage cleanup. - pass + except Exception: + logging.exception("TTS generation failed") + # try to remove file immediately on failure + _cleanup_file(path) + raise HTTPException(status_code=500, detail="TTS generation failed") diff --git a/frontend/pages/index.js b/frontend/pages/index.js index e974590..6d8c681 100644 --- a/frontend/pages/index.js +++ b/frontend/pages/index.js @@ -31,23 +31,35 @@ export default function Home() { } async function sendAudio(blob) { - const form = new FormData() - // Convert webm to wav on the client is complex; many servers accept webm/ogg. - form.append('file', blob, 'recording.webm') + try { + const form = new FormData() + // Convert webm to wav on the client is complex; many servers accept webm/ogg. + form.append('file', blob, 'recording.webm') - const res = await fetch('http://localhost:8000/chat', { method: 'POST', body: form }) - if (!res.ok) { - const text = await res.text() - alert('Error: ' + res.status + ' ' + text) - return - } - const audioBlob = await res.blob() - const url = URL.createObjectURL(audioBlob) - if (audioRef.current) { - audioRef.current.src = url - audioRef.current.play() - setPlaying(true) - audioRef.current.onended = () => setPlaying(false) + console.log('Sending request to backend...') + const res = await fetch('http://localhost:8000/chat', { method: 'POST', body: form }) + console.log('Response received:', res.status, res.statusText) + + if (!res.ok) { + const text = await res.text() + alert('Error: ' + res.status + ' ' + text) + return + } + + console.log('Converting response to blob...') + const audioBlob = await res.blob() + console.log('Audio blob created, size:', audioBlob.size) + + const url = URL.createObjectURL(audioBlob) + if (audioRef.current) { + audioRef.current.src = url + audioRef.current.play() + setPlaying(true) + audioRef.current.onended = () => setPlaying(false) + } + } catch (error) { + console.error('Error in sendAudio:', error) + alert('Failed to process audio: ' + error.message) } } diff --git a/middleware/server.py b/middleware/server.py index 3dc23e2..86ccd43 100644 --- a/middleware/server.py +++ b/middleware/server.py @@ -1,15 +1,34 @@ from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import StreamingResponse, JSONResponse +from fastapi.middleware.cors import CORSMiddleware import httpx import tempfile import shutil import asyncio +import os +import json +import logging +import traceback app = FastAPI() +# Add CORS middleware to allow frontend requests +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # In production, replace with specific origins like ["http://localhost:3000"] + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("middleware") + WHISPER_URL = "http://whisper:8001/transcribe" COQUITTS_URL = "http://coquitts:8002/speak" -OLLAMA_URL = "http://ollama:11434/v1/complete" +OLLAMA_URL = "http://ollama:11434/api/generate" +LLM_MODEL = os.getenv("LLM_MODEL", "gemma3:270m") +logger.info("Using LLM model: %s", LLM_MODEL) @app.post("/chat") @@ -17,40 +36,134 @@ async def chat(file: UploadFile = File(...)): if not file.content_type.startswith("audio"): raise HTTPException(status_code=400, detail="File must be audio") - # save file to temp - with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: - contents = await file.read() - tmp.write(contents) - tmp.flush() - tmp_path = tmp.name + tmp_path = None + try: + # save file to temp + with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: + contents = await file.read() + tmp.write(contents) + tmp.flush() + tmp_path = tmp.name + logger.info("Saved uploaded audio to %s", tmp_path) - async with httpx.AsyncClient() as client: - # Send audio to whisper - with open(tmp_path, "rb") as f: - files = {"file": ("audio.wav", f, "audio/wav")} - r = await client.post(WHISPER_URL, files=files, timeout=120.0) + async with httpx.AsyncClient() as client: + # Send audio to whisper + try: + with open(tmp_path, "rb") as f: + files = {"file": ("audio.wav", f, "audio/wav")} + r = await client.post(WHISPER_URL, files=files, timeout=120.0) + except httpx.RequestError as e: + logger.exception("Request to Whisper failed") + raise HTTPException(status_code=502, detail=f"Whisper request failed: {e}") - if r.status_code != 200: - raise HTTPException(status_code=502, detail=f"Whisper error: {r.status_code} {r.text}") + if r.status_code != 200: + logger.error("Whisper returned non-200: %s body=%s", r.status_code, r.text) + raise HTTPException(status_code=502, detail=f"Whisper error: {r.status_code} {r.text}") - text = r.json().get("text", "") + try: + text = r.json().get("text", "") + except Exception: + logger.exception("Failed to decode Whisper JSON: %s", r.text) + raise HTTPException(status_code=502, detail="Invalid JSON from Whisper") - # Send text to ollama for reasoning - # We assume Ollama HTTP API accepts JSON {"model":"", "prompt":"..."} - ollama_payload = {"model": "llama2", "prompt": text} - ro = await client.post(OLLAMA_URL, json=ollama_payload, timeout=120.0) - if ro.status_code != 200: - raise HTTPException(status_code=502, detail=f"Ollama error: {ro.status_code} {ro.text}") + logger.info("Whisper transcribed text: %s", text) - answer_json = ro.json() - # Depending on API shape, try to extract text - answer_text = answer_json.get("response") or answer_json.get("text") or answer_json.get("output") or str(answer_json) + # Send text to ollama for reasoning. Ollama often streams incremental + # JSON objects (one per line) instead of returning a single JSON body. + # Use a streaming request and parse JSON lines as they arrive. + ollama_payload = {"model": LLM_MODEL, "prompt": text} + try: + async with client.stream("POST", OLLAMA_URL, json=ollama_payload, timeout=120.0) as ro: + if ro.status_code != 200: + # read body for error message + body = await ro.aread() + body_text = body.decode(errors="ignore") + logger.error("Ollama returned non-200: %s body=%s", ro.status_code, body_text) + raise HTTPException(status_code=502, detail=f"Ollama error: {ro.status_code} {body_text}") - # Send answer to coquitts to generate German audio - coquitts_payload = {"text": answer_text, "language": "de"} - co = await client.post(COQUITTS_URL, json=coquitts_payload, timeout=120.0) - if co.status_code != 200: - raise HTTPException(status_code=502, detail=f"CoquiTTS error: {co.status_code} {co.text}") + parts = [] + # Ollama can send multiple JSON objects line-by-line; collect 'response' parts + async for raw_line in ro.aiter_lines(): + if not raw_line: + continue + line = raw_line.strip() + # handle Server-Sent Events style 'data: ...' prefixes + if line.startswith("data:"): + line = line[len("data:"):].strip() + try: + j = json.loads(line) + except Exception: + logger.debug("Skipping non-JSON line from Ollama: %s", line) + continue - # stream the audio back - return StreamingResponse(co.aiter_bytes(), media_type="audio/wav") + # prefer 'response', fallback to other candidate fields + piece = j.get("response") or j.get("text") or j.get("output") + if piece: + parts.append(str(piece)) + + # stop when stream indicates completion + if j.get("done") or j.get("done_reason"): + break + + answer_text = "".join(parts).strip() + except httpx.RequestError as e: + logger.exception("Request to Ollama failed") + raise HTTPException(status_code=502, detail=f"Ollama request failed: {e}") + + if not answer_text: + # as a last resort, try to fetch full text non-streamed + try: + fallback = await client.post(OLLAMA_URL, json=ollama_payload, timeout=30.0) + if fallback.status_code == 200: + try: + fallback_json = fallback.json() + answer_text = ( + fallback_json.get("response") + or fallback_json.get("text") + or fallback_json.get("output") + or str(fallback_json) + ) + except Exception: + answer_text = (await fallback.aread()).decode(errors="ignore") + else: + logger.debug("Ollama fallback non-200: %s %s", fallback.status_code, await fallback.aread()) + except Exception: + # nothing more to do; keep answer_text empty and let downstream handle it + logger.debug("Ollama fallback failed", exc_info=True) + + logger.info("Ollama returned: %s", answer_text) + + # Send answer to coquitts to generate German audio + coquitts_payload = {"text": answer_text, "language": "de"} + try: + co = await client.post(COQUITTS_URL, json=coquitts_payload, timeout=120.0) + except httpx.RequestError as e: + logger.exception("Request to CoquiTTS failed") + raise HTTPException(status_code=502, detail=f"CoquiTTS request failed: {e}") + + if co.status_code != 200: + logger.error("CoquiTTS returned non-200: %s body=%s", co.status_code, co.text) + raise HTTPException(status_code=502, detail=f"CoquiTTS error: {co.status_code} {co.text}") + + logger.info("CoquiTTS returned audio, streaming back to client") + + # Get the audio content as bytes instead of streaming + audio_content = await co.aread() + + # Return the audio as a regular response with proper headers + return StreamingResponse( + iter([audio_content]), + media_type="audio/wav", + headers={ + "Content-Length": str(len(audio_content)), + "Accept-Ranges": "bytes" + } + ) + finally: + # cleanup temp file + try: + if tmp_path and os.path.exists(tmp_path): + os.remove(tmp_path) + logger.info("Removed temp file %s", tmp_path) + except Exception: + logger.warning("Failed to remove temp file %s: %s", tmp_path, traceback.format_exc()) diff --git a/whisper/requirements.txt b/whisper/requirements.txt index 54b866b..160f002 100644 --- a/whisper/requirements.txt +++ b/whisper/requirements.txt @@ -1,6 +1,6 @@ fastapi==0.100.0 uvicorn[standard]==0.22.0 -whisper==1.1.10 +openai-whisper pydub==0.25.1 aiofiles==23.1.0 python-multipart==0.0.6 diff --git a/whisper/server.py b/whisper/server.py index 5791250..ca1e2f3 100644 --- a/whisper/server.py +++ b/whisper/server.py @@ -3,30 +3,91 @@ from fastapi.responses import JSONResponse import whisper import tempfile import shutil +import os +import logging +from pydub import AudioSegment +logging.basicConfig(level=logging.INFO) app = FastAPI() -model = whisper.load_model("small") +# Load model at startup +try: + model = whisper.load_model("small") +except Exception: + logging.exception("Failed to load Whisper model") + # re-raise so container fails fast if model can't be loaded + raise + + +def convert_to_wav(src_path: str) -> str: + """Convert an audio file (webm/ogg/mp3/...) to a 16 kHz mono WAV file using pydub/ffmpeg. + + Returns path to the new WAV file (caller is responsible for cleanup). + """ + audio = AudioSegment.from_file(src_path) + audio = audio.set_frame_rate(16000).set_channels(1) + wav_tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") + wav_tmp.close() + audio.export(wav_tmp.name, format="wav") + return wav_tmp.name @app.post("/transcribe") async def transcribe(file: UploadFile = File(...)): - if not file.content_type.startswith("audio"): + if not file.content_type or not file.content_type.startswith("audio"): raise HTTPException(status_code=400, detail="File must be audio") - with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: + # preserve original extension if possible + filename = file.filename or "upload" + ext = os.path.splitext(filename)[1] or "" + if not ext: + # try to infer common extension from content-type + if "webm" in file.content_type: + ext = ".webm" + elif "ogg" in file.content_type or "opus" in file.content_type: + ext = ".ogg" + elif "mpeg" in file.content_type or "mp3" in file.content_type: + ext = ".mp3" + else: + ext = ".wav" + + with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp: contents = await file.read() tmp.write(contents) tmp.flush() tmp_path = tmp.name - try: - result = model.transcribe(tmp_path, language=None) - text = result.get("text", "") - finally: - try: - shutil.os.remove(tmp_path) - except Exception: - pass + logging.info("Received upload %s (%d bytes, content-type=%s)", filename, os.path.getsize(tmp_path), file.content_type) - return JSONResponse({"text": text}) + # If the uploaded file is not a WAV, convert it to WAV first to ensure ffmpeg/pydub compatibility. + wav_path = tmp_path + converted = False + try: + if not tmp_path.lower().endswith('.wav'): + try: + wav_path = convert_to_wav(tmp_path) + converted = True + logging.info("Converted to wav: %s (size=%d)", wav_path, os.path.getsize(wav_path)) + except Exception as e: + # conversion failed; return a helpful error including ffmpeg/pydub message + logging.exception("Failed to convert uploaded audio to wav") + # try to surface the underlying error text + raise HTTPException(status_code=400, detail=f"Failed to convert audio: {e}") + + try: + result = model.transcribe(wav_path, language=None) + text = result.get("text", "") + except RuntimeError as e: + # likely ffmpeg failed while loading audio; include error message for debugging + logging.exception("Whisper failed to transcribe audio") + raise HTTPException(status_code=500, detail=str(e)) + + return JSONResponse({"text": text}) + finally: + # cleanup temp files + for path in {tmp_path, wav_path}: + try: + if path and os.path.exists(path): + os.remove(path) + except Exception: + logging.exception("Failed to remove temp file %s", path)