Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 49 additions & 2 deletions ppdiffusers/ppdiffusers/patches/paddle_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,9 +351,44 @@ def to(self=None, device=None, dtype=None, blocking=None):

nn.Layer.to = to

from ..utils.import_utils import is_ppxformers_available
from ..utils.import_utils import is_ppxformers_available, is_npu_available

if is_ppxformers_available():
if is_npu_available():
for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
if lib.endswith(".so"):
paddle.utils.cpp_extension.extension_utils.load_op_meta_info_and_register_op(
lib
)
from paddle.base import core
def scaled_dot_product_attention_npu(query,
key,
value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
training=True,
name=None,
fixed_seed_offset=None,
return_softmax=False,
is_triangle_upper_mask=True,
):
out = core.eager._run_custom_op(
"flash_attention_npu",
query,
key,
value,
fixed_seed_offset,
attn_mask,
dropout_p,
is_causal,
return_softmax,
not training,
is_triangle_upper_mask,
)[0]
return out
paddle.nn.functional.scaled_dot_product_attention_npu = scaled_dot_product_attention_npu

if is_ppxformers_available() or is_npu_available():
from paddle.incubate.nn.memory_efficient_attention import memory_efficient_attention

try:
Expand Down Expand Up @@ -392,6 +427,8 @@ def scaled_dot_product_attention_(
attention_op = "cutlass"
if is_support_flash_attention and query.dtype not in [paddle.float32]:
attention_op = "flash"
elif is_npu_available() and query.dtype not in [paddle.float32]:
attention_op = "flash_npu"
else:
if attention_op == "flash" and flash_attn_error is not None:
raise OSError(flash_attn_error)
Expand Down Expand Up @@ -473,6 +510,16 @@ def scaled_dot_product_attention_(
is_causal=bool(is_causal),
training=training,
)
elif attention_op == "flash_npu":
output = paddle.nn.functional.scaled_dot_product_attention_npu(
query,
key,
value,
attn_mask=None if is_causal else attn_mask,
dropout_p=dropout_p if training else 0.0,
is_causal=bool(is_causal),
training=training,
)
else:
raise ValueError(
"ppxformers's attention_op shoulde be in ['auto', 'math', 'cutlass', `memory_efficient`, 'flash']."
Expand Down
2 changes: 2 additions & 0 deletions ppdiffusers/ppdiffusers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,8 @@ def is_scipy_available():
def is_librosa_available():
return _librosa_available

def is_npu_available():
return paddle.device.get_device().startswith("npu")

def is_ppxformers_available():
USE_PPXFORMERS = str2bool(os.getenv("USE_PPXFORMERS", True))
Expand Down