Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions xfuser/core/distributed/group_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
except ModuleNotFoundError:
pass

import xfuser.envs as envs
if envs._is_npu():
print("torch.npu synchronize")
from torch.npu import synchronize

import xfuser.envs as envs
from xfuser.logger import init_logger

Expand Down
10 changes: 10 additions & 0 deletions xfuser/core/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
except ModuleNotFoundError:
pass

try:
from torch.npu import set_device, device_count
except ModuleNotFoundError:
pass

from .utils import RankGenerator

env_info = envs.PACKAGES_CHECKER.get_packages_info()
Expand Down Expand Up @@ -396,6 +401,11 @@ def initialize_model_parallel(
f"sequence_parallel_degree is not equal to ring_degree * ulysses_degree, {sequence_parallel_degree} != {ring_degree} * {ulysses_degree}"
)

# FIXME: Since the async p2p communication operation of NPU is not same as cuda in torch,
# the pipefusion is not ready for npu yet
if envs._is_npu():
assert pipeline_parallel_degree == 1, "Current pipefusion is not ready for NPU"

dit_parallel_size = (
data_parallel_degree
* classifier_free_guidance_degree
Expand Down
4 changes: 4 additions & 0 deletions xfuser/core/distributed/runtime_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
except ModuleNotFoundError:
pass

import xfuser.envs as envs
if envs._is_npu():
from torch.npu import manual_seed as device_manual_seed
from torch.npu import manual_seed_all as device_manual_seed_all
from xfuser.config.config import (
ParallelConfig,
RuntimeConfig,
Expand Down
4 changes: 4 additions & 0 deletions xfuser/core/utils/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
except ModuleNotFoundError:
pass

import xfuser.envs as envs
if envs._is_npu():
from torch.npu import synchronize

def gpu_timer_decorator(func):
def wrapper(*args, **kwargs):
synchronize()
Expand Down
28 changes: 25 additions & 3 deletions xfuser/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,24 +67,36 @@ def _is_mps():
return torch.backends.mps.is_available()


def _is_npu():
try:
if hasattr(torch, "npu") and torch.npu.is_available():
return True
except ModuleNotFoundError:
return False


def get_device(local_rank: int) -> torch.device:
if torch.cuda.is_available():
if _is_cuda():
return torch.device("cuda", local_rank)
elif _is_musa():
return torch.device("musa", local_rank)
elif _is_mps():
return torch.device("mps")
elif _is_npu():
return torch.device("npu", local_rank)
else:
return torch.device("cpu")


def get_device_name() -> str:
if torch.cuda.is_available():
if _is_cuda():
return "cuda"
elif _is_musa():
return "musa"
elif _is_mps():
return "mps"
elif _is_npu():
return "npu"
else:
return "cpu"

Expand All @@ -100,19 +112,23 @@ def get_device_version():
return torch.version.musa
elif _is_mps():
return None
elif _is_npu():
return None
else:
raise NotImplementedError(
"No Accelerators(AMD/NV/MTT GPU, AMD MI instinct accelerators) available"
)


def get_torch_distributed_backend() -> str:
if torch.cuda.is_available():
if _is_cuda():
return "nccl"
elif _is_musa():
return "mccl"
elif _is_mps():
return "gloo"
elif _is_npu():
return "hccl"
else:
raise NotImplementedError(
"No Accelerators(AMD/NV/MTT GPU, AMD MI instinct accelerators) available"
Expand Down Expand Up @@ -191,6 +207,12 @@ def check_aiter(self):
def check_flash_attn(self):
if not torch.cuda.is_available():
return False

# Check if torch_npu is available
if _is_npu():
logger.info("falsh_attn is not ready on torch_npu for now")
return False

if _is_musa():
logger.info(
"Flash Attention library is not supported on MUSA for the moment."
Expand Down
4 changes: 4 additions & 0 deletions xfuser/model_executor/layers/feedforward.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
except ModuleNotFoundError:
pass

import xfuser.envs as envs
if envs._is_npu():
from torch.npu import empty_cache

from xfuser.core.distributed.parallel_state import (
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_rank,
Expand Down
4 changes: 3 additions & 1 deletion xfuser/model_executor/pipelines/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from xfuser.core.distributed.group_coordinator import GroupCoordinator
from .base_pipeline import xFuserPipelineBaseWrapper
from .register import xFuserPipelineWrapperRegister
from ...envs import _is_npu

if is_torch_xla_available():
import torch_xla.core.xla_model as xm
Expand Down Expand Up @@ -75,13 +76,14 @@ def prepare_run(
prompt = [""] * input_config.batch_size if input_config.batch_size > 1 else ""
warmup_steps = get_runtime_state().runtime_config.warmup_steps
get_runtime_state().runtime_config.warmup_steps = sync_steps
device = "npu" if _is_npu() else "cuda"
self.__call__(
height=input_config.height,
width=input_config.width,
prompt=prompt,
num_inference_steps=steps,
max_sequence_length=input_config.max_sequence_length,
generator=torch.Generator(device="cuda").manual_seed(42),
generator=torch.Generator(device=device).manual_seed(42),
output_type=input_config.output_type,
)
get_runtime_state().runtime_config.warmup_steps = warmup_steps
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
)
from .base_pipeline import xFuserPipelineBaseWrapper
from .register import xFuserPipelineWrapperRegister
from ...envs import _is_npu

if is_torch_xla_available():
import torch_xla.core.xla_model as xm
Expand Down Expand Up @@ -74,12 +75,15 @@ def prepare_run(
prompt = [""] * input_config.batch_size if input_config.batch_size > 1 else ""
warmup_steps = get_runtime_state().runtime_config.warmup_steps
get_runtime_state().runtime_config.warmup_steps = sync_steps
device = "cuda"
if _is_npu():
device = "npu"
self.__call__(
height=input_config.height,
width=input_config.width,
prompt=prompt,
num_inference_steps=steps,
generator=torch.Generator(device="cuda").manual_seed(42),
generator=torch.Generator(device=device).manual_seed(42),
output_type=input_config.output_type,
)
get_runtime_state().runtime_config.warmup_steps = warmup_steps
Expand Down