295 lines
12 KiB
Python
295 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
ARIA Whisper Bridge — laeuft auf der Gamebox (RTX 3060).
|
|
|
|
Empfaengt stt_request via RVS → FFmpeg-Konvertierung → faster-whisper auf GPU
|
|
→ sendet stt_response zurueck an die aria-bridge.
|
|
|
|
Env:
|
|
RVS_HOST, RVS_PORT, RVS_TLS, RVS_TLS_FALLBACK, RVS_TOKEN
|
|
WHISPER_MODEL Default: small
|
|
WHISPER_DEVICE Default: cuda
|
|
WHISPER_COMPUTE_TYPE Default: float16
|
|
WHISPER_LANGUAGE Default: de
|
|
"""
|
|
import asyncio
|
|
import base64
|
|
import json
|
|
import logging
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
import tempfile
|
|
import time
|
|
from typing import Optional
|
|
|
|
import numpy as np
|
|
import websockets
|
|
from faster_whisper import WhisperModel
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
|
datefmt="%H:%M:%S",
|
|
)
|
|
logger = logging.getLogger("whisper-bridge")
|
|
|
|
RVS_HOST = os.getenv("RVS_HOST", "").strip()
|
|
RVS_PORT = int(os.getenv("RVS_PORT", "443"))
|
|
RVS_TLS = os.getenv("RVS_TLS", "true").lower() == "true"
|
|
RVS_TLS_FALLBACK = os.getenv("RVS_TLS_FALLBACK", "true").lower() == "true"
|
|
RVS_TOKEN = os.getenv("RVS_TOKEN", "").strip()
|
|
|
|
WHISPER_MODEL = os.getenv("WHISPER_MODEL", "small")
|
|
WHISPER_DEVICE = os.getenv("WHISPER_DEVICE", "cuda")
|
|
WHISPER_COMPUTE_TYPE = os.getenv("WHISPER_COMPUTE_TYPE", "float16")
|
|
WHISPER_LANGUAGE = os.getenv("WHISPER_LANGUAGE", "de")
|
|
|
|
ALLOWED_MODELS = {"tiny", "base", "small", "medium", "large-v3"}
|
|
|
|
|
|
class WhisperRunner:
|
|
"""Haelt das Whisper-Modell. Hot-Swap bei Konfig-Wechsel via ensure_loaded()."""
|
|
|
|
def __init__(self) -> None:
|
|
self.model_size: str = WHISPER_MODEL
|
|
self.model: Optional[WhisperModel] = None
|
|
self._lock = asyncio.Lock()
|
|
|
|
def _load_blocking(self, size: str) -> None:
|
|
logger.info(
|
|
"Lade Whisper '%s' (device=%s, compute=%s)",
|
|
size, WHISPER_DEVICE, WHISPER_COMPUTE_TYPE,
|
|
)
|
|
t0 = time.time()
|
|
self.model = WhisperModel(
|
|
size, device=WHISPER_DEVICE, compute_type=WHISPER_COMPUTE_TYPE,
|
|
)
|
|
self.model_size = size
|
|
logger.info("Whisper '%s' geladen in %.1fs", size, time.time() - t0)
|
|
|
|
async def ensure_loaded(self, desired_size: str) -> None:
|
|
if desired_size not in ALLOWED_MODELS:
|
|
logger.warning("Ungueltiges Whisper-Modell '%s' — nutze %s", desired_size, WHISPER_MODEL)
|
|
desired_size = WHISPER_MODEL
|
|
async with self._lock:
|
|
if self.model is not None and self.model_size == desired_size:
|
|
return
|
|
loop = asyncio.get_event_loop()
|
|
await loop.run_in_executor(None, self._load_blocking, desired_size)
|
|
|
|
async def transcribe(self, audio: np.ndarray, language: str) -> tuple[str, float]:
|
|
if self.model is None:
|
|
return "", 0.0
|
|
|
|
def _run():
|
|
segments, info = self.model.transcribe(
|
|
audio, language=language, beam_size=5, vad_filter=True,
|
|
)
|
|
text = " ".join(seg.text.strip() for seg in segments)
|
|
return text, info.duration
|
|
|
|
loop = asyncio.get_event_loop()
|
|
return await loop.run_in_executor(None, _run)
|
|
|
|
|
|
def ffmpeg_to_float32(audio_b64: str, mime_type: str) -> np.ndarray:
|
|
"""Dekodiert beliebiges Audio-Format → 16kHz mono float32 PCM."""
|
|
if "mp4" in mime_type or "m4a" in mime_type or "aac" in mime_type:
|
|
ext = ".mp4"
|
|
elif "wav" in mime_type:
|
|
ext = ".wav"
|
|
elif "ogg" in mime_type or "opus" in mime_type:
|
|
ext = ".ogg"
|
|
else:
|
|
ext = ".bin"
|
|
|
|
in_fh = tempfile.NamedTemporaryFile(suffix=ext, delete=False)
|
|
try:
|
|
in_fh.write(base64.b64decode(audio_b64))
|
|
in_fh.close()
|
|
out_path = in_fh.name + ".raw"
|
|
cmd = ["ffmpeg", "-y", "-i", in_fh.name, "-ar", "16000", "-ac", "1", "-f", "f32le", out_path]
|
|
result = subprocess.run(cmd, capture_output=True, timeout=30)
|
|
if result.returncode != 0:
|
|
logger.error("FFmpeg Fehler: %s", result.stderr.decode(errors="replace")[:300])
|
|
return np.zeros(0, dtype=np.float32)
|
|
try:
|
|
return np.fromfile(out_path, dtype=np.float32)
|
|
finally:
|
|
try:
|
|
os.unlink(out_path)
|
|
except OSError:
|
|
pass
|
|
finally:
|
|
try:
|
|
os.unlink(in_fh.name)
|
|
except OSError:
|
|
pass
|
|
|
|
|
|
async def _send(ws, mtype: str, payload: dict) -> None:
|
|
try:
|
|
await ws.send(json.dumps({
|
|
"type": mtype,
|
|
"payload": payload,
|
|
"timestamp": int(time.time() * 1000),
|
|
}))
|
|
except Exception as e:
|
|
logger.warning("Send fehlgeschlagen (%s): %s", mtype, e)
|
|
|
|
|
|
async def handle_stt_request(ws, payload: dict, runner: WhisperRunner) -> None:
|
|
request_id = payload.get("requestId", "")
|
|
audio_b64 = payload.get("audio", "")
|
|
mime_type = payload.get("mimeType", "audio/mp4")
|
|
model = payload.get("model") or WHISPER_MODEL
|
|
language = payload.get("language") or WHISPER_LANGUAGE
|
|
|
|
if not audio_b64:
|
|
await _send(ws, "stt_response", {"requestId": request_id, "error": "no-audio"})
|
|
return
|
|
|
|
try:
|
|
t_load = time.time()
|
|
# Falls Modell noch nicht geladen (Race-Condition: stt_request vor config)
|
|
# → Status-Broadcast loading→ready damit der App-Banner aufpoppt
|
|
needs_load = runner.model is None or runner.model_size != model
|
|
if needs_load:
|
|
await _broadcast_status(ws, "loading", model=model)
|
|
await runner.ensure_loaded(model)
|
|
load_ms = int((time.time() - t_load) * 1000)
|
|
if needs_load:
|
|
await _broadcast_status(ws, "ready",
|
|
model=runner.model_size,
|
|
loadSeconds=load_ms / 1000.0)
|
|
|
|
audio = ffmpeg_to_float32(audio_b64, mime_type)
|
|
if audio.size == 0:
|
|
await _send(ws, "stt_response", {"requestId": request_id, "error": "ffmpeg-failed"})
|
|
return
|
|
duration_s = len(audio) / 16000.0
|
|
logger.info("STT-Request: %.1fs Audio, model=%s, lang=%s", duration_s, runner.model_size, language)
|
|
|
|
t_stt = time.time()
|
|
text, detected_duration = await runner.transcribe(audio, language)
|
|
stt_ms = int((time.time() - t_stt) * 1000)
|
|
|
|
logger.info("STT-Ergebnis (%dms): '%s'", stt_ms, text[:100])
|
|
|
|
await _send(ws, "stt_response", {
|
|
"requestId": request_id,
|
|
"text": text.strip(),
|
|
"durationS": duration_s,
|
|
"sttMs": stt_ms,
|
|
"loadMs": load_ms,
|
|
"model": runner.model_size,
|
|
})
|
|
except Exception as e:
|
|
logger.exception("STT-Request fehlgeschlagen")
|
|
await _send(ws, "stt_response", {
|
|
"requestId": request_id,
|
|
"error": str(e)[:200],
|
|
})
|
|
|
|
|
|
async def _broadcast_status(ws, state: str, **extra) -> None:
|
|
"""Sendet service_status fuer das Whisper-Modul.
|
|
state: 'loading' | 'ready' | 'error'."""
|
|
payload = {"service": "whisper", "state": state}
|
|
payload.update(extra)
|
|
await _send(ws, "service_status", payload)
|
|
|
|
|
|
async def run_loop(runner: WhisperRunner) -> None:
|
|
use_tls = RVS_TLS
|
|
retry_s = 2
|
|
tls_fallback_tried = False
|
|
|
|
while True:
|
|
scheme = "wss" if use_tls else "ws"
|
|
url = f"{scheme}://{RVS_HOST}:{RVS_PORT}/ws?token={RVS_TOKEN}"
|
|
masked = url.replace(RVS_TOKEN, "***") if RVS_TOKEN else url
|
|
try:
|
|
logger.info("Verbinde zu RVS: %s", masked)
|
|
# max_size 50MB damit grosse stt_request (Voice-Cloning-WAVs als
|
|
# base64 koennen mehrere MB werden) nicht das Frame-Limit sprengen
|
|
# und die Verbindung mit 1009 'message too big' killen.
|
|
async with websockets.connect(url, ping_interval=20, ping_timeout=10, max_size=50 * 1024 * 1024) as ws:
|
|
logger.info("RVS verbunden")
|
|
retry_s = 2
|
|
tls_fallback_tried = False
|
|
|
|
# KEIN initialer Preload. Der aria-bridge broadcastet kurz nach
|
|
# RVS-Connect die persistierte Config (whisperModel) — wir laden
|
|
# erst wenn der drinsteht, sonst wuerde 2x geladen werden
|
|
# (small als ENV-Default + dann das echte Modell).
|
|
# Wenn ein stt_request schneller kommt als die Config: ensure_loaded
|
|
# im Handler greift dann ein und laedt das angeforderte Modell.
|
|
if runner.model is not None:
|
|
# Wir sind reconnectet — Modell schon im RAM, einfach 'ready'
|
|
asyncio.create_task(_broadcast_status(ws, "ready", model=runner.model_size))
|
|
|
|
async for raw in ws:
|
|
try:
|
|
msg = json.loads(raw)
|
|
except Exception:
|
|
continue
|
|
mtype = msg.get("type", "")
|
|
payload = msg.get("payload", {}) or {}
|
|
|
|
if mtype == "stt_request":
|
|
req_id = payload.get("requestId", "?")
|
|
audio_len = len(payload.get("audio", ""))
|
|
logger.info("stt_request empfangen (id=%s, %dKB Audio)",
|
|
req_id[:8] if req_id != "?" else "?", audio_len // 1365)
|
|
asyncio.create_task(handle_stt_request(ws, payload, runner))
|
|
elif mtype == "config":
|
|
new_model = payload.get("whisperModel") or WHISPER_MODEL
|
|
# Laden wenn (a) noch nix geladen, oder (b) Modell wechselt
|
|
needs_load = (runner.model is None) or (new_model != runner.model_size)
|
|
if needs_load:
|
|
logger.info("Config-Broadcast: Whisper-Modell -> %s%s",
|
|
new_model,
|
|
" (initial)" if runner.model is None else " (Wechsel)")
|
|
async def _swap_with_status(target):
|
|
await _broadcast_status(ws, "loading", model=target)
|
|
try:
|
|
t0 = time.time()
|
|
await runner.ensure_loaded(target)
|
|
elapsed = time.time() - t0
|
|
await _broadcast_status(ws, "ready",
|
|
model=runner.model_size,
|
|
loadSeconds=elapsed)
|
|
except Exception as e:
|
|
await _broadcast_status(ws, "error", error=str(e)[:200])
|
|
asyncio.create_task(_swap_with_status(new_model))
|
|
else:
|
|
# Alle anderen Nachrichten debug-loggen — hilft beim Diagnostizieren,
|
|
# ob stt_request ueberhaupt durch den RVS kommt
|
|
logger.debug("Unbeachteter Type: %s", mtype)
|
|
except Exception as e:
|
|
logger.warning("Verbindung verloren: %s", e)
|
|
if use_tls and RVS_TLS_FALLBACK and not tls_fallback_tried:
|
|
logger.info("TLS-Verbindung fehlgeschlagen — Fallback auf ws://")
|
|
use_tls = False
|
|
tls_fallback_tried = True
|
|
continue
|
|
await asyncio.sleep(min(retry_s, 30))
|
|
retry_s = min(retry_s * 2, 30)
|
|
|
|
|
|
async def main() -> None:
|
|
if not RVS_HOST:
|
|
logger.error("RVS_HOST ist nicht gesetzt — Abbruch")
|
|
sys.exit(1)
|
|
runner = WhisperRunner()
|
|
await run_loop(runner)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
asyncio.run(main())
|
|
except KeyboardInterrupt:
|
|
sys.exit(0)
|