Simplify and support Sentence Transformers via SparseEncoder

#1
by tomaarsen HF Staff - opened

Hello!

Pull Request overview

Details

This PR mirrors https://huggingface.co/naver/splade-code-06B/discussions/1. Apologies that I did not get this out earlier, I started the work, but totally forgot to finish it over the weekend.
The easier usage is now, while the old transformers code is unaffected:

from sentence_transformers import SparseEncoder

model = SparseEncoder("naver/splade-code-8B", trust_remote_code=True, revision="refs/pr/1")

queries = [
    "SELECT *\nFROM Student\nWHERE Age = (\nSELECT MAX(Age)\nFROM Student\nWHERE Group = 'specific_group'\n)\nAND Group = 'specific_group';"
]

query_embeddings = model.encode(queries)
print(query_embeddings.shape)
# torch.Size([1, 151936])

sparsity = model.sparsity(query_embeddings)
print(sparsity)
# {'active_dims': 1120.0, 'sparsity_ratio': 0.9926284751474305}

decoded = model.decode(query_embeddings, top_k=10)
print(decoded)
# [[
#     ('Δ group', 2.34375),
#     ('Δ oldest', 2.28125),
#     ('Δ age', 2.25),
#     ('_group', 2.25),
#     ('Δ Group', 2.171875),
#     ('Δ Age', 2.109375),
#     ('Δ MAX', 2.0625),
#     ('Δ Student', 2.046875),
#     ('Δ specific', 2.03125),
#     ('Δ student', 2.0),
# ]]

And it works with transformers>5.2.0 and sentence-transformers>5.0.0. If you install kernels, you can set model_kwargs={"attn_implementation": "flash_attention_2"} and it will use a kernel from the Hub without having to actually install flash-attn (which is always annoying).

It also works with e.g. sdpa or eager, unlike the current implementation, and is likely to keep working with future transformers versions as it imports very little from transformers.

Note that the above script has revision="refs/pr/1" so you can test it directly from this PR branch without having to check anything out locally or merge it.

As mentioned, the old Transformers code also gives the same result as before, e.g. when I run the baseline code with transformers locally, I get:

+--------------------------------------------------------------------+
|                        TOP ACTIVATED WORDS                         |
+--------------------------------------------------------------------+


* INPUT: SELECT *
FROM Student
WHERE Age = (
SELECT MAX(Age)
FROM Student
WHERE Group = 'specific_group'
)
AND Group = 'specific_group';

Δ group                    | β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ 2.34
Δ oldest                   | β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ 2.28
Δ age                      | β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ 2.25
_group                    | β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ 2.25
Δ Group                    | β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ 2.17
Δ Age                      | β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ 2.11
Δ MAX                      | β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ 2.06
Δ Student                  | β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ 2.05
Δ specific                 | β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ 2.03
Δ student                  | β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ 2.00

P.s. I moved the lora adapter files to the main directory as these files are commonly kept there. It's only a move, no weights were updated.

cc @slupart @sclincha

  • Tom Aarsen
tomaarsen changed pull request status to open
NAVER LABS Europe org

Thanks @tomaarsen !

sclincha changed pull request status to merged

Sign up or log in to comment