smcleod's picture
Create conversion_scripts/README.md
3e24fe7 verified

Multitalker Parakeet Streaming -- Conversion Scripts

Scripts for exporting NVIDIA's multitalker-parakeet-streaming-0.6b-v1 NeMo checkpoint to ONNX format for use with parakeet-rs.

Prerequisites

  • Python 3.12 (PyTorch does not ship wheels for 3.14+)
  • The .nemo checkpoint file (~2.4GB)

Download the model:

# Either from Hugging Face directly:
huggingface-cli download nvidia/multitalker-parakeet-streaming-0.6b-v1 \
  --local-dir multitalker-parakeet-streaming-0.6b-v1

# Or via git-lfs:
git clone https://huggingface.co/nvidia/multitalker-parakeet-streaming-0.6b-v1

Setup

cd conversion_scripts
uv venv --python 3.12
uv pip install --python .venv/bin/python3.12 -r requirements.txt
source .venv/bin/activate

Note: You must target the venv Python explicitly when installing with uv pip to avoid it picking up a system Python version that lacks PyTorch wheels.

Scripts

inspect_model.py

Loads the .nemo checkpoint and prints all architecture parameters needed for export and Rust integration. Saves results to multitalker_config.json.

python inspect_model.py ../multitalker-parakeet-streaming-0.6b-v1/multitalker-parakeet-streaming-0.6b-v1.nemo

export_multitalker.py

Exports the model to ONNX with dynamic int8 quantisation. This is the main script.

# Export with int8 quantisation (default)
python export_multitalker.py

# Export fp32 only
python export_multitalker.py --no-quantise

# Specify model path and output directory
python export_multitalker.py \
  --nemo-path /path/to/model.nemo \
  --output-dir ../

Output files

File Size Description
encoder.onnx ~40MB Encoder graph (references external weights)
encoder.onnx.data ~2.3GB Encoder weights (fp32)
encoder.int8.onnx ~627MB Encoder, dynamically quantised to int8
decoder_joint.onnx ~34MB Decoder + joint network (fp32)
decoder_joint.int8.onnx ~8.6MB Decoder + joint, dynamically quantised to int8
tokenizer.model ~245KB SentencePiece vocabulary
multitalker_config.json <1KB Model dimensions for Rust integration

For inference with parakeet-rs, you need at minimum:

  • encoder.int8.onnx (or encoder.onnx + encoder.onnx.data)
  • decoder_joint.int8.onnx (or decoder_joint.onnx)
  • tokenizer.model

export_model.py

Exports the standard (non-multitalker) Nemotron TDT model. Not used for the multitalker pipeline but kept for reference.

compare_models.py / build_preprocessor.py

Utilities from the original TDT export. Not required for the multitalker pipeline.

How the export works

The multitalker model injects speaker activity masks into the encoder via PyTorch forward hooks. During standard NeMo ONNX export, these hooks read from instance attributes which become constants in the ONNX graph -- the speaker targets are baked in rather than exposed as inputs.

export_multitalker.py works around this with a wrapper module (MultitalkerEncoderExport) that:

  1. Takes spk_targets and bg_spk_targets as explicit forward() parameters
  2. Sets them on the model instance before calling encoder.forward_for_export()
  3. The hooks fire during tracing and follow the tensor data flow through the speaker kernel FFNs

If the primary approach fails (e.g. the targets get constant-folded), a fallback (ExplicitKernelEncoder) removes hooks entirely and replicates the kernel injection logic inline.

The decoder is wrapped separately (DecoderJointExport) because NeMo's internal RNNTDecoderJoint.forward uses positional arguments that conflict with NeMo's own typed-methods decorator requiring kwargs.

ONNX input/output reference

Encoder

Input Shape Description
processed_signal [1, 128, time] Mel spectrogram features
processed_signal_length [1] Number of valid mel frames
cache_last_channel [1, 24, 70, 1024] Attention cache (batch-first)
cache_last_time [1, 24, 1024, 8] Convolution cache (batch-first)
cache_last_channel_len [1] Current cache fill level
spk_targets [1, spk_time] Target speaker activity mask
bg_spk_targets [1, spk_time] Background speaker activity mask
Output Shape Description
encoded [1, 1024, encoded_time] Encoded features
encoded_len [1] Number of valid encoded frames
cache_last_channel_next [1, 24, 70, 1024] Updated attention cache
cache_last_time_next [1, 24, 1024, 8] Updated convolution cache
cache_last_channel_len_next [1] Updated cache fill level

Cache format note: Caches are batch-first [batch, n_layers, ...]. This differs from the standard Nemotron ONNX export which uses [n_layers, batch, ...]. The difference arises because forward_for_export() internally transposes axes 0 and 1.

Decoder + Joint

Input Shape Description
encoder_outputs [1, enc_time, 1024] Encoder output (transposed)
targets [1, 1] Previous token ID
input_states_1 [2, 1, 640] LSTM hidden state
input_states_2 [2, 1, 640] LSTM cell state
Output Shape Description
outputs [1, enc_time, 1, 1025] Joint logits (1024 vocab + 1 blank)
prednet_lengths scalar Prediction network sequence length
states_1 [2, 1, 640] Updated LSTM hidden state
states_2 [2, 1, 640] Updated LSTM cell state

Known issues

  • PyTorch >= 2.9 breaks NeMo ONNX export due to dynamo=True becoming the default. Pin to torch<2.9.0 (see requirements.txt).
  • Python 3.14+ has no PyTorch wheels. Use Python 3.12.
  • The fp32 encoder exports weights as hundreds of scattered files which are then consolidated into a single .data file. This consolidation step can use significant memory.

Model architecture summary

d_model:              1024
encoder_layers:       24
subsampling_factor:   8
left_context:         70 frames
conv_context:         8 (kernel_size - 1)
chunk_size:           112 mel frames (~1.12s)
pre_encode_cache:     9 frames
vocab_size:           1024 (+ 1 blank = 1025)
decoder_lstm_dim:     640
decoder_lstm_layers:  2
spk_kernel_layers:    [0] (injection at encoder layer 0 only)
spk_kernel_arch:      Linear(1024,1024) -> ReLU -> Dropout(0.5) -> Linear(1024,1024)
sample_rate:          16000 Hz
n_mels:               128
n_fft:                512