Integrate with Sentence Transformers v5.4

#3
by tomaarsen HF Staff - opened

Hello!

Pull Request overview

  • Integrate OmniEmbed-v0.1 with Sentence Transformers v5.4, so it can be loaded via SentenceTransformer("Tevatron/OmniEmbed-v0.1") and used for text, image, audio, and video retrieval out of the box. Requires transformers>=5.6.0.

Preface

I want to share that this integration was AI-assisted, but led by myself to aim for results that mirror the original Transformers integration, albeit with a simpler and common interface.

Details

The integration is a config-only wrap around the stock pipeline:

Transformer(any-to-any) -> Pooling(lasttoken) -> Normalize

Sentence Transformers' any-to-any task routes through AutoModelForMultimodalLM, which as of transformers>=5.6.0 natively resolves the qwen2_5_omni_thinker model_type into Qwen2_5OmniThinkerForConditionalGeneration. transformers' PEFT auto-loader then attaches the existing LoRA adapter on top of the base weights (no trust_remote_code=True and no custom modeling_*.py needed).

A few non-obvious pieces worth calling out:

  • The chat-template suffix <|endoftext|> is important for the last-token pooling. The existing README appends it to the templated string manually. For Sentence Transformers a dedicated template is registered at additional_chat_templates/sentence_transformers.jinja that ends with <|im_start|>assistant\n<|endoftext|>, wired in via processing_kwargs.chat_template.chat_template = "sentence_transformers" in sentence_bert_config.json.
  • The legacy chat_template.json was converted to a modern chat_template.jinja (same content) because Qwen2_5OmniProcessor raises ValueError: Cannot load chat template due to conflicting files if both chat_template.json and an additional_chat_templates/ directory exist. Any existing code calling processor.apply_chat_template(...) keeps the same rendered output.
  • Pooling / extraction: this is a ForConditionalGeneration model, so its output exposes hidden_states (a tuple) rather than last_hidden_state. sentence_bert_config.json uses "method_output_name": ["hidden_states", -1] per modality so ST auto-enables output_hidden_states and indexes correctly.
  • Prompts: the existing README uses "Query: " as a query prefix and no prefix for documents, so config_sentence_transformers.json has prompts = {"query": "Query: ", "document": ""} and default_prompt_name = null. Users call model.encode_query(...) / model.encode_document(...); documents can be plain strings (paths/URLs for image/audio/video also work directly), or dicts like {"text": ..., "image": ..., "audio": ..., "video": ...} for combined-modality inputs.
  • Video memory: the processor defaults produce 1000+ tokens per video on a 7B model, which OOMs 24 GB GPUs. The README snippet shows a runtime tune (fps=1, max_pixels=64*28*28) plus batch_size=1 in encode_document(...) as a copy-paste-safe starting point; users on bigger hardware can raise both.

Retrieval rankings match the existing transformers reproduction across all four examples (video / audio / image / multilingual text) in the README. Absolute cosine similarities agree within bf16 precision on text; image/audio/video drift by up to ~0.02 because Sentence Transformers and the README's process_mm_info path differ in how they handle image resize and audio clips longer than Whisper's 300-second window.

Added files

  • modules.json: three-module pipeline (stock Transformer, Pooling, Normalize).
  • sentence_bert_config.json: transformer_task="any-to-any", per-modality method_output_name=["hidden_states", -1], module_output_name="token_embeddings", processing_kwargs.chat_template.chat_template="sentence_transformers", message.format="structured".
  • config_sentence_transformers.json: prompts, similarity_fn_name="cosine".
  • 1_Pooling/config.json: pooling_mode="lasttoken", embedding_dimension=3584, include_prompt=true.
  • additional_chat_templates/sentence_transformers.jinja: minimal chat template producing the exact system/user/assistant scaffold plus the trailing <|endoftext|> pooling token, for text / image / audio / video content blocks.
  • chat_template.jinja: modern replacement for the deleted legacy chat_template.json, with identical content.

Modified files

  • README.md: added a "Using Sentence Transformers" subsection at the top of Usage (the existing usage content is preserved as "Using Transformers"), mirroring the same four retrieval examples plus a "Multimodal Inputs" note for combined-modality dicts. YAML frontmatter updated with library_name: sentence-transformers and tags: [sentence-transformers, peft, multimodal, feature-extraction].

Deleted files

  • chat_template.json: replaced with chat_template.jinja (same template content). Required because Qwen2_5OmniProcessor rejects a repo that ships both the legacy JSON and an additional_chat_templates/ directory.
import torch
from sentence_transformers import SentenceTransformer

model = SentenceTransformer(
    "Tevatron/OmniEmbed-v0.1",
    model_kwargs={
        "torch_dtype": torch.bfloat16,
        "attn_implementation": "flash_attention_2",  # pip install kernels; recommended but not mandatory
    },
    revision="refs/pr/3",
)

# For video on smaller GPUs, cap the processor up front:
model[0].processing_kwargs.update({
    "video": {"max_pixels": 64 * 28 * 28, "do_sample_frames": True, "fps": 1},
})

example_query = "How many input modality does Qwen2.5-Omni support?"
example_images = [
    "https://huggingface.co/Tevatron/OmniEmbed-v0.1/resolve/main/assets/qwen2.5omni_hgf.png",
    "https://huggingface.co/Tevatron/OmniEmbed-v0.1/resolve/main/assets/llama4_hgf.png",
]
query_embedding = model.encode_query(example_query)
document_embeddings = model.encode_document(example_images, batch_size=1)
print(model.similarity(query_embedding, document_embeddings))
# Image similarities: tensor([[0.4682, 0.2956]])

I added revision for this snippet here, as it allows you to run this before merging hte PR.

Note that none of the existing behavior is affected: the original transformers snippet in the README keeps working as-is, the adapter weights are unchanged, and loading via Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained("Tevatron/OmniEmbed-v0.1", ...) still works. This PR only adds an additional way to run the model in a familiar, library-level format.

Happy to tweak anything you'd like changed. Please let me know if you have any questions or feedback!

  • Tom Aarsen
tomaarsen changed pull request status to open
MrLight changed pull request status to merged

Sign up or log in to comment