smcleod commited on
Commit
3e24fe7
·
verified ·
1 Parent(s): 9d9ac36

Create conversion_scripts/README.md

Browse files
Files changed (1) hide show
  1. conversion_scripts/README.md +162 -0
conversion_scripts/README.md ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Multitalker Parakeet Streaming -- Conversion Scripts
2
+
3
+ Scripts for exporting NVIDIA's [multitalker-parakeet-streaming-0.6b-v1](https://huggingface.co/nvidia/multitalker-parakeet-streaming-0.6b-v1) NeMo checkpoint to ONNX format for use with [parakeet-rs](https://github.com/altunenes/parakeet-rs).
4
+
5
+ ## Prerequisites
6
+
7
+ - Python 3.12 (PyTorch does not ship wheels for 3.14+)
8
+ - The `.nemo` checkpoint file (~2.4GB)
9
+
10
+ Download the model:
11
+
12
+ ```bash
13
+ # Either from Hugging Face directly:
14
+ huggingface-cli download nvidia/multitalker-parakeet-streaming-0.6b-v1 \
15
+ --local-dir multitalker-parakeet-streaming-0.6b-v1
16
+
17
+ # Or via git-lfs:
18
+ git clone https://huggingface.co/nvidia/multitalker-parakeet-streaming-0.6b-v1
19
+ ```
20
+
21
+ ## Setup
22
+
23
+ ```bash
24
+ cd conversion_scripts
25
+ uv venv --python 3.12
26
+ uv pip install --python .venv/bin/python3.12 -r requirements.txt
27
+ source .venv/bin/activate
28
+ ```
29
+
30
+ > **Note:** You must target the venv Python explicitly when installing with `uv pip` to avoid it picking up a system Python version that lacks PyTorch wheels.
31
+
32
+ ## Scripts
33
+
34
+ ### inspect_model.py
35
+
36
+ Loads the `.nemo` checkpoint and prints all architecture parameters needed for export and Rust integration. Saves results to `multitalker_config.json`.
37
+
38
+ ```bash
39
+ python inspect_model.py ../multitalker-parakeet-streaming-0.6b-v1/multitalker-parakeet-streaming-0.6b-v1.nemo
40
+ ```
41
+
42
+ ### export_multitalker.py
43
+
44
+ Exports the model to ONNX with dynamic int8 quantisation. This is the main script.
45
+
46
+ ```bash
47
+ # Export with int8 quantisation (default)
48
+ python export_multitalker.py
49
+
50
+ # Export fp32 only
51
+ python export_multitalker.py --no-quantise
52
+
53
+ # Specify model path and output directory
54
+ python export_multitalker.py \
55
+ --nemo-path /path/to/model.nemo \
56
+ --output-dir ../
57
+ ```
58
+
59
+ #### Output files
60
+
61
+ | File | Size | Description |
62
+ | ------------------------- | ------ | ---------------------------------------------- |
63
+ | `encoder.onnx` | ~40MB | Encoder graph (references external weights) |
64
+ | `encoder.onnx.data` | ~2.3GB | Encoder weights (fp32) |
65
+ | `encoder.int8.onnx` | ~627MB | Encoder, dynamically quantised to int8 |
66
+ | `decoder_joint.onnx` | ~34MB | Decoder + joint network (fp32) |
67
+ | `decoder_joint.int8.onnx` | ~8.6MB | Decoder + joint, dynamically quantised to int8 |
68
+ | `tokenizer.model` | ~245KB | SentencePiece vocabulary |
69
+ | `multitalker_config.json` | <1KB | Model dimensions for Rust integration |
70
+
71
+ For inference with parakeet-rs, you need at minimum:
72
+ - `encoder.int8.onnx` (or `encoder.onnx` + `encoder.onnx.data`)
73
+ - `decoder_joint.int8.onnx` (or `decoder_joint.onnx`)
74
+ - `tokenizer.model`
75
+
76
+ ### export_model.py
77
+
78
+ Exports the standard (non-multitalker) Nemotron TDT model. Not used for the multitalker pipeline but kept for reference.
79
+
80
+ ### compare_models.py / build_preprocessor.py
81
+
82
+ Utilities from the original TDT export. Not required for the multitalker pipeline.
83
+
84
+ ## How the export works
85
+
86
+ The multitalker model injects speaker activity masks into the encoder via PyTorch forward hooks. During standard NeMo ONNX export, these hooks read from instance attributes which become constants in the ONNX graph -- the speaker targets are baked in rather than exposed as inputs.
87
+
88
+ `export_multitalker.py` works around this with a wrapper module (`MultitalkerEncoderExport`) that:
89
+
90
+ 1. Takes `spk_targets` and `bg_spk_targets` as explicit `forward()` parameters
91
+ 2. Sets them on the model instance before calling `encoder.forward_for_export()`
92
+ 3. The hooks fire during tracing and follow the tensor data flow through the speaker kernel FFNs
93
+
94
+ If the primary approach fails (e.g. the targets get constant-folded), a fallback (`ExplicitKernelEncoder`) removes hooks entirely and replicates the kernel injection logic inline.
95
+
96
+ The decoder is wrapped separately (`DecoderJointExport`) because NeMo's internal `RNNTDecoderJoint.forward` uses positional arguments that conflict with NeMo's own typed-methods decorator requiring kwargs.
97
+
98
+ ## ONNX input/output reference
99
+
100
+ ### Encoder
101
+
102
+ | Input | Shape | Description |
103
+ | ------------------------- | ------------------- | -------------------------------- |
104
+ | `processed_signal` | `[1, 128, time]` | Mel spectrogram features |
105
+ | `processed_signal_length` | `[1]` | Number of valid mel frames |
106
+ | `cache_last_channel` | `[1, 24, 70, 1024]` | Attention cache (batch-first) |
107
+ | `cache_last_time` | `[1, 24, 1024, 8]` | Convolution cache (batch-first) |
108
+ | `cache_last_channel_len` | `[1]` | Current cache fill level |
109
+ | `spk_targets` | `[1, spk_time]` | Target speaker activity mask |
110
+ | `bg_spk_targets` | `[1, spk_time]` | Background speaker activity mask |
111
+
112
+ | Output | Shape | Description |
113
+ | ----------------------------- | ------------------------- | ------------------------------ |
114
+ | `encoded` | `[1, 1024, encoded_time]` | Encoded features |
115
+ | `encoded_len` | `[1]` | Number of valid encoded frames |
116
+ | `cache_last_channel_next` | `[1, 24, 70, 1024]` | Updated attention cache |
117
+ | `cache_last_time_next` | `[1, 24, 1024, 8]` | Updated convolution cache |
118
+ | `cache_last_channel_len_next` | `[1]` | Updated cache fill level |
119
+
120
+ > **Cache format note:** Caches are batch-first `[batch, n_layers, ...]`. This differs from the standard Nemotron ONNX export which uses `[n_layers, batch, ...]`. The difference arises because `forward_for_export()` internally transposes axes 0 and 1.
121
+
122
+ ### Decoder + Joint
123
+
124
+ | Input | Shape | Description |
125
+ | ----------------- | --------------------- | --------------------------- |
126
+ | `encoder_outputs` | `[1, enc_time, 1024]` | Encoder output (transposed) |
127
+ | `targets` | `[1, 1]` | Previous token ID |
128
+ | `input_states_1` | `[2, 1, 640]` | LSTM hidden state |
129
+ | `input_states_2` | `[2, 1, 640]` | LSTM cell state |
130
+
131
+ | Output | Shape | Description |
132
+ | ----------------- | ------------------------ | ----------------------------------- |
133
+ | `outputs` | `[1, enc_time, 1, 1025]` | Joint logits (1024 vocab + 1 blank) |
134
+ | `prednet_lengths` | scalar | Prediction network sequence length |
135
+ | `states_1` | `[2, 1, 640]` | Updated LSTM hidden state |
136
+ | `states_2` | `[2, 1, 640]` | Updated LSTM cell state |
137
+
138
+ ## Known issues
139
+
140
+ - **PyTorch >= 2.9** breaks NeMo ONNX export due to `dynamo=True` becoming the default. Pin to `torch<2.9.0` (see `requirements.txt`).
141
+ - **Python 3.14+** has no PyTorch wheels. Use Python 3.12.
142
+ - The fp32 encoder exports weights as hundreds of scattered files which are then consolidated into a single `.data` file. This consolidation step can use significant memory.
143
+
144
+ ## Model architecture summary
145
+
146
+ ```
147
+ d_model: 1024
148
+ encoder_layers: 24
149
+ subsampling_factor: 8
150
+ left_context: 70 frames
151
+ conv_context: 8 (kernel_size - 1)
152
+ chunk_size: 112 mel frames (~1.12s)
153
+ pre_encode_cache: 9 frames
154
+ vocab_size: 1024 (+ 1 blank = 1025)
155
+ decoder_lstm_dim: 640
156
+ decoder_lstm_layers: 2
157
+ spk_kernel_layers: [0] (injection at encoder layer 0 only)
158
+ spk_kernel_arch: Linear(1024,1024) -> ReLU -> Dropout(0.5) -> Linear(1024,1024)
159
+ sample_rate: 16000 Hz
160
+ n_mels: 128
161
+ n_fft: 512
162
+ ```