# Copyright 2026 Sam McLeod # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 """Export the LLM embed_tokens table as a tiny single-Gather ONNX graph. Every LLM-side graph in this repo (`prompt_encode`, `decode_step`, `editor`) takes `inputs_embeds [..., 2048]` rather than `input_ids`. The Python export path runs `model.llm.model.embed_tokens(text_ids)` upstream of the graph and splices in audio embeddings at `<|audio|>` positions. Without the embedding table itself, downstream consumers (parakeet-rs, CoreML, plain ONNX Runtime) cannot run end-to-end. This script ships the embedding lookup as its own ONNX graph in three tiers: - fp32 : Gather(weight_fp32, input_ids) 822 MB - fp16w: Cast(FP16->FP32)(weight_fp16) -> Gather(...) 411 MB - int8 : Gather(weight_int8) -> Cast -> Mul(per-row scale, gathered) 206 MB All three produce identical FP32 [B, N, 2048] output (FP16w/INT8 within tight numeric tolerance). Op set is opset 20 / IR 10 / `ai.onnx`-only, matching the rest of the bundle. Per-variant model class differs: - base / plus : `model.language_model.model.embed_tokens` (AutoModelForSpeechSeq2Seq) - nar : `model.llm.model.embed_tokens` (AutoModelForCausalLM-style NAR) Usage: HF_HOME=$TMPDIR/hf_home HF_MODULES_CACHE=$TMPDIR/hf_modules \ uv run python src/export_embed_tokens.py --variant base uv run python src/export_embed_tokens.py --variant plus uv run python src/export_embed_tokens.py --variant nar uv run python src/export_embed_tokens.py --variant all """ from __future__ import annotations import argparse import json import time from pathlib import Path from typing import Any import numpy as np import onnx import torch import torch.nn as nn from onnx import TensorProto, helper, numpy_helper REPO_ROOT = Path(__file__).resolve().parents[1] MODELS = REPO_ROOT / "models" EXPORTS = REPO_ROOT / "exports" VARIANT_CONFIG = { "base": { "model_dir": MODELS / "granite-speech-4.1-2b", "out_dir": EXPORTS / "granite-speech-4.1-2b", "loader": "speech_seq2seq", "embed_attr": "language_model.model.embed_tokens", }, "plus": { "model_dir": MODELS / "granite-speech-4.1-2b-plus", "out_dir": EXPORTS / "granite-speech-4.1-2b-plus", "loader": "speech_seq2seq", "embed_attr": "language_model.model.embed_tokens", }, "nar": { "model_dir": MODELS / "granite-speech-4.1-2b-nar", "out_dir": EXPORTS / "granite-speech-4.1-2b-nar", "loader": "nar", "embed_attr": "llm.model.embed_tokens", }, } # --------------------------------------------------------------------------- # Loaders # --------------------------------------------------------------------------- def _load_speech_seq2seq(model_dir: Path) -> nn.Module: from transformers import AutoModelForSpeechSeq2Seq print(f" loading AutoModelForSpeechSeq2Seq from {model_dir} (eager, fp32)") t0 = time.time() model = AutoModelForSpeechSeq2Seq.from_pretrained( str(model_dir), torch_dtype=torch.float32, attn_implementation="eager", ) model.eval() model = model.to(torch.float32) print(f" loaded in {time.time() - t0:.1f}s") return model def _load_nar(model_dir: Path) -> nn.Module: """NAR uses the upstream NLE class loaded via `trust_remote_code`. The config nests `attn_implementation=flash_attention_2` which fails on this machine; we override it to eager up-front, matching `export_nar_editor.py`. """ from transformers import AutoConfig, AutoModel granite_local = REPO_ROOT / "models" / "granite-4.0-1b-base" if not granite_local.exists(): raise FileNotFoundError( f"Expected local Granite 4.0 base at {granite_local}; " "run `hf download ibm-granite/granite-4.0-1b-base ...` first." ) config = AutoConfig.from_pretrained(str(model_dir), trust_remote_code=True) config.llm_name = str(granite_local) config.attn_implementation = "eager" config._attn_implementation = "eager" for sub_attr in ("llm_config", "encoder_config", "projector_config"): sub = getattr(config, sub_attr, None) if sub is not None: for attr in ("attn_implementation", "_attn_implementation"): try: setattr(sub, attr, "eager") except Exception: pass print(f" loading AutoModel (trust_remote_code, eager) from {model_dir}") t0 = time.time() model = AutoModel.from_pretrained( str(model_dir), trust_remote_code=True, torch_dtype=torch.float32, attn_implementation="eager", config=config, ) model.eval() model = model.to(torch.float32) print(f" loaded in {time.time() - t0:.1f}s") return model def _resolve_attr(obj: Any, dotted: str) -> Any: cur = obj for part in dotted.split("."): cur = getattr(cur, part) return cur def load_embed_tokens(variant: str) -> tuple[np.ndarray, dict[str, Any]]: cfg = VARIANT_CONFIG[variant] if cfg["loader"] == "speech_seq2seq": model = _load_speech_seq2seq(cfg["model_dir"]) elif cfg["loader"] == "nar": model = _load_nar(cfg["model_dir"]) else: raise ValueError(f"unknown loader: {cfg['loader']}") embed = _resolve_attr(model, cfg["embed_attr"]) weight = embed.weight.detach().to(torch.float32).cpu().numpy() info = { "variant": variant, "embed_attr": cfg["embed_attr"], "vocab_size": int(weight.shape[0]), "hidden_size": int(weight.shape[1]), } print(f" embed_tokens.weight shape={list(weight.shape)} dtype={weight.dtype}") return weight, info # --------------------------------------------------------------------------- # Graph builders # --------------------------------------------------------------------------- def _make_input(name: str, dtype: int, shape: list[Any]) -> onnx.ValueInfoProto: return helper.make_tensor_value_info(name, dtype, shape) def _save_with_single_sidecar(model: onnx.ModelProto, out_path: Path) -> None: sidecar = out_path.name + "_data" if (out_path.parent / sidecar).exists(): (out_path.parent / sidecar).unlink() if out_path.exists(): out_path.unlink() out_path.parent.mkdir(parents=True, exist_ok=True) onnx.save_model( model, str(out_path), save_as_external_data=True, all_tensors_to_one_file=True, location=sidecar, size_threshold=1024, convert_attribute=False, ) onnx.checker.check_model(str(out_path), full_check=False) def _opset_imports() -> list[onnx.OperatorSetIdProto]: return [helper.make_opsetid("", 20)] def _producer() -> dict[str, str]: return { "producer_name": "granite-speech-4.1-onnx", "producer_version": "embed_tokens-1", } def build_fp32(weight: np.ndarray) -> onnx.ModelProto: """Single Gather op over the FP32 weight.""" weight_init = numpy_helper.from_array(weight.astype(np.float32, copy=False), name="embed_tokens.weight") input_ids = _make_input("input_ids", TensorProto.INT64, ["B", "N"]) output = _make_input("inputs_embeds", TensorProto.FLOAT, ["B", "N", weight.shape[1]]) gather = helper.make_node( "Gather", inputs=["embed_tokens.weight", "input_ids"], outputs=["inputs_embeds"], axis=0, name="EmbedTokens_Gather", ) graph = helper.make_graph( [gather], "embed_tokens_fp32", [input_ids], [output], [weight_init], ) model = helper.make_model( graph, opset_imports=_opset_imports(), ir_version=10, **_producer(), ) return model def build_fp16w(weight: np.ndarray) -> onnx.ModelProto: """Weight stored as FP16; Cast(FP16->FP32) before Gather. Output is FP32. Mirrors the pattern from `convert_fp.py::convert_to_weights_fp16`: a single Cast feeds the consumer; arithmetic stays FP32. """ weight_fp16 = weight.astype(np.float16, copy=False) weight_init = numpy_helper.from_array(weight_fp16, name="embed_tokens.weight") input_ids = _make_input("input_ids", TensorProto.INT64, ["B", "N"]) output = _make_input("inputs_embeds", TensorProto.FLOAT, ["B", "N", weight.shape[1]]) cast_node = helper.make_node( "Cast", inputs=["embed_tokens.weight"], outputs=["embed_tokens.weight_fp32"], to=TensorProto.FLOAT, name="EmbedTokens_Cast_fp16w", ) gather = helper.make_node( "Gather", inputs=["embed_tokens.weight_fp32", "input_ids"], outputs=["inputs_embeds"], axis=0, name="EmbedTokens_Gather", ) graph = helper.make_graph( [cast_node, gather], "embed_tokens_fp16w", [input_ids], [output], [weight_init], ) model = helper.make_model( graph, opset_imports=_opset_imports(), ir_version=10, **_producer(), ) return model def _per_row_int8_quantise(weight: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """Symmetric per-row INT8 quantisation. Returns (int8_weight, fp32_scales). Per-row scale = max(|row|) / 127. Rows that are exactly zero get a scale of 1.0 (the row is encoded as zeros, multiplied by 1.0, recovers zero). This mirrors the per-output-channel layout `quantize_dynamic` uses for MatMul weights, except over rows of the embedding matrix instead of columns of a transposed weight. """ if weight.dtype != np.float32: weight = weight.astype(np.float32, copy=False) abs_max = np.maximum(np.abs(weight).max(axis=1), 1e-12) scales = (abs_max / 127.0).astype(np.float32) int_w = np.round(weight / scales[:, None]) int_w = np.clip(int_w, -127, 127).astype(np.int8) return int_w, scales def build_int8(weight: np.ndarray) -> tuple[onnx.ModelProto, dict[str, Any]]: """Per-row INT8 weights. The graph dequantises only rows actually gathered. Topology (only `ai.onnx` ops): Gather(weight_int8, input_ids, axis=0) -> int8 [B, N, D] Cast(int8 -> fp32) -> [B, N, D] Gather(scales_fp32, input_ids, axis=0) -> [B, N] Unsqueeze(axes=[-1]) -> [B, N, 1] Mul -> [B, N, D] """ int8_w, scales = _per_row_int8_quantise(weight) # Reconstruction error stats (on the full table; consumers gather a subset). recon = int8_w.astype(np.float32) * scales[:, None] err = np.abs(recon - weight.astype(np.float32)) quant_stats = { "max_abs_err": float(err.max()), "mean_abs_err": float(err.mean()), "p99_abs_err": float(np.quantile(err, 0.99)), } print( f" int8 quant per-row: max_abs_err={quant_stats['max_abs_err']:.3e} " f"mean={quant_stats['mean_abs_err']:.3e} p99={quant_stats['p99_abs_err']:.3e}" ) weight_int8_init = numpy_helper.from_array(int8_w, name="embed_tokens.weight_int8") scales_init = numpy_helper.from_array(scales, name="embed_tokens.scales") # Constant initializer for the Unsqueeze axes input (opset 13+ moved axes # from attribute to input). axes_init = numpy_helper.from_array(np.array([-1], dtype=np.int64), name="EmbedTokens_unsqueeze_axes") input_ids = _make_input("input_ids", TensorProto.INT64, ["B", "N"]) output = _make_input("inputs_embeds", TensorProto.FLOAT, ["B", "N", weight.shape[1]]) nodes = [ helper.make_node( "Gather", inputs=["embed_tokens.weight_int8", "input_ids"], outputs=["embed_tokens.weight_int8_gathered"], axis=0, name="EmbedTokens_Gather_weight", ), helper.make_node( "Cast", inputs=["embed_tokens.weight_int8_gathered"], outputs=["embed_tokens.weight_fp32_gathered"], to=TensorProto.FLOAT, name="EmbedTokens_Cast_int8_to_fp32", ), helper.make_node( "Gather", inputs=["embed_tokens.scales", "input_ids"], outputs=["embed_tokens.scales_gathered"], axis=0, name="EmbedTokens_Gather_scales", ), helper.make_node( "Unsqueeze", inputs=["embed_tokens.scales_gathered", "EmbedTokens_unsqueeze_axes"], outputs=["embed_tokens.scales_expanded"], name="EmbedTokens_Unsqueeze_scales", ), helper.make_node( "Mul", inputs=["embed_tokens.weight_fp32_gathered", "embed_tokens.scales_expanded"], outputs=["inputs_embeds"], name="EmbedTokens_Mul", ), ] graph = helper.make_graph( nodes, "embed_tokens_int8", [input_ids], [output], [weight_int8_init, scales_init, axes_init], ) model = helper.make_model( graph, opset_imports=_opset_imports(), ir_version=10, **_producer(), ) return model, quant_stats # --------------------------------------------------------------------------- # Parity # --------------------------------------------------------------------------- def parity_check(weight: np.ndarray, out_paths: dict[str, Path]) -> dict[str, Any]: import onnxruntime as ort # Reference: PyTorch nn.Embedding output. Sample IDs must fit the variant's # vocab (NAR's LLM is smaller than the AR vocabs). V = int(weight.shape[0]) ref_emb = nn.Embedding.from_pretrained(torch.from_numpy(weight), freeze=True) candidates = [0, 1, 100, 1024, 50000, 100352, 100351, V - 1] sample_ids = torch.tensor( [[i for i in candidates if 0 <= i < V]], dtype=torch.long ) with torch.inference_mode(): ref = ref_emb(sample_ids).detach().float().cpu().numpy() ids_np = sample_ids.numpy().astype(np.int64) so = ort.SessionOptions() results: dict[str, Any] = {} for tier, path in out_paths.items(): if not path.exists(): continue sess = ort.InferenceSession(str(path), so, providers=["CPUExecutionProvider"]) got = sess.run(None, {"input_ids": ids_np})[0] diff = np.abs(got - ref) results[tier] = { "shape": list(got.shape), "dtype": str(got.dtype), "max_abs_err": float(diff.max()), "mean_abs_err": float(diff.mean()), "p99_abs_err": float(np.quantile(diff, 0.99)), } print( f" parity {tier:>5}: shape={list(got.shape)} max_abs_err={diff.max():.3e} " f"mean={diff.mean():.3e}" ) return results # --------------------------------------------------------------------------- # Driver # --------------------------------------------------------------------------- def export_variant(variant: str, tiers: list[str]) -> dict[str, Any]: cfg = VARIANT_CONFIG[variant] out_dir = cfg["out_dir"] out_dir.mkdir(parents=True, exist_ok=True) print(f"\n=== exporting embed_tokens for variant '{variant}' ===") weight, info = load_embed_tokens(variant) out_paths: dict[str, Path] = {} quant_stats: dict[str, Any] | None = None if "fp32" in tiers: out = out_dir / "embed_tokens.onnx" print(f" -> {out}") model = build_fp32(weight) _save_with_single_sidecar(model, out) out_paths["fp32"] = out print(f" {out.name} ({out.stat().st_size / 1e6:.1f} MB onnx, {(out.parent / (out.name + '_data')).stat().st_size / 1e9:.2f} GB sidecar)") if "fp16w" in tiers: out = out_dir / "embed_tokens_fp16w.onnx" print(f" -> {out}") model = build_fp16w(weight) _save_with_single_sidecar(model, out) out_paths["fp16w"] = out print(f" {out.name} ({(out.parent / (out.name + '_data')).stat().st_size / 1e9:.2f} GB sidecar)") if "int8" in tiers: out = out_dir / "embed_tokens_int8.onnx" print(f" -> {out}") model, quant_stats = build_int8(weight) _save_with_single_sidecar(model, out) out_paths["int8"] = out print(f" {out.name} ({(out.parent / (out.name + '_data')).stat().st_size / 1e9:.2f} GB sidecar)") parity = parity_check(weight, out_paths) summary = { "variant": variant, "info": info, "tiers": {tier: str(p.relative_to(REPO_ROOT)) for tier, p in out_paths.items()}, "parity": parity, } if quant_stats is not None: summary["int8_table_quant_stats"] = quant_stats parity_path = out_dir / "embed_tokens_parity.json" parity_path.write_text(json.dumps(summary, indent=2) + "\n", encoding="utf-8") print(f" wrote parity report -> {parity_path.relative_to(REPO_ROOT)}") return summary def main() -> int: parser = argparse.ArgumentParser() parser.add_argument( "--variant", choices=["base", "plus", "nar", "all"], default="all", ) parser.add_argument( "--tiers", default="fp32,int8,fp16w", help="comma-separated subset of fp32,int8,fp16w (default: all three)", ) args = parser.parse_args() tiers = [t.strip() for t in args.tiers.split(",") if t.strip()] invalid = sorted(set(tiers) - {"fp32", "int8", "fp16w"}) if invalid: raise SystemExit(f"unknown tiers: {invalid}") if args.variant == "all": variants = ["base", "plus", "nar"] else: variants = [args.variant] for variant in variants: export_variant(variant, tiers) return 0 if __name__ == "__main__": raise SystemExit(main())