Fix GPT2 attention scaling ignored in SDPA/FlashAttention#44397
Fix GPT2 attention scaling ignored in SDPA/FlashAttention#44397OiPunk wants to merge 3 commits intohuggingface:mainfrom
Conversation
…ends GPT2Attention.forward() did not pass the `scaling` parameter to `attention_interface`, causing `scale_attn_weights` and `scale_attn_by_inverse_layer_idx` config options to be silently ignored when using SDPA or FlashAttention backends. Compute the combined scaling factor in __init__ (following the pattern used by LLaMA and other models) and forward it to the attention interface so all backends produce consistent results. Fixes huggingface#44380
…2Attention) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
vasqu
left a comment
There was a problem hiding this comment.
Some smaller comments but we got lucky on the default case tbh. Very good finding
| if self.scale_attn_weights: | ||
| self.scaling = self.head_dim**-0.5 |
There was a problem hiding this comment.
I think we got lucky here because SDPA and FA will default to exactly this
| if self.scale_attn_by_inverse_layer_idx: | ||
| self.scaling /= float(self.layer_idx + 1) |
There was a problem hiding this comment.
This is what was silently ignored then
There was a problem hiding this comment.
Imo, we can just use self.scaling in the eager forward as well then and copy from bert or similar (meaning the eager forward). No need for extra treatment then
There was a problem hiding this comment.
Thanks for the review! Addressed in the latest commit (04f9ba9):
self.scalingis computed once in__init__and used in both the_upcast_and_reordered_attnpath (viabaddbmm alpha) and the standardattention_interfacepath (viascaling=self.scaling).- The old per-forward scale factor computation in
_upcast_and_reordered_attnis removed.
So no extra treatment — just self.scaling everywhere, same pattern as bert.
There was a problem hiding this comment.
Let's also address https://github.com/OiPunk/transformers/blob/04f9ba9ff0e3de8a8b21c801ba79509328ff14da/src/transformers/models/gpt2/modeling_gpt2.py#L54-L79?target=https://github.com then
We can be closer to Bert since we unify into self.scaling instead https://github.com/OiPunk/transformers/blob/04f9ba9ff0e3de8a8b21c801ba79509328ff14da/src/transformers/models/bert/modeling_bert.py#L115-L140?target=https://github.com
| ) | ||
| result.loss.backward() | ||
|
|
||
| def test_gpt2_sdpa_matches_eager_with_scaling_configs(self): |
There was a problem hiding this comment.
Added! The FA test is at test_gpt2_fa2_matches_eager_with_scaling_configs — both tests use model.set_attn_implementation() to switch backends without reloading.
| config.scale_attn_by_inverse_layer_idx = True | ||
|
|
||
| # Eager attention (known-correct reference) | ||
| config._attn_implementation = "eager" |
There was a problem hiding this comment.
I'd rather use this after init, i.e. model.set_attn_implementation("eager")
| output_eager = model_eager(input_ids, token_type_ids=token_type_ids).logits | ||
|
|
||
| # SDPA attention (was buggy: ignored scaling configs) | ||
| config._attn_implementation = "sdpa" |
| model_sdpa = GPT2LMHeadModel(config).to(torch_device).eval() | ||
| model_sdpa.load_state_dict(model_eager.state_dict()) |
There was a problem hiding this comment.
We don't need this reloading stuff; we just switch up the flags with the set_attn...
…ve tests Per reviewer feedback: - Replace inline scale_factor computation with self.scaling in _upcast_and_reordered_attn for both GPT2 and DecisionTransformer - Use model.set_attn_implementation() instead of model reloading in tests - Add FlashAttention2 vs eager comparison test
|
@vasqu Thanks for the review! Addressed all feedback:
|
|
[For maintainers] Suggested jobs to run (before merge) run-slow: decision_transformer, gpt2 |
vasqu
left a comment
There was a problem hiding this comment.
Thanks for iterating, left another round of smaller comments to simplify a bit / follow some conventions
| with torch.no_grad(): | ||
| output_eager = model(input_ids, token_type_ids=token_type_ids).logits | ||
|
|
||
| # FlashAttention2 |
There was a problem hiding this comment.
| # FlashAttention2 | |
| # Flash Attention 2 (was buggy: ignored scaling configs) |
| with torch.no_grad(): | ||
| output_fa2 = model(input_ids, token_type_ids=token_type_ids).logits | ||
|
|
||
| torch.testing.assert_close(output_eager, output_fa2, atol=5e-3, rtol=5e-3) |
There was a problem hiding this comment.
| torch.testing.assert_close(output_eager, output_fa2, atol=5e-3, rtol=5e-3) | |
| torch.testing.assert_close(output_eager, output_fa2, atol=1e-2, rtol=1e-2) |
Any reason for that atol/rtol? I remember FA being flaky at times so raising it would be personally preferred
| with torch.no_grad(): | ||
| output_sdpa = model(input_ids, token_type_ids=token_type_ids).logits | ||
|
|
||
| torch.testing.assert_close(output_eager, output_sdpa, atol=1e-5, rtol=1e-4) |
There was a problem hiding this comment.
| torch.testing.assert_close(output_eager, output_sdpa, atol=1e-5, rtol=1e-4) | |
| torch.testing.assert_close(output_eager, output_sdpa, atol=1e-4, rtol=1e-4) |
same here, just a small raise to avoid flakiness
There was a problem hiding this comment.
Let's also address https://github.com/OiPunk/transformers/blob/04f9ba9ff0e3de8a8b21c801ba79509328ff14da/src/transformers/models/gpt2/modeling_gpt2.py#L54-L79?target=https://github.com then
We can be closer to Bert since we unify into self.scaling instead https://github.com/OiPunk/transformers/blob/04f9ba9ff0e3de8a8b21c801ba79509328ff14da/src/transformers/models/bert/modeling_bert.py#L115-L140?target=https://github.com
|
|
||
| torch.testing.assert_close(output_eager, output_sdpa, atol=1e-5, rtol=1e-4) | ||
|
|
||
| @require_torch_accelerator |
There was a problem hiding this comment.
| @require_torch_accelerator | |
| @require_torch_gpu | |
| @mark.flash_attn_test |
What does this PR do?
Fixes #44380
GPT2Attention.forward()did not pass thescalingparameter toattention_interface, causingscale_attn_weightsandscale_attn_by_inverse_layer_idxconfig options to be silently ignored when using SDPA or FlashAttention backends.The eager attention implementation (
eager_attention_forward) reads these flags directly from the module and applies scaling correctly. However, SDPA and FlashAttention rely on thescalingparameter passed to the attention function call — which GPT2 never provided.Changes
modeling_gpt2.py: Computeself.scalinginGPT2Attention.__init__by combiningscale_attn_weights(1/√d_k) andscale_attn_by_inverse_layer_idx(1/(layer_idx+1)), following the same pattern used by LLaMA and other models.modeling_gpt2.py: Passscaling=self.scalingtoattention_interface()inforward().test_modeling_gpt2.py: Addtest_gpt2_sdpa_matches_eager_with_scaling_configsthat verifies SDPA and eager produce equivalent outputs when using non-default scaling configs.Behavior table (before → after)
scale_attn_weights=True(default)scale_attn_weights=Falsescale_attn_by_inverse_layer_idx=True