| | |
| | |
| | """Qwen3VLNemotronEmbed Processor for query and document processing.""" |
| |
|
| | import math |
| | from typing import Any, Dict, List, Optional, Union |
| |
|
| | import torch |
| | from PIL import Image |
| | from transformers import Qwen3VLProcessor |
| |
|
| |
|
| | class Qwen3VLNemotronEmbedProcessor(Qwen3VLProcessor): |
| | """Processor for Qwen3VLNemotronEmbed that handles query/document processing. |
| | |
| | This processor extends Qwen3VLProcessor with methods for processing queries and |
| | documents for retrieval tasks. |
| | |
| | Args: |
| | image_processor: Image processor for vision inputs. |
| | tokenizer: Tokenizer for text inputs. |
| | chat_template: Optional chat template. |
| | q_max_length: Maximum length for query sequences (default: 512). |
| | p_max_length: Maximum length for passage/document sequences (default: 4096). |
| | query_prefix: Prefix to add to queries (default: "query:"). |
| | passage_prefix: Prefix to add to passages (default: "passage:"). |
| | reserve_tokens_for_images: Reserved tokens for image placeholders (default: 100). |
| | """ |
| |
|
| | |
| | |
| | attributes = ["image_processor", "tokenizer"] |
| | video_processor_class = None |
| |
|
| | |
| | processor_attributes = [ |
| | "q_max_length", |
| | "p_max_length", |
| | "query_prefix", |
| | "passage_prefix", |
| | "reserve_tokens_for_images", |
| | ] |
| |
|
| | def __init__( |
| | self, |
| | image_processor=None, |
| | tokenizer=None, |
| | chat_template=None, |
| | q_max_length: int = 512, |
| | p_max_length: int = 4096, |
| | query_prefix: str = "query:", |
| | passage_prefix: str = "passage:", |
| | reserve_tokens_for_images: int = 100, |
| | **kwargs, |
| | ): |
| | if chat_template is not None: |
| | super().__init__(image_processor, tokenizer, chat_template, **kwargs) |
| | else: |
| | super().__init__(image_processor, tokenizer, **kwargs) |
| |
|
| | self.q_max_length = q_max_length |
| | self.p_max_length = p_max_length |
| | self.query_prefix = query_prefix |
| | self.passage_prefix = passage_prefix |
| | self.reserve_tokens_for_images = reserve_tokens_for_images |
| |
|
| | self.patch_size = self.image_processor.patch_size |
| | self.merge_size = self.image_processor.merge_size |
| |
|
| | def apply_chat_template( |
| | self, |
| | conversation, |
| | chat_template=None, |
| | **kwargs, |
| | ) -> str: |
| | """Apply chat template to conversation.""" |
| | return self.tokenizer.apply_chat_template( |
| | conversation, |
| | chat_template=chat_template, |
| | **kwargs, |
| | ) |
| |
|
| | @property |
| | def min_pixels(self) -> int: |
| | """Get min_pixels from image processor.""" |
| | return self.image_processor.size["shortest_edge"] |
| |
|
| | @property |
| | def max_pixels(self) -> int: |
| | """Get max_pixels from image processor.""" |
| | return self.image_processor.size["longest_edge"] |
| |
|
| | def calculate_image_tokens( |
| | self, |
| | image: Image.Image, |
| | min_pixels: Optional[int] = None, |
| | max_pixels: Optional[int] = None, |
| | ) -> int: |
| | """Calculate the number of tokens an image will use after processing. |
| | |
| | Args: |
| | image: PIL Image to calculate tokens for. |
| | min_pixels: Minimum pixels for resizing (uses processor default if None). |
| | max_pixels: Maximum pixels for resizing (uses processor default if None). |
| | |
| | Returns: |
| | Number of tokens the image will consume. |
| | """ |
| | min_pixels = min_pixels or self.min_pixels |
| | max_pixels = max_pixels or self.max_pixels |
| |
|
| | width, height = image.size |
| | factor = self.patch_size * self.merge_size |
| |
|
| | h_bar = round(height / factor) * factor |
| | w_bar = round(width / factor) * factor |
| |
|
| | if h_bar * w_bar > max_pixels: |
| | beta = math.sqrt((height * width) / max_pixels) |
| | h_bar = max(factor, math.floor(height / beta / factor) * factor) |
| | w_bar = max(factor, math.floor(width / beta / factor) * factor) |
| | elif h_bar * w_bar < min_pixels: |
| | beta = math.sqrt(min_pixels / (height * width)) |
| | h_bar = math.ceil(height * beta / factor) * factor |
| | w_bar = math.ceil(width * beta / factor) * factor |
| |
|
| | grid_h = h_bar // self.patch_size |
| | grid_w = w_bar // self.patch_size |
| | num_patches = grid_h * grid_w |
| | return num_patches // (self.merge_size**2) |
| |
|
| | def process_queries( |
| | self, |
| | queries: List[str | dict], |
| | padding: bool = True, |
| | truncation: bool = True, |
| | pad_to_multiple_of: Optional[int] = None, |
| | return_tensors: str = "pt", |
| | **kwargs, |
| | ) -> Dict[str, torch.Tensor]: |
| | """Process text queries for retrieval. |
| | |
| | Args: |
| | queries: List of query strings or dicts with "text" key. |
| | padding: Whether to pad sequences. |
| | truncation: Whether to truncate sequences. |
| | pad_to_multiple_of: Pad to a multiple of this value. |
| | return_tensors: Return tensor type ("pt" for PyTorch). |
| | |
| | Returns: |
| | Dictionary with input_ids, attention_mask, and other model inputs. |
| | """ |
| | query_texts = [] |
| | for query in queries: |
| | if isinstance(query, dict): |
| | query_text = query["text"] |
| | else: |
| | query_text = query |
| |
|
| | prefixed = f"{self.query_prefix} {query_text}" if self.query_prefix else query_text |
| | message = [ |
| | { |
| | "role": "user", |
| | "content": [{"type": "text", "text": f"Query: {prefixed}"}], |
| | } |
| | ] |
| | query_text = self.apply_chat_template( |
| | message, tokenize=False, add_generation_prompt=True |
| | ) |
| | query_texts.append(query_text) |
| |
|
| | return self( |
| | text=query_texts, |
| | truncation=truncation, |
| | max_length=self.q_max_length, |
| | padding=padding, |
| | pad_to_multiple_of=pad_to_multiple_of, |
| | return_tensors=return_tensors, |
| | **kwargs, |
| | ) |
| |
|
| | def process_documents( |
| | self, |
| | documents: Union[Dict[str, List], List[Dict[str, Any]]], |
| | padding: bool = True, |
| | truncation: bool = True, |
| | pad_to_multiple_of: Optional[int] = None, |
| | return_tensors: str = "pt", |
| | **kwargs, |
| | ) -> Dict[str, torch.Tensor]: |
| | """Process image-text documents for retrieval. |
| | |
| | Args: |
| | documents: Either a dict with "image" and "text" keys containing lists, |
| | or a list of dicts each with "image" and "text" keys. |
| | padding: Whether to pad sequences. |
| | truncation: Whether to truncate sequences. |
| | pad_to_multiple_of: Pad to a multiple of this value. |
| | return_tensors: Return tensor type ("pt" for PyTorch). |
| | |
| | Returns: |
| | Dictionary with input_ids, attention_mask, pixel_values, and other model inputs. |
| | """ |
| | if isinstance(documents, dict): |
| | images = documents["image"] |
| | texts = documents["text"] |
| | assert len(texts) == len(images), ( |
| | "Number of texts must match number of images" |
| | ) |
| | elif isinstance(documents, list): |
| | images = [d["image"] for d in documents] |
| | texts = [d["text"] for d in documents] |
| | else: |
| | raise ValueError("documents must be a dict or list of dicts") |
| |
|
| | if self.passage_prefix: |
| | texts = [f"{self.passage_prefix} {t}" for t in texts] |
| |
|
| | image_tokens_list = [self.calculate_image_tokens(img) for img in images] |
| | max_image_tokens = max(image_tokens_list) if image_tokens_list else 0 |
| |
|
| | assert self.p_max_length > max_image_tokens + self.reserve_tokens_for_images, ( |
| | f"p_max_length ({self.p_max_length}) is too small for max_image_tokens " |
| | f"({max_image_tokens}) + reserve ({self.reserve_tokens_for_images})" |
| | ) |
| | available_text_tokens = ( |
| | self.p_max_length - max_image_tokens - self.reserve_tokens_for_images |
| | ) |
| |
|
| | if ( |
| | pad_to_multiple_of is not None |
| | and available_text_tokens % pad_to_multiple_of != 0 |
| | ): |
| | available_text_tokens = ( |
| | available_text_tokens // pad_to_multiple_of |
| | ) * pad_to_multiple_of |
| |
|
| | input_texts = [] |
| | for text, image in zip(texts, images): |
| | message = [ |
| | { |
| | "role": "user", |
| | "content": [ |
| | {"type": "image", "image": image}, |
| | {"type": "text", "text": text}, |
| | ], |
| | } |
| | ] |
| | input_text = self.apply_chat_template( |
| | message, tokenize=False, add_generation_prompt=True |
| | ) |
| | input_texts.append(input_text) |
| |
|
| | return self( |
| | text=input_texts, |
| | images=images, |
| | truncation=truncation, |
| | padding=padding, |
| | pad_to_multiple_of=pad_to_multiple_of, |
| | return_tensors=return_tensors, |
| | max_length=available_text_tokens, |
| | **kwargs, |
| | ) |
| |
|