feat(flux): Modell-Wahl per Diagnostic + raw/switch-Keywords + Download-Hinweis

Diagnostic-Einstellungen fuer FLUX:
- Default-Modell (dev | schnell) — wird via RVS gepusht, flux-bridge
  hot-swappt die Pipeline aus dem HF-Cache (~15-30s)
- Raw-Keyword (Default 'flux') — Pipe-Modus, Brain leitet Stefans Text
  1:1 als prompt durch, kein Rewriting/Beautify
- Switch-Keyword (Default 'fix') — zwingt das ANDERE Modell als Default

Brain-Tool flux_generate um model + raw erweitert, System-Prompt-Block
mit den aktuellen Diagnostic-Settings + Whisper-Toleranz-Hinweis.

Kein eager Bootstrap-Load: flux-bridge wartet auf config oder ersten
Request. Bei erstem HF-Download zeigt Banner "laedt erstmalig runter"
mit Pfeil-Icon, Toast in der App wenn fertig.

FLUX_MODEL aus der .env entfernt (Steuerung jetzt komplett ueber
Diagnostic). HF_TOKEN-Kommentar erklaert warum trotz lokaler Inference
noetig (HF Gate-Mechanismus fuer FLUX.1-dev).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-16 23:11:22 +02:00
parent 7e53dcfed3
commit 2d348aeec7
8 changed files with 440 additions and 72 deletions
+189 -37
View File
@@ -51,7 +51,13 @@ RVS_TLS = os.getenv("RVS_TLS", "true").lower() == "true"
RVS_TLS_FALLBACK = os.getenv("RVS_TLS_FALLBACK", "true").lower() == "true"
RVS_TOKEN = os.getenv("RVS_TOKEN", "").strip()
FLUX_MODEL = os.getenv("FLUX_MODEL", "black-forest-labs/FLUX.1-dev").strip()
# Bootstrap-Fallback: nur relevant wenn beim allerersten Start KEIN
# Diagnostic-config-Broadcast eintrifft UND der erste Render-Request
# auch kein 'model' enthaelt. Default 'schnell', weil Apache-2.0
# (kein HF-Token noetig) — Stefan stellt sein gewuenschtes Default ueber
# Diagnostic ein. ENV ist also nur fuer den extremen Edge-Case da, in
# der .env.example absichtlich nicht mehr dokumentiert.
FLUX_MODEL = os.getenv("FLUX_MODEL", "black-forest-labs/FLUX.1-schnell").strip()
FLUX_DEVICE = os.getenv("FLUX_DEVICE", "cuda").strip()
FLUX_DTYPE = os.getenv("FLUX_DTYPE", "bfloat16").strip().lower()
FLUX_OFFLOAD = os.getenv("FLUX_OFFLOAD", "model").strip().lower()
@@ -64,11 +70,54 @@ DEFAULT_STEPS_SCHNELL = 4
DEFAULT_GUIDANCE_DEV = 3.5
DEFAULT_GUIDANCE_SCHNELL = 0.0
# Mapping fuer das User-facing Tag → HF-Modell-ID. Stefan stellt in Diagnostic
# nur 'dev' / 'schnell' ein; FLUX_MODEL aus der env kann zwar eine custom-ID
# sein (Bootstrap), wird aber beim ersten config-Broadcast normalerweise
# durch die Diagnostic-Wahl uebersteuert.
MODEL_TAGS: dict[str, str] = {
"dev": "black-forest-labs/FLUX.1-dev",
"schnell": "black-forest-labs/FLUX.1-schnell",
}
def _tag_to_model_id(tag: str) -> str:
"""Mappt 'dev'/'schnell' auf HF-ID. Andere Strings werden 1:1 durchgereicht
(custom-IDs aus FLUX_MODEL env). Leere/ungueltige Werte → FLUX_MODEL Default."""
if not tag:
return FLUX_MODEL
t = tag.strip()
return MODEL_TAGS.get(t, t)
def _is_schnell(model_id: str) -> bool:
return "schnell" in model_id.lower()
def _is_model_cached(model_id: str) -> bool:
"""Prueft ob ein HF-Modell-Snapshot lokal im hf-cache vorhanden ist.
HF speichert unter ~/.cache/huggingface/hub/models--{org}--{name}/snapshots/{rev}/.
Wenn das snapshots-Verzeichnis nicht existiert oder leer ist → Erst-Download
steht an (24+ GB fuer FLUX.1-dev, 24+ GB fuer FLUX.1-schnell — Stefan kriegt
dann nen Hinweis im Banner).
"""
if not model_id:
return False
cache_root = os.environ.get("HF_HOME") or os.path.expanduser("~/.cache/huggingface")
safe = "models--" + model_id.replace("/", "--")
snapshots = os.path.join(cache_root, "hub", safe, "snapshots")
if not os.path.isdir(snapshots):
return False
try:
for rev in os.listdir(snapshots):
rev_dir = os.path.join(snapshots, rev)
if os.path.isdir(rev_dir) and any(os.scandir(rev_dir)):
return True
except OSError:
return False
return False
def _torch_dtype():
"""Lazy-resolve damit Torch erst beim Modell-Laden importiert wird."""
import torch
@@ -89,27 +138,58 @@ def _snap_dim(v: int, default: int = 1024) -> int:
class FluxRunner:
"""Haelt die FLUX-Pipeline. Synthese laeuft im Executor (blocking).
"""Haelt EINE FLUX-Pipeline. Bei Modell-Wechsel wird die alte verworfen
und die neue geladen (~15-30 s aus HF-Cache, keine Re-Downloads).
GPU ist die Engstelle — wir serialisieren via Queue im Caller, hier
nur Single-Lock fuer load. Ein Render auf der 3060 dauert je nach
Steps/Aufloesung 20-90 s.
Pro Request kann ein 'dev'/'schnell'-Tag mitkommen; ohne Angabe wird
`default_model_id` genommen (steht Bootstrap auf FLUX_MODEL, wird beim
ersten config-Broadcast von der aria-bridge auf die Diagnostic-Wahl
aktualisiert).
"""
def __init__(self) -> None:
self.pipe = None
self._lock = asyncio.Lock()
self.model_id: str = FLUX_MODEL
# Aktuell geladenes Modell — leer solange noch nix geladen wurde.
self.model_id: str = ""
# Was bei einem Request OHNE explizite model-Angabe benutzt wird.
# Wird durch Diagnostic-config gesetzt; FLUX_MODEL bleibt nur als
# Edge-Case-Fallback wenn weder Config noch Request einen Wert nennen.
self.default_model_id: str = FLUX_MODEL
self.last_load_seconds: float = 0.0
# True wenn der letzte _load_blocking einen Fresh-Download triggern
# musste (Modell war nicht im HF-Cache). Wird vom Caller geprueft
# und in den 'ready'-service_status als freshlyDownloaded gesetzt.
self.last_load_was_download: bool = False
def _load_blocking(self) -> None:
def _load_blocking(self, model_id: str) -> None:
import torch
from diffusers import FluxPipeline
logger.info("Lade FLUX '%s' (dtype=%s, offload=%s)...",
self.model_id, FLUX_DTYPE, FLUX_OFFLOAD)
# Alte Pipeline freigeben damit der HF-Loader VRAM/RAM kriegt
if self.pipe is not None:
logger.info("Verwerfe alte Pipeline '%s'", self.model_id)
try:
del self.pipe
except Exception:
pass
self.pipe = None
try:
torch.cuda.empty_cache()
except Exception:
pass
import gc
gc.collect()
was_cached = _is_model_cached(model_id)
self.last_load_was_download = not was_cached
if not was_cached:
logger.warning("FLUX '%s' nicht im HF-Cache — Erst-Download steht bevor (kann 5-10 min dauern).",
model_id)
logger.info("Lade FLUX '%s' (dtype=%s, offload=%s, cached=%s)...",
model_id, FLUX_DTYPE, FLUX_OFFLOAD, was_cached)
t0 = time.time()
pipe = FluxPipeline.from_pretrained(self.model_id, torch_dtype=_torch_dtype())
pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=_torch_dtype())
if FLUX_OFFLOAD == "sequential":
pipe.enable_sequential_cpu_offload()
@@ -125,20 +205,25 @@ class FluxRunner:
pass
self.pipe = pipe
self.model_id = model_id
self.last_load_seconds = time.time() - t0
logger.info("FLUX geladen in %.1fs", self.last_load_seconds)
# CUDA-Cache nach dem Load aufraeumen
logger.info("FLUX '%s' geladen in %.1fs", model_id, self.last_load_seconds)
try:
torch.cuda.empty_cache()
except Exception:
pass
async def ensure_loaded(self) -> None:
async def ensure_loaded(self, model_id: Optional[str] = None) -> bool:
"""Stellt sicher dass die richtige Pipeline geladen ist. Wenn ein
anderes Modell gewuenscht ist als gerade aktiv, wird geswappt.
Returns True wenn ein Swap/Load stattgefunden hat."""
target = model_id or self.default_model_id or FLUX_MODEL
async with self._lock:
if self.pipe is not None:
return
if self.pipe is not None and self.model_id == target:
return False
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, self._load_blocking)
await loop.run_in_executor(None, self._load_blocking, target)
return True
def _generate_blocking(self, prompt: str, width: int, height: int,
steps: int, guidance: float, seed: Optional[int]) -> bytes:
@@ -147,8 +232,8 @@ class FluxRunner:
if seed is not None and seed >= 0:
gen = torch.Generator(device=FLUX_DEVICE).manual_seed(int(seed))
logger.info("Render: %dx%d, steps=%d, guidance=%.2f, seed=%s, prompt=%r",
width, height, steps, guidance, seed, prompt[:80])
logger.info("Render (%s): %dx%d, steps=%d, guidance=%.2f, seed=%s, prompt=%r",
self.model_id, width, height, steps, guidance, seed, prompt[:80])
out = self.pipe(
prompt=prompt,
width=width,
@@ -169,8 +254,9 @@ class FluxRunner:
return png_bytes
async def generate(self, prompt: str, width: int, height: int,
steps: int, guidance: float, seed: Optional[int]) -> bytes:
await self.ensure_loaded()
steps: int, guidance: float, seed: Optional[int],
model_id: Optional[str] = None) -> bytes:
await self.ensure_loaded(model_id)
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None, self._generate_blocking, prompt, width, height, steps, guidance, seed,
@@ -205,8 +291,10 @@ async def _broadcast_status(ws, state: str, **extra) -> None:
_flux_queue: "asyncio.Queue[tuple]" = asyncio.Queue()
def _resolve_request(payload: dict, runner: FluxRunner) -> tuple[str, int, int, int, float, Optional[int]]:
"""Liest Felder aus dem flux_request payload + clampt auf Caps."""
def _resolve_request(payload: dict, runner: FluxRunner) -> tuple[str, int, int, int, float, Optional[int], str]:
"""Liest Felder aus dem flux_request payload + clampt auf Caps.
Returns (prompt, width, height, steps, guidance, seed, resolved_model_id).
"""
prompt = (payload.get("prompt") or "").strip()
if not prompt:
raise ValueError("prompt fehlt")
@@ -216,7 +304,11 @@ def _resolve_request(payload: dict, runner: FluxRunner) -> tuple[str, int, int,
width = _snap_dim(payload.get("width", 1024))
height = _snap_dim(payload.get("height", 1024))
schnell = _is_schnell(runner.model_id)
# Modell-Wahl: explizit per Request > runner.default_model_id > FLUX_MODEL.
req_model = (payload.get("model") or "").strip()
resolved_model_id = _tag_to_model_id(req_model) if req_model else (runner.default_model_id or FLUX_MODEL)
schnell = _is_schnell(resolved_model_id)
default_steps = DEFAULT_STEPS_SCHNELL if schnell else DEFAULT_STEPS_DEV
default_guidance = DEFAULT_GUIDANCE_SCHNELL if schnell else DEFAULT_GUIDANCE_DEV
@@ -240,7 +332,7 @@ def _resolve_request(payload: dict, runner: FluxRunner) -> tuple[str, int, int,
except (TypeError, ValueError):
seed = None
return prompt, width, height, steps, guidance, seed
return prompt, width, height, steps, guidance, seed, resolved_model_id
async def _flux_worker(ws, runner: FluxRunner) -> None:
@@ -263,29 +355,53 @@ async def _flux_worker(ws, runner: FluxRunner) -> None:
async def _do_render(ws, runner: FluxRunner, payload: dict, request_id: str) -> None:
t0 = time.time()
try:
prompt, width, height, steps, guidance, seed = _resolve_request(payload, runner)
prompt, width, height, steps, guidance, seed, target_model_id = _resolve_request(payload, runner)
except ValueError as e:
logger.warning("flux_request invalid: %s", e)
await _send(ws, "flux_response", {"requestId": request_id, "error": str(e)})
return
# Modell-Swap noetig? Status broadcasten damit Diagnostic-Banner es zeigt.
swap_needed = (runner.pipe is None or runner.model_id != target_model_id)
will_download = swap_needed and not _is_model_cached(target_model_id)
if swap_needed:
await _broadcast_status(ws, "loading", model=target_model_id,
downloading=will_download)
await _send(ws, "flux_response", {
"requestId": request_id,
"state": "switching_model",
"model": target_model_id,
"downloading": will_download,
})
# Progress-Ping: User soll sehen dass was passiert (Render >30s realistisch)
await _send(ws, "flux_response", {
"requestId": request_id,
"state": "rendering",
"width": width, "height": height, "steps": steps,
"model": target_model_id,
})
try:
png = await runner.generate(prompt, width, height, steps, guidance, seed)
png = await runner.generate(prompt, width, height, steps, guidance, seed,
model_id=target_model_id)
except Exception as e:
logger.exception("FLUX Render-Fehler")
await _send(ws, "flux_response", {"requestId": request_id, "error": str(e)[:200]})
if swap_needed:
await _broadcast_status(ws, "error", error=str(e)[:200])
return
if swap_needed:
await _broadcast_status(ws, "ready",
model=runner.model_id,
loadSeconds=runner.last_load_seconds,
freshlyDownloaded=runner.last_load_was_download)
dt = time.time() - t0
b64 = base64.b64encode(png).decode("ascii")
logger.info("Render fertig: %dx%d, %d KB PNG, %.1fs", width, height, len(png) // 1024, dt)
logger.info("Render fertig: %dx%d, %d KB PNG, %.1fs (%s)",
width, height, len(png) // 1024, dt, runner.model_id)
await _send(ws, "flux_response", {
"requestId": request_id,
@@ -327,22 +443,24 @@ async def run_loop(runner: FluxRunner) -> None:
tls_fallback_tried = False
async def _load_with_status():
"""Bei Connect KEIN Eager-Load — wir fragen erst die
Diagnostic-Config ab. Welches Modell tatsaechlich geladen
wird entscheidet sich entweder durch den config-Broadcast
(kommt direkt danach) oder durch den ersten flux_request.
Bis dahin gibt's keinen service_status, das Banner taucht
erst auf wenn wir wirklich was laden."""
try:
if runner.pipe is not None:
logger.info("Initial: broadcaste ready (Pipeline schon im RAM: %s)",
runner.model_id)
await _broadcast_status(ws, "ready",
model=runner.model_id,
loadSeconds=runner.last_load_seconds)
else:
logger.info("Initial: broadcaste loading + lade '%s'", runner.model_id)
await _broadcast_status(ws, "loading", model=runner.model_id)
await runner.ensure_loaded()
# Pipeline ueberlebt nur Container-Lifetime; hier
# also nur falls schon ein Modell aktiv ist (Reconnect).
await _broadcast_status(ws, "ready",
model=runner.model_id,
loadSeconds=runner.last_load_seconds)
logger.info("Initial: sende config_request an aria-bridge "
"(kein Eager-Load, warte auf Diagnostic-Wahl)")
await _send(ws, "config_request", {"service": "flux"})
except Exception as e:
logger.exception("Initial-Load crashed: %s", e)
logger.exception("Initial-Setup crashed: %s", e)
try:
await _broadcast_status(ws, "error", error=str(e)[:200])
except Exception:
@@ -351,6 +469,33 @@ async def run_loop(runner: FluxRunner) -> None:
worker = asyncio.create_task(_flux_worker(ws, runner))
async def _apply_default_change(new_tag: str):
"""Wechselt den Default. Wenn ein anderes Modell als aktuell
aktiv gewuenscht ist, wird eager geladen — der naechste
Render ist dann ohne Swap-Delay."""
new_model_id = _tag_to_model_id(new_tag)
runner.default_model_id = new_model_id
if runner.model_id == new_model_id:
logger.info("[config] Default-Modell bleibt: %s", new_model_id)
return
will_download = not _is_model_cached(new_model_id)
logger.info("[config] Default-Modell wechselt: %s%s (download=%s)",
runner.model_id or "(none)", new_model_id, will_download)
try:
await _broadcast_status(ws, "loading", model=new_model_id,
downloading=will_download)
await runner.ensure_loaded(new_model_id)
await _broadcast_status(ws, "ready",
model=runner.model_id,
loadSeconds=runner.last_load_seconds,
freshlyDownloaded=runner.last_load_was_download)
except Exception as e:
logger.exception("Modell-Swap fehlgeschlagen")
try:
await _broadcast_status(ws, "error", error=str(e)[:200])
except Exception:
pass
try:
async for raw in ws:
try:
@@ -362,6 +507,13 @@ async def run_loop(runner: FluxRunner) -> None:
if mtype == "flux_request":
await _flux_queue.put(payload)
elif mtype == "config":
# Diagnostic-Broadcast (oder aria-bridge nach Reconnect).
# Wir interessieren uns nur fuer fluxDefaultModel — die
# Keywords nutzt das Brain, nicht wir.
tag = (payload.get("fluxDefaultModel") or "").strip()
if tag:
asyncio.create_task(_apply_default_change(tag))
finally:
worker.cancel()
try: