diff --git a/model/dit.py b/model/dit.py index 6985c53..1968b47 100755 --- a/model/dit.py +++ b/model/dit.py @@ -127,7 +127,16 @@ def __init__( self.dim = dim self.depth = depth - llama_config = LlamaConfig(hidden_size=dim, intermediate_size=dim * ff_mult, hidden_act='silu', max_position_embeddings=self.max_frames) + # Must set num_attention_heads to match DiT heads for correct head_dim in rotary embeddings + # head_dim = hidden_size // num_attention_heads, so num_attention_heads = hidden_size // dim_head + num_attention_heads = dim // dim_head + llama_config = LlamaConfig( + hidden_size=dim, + intermediate_size=dim * ff_mult, + hidden_act='silu', + max_position_embeddings=self.max_frames, + num_attention_heads=num_attention_heads, # Fix for transformers 4.49+ + ) llama_config._attn_implementation = 'sdpa' self.transformer_blocks = nn.ModuleList( [LlamaDecoderLayer(llama_config, layer_idx=i) for i in range(depth)] @@ -211,7 +220,9 @@ def forward( ) for i, block in enumerate(self.transformer_blocks): - x, *_ = block(x, attention_mask=attention_mask, position_embeddings=rotary_embed) + # Note: In transformers 5.0+, LlamaDecoderLayer returns tensor, not tuple + # Using `x, *_ = block(...)` would unpack the tensor's first dimension + x = block(x, attention_mask=attention_mask, position_embeddings=rotary_embed) if i < self.depth // 2: x = x + self.text_fusion_linears[i](text_embed)