#!/usr/bin/env python3 """ ARIA FLUX-Bridge — laeuft auf der Gamebox (RTX 3060). Empfaengt flux_request via RVS → FLUX.1-dev/-schnell auf GPU → sendet flux_response mit base64-PNG zurueck an die aria-bridge. Diese speichert die Datei nach /shared/uploads/ und ARIA referenziert sie mit [FILE: ...]-Marker in ihrer Antwort. 12 GB VRAM auf der 3060 reichen fuer FLUX.1-dev nur mit `enable_model_cpu_offload()` — sonst OOM. Setze FLUX_OFFLOAD=sequential fuer Maximal-Sparsamkeit (langsamer) oder FLUX_OFFLOAD=none wenn die GPU genug VRAM hat (z.B. spaeter 4090). Env: RVS_HOST, RVS_PORT, RVS_TLS, RVS_TLS_FALLBACK, RVS_TOKEN FLUX_MODEL Default: black-forest-labs/FLUX.1-dev Alt: black-forest-labs/FLUX.1-schnell (4-Step, Apache-2.0) FLUX_DEVICE Default: cuda FLUX_DTYPE Default: bfloat16 (alt: float16) FLUX_OFFLOAD Default: model (alt: sequential | none) FLUX_MAX_STEPS Default: 50 FLUX_MAX_DIM Default: 1536 """ import asyncio import base64 import io import json import logging import os import sys import time import uuid from typing import Optional import websockets logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", datefmt="%H:%M:%S", ) logger = logging.getLogger("flux-bridge") # HuggingFace/Torch download-Logs daempfen logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("urllib3").setLevel(logging.WARNING) RVS_HOST = os.getenv("RVS_HOST", "").strip() RVS_PORT = int(os.getenv("RVS_PORT", "443")) RVS_TLS = os.getenv("RVS_TLS", "true").lower() == "true" RVS_TLS_FALLBACK = os.getenv("RVS_TLS_FALLBACK", "true").lower() == "true" RVS_TOKEN = os.getenv("RVS_TOKEN", "").strip() # Bootstrap-Fallback: nur relevant wenn beim allerersten Start KEIN # Diagnostic-config-Broadcast eintrifft UND der erste Render-Request # auch kein 'model' enthaelt. Default 'schnell', weil Apache-2.0 # (kein HF-Token noetig) — Stefan stellt sein gewuenschtes Default ueber # Diagnostic ein. ENV ist also nur fuer den extremen Edge-Case da, in # der .env.example absichtlich nicht mehr dokumentiert. FLUX_MODEL = os.getenv("FLUX_MODEL", "black-forest-labs/FLUX.1-schnell").strip() FLUX_DEVICE = os.getenv("FLUX_DEVICE", "cuda").strip() FLUX_DTYPE = os.getenv("FLUX_DTYPE", "bfloat16").strip().lower() FLUX_OFFLOAD = os.getenv("FLUX_OFFLOAD", "model").strip().lower() FLUX_MAX_STEPS = int(os.getenv("FLUX_MAX_STEPS", "50")) FLUX_MAX_DIM = int(os.getenv("FLUX_MAX_DIM", "1536")) # FLUX-dev native: guidance=3.5, steps=28. FLUX-schnell: guidance=0.0, steps=4. DEFAULT_STEPS_DEV = 28 DEFAULT_STEPS_SCHNELL = 4 DEFAULT_GUIDANCE_DEV = 3.5 DEFAULT_GUIDANCE_SCHNELL = 0.0 # Mapping fuer das User-facing Tag → HF-Modell-ID. Stefan stellt in Diagnostic # nur 'dev' / 'schnell' ein; FLUX_MODEL aus der env kann zwar eine custom-ID # sein (Bootstrap), wird aber beim ersten config-Broadcast normalerweise # durch die Diagnostic-Wahl uebersteuert. MODEL_TAGS: dict[str, str] = { "dev": "black-forest-labs/FLUX.1-dev", "schnell": "black-forest-labs/FLUX.1-schnell", } def _tag_to_model_id(tag: str) -> str: """Mappt 'dev'/'schnell' auf HF-ID. Andere Strings werden 1:1 durchgereicht (custom-IDs aus FLUX_MODEL env). Leere/ungueltige Werte → FLUX_MODEL Default.""" if not tag: return FLUX_MODEL t = tag.strip() return MODEL_TAGS.get(t, t) def _is_schnell(model_id: str) -> bool: return "schnell" in model_id.lower() def _is_model_cached(model_id: str) -> bool: """Prueft ob ein HF-Modell-Snapshot lokal im hf-cache vorhanden ist. HF speichert unter ~/.cache/huggingface/hub/models--{org}--{name}/snapshots/{rev}/. Wenn das snapshots-Verzeichnis nicht existiert oder leer ist → Erst-Download steht an (24+ GB fuer FLUX.1-dev, 24+ GB fuer FLUX.1-schnell — Stefan kriegt dann nen Hinweis im Banner). """ if not model_id: return False cache_root = os.environ.get("HF_HOME") or os.path.expanduser("~/.cache/huggingface") safe = "models--" + model_id.replace("/", "--") snapshots = os.path.join(cache_root, "hub", safe, "snapshots") if not os.path.isdir(snapshots): return False try: for rev in os.listdir(snapshots): rev_dir = os.path.join(snapshots, rev) if os.path.isdir(rev_dir) and any(os.scandir(rev_dir)): return True except OSError: return False return False def _torch_dtype(): """Lazy-resolve damit Torch erst beim Modell-Laden importiert wird.""" import torch return {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}\ .get(FLUX_DTYPE, torch.bfloat16) def _snap_dim(v: int, default: int = 1024) -> int: """FLUX braucht Multiples von 16 (sicher: 64). Clamp + Snap.""" try: n = int(v) except (TypeError, ValueError): n = default n = max(256, min(FLUX_MAX_DIM, n)) # Auf naechstes Vielfaches von 64 abrunden n = (n // 64) * 64 return max(256, n) class FluxRunner: """Haelt EINE FLUX-Pipeline. Bei Modell-Wechsel wird die alte verworfen und die neue geladen (~15-30 s aus HF-Cache, keine Re-Downloads). Pro Request kann ein 'dev'/'schnell'-Tag mitkommen; ohne Angabe wird `default_model_id` genommen (steht Bootstrap auf FLUX_MODEL, wird beim ersten config-Broadcast von der aria-bridge auf die Diagnostic-Wahl aktualisiert). """ def __init__(self) -> None: self.pipe = None self._lock = asyncio.Lock() # Aktuell geladenes Modell — leer solange noch nix geladen wurde. self.model_id: str = "" # Was bei einem Request OHNE explizite model-Angabe benutzt wird. # Wird durch Diagnostic-config gesetzt; FLUX_MODEL bleibt nur als # Edge-Case-Fallback wenn weder Config noch Request einen Wert nennen. self.default_model_id: str = FLUX_MODEL self.last_load_seconds: float = 0.0 # True wenn der letzte _load_blocking einen Fresh-Download triggern # musste (Modell war nicht im HF-Cache). Wird vom Caller geprueft # und in den 'ready'-service_status als freshlyDownloaded gesetzt. self.last_load_was_download: bool = False def _load_blocking(self, model_id: str) -> None: import torch from diffusers import FluxPipeline # Alte Pipeline freigeben damit der HF-Loader VRAM/RAM kriegt if self.pipe is not None: logger.info("Verwerfe alte Pipeline '%s'", self.model_id) try: del self.pipe except Exception: pass self.pipe = None try: torch.cuda.empty_cache() except Exception: pass import gc gc.collect() was_cached = _is_model_cached(model_id) self.last_load_was_download = not was_cached if not was_cached: logger.warning("FLUX '%s' nicht im HF-Cache — Erst-Download steht bevor (kann 5-10 min dauern).", model_id) logger.info("Lade FLUX '%s' (dtype=%s, offload=%s, cached=%s)...", model_id, FLUX_DTYPE, FLUX_OFFLOAD, was_cached) t0 = time.time() pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=_torch_dtype()) if FLUX_OFFLOAD == "sequential": pipe.enable_sequential_cpu_offload() elif FLUX_OFFLOAD == "none": pipe.to(FLUX_DEVICE) else: # "model" — default, Sweet-Spot fuer 12 GB Karten pipe.enable_model_cpu_offload() # VAE-Tiling spart VRAM bei grossen Bildern (>1024) try: pipe.vae.enable_tiling() except Exception: pass self.pipe = pipe self.model_id = model_id self.last_load_seconds = time.time() - t0 logger.info("FLUX '%s' geladen in %.1fs", model_id, self.last_load_seconds) try: torch.cuda.empty_cache() except Exception: pass async def ensure_loaded(self, model_id: Optional[str] = None) -> bool: """Stellt sicher dass die richtige Pipeline geladen ist. Wenn ein anderes Modell gewuenscht ist als gerade aktiv, wird geswappt. Returns True wenn ein Swap/Load stattgefunden hat.""" target = model_id or self.default_model_id or FLUX_MODEL async with self._lock: if self.pipe is not None and self.model_id == target: return False loop = asyncio.get_event_loop() await loop.run_in_executor(None, self._load_blocking, target) return True def _generate_blocking(self, prompt: str, width: int, height: int, steps: int, guidance: float, seed: Optional[int]) -> bytes: import torch gen = None if seed is not None and seed >= 0: gen = torch.Generator(device=FLUX_DEVICE).manual_seed(int(seed)) logger.info("Render (%s): %dx%d, steps=%d, guidance=%.2f, seed=%s, prompt=%r", self.model_id, width, height, steps, guidance, seed, prompt[:80]) out = self.pipe( prompt=prompt, width=width, height=height, num_inference_steps=steps, guidance_scale=guidance, generator=gen, ) image = out.images[0] buf = io.BytesIO() image.save(buf, format="PNG", optimize=True) png_bytes = buf.getvalue() # VRAM zurueckgeben fuer den naechsten Render try: torch.cuda.empty_cache() except Exception: pass return png_bytes async def generate(self, prompt: str, width: int, height: int, steps: int, guidance: float, seed: Optional[int], model_id: Optional[str] = None) -> bytes: await self.ensure_loaded(model_id) loop = asyncio.get_event_loop() return await loop.run_in_executor( None, self._generate_blocking, prompt, width, height, steps, guidance, seed, ) # ── Helpers ───────────────────────────────────────────────── async def _send(ws, mtype: str, payload: dict) -> None: try: await ws.send(json.dumps({ "type": mtype, "payload": payload, "timestamp": int(time.time() * 1000), })) except Exception as e: logger.warning("Send fehlgeschlagen (%s): %s", mtype, e) async def _broadcast_status(ws, state: str, **extra) -> None: """Sendet service_status fuer das Flux-Modul. state: 'loading' | 'ready' | 'error'.""" payload = {"service": "flux", "state": state} payload.update(extra) await _send(ws, "service_status", payload) # ── Flux-Request Queue ────────────────────────────────────── # Eine GPU, ein Render gleichzeitig. Parallele Requests OOM-en sonst. _flux_queue: "asyncio.Queue[tuple]" = asyncio.Queue() def _resolve_request(payload: dict, runner: FluxRunner) -> tuple[str, int, int, int, float, Optional[int], str]: """Liest Felder aus dem flux_request payload + clampt auf Caps. Returns (prompt, width, height, steps, guidance, seed, resolved_model_id). """ prompt = (payload.get("prompt") or "").strip() if not prompt: raise ValueError("prompt fehlt") if len(prompt) > 2000: prompt = prompt[:2000] width = _snap_dim(payload.get("width", 1024)) height = _snap_dim(payload.get("height", 1024)) # Modell-Wahl: explizit per Request > runner.default_model_id > FLUX_MODEL. req_model = (payload.get("model") or "").strip() resolved_model_id = _tag_to_model_id(req_model) if req_model else (runner.default_model_id or FLUX_MODEL) schnell = _is_schnell(resolved_model_id) default_steps = DEFAULT_STEPS_SCHNELL if schnell else DEFAULT_STEPS_DEV default_guidance = DEFAULT_GUIDANCE_SCHNELL if schnell else DEFAULT_GUIDANCE_DEV try: steps = int(payload.get("steps", default_steps)) except (TypeError, ValueError): steps = default_steps steps = max(1, min(FLUX_MAX_STEPS, steps)) try: guidance = float(payload.get("guidance_scale", default_guidance)) except (TypeError, ValueError): guidance = default_guidance if not (0.0 <= guidance <= 20.0): guidance = default_guidance seed = payload.get("seed") if seed is not None: try: seed = int(seed) except (TypeError, ValueError): seed = None return prompt, width, height, steps, guidance, seed, resolved_model_id async def _flux_worker(ws, runner: FluxRunner) -> None: """Serialisiert Renders — eine GPU, ein Bild gleichzeitig.""" while True: payload = await _flux_queue.get() request_id = payload.get("requestId") or str(uuid.uuid4()) try: await _do_render(ws, runner, payload, request_id) except Exception: logger.exception("Flux-Worker Fehler") await _send(ws, "flux_response", { "requestId": request_id, "error": "internal error", }) finally: _flux_queue.task_done() async def _do_render(ws, runner: FluxRunner, payload: dict, request_id: str) -> None: t0 = time.time() try: prompt, width, height, steps, guidance, seed, target_model_id = _resolve_request(payload, runner) except ValueError as e: logger.warning("flux_request invalid: %s", e) await _send(ws, "flux_response", {"requestId": request_id, "error": str(e)}) return # Modell-Swap noetig? Status broadcasten damit Diagnostic-Banner es zeigt. swap_needed = (runner.pipe is None or runner.model_id != target_model_id) will_download = swap_needed and not _is_model_cached(target_model_id) if swap_needed: await _broadcast_status(ws, "loading", model=target_model_id, downloading=will_download) await _send(ws, "flux_response", { "requestId": request_id, "state": "switching_model", "model": target_model_id, "downloading": will_download, }) # Progress-Ping: User soll sehen dass was passiert (Render >30s realistisch) await _send(ws, "flux_response", { "requestId": request_id, "state": "rendering", "width": width, "height": height, "steps": steps, "model": target_model_id, }) try: png = await runner.generate(prompt, width, height, steps, guidance, seed, model_id=target_model_id) except Exception as e: logger.exception("FLUX Render-Fehler") await _send(ws, "flux_response", {"requestId": request_id, "error": str(e)[:200]}) if swap_needed: await _broadcast_status(ws, "error", error=str(e)[:200]) return if swap_needed: await _broadcast_status(ws, "ready", model=runner.model_id, loadSeconds=runner.last_load_seconds, freshlyDownloaded=runner.last_load_was_download) dt = time.time() - t0 b64 = base64.b64encode(png).decode("ascii") logger.info("Render fertig: %dx%d, %d KB PNG, %.1fs (%s)", width, height, len(png) // 1024, dt, runner.model_id) await _send(ws, "flux_response", { "requestId": request_id, "state": "done", "base64": b64, "mimeType": "image/png", "width": width, "height": height, "steps": steps, "guidance": guidance, "seed": seed, "model": runner.model_id, "renderSeconds": round(dt, 2), "sizeBytes": len(png), }) # ── Haupt-Loop ────────────────────────────────────────────── async def run_loop(runner: FluxRunner) -> None: use_tls = RVS_TLS retry_s = 2 tls_fallback_tried = False while True: scheme = "wss" if use_tls else "ws" url = f"{scheme}://{RVS_HOST}:{RVS_PORT}/ws?token={RVS_TOKEN}" masked = url.replace(RVS_TOKEN, "***") if RVS_TOKEN else url try: logger.info("Verbinde zu RVS: %s", masked) # max_size 100 MB damit ein 4 MP PNG (~5-10 MB → ~13 MB base64) # locker reinpasst. Mit dem RVS-Limit (100 MB) konsistent. async with websockets.connect(url, ping_interval=20, ping_timeout=10, max_size=100 * 1024 * 1024) as ws: logger.info("RVS verbunden") retry_s = 2 tls_fallback_tried = False async def _load_with_status(): """Bei Connect KEIN Eager-Load — wir fragen erst die Diagnostic-Config ab. Welches Modell tatsaechlich geladen wird entscheidet sich entweder durch den config-Broadcast (kommt direkt danach) oder durch den ersten flux_request. Bis dahin gibt's keinen service_status, das Banner taucht erst auf wenn wir wirklich was laden.""" try: if runner.pipe is not None: # Pipeline ueberlebt nur Container-Lifetime; hier # also nur falls schon ein Modell aktiv ist (Reconnect). await _broadcast_status(ws, "ready", model=runner.model_id, loadSeconds=runner.last_load_seconds) logger.info("Initial: sende config_request an aria-bridge " "(kein Eager-Load, warte auf Diagnostic-Wahl)") await _send(ws, "config_request", {"service": "flux"}) except Exception as e: logger.exception("Initial-Setup crashed: %s", e) try: await _broadcast_status(ws, "error", error=str(e)[:200]) except Exception: pass asyncio.create_task(_load_with_status()) worker = asyncio.create_task(_flux_worker(ws, runner)) async def _apply_default_change(new_tag: str): """Wechselt den Default. Wenn ein anderes Modell als aktuell aktiv gewuenscht ist, wird eager geladen — der naechste Render ist dann ohne Swap-Delay.""" new_model_id = _tag_to_model_id(new_tag) runner.default_model_id = new_model_id if runner.model_id == new_model_id: logger.info("[config] Default-Modell bleibt: %s", new_model_id) return will_download = not _is_model_cached(new_model_id) logger.info("[config] Default-Modell wechselt: %s → %s (download=%s)", runner.model_id or "(none)", new_model_id, will_download) try: await _broadcast_status(ws, "loading", model=new_model_id, downloading=will_download) await runner.ensure_loaded(new_model_id) await _broadcast_status(ws, "ready", model=runner.model_id, loadSeconds=runner.last_load_seconds, freshlyDownloaded=runner.last_load_was_download) except Exception as e: logger.exception("Modell-Swap fehlgeschlagen") try: await _broadcast_status(ws, "error", error=str(e)[:200]) except Exception: pass try: async for raw in ws: try: msg = json.loads(raw) except Exception: continue mtype = msg.get("type", "") payload = msg.get("payload", {}) or {} if mtype == "flux_request": await _flux_queue.put(payload) elif mtype == "config": # Diagnostic-Broadcast (oder aria-bridge nach Reconnect). # HuggingFace-Token MUSS vor dem Modell-Swap gesetzt sein, # weil FluxPipeline.from_pretrained den Token aus der env # liest. Reihenfolge im selben Tick gewaehrleistet das. if "huggingfaceToken" in payload: tok = (payload.get("huggingfaceToken") or "").strip() if tok: os.environ["HF_TOKEN"] = tok os.environ["HUGGING_FACE_HUB_TOKEN"] = tok logger.info("[config] HF-Token gesetzt (len=%d)", len(tok)) else: os.environ.pop("HF_TOKEN", None) os.environ.pop("HUGGING_FACE_HUB_TOKEN", None) logger.info("[config] HF-Token entfernt (leerer Wert)") tag = (payload.get("fluxDefaultModel") or "").strip() if tag: asyncio.create_task(_apply_default_change(tag)) finally: worker.cancel() try: await worker except asyncio.CancelledError: pass except Exception as e: logger.warning("Verbindung verloren: %s", e) if use_tls and RVS_TLS_FALLBACK and not tls_fallback_tried: logger.info("TLS fehlgeschlagen — Fallback auf ws://") use_tls = False tls_fallback_tried = True continue await asyncio.sleep(min(retry_s, 30)) retry_s = min(retry_s * 2, 30) async def main() -> None: if not RVS_HOST: logger.error("RVS_HOST nicht gesetzt — Abbruch") sys.exit(1) runner = FluxRunner() await run_loop(runner) if __name__ == "__main__": try: asyncio.run(main()) except KeyboardInterrupt: sys.exit(0)