"""Generate video with a RAM-quantized LTX-2.3 model via ltx-2-mlx. The script loads the mixed-precision model directory produced by reformat_ltx_for_pipeline.py into dgrauet/ltx-2-mlx's DistilledPipeline, replacing the default apply_quantization with a per-layer mixed-precision version that handles our variable-bits allocation. Usage: python experiments/flux_phase1/generate_ltx.py \\ --model-dir results/ltx-2.3/model_dir_12gb/ \\ --prompt "A cat walks through a field of flowers" \\ --output output.mp4 # Specify resolution and frame count python experiments/flux_phase1/generate_ltx.py \\ --model-dir results/ltx-2.3/model_dir_12gb/ \\ --prompt "Ocean waves at sunset" \\ --height 480 --width 704 --num-frames 97 \\ --output ocean.mp4 # Quiet (no progress output) python experiments/flux_phase1/generate_ltx.py \\ --model-dir results/ltx-2.3/model_dir_12gb/ \\ --prompt "..." --output out.mp4 --quiet """ import argparse import sys from collections import defaultdict from pathlib import Path import mlx.core as mx import mlx.nn as nn def apply_mixed_precision_quantization( model: nn.Module, weights: dict, group_size: int = 64, ) -> None: """Per-layer mixed-precision quantization from a weight dict. Unlike ltx_core_mlx's apply_quantization (which uses a single detected bit width for all layers), this version detects each layer's bits from its packed weight shape and applies nn.quantize once per unique bit width. Layers that have .scales but whose bits can't be determined are skipped (kept as nn.Linear — they will fail at load_weights if shapes mismatch, which surfaces any genuine key errors). """ layer_bits: dict[str, int] = {} for key in weights: if not key.endswith(".scales"): continue layer = key[: -len(".scales")] w_key = layer + ".weight" if w_key not in weights: continue w_cols = weights[w_key].shape[-1] s_cols = weights[key].shape[-1] bits = round(w_cols * 32 / (s_cols * group_size)) if bits in (2, 3, 4, 5, 6, 8): layer_bits[layer] = bits if not layer_bits: return bits_to_layers: dict[int, set] = defaultdict(set) for layer, b in layer_bits.items(): bits_to_layers[b].add(layer) for bits, layers in sorted(bits_to_layers.items()): def _predicate(path: str, module: nn.Module, _layers=layers) -> bool: return path in _layers and isinstance(module, nn.Linear) nn.quantize(model, group_size=group_size, bits=bits, class_predicate=_predicate) total = sum(len(v) for v in bits_to_layers.values()) dist = {b: len(v) for b, v in sorted(bits_to_layers.items())} print(f" Mixed-precision quantization: {total} layers — {dist}", flush=True) def _patch_pipeline_quantization(group_size: int = 64): """Monkeypatch ltx_core_mlx to use our mixed-precision quantizer.""" import ltx_core_mlx.utils.weights as wm import ltx_pipelines_mlx.utils._orchestration as orch def _patched_apply(model, weights, group_size=group_size, bits=None): apply_mixed_precision_quantization(model, weights, group_size) wm.apply_quantization = _patched_apply # Also patch the reference in _orchestration (it imports apply_quantization # at module level in some builds) if hasattr(orch, "apply_quantization"): orch.apply_quantization = _patched_apply def main(): p = argparse.ArgumentParser( description="Generate video with RAM-quantized LTX-2.3." ) p.add_argument("--model-dir", required=True, help="Model directory from reformat_ltx_for_pipeline.py") p.add_argument("--prompt", required=True, help="Text prompt for video generation") p.add_argument("--output", default="output.mp4", help="Output video path") p.add_argument("--height", type=int, default=480) p.add_argument("--width", type=int, default=704) p.add_argument("--num-frames", type=int, default=97) p.add_argument("--frame-rate", type=float, default=24.0) p.add_argument("--seed", type=int, default=42) p.add_argument("--stage1-steps", type=int, default=None) p.add_argument("--stage2-steps", type=int, default=None) p.add_argument("--gemma-model", default="mlx-community/gemma-3-12b-it-4bit", help="Gemma model ID for text encoding") p.add_argument("--low-memory", action="store_true", default=True, help="Aggressive memory management (default: on)") p.add_argument("--no-low-memory", dest="low_memory", action="store_false") p.add_argument("--quiet", action="store_true", help="Suppress pipeline progress output") args = p.parse_args() model_dir = Path(args.model_dir) if not model_dir.exists(): print(f"Error: model dir {model_dir} not found", file=sys.stderr) sys.exit(1) required = [ "transformer-distilled.safetensors", "connector.safetensors", "vae_decoder.safetensors", "audio_vae.safetensors", "vocoder.safetensors", ] missing = [f for f in required if not (model_dir / f).exists()] if missing: print(f"Error: missing files in {model_dir}: {missing}", file=sys.stderr) print("Run reformat_ltx_for_pipeline.py first.", file=sys.stderr) sys.exit(1) # Patch before importing the pipeline print("Patching quantization to use per-layer mixed-precision…") _patch_pipeline_quantization() from ltx_pipelines_mlx.distilled import DistilledPipeline print(f"Loading pipeline from {model_dir}…") pipeline = DistilledPipeline( model_dir=str(model_dir), gemma_model_id=args.gemma_model, low_memory=args.low_memory, ) print(f"\nGenerating: '{args.prompt}'") print(f" Resolution: {args.height}×{args.width}, {args.num_frames} frames @ {args.frame_rate} fps") video_latent, audio_latent = pipeline.generate_two_stage( prompt=args.prompt, height=args.height, width=args.width, num_frames=args.num_frames, frame_rate=args.frame_rate, seed=args.seed, stage1_steps=args.stage1_steps, stage2_steps=args.stage2_steps, ) print(f"\nDecoding and saving → {args.output}") out = pipeline._decode_and_save_video( video_latent, audio_latent, args.output, frame_rate=args.frame_rate ) print(f"Done: {out}") if __name__ == "__main__": main()