nemotron-colembed-vl-4b-v2 / processing_qwen3_vl_nemotron_embed.py
nvidia-oliver-holworthy's picture
Extend transformers version compatibility 4.57.x through 5.1.x (#3)
a0e3da5
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0.
"""Qwen3VLNemotronEmbed Processor for query and document processing."""
import math
from typing import Any, Dict, List, Optional, Union
import torch
from PIL import Image
from transformers import Qwen3VLProcessor
class Qwen3VLNemotronEmbedProcessor(Qwen3VLProcessor):
"""Processor for Qwen3VLNemotronEmbed that handles query/document processing.
This processor extends Qwen3VLProcessor with methods for processing queries and
documents for retrieval tasks.
Args:
image_processor: Image processor for vision inputs.
tokenizer: Tokenizer for text inputs.
chat_template: Optional chat template.
q_max_length: Maximum length for query sequences (default: 512).
p_max_length: Maximum length for passage/document sequences (default: 4096).
query_prefix: Prefix to add to queries (default: "query:").
passage_prefix: Prefix to add to passages (default: "passage:").
reserve_tokens_for_images: Reserved tokens for image placeholders (default: 100).
"""
# Override parent to exclude video_processor (not used by this model) and avoid
# type-check failures on transformers <5.0 where BaseVideoProcessor doesn't exist.
attributes = ["image_processor", "tokenizer"]
video_processor_class = None
# Attributes to save/load
processor_attributes = [
"q_max_length",
"p_max_length",
"query_prefix",
"passage_prefix",
"reserve_tokens_for_images",
]
def __init__(
self,
image_processor=None,
tokenizer=None,
chat_template=None,
q_max_length: int = 512,
p_max_length: int = 4096,
query_prefix: str = "query:",
passage_prefix: str = "passage:",
reserve_tokens_for_images: int = 100,
**kwargs,
):
if chat_template is not None:
super().__init__(image_processor, tokenizer, chat_template, **kwargs)
else:
super().__init__(image_processor, tokenizer, **kwargs)
self.q_max_length = q_max_length
self.p_max_length = p_max_length
self.query_prefix = query_prefix
self.passage_prefix = passage_prefix
self.reserve_tokens_for_images = reserve_tokens_for_images
self.patch_size = self.image_processor.patch_size
self.merge_size = self.image_processor.merge_size
def apply_chat_template(
self,
conversation,
chat_template=None,
**kwargs,
) -> str:
"""Apply chat template to conversation."""
return self.tokenizer.apply_chat_template(
conversation,
chat_template=chat_template,
**kwargs,
)
@property
def min_pixels(self) -> int:
"""Get min_pixels from image processor."""
return self.image_processor.size["shortest_edge"]
@property
def max_pixels(self) -> int:
"""Get max_pixels from image processor."""
return self.image_processor.size["longest_edge"]
def calculate_image_tokens(
self,
image: Image.Image,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
) -> int:
"""Calculate the number of tokens an image will use after processing.
Args:
image: PIL Image to calculate tokens for.
min_pixels: Minimum pixels for resizing (uses processor default if None).
max_pixels: Maximum pixels for resizing (uses processor default if None).
Returns:
Number of tokens the image will consume.
"""
min_pixels = min_pixels or self.min_pixels
max_pixels = max_pixels or self.max_pixels
width, height = image.size
factor = self.patch_size * self.merge_size
h_bar = round(height / factor) * factor
w_bar = round(width / factor) * factor
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = max(factor, math.floor(height / beta / factor) * factor)
w_bar = max(factor, math.floor(width / beta / factor) * factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = math.ceil(height * beta / factor) * factor
w_bar = math.ceil(width * beta / factor) * factor
grid_h = h_bar // self.patch_size
grid_w = w_bar // self.patch_size
num_patches = grid_h * grid_w
return num_patches // (self.merge_size**2)
def process_queries(
self,
queries: List[str | dict],
padding: bool = True,
truncation: bool = True,
pad_to_multiple_of: Optional[int] = None,
return_tensors: str = "pt",
**kwargs,
) -> Dict[str, torch.Tensor]:
"""Process text queries for retrieval.
Args:
queries: List of query strings or dicts with "text" key.
padding: Whether to pad sequences.
truncation: Whether to truncate sequences.
pad_to_multiple_of: Pad to a multiple of this value.
return_tensors: Return tensor type ("pt" for PyTorch).
Returns:
Dictionary with input_ids, attention_mask, and other model inputs.
"""
query_texts = []
for query in queries:
if isinstance(query, dict):
query_text = query["text"]
else:
query_text = query
prefixed = f"{self.query_prefix} {query_text}" if self.query_prefix else query_text
message = [
{
"role": "user",
"content": [{"type": "text", "text": f"Query: {prefixed}"}],
}
]
query_text = self.apply_chat_template(
message, tokenize=False, add_generation_prompt=True
)
query_texts.append(query_text)
return self(
text=query_texts,
truncation=truncation,
max_length=self.q_max_length,
padding=padding,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors=return_tensors,
**kwargs,
)
def process_documents(
self,
documents: Union[Dict[str, List], List[Dict[str, Any]]],
padding: bool = True,
truncation: bool = True,
pad_to_multiple_of: Optional[int] = None,
return_tensors: str = "pt",
**kwargs,
) -> Dict[str, torch.Tensor]:
"""Process image-text documents for retrieval.
Args:
documents: Either a dict with "image" and "text" keys containing lists,
or a list of dicts each with "image" and "text" keys.
padding: Whether to pad sequences.
truncation: Whether to truncate sequences.
pad_to_multiple_of: Pad to a multiple of this value.
return_tensors: Return tensor type ("pt" for PyTorch).
Returns:
Dictionary with input_ids, attention_mask, pixel_values, and other model inputs.
"""
if isinstance(documents, dict):
images = documents["image"]
texts = documents["text"]
assert len(texts) == len(images), (
"Number of texts must match number of images"
)
elif isinstance(documents, list):
images = [d["image"] for d in documents]
texts = [d["text"] for d in documents]
else:
raise ValueError("documents must be a dict or list of dicts")
if self.passage_prefix:
texts = [f"{self.passage_prefix} {t}" for t in texts]
image_tokens_list = [self.calculate_image_tokens(img) for img in images]
max_image_tokens = max(image_tokens_list) if image_tokens_list else 0
assert self.p_max_length > max_image_tokens + self.reserve_tokens_for_images, (
f"p_max_length ({self.p_max_length}) is too small for max_image_tokens "
f"({max_image_tokens}) + reserve ({self.reserve_tokens_for_images})"
)
available_text_tokens = (
self.p_max_length - max_image_tokens - self.reserve_tokens_for_images
)
if (
pad_to_multiple_of is not None
and available_text_tokens % pad_to_multiple_of != 0
):
available_text_tokens = (
available_text_tokens // pad_to_multiple_of
) * pad_to_multiple_of
input_texts = []
for text, image in zip(texts, images):
message = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": text},
],
}
]
input_text = self.apply_chat_template(
message, tokenize=False, add_generation_prompt=True
)
input_texts.append(input_text)
return self(
text=input_texts,
images=images,
truncation=truncation,
padding=padding,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors=return_tensors,
max_length=available_text_tokens,
**kwargs,
)