"""Paired text encoder for Jolia's zero-shot CLIP head. Loads the Hugging Face text encoder Jolia was trained against (default ``Qwen/Qwen3-Embedding-8B``) and exposes a one-call ``__call__`` that returns last-token-pooled embeddings ready to feed into :meth:`JoliaModel.encode_text`. Pooling matches the training-time path (``LastTokenPoolingModelWrapper`` in ``rarm/dinov2/train/multimodal_wrapper.py``): last-token pooling that handles both left- and right-padding via the attention mask. Tokenizer settings match the cache pipeline (``padding="max_length"``, ``truncation=True``). Decoupled from :class:`JoliaModel` on purpose — the text encoder is the heavy piece (Qwen3-Embedding-8B is ~18 GB), so loading it is opt-in. Usage:: from text_encoder_jolia import JoliaTextEncoder te = JoliaTextEncoder.from_pretrained("Qwen/Qwen3-Embedding-8B").eval() pooled = te(["liver lesion", "normal liver"]) # (N, text_embed_dim) """ from __future__ import annotations import torch import torch.nn as nn from transformers import AutoModel, AutoTokenizer class JoliaTextEncoder(nn.Module): """Tokenizer + Qwen3-style causal LM with last-token attention-mask pooling. Args: model: A loaded causal LM (e.g. ``Qwen/Qwen3-Embedding-8B``). Must expose ``last_hidden_state`` and accept ``attention_mask``. tokenizer: The matching HuggingFace tokenizer. context_length: ``max_length`` for tokenization (Jolia was trained at 512). """ def __init__(self, model: nn.Module, tokenizer: object, context_length: int = 512) -> None: super().__init__() self.model = model self.tokenizer = tokenizer self.context_length = context_length @classmethod def from_pretrained( cls, model_id: str = "Qwen/Qwen3-Embedding-8B", context_length: int = 512, dtype: torch.dtype | str | None = None, device_map: str | dict | None = None, ) -> "JoliaTextEncoder": """Convenience loader for the paired text encoder. Heads-up: ``Qwen/Qwen3-Embedding-8B`` is ~18 GB on disk — first use downloads to ``~/.cache/huggingface``. Pass ``dtype=torch.bfloat16`` and/or ``device_map="auto"`` to fit comfortably on a single GPU. """ kwargs: dict = {} if dtype is not None: kwargs["dtype"] = dtype if device_map is not None: kwargs["device_map"] = device_map model = AutoModel.from_pretrained(model_id, **kwargs) tokenizer = AutoTokenizer.from_pretrained(model_id) return cls(model=model, tokenizer=tokenizer, context_length=context_length) @torch.no_grad() def forward(self, texts: list[str]) -> torch.Tensor: """Encode a list of strings -> ``(N, hidden_size)`` last-token-pooled features.""" device = next(self.model.parameters()).device encoded = self.tokenizer( texts, padding="max_length", truncation=True, max_length=self.context_length, return_tensors="pt", ).to(device) outputs = self.model(**encoded) return _last_token_pool(outputs.last_hidden_state, encoded["attention_mask"]).float() def _last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: """Last-token pool that handles both left and right padding. Mirrors the training-time wrapper in ``multimodal_wrapper.py`` and the pooling Qwen3-Embedding's own model card recommends. """ left_padding = bool((attention_mask[:, -1].sum() == attention_mask.size(0)).item()) if left_padding: return last_hidden_states[:, -1] seq_lengths = attention_mask.sum(dim=1) - 1 batch_idx = torch.arange(last_hidden_states.size(0), device=last_hidden_states.device) return last_hidden_states[batch_idx, seq_lengths]