diff --git a/xtts/whisper/bridge.py b/xtts/whisper/bridge.py index ae589de..34d6646 100644 --- a/xtts/whisper/bridge.py +++ b/xtts/whisper/bridge.py @@ -2,8 +2,19 @@ """ ARIA Whisper Bridge — laeuft auf der Gamebox (RTX 3060). -Empfaengt stt_request via RVS → FFmpeg-Konvertierung → faster-whisper auf GPU -→ sendet stt_response zurueck an die aria-bridge. +Zwei Modi: + +1) Legacy One-Shot: stt_request mit komplettem Audio (mp4/wav/ogg base64) + → ffmpeg → faster-whisper → stt_response. Bleibt fuer Fallback/alte App. + +2) Streaming + ML-Endpointer (neu): App schickt live PCM-Chunks waehrend + der Aufnahme. Bridge transkribiert alle ~700ms auf dem Ringbuffer und + feuert stt_endpoint sobald der Transkript-String N ms nicht mehr + waechst. Ersetzt dB/VAD-Stille — endpointet auf SEMANTISCHE Stille, + funktioniert im Auto / mit Musik im Hintergrund. + + Erwartetes PCM-Format vom App-Native-Modul: 16 kHz mono s16le (genau + das was OpenWakeWord/AudioRecord schon liefert — kein Resampling). Env: RVS_HOST, RVS_PORT, RVS_TLS, RVS_TLS_FALLBACK, RVS_TOKEN @@ -21,6 +32,7 @@ import subprocess import sys import tempfile import time +from dataclasses import dataclass, field from typing import Optional import numpy as np @@ -47,6 +59,13 @@ WHISPER_LANGUAGE = os.getenv("WHISPER_LANGUAGE", "de") ALLOWED_MODELS = {"tiny", "base", "small", "medium", "large-v3"} +# Streaming-Parameter (Defaults — koennen pro Session vom App-Payload ueberschrieben werden) +STREAM_TRANSCRIBE_INTERVAL_MS = 700 # alle 700ms transkribieren waehrend Stream laeuft +STREAM_DEFAULT_ENDPOINT_MS = 1500 # nach 1.5s ohne neuen Text → Endpoint +STREAM_DEFAULT_HARD_CAP_MS = 60000 # nach 60s Audio: harter Cut egal was +STREAM_MIN_AUDIO_MS = 600 # erst transkribieren wenn min 600ms Audio da +STREAM_SESSION_TTL_S = 120 # tote Sessions nach 2 min aufraeumen + class WhisperRunner: """Haelt das Whisper-Modell. Hot-Swap bei Konfig-Wechsel via ensure_loaded().""" @@ -55,6 +74,9 @@ class WhisperRunner: self.model_size: str = WHISPER_MODEL self.model: Optional[WhisperModel] = None self._lock = asyncio.Lock() + # Serialisiert transcribe()-Calls — faster-whisper ist nicht + # parallel-safe auf einer GPU-Instanz, plus VRAM-Fragmentierung. + self._transcribe_lock = asyncio.Lock() def _load_blocking(self, size: str) -> None: logger.info( @@ -78,19 +100,21 @@ class WhisperRunner: loop = asyncio.get_event_loop() await loop.run_in_executor(None, self._load_blocking, desired_size) - async def transcribe(self, audio: np.ndarray, language: str) -> tuple[str, float]: + async def transcribe(self, audio: np.ndarray, language: str, + beam_size: int = 5, vad_filter: bool = True) -> tuple[str, float]: if self.model is None: return "", 0.0 def _run(): segments, info = self.model.transcribe( - audio, language=language, beam_size=5, vad_filter=True, + audio, language=language, beam_size=beam_size, vad_filter=vad_filter, ) text = " ".join(seg.text.strip() for seg in segments) return text, info.duration loop = asyncio.get_event_loop() - return await loop.run_in_executor(None, _run) + async with self._transcribe_lock: + return await loop.run_in_executor(None, _run) def ffmpeg_to_float32(audio_b64: str, mime_type: str) -> np.ndarray: @@ -128,6 +152,14 @@ def ffmpeg_to_float32(audio_b64: str, mime_type: str) -> np.ndarray: pass +def pcm_s16le_to_float32(pcm_bytes: bytes) -> np.ndarray: + """16-bit signed little-endian PCM → float32 in [-1, 1]. Whisper-Format.""" + if not pcm_bytes: + return np.zeros(0, dtype=np.float32) + arr = np.frombuffer(pcm_bytes, dtype=np.int16).astype(np.float32) / 32768.0 + return arr + + async def _send(ws, mtype: str, payload: dict) -> None: try: await ws.send(json.dumps({ @@ -139,14 +171,284 @@ async def _send(ws, mtype: str, payload: dict) -> None: logger.warning("Send fehlgeschlagen (%s): %s", mtype, e) +# ────────────────────────────────────────────────────────────── +# STREAMING-SESSIONS +# ────────────────────────────────────────────────────────────── + +@dataclass +class StreamSession: + """State pro laufendem Streaming-STT-Request.""" + request_id: str + audio_request_id: str + language: str + model: str + endpoint_ms: int + hard_cap_ms: int + voice: str = "" # echoed back via stt_endpoint fuer ChatScreen → TTS-Override + speed: float = 1.0 + interrupted: bool = False # Barge-In + location: Optional[dict] = None + sample_rate: int = 16000 + pcm_buffer: bytearray = field(default_factory=bytearray) + started_at: float = field(default_factory=time.time) + last_chunk_at: float = field(default_factory=time.time) + last_partial: str = "" + last_growth_at: float = 0.0 + last_transcribe_at: float = 0.0 + closed: bool = False # nach stream_end gesetzt + endpoint_sent: bool = False # Endpoint nur einmal feuern + + +class SessionManager: + """Haelt alle aktiven Streaming-Sessions + Endpointer-Loop.""" + + def __init__(self, runner: WhisperRunner) -> None: + self.runner = runner + self._sessions: dict[str, StreamSession] = {} + self._ws = None # wird vom run_loop gesetzt + self._loop_task: Optional[asyncio.Task] = None + + def attach_ws(self, ws) -> None: + self._ws = ws + + def detach_ws(self) -> None: + self._ws = None + # Sessions ueberleben Disconnect — der naechste Reconnect kann sie weiter + # fuettern, falls die App das gleiche requestId nochmal schickt. + # Aber unsere App startet nach Reconnect eine neue Aufnahme; alte Sessions + # werden vom Cleanup-Task entsorgt nach STREAM_SESSION_TTL_S. + + def start_session(self, payload: dict) -> Optional[StreamSession]: + request_id = payload.get("requestId", "").strip() + if not request_id: + logger.warning("stt_stream_start ohne requestId — ignoriert") + return None + if request_id in self._sessions: + logger.warning("stt_stream_start: requestId %s schon aktiv — alte Session wird ersetzt", + request_id[:8]) + try: + endpoint_ms = int(payload.get("endpointMs") or STREAM_DEFAULT_ENDPOINT_MS) + except (TypeError, ValueError): + endpoint_ms = STREAM_DEFAULT_ENDPOINT_MS + try: + hard_cap_ms = int(payload.get("hardCapMs") or STREAM_DEFAULT_HARD_CAP_MS) + except (TypeError, ValueError): + hard_cap_ms = STREAM_DEFAULT_HARD_CAP_MS + try: + speed = float(payload.get("speed") or 1.0) + except (TypeError, ValueError): + speed = 1.0 + session = StreamSession( + request_id=request_id, + audio_request_id=payload.get("audioRequestId", "") or "", + language=payload.get("language") or WHISPER_LANGUAGE, + model=payload.get("model") or self.runner.model_size or WHISPER_MODEL, + endpoint_ms=endpoint_ms, + hard_cap_ms=hard_cap_ms, + voice=payload.get("voice", "") or "", + speed=speed, + interrupted=bool(payload.get("interrupted", False)), + location=payload.get("location") or None, + sample_rate=int(payload.get("sampleRate") or 16000), + ) + self._sessions[request_id] = session + logger.info("Stream-Session offen: id=%s lang=%s model=%s endpointMs=%d hardCapMs=%d voice=%r", + request_id[:8], session.language, session.model, + session.endpoint_ms, session.hard_cap_ms, session.voice or "(default)") + return session + + def feed_chunk(self, payload: dict) -> bool: + request_id = payload.get("requestId", "") + session = self._sessions.get(request_id) + if session is None or session.closed: + return False + pcm_b64 = payload.get("pcm", "") + if not pcm_b64: + return False + try: + pcm_bytes = base64.b64decode(pcm_b64) + except Exception: + logger.warning("Stream %s: ungueltige base64-PCM-Daten", request_id[:8]) + return False + session.pcm_buffer.extend(pcm_bytes) + session.last_chunk_at = time.time() + return True + + def end_session(self, request_id: str) -> Optional[StreamSession]: + """Markiert Session als geschlossen. Der Endpointer-Loop macht das + Final-Transcribe + Cleanup.""" + session = self._sessions.get(request_id) + if session is None: + return None + session.closed = True + return session + + def drop(self, request_id: str) -> None: + self._sessions.pop(request_id, None) + + async def run_endpointer(self) -> None: + """Background-Loop: alle ~200ms ueber alle Sessions iterieren.""" + logger.info("Endpointer-Loop gestartet (transcribe-interval=%dms, default-endpoint=%dms)", + STREAM_TRANSCRIBE_INTERVAL_MS, STREAM_DEFAULT_ENDPOINT_MS) + while True: + await asyncio.sleep(0.2) + now = time.time() + # Snapshot — sonst RuntimeError wenn wir waehrend Iteration sessions[] + # mutieren (Endpoint-Drop). + for sid, sess in list(self._sessions.items()): + try: + await self._tick_session(sess, now) + except Exception: + logger.exception("Endpointer-Tick crashed (session=%s)", sid[:8]) + + # Cleanup: tote Sessions (ohne Chunk seit STREAM_SESSION_TTL_S) + for sid, sess in list(self._sessions.items()): + if now - sess.last_chunk_at > STREAM_SESSION_TTL_S: + logger.info("Stream %s: TTL ueberschritten (ohne Daten seit %.0fs) — drop", + sid[:8], now - sess.last_chunk_at) + self.drop(sid) + + async def _tick_session(self, sess: StreamSession, now: float) -> None: + ws = self._ws + if ws is None: + return # disconnected — Endpointer pausiert bis Reconnect + + audio_ms = self._buffer_duration_ms(sess) + + # Hard-Cap erreicht → wie Endpoint behandeln (egal ob neuer Text) + elapsed_ms = (now - sess.started_at) * 1000.0 + if elapsed_ms > sess.hard_cap_ms and not sess.endpoint_sent and not sess.closed: + logger.info("Stream %s: HardCap %dms erreicht — forciere Endpoint", + sess.request_id[:8], sess.hard_cap_ms) + await self._finalize(sess, ws, reason="hardcap") + return + + # Closed (stream_end empfangen) → finalisieren mit dem gesammelten Buffer + if sess.closed and not sess.endpoint_sent: + await self._finalize(sess, ws, reason="stream_end") + return + + # Noch zu wenig Audio fuer eine erste Transkription + if audio_ms < STREAM_MIN_AUDIO_MS: + return + + # Transcribe-Throttling + since_last = (now - sess.last_transcribe_at) * 1000.0 + if since_last < STREAM_TRANSCRIBE_INTERVAL_MS: + return + + sess.last_transcribe_at = now + try: + audio = pcm_s16le_to_float32(bytes(sess.pcm_buffer)) + except Exception: + logger.exception("Stream %s: PCM-Decode fehlgeschlagen", sess.request_id[:8]) + return + + try: + # Kleinere beam_size fuer Streaming-Partials — wir wollen Latenz, + # nicht maximale Genauigkeit. Final-Transcribe (in _finalize) faehrt + # dann mit beam_size=5. + text, _dur = await self.runner.transcribe(audio, sess.language, beam_size=1, vad_filter=True) + except Exception: + logger.exception("Stream %s: Partial-Transcribe crashed", sess.request_id[:8]) + return + + text = text.strip() + grew = bool(text) and text != sess.last_partial + if grew: + sess.last_partial = text + sess.last_growth_at = now + # Optional: stt_partial broadcasten fuer UI-Feedback. Wir schicken's + # mit damit Diagnostic / ChatScreen Live-Text zeigen kann. + await _send(ws, "stt_partial", { + "requestId": sess.request_id, + "audioRequestId": sess.audio_request_id, + "text": text, + }) + else: + # Stagnation pruefen — Endpoint-Bedingung + if sess.last_growth_at == 0.0: + # Noch gar kein Text erkannt. Wenn der User gar nichts sagt + # springt Brain irgendwann aus eigenem Conversation-Window- + # Timeout in der App raus; wir machen hier nix. + return + silence_ms = (now - sess.last_growth_at) * 1000.0 + if silence_ms >= sess.endpoint_ms and not sess.endpoint_sent: + logger.info("Stream %s: Endpoint nach %dms ohne neuen Text — Text=%r", + sess.request_id[:8], int(silence_ms), sess.last_partial[:80]) + await self._finalize(sess, ws, reason="endpoint") + + def _buffer_duration_ms(self, sess: StreamSession) -> float: + # 16-bit s16le mono → 2 bytes pro Sample + samples = len(sess.pcm_buffer) // 2 + if samples == 0: + return 0.0 + return (samples / sess.sample_rate) * 1000.0 + + async def _finalize(self, sess: StreamSession, ws, reason: str) -> None: + """Endgueltige Transkription auf dem vollen Buffer (beam_size=5), + feuert stt_endpoint + stt_stream_done, droppt Session.""" + if sess.endpoint_sent: + return + sess.endpoint_sent = True + audio = pcm_s16le_to_float32(bytes(sess.pcm_buffer)) + if audio.size == 0: + logger.info("Stream %s: leere Audio-Daten — final text leer", sess.request_id[:8]) + final_text = "" + stt_ms = 0 + duration_s = 0.0 + else: + t0 = time.time() + try: + final_text, _dur = await self.runner.transcribe(audio, sess.language, beam_size=5, vad_filter=True) + except Exception: + logger.exception("Stream %s: Final-Transcribe crashed", sess.request_id[:8]) + final_text = sess.last_partial # fallback auf letzten Partial + stt_ms = int((time.time() - t0) * 1000) + duration_s = audio.size / 16000.0 + final_text = final_text.strip() + + logger.info("Stream %s: FINAL (reason=%s, %.1fs Audio, %dms): %r", + sess.request_id[:8], reason, duration_s, stt_ms, final_text[:120]) + + # stt_endpoint: das ist DAS Event auf das aria-bridge horcht fuer den + # Brain-Shortcut. Enthaelt alle Felder die bisher in 'audio' lagen, + # ohne den Audio-Roundtrip (App → aria-bridge → whisper → aria-bridge). + endpoint_payload = { + "requestId": sess.request_id, + "audioRequestId": sess.audio_request_id, + "text": final_text, + "reason": reason, + "durationS": duration_s, + "sttMs": stt_ms, + "voice": sess.voice, + "speed": sess.speed, + "interrupted": sess.interrupted, + } + if sess.location: + endpoint_payload["location"] = sess.location + await _send(ws, "stt_endpoint", endpoint_payload) + + # stt_stream_done: an die App — damit sie ihre Recording-State-Machine + # zurueck auf armed setzt (Mikro aus, ggf. Wake-Word wieder an). + await _send(ws, "stt_stream_done", { + "requestId": sess.request_id, + "audioRequestId": sess.audio_request_id, + "text": final_text, + "reason": reason, + }) + + self.drop(sess.request_id) + + +# ────────────────────────────────────────────────────────────── +# LEGACY ONE-SHOT (unveraendert) +# ────────────────────────────────────────────────────────────── + async def handle_stt_request(ws, payload: dict, runner: WhisperRunner) -> None: request_id = payload.get("requestId", "") audio_b64 = payload.get("audio", "") mime_type = payload.get("mimeType", "audio/mp4") - # Modell-Auswahl: - # payload.model gesetzt → nimm das (aria-bridge sendet's basierend auf Config) - # sonst + Modell geladen → behalt das aktuelle (kein sinnloser Swap) - # sonst → fallback auf ENV-Default model = payload.get("model") or (runner.model_size if runner.model is not None else WHISPER_MODEL) language = payload.get("language") or WHISPER_LANGUAGE @@ -156,8 +458,6 @@ async def handle_stt_request(ws, payload: dict, runner: WhisperRunner) -> None: try: t_load = time.time() - # Falls Modell noch nicht geladen (Race-Condition: stt_request vor config) - # → Status-Broadcast loading→ready damit der App-Banner aufpoppt needs_load = runner.model is None or runner.model_size != model if needs_load: await _broadcast_status(ws, "loading", model=model) @@ -205,7 +505,11 @@ async def _broadcast_status(ws, state: str, **extra) -> None: await _send(ws, "service_status", payload) -async def run_loop(runner: WhisperRunner) -> None: +# ────────────────────────────────────────────────────────────── +# WS-LOOP +# ────────────────────────────────────────────────────────────── + +async def run_loop(runner: WhisperRunner, sessions: SessionManager) -> None: use_tls = RVS_TLS retry_s = 2 tls_fallback_tried = False @@ -216,20 +520,12 @@ async def run_loop(runner: WhisperRunner) -> None: masked = url.replace(RVS_TOKEN, "***") if RVS_TOKEN else url try: logger.info("Verbinde zu RVS: %s", masked) - # max_size 50MB damit grosse stt_request (Voice-Cloning-WAVs als - # base64 koennen mehrere MB werden) nicht das Frame-Limit sprengen - # und die Verbindung mit 1009 'message too big' killen. async with websockets.connect(url, ping_interval=20, ping_timeout=10, max_size=50 * 1024 * 1024) as ws: logger.info("RVS verbunden") retry_s = 2 tls_fallback_tried = False + sessions.attach_ws(ws) - # Initialer Status-Broadcast — uebertont alten "ready"-State - # im App/Diagnostic Banner (sonst denkt der User noch alles ist - # gut von vorher). Wenn Modell schon geladen → ready, sonst - # loading mit aktuellem (Default-)Namen. - # Plus: config_request an aria-bridge — wir wissen nicht ob - # sie auch grad reconnected hat oder schon laenger online ist. async def _initial_handshake(): try: if runner.model is not None: @@ -259,9 +555,41 @@ async def run_loop(runner: WhisperRunner) -> None: logger.info("stt_request empfangen (id=%s, %dKB Audio)", req_id[:8] if req_id != "?" else "?", audio_len // 1365) asyncio.create_task(handle_stt_request(ws, payload, runner)) + + elif mtype == "stt_stream_start": + # Ggf. Modell sicherstellen — sonst antwortet der erste + # transcribe-Call mit Leerstring weil Model None. + target_model = payload.get("model") or runner.model_size or WHISPER_MODEL + needs_load = (runner.model is None) or (target_model != runner.model_size) + if needs_load: + async def _load_then_start(p, target): + await _broadcast_status(ws, "loading", model=target) + try: + await runner.ensure_loaded(target) + await _broadcast_status(ws, "ready", model=runner.model_size) + except Exception as e: + await _broadcast_status(ws, "error", error=str(e)[:200]) + return + sessions.start_session(p) + asyncio.create_task(_load_then_start(payload, target_model)) + else: + sessions.start_session(payload) + + elif mtype == "stt_audio_chunk": + ok = sessions.feed_chunk(payload) + if not ok: + # Sehr verbose im Schlimmstfall — debug-Level reicht. + logger.debug("stt_audio_chunk: unbekannte/closed session %s", + payload.get("requestId", "")[:8]) + + elif mtype == "stt_stream_end": + req_id = payload.get("requestId", "") + logger.info("stt_stream_end empfangen: id=%s reason=%s", + req_id[:8], payload.get("reason", "")) + sessions.end_session(req_id) + elif mtype == "config": new_model = payload.get("whisperModel") or WHISPER_MODEL - # Laden wenn (a) noch nix geladen, oder (b) Modell wechselt needs_load = (runner.model is None) or (new_model != runner.model_size) if needs_load: logger.info("Config-Broadcast: Whisper-Modell -> %s%s", @@ -280,11 +608,10 @@ async def run_loop(runner: WhisperRunner) -> None: await _broadcast_status(ws, "error", error=str(e)[:200]) asyncio.create_task(_swap_with_status(new_model)) else: - # Alle anderen Nachrichten debug-loggen — hilft beim Diagnostizieren, - # ob stt_request ueberhaupt durch den RVS kommt logger.debug("Unbeachteter Type: %s", mtype) except Exception as e: logger.warning("Verbindung verloren: %s", e) + sessions.detach_ws() if use_tls and RVS_TLS_FALLBACK and not tls_fallback_tried: logger.info("TLS-Verbindung fehlgeschlagen — Fallback auf ws://") use_tls = False @@ -292,10 +619,6 @@ async def run_loop(runner: WhisperRunner) -> None: continue await asyncio.sleep(min(retry_s, 30)) retry_s = min(retry_s * 2, 30) - # Sticky-Fallback verhindern: nach jedem Disconnect-Cycle wieder - # mit wss anfangen. Sonst klebt der Client nach einem temporaeren - # TLS-Hick auf ws:// fest und kommt nie mehr auf wss zurueck — - # genau das Problem das die App + Bridge frueher schon hatten. use_tls = RVS_TLS tls_fallback_tried = False @@ -305,7 +628,11 @@ async def main() -> None: logger.error("RVS_HOST ist nicht gesetzt — Abbruch") sys.exit(1) runner = WhisperRunner() - await run_loop(runner) + sessions = SessionManager(runner) + # Endpointer-Loop nebenbei laufen lassen — er pruefst _ws is None und + # schlaeft solange das nicht gesetzt ist. + asyncio.create_task(sessions.run_endpointer()) + await run_loop(runner, sessions) if __name__ == "__main__":