From 6d11cebd6b69ff2e44a7cac1f48e2e081e6de259 Mon Sep 17 00:00:00 2001 From: xander1421 Date: Sun, 1 Feb 2026 04:37:14 +0200 Subject: [PATCH] fix: Compatibility with transformers 4.49+ and 5.0+ Two fixes for newer transformers versions: 1. Add explicit num_attention_heads to LlamaConfig - transformers 4.49+ requires this for correct head_dim calculation - head_dim = hidden_size // num_attention_heads - Without this, rotary embeddings have wrong dimensions 2. Fix LlamaDecoderLayer output handling for transformers 5.0+ - In transformers 5.0, LlamaDecoderLayer returns a tensor directly - Previously returned tuple (hidden_states, ...) - Using `x, *_ = block(...)` on a tensor iterates over first dimension, effectively doing x = tensor[0] and losing the batch dimension - Changed to `x = block(...)` which works for both old and new versions Tested with: - Python 3.14 - PyTorch 2.7 - transformers 5.0.0 Co-Authored-By: Claude Opus 4.5 --- model/dit.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/model/dit.py b/model/dit.py index 6985c53..1968b47 100755 --- a/model/dit.py +++ b/model/dit.py @@ -127,7 +127,16 @@ def __init__( self.dim = dim self.depth = depth - llama_config = LlamaConfig(hidden_size=dim, intermediate_size=dim * ff_mult, hidden_act='silu', max_position_embeddings=self.max_frames) + # Must set num_attention_heads to match DiT heads for correct head_dim in rotary embeddings + # head_dim = hidden_size // num_attention_heads, so num_attention_heads = hidden_size // dim_head + num_attention_heads = dim // dim_head + llama_config = LlamaConfig( + hidden_size=dim, + intermediate_size=dim * ff_mult, + hidden_act='silu', + max_position_embeddings=self.max_frames, + num_attention_heads=num_attention_heads, # Fix for transformers 4.49+ + ) llama_config._attn_implementation = 'sdpa' self.transformer_blocks = nn.ModuleList( [LlamaDecoderLayer(llama_config, layer_idx=i) for i in range(depth)] @@ -211,7 +220,9 @@ def forward( ) for i, block in enumerate(self.transformer_blocks): - x, *_ = block(x, attention_mask=attention_mask, position_embeddings=rotary_embed) + # Note: In transformers 5.0+, LlamaDecoderLayer returns tensor, not tuple + # Using `x, *_ = block(...)` would unpack the tensor's first dimension + x = block(x, attention_mask=attention_mask, position_embeddings=rotary_embed) if i < self.depth // 2: x = x + self.text_fusion_linears[i](text_embed)