Extend support for transformers versions from 4.45.0 to 5.2.0
This PR updates the model implementation to extend support for a wider range of transformers versions than only 4.49.0.
The model's custom code relied on internal transformers APIs that changed across versions, causing failures on versions other than 4.49.0.
_update_causal_maskremoved in transformers 5.0 — TheLlamaBidirectionalModeloverrode_update_causal_mask(), a private method onLlamaModelthat was removed in transformers 5.0 in favor ofcreate_bidirectional_maskfromtransformers.masking_utils.hidden_states[-1]returnsNoneon transformers >= 4.57 —_extract_embeddingsaccessedoutputs.hidden_states[-1], but the customLlamaBidirectionalModel.forward()never populated thehidden_statestuple, onlylast_hidden_state.
This worked incidentally on some versions but broke on 4.57+ where the base class internals changed.additional_special_tokensunavailable on transformers 5.0 — The processor filteredtokenizer.additional_special_tokens, but in transformers 5.0 the tokenizer backend changed toTokenizersBackendwhich doesn't expose this attribute.Missing
self.post_init()call —LlamaNemotronVLModel.__init__didn't callself.post_init(), the standard transformers finalization step. This causedAttributeError: 'LlamaNemotronVLModel' object has no attribute 'all_tied_weights_keys'
on transformers 5.0+.
Changes
modeling_llama_nemotron_vl.py
Replaced
_update_causal_maskoverride with explicitforward()and_create_bidirectional_mask— Instead of hooking into a private method,LlamaBidirectionalModelnow has its ownforward()that constructs the bidirectional mask
directly, dispatching totransformers.masking_utils.create_bidirectional_maskon 5.0+ and falling back to_prepare_4d_attention_maskon older versions.Version-portable decoder layer calls — Uses runtime introspection to detect API differences across transformers versions:
past_key_valuevspast_key_valuesparameter naming (changed in 4.56),DynamicCacheconstructor signature, and tuple
vs tensor return from decoder layers.Changed
_extract_embeddingsto useoutputs.last_hidden_state— Replacedself(**batch, output_hidden_states=True).hidden_states[-1]withself(**batch).last_hidden_state, which is always reliably populated by the customforward().Changed
forward()return type toBaseModelOutputWithPast— The parent returnedCausalLMOutputWithPast(which haslogitsbut nolast_hidden_state). Since this is an embedding model that doesn't compute logits,BaseModelOutputWithPast
is the correct output type and propagateslast_hidden_statefrom the language model.Added
self.post_init()— Standard transformers pattern that initializes internal bookkeeping (all_tied_weights_keys, weight tying, etc.), required by transformers 5.0+.
processing_llama_nemotron_vl.py
- Removed
additional_special_tokensfiltering — The processor filtered out<box>,</box>,<ref>,</ref>tokens fromadditional_special_tokens, which broke on transformers 5.0 where the attribute doesn't exist. Testing confirmed this
filtering has no effect on embedding output (zero diff with and without it), so it was removed entirely.
Test results
All versions produce zero diff against the reference (generated with transformers 4.49.0):
| Version | Result |
|---|---|
| 4.44.2 | FAIL — tokenizers crate can't parse tokenizer.json |
| 4.45.0 | PASS (zero diff) |
| 4.46.1 | PASS (zero diff) |
| 4.47.0 | PASS (zero diff) |
| 4.48.0 | PASS (zero diff) |
| 4.49.0 | PASS (reference) |
| 4.57.6 | PASS (zero diff) |
| 5.0.0 | PASS (zero diff) |
| 5.1.0 | PASS (zero diff) |
| 5.2.0 | PASS (zero diff) |