Skip to content

Fix max_seqlen type in vision attention for torch.compile + FA2#44973

Open
andylizf wants to merge 1 commit intohuggingface:mainfrom
andylizf:fix-qwen-vl-vision-attn-compile
Open

Fix max_seqlen type in vision attention for torch.compile + FA2#44973
andylizf wants to merge 1 commit intohuggingface:mainfrom
andylizf:fix-qwen-vl-vision-attn-compile

Conversation

@andylizf
Copy link

@andylizf andylizf commented Mar 24, 2026

What does this PR do?

Adds .item() to max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() in all vision attention modules that pass this value to flash_attn_varlen_func.

Context

On released versions (e.g. 4.52.4), using torch.compile + attn_implementation="flash_attention_2" crashes because max_seqlen is a 0-d tensor and the flash_attn C++ op expects int:

TorchRuntimeError: flash_attn::_flash_attn_varlen_forward()
Expected a value of type 'int' for argument 'max_seqlen_q'
but instead found type 'FakeTensor'.

On main, this is already handled downstream by _process_flash_attention_kwargs which converts via .item() when is_tracing() is True (as @JJJYmmm pointed out). So this change is defense-in-depth on main, but a necessary fix for released versions.

Adding .item() at the source is consistent with how modeling_flash_attention_utils.py documents the issue (line 352-353):

# This is a limitation of flash attention API, as the function `flash_attn_varlen_func`
# requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`.
max_length_q = max_length_q.item()

Note: Qwen3.5's Qwen3_5VisionAttention (line 1004) has the same pattern without .item() — it works on main only because of the downstream fix, not because of a different implementation.

Affected models (19 files)

Qwen2-VL, Qwen2.5-VL, Qwen3-VL, Qwen3.5, Qwen3.5-MoE, Qwen3-VL-MoE, Qwen2.5-Omni, Qwen3-Omni-MoE, GLM-4V, GLM-4V-MoE, GLM-Image, GLM-OCR, ERNIE-4.5-VL-MoE, PaddleOCR-VL, Video-LLaMA-3

Reproduction (on transformers ≤ 4.52.4)

import torch
torch.set_float32_matmul_precision("high")
from PIL import Image
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration

model = Qwen3VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen3-VL-Embedding-2B",
    dtype=torch.bfloat16, attn_implementation="flash_attention_2",
).cuda().eval()
model = torch.compile(model, mode="max-autotune-no-cudagraphs")

processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-Embedding-2B")
img = Image.new("RGB", (875, 1024), color=(128, 128, 128))
messages = [
    {"role": "system", "content": [{"type": "text", "text": "Describe."}]},
    {"role": "user", "content": [{"type": "image", "image": img}]},
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[text], images=[img], return_tensors="pt", padding=True)
inputs = {k: v.to("cuda") if hasattr(v, "to") else v for k, v in inputs.items()}

with torch.no_grad():
    outputs = model(**inputs, output_hidden_states=True)  # crashes on ≤4.52, works on main

Fixes #44962

…+ FA2 compatibility

When using `torch.compile` with `attn_implementation="flash_attention_2"`,
`max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()` produces a 0-d
tensor. The flash_attn C++ op expects `int` for `max_seqlen_q`/`max_seqlen_k`,
causing a TorchRuntimeError during Dynamo tracing with FakeTensors.

While `_process_flash_attention_kwargs` in `modeling_flash_attention_utils.py`
already handles this conversion for the text model path, adding `.item()`
at the source is more defensive and consistent.

This affects all VL models sharing this vision attention pattern:
Qwen2-VL, Qwen2.5-VL, Qwen3-VL, Qwen3.5, Qwen3.5-MoE, Qwen3-VL-MoE,
Qwen2.5-Omni, Qwen3-Omni-MoE, GLM-4V, GLM-4V-MoE, GLM-Image, GLM-OCR,
ERNIE-4.5-VL-MoE, PaddleOCR-VL, Video-LLaMA-3.

Fixes huggingface#44962
@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: ernie4_5_vl_moe, glm4v, glm4v_moe, glm_image, glm_ocr, paddleocr_vl, qwen2_5_omni, qwen2_5_vl, qwen2_vl, qwen3_5, qwen3_5_moe, qwen3_omni_moe, qwen3_vl, qwen3_vl_moe, video_llama_3

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds reasonable to me, cc @vasqu for attention

vasqu
vasqu previously approved these changes Mar 25, 2026
Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I only recently changed the way we handle this

# There is a limitation of the flash attention API, as the function `flash_attn_varlen_func`
# may require `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`.
#
# You can either set
# - Env: `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1`
# - Before compiling: `torch._dynamo.config.capture_scalar_outputs = True`
# to allow torch compile to handle scalar outputs in those cases.
same_max_seqlen = max_seqlen_q is max_seqlen_k # to avoid 2x device syncs
if supports_mapping["max_seqlen_q"] and max_seqlen_q is not None:
if not isinstance(max_seqlen_q, int) and is_tracing(max_seqlen_q):
max_seqlen_q = max_seqlen_q.item()
flash_kwargs["max_seqlen_q"] = max_seqlen_q
if supports_mapping["max_seqlen_k"] and max_seqlen_k is not None:
if same_max_seqlen and flash_kwargs["max_seqlen_q"] is not None:
max_seqlen_k = flash_kwargs["max_seqlen_q"]
elif not isinstance(max_seqlen_k, int) and is_tracing(max_seqlen_k):
max_seqlen_k = max_seqlen_k.item()
flash_kwargs["max_seqlen_k"] = max_seqlen_k

The essence is to be as device sync friendly as we can by avoiding .item() calls where we can.

@vasqu vasqu dismissed their stale review March 25, 2026 14:04

I read it in the wrong order, I thought we removed .item calls

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit confused because this should only affect old v4 versions iiuc. I don't see a reason why we would add this to main where we already correctly handle this downstream

I refactored this in order to avoid these device syncs where possible. I'm not sure how this benefits you: You will have not changed old v4 versions with this and we don't have any benefit in adding this (moreso a disadvantage by adding device syncs for non-compile paths).

I would ask you to use v5 in this case tbh, is there any reason to resort to some old v4 versions? Or anything you need in v5

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Qwen3VL/Qwen2.5VL VisionAttention breaks torch.compile with flash_attention_2

4 participants