| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """Utility to convert weights to safetensors.""" |
| |
|
| | import argparse |
| |
|
| | import torch |
| |
|
| | from .configuration_embed1 import CosmosEmbed1Config |
| | from .modeling_embed1 import CosmosEmbed1 |
| |
|
| |
|
| | def parse_args(): |
| | parser = argparse.ArgumentParser(description="Save model weights with optional format conversion and sharding.") |
| | parser.add_argument("--input_weights", type=str, required=True, help="Path to the input .pt weights file") |
| | parser.add_argument( |
| | "--output_weights", |
| | type=str, |
| | required=True, |
| | help="Path to the output directory where safetensors weights will be saved", |
| | ) |
| | return parser.parse_args() |
| |
|
| |
|
| | def main(): |
| | args = parse_args() |
| | model = CosmosEmbed1(CosmosEmbed1Config()).to("cuda", dtype=torch.bfloat16) |
| |
|
| | |
| | model.qformer.cls.predictions.decoder.weight = torch.nn.Parameter( |
| | model.qformer.cls.predictions.decoder.weight.clone() |
| | ) |
| | model.qformer.bert.embeddings.word_embeddings.weight = torch.nn.Parameter( |
| | model.qformer.bert.embeddings.word_embeddings.weight.clone() |
| | ) |
| | model.qformer.cls.predictions.decoder.bias = torch.nn.Parameter(model.qformer.cls.predictions.decoder.bias.clone()) |
| | model.qformer.cls.predictions.bias = torch.nn.Parameter(model.qformer.cls.predictions.bias.clone()) |
| |
|
| | with open(args.input_weights, "rb") as fp: |
| | state_dict = torch.load(fp) |
| | model.load_state_dict(state_dict, strict=True) |
| |
|
| | model.save_pretrained( |
| | args.output_weights, |
| | safe_serialization=True, |
| | max_shard_size="500MB", |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|