mvp
This commit is contained in:
@@ -64,3 +64,6 @@ Thumbs.db
|
|||||||
.pytest_cache/
|
.pytest_cache/
|
||||||
.coverage
|
.coverage
|
||||||
coverage.xml
|
coverage.xml
|
||||||
|
|
||||||
|
# Ollama specific
|
||||||
|
ollama-data/
|
||||||
@@ -8,6 +8,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|||||||
build-essential \
|
build-essential \
|
||||||
git \
|
git \
|
||||||
wget \
|
wget \
|
||||||
|
espeak-ng \
|
||||||
|
espeak-ng-data \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
COPY requirements.txt ./
|
COPY requirements.txt ./
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
fastapi==0.100.0
|
fastapi==0.100.0
|
||||||
uvicorn[standard]==0.22.0
|
uvicorn[standard]==0.22.0
|
||||||
TTS>=0.12.0
|
coqui-tts==0.27.2
|
||||||
soundfile==0.12.1
|
soundfile==0.12.1
|
||||||
numpy==1.26.0
|
numpy==1.26.0
|
||||||
|
|||||||
+35
-8
@@ -1,18 +1,35 @@
|
|||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException, Body, BackgroundTasks
|
||||||
from fastapi import Body
|
|
||||||
from fastapi.responses import FileResponse, JSONResponse
|
from fastapi.responses import FileResponse, JSONResponse
|
||||||
from TTS.api import TTS
|
from TTS.api import TTS
|
||||||
import tempfile
|
import tempfile
|
||||||
import os
|
import os
|
||||||
|
import logging
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
# Load a German-capable model. Model may be downloaded on first run.
|
# Load TTS model at startup with a valid 4-part model name
|
||||||
tts = TTS(model_name="tts_models/de/thorsten_hsmm")
|
model_name = os.environ.get("TTS_MODEL", "tts_models/de/thorsten/tacotron2-DDC")
|
||||||
|
try:
|
||||||
|
tts = TTS(model_name=model_name)
|
||||||
|
logging.info(f"Successfully loaded TTS model: {model_name}")
|
||||||
|
except Exception as exc:
|
||||||
|
logging.exception("Failed to load TTS model '%s'", model_name)
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to load TTS model '{model_name}': {exc}.\n"
|
||||||
|
"Set the environment variable TTS_MODEL to a valid model id in the format 'tts_models/<lang>/<dataset>/<model>' "
|
||||||
|
"or install the model manually."
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
|
||||||
|
def _cleanup_file(path: str) -> None:
|
||||||
|
try:
|
||||||
|
os.remove(path)
|
||||||
|
except Exception:
|
||||||
|
logging.exception("Failed to remove temporary file %s", path)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/speak")
|
@app.post("/speak")
|
||||||
def speak(payload: dict = Body(...)):
|
def speak(payload: dict = Body(...), background_tasks: BackgroundTasks = None):
|
||||||
text = payload.get("text")
|
text = payload.get("text")
|
||||||
language = payload.get("language", "de")
|
language = payload.get("language", "de")
|
||||||
if not text:
|
if not text:
|
||||||
@@ -21,8 +38,18 @@ def speak(payload: dict = Body(...)):
|
|||||||
fd, path = tempfile.mkstemp(suffix=".wav")
|
fd, path = tempfile.mkstemp(suffix=".wav")
|
||||||
os.close(fd)
|
os.close(fd)
|
||||||
try:
|
try:
|
||||||
|
# Check if the model is multilingual before passing language parameter
|
||||||
|
if hasattr(tts, 'is_multi_lingual') and tts.is_multi_lingual:
|
||||||
tts.tts_to_file(text=text, speaker=None, language=language, file_path=path)
|
tts.tts_to_file(text=text, speaker=None, language=language, file_path=path)
|
||||||
|
else:
|
||||||
|
# For non-multilingual models, don't pass the language parameter
|
||||||
|
tts.tts_to_file(text=text, speaker=None, file_path=path)
|
||||||
|
# schedule cleanup after response has been sent
|
||||||
|
if background_tasks is not None:
|
||||||
|
background_tasks.add_task(_cleanup_file, path)
|
||||||
return FileResponse(path, media_type="audio/wav", filename="response.wav")
|
return FileResponse(path, media_type="audio/wav", filename="response.wav")
|
||||||
finally:
|
except Exception:
|
||||||
# FileResponse will stream the file; don't remove immediately. Consumer can manage cleanup.
|
logging.exception("TTS generation failed")
|
||||||
pass
|
# try to remove file immediately on failure
|
||||||
|
_cleanup_file(path)
|
||||||
|
raise HTTPException(status_code=500, detail="TTS generation failed")
|
||||||
|
|||||||
@@ -31,17 +31,25 @@ export default function Home() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async function sendAudio(blob) {
|
async function sendAudio(blob) {
|
||||||
|
try {
|
||||||
const form = new FormData()
|
const form = new FormData()
|
||||||
// Convert webm to wav on the client is complex; many servers accept webm/ogg.
|
// Convert webm to wav on the client is complex; many servers accept webm/ogg.
|
||||||
form.append('file', blob, 'recording.webm')
|
form.append('file', blob, 'recording.webm')
|
||||||
|
|
||||||
|
console.log('Sending request to backend...')
|
||||||
const res = await fetch('http://localhost:8000/chat', { method: 'POST', body: form })
|
const res = await fetch('http://localhost:8000/chat', { method: 'POST', body: form })
|
||||||
|
console.log('Response received:', res.status, res.statusText)
|
||||||
|
|
||||||
if (!res.ok) {
|
if (!res.ok) {
|
||||||
const text = await res.text()
|
const text = await res.text()
|
||||||
alert('Error: ' + res.status + ' ' + text)
|
alert('Error: ' + res.status + ' ' + text)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
console.log('Converting response to blob...')
|
||||||
const audioBlob = await res.blob()
|
const audioBlob = await res.blob()
|
||||||
|
console.log('Audio blob created, size:', audioBlob.size)
|
||||||
|
|
||||||
const url = URL.createObjectURL(audioBlob)
|
const url = URL.createObjectURL(audioBlob)
|
||||||
if (audioRef.current) {
|
if (audioRef.current) {
|
||||||
audioRef.current.src = url
|
audioRef.current.src = url
|
||||||
@@ -49,6 +57,10 @@ export default function Home() {
|
|||||||
setPlaying(true)
|
setPlaying(true)
|
||||||
audioRef.current.onended = () => setPlaying(false)
|
audioRef.current.onended = () => setPlaying(false)
|
||||||
}
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error in sendAudio:', error)
|
||||||
|
alert('Failed to process audio: ' + error.message)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
|||||||
+124
-11
@@ -1,15 +1,34 @@
|
|||||||
from fastapi import FastAPI, File, UploadFile, HTTPException
|
from fastapi import FastAPI, File, UploadFile, HTTPException
|
||||||
from fastapi.responses import StreamingResponse, JSONResponse
|
from fastapi.responses import StreamingResponse, JSONResponse
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
import httpx
|
import httpx
|
||||||
import tempfile
|
import tempfile
|
||||||
import shutil
|
import shutil
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import traceback
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
# Add CORS middleware to allow frontend requests
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"], # In production, replace with specific origins like ["http://localhost:3000"]
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger("middleware")
|
||||||
|
|
||||||
WHISPER_URL = "http://whisper:8001/transcribe"
|
WHISPER_URL = "http://whisper:8001/transcribe"
|
||||||
COQUITTS_URL = "http://coquitts:8002/speak"
|
COQUITTS_URL = "http://coquitts:8002/speak"
|
||||||
OLLAMA_URL = "http://ollama:11434/v1/complete"
|
OLLAMA_URL = "http://ollama:11434/api/generate"
|
||||||
|
LLM_MODEL = os.getenv("LLM_MODEL", "gemma3:270m")
|
||||||
|
logger.info("Using LLM model: %s", LLM_MODEL)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/chat")
|
@app.post("/chat")
|
||||||
@@ -17,40 +36,134 @@ async def chat(file: UploadFile = File(...)):
|
|||||||
if not file.content_type.startswith("audio"):
|
if not file.content_type.startswith("audio"):
|
||||||
raise HTTPException(status_code=400, detail="File must be audio")
|
raise HTTPException(status_code=400, detail="File must be audio")
|
||||||
|
|
||||||
|
tmp_path = None
|
||||||
|
try:
|
||||||
# save file to temp
|
# save file to temp
|
||||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
||||||
contents = await file.read()
|
contents = await file.read()
|
||||||
tmp.write(contents)
|
tmp.write(contents)
|
||||||
tmp.flush()
|
tmp.flush()
|
||||||
tmp_path = tmp.name
|
tmp_path = tmp.name
|
||||||
|
logger.info("Saved uploaded audio to %s", tmp_path)
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
# Send audio to whisper
|
# Send audio to whisper
|
||||||
|
try:
|
||||||
with open(tmp_path, "rb") as f:
|
with open(tmp_path, "rb") as f:
|
||||||
files = {"file": ("audio.wav", f, "audio/wav")}
|
files = {"file": ("audio.wav", f, "audio/wav")}
|
||||||
r = await client.post(WHISPER_URL, files=files, timeout=120.0)
|
r = await client.post(WHISPER_URL, files=files, timeout=120.0)
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
logger.exception("Request to Whisper failed")
|
||||||
|
raise HTTPException(status_code=502, detail=f"Whisper request failed: {e}")
|
||||||
|
|
||||||
if r.status_code != 200:
|
if r.status_code != 200:
|
||||||
|
logger.error("Whisper returned non-200: %s body=%s", r.status_code, r.text)
|
||||||
raise HTTPException(status_code=502, detail=f"Whisper error: {r.status_code} {r.text}")
|
raise HTTPException(status_code=502, detail=f"Whisper error: {r.status_code} {r.text}")
|
||||||
|
|
||||||
|
try:
|
||||||
text = r.json().get("text", "")
|
text = r.json().get("text", "")
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to decode Whisper JSON: %s", r.text)
|
||||||
|
raise HTTPException(status_code=502, detail="Invalid JSON from Whisper")
|
||||||
|
|
||||||
# Send text to ollama for reasoning
|
logger.info("Whisper transcribed text: %s", text)
|
||||||
# We assume Ollama HTTP API accepts JSON {"model":"<model>", "prompt":"..."}
|
|
||||||
ollama_payload = {"model": "llama2", "prompt": text}
|
# Send text to ollama for reasoning. Ollama often streams incremental
|
||||||
ro = await client.post(OLLAMA_URL, json=ollama_payload, timeout=120.0)
|
# JSON objects (one per line) instead of returning a single JSON body.
|
||||||
|
# Use a streaming request and parse JSON lines as they arrive.
|
||||||
|
ollama_payload = {"model": LLM_MODEL, "prompt": text}
|
||||||
|
try:
|
||||||
|
async with client.stream("POST", OLLAMA_URL, json=ollama_payload, timeout=120.0) as ro:
|
||||||
if ro.status_code != 200:
|
if ro.status_code != 200:
|
||||||
raise HTTPException(status_code=502, detail=f"Ollama error: {ro.status_code} {ro.text}")
|
# read body for error message
|
||||||
|
body = await ro.aread()
|
||||||
|
body_text = body.decode(errors="ignore")
|
||||||
|
logger.error("Ollama returned non-200: %s body=%s", ro.status_code, body_text)
|
||||||
|
raise HTTPException(status_code=502, detail=f"Ollama error: {ro.status_code} {body_text}")
|
||||||
|
|
||||||
answer_json = ro.json()
|
parts = []
|
||||||
# Depending on API shape, try to extract text
|
# Ollama can send multiple JSON objects line-by-line; collect 'response' parts
|
||||||
answer_text = answer_json.get("response") or answer_json.get("text") or answer_json.get("output") or str(answer_json)
|
async for raw_line in ro.aiter_lines():
|
||||||
|
if not raw_line:
|
||||||
|
continue
|
||||||
|
line = raw_line.strip()
|
||||||
|
# handle Server-Sent Events style 'data: ...' prefixes
|
||||||
|
if line.startswith("data:"):
|
||||||
|
line = line[len("data:"):].strip()
|
||||||
|
try:
|
||||||
|
j = json.loads(line)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Skipping non-JSON line from Ollama: %s", line)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# prefer 'response', fallback to other candidate fields
|
||||||
|
piece = j.get("response") or j.get("text") or j.get("output")
|
||||||
|
if piece:
|
||||||
|
parts.append(str(piece))
|
||||||
|
|
||||||
|
# stop when stream indicates completion
|
||||||
|
if j.get("done") or j.get("done_reason"):
|
||||||
|
break
|
||||||
|
|
||||||
|
answer_text = "".join(parts).strip()
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
logger.exception("Request to Ollama failed")
|
||||||
|
raise HTTPException(status_code=502, detail=f"Ollama request failed: {e}")
|
||||||
|
|
||||||
|
if not answer_text:
|
||||||
|
# as a last resort, try to fetch full text non-streamed
|
||||||
|
try:
|
||||||
|
fallback = await client.post(OLLAMA_URL, json=ollama_payload, timeout=30.0)
|
||||||
|
if fallback.status_code == 200:
|
||||||
|
try:
|
||||||
|
fallback_json = fallback.json()
|
||||||
|
answer_text = (
|
||||||
|
fallback_json.get("response")
|
||||||
|
or fallback_json.get("text")
|
||||||
|
or fallback_json.get("output")
|
||||||
|
or str(fallback_json)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
answer_text = (await fallback.aread()).decode(errors="ignore")
|
||||||
|
else:
|
||||||
|
logger.debug("Ollama fallback non-200: %s %s", fallback.status_code, await fallback.aread())
|
||||||
|
except Exception:
|
||||||
|
# nothing more to do; keep answer_text empty and let downstream handle it
|
||||||
|
logger.debug("Ollama fallback failed", exc_info=True)
|
||||||
|
|
||||||
|
logger.info("Ollama returned: %s", answer_text)
|
||||||
|
|
||||||
# Send answer to coquitts to generate German audio
|
# Send answer to coquitts to generate German audio
|
||||||
coquitts_payload = {"text": answer_text, "language": "de"}
|
coquitts_payload = {"text": answer_text, "language": "de"}
|
||||||
|
try:
|
||||||
co = await client.post(COQUITTS_URL, json=coquitts_payload, timeout=120.0)
|
co = await client.post(COQUITTS_URL, json=coquitts_payload, timeout=120.0)
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
logger.exception("Request to CoquiTTS failed")
|
||||||
|
raise HTTPException(status_code=502, detail=f"CoquiTTS request failed: {e}")
|
||||||
|
|
||||||
if co.status_code != 200:
|
if co.status_code != 200:
|
||||||
|
logger.error("CoquiTTS returned non-200: %s body=%s", co.status_code, co.text)
|
||||||
raise HTTPException(status_code=502, detail=f"CoquiTTS error: {co.status_code} {co.text}")
|
raise HTTPException(status_code=502, detail=f"CoquiTTS error: {co.status_code} {co.text}")
|
||||||
|
|
||||||
# stream the audio back
|
logger.info("CoquiTTS returned audio, streaming back to client")
|
||||||
return StreamingResponse(co.aiter_bytes(), media_type="audio/wav")
|
|
||||||
|
# Get the audio content as bytes instead of streaming
|
||||||
|
audio_content = await co.aread()
|
||||||
|
|
||||||
|
# Return the audio as a regular response with proper headers
|
||||||
|
return StreamingResponse(
|
||||||
|
iter([audio_content]),
|
||||||
|
media_type="audio/wav",
|
||||||
|
headers={
|
||||||
|
"Content-Length": str(len(audio_content)),
|
||||||
|
"Accept-Ranges": "bytes"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
# cleanup temp file
|
||||||
|
try:
|
||||||
|
if tmp_path and os.path.exists(tmp_path):
|
||||||
|
os.remove(tmp_path)
|
||||||
|
logger.info("Removed temp file %s", tmp_path)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to remove temp file %s: %s", tmp_path, traceback.format_exc())
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
fastapi==0.100.0
|
fastapi==0.100.0
|
||||||
uvicorn[standard]==0.22.0
|
uvicorn[standard]==0.22.0
|
||||||
whisper==1.1.10
|
openai-whisper
|
||||||
pydub==0.25.1
|
pydub==0.25.1
|
||||||
aiofiles==23.1.0
|
aiofiles==23.1.0
|
||||||
python-multipart==0.0.6
|
python-multipart==0.0.6
|
||||||
|
|||||||
+69
-8
@@ -3,30 +3,91 @@ from fastapi.responses import JSONResponse
|
|||||||
import whisper
|
import whisper
|
||||||
import tempfile
|
import tempfile
|
||||||
import shutil
|
import shutil
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from pydub import AudioSegment
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
# Load model at startup
|
||||||
|
try:
|
||||||
model = whisper.load_model("small")
|
model = whisper.load_model("small")
|
||||||
|
except Exception:
|
||||||
|
logging.exception("Failed to load Whisper model")
|
||||||
|
# re-raise so container fails fast if model can't be loaded
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_wav(src_path: str) -> str:
|
||||||
|
"""Convert an audio file (webm/ogg/mp3/...) to a 16 kHz mono WAV file using pydub/ffmpeg.
|
||||||
|
|
||||||
|
Returns path to the new WAV file (caller is responsible for cleanup).
|
||||||
|
"""
|
||||||
|
audio = AudioSegment.from_file(src_path)
|
||||||
|
audio = audio.set_frame_rate(16000).set_channels(1)
|
||||||
|
wav_tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
|
||||||
|
wav_tmp.close()
|
||||||
|
audio.export(wav_tmp.name, format="wav")
|
||||||
|
return wav_tmp.name
|
||||||
|
|
||||||
|
|
||||||
@app.post("/transcribe")
|
@app.post("/transcribe")
|
||||||
async def transcribe(file: UploadFile = File(...)):
|
async def transcribe(file: UploadFile = File(...)):
|
||||||
if not file.content_type.startswith("audio"):
|
if not file.content_type or not file.content_type.startswith("audio"):
|
||||||
raise HTTPException(status_code=400, detail="File must be audio")
|
raise HTTPException(status_code=400, detail="File must be audio")
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
# preserve original extension if possible
|
||||||
|
filename = file.filename or "upload"
|
||||||
|
ext = os.path.splitext(filename)[1] or ""
|
||||||
|
if not ext:
|
||||||
|
# try to infer common extension from content-type
|
||||||
|
if "webm" in file.content_type:
|
||||||
|
ext = ".webm"
|
||||||
|
elif "ogg" in file.content_type or "opus" in file.content_type:
|
||||||
|
ext = ".ogg"
|
||||||
|
elif "mpeg" in file.content_type or "mp3" in file.content_type:
|
||||||
|
ext = ".mp3"
|
||||||
|
else:
|
||||||
|
ext = ".wav"
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp:
|
||||||
contents = await file.read()
|
contents = await file.read()
|
||||||
tmp.write(contents)
|
tmp.write(contents)
|
||||||
tmp.flush()
|
tmp.flush()
|
||||||
tmp_path = tmp.name
|
tmp_path = tmp.name
|
||||||
|
|
||||||
|
logging.info("Received upload %s (%d bytes, content-type=%s)", filename, os.path.getsize(tmp_path), file.content_type)
|
||||||
|
|
||||||
|
# If the uploaded file is not a WAV, convert it to WAV first to ensure ffmpeg/pydub compatibility.
|
||||||
|
wav_path = tmp_path
|
||||||
|
converted = False
|
||||||
try:
|
try:
|
||||||
result = model.transcribe(tmp_path, language=None)
|
if not tmp_path.lower().endswith('.wav'):
|
||||||
|
try:
|
||||||
|
wav_path = convert_to_wav(tmp_path)
|
||||||
|
converted = True
|
||||||
|
logging.info("Converted to wav: %s (size=%d)", wav_path, os.path.getsize(wav_path))
|
||||||
|
except Exception as e:
|
||||||
|
# conversion failed; return a helpful error including ffmpeg/pydub message
|
||||||
|
logging.exception("Failed to convert uploaded audio to wav")
|
||||||
|
# try to surface the underlying error text
|
||||||
|
raise HTTPException(status_code=400, detail=f"Failed to convert audio: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = model.transcribe(wav_path, language=None)
|
||||||
text = result.get("text", "")
|
text = result.get("text", "")
|
||||||
finally:
|
except RuntimeError as e:
|
||||||
try:
|
# likely ffmpeg failed while loading audio; include error message for debugging
|
||||||
shutil.os.remove(tmp_path)
|
logging.exception("Whisper failed to transcribe audio")
|
||||||
except Exception:
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
pass
|
|
||||||
|
|
||||||
return JSONResponse({"text": text})
|
return JSONResponse({"text": text})
|
||||||
|
finally:
|
||||||
|
# cleanup temp files
|
||||||
|
for path in {tmp_path, wav_path}:
|
||||||
|
try:
|
||||||
|
if path and os.path.exists(path):
|
||||||
|
os.remove(path)
|
||||||
|
except Exception:
|
||||||
|
logging.exception("Failed to remove temp file %s", path)
|
||||||
|
|||||||
Reference in New Issue
Block a user