Skip to content

Fix FA kernel launch needs correct cuda device ctx in multi-gpu env#44967

Open
Qubitium wants to merge 5 commits intohuggingface:mainfrom
Qubitium:fix-fa-cuda-ctx
Open

Fix FA kernel launch needs correct cuda device ctx in multi-gpu env#44967
Qubitium wants to merge 5 commits intohuggingface:mainfrom
Qubitium:fix-fa-cuda-ctx

Conversation

@Qubitium
Copy link
Contributor

@Qubitium Qubitium commented Mar 24, 2026

What does this PR do?

Fix: FA kernel launches currently are not thread-safe (nogil) in multi-gpu env. This simple patch fixes the issue.

 # Set the correct CUDA context before launching the FlashAttention kernel.
 with torch.cuda.device(query.device):

Bind to correct cuda device ctx before launching the FA kernel.

This is much harder to replicate but the patch is simple and I tihink, obivously enough to point to the src of potential bug.

This PR patches the ctx at the source of the FA kernel launch, but another method is make sure the entire inference call (upstream) has a single ctx biind so any down stream calls can be protected. So either move this up. the call stack or just leave it here. I leave the decision up to pkg maintainer.

Without this fix, my loop test (real code doing CB + Page Attn + FA2) on 2 gpu, 2 threads, 2 models, will sometimes, not always, emit FA stack-traces in the form of:

`None` tensor which is neither float16 or bflaot16 crashes.  <-- not direct quote, paraphrased

Code Agent Policy

  • I confirm that this is not a pure code agent PR.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ArthurZucker @remi-or This is part of the same CB call stack as I run with CB + Paged Attn + FA2 with two models on two threads each with different gpu on nogil.

Note: This FA crash is much harder to reproduce. I still trying to generate a script to consistently reproduce this error.

Stack

    File "/root/vm314t/lib/python3.14t/site-packages/flash_attn/flash_attn_interface.py", line 165, in _flash_attn_varlen_forward
      out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd(
                                             ~~~~~~~~~~~~~~~~~~~~~~~~~^
          q,
          ^^
      ...
          None,
          ^^^^^
          )
  RuntimeError: FlashAttention only support fp16 and bf16 data type

The above looked exactly like the stacktrace I got but the problem is the above was generated by feeding FA kernel fp32 tensors or none tensor, which is not possible in my working code which can no longer replicate this crash for whatever reason. Sigh.

Still working to reproduce this. I still believe this is related to cross-cuda ctx issue and memory cleared on the wrong cuda ctx thread by python or cuda memory allocator.

  • Possibility that q tensor is Float32 or any other type that is not BF16/F16: 100% impossible
  • Possibility that q tensor is None: Yes! If the memory handle was cleared/held in the wrong cuda ctx thread.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant