Skip to content

Commit 75b52b6

Browse files
Add NPU support for one model in one node
Signed-off-by: ChenTaoyu-SJTU <[email protected]>
1 parent cd06115 commit 75b52b6

File tree

3 files changed

+33
-4
lines changed

3 files changed

+33
-4
lines changed

xfuser/core/distributed/parallel_state.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@
2222
except ModuleNotFoundError:
2323
pass
2424

25+
try:
26+
from torch.npu import set_device, device_count
27+
except ModuleNotFoundError:
28+
pass
29+
2530
from .utils import RankGenerator
2631

2732
env_info = envs.PACKAGES_CHECKER.get_packages_info()

xfuser/envs.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,24 +67,36 @@ def _is_mps():
6767
return torch.backends.mps.is_available()
6868

6969

70+
def _is_npu():
71+
try:
72+
if hasattr(torch, "npu") and torch.npu.is_available():
73+
return True
74+
except ModuleNotFoundError:
75+
return False
76+
77+
7078
def get_device(local_rank: int) -> torch.device:
71-
if torch.cuda.is_available():
79+
if _is_cuda():
7280
return torch.device("cuda", local_rank)
7381
elif _is_musa():
7482
return torch.device("musa", local_rank)
7583
elif _is_mps():
7684
return torch.device("mps")
85+
elif _is_npu():
86+
return torch.device("npu", local_rank)
7787
else:
7888
return torch.device("cpu")
7989

8090

8191
def get_device_name() -> str:
82-
if torch.cuda.is_available():
92+
if _is_cuda():
8393
return "cuda"
8494
elif _is_musa():
8595
return "musa"
8696
elif _is_mps():
8797
return "mps"
98+
elif _is_npu():
99+
return "npu"
88100
else:
89101
return "cpu"
90102

@@ -100,19 +112,23 @@ def get_device_version():
100112
return torch.version.musa
101113
elif _is_mps():
102114
return None
115+
elif _is_npu():
116+
return None
103117
else:
104118
raise NotImplementedError(
105119
"No Accelerators(AMD/NV/MTT GPU, AMD MI instinct accelerators) available"
106120
)
107121

108122

109123
def get_torch_distributed_backend() -> str:
110-
if torch.cuda.is_available():
124+
if _is_cuda():
111125
return "nccl"
112126
elif _is_musa():
113127
return "mccl"
114128
elif _is_mps():
115129
return "gloo"
130+
elif _is_npu():
131+
return "hccl"
116132
else:
117133
raise NotImplementedError(
118134
"No Accelerators(AMD/NV/MTT GPU, AMD MI instinct accelerators) available"
@@ -191,6 +207,12 @@ def check_aiter(self):
191207
def check_flash_attn(self):
192208
if not torch.cuda.is_available():
193209
return False
210+
211+
# Check if torch_npu is available
212+
if _is_npu():
213+
logger.info("falsh_attn is not ready on torch_npu for now")
214+
return False
215+
194216
if _is_musa():
195217
logger.info(
196218
"Flash Attention library is not supported on MUSA for the moment."

xfuser/model_executor/pipelines/pipeline_flux.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from xfuser.core.distributed.group_coordinator import GroupCoordinator
4141
from .base_pipeline import xFuserPipelineBaseWrapper
4242
from .register import xFuserPipelineWrapperRegister
43+
from ...envs import _is_npu
4344

4445
if is_torch_xla_available():
4546
import torch_xla.core.xla_model as xm
@@ -75,13 +76,14 @@ def prepare_run(
7576
prompt = [""] * input_config.batch_size if input_config.batch_size > 1 else ""
7677
warmup_steps = get_runtime_state().runtime_config.warmup_steps
7778
get_runtime_state().runtime_config.warmup_steps = sync_steps
79+
device = "npu" if _is_npu() else "cuda"
7880
self.__call__(
7981
height=input_config.height,
8082
width=input_config.width,
8183
prompt=prompt,
8284
num_inference_steps=steps,
8385
max_sequence_length=input_config.max_sequence_length,
84-
generator=torch.Generator(device="cuda").manual_seed(42),
86+
generator=torch.Generator(device=device).manual_seed(42),
8587
output_type=input_config.output_type,
8688
)
8789
get_runtime_state().runtime_config.warmup_steps = warmup_steps

0 commit comments

Comments
 (0)