| import torch |
| from torch import nn |
| from omegaconf import OmegaConf |
| import numpy as np |
| from huggingface_hub import hf_hub_download |
| import os |
| from audiocraft.encodec import EncodecModel |
| from audiocraft.lm import LMModel |
| from audiocraft.seanet import SEANetDecoder |
| from audiocraft.vq import ResidualVectorQuantizer |
|
|
|
|
| N_REPEAT = 2 |
|
|
| def _shift(x): |
| |
| for i, _slice in enumerate(x): |
| n = x.shape[2] |
| offset = np.random.randint(.24 * n, max(1, .74 * n)) |
| print(offset) |
| x[i, :, :] = torch.roll(_slice, offset, dims=1) |
| return x |
|
|
| class AudioGen(torch.nn.Module): |
|
|
| |
|
|
| def __init__(self): |
|
|
| super().__init__() |
| _file_1 = hf_hub_download( |
| repo_id='facebook/audiogen-medium', |
| filename="compression_state_dict.bin", |
| cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None), |
| library_name="audiocraft", |
| library_version= '1.3.0a1') |
| pkg = torch.load(_file_1, map_location='cpu') |
| decoder = SEANetDecoder() |
| quantizer = ResidualVectorQuantizer() |
| self.compression_model = EncodecModel(decoder=decoder, |
| quantizer=quantizer, |
| frame_rate=50, |
| renormalize=False, |
| sample_rate=16000, |
| channels=1, |
| causal=False) |
| self.compression_model.load_state_dict(pkg['best_state'], strict=False) |
| self.compression_model.eval() |
| |
| |
| _file_2 = hf_hub_download( |
| repo_id='facebook/audiogen-medium', |
| filename="state_dict.bin", |
| cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None), |
| library_name="audiocraft", |
| library_version= '1.3.0a1') |
| pkg = torch.load(_file_2, map_location='cpu') |
| cfg = OmegaConf.create(pkg['xp.cfg']) |
| _best = pkg['best_state'] |
| _best['t5.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight') |
| _best['t5.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias') |
| self.lm = LMModel() |
| self.lm.load_state_dict(pkg['best_state'], strict=True) |
| self.lm.eval() |
|
|
|
|
| @torch.no_grad() |
| def generate(self, |
| prompt='dogs mewo', |
| duration=2.24, |
| ): |
| torch.manual_seed(42) |
| self.lm.n_draw = int(duration / 12) + 1 |
| |
| with torch.autocast(device_type='cuda', dtype=torch.float16): |
| gen_tokens = self.lm.generate( |
| text_condition=[prompt] * N_REPEAT + [''] * N_REPEAT, |
| max_tokens=int(duration / (N_REPEAT * self.lm.n_draw) * self.compression_model.frame_rate) |
| ) |
| x = self.compression_model.decode(gen_tokens, None) |
|
|
|
|
| for _ in range(7): |
| x = _shift(x) |
|
|
| return x.reshape(-1) |
|
|