""" ๐ต AudioGhost AI - Hugging Face Spaces ZeroGPU Edition AI-Powered Audio Separation using SAM-Audio This application is optimized for Hugging Face ZeroGPU deployment. GPU is dynamically allocated only during inference and released after. """ import gradio as gr import torch import torchaudio import gc import time import tempfile import os from pathlib import Path from typing import Tuple, Optional import types # ZeroGPU support - handle import gracefully try: import spaces ZEROGPU_AVAILABLE = True except ImportError: ZEROGPU_AVAILABLE = False print("[INFO] 'spaces' module not found. ZeroGPU decorator will be skipped.") # ========== CONFIGURATION ========== MODEL_SIZES = { "small": "facebook/sam-audio-small", "base": "facebook/sam-audio-base", "large": "facebook/sam-audio-large" } # Chunk settings for memory efficiency CHUNK_DURATION_MAP = { "small": 30.0, "base": 25.0, "large": 15.0 } # ========== MODEL CACHING ========== _model_cache = {} _processor_cache = {} def show_gpu_memory(label: str = "") -> str: """Show GPU memory stats""" if torch.cuda.is_available(): allocated = torch.cuda.memory_allocated() / 1024**3 reserved = torch.cuda.memory_reserved() / 1024**3 return f"[GPU {label}] Allocated: {allocated:.2f}GB | Reserved: {reserved:.2f}GB" return "[GPU] Not available" def create_lite_model(model_name: str): """ Create a memory-optimized SAM Audio model. Optimizations: - Removes vision_encoder (~2GB VRAM saved) - Removes visual_ranker (~2GB saved) - Removes text_ranker (~2GB saved) - Removes span_predictor (~1-2GB saved) Total: Reduces VRAM from ~11GB to ~4-5GB """ from sam_audio import SAMAudio, SAMAudioProcessor print(f"[LITE] Loading {model_name}...") # Load model model = SAMAudio.from_pretrained(model_name, torch_dtype=torch.float32) processor = SAMAudioProcessor.from_pretrained(model_name) print("[LITE] Optimizing model for low VRAM...") # Get vision encoder dim before deleting vision_dim = model.vision_encoder.dim if hasattr(model.vision_encoder, 'dim') else 1024 # Delete heavy components del model.vision_encoder gc.collect() print(" โ Removed vision_encoder") # Store the dim for _get_video_features model._vision_encoder_dim = vision_dim # Replace _get_video_features to not use vision_encoder def _get_video_features_lite(self, video, audio_features): B, T, _ = audio_features.shape return audio_features.new_zeros(B, self._vision_encoder_dim, T) model._get_video_features = types.MethodType(_get_video_features_lite, model) # Delete rankers components_to_remove = [ ('visual_ranker', 'visual_ranker'), ('text_ranker', 'text_ranker'), ('span_predictor', 'span_predictor'), ('span_predictor_transform', 'span_predictor_transform') ] for attr_name, display_name in components_to_remove: if hasattr(model, attr_name) and getattr(model, attr_name) is not None: delattr(model, attr_name) setattr(model, attr_name, None) gc.collect() print(f" โ Removed {display_name}") # Set to eval mode model = model.eval() print("[LITE] Model optimization complete!") return model, processor def get_or_load_model(model_size: str): """Get cached model or load new one""" global _model_cache, _processor_cache model_name = MODEL_SIZES.get(model_size, MODEL_SIZES["base"]) if model_name not in _model_cache: print(f"[CACHE] Loading new model: {model_name}") # Clear old models to save memory if len(_model_cache) > 0: print(f"[CACHE] Clearing {len(_model_cache)} old model(s)...") _model_cache.clear() _processor_cache.clear() gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() model, processor = create_lite_model(model_name) _model_cache[model_name] = model _processor_cache[model_name] = processor else: print(f"[CACHE] Using cached model: {model_name}") return _model_cache[model_name], _processor_cache[model_name] def process_audio_chunks( audio_tensor: torch.Tensor, sample_rate: int, model, processor, description: str, device: str, dtype: torch.dtype, chunk_duration: float, progress_callback=None ) -> Tuple[torch.Tensor, torch.Tensor]: """Process audio in chunks for memory efficiency""" max_chunk_samples = int(sample_rate * chunk_duration) audio_1d = audio_tensor.squeeze(0) if audio_1d.shape[-1] <= max_chunk_samples: # Process as single batch if progress_callback: progress_callback(0.5, "Processing audio...") batch = processor( audios=[audio_1d.unsqueeze(0)], descriptions=[description] ).to(device) with torch.inference_mode(): with torch.amp.autocast(device_type="cuda", enabled=(device == "cuda")): result = model.separate( batch, predict_spans=False, reranking_candidates=1 ) target = result.target[0].cpu() residual = result.residual[0].cpu() del batch, result return target, residual # Chunked processing audio_on_device = audio_1d.to(device, dtype) chunks = torch.split(audio_on_device, max_chunk_samples, dim=-1) total_chunks = len(chunks) print(f"[CHUNKS] Processing {total_chunks} chunks of {chunk_duration}s each") out_target = [] out_residual = [] for i, chunk in enumerate(chunks): if progress_callback: chunk_progress = 0.3 + (i / total_chunks) * 0.5 progress_callback(chunk_progress, f"Processing chunk {i+1}/{total_chunks}...") # Skip very short chunks (< 1 second) if chunk.shape[-1] < sample_rate: print(f"[CHUNKS] Skipping chunk {i+1} (too short)") continue batch = processor( audios=[chunk.unsqueeze(0)], descriptions=[description] ).to(device) with torch.inference_mode(): with torch.amp.autocast(device_type="cuda", enabled=(device == "cuda")): result = model.separate( batch, predict_spans=False, reranking_candidates=1 ) out_target.append(result.target[0].cpu()) out_residual.append(result.residual[0].cpu()) del batch, result torch.cuda.empty_cache() # Concatenate all chunks target = torch.cat(out_target, dim=-1) residual = torch.cat(out_residual, dim=-1) del out_target, out_residual, chunks, audio_on_device return target, residual # ========== MAIN GPU FUNCTION (ZeroGPU) ========== # Conditional decorator - only apply if spaces module is available def gpu_decorator(func): """Apply @spaces.GPU decorator if available, otherwise return function as-is""" if ZEROGPU_AVAILABLE: return spaces.GPU(duration=180)(func) return func @gpu_decorator def separate_audio_gpu( audio_path: str, description: str, model_size: str = "base", progress=gr.Progress() ) -> Tuple[Optional[str], Optional[str], str]: """ Main separation function decorated with @spaces.GPU for ZeroGPU. GPU is allocated when this function is called and released when it returns. This enables cost-effective GPU usage on Hugging Face Spaces. Args: audio_path: Path to input audio file description: Text description of sound to separate model_size: Model size (small/base/large) progress: Gradio progress callback Returns: Tuple of (ghost_audio_path, clean_audio_path, status_message) """ start_time = time.time() # Setup device device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.bfloat16 if device == "cuda" else torch.float32 print(f"\n{'='*60}") print(f"[ZeroGPU] Starting separation") print(f"[ZeroGPU] Device: {device}, Dtype: {dtype}") print(f"[ZeroGPU] Model: {model_size}") print(f"[ZeroGPU] Description: '{description}'") print(show_gpu_memory("Initial")) print(f"{'='*60}\n") try: progress(0.1, "Loading model...") # Get or load model model, processor = get_or_load_model(model_size) # Move model to GPU model = model.to(device, dtype) print(show_gpu_memory("After model load")) progress(0.2, "Loading audio...") # Load and preprocess audio sample_rate = processor.audio_sampling_rate audio, orig_sr = torchaudio.load(audio_path) if orig_sr != sample_rate: resampler = torchaudio.transforms.Resample(orig_sr, sample_rate) audio = resampler(audio) # Convert to mono if stereo if audio.shape[0] > 1: audio = audio.mean(dim=0, keepdim=True) audio_duration = audio.shape[1] / sample_rate print(f"[ZeroGPU] Audio duration: {audio_duration:.2f}s") # Get chunk duration for this model size chunk_duration = CHUNK_DURATION_MAP.get(model_size, 25.0) progress(0.3, "Running separation...") # Process audio target, residual = process_audio_chunks( audio_tensor=audio, sample_rate=sample_rate, model=model, processor=processor, description=description, device=device, dtype=dtype, chunk_duration=chunk_duration, progress_callback=lambda p, m: progress(p, m) ) progress(0.85, "Saving results...") # Clamp and format output target_audio = target.clamp(-1, 1).float().unsqueeze(0) residual_audio = residual.clamp(-1, 1).float().unsqueeze(0) # Save to temp files ghost_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) clean_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) torchaudio.save(ghost_file.name, target_audio, sample_rate) torchaudio.save(clean_file.name, residual_audio, sample_rate) ghost_file.close() clean_file.close() # Cleanup - move model back to CPU and free GPU memory progress(0.95, "Cleaning up GPU...") model.cpu() model.to(torch.float32) del target, residual, target_audio, residual_audio, audio gc.collect() torch.cuda.empty_cache() print(show_gpu_memory("After cleanup")) processing_time = time.time() - start_time speed = audio_duration / processing_time if processing_time > 0 else 0 status = ( f"โ Separation complete!\n" f"๐ Audio: {audio_duration:.1f}s | " f"โฑ๏ธ Time: {processing_time:.1f}s | " f"โก Speed: {speed:.2f}x realtime" ) print(f"\n[ZeroGPU] {status}\n") progress(1.0, "Done!") return ghost_file.name, clean_file.name, status except Exception as e: # Cleanup on error try: if 'model' in dir(): model.cpu() except: pass gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() error_msg = f"โ Error: {str(e)}" print(f"[ZeroGPU] {error_msg}") return None, None, error_msg def process_separation(audio, description, model_size): """Wrapper function for Gradio interface""" if audio is None: return None, None, "โ ๏ธ Please upload an audio file" if not description or description.strip() == "": return None, None, "โ ๏ธ Please enter a description of the sound to separate" ghost_path, clean_path, status = separate_audio_gpu( audio_path=audio, description=description.strip(), model_size=model_size ) return ghost_path, clean_path, status # ========== GRADIO INTERFACE ========== EXAMPLE_PROMPTS = [ "singing voice", "human speech", "drums", "guitar", "piano", "bass", "background noise", "wind noise", "crowd noise", "birds chirping", "dog barking", "car engine" ] # Custom CSS for premium look custom_css = """ .gradio-container { font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif !important; } .main-title { text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; font-size: 3rem !important; font-weight: 800 !important; margin-bottom: 0.5rem !important; } .subtitle { text-align: center; color: #6b7280; font-size: 1.1rem; margin-bottom: 2rem; } .status-box { padding: 1rem; border-radius: 0.75rem; background: linear-gradient(135deg, #f0f9ff 0%, #e0f2fe 100%); border: 1px solid #bae6fd; font-family: 'JetBrains Mono', monospace; } .example-btn { border-radius: 9999px !important; font-size: 0.875rem !important; } footer { text-align: center; margin-top: 2rem; padding: 1rem; color: #9ca3af; font-size: 0.875rem; } """ # Build the Gradio UI with gr.Blocks( title="AudioGhost AI | AI Audio Separation", theme=gr.themes.Soft( primary_hue="purple", secondary_hue="blue", neutral_hue="slate" ), css=custom_css ) as demo: # Header gr.HTML("""
AI-Powered Audio Separation using SAM-Audio