170 lines
7.4 KiB
Python
170 lines
7.4 KiB
Python
from fastapi import FastAPI, File, UploadFile, HTTPException
|
|
from fastapi.responses import StreamingResponse, JSONResponse
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
import httpx
|
|
import tempfile
|
|
import shutil
|
|
import asyncio
|
|
import os
|
|
import json
|
|
import logging
|
|
import traceback
|
|
|
|
app = FastAPI()
|
|
|
|
# Add CORS middleware to allow frontend requests
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"], # In production, replace with specific origins like ["http://localhost:3000"]
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger("middleware")
|
|
|
|
WHISPER_URL = "http://whisper:8001/transcribe"
|
|
COQUITTS_URL = "http://coquitts:8002/speak"
|
|
OLLAMA_URL = "http://ollama.ahlgrim.net:11434/api/generate"
|
|
LLM_MODEL = os.getenv("LLM_MODEL", "gemma3:270m")
|
|
logger.info("Using LLM model: %s", LLM_MODEL)
|
|
|
|
|
|
@app.post("/chat")
|
|
async def chat(file: UploadFile = File(...)):
|
|
if not file.content_type.startswith("audio"):
|
|
raise HTTPException(status_code=400, detail="File must be audio")
|
|
|
|
tmp_path = None
|
|
try:
|
|
# save file to temp
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
|
contents = await file.read()
|
|
tmp.write(contents)
|
|
tmp.flush()
|
|
tmp_path = tmp.name
|
|
logger.info("Saved uploaded audio to %s", tmp_path)
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
# Send audio to whisper
|
|
try:
|
|
with open(tmp_path, "rb") as f:
|
|
files = {"file": ("audio.wav", f, "audio/wav")}
|
|
r = await client.post(WHISPER_URL, files=files, timeout=120.0)
|
|
except httpx.RequestError as e:
|
|
logger.exception("Request to Whisper failed")
|
|
raise HTTPException(status_code=502, detail=f"Whisper request failed: {e}")
|
|
|
|
if r.status_code != 200:
|
|
logger.error("Whisper returned non-200: %s body=%s", r.status_code, r.text)
|
|
raise HTTPException(status_code=502, detail=f"Whisper error: {r.status_code} {r.text}")
|
|
|
|
try:
|
|
text = r.json().get("text", "")
|
|
except Exception:
|
|
logger.exception("Failed to decode Whisper JSON: %s", r.text)
|
|
raise HTTPException(status_code=502, detail="Invalid JSON from Whisper")
|
|
|
|
logger.info("Whisper transcribed text: %s", text)
|
|
|
|
# Send text to ollama for reasoning. Ollama often streams incremental
|
|
# JSON objects (one per line) instead of returning a single JSON body.
|
|
# Use a streaming request and parse JSON lines as they arrive.
|
|
ollama_payload = {"model": LLM_MODEL, "prompt": text}
|
|
try:
|
|
async with client.stream("POST", OLLAMA_URL, json=ollama_payload, timeout=120.0) as ro:
|
|
if ro.status_code != 200:
|
|
# read body for error message
|
|
body = await ro.aread()
|
|
body_text = body.decode(errors="ignore")
|
|
logger.error("Ollama returned non-200: %s body=%s", ro.status_code, body_text)
|
|
raise HTTPException(status_code=502, detail=f"Ollama error: {ro.status_code} {body_text}")
|
|
|
|
parts = []
|
|
# Ollama can send multiple JSON objects line-by-line; collect 'response' parts
|
|
async for raw_line in ro.aiter_lines():
|
|
if not raw_line:
|
|
continue
|
|
line = raw_line.strip()
|
|
# handle Server-Sent Events style 'data: ...' prefixes
|
|
if line.startswith("data:"):
|
|
line = line[len("data:"):].strip()
|
|
try:
|
|
j = json.loads(line)
|
|
except Exception:
|
|
logger.debug("Skipping non-JSON line from Ollama: %s", line)
|
|
continue
|
|
|
|
# prefer 'response', fallback to other candidate fields
|
|
piece = j.get("response") or j.get("text") or j.get("output")
|
|
if piece:
|
|
parts.append(str(piece))
|
|
|
|
# stop when stream indicates completion
|
|
if j.get("done") or j.get("done_reason"):
|
|
break
|
|
|
|
answer_text = "".join(parts).strip()
|
|
except httpx.RequestError as e:
|
|
logger.exception("Request to Ollama failed")
|
|
raise HTTPException(status_code=502, detail=f"Ollama request failed: {e}")
|
|
|
|
if not answer_text:
|
|
# as a last resort, try to fetch full text non-streamed
|
|
try:
|
|
fallback = await client.post(OLLAMA_URL, json=ollama_payload, timeout=30.0)
|
|
if fallback.status_code == 200:
|
|
try:
|
|
fallback_json = fallback.json()
|
|
answer_text = (
|
|
fallback_json.get("response")
|
|
or fallback_json.get("text")
|
|
or fallback_json.get("output")
|
|
or str(fallback_json)
|
|
)
|
|
except Exception:
|
|
answer_text = (await fallback.aread()).decode(errors="ignore")
|
|
else:
|
|
logger.debug("Ollama fallback non-200: %s %s", fallback.status_code, await fallback.aread())
|
|
except Exception:
|
|
# nothing more to do; keep answer_text empty and let downstream handle it
|
|
logger.debug("Ollama fallback failed", exc_info=True)
|
|
|
|
logger.info("Ollama returned: %s", answer_text)
|
|
|
|
# Send answer to coquitts to generate German audio
|
|
coquitts_payload = {"text": answer_text, "language": "de"}
|
|
try:
|
|
co = await client.post(COQUITTS_URL, json=coquitts_payload, timeout=120.0)
|
|
except httpx.RequestError as e:
|
|
logger.exception("Request to CoquiTTS failed")
|
|
raise HTTPException(status_code=502, detail=f"CoquiTTS request failed: {e}")
|
|
|
|
if co.status_code != 200:
|
|
logger.error("CoquiTTS returned non-200: %s body=%s", co.status_code, co.text)
|
|
raise HTTPException(status_code=502, detail=f"CoquiTTS error: {co.status_code} {co.text}")
|
|
|
|
logger.info("CoquiTTS returned audio, streaming back to client")
|
|
|
|
# Get the audio content as bytes instead of streaming
|
|
audio_content = await co.aread()
|
|
|
|
# Return the audio as a regular response with proper headers
|
|
return StreamingResponse(
|
|
iter([audio_content]),
|
|
media_type="audio/wav",
|
|
headers={
|
|
"Content-Length": str(len(audio_content)),
|
|
"Accept-Ranges": "bytes"
|
|
}
|
|
)
|
|
finally:
|
|
# cleanup temp file
|
|
try:
|
|
if tmp_path and os.path.exists(tmp_path):
|
|
os.remove(tmp_path)
|
|
logger.info("Removed temp file %s", tmp_path)
|
|
except Exception:
|
|
logger.warning("Failed to remove temp file %s: %s", tmp_path, traceback.format_exc())
|