# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0. import base64 import os from io import BytesIO from typing import Any, Dict, List, Literal, Optional, Union, Tuple import dataclasses from dataclasses import field import requests import torch import torchvision.transforms as T from PIL import Image from torchvision.transforms.functional import InterpolationMode from transformers import BatchEncoding, ProcessorMixin IMAGENET_MEAN = (0.485, 0.456, 0.406) IMAGENET_STD = (0.229, 0.224, 0.225) SIGLIP_MEAN = (0.5, 0.5, 0.5) SIGLIP_STD = (0.5, 0.5, 0.5) @dataclasses.dataclass class Conversation: """Manages prompt construction with system messages and multi-turn dialogues.""" # System instruction prepended to prompts system_message: str = "" # Role identifiers for dialogue turns roles: Tuple[str, str] = ("", "") # Message history as (role, content) pairs messages: List[List[str]] = field(default_factory=list) # Separator token between messages sep: str = "" # Token IDs that trigger generation stopping stop_token_ids: List[int] = None def get_prompt(self) -> str: """Construct the formatted prompt string from system message and dialogue history.""" ret = self.system_message + self.sep for role, message in self.messages: if message: ret += role + message + self.sep else: ret += role return ret def append_message(self, role: str, message: str): """Add a message turn to the dialogue history.""" self.messages.append([role, message]) def get_conv_template(name: str) -> Conversation: """Initialize a conversation instance with default configuration.""" return Conversation( stop_token_ids=[128259, 128001], ) def load_image(image): if isinstance(image, Image.Image): return image elif isinstance(image, str) and os.path.exists(image): return Image.open(image) elif isinstance(image, dict): if "disk_path" in image: return Image.open(image["disk_path"]) elif "base64" in image: return Image.open(BytesIO(base64.b64decode(image["base64"]))) elif "url" in image: response = requests.get(image["url"]) return Image.open(BytesIO(response.content)) elif "bytes" in image: return Image.open(BytesIO(image["bytes"])) else: raise ValueError(f"Invalid image: {image}") else: raise ValueError(f"Invalid image: {image}") def build_transform(input_size, norm_type="imagenet"): if norm_type == "imagenet": MEAN, STD = IMAGENET_MEAN, IMAGENET_STD elif norm_type == "siglip": MEAN, STD = SIGLIP_MEAN, SIGLIP_STD transform = T.Compose( [ T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=MEAN, std=STD), ] ) return transform def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): """ previous version mainly foucs on ratio. We also consider area ratio here. """ best_factor = float("-inf") best_ratio = (1, 1) area = width * height for ratio in target_ratios: target_aspect_ratio = ratio[0] / ratio[1] area_ratio = (ratio[0] * ratio[1] * image_size * image_size) / area # new area > 60% of original image area is enough. factor_based_on_area_n_ratio = min(area_ratio, 0.6) * min( target_aspect_ratio / aspect_ratio, aspect_ratio / target_aspect_ratio ) if factor_based_on_area_n_ratio > best_factor: best_factor = factor_based_on_area_n_ratio best_ratio = ratio return best_ratio def dynamic_preprocess( image, min_num=1, max_num=6, image_size=448, use_thumbnail=False ): orig_width, orig_height = image.size aspect_ratio = orig_width / orig_height # calculate the existing image aspect ratio target_ratios = set( (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num ) target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) # find the closest aspect ratio to the target target_aspect_ratio = find_closest_aspect_ratio( aspect_ratio, target_ratios, orig_width, orig_height, image_size ) # calculate the target width and height target_width = image_size * target_aspect_ratio[0] target_height = image_size * target_aspect_ratio[1] blocks = target_aspect_ratio[0] * target_aspect_ratio[1] # resize the image resized_img = image.resize((target_width, target_height)) processed_images = [] for i in range(blocks): box = ( (i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size, ((i % (target_width // image_size)) + 1) * image_size, ((i // (target_width // image_size)) + 1) * image_size, ) # split the image split_img = resized_img.crop(box) processed_images.append(split_img) assert len(processed_images) == blocks if use_thumbnail and len(processed_images) != 1: thumbnail_img = image.resize((image_size, image_size)) processed_images.append(thumbnail_img) return processed_images class LlamaNemotronVLProcessor(ProcessorMixin): attributes = ["tokenizer"] tokenizer_class = "AutoTokenizer" def __init__( self, tokenizer: Any, q_max_length: Optional[int] = None, p_max_length: Optional[int] = None, pad_to_multiple_of: Optional[int] = None, query_prefix: str = "query:", passage_prefix: str = "passage:", max_input_tiles: int = 6, num_image_token: int = 128258, dynamic_image_size: bool = True, image_size: int = 512, use_thumbnail: bool = True, template: str = "bidirectional-llama-retriever", num_channels: int = 3, norm_type: str = "siglip", system_message: str = "", padding: Union[bool, str] = True, **kwargs, ): tokenizer.padding_side = "left" tokenizer.model_input_names = tokenizer.model_input_names + ["pixel_values"] self.tokenizer = tokenizer self.q_max_length = q_max_length self.p_max_length = p_max_length self.pad_to_multiple_of = pad_to_multiple_of self.query_prefix = query_prefix self.passage_prefix = passage_prefix self.max_input_tiles = max_input_tiles self.num_image_token = num_image_token self.dynamic_image_size = dynamic_image_size self.image_size = image_size self.use_thumbnail = use_thumbnail self.template = template self.num_channels = num_channels self.norm_type = norm_type self.system_message = system_message self.padding = padding super().__init__(self.tokenizer) def process_documents( self, documents: Union[Dict, List[Dict]], return_tensors: Literal["pt", "np"] = "pt", padding: bool | str | None = None, truncation: bool = True, pixel_values_layout: Literal["per_image", "flat_tiles"] = "flat_tiles", **kwargs, ) -> Dict[str, Any]: """Process documents into model inputs with tokenized text and pixel values. Args: documents: Either a dict with "images" and "texts" lists, or a list of dicts each with "image" and "text" keys. Images can be PIL Images, file paths, or None/empty string for text-only documents. return_tensors: Output format — "pt" for PyTorch tensors, "np" for numpy arrays. padding: Padding strategy passed to the tokenizer. Defaults to the value set in the processor constructor. truncation: Whether to truncate sequences to p_max_length. pixel_values_layout: How to structure the pixel values output: - "flat_tiles": All image tiles concatenated into a single tensor of shape (total_tiles, C, H, W). Different images may contribute different numbers of tiles. None if no images are present. This is the format expected by the model's forward() method. - "per_image": A list aligned with the input documents, where each entry is either a tensor of shape (num_tiles, C, H, W) or None. Returns: Dict with "input_ids", "attention_mask", and "pixel_values". """ if return_tensors not in ("pt", "np"): raise ValueError( f"Invalid return_tensors: {return_tensors!r}. Must be 'pt' or 'np'." ) if isinstance(documents, dict): images = documents["images"] texts = documents["texts"] assert len(texts) == len(images) elif isinstance(documents, list): images = [pair["image"] for pair in documents] texts = [pair["text"] for pair in documents] else: raise ValueError("The documents need to be a dict or list of dicts") contents = [] pil_images_by_idx = {} max_input_tile_by_idx = {} for idx, (image, text) in enumerate(zip(images, texts)): prefix = "" if image is not None and image != "": pil_images_by_idx[idx] = load_image(image) prefix = "" max_input_tile_by_idx[idx] = self.max_input_tiles # ToDo: Order is hardcoded and different than before. No \n after content = text if prefix != "": content = prefix + " " + content if self.passage_prefix: content = self.passage_prefix + " " + content contents.append(content) assert len(max_input_tile_by_idx) == len(pil_images_by_idx), ( "The number of max_input_tile_by_idx and pil_images_by_idx should be the same." ) transform = build_transform( input_size=self.image_size, norm_type=self.norm_type ) template = get_conv_template(self.template) template.system_message = self.system_message IMG_START_TOKEN = "" IMG_END_TOKEN = "" IMG_CONTEXT_TOKEN = "" content_prompts = [] pixel_values_list = [] for i, content in enumerate(contents): pil_image = pil_images_by_idx.get(i) max_input_tiles = max_input_tile_by_idx.get(i) if pil_image is not None: if self.dynamic_image_size: image_tiles = dynamic_preprocess( pil_image, image_size=self.image_size, max_num=max_input_tiles, use_thumbnail=self.use_thumbnail, ) else: image_tiles = [pil_image] pixel_values = [transform(item) for item in image_tiles] pixel_values = torch.stack(pixel_values).to(dtype=torch.bfloat16) else: pixel_values = None pixel_values_list.append(pixel_values) if pixel_values is not None and "" not in content: content = " " + content # Reseting conversation messages template.messages.clear() # TODO: do we need this template? template.append_message(template.roles[0], content) # user template.append_message(template.roles[1], None) # assistant content_prompt = template.get_prompt() if pixel_values is not None: num_patches = pixel_values.shape[0] image_tokens = ( IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN ) content_prompt = content_prompt.replace("", image_tokens, 1) else: content_prompt = content_prompt.replace("", "", 1) content_prompts.append(content_prompt) max_length = None if truncation: max_length = self.p_max_length or self.tokenizer.model_max_length if padding is None: padding = self.padding model_inputs = self.tokenizer( content_prompts, truncation=truncation, max_length=max_length, padding=padding, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors=return_tensors, ) if pixel_values_layout == "flat_tiles": pixel_values_list = [pv for pv in pixel_values_list if pv is not None] if len(pixel_values_list) > 1: pixel_values_squeezed = torch.concat(pixel_values_list, axis=0) elif len(pixel_values_list) == 1: pixel_values_squeezed = pixel_values_list[0] else: pixel_values_squeezed = None if pixel_values_squeezed is not None and return_tensors == "np": pixel_values_return_value = ( pixel_values_squeezed.to(dtype=torch.float16).cpu().numpy() ) else: pixel_values_return_value = pixel_values_squeezed elif pixel_values_layout == "per_image": if return_tensors == "np": pixel_values_return_value = [ pv.to(dtype=torch.float16).cpu().numpy() if pv is not None else None for pv in pixel_values_list ] else: pixel_values_return_value = pixel_values_list else: raise ValueError( f"Invalid pixel_values_layout: {pixel_values_layout!r}. " f"Must be 'squeezed' or 'per_image'." ) batch_docs = { "input_ids": model_inputs["input_ids"], "attention_mask": model_inputs["attention_mask"], "pixel_values": pixel_values_return_value, } return batch_docs def process_queries( self, queries: List[str], return_tensors: Literal["pt", "np"] = "pt", padding: bool | str | None = None, truncation: bool = True, **kwargs, ) -> BatchEncoding: """Process queries into model inputs with tokenized text. Args: queries: List of query strings. return_tensors: Output format — "pt" for PyTorch tensors, "np" for numpy arrays. padding: Padding strategy passed to the tokenizer. Defaults to the value set in the processor constructor. truncation: Whether to truncate sequences to q_max_length. Returns: Dict with "input_ids" and "attention_mask". """ if return_tensors not in ("pt", "np"): raise ValueError( f"Invalid return_tensors: {return_tensors!r}. Must be 'pt' or 'np'." ) template = get_conv_template(self.template) template.system_message = self.system_message query_prompts = [] for query in queries: if self.query_prefix: query = f"{self.query_prefix} {query}" # Reseting conversation messages template.messages.clear() template.append_message(template.roles[0], query) # user template.append_message(template.roles[1], None) # assistant query_prompt = template.get_prompt() query_prompts.append(query_prompt) max_length = None if truncation: max_length = self.q_max_length or self.tokenizer.model_max_length if padding is None: padding = self.padding batch_query = self.tokenizer( query_prompts, truncation=truncation, max_length=max_length, padding=padding, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors=return_tensors, ) return batch_query def process_queries_documents_biencoder(self, features: Dict, **kwargs) -> Dict[str, Any]: """ (Pdb) features [{'image': [, , ], 'text': ['passage: ', 'passage: ', 'passage: '], 'question': "query: What change did Carl Rey suggest for the Strategic Plan's website objective deadline?"}, {'image': [, , ], 'text': ['passage: ', 'passage: ', 'passage: '], 'question': 'query: What are the name and TIN requirements for individuals with real estate transactions?'}, {'image': [, , ], 'text': ['passage: ', 'passage: ', 'passage: '], 'question': 'query: How does Richard Hooker view human inclinations?'}] """ queries = [] pos_neg_text_batch = [] pos_neg_image_batch = [] for feature in features: queries.append(feature["question"]) pos_neg_text_batch.extend(feature["doc_text"]) pos_neg_image_batch.extend(feature["doc_image"]) query_batch_dict = self.process_queries(queries, **kwargs) doc_batch_dict = self.process_documents( {"images": pos_neg_image_batch, "texts": pos_neg_text_batch}, **kwargs ) merged_batch_dict = self.merge_batch_dict(query_batch_dict, doc_batch_dict) merged_batch_dict = self.add_dummy_labels(queries, merged_batch_dict) return merged_batch_dict def merge_batch_dict(self, query_batch_dict, doc_batch_dict): q_prefix, d_prefix = "q_", "d_" # merge into a single BatchEncoding by adding prefix merged_batch_dict = {} for k in list(query_batch_dict.keys()): merged_batch_dict[q_prefix + k] = query_batch_dict[k] del query_batch_dict[k] for k in list(doc_batch_dict.keys()): merged_batch_dict[d_prefix + k] = doc_batch_dict[k] del doc_batch_dict[k] return merged_batch_dict def add_dummy_labels(self, questions, merged_batch_dict): # dummy placeholder for field "labels", won't use it to compute loss labels = torch.zeros(len(questions), dtype=torch.long) merged_batch_dict["labels"] = labels return merged_batch_dict