""" ๐ŸŽต AudioGhost AI - Hugging Face Spaces ZeroGPU Edition AI-Powered Audio Separation using SAM-Audio This application is optimized for Hugging Face ZeroGPU deployment. GPU is dynamically allocated only during inference and released after. """ import gradio as gr import torch import torchaudio import gc import time import tempfile import os from pathlib import Path from typing import Tuple, Optional import types # ZeroGPU support - handle import gracefully try: import spaces ZEROGPU_AVAILABLE = True except ImportError: ZEROGPU_AVAILABLE = False print("[INFO] 'spaces' module not found. ZeroGPU decorator will be skipped.") # ========== CONFIGURATION ========== MODEL_SIZES = { "small": "facebook/sam-audio-small", "base": "facebook/sam-audio-base", "large": "facebook/sam-audio-large" } # Chunk settings for memory efficiency CHUNK_DURATION_MAP = { "small": 30.0, "base": 25.0, "large": 15.0 } # ========== MODEL CACHING ========== _model_cache = {} _processor_cache = {} def show_gpu_memory(label: str = "") -> str: """Show GPU memory stats""" if torch.cuda.is_available(): allocated = torch.cuda.memory_allocated() / 1024**3 reserved = torch.cuda.memory_reserved() / 1024**3 return f"[GPU {label}] Allocated: {allocated:.2f}GB | Reserved: {reserved:.2f}GB" return "[GPU] Not available" def create_lite_model(model_name: str): """ Create a memory-optimized SAM Audio model. Optimizations: - Removes vision_encoder (~2GB VRAM saved) - Removes visual_ranker (~2GB saved) - Removes text_ranker (~2GB saved) - Removes span_predictor (~1-2GB saved) Total: Reduces VRAM from ~11GB to ~4-5GB """ from sam_audio import SAMAudio, SAMAudioProcessor print(f"[LITE] Loading {model_name}...") # Load model model = SAMAudio.from_pretrained(model_name, torch_dtype=torch.float32) processor = SAMAudioProcessor.from_pretrained(model_name) print("[LITE] Optimizing model for low VRAM...") # Get vision encoder dim before deleting vision_dim = model.vision_encoder.dim if hasattr(model.vision_encoder, 'dim') else 1024 # Delete heavy components del model.vision_encoder gc.collect() print(" โœ“ Removed vision_encoder") # Store the dim for _get_video_features model._vision_encoder_dim = vision_dim # Replace _get_video_features to not use vision_encoder def _get_video_features_lite(self, video, audio_features): B, T, _ = audio_features.shape return audio_features.new_zeros(B, self._vision_encoder_dim, T) model._get_video_features = types.MethodType(_get_video_features_lite, model) # Delete rankers components_to_remove = [ ('visual_ranker', 'visual_ranker'), ('text_ranker', 'text_ranker'), ('span_predictor', 'span_predictor'), ('span_predictor_transform', 'span_predictor_transform') ] for attr_name, display_name in components_to_remove: if hasattr(model, attr_name) and getattr(model, attr_name) is not None: delattr(model, attr_name) setattr(model, attr_name, None) gc.collect() print(f" โœ“ Removed {display_name}") # Set to eval mode model = model.eval() print("[LITE] Model optimization complete!") return model, processor def get_or_load_model(model_size: str): """Get cached model or load new one""" global _model_cache, _processor_cache model_name = MODEL_SIZES.get(model_size, MODEL_SIZES["base"]) if model_name not in _model_cache: print(f"[CACHE] Loading new model: {model_name}") # Clear old models to save memory if len(_model_cache) > 0: print(f"[CACHE] Clearing {len(_model_cache)} old model(s)...") _model_cache.clear() _processor_cache.clear() gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() model, processor = create_lite_model(model_name) _model_cache[model_name] = model _processor_cache[model_name] = processor else: print(f"[CACHE] Using cached model: {model_name}") return _model_cache[model_name], _processor_cache[model_name] def process_audio_chunks( audio_tensor: torch.Tensor, sample_rate: int, model, processor, description: str, device: str, dtype: torch.dtype, chunk_duration: float, progress_callback=None ) -> Tuple[torch.Tensor, torch.Tensor]: """Process audio in chunks for memory efficiency""" max_chunk_samples = int(sample_rate * chunk_duration) audio_1d = audio_tensor.squeeze(0) if audio_1d.shape[-1] <= max_chunk_samples: # Process as single batch if progress_callback: progress_callback(0.5, "Processing audio...") batch = processor( audios=[audio_1d.unsqueeze(0)], descriptions=[description] ).to(device) with torch.inference_mode(): with torch.amp.autocast(device_type="cuda", enabled=(device == "cuda")): result = model.separate( batch, predict_spans=False, reranking_candidates=1 ) target = result.target[0].cpu() residual = result.residual[0].cpu() del batch, result return target, residual # Chunked processing audio_on_device = audio_1d.to(device, dtype) chunks = torch.split(audio_on_device, max_chunk_samples, dim=-1) total_chunks = len(chunks) print(f"[CHUNKS] Processing {total_chunks} chunks of {chunk_duration}s each") out_target = [] out_residual = [] for i, chunk in enumerate(chunks): if progress_callback: chunk_progress = 0.3 + (i / total_chunks) * 0.5 progress_callback(chunk_progress, f"Processing chunk {i+1}/{total_chunks}...") # Skip very short chunks (< 1 second) if chunk.shape[-1] < sample_rate: print(f"[CHUNKS] Skipping chunk {i+1} (too short)") continue batch = processor( audios=[chunk.unsqueeze(0)], descriptions=[description] ).to(device) with torch.inference_mode(): with torch.amp.autocast(device_type="cuda", enabled=(device == "cuda")): result = model.separate( batch, predict_spans=False, reranking_candidates=1 ) out_target.append(result.target[0].cpu()) out_residual.append(result.residual[0].cpu()) del batch, result torch.cuda.empty_cache() # Concatenate all chunks target = torch.cat(out_target, dim=-1) residual = torch.cat(out_residual, dim=-1) del out_target, out_residual, chunks, audio_on_device return target, residual # ========== MAIN GPU FUNCTION (ZeroGPU) ========== # Conditional decorator - only apply if spaces module is available def gpu_decorator(func): """Apply @spaces.GPU decorator if available, otherwise return function as-is""" if ZEROGPU_AVAILABLE: return spaces.GPU(duration=180)(func) return func @gpu_decorator def separate_audio_gpu( audio_path: str, description: str, model_size: str = "base", progress=gr.Progress() ) -> Tuple[Optional[str], Optional[str], str]: """ Main separation function decorated with @spaces.GPU for ZeroGPU. GPU is allocated when this function is called and released when it returns. This enables cost-effective GPU usage on Hugging Face Spaces. Args: audio_path: Path to input audio file description: Text description of sound to separate model_size: Model size (small/base/large) progress: Gradio progress callback Returns: Tuple of (ghost_audio_path, clean_audio_path, status_message) """ start_time = time.time() # Setup device device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.bfloat16 if device == "cuda" else torch.float32 print(f"\n{'='*60}") print(f"[ZeroGPU] Starting separation") print(f"[ZeroGPU] Device: {device}, Dtype: {dtype}") print(f"[ZeroGPU] Model: {model_size}") print(f"[ZeroGPU] Description: '{description}'") print(show_gpu_memory("Initial")) print(f"{'='*60}\n") try: progress(0.1, "Loading model...") # Get or load model model, processor = get_or_load_model(model_size) # Move model to GPU model = model.to(device, dtype) print(show_gpu_memory("After model load")) progress(0.2, "Loading audio...") # Load and preprocess audio sample_rate = processor.audio_sampling_rate audio, orig_sr = torchaudio.load(audio_path) if orig_sr != sample_rate: resampler = torchaudio.transforms.Resample(orig_sr, sample_rate) audio = resampler(audio) # Convert to mono if stereo if audio.shape[0] > 1: audio = audio.mean(dim=0, keepdim=True) audio_duration = audio.shape[1] / sample_rate print(f"[ZeroGPU] Audio duration: {audio_duration:.2f}s") # Get chunk duration for this model size chunk_duration = CHUNK_DURATION_MAP.get(model_size, 25.0) progress(0.3, "Running separation...") # Process audio target, residual = process_audio_chunks( audio_tensor=audio, sample_rate=sample_rate, model=model, processor=processor, description=description, device=device, dtype=dtype, chunk_duration=chunk_duration, progress_callback=lambda p, m: progress(p, m) ) progress(0.85, "Saving results...") # Clamp and format output target_audio = target.clamp(-1, 1).float().unsqueeze(0) residual_audio = residual.clamp(-1, 1).float().unsqueeze(0) # Save to temp files ghost_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) clean_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) torchaudio.save(ghost_file.name, target_audio, sample_rate) torchaudio.save(clean_file.name, residual_audio, sample_rate) ghost_file.close() clean_file.close() # Cleanup - move model back to CPU and free GPU memory progress(0.95, "Cleaning up GPU...") model.cpu() model.to(torch.float32) del target, residual, target_audio, residual_audio, audio gc.collect() torch.cuda.empty_cache() print(show_gpu_memory("After cleanup")) processing_time = time.time() - start_time speed = audio_duration / processing_time if processing_time > 0 else 0 status = ( f"โœ… Separation complete!\n" f"๐Ÿ“Š Audio: {audio_duration:.1f}s | " f"โฑ๏ธ Time: {processing_time:.1f}s | " f"โšก Speed: {speed:.2f}x realtime" ) print(f"\n[ZeroGPU] {status}\n") progress(1.0, "Done!") return ghost_file.name, clean_file.name, status except Exception as e: # Cleanup on error try: if 'model' in dir(): model.cpu() except: pass gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() error_msg = f"โŒ Error: {str(e)}" print(f"[ZeroGPU] {error_msg}") return None, None, error_msg def process_separation(audio, description, model_size): """Wrapper function for Gradio interface""" if audio is None: return None, None, "โš ๏ธ Please upload an audio file" if not description or description.strip() == "": return None, None, "โš ๏ธ Please enter a description of the sound to separate" ghost_path, clean_path, status = separate_audio_gpu( audio_path=audio, description=description.strip(), model_size=model_size ) return ghost_path, clean_path, status # ========== GRADIO INTERFACE ========== EXAMPLE_PROMPTS = [ "singing voice", "human speech", "drums", "guitar", "piano", "bass", "background noise", "wind noise", "crowd noise", "birds chirping", "dog barking", "car engine" ] # Custom CSS for premium look custom_css = """ .gradio-container { font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif !important; } .main-title { text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; font-size: 3rem !important; font-weight: 800 !important; margin-bottom: 0.5rem !important; } .subtitle { text-align: center; color: #6b7280; font-size: 1.1rem; margin-bottom: 2rem; } .status-box { padding: 1rem; border-radius: 0.75rem; background: linear-gradient(135deg, #f0f9ff 0%, #e0f2fe 100%); border: 1px solid #bae6fd; font-family: 'JetBrains Mono', monospace; } .example-btn { border-radius: 9999px !important; font-size: 0.875rem !important; } footer { text-align: center; margin-top: 2rem; padding: 1rem; color: #9ca3af; font-size: 0.875rem; } """ # Build the Gradio UI with gr.Blocks( title="AudioGhost AI | AI Audio Separation", theme=gr.themes.Soft( primary_hue="purple", secondary_hue="blue", neutral_hue="slate" ), css=custom_css ) as demo: # Header gr.HTML("""

๐Ÿ‘ป AudioGhost AI

AI-Powered Audio Separation using SAM-Audio

""") # Main content with gr.Row(equal_height=True): # Left column - Input with gr.Column(scale=1): gr.Markdown("### ๐Ÿ“ค Input") audio_input = gr.Audio( label="Upload Audio File", type="filepath", sources=["upload", "microphone"], format="wav" ) description_input = gr.Textbox( label="Sound Description", placeholder="Describe the sound you want to extract (e.g., 'singing voice', 'drums', 'background noise')", lines=2, max_lines=4 ) # Example prompts gr.Markdown("**Quick Examples:**") with gr.Row(): for prompt in EXAMPLE_PROMPTS[:6]: gr.Button( prompt, size="sm", variant="secondary" ).click( lambda p=prompt: p, outputs=description_input ) with gr.Row(): for prompt in EXAMPLE_PROMPTS[6:]: gr.Button( prompt, size="sm", variant="secondary" ).click( lambda p=prompt: p, outputs=description_input ) model_choice = gr.Radio( choices=[ ("๐Ÿš€ Small (Fastest)", "small"), ("โš–๏ธ Base (Balanced)", "base"), ("๐ŸŽฏ Large (Best Quality)", "large") ], value="base", label="Model Size", info="Larger models = better quality but slower processing" ) submit_btn = gr.Button( "๐ŸŽต Separate Audio", variant="primary", size="lg" ) # Right column - Output with gr.Column(scale=1): gr.Markdown("### ๐Ÿ“ฅ Output") ghost_output = gr.Audio( label="๐Ÿ‘ป Extracted Audio (Ghost)", type="filepath", interactive=False ) clean_output = gr.Audio( label="๐Ÿงน Residual Audio (Clean)", type="filepath", interactive=False ) status_output = gr.Textbox( label="Status", interactive=False, lines=3, elem_classes=["status-box"] ) # Connect the button submit_btn.click( fn=process_separation, inputs=[audio_input, description_input, model_choice], outputs=[ghost_output, clean_output, status_output] ) # Footer gr.HTML(""" """) # Info accordion with gr.Accordion("โ„น๏ธ About AudioGhost AI", open=False): gr.Markdown(""" ## How It Works AudioGhost AI uses **SAM-Audio (Segment Anything Model for Audio)** from Meta AI Research to separate specific sounds from audio files based on text descriptions. ### Features: - ๐ŸŽฏ **Text-Guided Separation**: Describe what you want to extract - ๐Ÿš€ **Multiple Model Sizes**: Choose between speed and quality - ๐Ÿ’พ **Memory Optimized**: Processes long audio in chunks - โšก **ZeroGPU Powered**: Dynamic GPU allocation for efficiency ### Tips for Best Results: 1. Be specific in your description (e.g., "female singing voice" vs just "voice") 2. Use the Base model for most use cases 3. For very long audio (>5 min), the Large model may timeout ### Limitations: - Maximum audio length depends on model size and complexity - Some complex mixtures may not separate perfectly - GPU time is limited to 180 seconds per request """) # Launch configuration if __name__ == "__main__": demo.queue(max_size=10) demo.launch( server_name="0.0.0.0", server_port=7860, share=False )