Spaces:
Running on Zero
Running on Zero
| from pathlib import Path | |
| from typing import Generator | |
| import librosa | |
| import numpy as np | |
| import torch | |
| from neucodec import NeuCodec, DistillNeuCodec | |
| from utils.phonemize_text import phonemize_with_dict | |
| from collections import defaultdict | |
| from concurrent.futures import ThreadPoolExecutor | |
| import re | |
| import gc | |
| # ============================================================================ | |
| # Shared Utilities | |
| # ============================================================================ | |
| def _linear_overlap_add(frames: list[np.ndarray], stride: int) -> np.ndarray: | |
| """Linear overlap-add for smooth audio concatenation""" | |
| assert len(frames) | |
| dtype = frames[0].dtype | |
| shape = frames[0].shape[:-1] | |
| total_size = 0 | |
| for i, frame in enumerate(frames): | |
| frame_end = stride * i + frame.shape[-1] | |
| total_size = max(total_size, frame_end) | |
| sum_weight = np.zeros(total_size, dtype=dtype) | |
| out = np.zeros(*shape, total_size, dtype=dtype) | |
| offset: int = 0 | |
| for frame in frames: | |
| frame_length = frame.shape[-1] | |
| t = np.linspace(0, 1, frame_length + 2, dtype=dtype)[1:-1] | |
| weight = np.abs(0.5 - (t - 0.5)) | |
| out[..., offset : offset + frame_length] += weight * frame | |
| sum_weight[offset : offset + frame_length] += weight | |
| offset += stride | |
| assert sum_weight.min() > 0 | |
| return out / sum_weight | |
| def _compile_codec_with_triton(codec): | |
| """Compile codec with Triton for faster decoding (Windows/Linux compatible)""" | |
| try: | |
| import triton | |
| if hasattr(codec, 'dec') and hasattr(codec.dec, 'resblocks'): | |
| if len(codec.dec.resblocks) > 2: | |
| codec.dec.resblocks[2].forward = torch.compile( | |
| codec.dec.resblocks[2].forward, | |
| mode="reduce-overhead", | |
| dynamic=True | |
| ) | |
| print(" ✅ Triton compilation enabled for codec") | |
| return True | |
| except ImportError: | |
| print(" ⚠️ Triton not found. Install for faster speed:") | |
| print(" • Linux: pip install triton") | |
| print(" • Windows: pip install triton-windows") | |
| print(" (Optional but recommended)") | |
| return False | |
| # ============================================================================ | |
| # VieNeuTTS - Standard implementation (CPU/GPU compatible) | |
| # Supports: PyTorch Transformers, GGUF/GGML quantized models | |
| # ============================================================================ | |
| class VieNeuTTS: | |
| """ | |
| Standard VieNeu-TTS implementation. | |
| Supports: | |
| - PyTorch + Transformers backend (CPU/GPU) | |
| - GGUF quantized models via llama-cpp-python (CPU optimized) | |
| Use this for: | |
| - CPU-only environments | |
| - Standard PyTorch workflows | |
| - GGUF quantized models | |
| """ | |
| def __init__( | |
| self, | |
| backbone_repo="pnnbao-ump/VieNeu-TTS", | |
| backbone_device="cpu", | |
| codec_repo="neuphonic/neucodec", | |
| codec_device="cpu", | |
| ): | |
| """ | |
| Initialize VieNeu-TTS. | |
| Args: | |
| backbone_repo: Model repository or path to GGUF file | |
| backbone_device: Device for backbone ('cpu', 'cuda', 'gpu') | |
| codec_repo: Codec repository | |
| codec_device: Device for codec | |
| """ | |
| # Constants | |
| self.sample_rate = 24_000 | |
| self.max_context = 2048 | |
| self.hop_length = 480 | |
| self.streaming_overlap_frames = 1 | |
| self.streaming_frames_per_chunk = 25 | |
| self.streaming_lookforward = 5 | |
| self.streaming_lookback = 50 | |
| self.streaming_stride_samples = self.streaming_frames_per_chunk * self.hop_length | |
| # Flags | |
| self._is_quantized_model = False | |
| self._is_onnx_codec = False | |
| # HF tokenizer | |
| self.tokenizer = None | |
| # Load models | |
| self._load_backbone(backbone_repo, backbone_device) | |
| self._load_codec(codec_repo, codec_device) | |
| def _load_backbone(self, backbone_repo, backbone_device): | |
| # MPS device validation | |
| if backbone_device == "mps": | |
| if not torch.backends.mps.is_available(): | |
| print("Warning: MPS not available, falling back to CPU") | |
| backbone_device = "cpu" | |
| print(f"Loading backbone from: {backbone_repo} on {backbone_device} ...") | |
| if backbone_repo.lower().endswith("gguf") or "gguf" in backbone_repo.lower(): | |
| try: | |
| from llama_cpp import Llama | |
| except ImportError as e: | |
| raise ImportError( | |
| "Failed to import `llama_cpp`. " | |
| "Xem hướng dẫn cài đặt llama_cpp_python phiên bản tối thiểu 0.3.16 tại: https://llama-cpp-python.readthedocs.io/en/latest/" | |
| ) from e | |
| self.backbone = Llama.from_pretrained( | |
| repo_id=backbone_repo, | |
| filename="*.gguf", | |
| verbose=False, | |
| n_gpu_layers=-1 if backbone_device == "gpu" else 0, | |
| n_ctx=self.max_context, | |
| mlock=True, | |
| flash_attn=True if backbone_device == "gpu" else False, | |
| ) | |
| self._is_quantized_model = True | |
| else: | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| self.tokenizer = AutoTokenizer.from_pretrained(backbone_repo) | |
| self.backbone = AutoModelForCausalLM.from_pretrained(backbone_repo).to( | |
| torch.device(backbone_device) | |
| ) | |
| def _load_codec(self, codec_repo, codec_device): | |
| # MPS device validation | |
| if codec_device == "mps": | |
| if not torch.backends.mps.is_available(): | |
| print("Warning: MPS not available for codec, falling back to CPU") | |
| codec_device = "cpu" | |
| print(f"Loading codec from: {codec_repo} on {codec_device} ...") | |
| match codec_repo: | |
| case "neuphonic/neucodec": | |
| self.codec = NeuCodec.from_pretrained(codec_repo) | |
| self.codec.eval().to(codec_device) | |
| case "neuphonic/distill-neucodec": | |
| self.codec = DistillNeuCodec.from_pretrained(codec_repo) | |
| self.codec.eval().to(codec_device) | |
| case "neuphonic/neucodec-onnx-decoder-int8": | |
| if codec_device != "cpu": | |
| raise ValueError("Onnx decoder only currently runs on CPU.") | |
| try: | |
| from neucodec import NeuCodecOnnxDecoder | |
| except ImportError as e: | |
| raise ImportError( | |
| "Failed to import the onnx decoder." | |
| "Ensure you have onnxruntime installed as well as neucodec >= 0.0.4." | |
| ) from e | |
| self.codec = NeuCodecOnnxDecoder.from_pretrained(codec_repo) | |
| self._is_onnx_codec = True | |
| case _: | |
| raise ValueError(f"Unsupported codec repository: {codec_repo}") | |
| def encode_reference(self, ref_audio_path: str | Path): | |
| """Encode reference audio to codes""" | |
| wav, _ = librosa.load(ref_audio_path, sr=16000, mono=True) | |
| wav_tensor = torch.from_numpy(wav).float().unsqueeze(0).unsqueeze(0) # [1, 1, T] | |
| with torch.no_grad(): | |
| ref_codes = self.codec.encode_code(audio_or_path=wav_tensor).squeeze(0).squeeze(0) | |
| return ref_codes | |
| def infer(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> np.ndarray: | |
| """ | |
| Perform inference to generate speech from text using the TTS model and reference audio. | |
| Args: | |
| text (str): Input text to be converted to speech. | |
| ref_codes (np.ndarray | torch.tensor): Encoded reference. | |
| ref_text (str): Reference text for reference audio. | |
| Returns: | |
| np.ndarray: Generated speech waveform. | |
| """ | |
| # Generate tokens | |
| if self._is_quantized_model: | |
| output_str = self._infer_ggml(ref_codes, ref_text, text) | |
| else: | |
| prompt_ids = self._apply_chat_template(ref_codes, ref_text, text) | |
| output_str = self._infer_torch(prompt_ids) | |
| # Decode | |
| wav = self._decode(output_str) | |
| return wav | |
| def infer_stream(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> Generator[np.ndarray, None, None]: | |
| """ | |
| Perform streaming inference to generate speech from text using the TTS model and reference audio. | |
| Args: | |
| text (str): Input text to be converted to speech. | |
| ref_codes (np.ndarray | torch.tensor): Encoded reference. | |
| ref_text (str): Reference text for reference audio. | |
| Yields: | |
| np.ndarray: Generated speech waveform. | |
| """ | |
| if self._is_quantized_model: | |
| return self._infer_stream_ggml(ref_codes, ref_text, text) | |
| else: | |
| raise NotImplementedError("Streaming is not implemented for the torch backend!") | |
| def _decode(self, codes: str): | |
| """Decode speech tokens to audio waveform.""" | |
| # Extract speech token IDs using regex | |
| speech_ids = [int(num) for num in re.findall(r"<\|speech_(\d+)\|>", codes)] | |
| if len(speech_ids) == 0: | |
| raise ValueError( | |
| "No valid speech tokens found in the output. " | |
| "The model may not have generated proper speech tokens." | |
| ) | |
| # Onnx decode | |
| if self._is_onnx_codec: | |
| codes = np.array(speech_ids, dtype=np.int32)[np.newaxis, np.newaxis, :] | |
| recon = self.codec.decode_code(codes) | |
| # Torch decode | |
| else: | |
| with torch.no_grad(): | |
| codes = torch.tensor(speech_ids, dtype=torch.long)[None, None, :].to( | |
| self.codec.device | |
| ) | |
| recon = self.codec.decode_code(codes).cpu().numpy() | |
| return recon[0, 0, :] | |
| def _apply_chat_template(self, ref_codes: list[int], ref_text: str, input_text: str) -> list[int]: | |
| input_text = phonemize_with_dict(ref_text) + " " + phonemize_with_dict(input_text) | |
| speech_replace = self.tokenizer.convert_tokens_to_ids("<|SPEECH_REPLACE|>") | |
| speech_gen_start = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_START|>") | |
| text_replace = self.tokenizer.convert_tokens_to_ids("<|TEXT_REPLACE|>") | |
| text_prompt_start = self.tokenizer.convert_tokens_to_ids("<|TEXT_PROMPT_START|>") | |
| text_prompt_end = self.tokenizer.convert_tokens_to_ids("<|TEXT_PROMPT_END|>") | |
| input_ids = self.tokenizer.encode(input_text, add_special_tokens=False) | |
| chat = """user: Convert the text to speech:<|TEXT_REPLACE|>\nassistant:<|SPEECH_REPLACE|>""" | |
| ids = self.tokenizer.encode(chat) | |
| text_replace_idx = ids.index(text_replace) | |
| ids = ( | |
| ids[:text_replace_idx] | |
| + [text_prompt_start] | |
| + input_ids | |
| + [text_prompt_end] | |
| + ids[text_replace_idx + 1 :] # noqa | |
| ) | |
| speech_replace_idx = ids.index(speech_replace) | |
| codes_str = "".join([f"<|speech_{i}|>" for i in ref_codes]) | |
| codes = self.tokenizer.encode(codes_str, add_special_tokens=False) | |
| ids = ids[:speech_replace_idx] + [speech_gen_start] + list(codes) | |
| return ids | |
| def _infer_torch(self, prompt_ids: list[int]) -> str: | |
| prompt_tensor = torch.tensor(prompt_ids).unsqueeze(0).to(self.backbone.device) | |
| speech_end_id = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>") | |
| with torch.no_grad(): | |
| output_tokens = self.backbone.generate( | |
| prompt_tensor, | |
| max_length=self.max_context, | |
| eos_token_id=speech_end_id, | |
| do_sample=True, | |
| temperature=1.0, | |
| top_k=50, | |
| use_cache=True, | |
| min_new_tokens=50, | |
| ) | |
| input_length = prompt_tensor.shape[-1] | |
| output_str = self.tokenizer.decode( | |
| output_tokens[0, input_length:].cpu().numpy().tolist(), add_special_tokens=False | |
| ) | |
| return output_str | |
| def _infer_ggml(self, ref_codes: list[int], ref_text: str, input_text: str) -> str: | |
| ref_text = phonemize_with_dict(ref_text) | |
| input_text = phonemize_with_dict(input_text) | |
| codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes]) | |
| prompt = ( | |
| f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text} {input_text}" | |
| f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}" | |
| ) | |
| output = self.backbone( | |
| prompt, | |
| max_tokens=self.max_context, | |
| temperature=0.7, | |
| top_k=50, | |
| stop=["<|SPEECH_GENERATION_END|>"], | |
| ) | |
| output_str = output["choices"][0]["text"] | |
| return output_str | |
| def _infer_stream_ggml(self, ref_codes: torch.Tensor, ref_text: str, input_text: str) -> Generator[np.ndarray, None, None]: | |
| ref_text = phonemize_with_dict(ref_text) | |
| input_text = phonemize_with_dict(input_text) | |
| codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes]) | |
| prompt = ( | |
| f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text} {input_text}" | |
| f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}" | |
| ) | |
| audio_cache: list[np.ndarray] = [] | |
| token_cache: list[str] = [f"<|speech_{idx}|>" for idx in ref_codes] | |
| n_decoded_samples: int = 0 | |
| n_decoded_tokens: int = len(ref_codes) | |
| for item in self.backbone( | |
| prompt, | |
| max_tokens=self.max_context, | |
| temperature=0.7, | |
| top_k=50, | |
| stop=["<|SPEECH_GENERATION_END|>"], | |
| stream=True | |
| ): | |
| output_str = item["choices"][0]["text"] | |
| token_cache.append(output_str) | |
| if len(token_cache[n_decoded_tokens:]) >= self.streaming_frames_per_chunk + self.streaming_lookforward: | |
| # decode chunk | |
| tokens_start = max( | |
| n_decoded_tokens | |
| - self.streaming_lookback | |
| - self.streaming_overlap_frames, | |
| 0 | |
| ) | |
| tokens_end = ( | |
| n_decoded_tokens | |
| + self.streaming_frames_per_chunk | |
| + self.streaming_lookforward | |
| + self.streaming_overlap_frames | |
| ) | |
| sample_start = ( | |
| n_decoded_tokens - tokens_start | |
| ) * self.hop_length | |
| sample_end = ( | |
| sample_start | |
| + (self.streaming_frames_per_chunk + 2 * self.streaming_overlap_frames) * self.hop_length | |
| ) | |
| curr_codes = token_cache[tokens_start:tokens_end] | |
| recon = self._decode("".join(curr_codes)) | |
| recon = recon[sample_start:sample_end] | |
| audio_cache.append(recon) | |
| # postprocess | |
| processed_recon = _linear_overlap_add( | |
| audio_cache, stride=self.streaming_stride_samples | |
| ) | |
| new_samples_end = len(audio_cache) * self.streaming_stride_samples | |
| processed_recon = processed_recon[ | |
| n_decoded_samples:new_samples_end | |
| ] | |
| n_decoded_samples = new_samples_end | |
| n_decoded_tokens += self.streaming_frames_per_chunk | |
| yield processed_recon | |
| # final decoding handled separately as non-constant chunk size | |
| remaining_tokens = len(token_cache) - n_decoded_tokens | |
| if len(token_cache) > n_decoded_tokens: | |
| tokens_start = max( | |
| len(token_cache) | |
| - (self.streaming_lookback + self.streaming_overlap_frames + remaining_tokens), | |
| 0 | |
| ) | |
| sample_start = ( | |
| len(token_cache) | |
| - tokens_start | |
| - remaining_tokens | |
| - self.streaming_overlap_frames | |
| ) * self.hop_length | |
| curr_codes = token_cache[tokens_start:] | |
| recon = self._decode("".join(curr_codes)) | |
| recon = recon[sample_start:] | |
| audio_cache.append(recon) | |
| processed_recon = _linear_overlap_add(audio_cache, stride=self.streaming_stride_samples) | |
| processed_recon = processed_recon[n_decoded_samples:] | |
| yield processed_recon | |
| # ============================================================================ | |
| # FastVieNeuTTS - GPU-optimized implementation | |
| # Requires: LMDeploy with CUDA | |
| # ============================================================================ | |
| class FastVieNeuTTS: | |
| """ | |
| GPU-optimized VieNeu-TTS using LMDeploy TurbomindEngine. | |
| """ | |
| def __init__( | |
| self, | |
| backbone_repo="pnnbao-ump/VieNeu-TTS", | |
| backbone_device="cuda", | |
| codec_repo="neuphonic/neucodec", | |
| codec_device="cuda", | |
| memory_util=0.3, | |
| tp=1, | |
| enable_prefix_caching=True, | |
| quant_policy=0, | |
| enable_triton=True, | |
| max_batch_size=8, | |
| ): | |
| """ | |
| Initialize FastVieNeuTTS with LMDeploy backend and optimizations. | |
| Args: | |
| backbone_repo: Model repository | |
| backbone_device: Device for backbone (must be CUDA) | |
| codec_repo: Codec repository | |
| codec_device: Device for codec | |
| memory_util: GPU memory utilization (0.0-1.0) | |
| tp: Tensor parallel size for multi-GPU | |
| enable_prefix_caching: Enable prefix caching for faster batch processing | |
| quant_policy: KV cache quantization (0=off, 8=int8, 4=int4) | |
| enable_triton: Enable Triton compilation for codec | |
| max_batch_size: Maximum batch size for inference (prevent GPU overload) | |
| """ | |
| if backbone_device != "cuda" and not backbone_device.startswith("cuda:"): | |
| raise ValueError("LMDeploy backend requires CUDA device") | |
| # Constants | |
| self.sample_rate = 24_000 | |
| self.max_context = 2048 | |
| self.hop_length = 480 | |
| self.streaming_overlap_frames = 1 | |
| self.streaming_frames_per_chunk = 50 | |
| self.streaming_lookforward = 5 | |
| self.streaming_lookback = 50 | |
| self.streaming_stride_samples = self.streaming_frames_per_chunk * self.hop_length | |
| self.max_batch_size = max_batch_size | |
| self._ref_cache = {} | |
| self.stored_dict = defaultdict(dict) | |
| # Flags | |
| self._is_onnx_codec = False | |
| self._triton_enabled = False | |
| # Load models | |
| self._load_backbone_lmdeploy(backbone_repo, memory_util, tp, enable_prefix_caching, quant_policy) | |
| self._load_codec(codec_repo, codec_device, enable_triton) | |
| self._warmup_model() | |
| print("✅ FastVieNeuTTS with optimizations loaded successfully!") | |
| print(f" Max batch size: {self.max_batch_size} (adjustable to prevent GPU overload)") | |
| def _load_backbone_lmdeploy(self, repo, memory_util, tp, enable_prefix_caching, quant_policy): | |
| """Load backbone using LMDeploy's TurbomindEngine""" | |
| print(f"Loading backbone with LMDeploy from: {repo}") | |
| try: | |
| from lmdeploy import pipeline, TurbomindEngineConfig, GenerationConfig | |
| except ImportError as e: | |
| raise ImportError( | |
| "Failed to import `lmdeploy`. " | |
| "Xem hướng dẫn cài đặt lmdeploy để tối ưu hiệu suất GPU tại: https://github.com/pnnbao97/VieNeu-TTS" | |
| ) from e | |
| backend_config = TurbomindEngineConfig( | |
| cache_max_entry_count=memory_util, | |
| tp=tp, | |
| enable_prefix_caching=enable_prefix_caching, | |
| dtype='bfloat16', | |
| quant_policy=quant_policy | |
| ) | |
| self.backbone = pipeline(repo, backend_config=backend_config) | |
| self.gen_config = GenerationConfig( | |
| top_p=0.95, | |
| top_k=50, | |
| temperature=0.7, | |
| max_new_tokens=2048, | |
| do_sample=True, | |
| min_new_tokens=40, | |
| ) | |
| print(f" LMDeploy TurbomindEngine initialized") | |
| print(f" - Memory util: {memory_util}") | |
| print(f" - Tensor Parallel: {tp}") | |
| print(f" - Prefix caching: {enable_prefix_caching}") | |
| print(f" - KV quant: {quant_policy} ({'Enabled' if quant_policy > 0 else 'Disabled'})") | |
| def _load_codec(self, codec_repo, codec_device, enable_triton): | |
| """Load codec with optional Triton compilation""" | |
| print(f"Loading codec from: {codec_repo} on {codec_device}") | |
| match codec_repo: | |
| case "neuphonic/neucodec": | |
| self.codec = NeuCodec.from_pretrained(codec_repo) | |
| self.codec.eval().to(codec_device) | |
| case "neuphonic/distill-neucodec": | |
| self.codec = DistillNeuCodec.from_pretrained(codec_repo) | |
| self.codec.eval().to(codec_device) | |
| case "neuphonic/neucodec-onnx-decoder-int8": | |
| if codec_device != "cpu": | |
| raise ValueError("ONNX decoder only runs on CPU") | |
| try: | |
| from neucodec import NeuCodecOnnxDecoder | |
| except ImportError as e: | |
| raise ImportError( | |
| "Failed to import ONNX decoder. " | |
| "Ensure onnxruntime and neucodec >= 0.0.4 are installed." | |
| ) from e | |
| self.codec = NeuCodecOnnxDecoder.from_pretrained(codec_repo) | |
| self._is_onnx_codec = True | |
| case _: | |
| raise ValueError(f"Unsupported codec repository: {codec_repo}") | |
| if enable_triton and not self._is_onnx_codec and codec_device != "cpu": | |
| self._triton_enabled = _compile_codec_with_triton(self.codec) | |
| def _warmup_model(self): | |
| """Warmup inference pipeline to reduce first-token latency""" | |
| print("🔥 Warming up model...") | |
| try: | |
| dummy_codes = list(range(10)) | |
| dummy_prompt = self._format_prompt(dummy_codes, "warmup", "test") | |
| _ = self.backbone([dummy_prompt], gen_config=self.gen_config, do_preprocess=False) | |
| print(" ✅ Warmup complete") | |
| except Exception as e: | |
| print(f" ⚠️ Warmup failed (non-critical): {e}") | |
| def encode_reference(self, ref_audio_path: str | Path): | |
| """Encode reference audio to codes""" | |
| wav, _ = librosa.load(ref_audio_path, sr=16000, mono=True) | |
| wav_tensor = torch.from_numpy(wav).float().unsqueeze(0).unsqueeze(0) | |
| with torch.no_grad(): | |
| ref_codes = self.codec.encode_code(audio_or_path=wav_tensor).squeeze(0).squeeze(0) | |
| return ref_codes | |
| def get_cached_reference(self, voice_name: str, audio_path: str, ref_text: str = None): | |
| """ | |
| Get or create cached reference codes. | |
| Args: | |
| voice_name: Unique identifier for this voice | |
| audio_path: Path to reference audio | |
| ref_text: Optional reference text (stored with codes) | |
| Returns: | |
| ref_codes: Encoded reference codes | |
| """ | |
| cache_key = f"{voice_name}_{audio_path}" | |
| if cache_key not in self._ref_cache: | |
| ref_codes = self.encode_reference(audio_path) | |
| self._ref_cache[cache_key] = { | |
| 'codes': ref_codes, | |
| 'ref_text': ref_text | |
| } | |
| return self._ref_cache[cache_key]['codes'] | |
| def add_speaker(self, user_id: int, audio_file: str, ref_text: str): | |
| """ | |
| Add a speaker to the stored dictionary for easy access. | |
| Args: | |
| user_id: Unique user ID | |
| audio_file: Reference audio file path | |
| ref_text: Reference text | |
| Returns: | |
| user_id: The user ID for use in streaming | |
| """ | |
| codes = self.encode_reference(audio_file) | |
| if isinstance(codes, torch.Tensor): | |
| codes = codes.cpu().numpy() | |
| if isinstance(codes, np.ndarray): | |
| codes = codes.flatten().tolist() | |
| self.stored_dict[f"{user_id}"]['codes'] = codes | |
| self.stored_dict[f"{user_id}"]['ref_text'] = ref_text | |
| return user_id | |
| def _decode(self, codes: str): | |
| """Decode speech tokens to audio waveform""" | |
| speech_ids = [int(num) for num in re.findall(r"<\|speech_(\d+)\|>", codes)] | |
| if len(speech_ids) == 0: | |
| raise ValueError("No valid speech tokens found in output") | |
| if self._is_onnx_codec: | |
| codes = np.array(speech_ids, dtype=np.int32)[np.newaxis, np.newaxis, :] | |
| recon = self.codec.decode_code(codes) | |
| else: | |
| with torch.no_grad(): | |
| codes = torch.tensor(speech_ids, dtype=torch.long)[None, None, :].to( | |
| self.codec.device | |
| ) | |
| recon = self.codec.decode_code(codes).cpu().numpy() | |
| return recon[0, 0, :] | |
| def _decode_batch(self, codes_list: list[str], max_workers: int = None): | |
| """ | |
| Decode multiple code strings in parallel. | |
| Args: | |
| codes_list: List of code strings to decode | |
| max_workers: Number of parallel workers (auto-tuned if None) | |
| Returns: | |
| List of decoded audio arrays | |
| """ | |
| # Auto-tune workers based on GPU memory and batch size | |
| if max_workers is None: | |
| if torch.cuda.is_available(): | |
| gpu_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 | |
| # 1 worker per 4GB VRAM, max 4 workers | |
| max_workers = min(max(1, int(gpu_mem_gb / 4)), 4) | |
| else: | |
| max_workers = 2 | |
| # For small batches, use sequential to avoid overhead | |
| if len(codes_list) <= 2: | |
| return [self._decode(codes) for codes in codes_list] | |
| # Parallel decoding with controlled workers | |
| with ThreadPoolExecutor(max_workers=max_workers) as executor: | |
| futures = [executor.submit(self._decode, codes) for codes in codes_list] | |
| results = [f.result() for f in futures] | |
| return results | |
| def _format_prompt(self, ref_codes: list[int], ref_text: str, input_text: str) -> str: | |
| """Format prompt for LMDeploy""" | |
| ref_text_phones = phonemize_with_dict(ref_text) | |
| input_text_phones = phonemize_with_dict(input_text) | |
| codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes]) | |
| prompt = ( | |
| f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text_phones} {input_text_phones}" | |
| f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}" | |
| ) | |
| return prompt | |
| def infer(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> np.ndarray: | |
| """ | |
| Single inference. | |
| Args: | |
| text: Input text to synthesize | |
| ref_codes: Encoded reference audio codes | |
| ref_text: Reference text for reference audio | |
| Returns: | |
| Generated speech waveform as numpy array | |
| """ | |
| if isinstance(ref_codes, torch.Tensor): | |
| ref_codes = ref_codes.cpu().numpy() | |
| if isinstance(ref_codes, np.ndarray): | |
| ref_codes = ref_codes.flatten().tolist() | |
| prompt = self._format_prompt(ref_codes, ref_text, text) | |
| # Use LMDeploy pipeline for generation | |
| responses = self.backbone([prompt], gen_config=self.gen_config, do_preprocess=False) | |
| output_str = responses[0].text | |
| # Decode to audio | |
| wav = self._decode(output_str) | |
| return wav | |
| def infer_batch(self, texts: list[str], ref_codes: np.ndarray | torch.Tensor, ref_text: str, max_batch_size: int = None) -> list[np.ndarray]: | |
| """ | |
| Batch inference for multiple texts. | |
| """ | |
| if max_batch_size is None: | |
| max_batch_size = self.max_batch_size | |
| if not isinstance(texts, list): | |
| texts = [texts] | |
| if isinstance(ref_codes, torch.Tensor): | |
| ref_codes = ref_codes.cpu().numpy() | |
| if isinstance(ref_codes, np.ndarray): | |
| ref_codes = ref_codes.flatten().tolist() | |
| all_wavs = [] | |
| for i in range(0, len(texts), max_batch_size): | |
| batch_texts = texts[i:i+max_batch_size] | |
| prompts = [self._format_prompt(ref_codes, ref_text, text) for text in batch_texts] | |
| responses = self.backbone(prompts, gen_config=self.gen_config, do_preprocess=False) | |
| batch_codes = [response.text for response in responses] | |
| if len(batch_codes) > 3: | |
| batch_wavs = self._decode_batch(batch_codes) | |
| else: | |
| batch_wavs = [self._decode(codes) for codes in batch_codes] | |
| all_wavs.extend(batch_wavs) | |
| if i + max_batch_size < len(texts): | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return all_wavs | |
| def infer_stream(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> Generator[np.ndarray, None, None]: | |
| """ | |
| Streaming inference with low latency. | |
| Args: | |
| text: Input text to synthesize | |
| ref_codes: Encoded reference audio codes | |
| ref_text: Reference text for reference audio | |
| Yields: | |
| Audio chunks as numpy arrays | |
| """ | |
| if isinstance(ref_codes, torch.Tensor): | |
| ref_codes = ref_codes.cpu().numpy() | |
| if isinstance(ref_codes, np.ndarray): | |
| ref_codes = ref_codes.flatten().tolist() | |
| prompt = self._format_prompt(ref_codes, ref_text, text) | |
| audio_cache = [] | |
| token_cache = [f"<|speech_{idx}|>" for idx in ref_codes] | |
| n_decoded_samples = 0 | |
| n_decoded_tokens = len(ref_codes) | |
| for response in self.backbone.stream_infer([prompt], gen_config=self.gen_config, do_preprocess=False): | |
| output_str = response.text | |
| # Extract new tokens | |
| new_tokens = output_str[len("".join(token_cache[len(ref_codes):])):] if len(token_cache) > len(ref_codes) else output_str | |
| if new_tokens: | |
| token_cache.append(new_tokens) | |
| # Check if we have enough tokens to decode a chunk | |
| if len(token_cache[n_decoded_tokens:]) >= self.streaming_frames_per_chunk + self.streaming_lookforward: | |
| # Decode chunk with context | |
| tokens_start = max( | |
| n_decoded_tokens - self.streaming_lookback - self.streaming_overlap_frames, | |
| 0 | |
| ) | |
| tokens_end = ( | |
| n_decoded_tokens | |
| + self.streaming_frames_per_chunk | |
| + self.streaming_lookforward | |
| + self.streaming_overlap_frames | |
| ) | |
| sample_start = (n_decoded_tokens - tokens_start) * self.hop_length | |
| sample_end = ( | |
| sample_start | |
| + (self.streaming_frames_per_chunk + 2 * self.streaming_overlap_frames) * self.hop_length | |
| ) | |
| curr_codes = token_cache[tokens_start:tokens_end] | |
| recon = self._decode("".join(curr_codes)) | |
| recon = recon[sample_start:sample_end] | |
| audio_cache.append(recon) | |
| # Overlap-add processing | |
| processed_recon = _linear_overlap_add( | |
| audio_cache, stride=self.streaming_stride_samples | |
| ) | |
| new_samples_end = len(audio_cache) * self.streaming_stride_samples | |
| processed_recon = processed_recon[n_decoded_samples:new_samples_end] | |
| n_decoded_samples = new_samples_end | |
| n_decoded_tokens += self.streaming_frames_per_chunk | |
| yield processed_recon | |
| # Final chunk | |
| remaining_tokens = len(token_cache) - n_decoded_tokens | |
| if remaining_tokens > 0: | |
| tokens_start = max( | |
| len(token_cache) - (self.streaming_lookback + self.streaming_overlap_frames + remaining_tokens), | |
| 0 | |
| ) | |
| sample_start = ( | |
| len(token_cache) - tokens_start - remaining_tokens - self.streaming_overlap_frames | |
| ) * self.hop_length | |
| curr_codes = token_cache[tokens_start:] | |
| recon = self._decode("".join(curr_codes)) | |
| recon = recon[sample_start:] | |
| audio_cache.append(recon) | |
| processed_recon = _linear_overlap_add(audio_cache, stride=self.streaming_stride_samples) | |
| processed_recon = processed_recon[n_decoded_samples:] | |
| yield processed_recon | |
| def cleanup_memory(self): | |
| """Clean up GPU memory""" | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| print("🧹 Memory cleaned up") | |
| def get_optimization_stats(self) -> dict: | |
| """ | |
| Get current optimization statistics. | |
| Returns: | |
| Dictionary with optimization info | |
| """ | |
| return { | |
| 'triton_enabled': self._triton_enabled, | |
| 'max_batch_size': self.max_batch_size, | |
| 'cached_references': len(self._ref_cache), | |
| 'active_sessions': len(self.stored_dict), | |
| 'kv_quant': self.gen_config.__dict__.get('quant_policy', 0), | |
| 'prefix_caching': True, # Always enabled in our config | |
| } | |