"""PyTorch Sarvam MoE model.""" import math from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F from torch import nn from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_attn_mask_utils import ( AttentionMaskConverter, _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, ) from transformers.modeling_outputs import MoeModelOutputWithPast from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from transformers.modeling_utils import PreTrainedModel from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS from transformers.utils import ( is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging, ) from transformers.generation.utils import GenerationMixin from dataclasses import dataclass from transformers.utils import ModelOutput if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input from .configuration_sarvam_moe import SarvamMoEConfig logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "SarvamMoEConfig" @dataclass class SarvamMoECausalLMOutputWithPast(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None past_key_values: Optional[Cache] = None hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None attentions: Optional[tuple[torch.FloatTensor, ...]] = None z_loss: Optional[torch.FloatTensor] = None aux_loss: Optional[torch.FloatTensor] = None router_logits: Optional[tuple[torch.FloatTensor]] = None class SarvamMoEModelOutputWithPast(MoeModelOutputWithPast): pass def _get_unpad_data(attention_mask): seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) return indices, cu_seqlens, max_seqlen_in_batch def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): return _prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) def _make_causal_mask( input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 ): return AttentionMaskConverter._make_causal_mask( input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length ) class SarvamMoERMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) ALL_LAYERNORM_LAYERS.append(SarvamMoERMSNorm) class SarvamMoERotaryEmbedding(nn.Module): def __init__(self, config: SarvamMoEConfig, device=None): super().__init__() self.config = config self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is None: self.rope_type = "default" inv_freq, self.attention_scaling = self.compute_default_rope_parameters( config, device ) else: self.rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default")) if self.rope_type == "default": inv_freq, self.attention_scaling = self.compute_default_rope_parameters( config, device ) else: rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = rope_init_fn(config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @staticmethod def compute_default_rope_parameters( config: SarvamMoEConfig, device: Optional[torch.device] = None, seq_len: Optional[int] = None, ) -> Tuple[torch.Tensor, float]: """ Default RoPE parameters (classic rotary embedding). Mirrors HF's default implementation: use `rope_theta`, head_dim and return (inv_freq, attention_scaling). """ base = config.rope_theta dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads inv_freq = 1.0 / ( base ** ( torch.arange(0, dim, 2, dtype=torch.int64, device=device) .to(dtype=torch.float32) / dim ) ) attention_factor = 1.0 return inv_freq, attention_factor @torch.no_grad() @dynamic_rope_update def forward(self, x, position_ids): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) def rotate_half(x): x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) rotary_dim = cos.shape[-1] q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) q_embed = torch.cat([q_embed, q_pass], dim=-1) k_embed = torch.cat([k_embed, k_pass], dim=-1) return q_embed, k_embed class SarvamMoEMLP(nn.Module): def __init__(self, config: SarvamMoEConfig, intermediate_size: int): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) class SarvamMoEGate(nn.Module): def __init__(self, config): super().__init__() self.config = config self.top_k = config.num_experts_per_tok self.num_experts = config.num_experts self.n_group = config.n_group self.topk_group = config.topk_group self.gating_dim = config.hidden_size self.weight = nn.Parameter(torch.empty((self.num_experts, self.gating_dim))) self.routed_scaling_factor = config.routed_scaling_factor self.score_function = config.score_function # Ideally, we should register the expert_bias as a buffer, but vllm complains about it. # self.register_buffer("expert_bias", torch.zeros((self.num_experts))) self.expert_bias = nn.Parameter( torch.zeros((self.num_experts)), requires_grad=False, ) self.reset_parameters() def reset_parameters(self) -> None: import torch.nn.init as init init.kaiming_uniform_(self.weight, a=math.sqrt(5)) def group_limited_topk(self, scores: torch.Tensor): num_tokens, _ = scores.size() group_scores = scores.view(num_tokens, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1) group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] group_mask = torch.zeros_like(group_scores) group_mask.scatter_(1, group_idx, 1) score_mask = ( group_mask.unsqueeze(-1) .expand(num_tokens, self.n_group, self.num_experts // self.n_group) .reshape(num_tokens, -1) ) masked_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) probs, top_indices = torch.topk(masked_scores, k=self.top_k, dim=-1) return probs, top_indices def forward(self, hidden_states): hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) scores = torch.sigmoid(logits.float()).type_as(logits) scores_for_routing = scores + self.expert_bias _, topk_idx = self.group_limited_topk(scores_for_routing) scores = torch.gather(scores, dim=1, index=topk_idx).type_as(logits) topk_weight = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if self.top_k > 1 else scores topk_weight = topk_weight * self.routed_scaling_factor return topk_idx, topk_weight, logits class SarvamMoEExperts(nn.ModuleList): def __init__(self, config: SarvamMoEConfig): # one MLP per expert experts = [ SarvamMoEMLP(config=config, intermediate_size=config.moe_intermediate_size) for _ in range(config.num_experts) ] super().__init__(experts) self.config = config self.num_experts_per_tok = config.num_experts_per_tok def forward( self, hidden_states: torch.Tensor, top_k_index: torch.LongTensor, top_k_weights: torch.Tensor, ) -> torch.Tensor: """ hidden_states: (tokens, hidden_size) or (batch * seq, hidden_size) top_k_index: (tokens, top_k) top_k_weights: (tokens, top_k) """ tokens, hidden_dim = hidden_states.shape flat_topk_idx = top_k_index.view(-1) if self.training: # training path: same as your previous logic x = hidden_states.repeat_interleave(self.num_experts_per_tok, dim=0) y = torch.empty_like(x) for i, expert in enumerate(self): mask = flat_topk_idx == i if mask.any(): y[mask] = expert(x[mask]) y = (y.view(*top_k_weights.shape, -1) * top_k_weights.unsqueeze(-1)).sum(dim=1) return y.to(hidden_states.dtype) # inference path: previous moe_infer logic num_experts = len(self) cnts = top_k_index.new_zeros((tokens, num_experts)) cnts.scatter_(1, top_k_index, 1) tokens_per_expert = cnts.sum(dim=0) idxs = top_k_index.view(-1).argsort() sorted_tokens = hidden_states[idxs // top_k_index.shape[1]] tokens_per_expert = tokens_per_expert.cpu().numpy().tolist() outputs = [] start_idx = 0 for i, num_tokens in enumerate(tokens_per_expert): end_idx = start_idx + num_tokens if num_tokens == 0: continue expert = self[i] tokens_for_expert = sorted_tokens[start_idx:end_idx] expert_out = expert(tokens_for_expert) outputs.append(expert_out.to(hidden_states.device)) start_idx = end_idx outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) new_x = torch.empty_like(outs) new_x[idxs] = outs final_out = ( new_x.view(*top_k_index.shape, -1) .type(top_k_weights.dtype) .mul_(top_k_weights.unsqueeze(dim=-1)) .sum(dim=1) .type(new_x.dtype) ) return final_out class SarvamMoESparseMoeBlock(nn.Module): def __init__(self, config: SarvamMoEConfig): super().__init__() self.config = config self.num_experts_per_tok = config.num_experts_per_tok # use the new experts container self.experts = SarvamMoEExperts(config) self.gate = SarvamMoEGate(config) if config.num_shared_experts is not None: self.shared_experts = SarvamMoEMLP( config=config, intermediate_size=config.moe_intermediate_size * config.num_shared_experts, ) # _setup_experts no longer needed def forward(self, hidden_states): identity = hidden_states bsz, seq_len, h = hidden_states.shape topk_idx, topk_weight, router_logits = self.gate(hidden_states) # flatten batch+seq for experts flat_hidden = hidden_states.view(-1, h) flat_topk_idx = topk_idx.view(-1, topk_idx.shape[-1]) flat_topk_weight = topk_weight.view(-1, topk_weight.shape[-1]) y = self.experts(flat_hidden, flat_topk_idx, flat_topk_weight) y = y.view(bsz, seq_len, h) if self.config.num_shared_experts is not None: y = y + self.shared_experts(identity) # router logits shape: (bsz, seq_len, num_experts) router_info = ( router_logits.view(bsz, seq_len, -1), topk_idx.view(bsz, seq_len, -1), ) return y, router_info def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) class SarvamMoEAttention(nn.Module): is_causal = True # vLLM / Transformers backend critical flag def __init__(self, config: SarvamMoEConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx if layer_idx is None: logger.warning_once( f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " "when creating this class." ) self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = config.head_dim or self.hidden_size // self.num_heads partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 self.rope_dim = int(self.head_dim * partial_rotary_factor) self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.query_key_value = nn.Linear( self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=config.use_qkv_bias, ) if self.config.use_qk_norm: self.query_layernorm = SarvamMoERMSNorm(self.head_dim, eps=config.rms_norm_eps) self.key_layernorm = SarvamMoERMSNorm(self.head_dim, eps=config.rms_norm_eps) self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.use_bias) self.scaling = self.head_dim**-0.5 def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ): bsz, q_len, _ = hidden_states.size() qkv = self.query_key_value(hidden_states) qkv = qkv.view( bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim ) query_states, key_states, value_states = qkv.split( [self.num_heads, self.num_key_value_heads, self.num_key_value_heads], dim=-2, ) query_states = query_states.transpose(1, 2).contiguous() key_states = key_states.transpose(1, 2).contiguous() value_states = value_states.transpose(1, 2).contiguous() if self.config.use_qk_norm: query_states = self.query_layernorm(query_states) key_states = self.key_layernorm(key_states) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin ) if past_key_value is not None: if self.layer_idx is None: raise ValueError( "When using cache, SarvamMoEAttention must be initialized with layer_idx." ) cache_kwargs = {"sin": sin, "cos": cos} key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs ) # NOTE: vLLM will set config._attn_implementation = "vllm" if self.config._attn_implementation == "vllm": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, ) # vLLM backend may return [B, L, hidden] or [B*L, hidden] if attn_output.dim() == 4: # [B, H, L, Dh] -> [B, L, hidden] attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, -1) elif attn_output.dim() == 3: if attn_output.shape[0] != bsz or attn_output.shape[1] != q_len: raise ValueError( f"Unexpected vLLM attention output shape {attn_output.shape}, " f"expected (bsz={bsz}, q_len={q_len}, hidden=*)" ) elif attn_output.dim() == 2: attn_output = attn_output.view(bsz, q_len, -1) else: raise ValueError( f"Unsupported vLLM attention output rank {attn_output.dim()} " f"with shape {attn_output.shape}" ) attn_output = self.dense(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) kv_seq_len = key_states.shape[-2] if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" f" {attn_weights.size()}" ) if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.dense(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value class SarvamMoEFlashAttention2(SarvamMoEAttention): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: output_attentions = False bsz, q_len, _ = hidden_states.size() qkv = self.query_key_value(hidden_states) qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim) query_states, key_states, value_states = qkv.split( [self.num_heads, self.num_key_value_heads, self.num_key_value_heads], dim=-2 ) query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) if self.config.use_qk_norm: query_states = self.query_layernorm(query_states) key_states = self.key_layernorm(key_states) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) dropout_rate = self.attention_dropout if self.training else 0.0 input_dtype = query_states.dtype if input_dtype == torch.float32: if hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype elif torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() else: target_dtype = self.query_key_value.weight.dtype logger.warning_once( f"The input hidden states seems to be silently casted in float32, this might be related to" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" f" {target_dtype}." ) query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) attn_output = self._flash_attention_forward( query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.dense(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value def _flash_attention_forward( self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None ): if not self._flash_attn_uses_top_left_mask: causal = self.is_causal else: causal = self.is_causal and query_length != 1 if attention_mask is not None: batch_size = query_states.shape[0] query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( query_states, key_states, value_states, attention_mask, query_length ) cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens attn_output_unpad = flash_attn_varlen_func( query_states, key_states, value_states, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, causal=causal, ) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: attn_output = flash_attn_func( query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal ) return attn_output def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape key_layer = index_first_axis( key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k ) value_layer = index_first_axis( value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k ) if query_length == kv_seq_len: query_layer = index_first_axis( query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k ) cu_seqlens_q = cu_seqlens_k max_seqlen_in_batch_q = max_seqlen_in_batch_k indices_q = indices_k elif query_length == 1: max_seqlen_in_batch_q = 1 cu_seqlens_q = torch.arange( batch_size + 1, dtype=torch.int32, device=query_layer.device ) indices_q = cu_seqlens_q[:-1] query_layer = query_layer.squeeze(1) else: attention_mask = attention_mask[:, -query_length:] query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) return ( query_layer, key_layer, value_layer, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) class SarvamMoESdpaAttention(SarvamMoEAttention): def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: return super().forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, **kwargs, ) bsz, q_len, _ = hidden_states.size() qkv = self.query_key_value(hidden_states) qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim) query_states, key_states, value_states = qkv.split( [self.num_heads, self.num_key_value_heads, self.num_key_value_heads], dim=-2 ) query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) if self.config.use_qk_norm: query_states = self.query_layernorm(query_states) key_states = self.key_layernorm(key_states) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) if attention_mask is not None: kv_seq_len = key_states.shape[-2] if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) if query_states.device.type == "cuda" and attention_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=self.attention_dropout if self.training else 0.0, is_causal=self.is_causal and attention_mask is None and q_len > 1, ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.dense(attn_output) return attn_output, None, past_key_value ATTENTION_CLASSES = { "eager": SarvamMoEAttention, "flash_attention_2": SarvamMoEFlashAttention2, "sdpa": SarvamMoESdpaAttention, "vllm": SarvamMoEAttention, } class SarvamMoEDecoderLayer(nn.Module): def __init__(self, config: SarvamMoEConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.attention = ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) self.mlp = ( SarvamMoESparseMoeBlock(config) if (config.num_experts is not None and layer_idx >= config.first_k_dense_replace) else SarvamMoEMLP(config=config, intermediate_size=config.intermediate_size) ) self.input_layernorm = SarvamMoERMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = SarvamMoERMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states, self_attn_weights, present_key_value = self.attention( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, position_embeddings=position_embeddings, use_cache=use_cache, **kwargs, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) if isinstance(hidden_states, tuple): hidden_states, router_logits = hidden_states else: router_logits = None hidden_states = residual + hidden_states.to(residual.device) outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) if use_cache: outputs += (present_key_value,) if output_router_logits: outputs += (router_logits,) return outputs class SarvamMoEPreTrainedModel(PreTrainedModel): config_class = SarvamMoEConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["SarvamMoEDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() class SarvamMoEModel(SarvamMoEPreTrainedModel): _supports_attention_backend = True def __init__(self, config: SarvamMoEConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = [] for layer_idx in range(config.num_hidden_layers): self.layers.append(SarvamMoEDecoderLayer(config, layer_idx)) self.layers = nn.ModuleList(self.layers) self._use_sdpa = config._attn_implementation == "sdpa" self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self.norm = SarvamMoERMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = SarvamMoERotaryEmbedding(config=config) self.gradient_checkpointing = False self.post_init() def get_input_embeddings(self): return self.word_embeddings def set_input_embeddings(self, value): self.word_embeddings = value def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple, SarvamMoEModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: batch_size, seq_length = input_ids.shape[:2] elif inputs_embeds is not None: batch_size, seq_length = inputs_embeds.shape[:2] else: raise ValueError("You have to specify either input_ids or inputs_embeds") if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers." ) use_cache = False if use_cache and past_key_values is None: past_key_values = DynamicCache() if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 if position_ids is None: position_ids = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) position_ids = position_ids.unsqueeze(0) if self._use_flash_attention_2: attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None elif self._use_sdpa and not output_attentions: attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, (batch_size, seq_length), inputs_embeds, past_seen_tokens, ) else: attention_mask = _prepare_4d_causal_attention_mask( attention_mask, (batch_size, seq_length), inputs_embeds, past_seen_tokens ) hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_router_logits = () if output_router_logits else None next_decoder_cache = None layers = self.layers for decoder_layer in layers: if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, position_ids, past_key_values, output_attentions, output_router_logits, use_cache, position_embeddings, **kwargs, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, output_router_logits=output_router_logits, use_cache=use_cache, position_embeddings=position_embeddings, **kwargs, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) if output_router_logits and layer_outputs[-1] is not None: all_router_logits += (layer_outputs[-1],) hidden_states = self.norm(hidden_states) if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = None if use_cache: next_cache = next_decoder_cache if not return_dict: return tuple( v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] if v is not None ) return SarvamMoEModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, router_logits=all_router_logits, ) class SarvamMoEForCausalLM(SarvamMoEPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: SarvamMoEConfig): super().__init__(config) self.model = SarvamMoEModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.post_init() def get_input_embeddings(self): return self.model.word_embeddings def set_input_embeddings(self, value): self.model.word_embeddings = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple, SarvamMoEModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, return_dict=return_dict, **kwargs, ) loss = None aux_loss = None hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() if labels is not None: loss = self.loss_function(logits, labels, self.config.vocab_size, **kwargs) if not return_dict: output = (logits,) + outputs[1:] if output_router_logits: output = (aux_loss,) + output return (loss,) + output if loss is not None else output return SarvamMoECausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, aux_loss=aux_loss, router_logits=outputs.router_logits, )