Skip to content

Fix GPT2 attention scaling ignored in SDPA/FlashAttention#44397

Open
OiPunk wants to merge 3 commits intohuggingface:mainfrom
OiPunk:codex/transformers-44380-gpt2-sdpa-scaling
Open

Fix GPT2 attention scaling ignored in SDPA/FlashAttention#44397
OiPunk wants to merge 3 commits intohuggingface:mainfrom
OiPunk:codex/transformers-44380-gpt2-sdpa-scaling

Conversation

@OiPunk
Copy link
Contributor

@OiPunk OiPunk commented Mar 2, 2026

What does this PR do?

Fixes #44380

GPT2Attention.forward() did not pass the scaling parameter to attention_interface, causing scale_attn_weights and scale_attn_by_inverse_layer_idx config options to be silently ignored when using SDPA or FlashAttention backends.

The eager attention implementation (eager_attention_forward) reads these flags directly from the module and applies scaling correctly. However, SDPA and FlashAttention rely on the scaling parameter passed to the attention function call — which GPT2 never provided.

Changes

  1. modeling_gpt2.py: Compute self.scaling in GPT2Attention.__init__ by combining scale_attn_weights (1/√d_k) and scale_attn_by_inverse_layer_idx (1/(layer_idx+1)), following the same pattern used by LLaMA and other models.
  2. modeling_gpt2.py: Pass scaling=self.scaling to attention_interface() in forward().
  3. test_modeling_gpt2.py: Add test_gpt2_sdpa_matches_eager_with_scaling_configs that verifies SDPA and eager produce equivalent outputs when using non-default scaling configs.

Behavior table (before → after)

Config Eager SDPA before fix SDPA after fix
scale_attn_weights=True (default) ÷√d_k ÷√d_k (PyTorch default, coincidental match) ÷√d_k ✓
scale_attn_weights=False No scaling ÷√d_k (wrong) No scaling ✓
scale_attn_by_inverse_layer_idx=True ÷(layer+1) Ignored ÷(layer+1) ✓

OiPunk and others added 2 commits March 3, 2026 00:14
…ends

GPT2Attention.forward() did not pass the `scaling` parameter to
`attention_interface`, causing `scale_attn_weights` and
`scale_attn_by_inverse_layer_idx` config options to be silently
ignored when using SDPA or FlashAttention backends.

Compute the combined scaling factor in __init__ (following the pattern
used by LLaMA and other models) and forward it to the attention
interface so all backends produce consistent results.

Fixes huggingface#44380
…2Attention)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some smaller comments but we got lucky on the default case tbh. Very good finding

Comment on lines +105 to +106
if self.scale_attn_weights:
self.scaling = self.head_dim**-0.5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we got lucky here because SDPA and FA will default to exactly this

Comment on lines +107 to +108
if self.scale_attn_by_inverse_layer_idx:
self.scaling /= float(self.layer_idx + 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what was silently ignored then

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imo, we can just use self.scaling in the eager forward as well then and copy from bert or similar (meaning the eager forward). No need for extra treatment then

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review! Addressed in the latest commit (04f9ba9):

  • self.scaling is computed once in __init__ and used in both the _upcast_and_reordered_attn path (via baddbmm alpha) and the standard attention_interface path (via scaling=self.scaling).
  • The old per-forward scale factor computation in _upcast_and_reordered_attn is removed.

So no extra treatment — just self.scaling everywhere, same pattern as bert.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

)
result.loss.backward()

def test_gpt2_sdpa_matches_eager_with_scaling_configs(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also check FA

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added! The FA test is at test_gpt2_fa2_matches_eager_with_scaling_configs — both tests use model.set_attn_implementation() to switch backends without reloading.

config.scale_attn_by_inverse_layer_idx = True

# Eager attention (known-correct reference)
config._attn_implementation = "eager"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather use this after init, i.e. model.set_attn_implementation("eager")

output_eager = model_eager(input_ids, token_type_ids=token_type_ids).logits

# SDPA attention (was buggy: ignored scaling configs)
config._attn_implementation = "sdpa"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Comment on lines +264 to +265
model_sdpa = GPT2LMHeadModel(config).to(torch_device).eval()
model_sdpa.load_state_dict(model_eager.state_dict())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need this reloading stuff; we just switch up the flags with the set_attn...

…ve tests

Per reviewer feedback:
- Replace inline scale_factor computation with self.scaling in
  _upcast_and_reordered_attn for both GPT2 and DecisionTransformer
- Use model.set_attn_implementation() instead of model reloading in tests
- Add FlashAttention2 vs eager comparison test
@OiPunk
Copy link
Contributor Author

OiPunk commented Mar 3, 2026

@vasqu Thanks for the review! Addressed all feedback:

  1. self.scaling in _upcast_and_reordered_attn: Replaced the inline scale_factor computation with self.scaling in both GPT2 and DecisionTransformer. Now all attention paths consistently use self.scaling computed once in __init__.

  2. set_attn_implementation in tests: Switched from model reloading with config._attn_implementation to model.set_attn_implementation("eager")/"sdpa" — single model, flag swap.

  3. FlashAttention2 test: Added test_gpt2_fa2_matches_eager_with_scaling_configs gated behind @require_torch_accelerator and @require_flash_attn.

@github-actions
Copy link
Contributor

github-actions bot commented Mar 3, 2026

[For maintainers] Suggested jobs to run (before merge)

run-slow: decision_transformer, gpt2

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating, left another round of smaller comments to simplify a bit / follow some conventions

with torch.no_grad():
output_eager = model(input_ids, token_type_ids=token_type_ids).logits

# FlashAttention2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# FlashAttention2
# Flash Attention 2 (was buggy: ignored scaling configs)

with torch.no_grad():
output_fa2 = model(input_ids, token_type_ids=token_type_ids).logits

torch.testing.assert_close(output_eager, output_fa2, atol=5e-3, rtol=5e-3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
torch.testing.assert_close(output_eager, output_fa2, atol=5e-3, rtol=5e-3)
torch.testing.assert_close(output_eager, output_fa2, atol=1e-2, rtol=1e-2)

Any reason for that atol/rtol? I remember FA being flaky at times so raising it would be personally preferred

with torch.no_grad():
output_sdpa = model(input_ids, token_type_ids=token_type_ids).logits

torch.testing.assert_close(output_eager, output_sdpa, atol=1e-5, rtol=1e-4)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
torch.testing.assert_close(output_eager, output_sdpa, atol=1e-5, rtol=1e-4)
torch.testing.assert_close(output_eager, output_sdpa, atol=1e-4, rtol=1e-4)

same here, just a small raise to avoid flakiness

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


torch.testing.assert_close(output_eager, output_sdpa, atol=1e-5, rtol=1e-4)

@require_torch_accelerator
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@require_torch_accelerator
@require_torch_gpu
@mark.flash_attn_test

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

GPT2 attention scaling config is ignored when using SDPA / FlashAttention backends

2 participants