Simplify and support Sentence Transformers via SparseEncoder
Hello!
Pull Request overview
- Simplify the bidirectional Qwen3 implementation heavily by relying on https://github.com/huggingface/transformers/pull/43705 (requires transformers v5.2.0+)
- Support SparseEncoder from Sentence Transformers, matches original implementation outputs
Details
This PR mirrors https://huggingface.co/naver/splade-code-06B/discussions/1. Apologies that I did not get this out earlier, I started the work, but totally forgot to finish it over the weekend.
The easier usage is now, while the old transformers code is unaffected:
from sentence_transformers import SparseEncoder
model = SparseEncoder("naver/splade-code-8B", trust_remote_code=True, revision="refs/pr/1")
queries = [
"SELECT *\nFROM Student\nWHERE Age = (\nSELECT MAX(Age)\nFROM Student\nWHERE Group = 'specific_group'\n)\nAND Group = 'specific_group';"
]
query_embeddings = model.encode(queries)
print(query_embeddings.shape)
# torch.Size([1, 151936])
sparsity = model.sparsity(query_embeddings)
print(sparsity)
# {'active_dims': 1120.0, 'sparsity_ratio': 0.9926284751474305}
decoded = model.decode(query_embeddings, top_k=10)
print(decoded)
# [[
# ('Δ group', 2.34375),
# ('Δ oldest', 2.28125),
# ('Δ age', 2.25),
# ('_group', 2.25),
# ('Δ Group', 2.171875),
# ('Δ Age', 2.109375),
# ('Δ MAX', 2.0625),
# ('Δ Student', 2.046875),
# ('Δ specific', 2.03125),
# ('Δ student', 2.0),
# ]]
And it works with transformers>5.2.0 and sentence-transformers>5.0.0. If you install kernels, you can set model_kwargs={"attn_implementation": "flash_attention_2"} and it will use a kernel from the Hub without having to actually install flash-attn (which is always annoying).
It also works with e.g. sdpa or eager, unlike the current implementation, and is likely to keep working with future transformers versions as it imports very little from transformers.
Note that the above script has revision="refs/pr/1" so you can test it directly from this PR branch without having to check anything out locally or merge it.
As mentioned, the old Transformers code also gives the same result as before, e.g. when I run the baseline code with transformers locally, I get:
+--------------------------------------------------------------------+
| TOP ACTIVATED WORDS |
+--------------------------------------------------------------------+
* INPUT: SELECT *
FROM Student
WHERE Age = (
SELECT MAX(Age)
FROM Student
WHERE Group = 'specific_group'
)
AND Group = 'specific_group';
Δ group | ββββββββββββββββββββ 2.34
Δ oldest | βββββββββββββββββββ 2.28
Δ age | βββββββββββββββββββ 2.25
_group | βββββββββββββββββββ 2.25
Δ Group | ββββββββββββββββββ 2.17
Δ Age | ββββββββββββββββββ 2.11
Δ MAX | βββββββββββββββββ 2.06
Δ Student | βββββββββββββββββ 2.05
Δ specific | βββββββββββββββββ 2.03
Δ student | βββββββββββββββββ 2.00
P.s. I moved the lora adapter files to the main directory as these files are commonly kept there. It's only a move, no weights were updated.
- Tom Aarsen
Thanks @tomaarsen !