| |
| |
| |
| |
| |
| |
| |
| from argparse import Namespace |
|
|
| import torch |
| from torch import nn |
| import torch.nn.functional as F |
|
|
| from .adaptor_registry import adaptor_registry, dict_t, state_t |
|
|
| from .adaptor_generic import GenericAdaptor |
|
|
|
|
| class OpenCLIP_RADIO(GenericAdaptor): |
| def __init__(self, main_config: Namespace, adaptor_config: dict_t, state: state_t): |
| super().__init__(main_config, adaptor_config, state) |
|
|
| import open_clip |
|
|
| self.oc_model = open_clip.create_model_from_pretrained( |
| model_name=adaptor_config['model'], |
| pretrained=adaptor_config['pretrained'], |
| return_transform=False, |
| ) |
| |
| self.oc_model.visual = None |
|
|
| self.tokenizer = open_clip.get_tokenizer(model_name=adaptor_config['model']) |
|
|
| def encode_text(self, text, normalize: bool = False): |
| return self.oc_model.encode_text(text, normalize=normalize) |
|
|
|
|
| @adaptor_registry.register_adaptor("open_clip") |
| def create_open_clip_adaptor(main_config: Namespace, adaptor_config: dict_t, state: state_t): |
| return OpenCLIP_RADIO(main_config, adaptor_config, state) |
|
|