41999c2304
Neue RVS-Messages auf der Whisper-Bridge:
stt_stream_start {requestId, audioRequestId, language?, model?,
endpointMs?=1500, hardCapMs?=60000, voice, speed,
interrupted, location, sampleRate?=16000}
stt_audio_chunk {requestId, pcm: base64-s16le, seq}
stt_stream_end {requestId, reason}
stt_partial (Bridge→App, alle ~700ms, fuer Live-UI-Feedback)
stt_endpoint (Bridge→App+aria-bridge, finaler Text + alle Echos)
stt_stream_done (Bridge→App, signalisiert Session-Ende)
Endpointer-Logik:
- alle 700ms transkribiert die Bridge den Ringbuffer (beam_size=1, schnell)
- waechst der Transkript-String → Stagnation-Timer reset
- waechst er nicht → bei endpointMs ohne Wachstum: finalisiert
- bei hardCapMs (60s) sowieso finalisiert egal ob stagnierend
- Final-Transcribe nochmal mit beam_size=5 fuer Qualitaet
- stt_endpoint enthaelt voice/speed/interrupted/location echos,
damit aria-bridge in Phase 2 direkt an Brain weiterleiten kann
Legacy stt_request (One-Shot mit base64-mp4/wav) bleibt unveraendert
als Fallback.
Default-Parameter (alle vom App-Payload uebersteuerbar):
STREAM_TRANSCRIBE_INTERVAL_MS = 700 (Throttle)
STREAM_DEFAULT_ENDPOINT_MS = 1500 (Stille = kein neuer Text)
STREAM_DEFAULT_HARD_CAP_MS = 60000 (Schmerzgrenze)
STREAM_MIN_AUDIO_MS = 600 (erst transkribieren ab N Audio)
STREAM_SESSION_TTL_S = 120 (tote Sessions aufraeumen)
Ersetzt den dB/VAD-Stille-Trigger auf der App-Seite — Endpointer
hoert auf SEMANTISCHE Stille (kein neuer Text), nicht akustische.
Funktioniert im Auto / mit Musik im Hintergrund / in lauten
Umgebungen wo VAD versagt.
643 lines
28 KiB
Python
643 lines
28 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
ARIA Whisper Bridge — laeuft auf der Gamebox (RTX 3060).
|
|
|
|
Zwei Modi:
|
|
|
|
1) Legacy One-Shot: stt_request mit komplettem Audio (mp4/wav/ogg base64)
|
|
→ ffmpeg → faster-whisper → stt_response. Bleibt fuer Fallback/alte App.
|
|
|
|
2) Streaming + ML-Endpointer (neu): App schickt live PCM-Chunks waehrend
|
|
der Aufnahme. Bridge transkribiert alle ~700ms auf dem Ringbuffer und
|
|
feuert stt_endpoint sobald der Transkript-String N ms nicht mehr
|
|
waechst. Ersetzt dB/VAD-Stille — endpointet auf SEMANTISCHE Stille,
|
|
funktioniert im Auto / mit Musik im Hintergrund.
|
|
|
|
Erwartetes PCM-Format vom App-Native-Modul: 16 kHz mono s16le (genau
|
|
das was OpenWakeWord/AudioRecord schon liefert — kein Resampling).
|
|
|
|
Env:
|
|
RVS_HOST, RVS_PORT, RVS_TLS, RVS_TLS_FALLBACK, RVS_TOKEN
|
|
WHISPER_MODEL Default: small
|
|
WHISPER_DEVICE Default: cuda
|
|
WHISPER_COMPUTE_TYPE Default: float16
|
|
WHISPER_LANGUAGE Default: de
|
|
"""
|
|
import asyncio
|
|
import base64
|
|
import json
|
|
import logging
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
import tempfile
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from typing import Optional
|
|
|
|
import numpy as np
|
|
import websockets
|
|
from faster_whisper import WhisperModel
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
|
datefmt="%H:%M:%S",
|
|
)
|
|
logger = logging.getLogger("whisper-bridge")
|
|
|
|
RVS_HOST = os.getenv("RVS_HOST", "").strip()
|
|
RVS_PORT = int(os.getenv("RVS_PORT", "443"))
|
|
RVS_TLS = os.getenv("RVS_TLS", "true").lower() == "true"
|
|
RVS_TLS_FALLBACK = os.getenv("RVS_TLS_FALLBACK", "true").lower() == "true"
|
|
RVS_TOKEN = os.getenv("RVS_TOKEN", "").strip()
|
|
|
|
WHISPER_MODEL = os.getenv("WHISPER_MODEL", "small")
|
|
WHISPER_DEVICE = os.getenv("WHISPER_DEVICE", "cuda")
|
|
WHISPER_COMPUTE_TYPE = os.getenv("WHISPER_COMPUTE_TYPE", "float16")
|
|
WHISPER_LANGUAGE = os.getenv("WHISPER_LANGUAGE", "de")
|
|
|
|
ALLOWED_MODELS = {"tiny", "base", "small", "medium", "large-v3"}
|
|
|
|
# Streaming-Parameter (Defaults — koennen pro Session vom App-Payload ueberschrieben werden)
|
|
STREAM_TRANSCRIBE_INTERVAL_MS = 700 # alle 700ms transkribieren waehrend Stream laeuft
|
|
STREAM_DEFAULT_ENDPOINT_MS = 1500 # nach 1.5s ohne neuen Text → Endpoint
|
|
STREAM_DEFAULT_HARD_CAP_MS = 60000 # nach 60s Audio: harter Cut egal was
|
|
STREAM_MIN_AUDIO_MS = 600 # erst transkribieren wenn min 600ms Audio da
|
|
STREAM_SESSION_TTL_S = 120 # tote Sessions nach 2 min aufraeumen
|
|
|
|
|
|
class WhisperRunner:
|
|
"""Haelt das Whisper-Modell. Hot-Swap bei Konfig-Wechsel via ensure_loaded()."""
|
|
|
|
def __init__(self) -> None:
|
|
self.model_size: str = WHISPER_MODEL
|
|
self.model: Optional[WhisperModel] = None
|
|
self._lock = asyncio.Lock()
|
|
# Serialisiert transcribe()-Calls — faster-whisper ist nicht
|
|
# parallel-safe auf einer GPU-Instanz, plus VRAM-Fragmentierung.
|
|
self._transcribe_lock = asyncio.Lock()
|
|
|
|
def _load_blocking(self, size: str) -> None:
|
|
logger.info(
|
|
"Lade Whisper '%s' (device=%s, compute=%s)",
|
|
size, WHISPER_DEVICE, WHISPER_COMPUTE_TYPE,
|
|
)
|
|
t0 = time.time()
|
|
self.model = WhisperModel(
|
|
size, device=WHISPER_DEVICE, compute_type=WHISPER_COMPUTE_TYPE,
|
|
)
|
|
self.model_size = size
|
|
logger.info("Whisper '%s' geladen in %.1fs", size, time.time() - t0)
|
|
|
|
async def ensure_loaded(self, desired_size: str) -> None:
|
|
if desired_size not in ALLOWED_MODELS:
|
|
logger.warning("Ungueltiges Whisper-Modell '%s' — nutze %s", desired_size, WHISPER_MODEL)
|
|
desired_size = WHISPER_MODEL
|
|
async with self._lock:
|
|
if self.model is not None and self.model_size == desired_size:
|
|
return
|
|
loop = asyncio.get_event_loop()
|
|
await loop.run_in_executor(None, self._load_blocking, desired_size)
|
|
|
|
async def transcribe(self, audio: np.ndarray, language: str,
|
|
beam_size: int = 5, vad_filter: bool = True) -> tuple[str, float]:
|
|
if self.model is None:
|
|
return "", 0.0
|
|
|
|
def _run():
|
|
segments, info = self.model.transcribe(
|
|
audio, language=language, beam_size=beam_size, vad_filter=vad_filter,
|
|
)
|
|
text = " ".join(seg.text.strip() for seg in segments)
|
|
return text, info.duration
|
|
|
|
loop = asyncio.get_event_loop()
|
|
async with self._transcribe_lock:
|
|
return await loop.run_in_executor(None, _run)
|
|
|
|
|
|
def ffmpeg_to_float32(audio_b64: str, mime_type: str) -> np.ndarray:
|
|
"""Dekodiert beliebiges Audio-Format → 16kHz mono float32 PCM."""
|
|
if "mp4" in mime_type or "m4a" in mime_type or "aac" in mime_type:
|
|
ext = ".mp4"
|
|
elif "wav" in mime_type:
|
|
ext = ".wav"
|
|
elif "ogg" in mime_type or "opus" in mime_type:
|
|
ext = ".ogg"
|
|
else:
|
|
ext = ".bin"
|
|
|
|
in_fh = tempfile.NamedTemporaryFile(suffix=ext, delete=False)
|
|
try:
|
|
in_fh.write(base64.b64decode(audio_b64))
|
|
in_fh.close()
|
|
out_path = in_fh.name + ".raw"
|
|
cmd = ["ffmpeg", "-y", "-i", in_fh.name, "-ar", "16000", "-ac", "1", "-f", "f32le", out_path]
|
|
result = subprocess.run(cmd, capture_output=True, timeout=30)
|
|
if result.returncode != 0:
|
|
logger.error("FFmpeg Fehler: %s", result.stderr.decode(errors="replace")[:300])
|
|
return np.zeros(0, dtype=np.float32)
|
|
try:
|
|
return np.fromfile(out_path, dtype=np.float32)
|
|
finally:
|
|
try:
|
|
os.unlink(out_path)
|
|
except OSError:
|
|
pass
|
|
finally:
|
|
try:
|
|
os.unlink(in_fh.name)
|
|
except OSError:
|
|
pass
|
|
|
|
|
|
def pcm_s16le_to_float32(pcm_bytes: bytes) -> np.ndarray:
|
|
"""16-bit signed little-endian PCM → float32 in [-1, 1]. Whisper-Format."""
|
|
if not pcm_bytes:
|
|
return np.zeros(0, dtype=np.float32)
|
|
arr = np.frombuffer(pcm_bytes, dtype=np.int16).astype(np.float32) / 32768.0
|
|
return arr
|
|
|
|
|
|
async def _send(ws, mtype: str, payload: dict) -> None:
|
|
try:
|
|
await ws.send(json.dumps({
|
|
"type": mtype,
|
|
"payload": payload,
|
|
"timestamp": int(time.time() * 1000),
|
|
}))
|
|
except Exception as e:
|
|
logger.warning("Send fehlgeschlagen (%s): %s", mtype, e)
|
|
|
|
|
|
# ──────────────────────────────────────────────────────────────
|
|
# STREAMING-SESSIONS
|
|
# ──────────────────────────────────────────────────────────────
|
|
|
|
@dataclass
|
|
class StreamSession:
|
|
"""State pro laufendem Streaming-STT-Request."""
|
|
request_id: str
|
|
audio_request_id: str
|
|
language: str
|
|
model: str
|
|
endpoint_ms: int
|
|
hard_cap_ms: int
|
|
voice: str = "" # echoed back via stt_endpoint fuer ChatScreen → TTS-Override
|
|
speed: float = 1.0
|
|
interrupted: bool = False # Barge-In
|
|
location: Optional[dict] = None
|
|
sample_rate: int = 16000
|
|
pcm_buffer: bytearray = field(default_factory=bytearray)
|
|
started_at: float = field(default_factory=time.time)
|
|
last_chunk_at: float = field(default_factory=time.time)
|
|
last_partial: str = ""
|
|
last_growth_at: float = 0.0
|
|
last_transcribe_at: float = 0.0
|
|
closed: bool = False # nach stream_end gesetzt
|
|
endpoint_sent: bool = False # Endpoint nur einmal feuern
|
|
|
|
|
|
class SessionManager:
|
|
"""Haelt alle aktiven Streaming-Sessions + Endpointer-Loop."""
|
|
|
|
def __init__(self, runner: WhisperRunner) -> None:
|
|
self.runner = runner
|
|
self._sessions: dict[str, StreamSession] = {}
|
|
self._ws = None # wird vom run_loop gesetzt
|
|
self._loop_task: Optional[asyncio.Task] = None
|
|
|
|
def attach_ws(self, ws) -> None:
|
|
self._ws = ws
|
|
|
|
def detach_ws(self) -> None:
|
|
self._ws = None
|
|
# Sessions ueberleben Disconnect — der naechste Reconnect kann sie weiter
|
|
# fuettern, falls die App das gleiche requestId nochmal schickt.
|
|
# Aber unsere App startet nach Reconnect eine neue Aufnahme; alte Sessions
|
|
# werden vom Cleanup-Task entsorgt nach STREAM_SESSION_TTL_S.
|
|
|
|
def start_session(self, payload: dict) -> Optional[StreamSession]:
|
|
request_id = payload.get("requestId", "").strip()
|
|
if not request_id:
|
|
logger.warning("stt_stream_start ohne requestId — ignoriert")
|
|
return None
|
|
if request_id in self._sessions:
|
|
logger.warning("stt_stream_start: requestId %s schon aktiv — alte Session wird ersetzt",
|
|
request_id[:8])
|
|
try:
|
|
endpoint_ms = int(payload.get("endpointMs") or STREAM_DEFAULT_ENDPOINT_MS)
|
|
except (TypeError, ValueError):
|
|
endpoint_ms = STREAM_DEFAULT_ENDPOINT_MS
|
|
try:
|
|
hard_cap_ms = int(payload.get("hardCapMs") or STREAM_DEFAULT_HARD_CAP_MS)
|
|
except (TypeError, ValueError):
|
|
hard_cap_ms = STREAM_DEFAULT_HARD_CAP_MS
|
|
try:
|
|
speed = float(payload.get("speed") or 1.0)
|
|
except (TypeError, ValueError):
|
|
speed = 1.0
|
|
session = StreamSession(
|
|
request_id=request_id,
|
|
audio_request_id=payload.get("audioRequestId", "") or "",
|
|
language=payload.get("language") or WHISPER_LANGUAGE,
|
|
model=payload.get("model") or self.runner.model_size or WHISPER_MODEL,
|
|
endpoint_ms=endpoint_ms,
|
|
hard_cap_ms=hard_cap_ms,
|
|
voice=payload.get("voice", "") or "",
|
|
speed=speed,
|
|
interrupted=bool(payload.get("interrupted", False)),
|
|
location=payload.get("location") or None,
|
|
sample_rate=int(payload.get("sampleRate") or 16000),
|
|
)
|
|
self._sessions[request_id] = session
|
|
logger.info("Stream-Session offen: id=%s lang=%s model=%s endpointMs=%d hardCapMs=%d voice=%r",
|
|
request_id[:8], session.language, session.model,
|
|
session.endpoint_ms, session.hard_cap_ms, session.voice or "(default)")
|
|
return session
|
|
|
|
def feed_chunk(self, payload: dict) -> bool:
|
|
request_id = payload.get("requestId", "")
|
|
session = self._sessions.get(request_id)
|
|
if session is None or session.closed:
|
|
return False
|
|
pcm_b64 = payload.get("pcm", "")
|
|
if not pcm_b64:
|
|
return False
|
|
try:
|
|
pcm_bytes = base64.b64decode(pcm_b64)
|
|
except Exception:
|
|
logger.warning("Stream %s: ungueltige base64-PCM-Daten", request_id[:8])
|
|
return False
|
|
session.pcm_buffer.extend(pcm_bytes)
|
|
session.last_chunk_at = time.time()
|
|
return True
|
|
|
|
def end_session(self, request_id: str) -> Optional[StreamSession]:
|
|
"""Markiert Session als geschlossen. Der Endpointer-Loop macht das
|
|
Final-Transcribe + Cleanup."""
|
|
session = self._sessions.get(request_id)
|
|
if session is None:
|
|
return None
|
|
session.closed = True
|
|
return session
|
|
|
|
def drop(self, request_id: str) -> None:
|
|
self._sessions.pop(request_id, None)
|
|
|
|
async def run_endpointer(self) -> None:
|
|
"""Background-Loop: alle ~200ms ueber alle Sessions iterieren."""
|
|
logger.info("Endpointer-Loop gestartet (transcribe-interval=%dms, default-endpoint=%dms)",
|
|
STREAM_TRANSCRIBE_INTERVAL_MS, STREAM_DEFAULT_ENDPOINT_MS)
|
|
while True:
|
|
await asyncio.sleep(0.2)
|
|
now = time.time()
|
|
# Snapshot — sonst RuntimeError wenn wir waehrend Iteration sessions[]
|
|
# mutieren (Endpoint-Drop).
|
|
for sid, sess in list(self._sessions.items()):
|
|
try:
|
|
await self._tick_session(sess, now)
|
|
except Exception:
|
|
logger.exception("Endpointer-Tick crashed (session=%s)", sid[:8])
|
|
|
|
# Cleanup: tote Sessions (ohne Chunk seit STREAM_SESSION_TTL_S)
|
|
for sid, sess in list(self._sessions.items()):
|
|
if now - sess.last_chunk_at > STREAM_SESSION_TTL_S:
|
|
logger.info("Stream %s: TTL ueberschritten (ohne Daten seit %.0fs) — drop",
|
|
sid[:8], now - sess.last_chunk_at)
|
|
self.drop(sid)
|
|
|
|
async def _tick_session(self, sess: StreamSession, now: float) -> None:
|
|
ws = self._ws
|
|
if ws is None:
|
|
return # disconnected — Endpointer pausiert bis Reconnect
|
|
|
|
audio_ms = self._buffer_duration_ms(sess)
|
|
|
|
# Hard-Cap erreicht → wie Endpoint behandeln (egal ob neuer Text)
|
|
elapsed_ms = (now - sess.started_at) * 1000.0
|
|
if elapsed_ms > sess.hard_cap_ms and not sess.endpoint_sent and not sess.closed:
|
|
logger.info("Stream %s: HardCap %dms erreicht — forciere Endpoint",
|
|
sess.request_id[:8], sess.hard_cap_ms)
|
|
await self._finalize(sess, ws, reason="hardcap")
|
|
return
|
|
|
|
# Closed (stream_end empfangen) → finalisieren mit dem gesammelten Buffer
|
|
if sess.closed and not sess.endpoint_sent:
|
|
await self._finalize(sess, ws, reason="stream_end")
|
|
return
|
|
|
|
# Noch zu wenig Audio fuer eine erste Transkription
|
|
if audio_ms < STREAM_MIN_AUDIO_MS:
|
|
return
|
|
|
|
# Transcribe-Throttling
|
|
since_last = (now - sess.last_transcribe_at) * 1000.0
|
|
if since_last < STREAM_TRANSCRIBE_INTERVAL_MS:
|
|
return
|
|
|
|
sess.last_transcribe_at = now
|
|
try:
|
|
audio = pcm_s16le_to_float32(bytes(sess.pcm_buffer))
|
|
except Exception:
|
|
logger.exception("Stream %s: PCM-Decode fehlgeschlagen", sess.request_id[:8])
|
|
return
|
|
|
|
try:
|
|
# Kleinere beam_size fuer Streaming-Partials — wir wollen Latenz,
|
|
# nicht maximale Genauigkeit. Final-Transcribe (in _finalize) faehrt
|
|
# dann mit beam_size=5.
|
|
text, _dur = await self.runner.transcribe(audio, sess.language, beam_size=1, vad_filter=True)
|
|
except Exception:
|
|
logger.exception("Stream %s: Partial-Transcribe crashed", sess.request_id[:8])
|
|
return
|
|
|
|
text = text.strip()
|
|
grew = bool(text) and text != sess.last_partial
|
|
if grew:
|
|
sess.last_partial = text
|
|
sess.last_growth_at = now
|
|
# Optional: stt_partial broadcasten fuer UI-Feedback. Wir schicken's
|
|
# mit damit Diagnostic / ChatScreen Live-Text zeigen kann.
|
|
await _send(ws, "stt_partial", {
|
|
"requestId": sess.request_id,
|
|
"audioRequestId": sess.audio_request_id,
|
|
"text": text,
|
|
})
|
|
else:
|
|
# Stagnation pruefen — Endpoint-Bedingung
|
|
if sess.last_growth_at == 0.0:
|
|
# Noch gar kein Text erkannt. Wenn der User gar nichts sagt
|
|
# springt Brain irgendwann aus eigenem Conversation-Window-
|
|
# Timeout in der App raus; wir machen hier nix.
|
|
return
|
|
silence_ms = (now - sess.last_growth_at) * 1000.0
|
|
if silence_ms >= sess.endpoint_ms and not sess.endpoint_sent:
|
|
logger.info("Stream %s: Endpoint nach %dms ohne neuen Text — Text=%r",
|
|
sess.request_id[:8], int(silence_ms), sess.last_partial[:80])
|
|
await self._finalize(sess, ws, reason="endpoint")
|
|
|
|
def _buffer_duration_ms(self, sess: StreamSession) -> float:
|
|
# 16-bit s16le mono → 2 bytes pro Sample
|
|
samples = len(sess.pcm_buffer) // 2
|
|
if samples == 0:
|
|
return 0.0
|
|
return (samples / sess.sample_rate) * 1000.0
|
|
|
|
async def _finalize(self, sess: StreamSession, ws, reason: str) -> None:
|
|
"""Endgueltige Transkription auf dem vollen Buffer (beam_size=5),
|
|
feuert stt_endpoint + stt_stream_done, droppt Session."""
|
|
if sess.endpoint_sent:
|
|
return
|
|
sess.endpoint_sent = True
|
|
audio = pcm_s16le_to_float32(bytes(sess.pcm_buffer))
|
|
if audio.size == 0:
|
|
logger.info("Stream %s: leere Audio-Daten — final text leer", sess.request_id[:8])
|
|
final_text = ""
|
|
stt_ms = 0
|
|
duration_s = 0.0
|
|
else:
|
|
t0 = time.time()
|
|
try:
|
|
final_text, _dur = await self.runner.transcribe(audio, sess.language, beam_size=5, vad_filter=True)
|
|
except Exception:
|
|
logger.exception("Stream %s: Final-Transcribe crashed", sess.request_id[:8])
|
|
final_text = sess.last_partial # fallback auf letzten Partial
|
|
stt_ms = int((time.time() - t0) * 1000)
|
|
duration_s = audio.size / 16000.0
|
|
final_text = final_text.strip()
|
|
|
|
logger.info("Stream %s: FINAL (reason=%s, %.1fs Audio, %dms): %r",
|
|
sess.request_id[:8], reason, duration_s, stt_ms, final_text[:120])
|
|
|
|
# stt_endpoint: das ist DAS Event auf das aria-bridge horcht fuer den
|
|
# Brain-Shortcut. Enthaelt alle Felder die bisher in 'audio' lagen,
|
|
# ohne den Audio-Roundtrip (App → aria-bridge → whisper → aria-bridge).
|
|
endpoint_payload = {
|
|
"requestId": sess.request_id,
|
|
"audioRequestId": sess.audio_request_id,
|
|
"text": final_text,
|
|
"reason": reason,
|
|
"durationS": duration_s,
|
|
"sttMs": stt_ms,
|
|
"voice": sess.voice,
|
|
"speed": sess.speed,
|
|
"interrupted": sess.interrupted,
|
|
}
|
|
if sess.location:
|
|
endpoint_payload["location"] = sess.location
|
|
await _send(ws, "stt_endpoint", endpoint_payload)
|
|
|
|
# stt_stream_done: an die App — damit sie ihre Recording-State-Machine
|
|
# zurueck auf armed setzt (Mikro aus, ggf. Wake-Word wieder an).
|
|
await _send(ws, "stt_stream_done", {
|
|
"requestId": sess.request_id,
|
|
"audioRequestId": sess.audio_request_id,
|
|
"text": final_text,
|
|
"reason": reason,
|
|
})
|
|
|
|
self.drop(sess.request_id)
|
|
|
|
|
|
# ──────────────────────────────────────────────────────────────
|
|
# LEGACY ONE-SHOT (unveraendert)
|
|
# ──────────────────────────────────────────────────────────────
|
|
|
|
async def handle_stt_request(ws, payload: dict, runner: WhisperRunner) -> None:
|
|
request_id = payload.get("requestId", "")
|
|
audio_b64 = payload.get("audio", "")
|
|
mime_type = payload.get("mimeType", "audio/mp4")
|
|
model = payload.get("model") or (runner.model_size if runner.model is not None else WHISPER_MODEL)
|
|
language = payload.get("language") or WHISPER_LANGUAGE
|
|
|
|
if not audio_b64:
|
|
await _send(ws, "stt_response", {"requestId": request_id, "error": "no-audio"})
|
|
return
|
|
|
|
try:
|
|
t_load = time.time()
|
|
needs_load = runner.model is None or runner.model_size != model
|
|
if needs_load:
|
|
await _broadcast_status(ws, "loading", model=model)
|
|
await runner.ensure_loaded(model)
|
|
load_ms = int((time.time() - t_load) * 1000)
|
|
if needs_load:
|
|
await _broadcast_status(ws, "ready",
|
|
model=runner.model_size,
|
|
loadSeconds=load_ms / 1000.0)
|
|
|
|
audio = ffmpeg_to_float32(audio_b64, mime_type)
|
|
if audio.size == 0:
|
|
await _send(ws, "stt_response", {"requestId": request_id, "error": "ffmpeg-failed"})
|
|
return
|
|
duration_s = len(audio) / 16000.0
|
|
logger.info("STT-Request: %.1fs Audio, model=%s, lang=%s", duration_s, runner.model_size, language)
|
|
|
|
t_stt = time.time()
|
|
text, detected_duration = await runner.transcribe(audio, language)
|
|
stt_ms = int((time.time() - t_stt) * 1000)
|
|
|
|
logger.info("STT-Ergebnis (%dms): '%s'", stt_ms, text[:100])
|
|
|
|
await _send(ws, "stt_response", {
|
|
"requestId": request_id,
|
|
"text": text.strip(),
|
|
"durationS": duration_s,
|
|
"sttMs": stt_ms,
|
|
"loadMs": load_ms,
|
|
"model": runner.model_size,
|
|
})
|
|
except Exception as e:
|
|
logger.exception("STT-Request fehlgeschlagen")
|
|
await _send(ws, "stt_response", {
|
|
"requestId": request_id,
|
|
"error": str(e)[:200],
|
|
})
|
|
|
|
|
|
async def _broadcast_status(ws, state: str, **extra) -> None:
|
|
"""Sendet service_status fuer das Whisper-Modul.
|
|
state: 'loading' | 'ready' | 'error'."""
|
|
payload = {"service": "whisper", "state": state}
|
|
payload.update(extra)
|
|
await _send(ws, "service_status", payload)
|
|
|
|
|
|
# ──────────────────────────────────────────────────────────────
|
|
# WS-LOOP
|
|
# ──────────────────────────────────────────────────────────────
|
|
|
|
async def run_loop(runner: WhisperRunner, sessions: SessionManager) -> None:
|
|
use_tls = RVS_TLS
|
|
retry_s = 2
|
|
tls_fallback_tried = False
|
|
|
|
while True:
|
|
scheme = "wss" if use_tls else "ws"
|
|
url = f"{scheme}://{RVS_HOST}:{RVS_PORT}/ws?token={RVS_TOKEN}"
|
|
masked = url.replace(RVS_TOKEN, "***") if RVS_TOKEN else url
|
|
try:
|
|
logger.info("Verbinde zu RVS: %s", masked)
|
|
async with websockets.connect(url, ping_interval=20, ping_timeout=10, max_size=50 * 1024 * 1024) as ws:
|
|
logger.info("RVS verbunden")
|
|
retry_s = 2
|
|
tls_fallback_tried = False
|
|
sessions.attach_ws(ws)
|
|
|
|
async def _initial_handshake():
|
|
try:
|
|
if runner.model is not None:
|
|
logger.info("Initial: broadcaste ready (Modell schon im RAM: %s)", runner.model_size)
|
|
await _broadcast_status(ws, "ready", model=runner.model_size)
|
|
else:
|
|
init_model = runner.model_size or WHISPER_MODEL
|
|
logger.info("Initial: broadcaste loading (model=%s)", init_model)
|
|
await _broadcast_status(ws, "loading", model=init_model)
|
|
logger.info("Initial: sende config_request an aria-bridge")
|
|
await _send(ws, "config_request", {"service": "whisper"})
|
|
except Exception as e:
|
|
logger.exception("Initial-Handshake crashed: %s", e)
|
|
asyncio.create_task(_initial_handshake())
|
|
|
|
async for raw in ws:
|
|
try:
|
|
msg = json.loads(raw)
|
|
except Exception:
|
|
continue
|
|
mtype = msg.get("type", "")
|
|
payload = msg.get("payload", {}) or {}
|
|
|
|
if mtype == "stt_request":
|
|
req_id = payload.get("requestId", "?")
|
|
audio_len = len(payload.get("audio", ""))
|
|
logger.info("stt_request empfangen (id=%s, %dKB Audio)",
|
|
req_id[:8] if req_id != "?" else "?", audio_len // 1365)
|
|
asyncio.create_task(handle_stt_request(ws, payload, runner))
|
|
|
|
elif mtype == "stt_stream_start":
|
|
# Ggf. Modell sicherstellen — sonst antwortet der erste
|
|
# transcribe-Call mit Leerstring weil Model None.
|
|
target_model = payload.get("model") or runner.model_size or WHISPER_MODEL
|
|
needs_load = (runner.model is None) or (target_model != runner.model_size)
|
|
if needs_load:
|
|
async def _load_then_start(p, target):
|
|
await _broadcast_status(ws, "loading", model=target)
|
|
try:
|
|
await runner.ensure_loaded(target)
|
|
await _broadcast_status(ws, "ready", model=runner.model_size)
|
|
except Exception as e:
|
|
await _broadcast_status(ws, "error", error=str(e)[:200])
|
|
return
|
|
sessions.start_session(p)
|
|
asyncio.create_task(_load_then_start(payload, target_model))
|
|
else:
|
|
sessions.start_session(payload)
|
|
|
|
elif mtype == "stt_audio_chunk":
|
|
ok = sessions.feed_chunk(payload)
|
|
if not ok:
|
|
# Sehr verbose im Schlimmstfall — debug-Level reicht.
|
|
logger.debug("stt_audio_chunk: unbekannte/closed session %s",
|
|
payload.get("requestId", "")[:8])
|
|
|
|
elif mtype == "stt_stream_end":
|
|
req_id = payload.get("requestId", "")
|
|
logger.info("stt_stream_end empfangen: id=%s reason=%s",
|
|
req_id[:8], payload.get("reason", ""))
|
|
sessions.end_session(req_id)
|
|
|
|
elif mtype == "config":
|
|
new_model = payload.get("whisperModel") or WHISPER_MODEL
|
|
needs_load = (runner.model is None) or (new_model != runner.model_size)
|
|
if needs_load:
|
|
logger.info("Config-Broadcast: Whisper-Modell -> %s%s",
|
|
new_model,
|
|
" (initial)" if runner.model is None else " (Wechsel)")
|
|
async def _swap_with_status(target):
|
|
await _broadcast_status(ws, "loading", model=target)
|
|
try:
|
|
t0 = time.time()
|
|
await runner.ensure_loaded(target)
|
|
elapsed = time.time() - t0
|
|
await _broadcast_status(ws, "ready",
|
|
model=runner.model_size,
|
|
loadSeconds=elapsed)
|
|
except Exception as e:
|
|
await _broadcast_status(ws, "error", error=str(e)[:200])
|
|
asyncio.create_task(_swap_with_status(new_model))
|
|
else:
|
|
logger.debug("Unbeachteter Type: %s", mtype)
|
|
except Exception as e:
|
|
logger.warning("Verbindung verloren: %s", e)
|
|
sessions.detach_ws()
|
|
if use_tls and RVS_TLS_FALLBACK and not tls_fallback_tried:
|
|
logger.info("TLS-Verbindung fehlgeschlagen — Fallback auf ws://")
|
|
use_tls = False
|
|
tls_fallback_tried = True
|
|
continue
|
|
await asyncio.sleep(min(retry_s, 30))
|
|
retry_s = min(retry_s * 2, 30)
|
|
use_tls = RVS_TLS
|
|
tls_fallback_tried = False
|
|
|
|
|
|
async def main() -> None:
|
|
if not RVS_HOST:
|
|
logger.error("RVS_HOST ist nicht gesetzt — Abbruch")
|
|
sys.exit(1)
|
|
runner = WhisperRunner()
|
|
sessions = SessionManager(runner)
|
|
# Endpointer-Loop nebenbei laufen lassen — er pruefst _ws is None und
|
|
# schlaeft solange das nicht gesetzt ist.
|
|
asyncio.create_task(sessions.run_endpointer())
|
|
await run_loop(runner, sessions)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
asyncio.run(main())
|
|
except KeyboardInterrupt:
|
|
sys.exit(0)
|