llama-nemotron-embed-vl-1b-v2 / processing_llama_nemotron_vl.py
nvidia-oliver-holworthy's picture
Add convenience options to processor methods (#3)
859e1f2
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0.
import base64
import os
from io import BytesIO
from typing import Any, Dict, List, Literal, Optional, Union, Tuple
import dataclasses
from dataclasses import field
import requests
import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from transformers import BatchEncoding, ProcessorMixin
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
SIGLIP_MEAN = (0.5, 0.5, 0.5)
SIGLIP_STD = (0.5, 0.5, 0.5)
@dataclasses.dataclass
class Conversation:
"""Manages prompt construction with system messages and multi-turn dialogues."""
# System instruction prepended to prompts
system_message: str = ""
# Role identifiers for dialogue turns
roles: Tuple[str, str] = ("", "")
# Message history as (role, content) pairs
messages: List[List[str]] = field(default_factory=list)
# Separator token between messages
sep: str = ""
# Token IDs that trigger generation stopping
stop_token_ids: List[int] = None
def get_prompt(self) -> str:
"""Construct the formatted prompt string from system message and dialogue history."""
ret = self.system_message + self.sep
for role, message in self.messages:
if message:
ret += role + message + self.sep
else:
ret += role
return ret
def append_message(self, role: str, message: str):
"""Add a message turn to the dialogue history."""
self.messages.append([role, message])
def get_conv_template(name: str) -> Conversation:
"""Initialize a conversation instance with default configuration."""
return Conversation(
stop_token_ids=[128259, 128001],
)
def load_image(image):
if isinstance(image, Image.Image):
return image
elif isinstance(image, str) and os.path.exists(image):
return Image.open(image)
elif isinstance(image, dict):
if "disk_path" in image:
return Image.open(image["disk_path"])
elif "base64" in image:
return Image.open(BytesIO(base64.b64decode(image["base64"])))
elif "url" in image:
response = requests.get(image["url"])
return Image.open(BytesIO(response.content))
elif "bytes" in image:
return Image.open(BytesIO(image["bytes"]))
else:
raise ValueError(f"Invalid image: {image}")
else:
raise ValueError(f"Invalid image: {image}")
def build_transform(input_size, norm_type="imagenet"):
if norm_type == "imagenet":
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
elif norm_type == "siglip":
MEAN, STD = SIGLIP_MEAN, SIGLIP_STD
transform = T.Compose(
[
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD),
]
)
return transform
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
"""
previous version mainly foucs on ratio.
We also consider area ratio here.
"""
best_factor = float("-inf")
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
area_ratio = (ratio[0] * ratio[1] * image_size * image_size) / area
# new area > 60% of original image area is enough.
factor_based_on_area_n_ratio = min(area_ratio, 0.6) * min(
target_aspect_ratio / aspect_ratio, aspect_ratio / target_aspect_ratio
)
if factor_based_on_area_n_ratio > best_factor:
best_factor = factor_based_on_area_n_ratio
best_ratio = ratio
return best_ratio
def dynamic_preprocess(
image, min_num=1, max_num=6, image_size=448, use_thumbnail=False
):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size
)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size,
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
class LlamaNemotronVLProcessor(ProcessorMixin):
attributes = ["tokenizer"]
tokenizer_class = "AutoTokenizer"
def __init__(
self,
tokenizer: Any,
q_max_length: Optional[int] = None,
p_max_length: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
query_prefix: str = "query:",
passage_prefix: str = "passage:",
max_input_tiles: int = 6,
num_image_token: int = 128258,
dynamic_image_size: bool = True,
image_size: int = 512,
use_thumbnail: bool = True,
template: str = "bidirectional-llama-retriever",
num_channels: int = 3,
norm_type: str = "siglip",
system_message: str = "",
padding: Union[bool, str] = True,
**kwargs,
):
tokenizer.padding_side = "left"
tokenizer.model_input_names = tokenizer.model_input_names + ["pixel_values"]
self.tokenizer = tokenizer
self.q_max_length = q_max_length
self.p_max_length = p_max_length
self.pad_to_multiple_of = pad_to_multiple_of
self.query_prefix = query_prefix
self.passage_prefix = passage_prefix
self.max_input_tiles = max_input_tiles
self.num_image_token = num_image_token
self.dynamic_image_size = dynamic_image_size
self.image_size = image_size
self.use_thumbnail = use_thumbnail
self.template = template
self.num_channels = num_channels
self.norm_type = norm_type
self.system_message = system_message
self.padding = padding
super().__init__(self.tokenizer)
def process_documents(
self,
documents: Union[Dict, List[Dict]],
return_tensors: Literal["pt", "np"] = "pt",
padding: bool | str | None = None,
truncation: bool = True,
pixel_values_layout: Literal["per_image", "flat_tiles"] = "flat_tiles",
**kwargs,
) -> Dict[str, Any]:
"""Process documents into model inputs with tokenized text and pixel values.
Args:
documents: Either a dict with "images" and "texts" lists, or a list of
dicts each with "image" and "text" keys. Images can be PIL Images,
file paths, or None/empty string for text-only documents.
return_tensors: Output format — "pt" for PyTorch tensors, "np" for numpy arrays.
padding: Padding strategy passed to the tokenizer. Defaults to the value
set in the processor constructor.
truncation: Whether to truncate sequences to p_max_length.
pixel_values_layout: How to structure the pixel values output:
- "flat_tiles": All image tiles concatenated into a single tensor of shape
(total_tiles, C, H, W). Different images may contribute different numbers
of tiles. None if no images are present. This is the format expected by
the model's forward() method.
- "per_image": A list aligned with the input documents, where each entry
is either a tensor of shape (num_tiles, C, H, W) or None.
Returns:
Dict with "input_ids", "attention_mask", and "pixel_values".
"""
if return_tensors not in ("pt", "np"):
raise ValueError(
f"Invalid return_tensors: {return_tensors!r}. Must be 'pt' or 'np'."
)
if isinstance(documents, dict):
images = documents["images"]
texts = documents["texts"]
assert len(texts) == len(images)
elif isinstance(documents, list):
images = [pair["image"] for pair in documents]
texts = [pair["text"] for pair in documents]
else:
raise ValueError("The documents need to be a dict or list of dicts")
contents = []
pil_images_by_idx = {}
max_input_tile_by_idx = {}
for idx, (image, text) in enumerate(zip(images, texts)):
prefix = ""
if image is not None and image != "":
pil_images_by_idx[idx] = load_image(image)
prefix = "<image>"
max_input_tile_by_idx[idx] = self.max_input_tiles
# ToDo: Order is hardcoded and different than before. No \n after <image>
content = text
if prefix != "":
content = prefix + " " + content
if self.passage_prefix:
content = self.passage_prefix + " " + content
contents.append(content)
assert len(max_input_tile_by_idx) == len(pil_images_by_idx), (
"The number of max_input_tile_by_idx and pil_images_by_idx should be the same."
)
transform = build_transform(
input_size=self.image_size, norm_type=self.norm_type
)
template = get_conv_template(self.template)
template.system_message = self.system_message
IMG_START_TOKEN = "<img>"
IMG_END_TOKEN = "</img>"
IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>"
content_prompts = []
pixel_values_list = []
for i, content in enumerate(contents):
pil_image = pil_images_by_idx.get(i)
max_input_tiles = max_input_tile_by_idx.get(i)
if pil_image is not None:
if self.dynamic_image_size:
image_tiles = dynamic_preprocess(
pil_image,
image_size=self.image_size,
max_num=max_input_tiles,
use_thumbnail=self.use_thumbnail,
)
else:
image_tiles = [pil_image]
pixel_values = [transform(item) for item in image_tiles]
pixel_values = torch.stack(pixel_values).to(dtype=torch.bfloat16)
else:
pixel_values = None
pixel_values_list.append(pixel_values)
if pixel_values is not None and "<image>" not in content:
content = "<image> " + content
# Reseting conversation messages
template.messages.clear()
# TODO: do we need this template?
template.append_message(template.roles[0], content) # user
template.append_message(template.roles[1], None) # assistant
content_prompt = template.get_prompt()
if pixel_values is not None:
num_patches = pixel_values.shape[0]
image_tokens = (
IMG_START_TOKEN
+ IMG_CONTEXT_TOKEN * self.num_image_token * num_patches
+ IMG_END_TOKEN
)
content_prompt = content_prompt.replace("<image>", image_tokens, 1)
else:
content_prompt = content_prompt.replace("<image>", "", 1)
content_prompts.append(content_prompt)
max_length = None
if truncation:
max_length = self.p_max_length or self.tokenizer.model_max_length
if padding is None:
padding = self.padding
model_inputs = self.tokenizer(
content_prompts,
truncation=truncation,
max_length=max_length,
padding=padding,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=return_tensors,
)
if pixel_values_layout == "flat_tiles":
pixel_values_list = [pv for pv in pixel_values_list if pv is not None]
if len(pixel_values_list) > 1:
pixel_values_squeezed = torch.concat(pixel_values_list, axis=0)
elif len(pixel_values_list) == 1:
pixel_values_squeezed = pixel_values_list[0]
else:
pixel_values_squeezed = None
if pixel_values_squeezed is not None and return_tensors == "np":
pixel_values_return_value = (
pixel_values_squeezed.to(dtype=torch.float16).cpu().numpy()
)
else:
pixel_values_return_value = pixel_values_squeezed
elif pixel_values_layout == "per_image":
if return_tensors == "np":
pixel_values_return_value = [
pv.to(dtype=torch.float16).cpu().numpy() if pv is not None else None
for pv in pixel_values_list
]
else:
pixel_values_return_value = pixel_values_list
else:
raise ValueError(
f"Invalid pixel_values_layout: {pixel_values_layout!r}. "
f"Must be 'squeezed' or 'per_image'."
)
batch_docs = {
"input_ids": model_inputs["input_ids"],
"attention_mask": model_inputs["attention_mask"],
"pixel_values": pixel_values_return_value,
}
return batch_docs
def process_queries(
self,
queries: List[str],
return_tensors: Literal["pt", "np"] = "pt",
padding: bool | str | None = None,
truncation: bool = True,
**kwargs,
) -> BatchEncoding:
"""Process queries into model inputs with tokenized text.
Args:
queries: List of query strings.
return_tensors: Output format — "pt" for PyTorch tensors, "np" for numpy arrays.
padding: Padding strategy passed to the tokenizer. Defaults to the value
set in the processor constructor.
truncation: Whether to truncate sequences to q_max_length.
Returns:
Dict with "input_ids" and "attention_mask".
"""
if return_tensors not in ("pt", "np"):
raise ValueError(
f"Invalid return_tensors: {return_tensors!r}. Must be 'pt' or 'np'."
)
template = get_conv_template(self.template)
template.system_message = self.system_message
query_prompts = []
for query in queries:
if self.query_prefix:
query = f"{self.query_prefix} {query}"
# Reseting conversation messages
template.messages.clear()
template.append_message(template.roles[0], query) # user
template.append_message(template.roles[1], None) # assistant
query_prompt = template.get_prompt()
query_prompts.append(query_prompt)
max_length = None
if truncation:
max_length = self.q_max_length or self.tokenizer.model_max_length
if padding is None:
padding = self.padding
batch_query = self.tokenizer(
query_prompts,
truncation=truncation,
max_length=max_length,
padding=padding,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=return_tensors,
)
return batch_query
def process_queries_documents_biencoder(self, features: Dict, **kwargs) -> Dict[str, Any]:
"""
(Pdb) features
[{'image': [<PIL.Image.Image image mode=RGB size=1275x1650 at 0x155059A5C3A0>, <PIL.Image.Image image mode=RGB size=1275x1650 at 0x155059A5C580>, <PIL.Image.Image image mode=RGB size=1275x1650 at 0x155059A5C940>], 'text': ['passage: ', 'passage: ', 'passage: '], 'question': "query: What change did Carl Rey suggest for the Strategic Plan's website objective deadline?"}, {'image': [<PIL.Image.Image image mode=RGB size=1275x1650 at 0x155059A5C0D0>, <PIL.Image.Image image mode=RGB size=1275x1650 at 0x155059A5DC00>, <PIL.Image.Image image mode=RGB size=1275x1650 at 0x155059A5EBF0>], 'text': ['passage: ', 'passage: ', 'passage: '], 'question': 'query: What are the name and TIN requirements for individuals with real estate transactions?'}, {'image': [<PIL.Image.Image image mode=RGB size=1275x1650 at 0x155059A5D390>, <PIL.Image.Image image mode=RGB size=1275x1650 at 0x155059A5C850>, <PIL.Image.Image image mode=RGB size=1275x1650 at 0x155059A5C070>], 'text': ['passage: ', 'passage: ', 'passage: '], 'question': 'query: How does Richard Hooker view human inclinations?'}]
"""
queries = []
pos_neg_text_batch = []
pos_neg_image_batch = []
for feature in features:
queries.append(feature["question"])
pos_neg_text_batch.extend(feature["doc_text"])
pos_neg_image_batch.extend(feature["doc_image"])
query_batch_dict = self.process_queries(queries, **kwargs)
doc_batch_dict = self.process_documents(
{"images": pos_neg_image_batch, "texts": pos_neg_text_batch}, **kwargs
)
merged_batch_dict = self.merge_batch_dict(query_batch_dict, doc_batch_dict)
merged_batch_dict = self.add_dummy_labels(queries, merged_batch_dict)
return merged_batch_dict
def merge_batch_dict(self, query_batch_dict, doc_batch_dict):
q_prefix, d_prefix = "q_", "d_"
# merge into a single BatchEncoding by adding prefix
merged_batch_dict = {}
for k in list(query_batch_dict.keys()):
merged_batch_dict[q_prefix + k] = query_batch_dict[k]
del query_batch_dict[k]
for k in list(doc_batch_dict.keys()):
merged_batch_dict[d_prefix + k] = doc_batch_dict[k]
del doc_batch_dict[k]
return merged_batch_dict
def add_dummy_labels(self, questions, merged_batch_dict):
# dummy placeholder for field "labels", won't use it to compute loss
labels = torch.zeros(len(questions), dtype=torch.long)
merged_batch_dict["labels"] = labels
return merged_batch_dict