diff --git a/bridge/aria_bridge.py b/bridge/aria_bridge.py index a064e79..2a40927 100644 --- a/bridge/aria_bridge.py +++ b/bridge/aria_bridge.py @@ -325,8 +325,16 @@ class STTEngine: Erkannter Text oder leerer String. """ if self.model is None: - logger.error("Whisper-Modell nicht initialisiert") - return "" + # Lazy-Load: normalerweise laeuft STT remote auf der Gamebox. + # Erst wenn das Fallback hier zuschlaegt, laden wir lokal. + logger.info("Lokales Whisper-Fallback — Modell wird nachgeladen...") + try: + self.initialize() + except Exception: + logger.exception("Lokales Whisper konnte nicht geladen werden") + return "" + if self.model is None: + return "" try: # Audio als float32 normalisieren @@ -523,6 +531,9 @@ class ARIABridge: # Wird fuer die direkt folgende ARIA-Antwort genutzt und dann zurueckgesetzt. # So kann jedes Geraet seine bevorzugte Stimme bekommen (pro Request). self._next_voice_override: Optional[str] = None + # STT-Requests die aktuell auf Antwort von der whisper-bridge (Gamebox) warten. + # requestId → Future mit dem Text (oder None bei Fehler). + self._pending_stt: dict[str, asyncio.Future] = {} def initialize(self) -> None: """Initialisiert alle Komponenten. @@ -535,8 +546,9 @@ class ARIABridge: logger.info("ARIA Voice Bridge startet...") logger.info("=" * 50) - # STT IMMER laden — verarbeitet Audio von der App (braucht kein Sounddevice) - self.stt_engine.initialize() + # STT wird standardmaessig von der whisper-bridge (Gamebox) erledigt. + # Lokales Whisper ist nur Fallback und wird lazy geladen wenn remote nicht + # antwortet. Das spart RAM auf der VM und Startup-Zeit. # Audio-Hardware pruefen (fuer lokales Mikro/Lautsprecher) self.audio_available = False @@ -1195,11 +1207,16 @@ class ARIABridge: changed = True if "whisperModel" in payload: new_model = payload["whisperModel"] - if new_model and new_model != self.stt_engine.model_size: - logger.info("[rvs] Whisper-Modell Wechsel: %s -> %s (laedt...)", self.stt_engine.model_size, new_model) - loop = asyncio.get_event_loop() - if await loop.run_in_executor(None, self.stt_engine.reload, new_model): - changed = True + allowed = {"tiny", "base", "small", "medium", "large-v3"} + if new_model in allowed and new_model != self.stt_engine.model_size: + # Merken und mitschicken an whisper-bridge (Gamebox). + # Lokales Modell wird NICHT geladen — nur das Fallback braucht's, + # und das passiert erst on-demand wenn Remote nicht antwortet. + logger.info("[rvs] Whisper-Modell → %s (nur Config; Modell laedt Gamebox)", + new_model) + self.stt_engine.model_size = new_model + self.stt_engine.model = None + changed = True # Persistent speichern in Shared Volume if changed: try: @@ -1359,22 +1376,111 @@ class ARIABridge: mime_type, duration_ms, len(audio_b64) // 1365) asyncio.create_task(self._process_app_audio(audio_b64, mime_type)) + elif msg_type == "stt_response": + # Antwort der whisper-bridge auf unseren stt_request + request_id = payload.get("requestId", "") + future = self._pending_stt.get(request_id) + if future is None or future.done(): + return + error = payload.get("error", "") + if error: + logger.warning("[rvs] stt_response Fehler: %s", error) + future.set_result(None) + else: + text = payload.get("text", "") + stt_ms = payload.get("sttMs", 0) + model = payload.get("model", "?") + logger.info("[rvs] Remote-STT OK (%s, %dms): '%s'", model, stt_ms, (text or "")[:80]) + future.set_result(text) + return + else: logger.debug("[rvs] Unbekannter Typ: %s", msg_type) + # STT-Orchestrierung: zuerst Remote (Gamebox), Fallback lokal. + # Timeout grosszuegig gewaehlt, damit auch ein erstmaliger Modell-Load + # auf der Gamebox (bis ~30s bei large-v3) durchgeht. + _STT_REMOTE_TIMEOUT_S = 45.0 + async def _process_app_audio(self, audio_b64: str, mime_type: str) -> None: - """Decodiert App-Audio (Base64 AAC/MP4), konvertiert zu 16kHz PCM, STT, sendet an core.""" + """App-Audio → STT → aria-core. Primaer via whisper-bridge (RVS), Fallback lokal.""" + # Erst Remote versuchen + text = await self._stt_remote(audio_b64, mime_type) + if text is None: + # Remote hat nicht geantwortet → lokales Whisper + logger.warning("[rvs] Remote-STT nicht verfuegbar — Fallback auf lokales Whisper") + text = await self._stt_local(audio_b64, mime_type) + if text is None: + return + + if text.strip(): + logger.info("[rvs] STT Ergebnis: '%s'", text[:80]) + # ERST an aria-core senden (wichtigster Schritt) + await self.send_to_core(text, source="app-voice") + # STT-Text an RVS senden (fuer Anzeige in App + Diagnostic) + # sender="stt" damit Bridge es ignoriert (kein Loop) + try: + await self._send_to_rvs({ + "type": "chat", + "payload": { + "text": text, + "sender": "stt", + }, + "timestamp": int(asyncio.get_event_loop().time() * 1000), + }) + except Exception as e: + logger.warning("[rvs] STT-Text konnte nicht an RVS gesendet werden: %s", e) + else: + logger.info("[rvs] Keine Sprache erkannt — ignoriert") + + async def _stt_remote(self, audio_b64: str, mime_type: str) -> Optional[str]: + """Schickt Audio an die whisper-bridge und wartet auf stt_response. + + Rueckgabe: + str — erkannter Text (kann leer sein) + None — Remote-STT nicht erreichbar oder Fehler/Timeout (→ Fallback) + """ + if self.ws_rvs is None: + return None + + request_id = str(uuid.uuid4()) + loop = asyncio.get_event_loop() + future: asyncio.Future = loop.create_future() + self._pending_stt[request_id] = future + + try: + await self._send_to_rvs({ + "type": "stt_request", + "payload": { + "requestId": request_id, + "audio": audio_b64, + "mimeType": mime_type, + "model": getattr(self.stt_engine, "model_size", "small"), + "language": getattr(self.stt_engine, "language", "de"), + }, + "timestamp": int(loop.time() * 1000), + }) + return await asyncio.wait_for(future, timeout=self._STT_REMOTE_TIMEOUT_S) + except asyncio.TimeoutError: + logger.warning("[rvs] Remote-STT Timeout (%.0fs)", self._STT_REMOTE_TIMEOUT_S) + return None + except Exception as e: + logger.warning("[rvs] Remote-STT Fehler: %s", e) + return None + finally: + self._pending_stt.pop(request_id, None) + + async def _stt_local(self, audio_b64: str, mime_type: str) -> Optional[str]: + """Lokales Whisper-Fallback: FFmpeg → float32 → stt_engine.transcribe.""" loop = asyncio.get_event_loop() tmp_in = None tmp_out = None try: - # Base64 → temp-Datei ext = ".mp4" if "mp4" in mime_type else ".wav" if "wav" in mime_type else ".ogg" tmp_in = tempfile.NamedTemporaryFile(suffix=ext, delete=False) tmp_in.write(base64.b64decode(audio_b64)) tmp_in.close() - # FFmpeg: beliebiges Format → 16kHz mono PCM (raw float32) tmp_out = tempfile.NamedTemporaryFile(suffix=".raw", delete=False) tmp_out.close() @@ -1389,45 +1495,21 @@ class ARIABridge: ) if result.returncode != 0: logger.error("[rvs] FFmpeg Fehler: %s", result.stderr.decode()[:200]) - return + return None - # PCM lesen → numpy float32 audio_data = np.fromfile(tmp_out.name, dtype=np.float32) if len(audio_data) == 0: logger.warning("[rvs] Leere Audio-Daten nach Konvertierung") - return + return None duration_s = len(audio_data) / 16000.0 - logger.info("[rvs] Audio konvertiert: %.1fs, %d samples", duration_s, len(audio_data)) - - # STT - text = await loop.run_in_executor(None, self.stt_engine.transcribe, audio_data) - - if text.strip(): - logger.info("[rvs] STT Ergebnis: '%s'", text[:80]) - # ERST an aria-core senden (wichtigster Schritt) - await self.send_to_core(text, source="app-voice") - # STT-Text an RVS senden (fuer Anzeige in App + Diagnostic) - # sender="stt" damit Bridge es ignoriert (kein Loop) - try: - await self._send_to_rvs({ - "type": "chat", - "payload": { - "text": text, - "sender": "stt", - }, - "timestamp": int(asyncio.get_event_loop().time() * 1000), - }) - except Exception as e: - logger.warning("[rvs] STT-Text konnte nicht an RVS gesendet werden: %s", e) - else: - logger.info("[rvs] Keine Sprache erkannt — ignoriert") - + logger.info("[rvs] Lokal-STT: %.1fs Audio, model=%s", duration_s, self.stt_engine.model_size) + return await loop.run_in_executor(None, self.stt_engine.transcribe, audio_data) except Exception: - logger.exception("[rvs] Audio-Verarbeitung fehlgeschlagen") + logger.exception("[rvs] Lokales STT fehlgeschlagen") + return None finally: - # Temp-Dateien aufraeumen - for f in [tmp_in, tmp_out]: + for f in (tmp_in, tmp_out): if f: try: os.unlink(f.name) diff --git a/rvs/server.js b/rvs/server.js index 42696bc..a72d146 100644 --- a/rvs/server.js +++ b/rvs/server.js @@ -20,6 +20,7 @@ const ALLOWED_TYPES = new Set([ "audio_pcm", "xtts_delete_voice", "voice_preload", "voice_ready", + "stt_request", "stt_response", ]); // Token-Raum: token -> { clients: Set } diff --git a/xtts/docker-compose.yml b/xtts/docker-compose.yml index 11e7676..b5dac87 100644 --- a/xtts/docker-compose.yml +++ b/xtts/docker-compose.yml @@ -58,5 +58,37 @@ services: - RVS_TOKEN=${RVS_TOKEN} restart: unless-stopped + # ─── Whisper STT (GPU) ──────────────────────── + # Faster-Whisper auf der Gamebox statt auf der VM (CPU) — + # deutlich schneller. Verbindet sich selbst per WebSocket an + # den RVS und nimmt dort stt_request Nachrichten der aria-bridge + # entgegen, antwortet mit stt_response. Laedt das Modell beim + # Start vor; auf Config-Broadcasts (Diagnostic → whisperModel) + # wird zur Laufzeit hot-swapped. + whisper-bridge: + build: ./whisper + container_name: aria-whisper-bridge + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] + environment: + - RVS_HOST=${RVS_HOST} + - RVS_PORT=${RVS_PORT:-443} + - RVS_TLS=${RVS_TLS:-true} + - RVS_TLS_FALLBACK=${RVS_TLS_FALLBACK:-true} + - RVS_TOKEN=${RVS_TOKEN} + - WHISPER_MODEL=${WHISPER_MODEL:-small} + - WHISPER_DEVICE=${WHISPER_DEVICE:-cuda} + - WHISPER_COMPUTE_TYPE=${WHISPER_COMPUTE_TYPE:-float16} + - WHISPER_LANGUAGE=${WHISPER_LANGUAGE:-de} + volumes: + - whisper-models:/root/.cache/huggingface # Model-Cache persistieren + restart: unless-stopped + volumes: xtts-models: + whisper-models: diff --git a/xtts/whisper/Dockerfile b/xtts/whisper/Dockerfile new file mode 100644 index 0000000..7a55c56 --- /dev/null +++ b/xtts/whisper/Dockerfile @@ -0,0 +1,14 @@ +FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04 + +RUN apt-get update && apt-get install -y --no-install-recommends \ + python3 python3-pip ffmpeg \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +COPY requirements.txt . +RUN pip3 install --no-cache-dir -r requirements.txt + +COPY bridge.py . + +CMD ["python3", "bridge.py"] diff --git a/xtts/whisper/bridge.py b/xtts/whisper/bridge.py new file mode 100644 index 0000000..aa23f6c --- /dev/null +++ b/xtts/whisper/bridge.py @@ -0,0 +1,247 @@ +#!/usr/bin/env python3 +""" +ARIA Whisper Bridge — laeuft auf der Gamebox (RTX 3060). + +Empfaengt stt_request via RVS → FFmpeg-Konvertierung → faster-whisper auf GPU +→ sendet stt_response zurueck an die aria-bridge. + +Env: + RVS_HOST, RVS_PORT, RVS_TLS, RVS_TLS_FALLBACK, RVS_TOKEN + WHISPER_MODEL Default: small + WHISPER_DEVICE Default: cuda + WHISPER_COMPUTE_TYPE Default: float16 + WHISPER_LANGUAGE Default: de +""" +import asyncio +import base64 +import json +import logging +import os +import subprocess +import sys +import tempfile +import time +from typing import Optional + +import numpy as np +import websockets +from faster_whisper import WhisperModel + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + datefmt="%H:%M:%S", +) +logger = logging.getLogger("whisper-bridge") + +RVS_HOST = os.getenv("RVS_HOST", "").strip() +RVS_PORT = int(os.getenv("RVS_PORT", "443")) +RVS_TLS = os.getenv("RVS_TLS", "true").lower() == "true" +RVS_TLS_FALLBACK = os.getenv("RVS_TLS_FALLBACK", "true").lower() == "true" +RVS_TOKEN = os.getenv("RVS_TOKEN", "").strip() + +WHISPER_MODEL = os.getenv("WHISPER_MODEL", "small") +WHISPER_DEVICE = os.getenv("WHISPER_DEVICE", "cuda") +WHISPER_COMPUTE_TYPE = os.getenv("WHISPER_COMPUTE_TYPE", "float16") +WHISPER_LANGUAGE = os.getenv("WHISPER_LANGUAGE", "de") + +ALLOWED_MODELS = {"tiny", "base", "small", "medium", "large-v3"} + + +class WhisperRunner: + """Haelt das Whisper-Modell. Hot-Swap bei Konfig-Wechsel via ensure_loaded().""" + + def __init__(self) -> None: + self.model_size: str = WHISPER_MODEL + self.model: Optional[WhisperModel] = None + self._lock = asyncio.Lock() + + def _load_blocking(self, size: str) -> None: + logger.info( + "Lade Whisper '%s' (device=%s, compute=%s)", + size, WHISPER_DEVICE, WHISPER_COMPUTE_TYPE, + ) + t0 = time.time() + self.model = WhisperModel( + size, device=WHISPER_DEVICE, compute_type=WHISPER_COMPUTE_TYPE, + ) + self.model_size = size + logger.info("Whisper '%s' geladen in %.1fs", size, time.time() - t0) + + async def ensure_loaded(self, desired_size: str) -> None: + if desired_size not in ALLOWED_MODELS: + logger.warning("Ungueltiges Whisper-Modell '%s' — nutze %s", desired_size, WHISPER_MODEL) + desired_size = WHISPER_MODEL + async with self._lock: + if self.model is not None and self.model_size == desired_size: + return + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, self._load_blocking, desired_size) + + async def transcribe(self, audio: np.ndarray, language: str) -> tuple[str, float]: + if self.model is None: + return "", 0.0 + + def _run(): + segments, info = self.model.transcribe( + audio, language=language, beam_size=5, vad_filter=True, + ) + text = " ".join(seg.text.strip() for seg in segments) + return text, info.duration + + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, _run) + + +def ffmpeg_to_float32(audio_b64: str, mime_type: str) -> np.ndarray: + """Dekodiert beliebiges Audio-Format → 16kHz mono float32 PCM.""" + if "mp4" in mime_type or "m4a" in mime_type or "aac" in mime_type: + ext = ".mp4" + elif "wav" in mime_type: + ext = ".wav" + elif "ogg" in mime_type or "opus" in mime_type: + ext = ".ogg" + else: + ext = ".bin" + + in_fh = tempfile.NamedTemporaryFile(suffix=ext, delete=False) + try: + in_fh.write(base64.b64decode(audio_b64)) + in_fh.close() + out_path = in_fh.name + ".raw" + cmd = ["ffmpeg", "-y", "-i", in_fh.name, "-ar", "16000", "-ac", "1", "-f", "f32le", out_path] + result = subprocess.run(cmd, capture_output=True, timeout=30) + if result.returncode != 0: + logger.error("FFmpeg Fehler: %s", result.stderr.decode(errors="replace")[:300]) + return np.zeros(0, dtype=np.float32) + try: + return np.fromfile(out_path, dtype=np.float32) + finally: + try: + os.unlink(out_path) + except OSError: + pass + finally: + try: + os.unlink(in_fh.name) + except OSError: + pass + + +async def _send(ws, mtype: str, payload: dict) -> None: + try: + await ws.send(json.dumps({ + "type": mtype, + "payload": payload, + "timestamp": int(time.time() * 1000), + })) + except Exception as e: + logger.warning("Send fehlgeschlagen (%s): %s", mtype, e) + + +async def handle_stt_request(ws, payload: dict, runner: WhisperRunner) -> None: + request_id = payload.get("requestId", "") + audio_b64 = payload.get("audio", "") + mime_type = payload.get("mimeType", "audio/mp4") + model = payload.get("model") or WHISPER_MODEL + language = payload.get("language") or WHISPER_LANGUAGE + + if not audio_b64: + await _send(ws, "stt_response", {"requestId": request_id, "error": "no-audio"}) + return + + try: + t_load = time.time() + await runner.ensure_loaded(model) + load_ms = int((time.time() - t_load) * 1000) + + audio = ffmpeg_to_float32(audio_b64, mime_type) + if audio.size == 0: + await _send(ws, "stt_response", {"requestId": request_id, "error": "ffmpeg-failed"}) + return + duration_s = len(audio) / 16000.0 + logger.info("STT-Request: %.1fs Audio, model=%s, lang=%s", duration_s, runner.model_size, language) + + t_stt = time.time() + text, detected_duration = await runner.transcribe(audio, language) + stt_ms = int((time.time() - t_stt) * 1000) + + logger.info("STT-Ergebnis (%dms): '%s'", stt_ms, text[:100]) + + await _send(ws, "stt_response", { + "requestId": request_id, + "text": text.strip(), + "durationS": duration_s, + "sttMs": stt_ms, + "loadMs": load_ms, + "model": runner.model_size, + }) + except Exception as e: + logger.exception("STT-Request fehlgeschlagen") + await _send(ws, "stt_response", { + "requestId": request_id, + "error": str(e)[:200], + }) + + +async def run_loop(runner: WhisperRunner) -> None: + # Modell vorab laden damit erste Anfrage flott ist + try: + await runner.ensure_loaded(WHISPER_MODEL) + except Exception as e: + logger.error("Preload fehlgeschlagen: %s — Fortsetzung, wird bei erstem Request nachgeladen", e) + + use_tls = RVS_TLS + retry_s = 2 + tls_fallback_tried = False + + while True: + scheme = "wss" if use_tls else "ws" + url = f"{scheme}://{RVS_HOST}:{RVS_PORT}/ws?token={RVS_TOKEN}" + masked = url.replace(RVS_TOKEN, "***") if RVS_TOKEN else url + try: + logger.info("Verbinde zu RVS: %s", masked) + async with websockets.connect(url, ping_interval=20, ping_timeout=10) as ws: + logger.info("RVS verbunden") + retry_s = 2 + tls_fallback_tried = False + async for raw in ws: + try: + msg = json.loads(raw) + except Exception: + continue + mtype = msg.get("type", "") + payload = msg.get("payload", {}) or {} + + if mtype == "stt_request": + asyncio.create_task(handle_stt_request(ws, payload, runner)) + elif mtype == "config": + new_model = payload.get("whisperModel") + if new_model and new_model != runner.model_size: + logger.info("Config-Broadcast: Whisper-Modell → %s", new_model) + asyncio.create_task(runner.ensure_loaded(new_model)) + # andere Types (chat, heartbeat, ...) einfach ignorieren + except Exception as e: + logger.warning("Verbindung verloren: %s", e) + if use_tls and RVS_TLS_FALLBACK and not tls_fallback_tried: + logger.info("TLS-Verbindung fehlgeschlagen — Fallback auf ws://") + use_tls = False + tls_fallback_tried = True + continue + await asyncio.sleep(min(retry_s, 30)) + retry_s = min(retry_s * 2, 30) + + +async def main() -> None: + if not RVS_HOST: + logger.error("RVS_HOST ist nicht gesetzt — Abbruch") + sys.exit(1) + runner = WhisperRunner() + await run_loop(runner) + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + sys.exit(0) diff --git a/xtts/whisper/requirements.txt b/xtts/whisper/requirements.txt new file mode 100644 index 0000000..806c63c --- /dev/null +++ b/xtts/whisper/requirements.txt @@ -0,0 +1,3 @@ +faster-whisper==1.0.3 +websockets>=12.0 +numpy>=1.24