| |
| |
| |
| |
| |
| |
| |
| from argparse import Namespace |
| from typing import Dict, Any |
|
|
| import torch |
|
|
| from .adaptor_generic import GenericAdaptor, AdaptorBase |
|
|
| dict_t = Dict[str, Any] |
| state_t = Dict[str, torch.Tensor] |
|
|
|
|
| class AdaptorRegistry: |
| def __init__(self): |
| self._registry = {} |
|
|
| def register_adaptor(self, name): |
| def decorator(factory_function): |
| if name in self._registry: |
| raise ValueError(f"Model '{name}' already registered") |
| self._registry[name] = factory_function |
| return factory_function |
| return decorator |
|
|
| def create_adaptor(self, name, main_config: Namespace, adaptor_config: dict_t, state: state_t) -> AdaptorBase: |
| if name not in self._registry: |
| return GenericAdaptor(main_config, adaptor_config, state) |
| return self._registry[name](main_config, adaptor_config, state) |
|
|
| |
| adaptor_registry = AdaptorRegistry() |
|
|