| | """ |
| | # Copyright 2025 The HuggingFace Inc. team. |
| | # |
| | # Licensed under the Apache License, Version 2.0 (the "License"); |
| | # you may not use this file except in compliance with the License. |
| | # You may obtain a copy of the License at |
| | # |
| | # http://www.apache.org/licenses/LICENSE-2.0 |
| | # |
| | # Unless required by applicable law or agreed to in writing, software |
| | # distributed under the License is distributed on an "AS IS" BASIS, |
| | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| | # See the License for the specific language governing permissions and |
| | # limitations under the License. |
| | |
| | Modeling for ColQwen3 retrieval, aligned with the ColQwen2 reference implementation. |
| | """ |
| |
|
| | from dataclasses import dataclass |
| | from typing import Optional |
| |
|
| | from torch import nn |
| | from transformers import AutoModelForImageTextToText |
| | from transformers.configuration_utils import PretrainedConfig |
| | from transformers.cache_utils import Cache |
| | from transformers.modeling_utils import PreTrainedModel |
| | from transformers.utils import ModelOutput, auto_docstring, can_return_tuple, is_torch_available, logging |
| | from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig |
| |
|
| | from .configuration_colqwen3 import ColQwen3Config |
| |
|
| |
|
| | if is_torch_available(): |
| | import torch |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | @auto_docstring |
| | class ColQwen3PreTrainedModel(PreTrainedModel): |
| | config_class = ColQwen3Config |
| | base_model_prefix = "model" |
| | _no_split_modules = [] |
| | _supports_sdpa = True |
| | _supports_flash_attn = True |
| | _supports_flex_attn = True |
| |
|
| | def _init_weights(self, module): |
| | std = ( |
| | self.config.initializer_range |
| | if hasattr(self.config, "initializer_range") |
| | else getattr(self.config.text_config, "initializer_range", 0.02) |
| | ) |
| |
|
| | if isinstance(module, (nn.Linear, nn.Conv2d)): |
| | module.weight.data.normal_(mean=0.0, std=std) |
| | if module.bias is not None: |
| | module.bias.data.zero_() |
| | elif isinstance(module, nn.Embedding): |
| | module.weight.data.normal_(mean=0.0, std=std) |
| | if module.padding_idx is not None: |
| | module.weight.data[module.padding_idx].zero_() |
| |
|
| |
|
| | @dataclass |
| | @auto_docstring( |
| | custom_intro=""" |
| | Base class for ColQwen3 embeddings output. |
| | """ |
| | ) |
| | class ColQwen3ForRetrievalOutput(ModelOutput): |
| | r""" |
| | embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): |
| | The embeddings of the model. |
| | past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
| | It is a [`~cache_utils.Cache`] instance. |
| | """ |
| |
|
| | loss: Optional[torch.FloatTensor] = None |
| | embeddings: Optional[torch.Tensor] = None |
| | past_key_values: Optional[Cache] = None |
| | hidden_states: Optional[tuple[torch.FloatTensor]] = None |
| | attentions: Optional[tuple[torch.FloatTensor]] = None |
| |
|
| |
|
| | @auto_docstring( |
| | custom_intro=""" |
| | ColQwen3 retrieval model that mirrors the ColQwen2 late-interaction pipeline while using a Qwen3-VL backbone. |
| | """ |
| | ) |
| | class ColQwen3(ColQwen3PreTrainedModel): |
| | _checkpoint_conversion_mapping = { |
| | |
| | r"^model\.visual": "vlm.model.visual", |
| | r"^model\.language_model": "vlm.model.language_model", |
| | r"^model\.": "vlm.model.", |
| | r"^visual": "vlm.model.visual", |
| | r"^language_model": "vlm.model.language_model", |
| | r"^custom_text_proj": "embedding_proj_layer", |
| | } |
| | config_class = ColQwen3Config |
| | model_type = ColQwen3Config.model_type |
| |
|
| | def __init__( |
| | self, |
| | config: ColQwen3Config, |
| | attn_impl: Optional[str] = None, |
| | mask_non_image_embeddings: bool = False, |
| | ): |
| | """ |
| | Args: |
| | config (ColQwen3Config): Configuration carrying nested vision/text configs for the retrieval model. |
| | attn_impl (Optional[str], optional): Attention implementation forwarded to the VLM (e.g., "flash_attention_2"). Defaults to None. |
| | mask_non_image_embeddings (bool, optional): If True, zero out non-image embeddings after projection. Defaults to False. |
| | """ |
| | super().__init__(config) |
| | self.config = config |
| |
|
| | vision_cfg = ( |
| | config.vision_config.to_dict() if isinstance(config.vision_config, PretrainedConfig) else config.vision_config |
| | ) |
| | text_cfg = config.text_config.to_dict() if isinstance(config.text_config, PretrainedConfig) else config.text_config |
| |
|
| | vlm_config = Qwen3VLConfig( |
| | text_config=text_cfg, |
| | vision_config=vision_cfg, |
| | image_token_id=getattr(config, "image_token_id", 151655), |
| | video_token_id=getattr(config, "video_token_id", 151656), |
| | vision_start_token_id=getattr(config, "vision_start_token_id", 151652), |
| | vision_end_token_id=getattr(config, "vision_end_token_id", 151653), |
| | tie_word_embeddings=getattr(config.text_config, "tie_word_embeddings", False), |
| | ) |
| | self.vlm = AutoModelForImageTextToText.from_config(vlm_config) |
| |
|
| | self.embedding_dim = self.config.embed_dim |
| | self.embedding_proj_layer = nn.Linear( |
| | self.vlm.config.text_config.hidden_size, |
| | self.embedding_dim, |
| | ) |
| | self.padding_side = getattr(config, "padding_side", "left") |
| | self.mask_non_image_embeddings = mask_non_image_embeddings |
| | self._tied_weights_keys = [f"vlm.{k}" for k in (self.vlm._tied_weights_keys or [])] |
| |
|
| | self.post_init() |
| |
|
| | if attn_impl is not None and hasattr(self.vlm, "set_attn_implementation"): |
| | self.vlm.set_attn_implementation(attn_impl) |
| |
|
| | @classmethod |
| | def from_pretrained(cls, *args, config: Optional[ColQwen3Config] = None, **kwargs): |
| | key_mapping = kwargs.pop("key_mapping", None) |
| | if key_mapping is None: |
| | key_mapping = getattr(cls, "_checkpoint_conversion_mapping", None) |
| |
|
| | return super().from_pretrained(*args, config=config, **kwargs, key_mapping=key_mapping) |
| |
|
| | @can_return_tuple |
| | @auto_docstring |
| | def forward( |
| | self, |
| | input_ids: Optional[torch.LongTensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[Cache] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | pixel_values: Optional[torch.Tensor] = None, |
| | image_grid_thw: Optional[torch.LongTensor] = None, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | pixel_values_videos: Optional[torch.Tensor] = None, |
| | video_grid_thw: Optional[torch.LongTensor] = None, |
| | ) -> ColQwen3ForRetrievalOutput: |
| | r""" |
| | image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): |
| | The temporal, height and width of feature shape of each image in LLM. |
| | video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): |
| | The temporal, height and width of feature shape of each video in LLM. |
| | """ |
| | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| | output_hidden_states = ( |
| | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| | ) |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | vlm_output = self.vlm.model( |
| | input_ids=input_ids, |
| | position_ids=position_ids, |
| | attention_mask=attention_mask, |
| | past_key_values=past_key_values, |
| | inputs_embeds=inputs_embeds, |
| | pixel_values_videos=pixel_values_videos, |
| | use_cache=use_cache, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | pixel_values=pixel_values, |
| | image_grid_thw=image_grid_thw, |
| | video_grid_thw=video_grid_thw, |
| | cache_position=cache_position, |
| | ) |
| |
|
| | vlm_hidden_states = vlm_output.hidden_states if output_hidden_states else None |
| |
|
| | last_hidden_states = vlm_output[0] |
| | proj_dtype = self.embedding_proj_layer.weight.dtype |
| | embeddings = self.embedding_proj_layer(last_hidden_states.to(proj_dtype)) |
| |
|
| | denom = embeddings.norm(dim=-1, keepdim=True).clamp_min(torch.finfo(embeddings.dtype).eps) |
| | embeddings = embeddings / denom |
| | if attention_mask is not None: |
| | embeddings = embeddings * attention_mask.unsqueeze(-1) |
| |
|
| | if pixel_values is not None and self.mask_non_image_embeddings: |
| | image_mask = (input_ids == self.vlm.config.image_token_id).unsqueeze(-1) |
| | embeddings = embeddings * image_mask |
| |
|
| | return ColQwen3ForRetrievalOutput( |
| | embeddings=embeddings, |
| | past_key_values=vlm_output.past_key_values, |
| | hidden_states=vlm_hidden_states, |
| | attentions=vlm_output.attentions, |
| | ) |
| |
|
| | def get_input_embeddings(self): |
| | return self.vlm.get_input_embeddings() |
| |
|
| | def set_input_embeddings(self, value): |
| | self.vlm.set_input_embeddings(value) |
| |
|
| | def get_output_embeddings(self): |
| | return self.vlm.get_output_embeddings() |
| |
|
| | def set_output_embeddings(self, new_embeddings): |
| | self.vlm.set_output_embeddings(new_embeddings) |
| |
|
| | def tie_weights(self): |
| | return self.vlm.tie_weights() |
| |
|
| | def resize_token_embeddings( |
| | self, |
| | new_num_tokens: Optional[int] = None, |
| | pad_to_multiple_of: Optional[int] = None, |
| | mean_resizing: bool = True, |
| | ) -> nn.Embedding: |
| | model_embeds = self.vlm.resize_token_embeddings( |
| | new_num_tokens=new_num_tokens, |
| | pad_to_multiple_of=pad_to_multiple_of, |
| | mean_resizing=mean_resizing, |
| | ) |
| |
|
| | self.vlm.config.text_config.vocab_size = model_embeds.num_embeddings |
| | self.vlm.config.vocab_size = model_embeds.num_embeddings |
| | return model_embeds |
| |
|
| |
|
| | __all__ = ["ColQwen3", "ColQwen3PreTrainedModel", "ColQwen3ForRetrievalOutput"] |
| |
|