Fix max_seqlen type in vision attention for torch.compile + FA2#44973
Fix max_seqlen type in vision attention for torch.compile + FA2#44973andylizf wants to merge 1 commit intohuggingface:mainfrom
Conversation
…+ FA2 compatibility When using `torch.compile` with `attn_implementation="flash_attention_2"`, `max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()` produces a 0-d tensor. The flash_attn C++ op expects `int` for `max_seqlen_q`/`max_seqlen_k`, causing a TorchRuntimeError during Dynamo tracing with FakeTensors. While `_process_flash_attention_kwargs` in `modeling_flash_attention_utils.py` already handles this conversion for the text model path, adding `.item()` at the source is more defensive and consistent. This affects all VL models sharing this vision attention pattern: Qwen2-VL, Qwen2.5-VL, Qwen3-VL, Qwen3.5, Qwen3.5-MoE, Qwen3-VL-MoE, Qwen2.5-Omni, Qwen3-Omni-MoE, GLM-4V, GLM-4V-MoE, GLM-Image, GLM-OCR, ERNIE-4.5-VL-MoE, PaddleOCR-VL, Video-LLaMA-3. Fixes huggingface#44962
|
[For maintainers] Suggested jobs to run (before merge) run-slow: ernie4_5_vl_moe, glm4v, glm4v_moe, glm_image, glm_ocr, paddleocr_vl, qwen2_5_omni, qwen2_5_vl, qwen2_vl, qwen3_5, qwen3_5_moe, qwen3_omni_moe, qwen3_vl, qwen3_vl_moe, video_llama_3 |
zucchini-nlp
left a comment
There was a problem hiding this comment.
Sounds reasonable to me, cc @vasqu for attention
vasqu
left a comment
There was a problem hiding this comment.
Yep, I only recently changed the way we handle this
transformers/src/transformers/modeling_flash_attention_utils.py
Lines 648 to 666 in 2f62491
The essence is to be as device sync friendly as we can by avoiding .item() calls where we can.
I read it in the wrong order, I thought we removed .item calls
There was a problem hiding this comment.
I am a bit confused because this should only affect old v4 versions iiuc. I don't see a reason why we would add this to main where we already correctly handle this downstream
I refactored this in order to avoid these device syncs where possible. I'm not sure how this benefits you: You will have not changed old v4 versions with this and we don't have any benefit in adding this (moreso a disadvantage by adding device syncs for non-compile paths).
I would ask you to use v5 in this case tbh, is there any reason to resort to some old v4 versions? Or anything you need in v5
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
What does this PR do?
Adds
.item()tomax_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()in all vision attention modules that pass this value toflash_attn_varlen_func.Context
On released versions (e.g. 4.52.4), using
torch.compile+attn_implementation="flash_attention_2"crashes becausemax_seqlenis a 0-d tensor and the flash_attn C++ op expectsint:On main, this is already handled downstream by
_process_flash_attention_kwargswhich converts via.item()whenis_tracing()is True (as @JJJYmmm pointed out). So this change is defense-in-depth on main, but a necessary fix for released versions.Adding
.item()at the source is consistent with howmodeling_flash_attention_utils.pydocuments the issue (line 352-353):Note: Qwen3.5's
Qwen3_5VisionAttention(line 1004) has the same pattern without.item()— it works on main only because of the downstream fix, not because of a different implementation.Affected models (19 files)
Qwen2-VL, Qwen2.5-VL, Qwen3-VL, Qwen3.5, Qwen3.5-MoE, Qwen3-VL-MoE, Qwen2.5-Omni, Qwen3-Omni-MoE, GLM-4V, GLM-4V-MoE, GLM-Image, GLM-OCR, ERNIE-4.5-VL-MoE, PaddleOCR-VL, Video-LLaMA-3
Reproduction (on transformers ≤ 4.52.4)
Fixes #44962