Add support for transformers 4.44 through 5.0+
Browse filesRefactor LlamaBidirectionalModel to handle API changes across transformers
versions using introspection-based detection.
Key changes:
- Replace _update_causal_mask override with full forward() implementation
(required since 4.53 when _update_causal_mask was removed)
- Remove forced eager attention mode that broke SDPA/flash attention in 4.48+
- Use create_bidirectional_mask when available (5.0+), fall back to
_prepare_4d_attention_mask for older versions
- Handle decoder layer return type (tuple vs tensor) and cache parameter
name (past_key_value vs past_key_values) differences via introspection
- Add comprehensive docstrings and type hints
Tested with transformers 4.44, 4.47.1, 4.48, 4.53, 4.56, 4.57.6, and 5.0.0.
Signed-off-by: Oliver Holworthy <nvidia-oliver-holworthy@users.noreply.huggingface.co>
Update comment about API difference detection
- llama_bidirectional_model.py +201 -21
- pooling.py +9 -6
|
@@ -1,46 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
-
from transformers.cache_utils import Cache
|
| 3 |
-
from transformers.
|
| 4 |
from transformers.models.llama.configuration_llama import LlamaConfig
|
| 5 |
-
from transformers.models.llama.modeling_llama import
|
| 6 |
-
LlamaModel,
|
| 7 |
-
)
|
| 8 |
from transformers.utils import logging
|
| 9 |
|
| 10 |
-
|
| 11 |
logger = logging.get_logger(__name__)
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
class LlamaBidirectionalConfig(LlamaConfig):
|
|
|
|
|
|
|
| 15 |
model_type = "llama_bidirec"
|
| 16 |
|
| 17 |
def __init__(
|
| 18 |
-
self, pooling="avg", temperature=1.0, **kwargs
|
| 19 |
-
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
self.pooling = pooling
|
| 21 |
self.temperature = temperature
|
| 22 |
-
super().__init__(**kwargs
|
| 23 |
|
| 24 |
|
| 25 |
class LlamaBidirectionalModel(LlamaModel):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
config_class = LlamaBidirectionalConfig
|
| 27 |
|
| 28 |
-
def __init__(self, config: LlamaConfig):
|
| 29 |
super().__init__(config)
|
| 30 |
for layer in self.layers:
|
| 31 |
layer.self_attn.is_causal = False
|
| 32 |
-
self.config._attn_implementation = "eager"
|
| 33 |
|
| 34 |
-
def
|
| 35 |
self,
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
|
|
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0.
|
| 3 |
+
"""
|
| 4 |
+
Bidirectional Llama model for embedding tasks.
|
| 5 |
+
|
| 6 |
+
This module provides a modified LlamaModel that uses bidirectional (non-causal)
|
| 7 |
+
attention, suitable for generating embeddings where each token should attend
|
| 8 |
+
to all other tokens in the sequence.
|
| 9 |
+
|
| 10 |
+
Supports transformers version 4.44 and above with a unified forward() implementation.
|
| 11 |
+
|
| 12 |
+
Version compatibility notes:
|
| 13 |
+
- transformers 4.47: Setting _attn_implementation in __init__ had no effect due to
|
| 14 |
+
attention initialization order
|
| 15 |
+
- transformers 4.48+: Attention refactor (transformers#35235) activated the
|
| 16 |
+
_attn_implementation setting, which defaulted to "eager" instead of "sdpa"
|
| 17 |
+
- transformers < 4.53: LlamaModel has _update_causal_mask method that can be overridden
|
| 18 |
+
- transformers 4.53+: _update_causal_mask removed; masking moved to masking_utils module,
|
| 19 |
+
necessitating a full forward() override for custom attention masks
|
| 20 |
+
- transformers < 4.54: Decoder layer returns tuple, uses past_key_value (singular)
|
| 21 |
+
- transformers 4.54-4.55: Decoder layer returns tensor, uses past_key_value (singular)
|
| 22 |
+
- transformers 4.56+: Decoder layer returns tensor, uses past_key_values (plural),
|
| 23 |
+
DynamicCache accepts config parameter
|
| 24 |
+
- transformers 5.0+: Has native create_bidirectional_mask in masking_utils
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
import inspect
|
| 28 |
+
|
| 29 |
import torch
|
| 30 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 31 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast
|
| 32 |
from transformers.models.llama.configuration_llama import LlamaConfig
|
| 33 |
+
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
|
|
|
|
|
|
|
| 34 |
from transformers.utils import logging
|
| 35 |
|
|
|
|
| 36 |
logger = logging.get_logger(__name__)
|
| 37 |
|
| 38 |
+
# Check if native create_bidirectional_mask exists (transformers >= 5.0)
|
| 39 |
+
try:
|
| 40 |
+
from transformers.masking_utils import create_bidirectional_mask
|
| 41 |
+
|
| 42 |
+
_HAS_NATIVE_BIDIRECTIONAL_MASK = True
|
| 43 |
+
except ImportError:
|
| 44 |
+
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
| 45 |
+
|
| 46 |
+
_HAS_NATIVE_BIDIRECTIONAL_MASK = False
|
| 47 |
+
|
| 48 |
+
# Detect API differences via introspection
|
| 49 |
+
_decoder_forward_params = inspect.signature(LlamaDecoderLayer.forward).parameters
|
| 50 |
+
_dynamic_cache_init_params = inspect.signature(DynamicCache.__init__).parameters
|
| 51 |
+
|
| 52 |
+
# past_key_value (singular) in < 4.56, past_key_values (plural) in >= 4.56
|
| 53 |
+
_USE_PLURAL_CACHE_PARAM = "past_key_values" in _decoder_forward_params
|
| 54 |
+
# DynamicCache accepts config parameter in >= 4.56
|
| 55 |
+
_DYNAMIC_CACHE_ACCEPTS_CONFIG = "config" in _dynamic_cache_init_params
|
| 56 |
+
|
| 57 |
|
| 58 |
class LlamaBidirectionalConfig(LlamaConfig):
|
| 59 |
+
"""Configuration for LlamaBidirectionalModel with pooling and temperature settings."""
|
| 60 |
+
|
| 61 |
model_type = "llama_bidirec"
|
| 62 |
|
| 63 |
def __init__(
|
| 64 |
+
self, pooling: str = "avg", temperature: float = 1.0, **kwargs
|
| 65 |
+
) -> None:
|
| 66 |
+
"""
|
| 67 |
+
Initialize bidirectional Llama configuration.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
pooling: Pooling strategy for embeddings ("avg", "cls", "last", etc.)
|
| 71 |
+
temperature: Temperature scaling for embeddings
|
| 72 |
+
**kwargs: Additional arguments passed to LlamaConfig
|
| 73 |
+
"""
|
| 74 |
self.pooling = pooling
|
| 75 |
self.temperature = temperature
|
| 76 |
+
super().__init__(**kwargs)
|
| 77 |
|
| 78 |
|
| 79 |
class LlamaBidirectionalModel(LlamaModel):
|
| 80 |
+
"""
|
| 81 |
+
LlamaModel modified to use bidirectional (non-causal) attention.
|
| 82 |
+
|
| 83 |
+
In standard Llama, each token can only attend to previous tokens (causal attention).
|
| 84 |
+
This model removes that restriction, allowing each token to attend to all tokens
|
| 85 |
+
in the sequence, which is useful for embedding tasks.
|
| 86 |
+
|
| 87 |
+
The key modifications are:
|
| 88 |
+
1. Setting is_causal=False on all attention layers
|
| 89 |
+
2. Using a bidirectional attention mask instead of causal mask
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
config_class = LlamaBidirectionalConfig
|
| 93 |
|
| 94 |
+
def __init__(self, config: LlamaConfig) -> None:
|
| 95 |
super().__init__(config)
|
| 96 |
for layer in self.layers:
|
| 97 |
layer.self_attn.is_causal = False
|
|
|
|
| 98 |
|
| 99 |
+
def _create_bidirectional_mask(
|
| 100 |
self,
|
| 101 |
+
input_embeds: torch.Tensor,
|
| 102 |
+
attention_mask: torch.Tensor | None,
|
| 103 |
+
) -> torch.Tensor | None:
|
| 104 |
+
"""
|
| 105 |
+
Create bidirectional attention mask.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
input_embeds: Input embeddings tensor of shape (batch_size, seq_len, hidden_size)
|
| 109 |
+
attention_mask: Optional 2D attention mask of shape (batch_size, seq_len)
|
| 110 |
+
where 1 indicates tokens to attend to and 0 indicates masked tokens
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
4D attention mask suitable for the attention implementation, or None
|
| 114 |
+
if no masking is needed
|
| 115 |
+
"""
|
| 116 |
+
if attention_mask is None:
|
| 117 |
+
return None
|
| 118 |
+
|
| 119 |
+
if _HAS_NATIVE_BIDIRECTIONAL_MASK:
|
| 120 |
+
return create_bidirectional_mask(
|
| 121 |
+
config=self.config,
|
| 122 |
+
input_embeds=input_embeds,
|
| 123 |
+
attention_mask=attention_mask,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# Fallback for transformers < 5.0 without create_bidirectional_mask
|
| 127 |
+
|
| 128 |
+
# Flash attention handles 2D masks internally; only pass mask if there
|
| 129 |
+
# are actually masked tokens (zeros), otherwise return None for efficiency
|
| 130 |
+
if getattr(self.config, "_attn_implementation", None) == "flash_attention_2":
|
| 131 |
+
has_masked_tokens = (attention_mask == 0).any()
|
| 132 |
+
return attention_mask if has_masked_tokens else None
|
| 133 |
+
|
| 134 |
+
return _prepare_4d_attention_mask(attention_mask, input_embeds.dtype)
|
| 135 |
+
|
| 136 |
+
def forward(
|
| 137 |
+
self,
|
| 138 |
+
input_ids: torch.LongTensor | None = None,
|
| 139 |
+
attention_mask: torch.Tensor | None = None,
|
| 140 |
+
position_ids: torch.LongTensor | None = None,
|
| 141 |
+
past_key_values: Cache | None = None,
|
| 142 |
+
inputs_embeds: torch.FloatTensor | None = None,
|
| 143 |
+
cache_position: torch.LongTensor | None = None,
|
| 144 |
+
use_cache: bool | None = None,
|
| 145 |
+
**kwargs,
|
| 146 |
+
) -> BaseModelOutputWithPast:
|
| 147 |
+
"""
|
| 148 |
+
Forward pass with bidirectional attention.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
input_ids: Input token IDs of shape (batch_size, seq_len)
|
| 152 |
+
attention_mask: Attention mask of shape (batch_size, seq_len)
|
| 153 |
+
position_ids: Position IDs for rotary embeddings
|
| 154 |
+
past_key_values: Cached key/value states for incremental decoding
|
| 155 |
+
inputs_embeds: Pre-computed input embeddings (alternative to input_ids)
|
| 156 |
+
cache_position: Position indices for cache updates
|
| 157 |
+
use_cache: Whether to return cached key/value states
|
| 158 |
+
**kwargs: Additional arguments passed to decoder layers
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
BaseModelOutputWithPast containing last_hidden_state and past_key_values
|
| 162 |
+
"""
|
| 163 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 164 |
+
raise ValueError(
|
| 165 |
+
"You must specify exactly one of input_ids or inputs_embeds"
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
if inputs_embeds is None:
|
| 169 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 170 |
+
|
| 171 |
+
# Initialize cache if needed
|
| 172 |
+
if use_cache and past_key_values is None:
|
| 173 |
+
if _DYNAMIC_CACHE_ACCEPTS_CONFIG:
|
| 174 |
+
past_key_values = DynamicCache(config=self.config)
|
| 175 |
+
else:
|
| 176 |
+
past_key_values = DynamicCache()
|
| 177 |
+
|
| 178 |
+
if cache_position is None:
|
| 179 |
+
past_seen_tokens = (
|
| 180 |
+
past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 181 |
+
)
|
| 182 |
+
cache_position = torch.arange(
|
| 183 |
+
past_seen_tokens,
|
| 184 |
+
past_seen_tokens + inputs_embeds.shape[1],
|
| 185 |
+
device=inputs_embeds.device,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
if position_ids is None:
|
| 189 |
+
position_ids = cache_position.unsqueeze(0)
|
| 190 |
+
|
| 191 |
+
bidirectional_mask = self._create_bidirectional_mask(
|
| 192 |
+
inputs_embeds, attention_mask
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
hidden_states = inputs_embeds
|
| 196 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 197 |
+
|
| 198 |
+
# Build decoder layer kwargs with correct cache parameter name
|
| 199 |
+
# (past_key_value in < 4.56, past_key_values in >= 4.56)
|
| 200 |
+
layer_kwargs = {
|
| 201 |
+
"attention_mask": bidirectional_mask,
|
| 202 |
+
"position_ids": position_ids,
|
| 203 |
+
"use_cache": use_cache,
|
| 204 |
+
"cache_position": cache_position,
|
| 205 |
+
"position_embeddings": position_embeddings,
|
| 206 |
+
}
|
| 207 |
+
if _USE_PLURAL_CACHE_PARAM:
|
| 208 |
+
layer_kwargs["past_key_values"] = past_key_values
|
| 209 |
+
else:
|
| 210 |
+
layer_kwargs["past_key_value"] = past_key_values
|
| 211 |
+
|
| 212 |
+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
| 213 |
+
layer_outputs = decoder_layer(hidden_states, **layer_kwargs)
|
| 214 |
+
|
| 215 |
+
# Decoder returns tuple in < 4.54, tensor in >= 4.54
|
| 216 |
+
if isinstance(layer_outputs, tuple):
|
| 217 |
+
hidden_states = layer_outputs[0]
|
| 218 |
+
else:
|
| 219 |
+
hidden_states = layer_outputs
|
| 220 |
|
| 221 |
+
hidden_states = self.norm(hidden_states)
|
| 222 |
|
| 223 |
+
return BaseModelOutputWithPast(
|
| 224 |
+
last_hidden_state=hidden_states,
|
| 225 |
+
past_key_values=past_key_values,
|
| 226 |
+
)
|
|
@@ -1,9 +1,10 @@
|
|
|
|
|
|
|
|
| 1 |
from torch import Tensor
|
| 2 |
import torch
|
| 3 |
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
pool_type: str) -> Tensor:
|
| 7 |
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
| 8 |
|
| 9 |
if pool_type == "avg":
|
|
@@ -13,14 +14,16 @@ def pool(last_hidden_states: Tensor,
|
|
| 13 |
elif pool_type == "cls":
|
| 14 |
emb = last_hidden[:, 0]
|
| 15 |
elif pool_type == "last":
|
| 16 |
-
left_padding =
|
| 17 |
if left_padding:
|
| 18 |
emb = last_hidden[:, -1]
|
| 19 |
else:
|
| 20 |
sequence_lengths = attention_mask.sum(dim=1) - 1
|
| 21 |
batch_size = last_hidden.shape[0]
|
| 22 |
-
emb = last_hidden[
|
|
|
|
|
|
|
| 23 |
else:
|
| 24 |
raise ValueError(f"pool_type {pool_type} not supported")
|
| 25 |
|
| 26 |
-
return emb
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0.
|
| 3 |
from torch import Tensor
|
| 4 |
import torch
|
| 5 |
|
| 6 |
+
|
| 7 |
+
def pool(last_hidden_states: Tensor, attention_mask: Tensor, pool_type: str) -> Tensor:
|
|
|
|
| 8 |
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
| 9 |
|
| 10 |
if pool_type == "avg":
|
|
|
|
| 14 |
elif pool_type == "cls":
|
| 15 |
emb = last_hidden[:, 0]
|
| 16 |
elif pool_type == "last":
|
| 17 |
+
left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0]
|
| 18 |
if left_padding:
|
| 19 |
emb = last_hidden[:, -1]
|
| 20 |
else:
|
| 21 |
sequence_lengths = attention_mask.sum(dim=1) - 1
|
| 22 |
batch_size = last_hidden.shape[0]
|
| 23 |
+
emb = last_hidden[
|
| 24 |
+
torch.arange(batch_size, device=last_hidden.device), sequence_lengths
|
| 25 |
+
]
|
| 26 |
else:
|
| 27 |
raise ValueError(f"pool_type {pool_type} not supported")
|
| 28 |
|
| 29 |
+
return emb
|