import math import torch import torch.nn as nn from .modules import ( film_modulate, unpatchify, PatchEmbed, PE_wrapper, TimestepEmbedder, FeedForward, RMSNorm, ) from .attention import Attention class AdaLN(nn.Module): def __init__(self, dim, ada_mode='ada', r=None, alpha=None): super().__init__() self.ada_mode = ada_mode self.scale_shift_table = None if ada_mode == 'ada': self.time_ada = nn.Linear(dim, 6 * dim, bias=True) elif ada_mode == 'ada_single': self.scale_shift_table = nn.Parameter(torch.zeros(6, dim)) elif ada_mode in ['ada_sola', 'ada_sola_bias']: self.lora_a = nn.Linear(dim, r * 6, bias=False) self.lora_b = nn.Linear(r * 6, dim * 6, bias=False) self.scaling = alpha / r if ada_mode == 'ada_sola_bias': self.scale_shift_table = nn.Parameter(torch.zeros(6, dim)) else: raise NotImplementedError def forward(self, time_token=None, time_ada=None): if self.ada_mode == 'ada': assert time_ada is None B = time_token.shape[0] time_ada = self.time_ada(time_token).reshape(B, 6, -1) elif self.ada_mode == 'ada_single': B = time_ada.shape[0] time_ada = time_ada.reshape(B, 6, -1) time_ada = self.scale_shift_table[None] + time_ada elif self.ada_mode in ['ada_sola', 'ada_sola_bias']: B = time_ada.shape[0] time_ada_lora = self.lora_b(self.lora_a(time_token)) * self.scaling time_ada = time_ada + time_ada_lora time_ada = time_ada.reshape(B, 6, -1) if self.scale_shift_table is not None: time_ada = self.scale_shift_table[None] + time_ada else: raise NotImplementedError return time_ada class DiTBlock(nn.Module): def __init__( self, dim, context_dim=None, num_heads=8, mlp_ratio=4., qkv_bias=False, qk_scale=None, qk_norm=None, act_layer='gelu', norm_layer=nn.LayerNorm, time_fusion='none', ada_sola_rank=None, ada_sola_alpha=None, skip=False, skip_norm=False, rope_mode='none', context_norm=False, use_checkpoint=False ): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( dim=dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm, rope_mode=rope_mode ) if context_dim is not None: self.use_context = True self.cross_attn = Attention( dim=dim, num_heads=num_heads, context_dim=context_dim, qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm, rope_mode='none' ) self.norm2 = norm_layer(dim) if context_norm: self.norm_context = norm_layer(context_dim) else: self.norm_context = nn.Identity() else: self.use_context = False self.norm3 = norm_layer(dim) self.mlp = FeedForward( dim=dim, mult=mlp_ratio, activation_fn=act_layer, dropout=0 ) self.use_adanorm = True if time_fusion != 'token' else False if self.use_adanorm: self.adaln = AdaLN( dim, ada_mode=time_fusion, r=ada_sola_rank, alpha=ada_sola_alpha ) if skip: self.skip_norm = norm_layer(2 * dim) if skip_norm else nn.Identity() self.skip_linear = nn.Linear(2 * dim, dim) else: self.skip_linear = None self.use_checkpoint = use_checkpoint def forward( self, x, time_token=None, time_ada=None, skip=None, context=None, x_mask=None, context_mask=None, extras=None ): if self.use_checkpoint: from torch.utils.checkpoint import checkpoint return checkpoint( self._forward, x, time_token, time_ada, skip, context, x_mask, context_mask, extras, use_reentrant=False ) else: return self._forward( x, time_token, time_ada, skip, context, x_mask, context_mask, extras ) def _forward( self, x, time_token=None, time_ada=None, skip=None, context=None, x_mask=None, context_mask=None, extras=None ): B, T, C = x.shape if self.skip_linear is not None: assert skip is not None cat = torch.cat([x, skip], dim=-1) cat = self.skip_norm(cat) x = self.skip_linear(cat) if self.use_adanorm: time_ada = self.adaln(time_token, time_ada) (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp) = time_ada.chunk(6, dim=1) if self.use_adanorm: x_norm = film_modulate( self.norm1(x), shift=shift_msa, scale=scale_msa ) x = x + (1 - gate_msa) * self.attn( x_norm, context=None, context_mask=x_mask, extras=extras ) else: x = x + self.attn( self.norm1(x), context=None, context_mask=x_mask, extras=extras ) if self.use_context: assert context is not None x = x + self.cross_attn( x=self.norm2(x), context=self.norm_context(context), context_mask=context_mask, extras=extras ) if self.use_adanorm: x_norm = film_modulate( self.norm3(x), shift=shift_mlp, scale=scale_mlp ) x = x + (1 - gate_mlp) * self.mlp(x_norm) else: x = x + self.mlp(self.norm3(x)) return x class FinalBlock(nn.Module): def __init__( self, embed_dim, patch_size, in_chans, img_size, input_type='2d', norm_layer=nn.LayerNorm, use_conv=True, use_adanorm=True ): super().__init__() self.in_chans = in_chans self.img_size = img_size self.input_type = input_type self.norm = norm_layer(embed_dim) self.use_adanorm = use_adanorm if input_type == '2d': self.patch_dim = patch_size**2 * in_chans self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True) if use_conv: self.final_layer = nn.Conv2d( self.in_chans, self.in_chans, 3, padding=1 ) else: self.final_layer = nn.Identity() elif input_type == '1d': self.patch_dim = patch_size * in_chans self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True) if use_conv: self.final_layer = nn.Conv1d( self.in_chans, self.in_chans, 3, padding=1 ) else: self.final_layer = nn.Identity() def forward(self, x, time_ada=None, extras=0): B, T, C = x.shape x = x[:, extras:, :] if self.use_adanorm: shift, scale = time_ada.reshape(B, 2, -1).chunk(2, dim=1) x = film_modulate(self.norm(x), shift, scale) else: x = self.norm(x) x = self.linear(x) x = unpatchify(x, self.in_chans, self.input_type, self.img_size) x = self.final_layer(x) return x class UDiT(nn.Module): def __init__( self, img_size=224, patch_size=16, in_chans=3, input_type='2d', out_chans=None, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, qk_norm=None, act_layer='gelu', norm_layer='layernorm', context_norm=False, use_checkpoint=False, time_fusion='token', ada_sola_rank=None, ada_sola_alpha=None, cls_dim=None, context_dim=768, context_fusion='concat', context_max_length=128, context_pe_method='sinu', pe_method='abs', rope_mode='none', use_conv=True, skip=True, skip_norm=True ): super().__init__() self.num_features = self.embed_dim = embed_dim self.in_chans = in_chans self.input_type = input_type if self.input_type == '2d': num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size) elif self.input_type == '1d': num_patches = img_size // patch_size self.patch_embed = PatchEmbed( patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, input_type=input_type ) out_chans = in_chans if out_chans is None else out_chans self.out_chans = out_chans self.rope = rope_mode self.x_pe = PE_wrapper( dim=embed_dim, method=pe_method, length=num_patches ) self.time_embed = TimestepEmbedder(embed_dim) self.time_fusion = time_fusion self.use_adanorm = False if cls_dim is not None: self.cls_embed = nn.Sequential( nn.Linear(cls_dim, embed_dim, bias=True), nn.SiLU(), nn.Linear(embed_dim, embed_dim, bias=True), ) else: self.cls_embed = None if time_fusion == 'token': self.extras = 2 if self.cls_embed else 1 self.time_pe = PE_wrapper( dim=embed_dim, method='abs', length=self.extras ) elif time_fusion in ['ada', 'ada_single', 'ada_sola', 'ada_sola_bias']: self.use_adanorm = True self.time_act = nn.SiLU() self.extras = 0 self.time_ada_final = nn.Linear( embed_dim, 2 * embed_dim, bias=True ) if time_fusion in ['ada_single', 'ada_sola', 'ada_sola_bias']: self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True) else: self.time_ada = None else: raise NotImplementedError self.use_context = False self.context_cross = False self.context_max_length = context_max_length self.context_fusion = 'none' if context_dim is not None: self.use_context = True self.context_embed = nn.Sequential( nn.Linear(context_dim, embed_dim, bias=True), nn.SiLU(), nn.Linear(embed_dim, embed_dim, bias=True), ) self.context_fusion = context_fusion if context_fusion == 'concat' or context_fusion == 'joint': self.extras += context_max_length self.context_pe = PE_wrapper( dim=embed_dim, method=context_pe_method, length=context_max_length ) context_dim = None elif context_fusion == 'cross': self.context_pe = PE_wrapper( dim=embed_dim, method=context_pe_method, length=context_max_length ) self.context_cross = True context_dim = embed_dim else: raise NotImplementedError self.use_skip = skip if norm_layer == 'layernorm': norm_layer = nn.LayerNorm elif norm_layer == 'rmsnorm': norm_layer = RMSNorm else: raise NotImplementedError self.in_blocks = nn.ModuleList([ DiTBlock( dim=embed_dim, context_dim=context_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm, act_layer=act_layer, norm_layer=norm_layer, time_fusion=time_fusion, ada_sola_rank=ada_sola_rank, ada_sola_alpha=ada_sola_alpha, skip=False, skip_norm=False, rope_mode=self.rope, context_norm=context_norm, use_checkpoint=use_checkpoint ) for _ in range(depth // 2) ]) self.mid_block = DiTBlock( dim=embed_dim, context_dim=context_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm, act_layer=act_layer, norm_layer=norm_layer, time_fusion=time_fusion, ada_sola_rank=ada_sola_rank, ada_sola_alpha=ada_sola_alpha, skip=False, skip_norm=False, rope_mode=self.rope, context_norm=context_norm, use_checkpoint=use_checkpoint ) self.out_blocks = nn.ModuleList([ DiTBlock( dim=embed_dim, context_dim=context_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm, act_layer=act_layer, norm_layer=norm_layer, time_fusion=time_fusion, ada_sola_rank=ada_sola_rank, ada_sola_alpha=ada_sola_alpha, skip=skip, skip_norm=skip_norm, rope_mode=self.rope, context_norm=context_norm, use_checkpoint=use_checkpoint ) for _ in range(depth // 2) ]) self.use_conv = use_conv self.final_block = FinalBlock( embed_dim=embed_dim, patch_size=patch_size, img_size=img_size, in_chans=out_chans, input_type=input_type, norm_layer=norm_layer, use_conv=use_conv, use_adanorm=self.use_adanorm ) self.initialize_weights() def _init_ada(self): if self.time_fusion == 'ada': nn.init.constant_(self.time_ada_final.weight, 0) nn.init.constant_(self.time_ada_final.bias, 0) for block in self.in_blocks: nn.init.constant_(block.adaln.time_ada.weight, 0) nn.init.constant_(block.adaln.time_ada.bias, 0) nn.init.constant_(self.mid_block.adaln.time_ada.weight, 0) nn.init.constant_(self.mid_block.adaln.time_ada.bias, 0) for block in self.out_blocks: nn.init.constant_(block.adaln.time_ada.weight, 0) nn.init.constant_(block.adaln.time_ada.bias, 0) elif self.time_fusion == 'ada_single': nn.init.constant_(self.time_ada.weight, 0) nn.init.constant_(self.time_ada.bias, 0) nn.init.constant_(self.time_ada_final.weight, 0) nn.init.constant_(self.time_ada_final.bias, 0) elif self.time_fusion in ['ada_sola', 'ada_sola_bias']: nn.init.constant_(self.time_ada.weight, 0) nn.init.constant_(self.time_ada.bias, 0) nn.init.constant_(self.time_ada_final.weight, 0) nn.init.constant_(self.time_ada_final.bias, 0) for block in self.in_blocks: nn.init.kaiming_uniform_( block.adaln.lora_a.weight, a=math.sqrt(5) ) nn.init.constant_(block.adaln.lora_b.weight, 0) nn.init.kaiming_uniform_( self.mid_block.adaln.lora_a.weight, a=math.sqrt(5) ) nn.init.constant_(self.mid_block.adaln.lora_b.weight, 0) for block in self.out_blocks: nn.init.kaiming_uniform_( block.adaln.lora_a.weight, a=math.sqrt(5) ) nn.init.constant_(block.adaln.lora_b.weight, 0) def initialize_weights(self): def _basic_init(module): if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) self.apply(_basic_init) w = self.patch_embed.proj.weight.data nn.init.xavier_uniform_(w.view([w.shape[0], -1])) nn.init.constant_(self.patch_embed.proj.bias, 0) if self.use_adanorm: self._init_ada() if self.context_cross: for block in self.in_blocks: nn.init.constant_(block.cross_attn.proj.weight, 0) nn.init.constant_(block.cross_attn.proj.bias, 0) nn.init.constant_(self.mid_block.cross_attn.proj.weight, 0) nn.init.constant_(self.mid_block.cross_attn.proj.bias, 0) for block in self.out_blocks: nn.init.constant_(block.cross_attn.proj.weight, 0) nn.init.constant_(block.cross_attn.proj.bias, 0) if self.cls_embed: if self.use_adanorm: nn.init.constant_(self.cls_embed[-1].weight, 0) nn.init.constant_(self.cls_embed[-1].bias, 0) if self.use_conv: nn.init.xavier_uniform_(self.final_block.final_layer.weight) nn.init.constant_(self.final_block.final_layer.bias, 0) def _concat_x_context(self, x, context, x_mask=None, context_mask=None): assert context.shape[-2] == self.context_max_length B = x.shape[0] if x_mask is None: x_mask = torch.ones(B, x.shape[-2], device=x.device).bool() if context_mask is None: context_mask = torch.ones( B, context.shape[-2], device=context.device ).bool() x_mask = torch.cat([context_mask, x_mask], dim=1) x = torch.cat((context, x), dim=1) return x, x_mask def forward( self, x, timesteps, context, x_mask=None, context_mask=None, cls_token=None, controlnet_skips=None, ): if timesteps.dim() == 0: timesteps = timesteps.expand(x.shape[0]).to(x.device, dtype=torch.long) x = self.patch_embed(x) x = self.x_pe(x) B, L, D = x.shape if self.use_context: context_token = self.context_embed(context) context_token = self.context_pe(context_token) if self.context_fusion == 'concat' or self.context_fusion == 'joint': x, x_mask = self._concat_x_context( x=x, context=context_token, x_mask=x_mask, context_mask=context_mask ) context_token, context_mask = None, None else: context_token, context_mask = None, None time_token = self.time_embed(timesteps) if self.cls_embed: cls_token = self.cls_embed(cls_token) time_ada = None time_ada_final = None if self.use_adanorm: if self.cls_embed: time_token = time_token + cls_token time_token = self.time_act(time_token) time_ada_final = self.time_ada_final(time_token) if self.time_ada is not None: time_ada = self.time_ada(time_token) else: time_token = time_token.unsqueeze(dim=1) if self.cls_embed: cls_token = cls_token.unsqueeze(dim=1) time_token = torch.cat([time_token, cls_token], dim=1) time_token = self.time_pe(time_token) x = torch.cat((time_token, x), dim=1) if x_mask is not None: x_mask = torch.cat([ torch.ones(B, time_token.shape[1], device=x_mask.device).bool(), x_mask ], dim=1) time_token = None skips = [] for blk in self.in_blocks: x = blk( x=x, time_token=time_token, time_ada=time_ada, skip=None, context=context_token, x_mask=x_mask, context_mask=context_mask, extras=self.extras ) if self.use_skip: skips.append(x) x = self.mid_block( x=x, time_token=time_token, time_ada=time_ada, skip=None, context=context_token, x_mask=x_mask, context_mask=context_mask, extras=self.extras ) for blk in self.out_blocks: if self.use_skip: skip = skips.pop() if controlnet_skips: skip = skip + controlnet_skips.pop() else: skip = None if controlnet_skips: x = x + controlnet_skips.pop() x = blk( x=x, time_token=time_token, time_ada=time_ada, skip=skip, context=context_token, x_mask=x_mask, context_mask=context_mask, extras=self.extras ) x = self.final_block(x, time_ada=time_ada_final, extras=self.extras) return x class LayerFusionDiTBlock(DiTBlock): def __init__( self, dim, ta_context_dim, ta_context_norm=False, context_dim=None, num_heads=8, mlp_ratio=4., qkv_bias=False, qk_scale=None, qk_norm=None, act_layer='gelu', norm_layer=nn.LayerNorm, ta_context_fusion='add', time_fusion='none', ada_sola_rank=None, ada_sola_alpha=None, skip=False, skip_norm=False, rope_mode='none', context_norm=False, use_checkpoint=False ): super().__init__( dim=dim, context_dim=context_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm, act_layer=act_layer, norm_layer=norm_layer, time_fusion=time_fusion, ada_sola_rank=ada_sola_rank, ada_sola_alpha=ada_sola_alpha, skip=skip, skip_norm=skip_norm, rope_mode=rope_mode, context_norm=context_norm, use_checkpoint=use_checkpoint ) self.ta_context_fusion = ta_context_fusion self.ta_context_norm = ta_context_norm if self.ta_context_fusion == "add": self.ta_context_projection = nn.Linear( ta_context_dim, dim, bias=False ) self.ta_context_norm = norm_layer( ta_context_dim ) if self.ta_context_norm else nn.Identity() elif self.ta_context_fusion == "concat": self.ta_context_projection = nn.Linear(ta_context_dim + dim, dim) self.ta_context_norm = norm_layer( ta_context_dim + dim ) if self.ta_context_norm else nn.Identity() def forward( self, x, time_aligned_context, time_token=None, time_ada=None, skip=None, context=None, x_mask=None, context_mask=None, extras=None ): if self.use_checkpoint: from torch.utils.checkpoint import checkpoint return checkpoint( self._forward, x, time_aligned_context, time_token, time_ada, skip, context, x_mask, context_mask, extras, use_reentrant=False ) else: return self._forward( x, time_aligned_context, time_token, time_ada, skip, context, x_mask, context_mask, extras, ) def _forward( self, x, time_aligned_context, time_token=None, time_ada=None, skip=None, context=None, x_mask=None, context_mask=None, extras=None ): B, T, C = x.shape if self.skip_linear is not None: assert skip is not None cat = torch.cat([x, skip], dim=-1) cat = self.skip_norm(cat) x = self.skip_linear(cat) if self.use_adanorm: time_ada = self.adaln(time_token, time_ada) (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp) = time_ada.chunk(6, dim=1) if self.use_adanorm: x_norm = film_modulate( self.norm1(x), shift=shift_msa, scale=scale_msa ) tanh_gate_msa = torch.tanh(1 - gate_msa) x = x + tanh_gate_msa * self.attn( x_norm, context=None, context_mask=x_mask, extras=extras ) else: x = x + self.attn( self.norm1(x), context=None, context_mask=x_mask, extras=extras ) if self.ta_context_fusion == "add": time_aligned_context = self.ta_context_projection( self.ta_context_norm(time_aligned_context) ) if time_aligned_context.size(1) < x.size(1): time_aligned_context = nn.functional.pad( time_aligned_context, (0, 0, 1, 0) ) x = x + time_aligned_context elif self.ta_context_fusion == "concat": if time_aligned_context.size(1) < x.size(1): time_aligned_context = nn.functional.pad( time_aligned_context, (0, 0, 1, 0) ) cat = torch.cat([x, time_aligned_context], dim=-1) cat = self.ta_context_norm(cat) x = self.ta_context_projection(cat) if self.use_context: assert context is not None x = x + self.cross_attn( x=self.norm2(x), context=self.norm_context(context), context_mask=context_mask, extras=extras ) if self.use_adanorm: x_norm = film_modulate( self.norm3(x), shift=shift_mlp, scale=scale_mlp ) x = x + (1 - gate_mlp) * self.mlp(x_norm) else: x = x + self.mlp(self.norm3(x)) return x class LayerFusionAudioDiT(UDiT): def __init__( self, img_size=224, patch_size=16, in_chans=3, input_type='2d', out_chans=None, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=False, qk_scale=None, qk_norm=None, act_layer='gelu', norm_layer='layernorm', context_norm=False, use_checkpoint=False, time_fusion='token', ada_sola_rank=None, ada_sola_alpha=None, cls_dim=None, ta_context_dim=768, ta_context_fusion='concat', ta_context_norm=True, context_dim=768, context_fusion='concat', context_max_length=128, context_pe_method='sinu', pe_method='abs', rope_mode='none', use_conv=True, skip=True, skip_norm=True ): nn.Module.__init__(self) self.num_features = self.embed_dim = embed_dim self.in_chans = in_chans self.input_type = input_type if self.input_type == '2d': num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size) elif self.input_type == '1d': num_patches = img_size // patch_size self.patch_embed = PatchEmbed( patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, input_type=input_type ) out_chans = in_chans if out_chans is None else out_chans self.out_chans = out_chans self.rope = rope_mode self.x_pe = PE_wrapper( dim=embed_dim, method=pe_method, length=num_patches ) self.time_embed = TimestepEmbedder(embed_dim) self.time_fusion = time_fusion self.use_adanorm = False if cls_dim is not None: self.cls_embed = nn.Sequential( nn.Linear(cls_dim, embed_dim, bias=True), nn.SiLU(), nn.Linear(embed_dim, embed_dim, bias=True), ) else: self.cls_embed = None if time_fusion == 'token': self.extras = 2 if self.cls_embed else 1 self.time_pe = PE_wrapper( dim=embed_dim, method='abs', length=self.extras ) elif time_fusion in ['ada', 'ada_single', 'ada_sola', 'ada_sola_bias']: self.use_adanorm = True self.time_act = nn.SiLU() self.extras = 0 self.time_ada_final = nn.Linear( embed_dim, 2 * embed_dim, bias=True ) if time_fusion in ['ada_single', 'ada_sola', 'ada_sola_bias']: self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True) else: self.time_ada = None else: raise NotImplementedError self.use_context = False self.context_cross = False self.context_max_length = context_max_length self.context_fusion = 'none' if context_dim is not None: self.use_context = True self.context_embed = nn.Sequential( nn.Linear(context_dim, embed_dim, bias=True), nn.SiLU(), nn.Linear(embed_dim, embed_dim, bias=True), ) self.context_fusion = context_fusion if context_fusion == 'concat' or context_fusion == 'joint': self.extras += context_max_length self.context_pe = PE_wrapper( dim=embed_dim, method=context_pe_method, length=context_max_length ) context_dim = None elif context_fusion == 'cross': self.context_pe = PE_wrapper( dim=embed_dim, method=context_pe_method, length=context_max_length ) self.context_cross = True context_dim = embed_dim else: raise NotImplementedError self.use_skip = skip if norm_layer == 'layernorm': norm_layer = nn.LayerNorm elif norm_layer == 'rmsnorm': norm_layer = RMSNorm else: raise NotImplementedError self.in_blocks = nn.ModuleList([ LayerFusionDiTBlock( dim=embed_dim, ta_context_dim=ta_context_dim, ta_context_fusion=ta_context_fusion, ta_context_norm=ta_context_norm, context_dim=context_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm, act_layer=act_layer, norm_layer=norm_layer, time_fusion=time_fusion, ada_sola_rank=ada_sola_rank, ada_sola_alpha=ada_sola_alpha, skip=False, skip_norm=False, rope_mode=self.rope, context_norm=context_norm, use_checkpoint=use_checkpoint ) for i in range(depth // 2) ]) self.mid_block = LayerFusionDiTBlock( dim=embed_dim, ta_context_dim=ta_context_dim, context_dim=context_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm, act_layer=act_layer, norm_layer=norm_layer, time_fusion=time_fusion, ada_sola_rank=ada_sola_rank, ada_sola_alpha=ada_sola_alpha, ta_context_fusion=ta_context_fusion, ta_context_norm=ta_context_norm, skip=False, skip_norm=False, rope_mode=self.rope, context_norm=context_norm, use_checkpoint=use_checkpoint ) self.out_blocks = nn.ModuleList([ LayerFusionDiTBlock( dim=embed_dim, ta_context_dim=ta_context_dim, context_dim=context_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm, act_layer=act_layer, norm_layer=norm_layer, time_fusion=time_fusion, ada_sola_rank=ada_sola_rank, ada_sola_alpha=ada_sola_alpha, ta_context_fusion=ta_context_fusion, ta_context_norm=ta_context_norm, skip=skip, skip_norm=skip_norm, rope_mode=self.rope, context_norm=context_norm, use_checkpoint=use_checkpoint ) for i in range(depth // 2) ]) self.use_conv = use_conv self.final_block = FinalBlock( embed_dim=embed_dim, patch_size=patch_size, img_size=img_size, in_chans=out_chans, input_type=input_type, norm_layer=norm_layer, use_conv=use_conv, use_adanorm=self.use_adanorm ) self.initialize_weights() def forward( self, x, timesteps, time_aligned_context, context, x_mask=None, context_mask=None, cls_token=None, controlnet_skips=None, ): if timesteps.dim() == 0: timesteps = timesteps.expand(x.shape[0]).to(x.device, dtype=torch.long) x = self.patch_embed(x) x = self.x_pe(x) B, L, D = x.shape if self.use_context: context_token = self.context_embed(context) context_token = self.context_pe(context_token) if self.context_fusion == 'concat' or self.context_fusion == 'joint': x, x_mask = self._concat_x_context( x=x, context=context_token, x_mask=x_mask, context_mask=context_mask ) context_token, context_mask = None, None else: context_token, context_mask = None, None time_token = self.time_embed(timesteps) if self.cls_embed: cls_token = self.cls_embed(cls_token) time_ada = None time_ada_final = None if self.use_adanorm: if self.cls_embed: time_token = time_token + cls_token time_token = self.time_act(time_token) time_ada_final = self.time_ada_final(time_token) if self.time_ada is not None: time_ada = self.time_ada(time_token) else: time_token = time_token.unsqueeze(dim=1) if self.cls_embed: cls_token = cls_token.unsqueeze(dim=1) time_token = torch.cat([time_token, cls_token], dim=1) time_token = self.time_pe(time_token) x = torch.cat((time_token, x), dim=1) if x_mask is not None: x_mask = torch.cat([ torch.ones(B, time_token.shape[1], device=x_mask.device).bool(), x_mask ], dim=1) time_token = None skips = [] for blk in self.in_blocks: x = blk( x=x, time_aligned_context=time_aligned_context, time_token=time_token, time_ada=time_ada, skip=None, context=context_token, x_mask=x_mask, context_mask=context_mask, extras=self.extras ) if self.use_skip: skips.append(x) x = self.mid_block( x=x, time_aligned_context=time_aligned_context, time_token=time_token, time_ada=time_ada, skip=None, context=context_token, x_mask=x_mask, context_mask=context_mask, extras=self.extras ) for blk in self.out_blocks: if self.use_skip: skip = skips.pop() if controlnet_skips: skip = skip + controlnet_skips.pop() else: skip = None if controlnet_skips: x = x + controlnet_skips.pop() x = blk( x=x, time_aligned_context=time_aligned_context, time_token=time_token, time_ada=time_ada, skip=skip, context=context_token, x_mask=x_mask, context_mask=context_mask, extras=self.extras ) x = self.final_block(x, time_ada=time_ada_final, extras=self.extras) return x