Skip to content

Commit ccba9d5

Browse files
ksikiricKristian Sikiric
andauthored
Aiter round mode control (#590)
* Adding AITER round mode control to USP attention calls. * Added guarding so that AITER call does not break for earlier AITER commits that does not support changing the round mode * Added a more robust way of checking if round mode is available. Also removed code duplication when checking if round mode is available when calling aiter flash attention * changed aiter.ops.mha.flash_attn_func back to aiter.flash_attn_func as this was a misstake and should not have been changed in the first place. * Indentation fix --------- Co-authored-by: Kristian Sikiric <[email protected]>
1 parent 78cb759 commit ccba9d5

File tree

1 file changed

+17
-4
lines changed
  • xfuser/model_executor/layers

1 file changed

+17
-4
lines changed

xfuser/model_executor/layers/usp.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,14 @@
3030
HAS_AITER = env_info["has_aiter"]
3131
if HAS_AITER:
3232
import aiter
33+
import inspect
34+
try:
35+
HAS_ROUND_MODE = inspect.signature(aiter.flash_attn_func).parameters.get("how_v3_bf16_cvt") is not None
36+
except (AttributeError, TypeError):
37+
HAS_ROUND_MODE = False
38+
if HAS_ROUND_MODE:
39+
import os
40+
HOW_V3_BF16_CVT = int(os.environ.get("HOW_V3_BF16_CVT", "2"))
3341

3442
aten = torch.ops.aten
3543

@@ -175,14 +183,19 @@ def _aiter_attn_call(query, key, value, dropout_p, is_causal):
175183
query = torch.permute(query, [0, 2, 1, 3]).contiguous()
176184
key = torch.permute(key, [0, 2, 1, 3]).contiguous()
177185
value = torch.permute(value, [0, 2, 1, 3]).contiguous()
186+
attn_kwargs = {
187+
"dropout_p": dropout_p,
188+
"causal": is_causal,
189+
"return_attn_probs": False,
190+
"return_lse": True,
191+
}
192+
if HAS_ROUND_MODE:
193+
attn_kwargs["how_v3_bf16_cvt"] = HOW_V3_BF16_CVT
178194
output, softmax_lse = aiter.flash_attn_func(
179195
query,
180196
key,
181197
value,
182-
dropout_p=dropout_p,
183-
causal=is_causal,
184-
return_attn_probs=False,
185-
return_lse=True
198+
**attn_kwargs
186199
)
187200
output = torch.permute(output, [0, 2, 1, 3])
188201
return output, softmax_lse

0 commit comments

Comments
 (0)