| | |
| | |
| |
|
| | import base64 |
| | import os |
| | from io import BytesIO |
| | from typing import Any, Dict, List, Literal, Optional, Union, Tuple |
| | import dataclasses |
| | from dataclasses import field |
| |
|
| | import requests |
| | import torch |
| | import torchvision.transforms as T |
| | from PIL import Image |
| | from torchvision.transforms.functional import InterpolationMode |
| | from transformers import BatchEncoding, ProcessorMixin |
| |
|
| | IMAGENET_MEAN = (0.485, 0.456, 0.406) |
| | IMAGENET_STD = (0.229, 0.224, 0.225) |
| |
|
| | SIGLIP_MEAN = (0.5, 0.5, 0.5) |
| | SIGLIP_STD = (0.5, 0.5, 0.5) |
| |
|
| |
|
| | @dataclasses.dataclass |
| | class Conversation: |
| | """Manages prompt construction with system messages and multi-turn dialogues.""" |
| |
|
| | |
| | system_message: str = "" |
| | |
| | roles: Tuple[str, str] = ("", "") |
| | |
| | messages: List[List[str]] = field(default_factory=list) |
| | |
| | sep: str = "" |
| | |
| | stop_token_ids: List[int] = None |
| |
|
| | def get_prompt(self) -> str: |
| | """Construct the formatted prompt string from system message and dialogue history.""" |
| | ret = self.system_message + self.sep |
| | for role, message in self.messages: |
| | if message: |
| | ret += role + message + self.sep |
| | else: |
| | ret += role |
| | return ret |
| |
|
| | def append_message(self, role: str, message: str): |
| | """Add a message turn to the dialogue history.""" |
| | self.messages.append([role, message]) |
| |
|
| |
|
| | def get_conv_template(name: str) -> Conversation: |
| | """Initialize a conversation instance with default configuration.""" |
| | return Conversation( |
| | stop_token_ids=[128259, 128001], |
| | ) |
| |
|
| |
|
| | def load_image(image): |
| | if isinstance(image, Image.Image): |
| | return image |
| | elif isinstance(image, str) and os.path.exists(image): |
| | return Image.open(image) |
| | elif isinstance(image, dict): |
| | if "disk_path" in image: |
| | return Image.open(image["disk_path"]) |
| | elif "base64" in image: |
| | return Image.open(BytesIO(base64.b64decode(image["base64"]))) |
| | elif "url" in image: |
| | response = requests.get(image["url"]) |
| | return Image.open(BytesIO(response.content)) |
| | elif "bytes" in image: |
| | return Image.open(BytesIO(image["bytes"])) |
| | else: |
| | raise ValueError(f"Invalid image: {image}") |
| | else: |
| | raise ValueError(f"Invalid image: {image}") |
| |
|
| |
|
| | def build_transform(input_size, norm_type="imagenet"): |
| | if norm_type == "imagenet": |
| | MEAN, STD = IMAGENET_MEAN, IMAGENET_STD |
| | elif norm_type == "siglip": |
| | MEAN, STD = SIGLIP_MEAN, SIGLIP_STD |
| |
|
| | transform = T.Compose( |
| | [ |
| | T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), |
| | T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), |
| | T.ToTensor(), |
| | T.Normalize(mean=MEAN, std=STD), |
| | ] |
| | ) |
| | return transform |
| |
|
| |
|
| | def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): |
| | """ |
| | previous version mainly foucs on ratio. |
| | We also consider area ratio here. |
| | """ |
| | best_factor = float("-inf") |
| | best_ratio = (1, 1) |
| | area = width * height |
| | for ratio in target_ratios: |
| | target_aspect_ratio = ratio[0] / ratio[1] |
| | area_ratio = (ratio[0] * ratio[1] * image_size * image_size) / area |
| | |
| | factor_based_on_area_n_ratio = min(area_ratio, 0.6) * min( |
| | target_aspect_ratio / aspect_ratio, aspect_ratio / target_aspect_ratio |
| | ) |
| |
|
| | if factor_based_on_area_n_ratio > best_factor: |
| | best_factor = factor_based_on_area_n_ratio |
| | best_ratio = ratio |
| |
|
| | return best_ratio |
| |
|
| |
|
| | def dynamic_preprocess( |
| | image, min_num=1, max_num=6, image_size=448, use_thumbnail=False |
| | ): |
| | orig_width, orig_height = image.size |
| | aspect_ratio = orig_width / orig_height |
| |
|
| | |
| | target_ratios = set( |
| | (i, j) |
| | for n in range(min_num, max_num + 1) |
| | for i in range(1, n + 1) |
| | for j in range(1, n + 1) |
| | if i * j <= max_num and i * j >= min_num |
| | ) |
| | target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) |
| |
|
| | |
| | target_aspect_ratio = find_closest_aspect_ratio( |
| | aspect_ratio, target_ratios, orig_width, orig_height, image_size |
| | ) |
| |
|
| | |
| | target_width = image_size * target_aspect_ratio[0] |
| | target_height = image_size * target_aspect_ratio[1] |
| | blocks = target_aspect_ratio[0] * target_aspect_ratio[1] |
| |
|
| | |
| | resized_img = image.resize((target_width, target_height)) |
| | processed_images = [] |
| | for i in range(blocks): |
| | box = ( |
| | (i % (target_width // image_size)) * image_size, |
| | (i // (target_width // image_size)) * image_size, |
| | ((i % (target_width // image_size)) + 1) * image_size, |
| | ((i // (target_width // image_size)) + 1) * image_size, |
| | ) |
| | |
| | split_img = resized_img.crop(box) |
| | processed_images.append(split_img) |
| | assert len(processed_images) == blocks |
| | if use_thumbnail and len(processed_images) != 1: |
| | thumbnail_img = image.resize((image_size, image_size)) |
| | processed_images.append(thumbnail_img) |
| | return processed_images |
| |
|
| |
|
| | class LlamaNemotronVLProcessor(ProcessorMixin): |
| | attributes = ["tokenizer"] |
| | tokenizer_class = "AutoTokenizer" |
| |
|
| | def __init__( |
| | self, |
| | tokenizer: Any, |
| | q_max_length: Optional[int] = None, |
| | p_max_length: Optional[int] = None, |
| | pad_to_multiple_of: Optional[int] = None, |
| | query_prefix: str = "query:", |
| | passage_prefix: str = "passage:", |
| | max_input_tiles: int = 6, |
| | num_image_token: int = 128258, |
| | dynamic_image_size: bool = True, |
| | image_size: int = 512, |
| | use_thumbnail: bool = True, |
| | template: str = "bidirectional-llama-retriever", |
| | num_channels: int = 3, |
| | norm_type: str = "siglip", |
| | system_message: str = "", |
| | padding: Union[bool, str] = True, |
| | **kwargs, |
| | ): |
| | tokenizer.padding_side = "left" |
| | tokenizer.model_input_names = tokenizer.model_input_names + ["pixel_values"] |
| | self.tokenizer = tokenizer |
| |
|
| | self.q_max_length = q_max_length |
| | self.p_max_length = p_max_length |
| | self.pad_to_multiple_of = pad_to_multiple_of |
| | self.query_prefix = query_prefix |
| | self.passage_prefix = passage_prefix |
| | self.max_input_tiles = max_input_tiles |
| | self.num_image_token = num_image_token |
| | self.dynamic_image_size = dynamic_image_size |
| | self.image_size = image_size |
| | self.use_thumbnail = use_thumbnail |
| | self.template = template |
| | self.num_channels = num_channels |
| | self.norm_type = norm_type |
| | self.system_message = system_message |
| | self.padding = padding |
| |
|
| | super().__init__(self.tokenizer) |
| |
|
| | def process_documents( |
| | self, |
| | documents: Union[Dict, List[Dict]], |
| | return_tensors: Literal["pt", "np"] = "pt", |
| | padding: bool | str | None = None, |
| | truncation: bool = True, |
| | pixel_values_layout: Literal["per_image", "flat_tiles"] = "flat_tiles", |
| | **kwargs, |
| | ) -> Dict[str, Any]: |
| | """Process documents into model inputs with tokenized text and pixel values. |
| | |
| | Args: |
| | documents: Either a dict with "images" and "texts" lists, or a list of |
| | dicts each with "image" and "text" keys. Images can be PIL Images, |
| | file paths, or None/empty string for text-only documents. |
| | return_tensors: Output format — "pt" for PyTorch tensors, "np" for numpy arrays. |
| | padding: Padding strategy passed to the tokenizer. Defaults to the value |
| | set in the processor constructor. |
| | truncation: Whether to truncate sequences to p_max_length. |
| | pixel_values_layout: How to structure the pixel values output: |
| | - "flat_tiles": All image tiles concatenated into a single tensor of shape |
| | (total_tiles, C, H, W). Different images may contribute different numbers |
| | of tiles. None if no images are present. This is the format expected by |
| | the model's forward() method. |
| | - "per_image": A list aligned with the input documents, where each entry |
| | is either a tensor of shape (num_tiles, C, H, W) or None. |
| | |
| | Returns: |
| | Dict with "input_ids", "attention_mask", and "pixel_values". |
| | """ |
| | if return_tensors not in ("pt", "np"): |
| | raise ValueError( |
| | f"Invalid return_tensors: {return_tensors!r}. Must be 'pt' or 'np'." |
| | ) |
| |
|
| | if isinstance(documents, dict): |
| | images = documents["images"] |
| | texts = documents["texts"] |
| | assert len(texts) == len(images) |
| | elif isinstance(documents, list): |
| | images = [pair["image"] for pair in documents] |
| | texts = [pair["text"] for pair in documents] |
| | else: |
| | raise ValueError("The documents need to be a dict or list of dicts") |
| |
|
| | contents = [] |
| | pil_images_by_idx = {} |
| | max_input_tile_by_idx = {} |
| | for idx, (image, text) in enumerate(zip(images, texts)): |
| | prefix = "" |
| | if image is not None and image != "": |
| | pil_images_by_idx[idx] = load_image(image) |
| | prefix = "<image>" |
| | max_input_tile_by_idx[idx] = self.max_input_tiles |
| |
|
| | |
| | content = text |
| | if prefix != "": |
| | content = prefix + " " + content |
| | if self.passage_prefix: |
| | content = self.passage_prefix + " " + content |
| | contents.append(content) |
| |
|
| | assert len(max_input_tile_by_idx) == len(pil_images_by_idx), ( |
| | "The number of max_input_tile_by_idx and pil_images_by_idx should be the same." |
| | ) |
| |
|
| | transform = build_transform( |
| | input_size=self.image_size, norm_type=self.norm_type |
| | ) |
| |
|
| | template = get_conv_template(self.template) |
| | template.system_message = self.system_message |
| |
|
| | IMG_START_TOKEN = "<img>" |
| | IMG_END_TOKEN = "</img>" |
| | IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>" |
| |
|
| | content_prompts = [] |
| | pixel_values_list = [] |
| | for i, content in enumerate(contents): |
| | pil_image = pil_images_by_idx.get(i) |
| | max_input_tiles = max_input_tile_by_idx.get(i) |
| | if pil_image is not None: |
| | if self.dynamic_image_size: |
| | image_tiles = dynamic_preprocess( |
| | pil_image, |
| | image_size=self.image_size, |
| | max_num=max_input_tiles, |
| | use_thumbnail=self.use_thumbnail, |
| | ) |
| | else: |
| | image_tiles = [pil_image] |
| |
|
| | pixel_values = [transform(item) for item in image_tiles] |
| | pixel_values = torch.stack(pixel_values).to(dtype=torch.bfloat16) |
| | else: |
| | pixel_values = None |
| | pixel_values_list.append(pixel_values) |
| |
|
| | if pixel_values is not None and "<image>" not in content: |
| | content = "<image> " + content |
| |
|
| | |
| | template.messages.clear() |
| |
|
| | |
| | template.append_message(template.roles[0], content) |
| | template.append_message(template.roles[1], None) |
| | content_prompt = template.get_prompt() |
| |
|
| | if pixel_values is not None: |
| | num_patches = pixel_values.shape[0] |
| | image_tokens = ( |
| | IMG_START_TOKEN |
| | + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches |
| | + IMG_END_TOKEN |
| | ) |
| | content_prompt = content_prompt.replace("<image>", image_tokens, 1) |
| | else: |
| | content_prompt = content_prompt.replace("<image>", "", 1) |
| |
|
| | content_prompts.append(content_prompt) |
| |
|
| | max_length = None |
| | if truncation: |
| | max_length = self.p_max_length or self.tokenizer.model_max_length |
| |
|
| | if padding is None: |
| | padding = self.padding |
| |
|
| | model_inputs = self.tokenizer( |
| | content_prompts, |
| | truncation=truncation, |
| | max_length=max_length, |
| | padding=padding, |
| | pad_to_multiple_of=self.pad_to_multiple_of, |
| | return_tensors=return_tensors, |
| | ) |
| |
|
| | if pixel_values_layout == "flat_tiles": |
| | pixel_values_list = [pv for pv in pixel_values_list if pv is not None] |
| | if len(pixel_values_list) > 1: |
| | pixel_values_squeezed = torch.concat(pixel_values_list, axis=0) |
| | elif len(pixel_values_list) == 1: |
| | pixel_values_squeezed = pixel_values_list[0] |
| | else: |
| | pixel_values_squeezed = None |
| |
|
| | if pixel_values_squeezed is not None and return_tensors == "np": |
| | pixel_values_return_value = ( |
| | pixel_values_squeezed.to(dtype=torch.float16).cpu().numpy() |
| | ) |
| | else: |
| | pixel_values_return_value = pixel_values_squeezed |
| | elif pixel_values_layout == "per_image": |
| | if return_tensors == "np": |
| | pixel_values_return_value = [ |
| | pv.to(dtype=torch.float16).cpu().numpy() if pv is not None else None |
| | for pv in pixel_values_list |
| | ] |
| | else: |
| | pixel_values_return_value = pixel_values_list |
| | else: |
| | raise ValueError( |
| | f"Invalid pixel_values_layout: {pixel_values_layout!r}. " |
| | f"Must be 'squeezed' or 'per_image'." |
| | ) |
| |
|
| | batch_docs = { |
| | "input_ids": model_inputs["input_ids"], |
| | "attention_mask": model_inputs["attention_mask"], |
| | "pixel_values": pixel_values_return_value, |
| | } |
| |
|
| | return batch_docs |
| |
|
| | def process_queries( |
| | self, |
| | queries: List[str], |
| | return_tensors: Literal["pt", "np"] = "pt", |
| | padding: bool | str | None = None, |
| | truncation: bool = True, |
| | **kwargs, |
| | ) -> BatchEncoding: |
| | """Process queries into model inputs with tokenized text. |
| | |
| | Args: |
| | queries: List of query strings. |
| | return_tensors: Output format — "pt" for PyTorch tensors, "np" for numpy arrays. |
| | padding: Padding strategy passed to the tokenizer. Defaults to the value |
| | set in the processor constructor. |
| | truncation: Whether to truncate sequences to q_max_length. |
| | |
| | Returns: |
| | Dict with "input_ids" and "attention_mask". |
| | """ |
| | if return_tensors not in ("pt", "np"): |
| | raise ValueError( |
| | f"Invalid return_tensors: {return_tensors!r}. Must be 'pt' or 'np'." |
| | ) |
| |
|
| | template = get_conv_template(self.template) |
| | template.system_message = self.system_message |
| |
|
| | query_prompts = [] |
| | for query in queries: |
| | if self.query_prefix: |
| | query = f"{self.query_prefix} {query}" |
| |
|
| | |
| | template.messages.clear() |
| |
|
| | template.append_message(template.roles[0], query) |
| | template.append_message(template.roles[1], None) |
| | query_prompt = template.get_prompt() |
| |
|
| | query_prompts.append(query_prompt) |
| |
|
| | max_length = None |
| | if truncation: |
| | max_length = self.q_max_length or self.tokenizer.model_max_length |
| |
|
| | if padding is None: |
| | padding = self.padding |
| |
|
| | batch_query = self.tokenizer( |
| | query_prompts, |
| | truncation=truncation, |
| | max_length=max_length, |
| | padding=padding, |
| | pad_to_multiple_of=self.pad_to_multiple_of, |
| | return_tensors=return_tensors, |
| | ) |
| |
|
| | return batch_query |
| |
|
| | def process_queries_documents_biencoder(self, features: Dict, **kwargs) -> Dict[str, Any]: |
| | """ |
| | (Pdb) features |
| | [{'image': [<PIL.Image.Image image mode=RGB size=1275x1650 at 0x155059A5C3A0>, <PIL.Image.Image image mode=RGB size=1275x1650 at 0x155059A5C580>, <PIL.Image.Image image mode=RGB size=1275x1650 at 0x155059A5C940>], 'text': ['passage: ', 'passage: ', 'passage: '], 'question': "query: What change did Carl Rey suggest for the Strategic Plan's website objective deadline?"}, {'image': [<PIL.Image.Image image mode=RGB size=1275x1650 at 0x155059A5C0D0>, <PIL.Image.Image image mode=RGB size=1275x1650 at 0x155059A5DC00>, <PIL.Image.Image image mode=RGB size=1275x1650 at 0x155059A5EBF0>], 'text': ['passage: ', 'passage: ', 'passage: '], 'question': 'query: What are the name and TIN requirements for individuals with real estate transactions?'}, {'image': [<PIL.Image.Image image mode=RGB size=1275x1650 at 0x155059A5D390>, <PIL.Image.Image image mode=RGB size=1275x1650 at 0x155059A5C850>, <PIL.Image.Image image mode=RGB size=1275x1650 at 0x155059A5C070>], 'text': ['passage: ', 'passage: ', 'passage: '], 'question': 'query: How does Richard Hooker view human inclinations?'}] |
| | """ |
| | queries = [] |
| | pos_neg_text_batch = [] |
| | pos_neg_image_batch = [] |
| | for feature in features: |
| | queries.append(feature["question"]) |
| | pos_neg_text_batch.extend(feature["doc_text"]) |
| | pos_neg_image_batch.extend(feature["doc_image"]) |
| |
|
| | query_batch_dict = self.process_queries(queries, **kwargs) |
| | doc_batch_dict = self.process_documents( |
| | {"images": pos_neg_image_batch, "texts": pos_neg_text_batch}, **kwargs |
| | ) |
| |
|
| | merged_batch_dict = self.merge_batch_dict(query_batch_dict, doc_batch_dict) |
| | merged_batch_dict = self.add_dummy_labels(queries, merged_batch_dict) |
| | return merged_batch_dict |
| |
|
| | def merge_batch_dict(self, query_batch_dict, doc_batch_dict): |
| | q_prefix, d_prefix = "q_", "d_" |
| | |
| | merged_batch_dict = {} |
| | for k in list(query_batch_dict.keys()): |
| | merged_batch_dict[q_prefix + k] = query_batch_dict[k] |
| | del query_batch_dict[k] |
| | for k in list(doc_batch_dict.keys()): |
| | merged_batch_dict[d_prefix + k] = doc_batch_dict[k] |
| | del doc_batch_dict[k] |
| | return merged_batch_dict |
| |
|
| | def add_dummy_labels(self, questions, merged_batch_dict): |
| | |
| | labels = torch.zeros(len(questions), dtype=torch.long) |
| | merged_batch_dict["labels"] = labels |
| | return merged_batch_dict |
| |
|