Jolia / text_encoder_jolia.py
SovanK's picture
Upload folder using huggingface_hub
6858e35 verified
Raw
History Blame Contribute Delete
3.89 kB
"""Paired text encoder for Jolia's zero-shot CLIP head.
Loads the Hugging Face text encoder Jolia was trained against
(default ``Qwen/Qwen3-Embedding-8B``) and exposes a one-call ``__call__``
that returns last-token-pooled embeddings ready to feed into
:meth:`JoliaModel.encode_text`.
Pooling matches the training-time path
(``LastTokenPoolingModelWrapper`` in ``rarm/dinov2/train/multimodal_wrapper.py``):
last-token pooling that handles both left- and right-padding via the
attention mask. Tokenizer settings match the cache pipeline
(``padding="max_length"``, ``truncation=True``).
Decoupled from :class:`JoliaModel` on purpose — the text encoder is the
heavy piece (Qwen3-Embedding-8B is ~18 GB), so loading it is opt-in.
Usage::
from text_encoder_jolia import JoliaTextEncoder
te = JoliaTextEncoder.from_pretrained("Qwen/Qwen3-Embedding-8B").eval()
pooled = te(["liver lesion", "normal liver"]) # (N, text_embed_dim)
"""
from __future__ import annotations
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
class JoliaTextEncoder(nn.Module):
"""Tokenizer + Qwen3-style causal LM with last-token attention-mask pooling.
Args:
model: A loaded causal LM (e.g. ``Qwen/Qwen3-Embedding-8B``). Must expose
``last_hidden_state`` and accept ``attention_mask``.
tokenizer: The matching HuggingFace tokenizer.
context_length: ``max_length`` for tokenization (Jolia was trained at 512).
"""
def __init__(self, model: nn.Module, tokenizer: object, context_length: int = 512) -> None:
super().__init__()
self.model = model
self.tokenizer = tokenizer
self.context_length = context_length
@classmethod
def from_pretrained(
cls,
model_id: str = "Qwen/Qwen3-Embedding-8B",
context_length: int = 512,
dtype: torch.dtype | str | None = None,
device_map: str | dict | None = None,
) -> "JoliaTextEncoder":
"""Convenience loader for the paired text encoder.
Heads-up: ``Qwen/Qwen3-Embedding-8B`` is ~18 GB on disk — first use
downloads to ``~/.cache/huggingface``. Pass ``dtype=torch.bfloat16``
and/or ``device_map="auto"`` to fit comfortably on a single GPU.
"""
kwargs: dict = {}
if dtype is not None:
kwargs["dtype"] = dtype
if device_map is not None:
kwargs["device_map"] = device_map
model = AutoModel.from_pretrained(model_id, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_id)
return cls(model=model, tokenizer=tokenizer, context_length=context_length)
@torch.no_grad()
def forward(self, texts: list[str]) -> torch.Tensor:
"""Encode a list of strings -> ``(N, hidden_size)`` last-token-pooled features."""
device = next(self.model.parameters()).device
encoded = self.tokenizer(
texts,
padding="max_length",
truncation=True,
max_length=self.context_length,
return_tensors="pt",
).to(device)
outputs = self.model(**encoded)
return _last_token_pool(outputs.last_hidden_state, encoded["attention_mask"]).float()
def _last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""Last-token pool that handles both left and right padding.
Mirrors the training-time wrapper in ``multimodal_wrapper.py`` and the
pooling Qwen3-Embedding's own model card recommends.
"""
left_padding = bool((attention_mask[:, -1].sum() == attention_mask.size(0)).item())
if left_padding:
return last_hidden_states[:, -1]
seq_lengths = attention_mask.sum(dim=1) - 1
batch_idx = torch.arange(last_hidden_states.size(0), device=last_hidden_states.device)
return last_hidden_states[batch_idx, seq_lengths]