# Multitalker Parakeet Streaming -- Conversion Scripts Scripts for exporting NVIDIA's [multitalker-parakeet-streaming-0.6b-v1](https://huggingface.co/nvidia/multitalker-parakeet-streaming-0.6b-v1) NeMo checkpoint to ONNX format for use with [parakeet-rs](https://github.com/altunenes/parakeet-rs). ## Prerequisites - Python 3.12 (PyTorch does not ship wheels for 3.14+) - The `.nemo` checkpoint file (~2.4GB) Download the model: ```bash # Either from Hugging Face directly: huggingface-cli download nvidia/multitalker-parakeet-streaming-0.6b-v1 \ --local-dir multitalker-parakeet-streaming-0.6b-v1 # Or via git-lfs: git clone https://huggingface.co/nvidia/multitalker-parakeet-streaming-0.6b-v1 ``` ## Setup ```bash cd conversion_scripts uv venv --python 3.12 uv pip install --python .venv/bin/python3.12 -r requirements.txt source .venv/bin/activate ``` > **Note:** You must target the venv Python explicitly when installing with `uv pip` to avoid it picking up a system Python version that lacks PyTorch wheels. ## Scripts ### inspect_model.py Loads the `.nemo` checkpoint and prints all architecture parameters needed for export and Rust integration. Saves results to `multitalker_config.json`. ```bash python inspect_model.py ../multitalker-parakeet-streaming-0.6b-v1/multitalker-parakeet-streaming-0.6b-v1.nemo ``` ### export_multitalker.py Exports the model to ONNX with dynamic int8 quantisation. This is the main script. ```bash # Export with int8 quantisation (default) python export_multitalker.py # Export fp32 only python export_multitalker.py --no-quantise # Specify model path and output directory python export_multitalker.py \ --nemo-path /path/to/model.nemo \ --output-dir ../ ``` #### Output files | File | Size | Description | | ------------------------- | ------ | ---------------------------------------------- | | `encoder.onnx` | ~40MB | Encoder graph (references external weights) | | `encoder.onnx.data` | ~2.3GB | Encoder weights (fp32) | | `encoder.int8.onnx` | ~627MB | Encoder, dynamically quantised to int8 | | `decoder_joint.onnx` | ~34MB | Decoder + joint network (fp32) | | `decoder_joint.int8.onnx` | ~8.6MB | Decoder + joint, dynamically quantised to int8 | | `tokenizer.model` | ~245KB | SentencePiece vocabulary | | `multitalker_config.json` | <1KB | Model dimensions for Rust integration | For inference with parakeet-rs, you need at minimum: - `encoder.int8.onnx` (or `encoder.onnx` + `encoder.onnx.data`) - `decoder_joint.int8.onnx` (or `decoder_joint.onnx`) - `tokenizer.model` ### export_model.py Exports the standard (non-multitalker) Nemotron TDT model. Not used for the multitalker pipeline but kept for reference. ### compare_models.py / build_preprocessor.py Utilities from the original TDT export. Not required for the multitalker pipeline. ## How the export works The multitalker model injects speaker activity masks into the encoder via PyTorch forward hooks. During standard NeMo ONNX export, these hooks read from instance attributes which become constants in the ONNX graph -- the speaker targets are baked in rather than exposed as inputs. `export_multitalker.py` works around this with a wrapper module (`MultitalkerEncoderExport`) that: 1. Takes `spk_targets` and `bg_spk_targets` as explicit `forward()` parameters 2. Sets them on the model instance before calling `encoder.forward_for_export()` 3. The hooks fire during tracing and follow the tensor data flow through the speaker kernel FFNs If the primary approach fails (e.g. the targets get constant-folded), a fallback (`ExplicitKernelEncoder`) removes hooks entirely and replicates the kernel injection logic inline. The decoder is wrapped separately (`DecoderJointExport`) because NeMo's internal `RNNTDecoderJoint.forward` uses positional arguments that conflict with NeMo's own typed-methods decorator requiring kwargs. ## ONNX input/output reference ### Encoder | Input | Shape | Description | | ------------------------- | ------------------- | -------------------------------- | | `processed_signal` | `[1, 128, time]` | Mel spectrogram features | | `processed_signal_length` | `[1]` | Number of valid mel frames | | `cache_last_channel` | `[1, 24, 70, 1024]` | Attention cache (batch-first) | | `cache_last_time` | `[1, 24, 1024, 8]` | Convolution cache (batch-first) | | `cache_last_channel_len` | `[1]` | Current cache fill level | | `spk_targets` | `[1, spk_time]` | Target speaker activity mask | | `bg_spk_targets` | `[1, spk_time]` | Background speaker activity mask | | Output | Shape | Description | | ----------------------------- | ------------------------- | ------------------------------ | | `encoded` | `[1, 1024, encoded_time]` | Encoded features | | `encoded_len` | `[1]` | Number of valid encoded frames | | `cache_last_channel_next` | `[1, 24, 70, 1024]` | Updated attention cache | | `cache_last_time_next` | `[1, 24, 1024, 8]` | Updated convolution cache | | `cache_last_channel_len_next` | `[1]` | Updated cache fill level | > **Cache format note:** Caches are batch-first `[batch, n_layers, ...]`. This differs from the standard Nemotron ONNX export which uses `[n_layers, batch, ...]`. The difference arises because `forward_for_export()` internally transposes axes 0 and 1. ### Decoder + Joint | Input | Shape | Description | | ----------------- | --------------------- | --------------------------- | | `encoder_outputs` | `[1, enc_time, 1024]` | Encoder output (transposed) | | `targets` | `[1, 1]` | Previous token ID | | `input_states_1` | `[2, 1, 640]` | LSTM hidden state | | `input_states_2` | `[2, 1, 640]` | LSTM cell state | | Output | Shape | Description | | ----------------- | ------------------------ | ----------------------------------- | | `outputs` | `[1, enc_time, 1, 1025]` | Joint logits (1024 vocab + 1 blank) | | `prednet_lengths` | scalar | Prediction network sequence length | | `states_1` | `[2, 1, 640]` | Updated LSTM hidden state | | `states_2` | `[2, 1, 640]` | Updated LSTM cell state | ## Known issues - **PyTorch >= 2.9** breaks NeMo ONNX export due to `dynamo=True` becoming the default. Pin to `torch<2.9.0` (see `requirements.txt`). - **Python 3.14+** has no PyTorch wheels. Use Python 3.12. - The fp32 encoder exports weights as hundreds of scattered files which are then consolidated into a single `.data` file. This consolidation step can use significant memory. ## Model architecture summary ``` d_model: 1024 encoder_layers: 24 subsampling_factor: 8 left_context: 70 frames conv_context: 8 (kernel_size - 1) chunk_size: 112 mel frames (~1.12s) pre_encode_cache: 9 frames vocab_size: 1024 (+ 1 blank = 1025) decoder_lstm_dim: 640 decoder_lstm_layers: 2 spk_kernel_layers: [0] (injection at encoder layer 0 only) spk_kernel_arch: Linear(1024,1024) -> ReLU -> Dropout(0.5) -> Linear(1024,1024) sample_rate: 16000 Hz n_mels: 128 n_fft: 512 ```