diff --git a/xtts/docker-compose.yml b/xtts/docker-compose.yml index c6bf89d..24f4c12 100644 --- a/xtts/docker-compose.yml +++ b/xtts/docker-compose.yml @@ -39,7 +39,11 @@ services: - RVS_TLS_FALLBACK=${RVS_TLS_FALLBACK:-true} - RVS_TOKEN=${RVS_TOKEN} - F5TTS_MODEL=${F5TTS_MODEL:-F5TTS_v1_Base} + - F5TTS_CKPT_FILE=${F5TTS_CKPT_FILE:-} + - F5TTS_VOCAB_FILE=${F5TTS_VOCAB_FILE:-} - F5TTS_DEVICE=${F5TTS_DEVICE:-cuda} + - F5TTS_CFG_STRENGTH=${F5TTS_CFG_STRENGTH:-2.5} + - F5TTS_NFE_STEP=${F5TTS_NFE_STEP:-32} - VOICES_DIR=/voices restart: unless-stopped diff --git a/xtts/f5tts/bridge.py b/xtts/f5tts/bridge.py index 737b4ec..769932d 100644 --- a/xtts/f5tts/bridge.py +++ b/xtts/f5tts/bridge.py @@ -53,7 +53,14 @@ RVS_TLS_FALLBACK = os.getenv("RVS_TLS_FALLBACK", "true").lower() == "true" RVS_TOKEN = os.getenv("RVS_TOKEN", "").strip() F5TTS_MODEL = os.getenv("F5TTS_MODEL", "F5TTS_v1_Base") +F5TTS_CKPT_FILE = os.getenv("F5TTS_CKPT_FILE", "") # optional: HF-Repo oder lokales .pt +F5TTS_VOCAB_FILE = os.getenv("F5TTS_VOCAB_FILE", "") # optional: zugehoerige vocab.txt F5TTS_DEVICE = os.getenv("F5TTS_DEVICE", "cuda") +# cfg_strength: wie stark der Generator am Referenz-Voice klebt. +# Default F5-TTS = 2.0. Bei nicht-EN/CN Sprachen (Deutsch!) hilft >2.5 +# damit das Modell nicht in eine andere Sprache abrutscht. +F5TTS_CFG_STRENGTH = float(os.getenv("F5TTS_CFG_STRENGTH", "2.5")) +F5TTS_NFE_STEP = int(os.getenv("F5TTS_NFE_STEP", "32")) VOICES_DIR = Path(os.getenv("VOICES_DIR", "/voices")) PCM_CHUNK_BYTES = 8192 # ~170ms @ 24kHz mono s16 @@ -82,10 +89,17 @@ class F5Runner: def _load_blocking(self) -> None: cls = _get_f5tts_cls() - logger.info("Lade F5-TTS '%s' (device=%s)...", F5TTS_MODEL, F5TTS_DEVICE) + logger.info("Lade F5-TTS '%s' (device=%s, ckpt=%s)...", + F5TTS_MODEL, F5TTS_DEVICE, F5TTS_CKPT_FILE or "default") t0 = time.time() - self.model = cls(model=F5TTS_MODEL, device=F5TTS_DEVICE) - logger.info("F5-TTS geladen in %.1fs", time.time() - t0) + kwargs = {"model": F5TTS_MODEL, "device": F5TTS_DEVICE} + if F5TTS_CKPT_FILE: + kwargs["ckpt_file"] = F5TTS_CKPT_FILE + if F5TTS_VOCAB_FILE: + kwargs["vocab_file"] = F5TTS_VOCAB_FILE + self.model = cls(**kwargs) + logger.info("F5-TTS geladen in %.1fs (cfg_strength=%.1f, nfe=%d)", + time.time() - t0, F5TTS_CFG_STRENGTH, F5TTS_NFE_STEP) async def ensure_loaded(self) -> None: async with self._lock: @@ -95,12 +109,18 @@ class F5Runner: await loop.run_in_executor(None, self._load_blocking) def _infer_blocking(self, gen_text: str, ref_wav: str, ref_text: str) -> tuple[np.ndarray, int]: + # cfg_strength + nfe_step erhoeht damit das Modell nicht in andere + # Sprachen abrutscht (Bug bei Deutsch: rutscht ohne Verstaerkung + # gerne ins Spanische ab, weil F5TTS_v1_Base hauptsaechlich auf EN+CN + # trainiert ist). wav, sr, _ = self.model.infer( ref_file=ref_wav, ref_text=ref_text, gen_text=gen_text, remove_silence=True, seed=-1, + cfg_strength=F5TTS_CFG_STRENGTH, + nfe_step=F5TTS_NFE_STEP, ) # F5-TTS gibt float32 1D-Array — auf 24kHz sample-rate standard if not isinstance(wav, np.ndarray):