Matryoshka Representation Learning
Paper • 2205.13147 • Published • 27
How to use FareedKhan/flax-sentence-embeddings_all_datasets_v4_MiniLM-L6_FareedKhan_prime_synthetic_data_2k_10_64 with sentence-transformers:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("FareedKhan/flax-sentence-embeddings_all_datasets_v4_MiniLM-L6_FareedKhan_prime_synthetic_data_2k_10_64")
sentences = [
"\nThe list you've provided contains a variety of medications, including antidepressants, antihistamines, anxiolytics, and more. Here's a breakdown by category:\n\n### Antidepressants\n- **Amphetamine**\n- **Cevimeline**\n- **Esmolol**\n- **Bortezomib**\n- **",
"Which body parts are associated with the expression of genes or proteins that impact the transporter responsible for the movement of Cycloserine?",
"Identify genes or proteins that interact with a protein threonine kinase, participate in the mitotic centrosome proteins and complexes recruitment pathway, and engage in protein-protein interactions with CCT2.",
"Which medication is effective against simple Plasmodium falciparum infections and functions by engaging with genes or proteins that interact with the minor groove of DNA rich in adenine and thymine?"
]
embeddings = model.encode(sentences)
similarities = model.similarity(embeddings, embeddings)
print(similarities.shape)
# [4, 4]This is a sentence-transformers model finetuned from flax-sentence-embeddings/all_datasets_v4_MiniLM-L6 on the json dataset. It maps sentences & paragraphs to a 384-dimensional dense vector space and can be used for semantic textual similarity, semantic search, paraphrase mining, text classification, clustering, and more.
SentenceTransformer(
(0): Transformer({'max_seq_length': 128, 'do_lower_case': False}) with Transformer model: BertModel
(1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
(2): Normalize()
)
First install the Sentence Transformers library:
pip install -U sentence-transformers
Then you can load this model and run inference.
from sentence_transformers import SentenceTransformer
# Download from the 🤗 Hub
model = SentenceTransformer("FareedKhan/flax-sentence-embeddings_all_datasets_v4_MiniLM-L6_FareedKhan_prime_synthetic_data_2k_10_64")
# Run inference
sentences = [
'\n\nDiarrhea, a condition characterized by the passage of loose, watery, and often more than five times a day, is a common ailment affecting individuals of all ages. It is typically acute when it lasts for a few days to a week or recurrent when it persists for more than four weeks. While acute diarrhea often resolves on its own and is usually not a cause for concern, recurrent or chronic forms require medical attention due to the risk of dehydration and nutrient deficiencies. \n\n### Causes\n\nDiarrhea can be caused by various factors, including:\n\n1. **Viral',
'Could you assist in identifying a condition linked to congenital secretory diarrhea, similar to intractable diarrhea of infancy, given my symptoms of persistent, salty watery diarrhea, hyponatremia, abnormal body pH, and reliance on parenteral nutrition due to chronic dehydration?',
'Could you describe the specific effects or phenotypes associated with acute hydrops in patients with the subtype of keratoconus?',
]
embeddings = model.encode(sentences)
print(embeddings.shape)
# [3, 384]
# Get the similarity scores for the embeddings
similarities = model.similarity(embeddings, embeddings)
print(similarities.shape)
# [3, 3]
dim_384InformationRetrievalEvaluator| Metric | Value |
|---|---|
| cosine_accuracy@1 | 0.3614 |
| cosine_accuracy@3 | 0.3861 |
| cosine_accuracy@5 | 0.4257 |
| cosine_accuracy@10 | 0.4653 |
| cosine_precision@1 | 0.3614 |
| cosine_precision@3 | 0.1287 |
| cosine_precision@5 | 0.0851 |
| cosine_precision@10 | 0.0465 |
| cosine_recall@1 | 0.3614 |
| cosine_recall@3 | 0.3861 |
| cosine_recall@5 | 0.4257 |
| cosine_recall@10 | 0.4653 |
| cosine_ndcg@10 | 0.407 |
| cosine_mrr@10 | 0.3891 |
| cosine_map@100 | 0.396 |
positive and anchor| positive | anchor | |
|---|---|---|
| type | string | string |
| details |
|
|
| positive | anchor |
|---|---|
|
Which pharmacological agents with antioxidant properties have the potential to disrupt the PCSK9-LDLR interaction by affecting the gene or protein players in this pathway? |
|
What is the name of the gynecological condition that arises due to blocked Bartholin's glands and involves cyst formation, falling under the broader category of women's reproductive health issues? |
|
What condition could be associated with the use of Capsaicin cream, peripheral neuropathy, and symptoms similar to sciatica? |
MatryoshkaLoss with these parameters:{
"loss": "MultipleNegativesRankingLoss",
"matryoshka_dims": [
384
],
"matryoshka_weights": [
1
],
"n_dims_per_step": -1
}
eval_strategy: epochper_device_train_batch_size: 64learning_rate: 1e-05num_train_epochs: 10warmup_ratio: 0.1bf16: Truetf32: Falseload_best_model_at_end: Trueoverwrite_output_dir: Falsedo_predict: Falseeval_strategy: epochprediction_loss_only: Trueper_device_train_batch_size: 64per_device_eval_batch_size: 8per_gpu_train_batch_size: Noneper_gpu_eval_batch_size: Nonegradient_accumulation_steps: 1eval_accumulation_steps: Nonetorch_empty_cache_steps: Nonelearning_rate: 1e-05weight_decay: 0.0adam_beta1: 0.9adam_beta2: 0.999adam_epsilon: 1e-08max_grad_norm: 1.0num_train_epochs: 10max_steps: -1lr_scheduler_type: linearlr_scheduler_kwargs: {}warmup_ratio: 0.1warmup_steps: 0log_level: passivelog_level_replica: warninglog_on_each_node: Truelogging_nan_inf_filter: Truesave_safetensors: Truesave_on_each_node: Falsesave_only_model: Falserestore_callback_states_from_checkpoint: Falseno_cuda: Falseuse_cpu: Falseuse_mps_device: Falseseed: 42data_seed: Nonejit_mode_eval: Falseuse_ipex: Falsebf16: Truefp16: Falsefp16_opt_level: O1half_precision_backend: autobf16_full_eval: Falsefp16_full_eval: Falsetf32: Falselocal_rank: 0ddp_backend: Nonetpu_num_cores: Nonetpu_metrics_debug: Falsedebug: []dataloader_drop_last: Falsedataloader_num_workers: 0dataloader_prefetch_factor: Nonepast_index: -1disable_tqdm: Falseremove_unused_columns: Truelabel_names: Noneload_best_model_at_end: Trueignore_data_skip: Falsefsdp: []fsdp_min_num_params: 0fsdp_config: {'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}fsdp_transformer_layer_cls_to_wrap: Noneaccelerator_config: {'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None}deepspeed: Nonelabel_smoothing_factor: 0.0optim: adamw_torchoptim_args: Noneadafactor: Falsegroup_by_length: Falselength_column_name: lengthddp_find_unused_parameters: Noneddp_bucket_cap_mb: Noneddp_broadcast_buffers: Falsedataloader_pin_memory: Truedataloader_persistent_workers: Falseskip_memory_metrics: Trueuse_legacy_prediction_loop: Falsepush_to_hub: Falseresume_from_checkpoint: Nonehub_model_id: Nonehub_strategy: every_savehub_private_repo: Falsehub_always_push: Falsegradient_checkpointing: Falsegradient_checkpointing_kwargs: Noneinclude_inputs_for_metrics: Falseeval_do_concat_batches: Truefp16_backend: autopush_to_hub_model_id: Nonepush_to_hub_organization: Nonemp_parameters: auto_find_batch_size: Falsefull_determinism: Falsetorchdynamo: Noneray_scope: lastddp_timeout: 1800torch_compile: Falsetorch_compile_backend: Nonetorch_compile_mode: Nonedispatch_batches: Nonesplit_batches: Noneinclude_tokens_per_second: Falseinclude_num_input_tokens_seen: Falseneftune_noise_alpha: Noneoptim_target_modules: Nonebatch_eval_metrics: Falseeval_on_start: Falseuse_liger_kernel: Falseeval_use_gather_object: Falsebatch_sampler: batch_samplermulti_dataset_batch_sampler: proportional| Epoch | Step | Training Loss | dim_384_cosine_map@100 |
|---|---|---|---|
| 0 | 0 | - | 0.3614 |
| 0.3448 | 10 | 2.117 | - |
| 0.6897 | 20 | 2.1255 | - |
| 1.0 | 29 | - | 0.3855 |
| 1.0345 | 30 | 1.9375 | - |
| 1.3793 | 40 | 1.7987 | - |
| 1.7241 | 50 | 1.7494 | - |
| 2.0 | 58 | - | 0.3901 |
| 2.0690 | 60 | 1.7517 | - |
| 2.4138 | 70 | 1.676 | - |
| 2.7586 | 80 | 1.608 | - |
| 3.0 | 87 | - | 0.3934 |
| 3.1034 | 90 | 1.5923 | - |
| 3.4483 | 100 | 1.5095 | - |
| 3.7931 | 110 | 1.5735 | - |
| 4.0 | 116 | - | 0.3910 |
| 4.1379 | 120 | 1.3643 | - |
| 4.4828 | 130 | 1.4395 | - |
| 4.8276 | 140 | 1.3595 | - |
| 5.0 | 145 | - | 0.3884 |
| 5.1724 | 150 | 1.3365 | - |
| 5.5172 | 160 | 1.3506 | - |
| 5.8621 | 170 | 1.3279 | - |
| 6.0 | 174 | - | 0.3957 |
| 6.2069 | 180 | 1.3075 | - |
| 6.5517 | 190 | 1.3138 | - |
| 6.8966 | 200 | 1.2749 | - |
| 7.0 | 203 | - | 0.3979 |
| 7.2414 | 210 | 1.1725 | - |
| 7.5862 | 220 | 1.2696 | - |
| 7.9310 | 230 | 1.2487 | - |
| 8.0 | 232 | - | 0.3986 |
| 8.2759 | 240 | 1.1558 | - |
| 8.6207 | 250 | 1.2447 | - |
| 8.9655 | 260 | 1.2566 | - |
| 9.0 | 261 | - | 0.3964 |
| 9.3103 | 270 | 1.2493 | - |
| 9.6552 | 280 | 1.2697 | - |
| 10.0 | 290 | 1.079 | 0.3960 |
@inproceedings{reimers-2019-sentence-bert,
title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
author = "Reimers, Nils and Gurevych, Iryna",
booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
month = "11",
year = "2019",
publisher = "Association for Computational Linguistics",
url = "https://arxiv.org/abs/1908.10084",
}
@misc{kusupati2024matryoshka,
title={Matryoshka Representation Learning},
author={Aditya Kusupati and Gantavya Bhatt and Aniket Rege and Matthew Wallingford and Aditya Sinha and Vivek Ramanujan and William Howard-Snyder and Kaifeng Chen and Sham Kakade and Prateek Jain and Ali Farhadi},
year={2024},
eprint={2205.13147},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
@misc{henderson2017efficient,
title={Efficient Natural Language Response Suggestion for Smart Reply},
author={Matthew Henderson and Rami Al-Rfou and Brian Strope and Yun-hsuan Sung and Laszlo Lukacs and Ruiqi Guo and Sanjiv Kumar and Balint Miklos and Ray Kurzweil},
year={2017},
eprint={1705.00652},
archivePrefix={arXiv},
primaryClass={cs.CL}
}