Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def get_cuda_version():
]
cuda_version = version_line.split(" ")[-2].replace(",", "")
return "cu" + cuda_version.replace(".", "")
except Exception as e:
except (subprocess.CalledProcessError, FileNotFoundError):
return "no_cuda"


Expand All @@ -26,7 +26,7 @@ def get_cuda_version():
author_email="[email protected]",
packages=find_packages(),
install_requires=[
"torch>=2.1.0",
"torch==2.4.1",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you make it as torch>=2.4.1

"accelerate>=0.33.0",
"transformers>=4.39.1",
"sentencepiece>=0.1.99",
Expand Down
37 changes: 37 additions & 0 deletions tests/core/test_envs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import unittest
from unittest.mock import patch
import torch
from xfuser import envs

class TestEnvs(unittest.TestCase):

@patch('torch.cuda.is_available', return_value=True)
def test_get_device_cuda(self, mock_is_available):
device = envs.get_device(0)
self.assertEqual(device.type, 'cuda')
self.assertEqual(device.index, 0)
device_name = envs.get_device_name()
self.assertEqual(device_name, 'cuda')

@patch('torch.cuda.is_available', return_value=False)
@patch('xfuser.envs._is_mps', return_value=True)
def test_get_device_mps(self, mock_is_mps, mock_is_available):
device = envs.get_device(0)
self.assertEqual(device.type, 'mps')
device_name = envs.get_device_name()
self.assertEqual(device_name, 'mps')
# test that getting CUDA_VERSION does not raise an error
cuda_version = envs.CUDA_VERSION
self.assertIsNotNone(cuda_version)

@patch('torch.cuda.is_available', return_value=False)
@patch('xfuser.envs._is_mps', return_value=False)
@patch('xfuser.envs._is_musa', return_value=False)
def test_get_device_cpu(self, mock_is_musa, mock_is_mps, mock_is_available):
device = envs.get_device(0)
self.assertEqual(device.type, 'cpu')
device_name = envs.get_device_name()
self.assertEqual(device_name, 'cpu')

if __name__ == '__main__':
unittest.main()
6 changes: 5 additions & 1 deletion xfuser/core/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,10 @@ def init_distributed_environment(
rank: int = -1,
distributed_init_method: str = "env://",
local_rank: int = -1,
backend: str = envs.get_torch_distributed_backend(),
backend: Optional[str] = None,
):
if backend is None:
backend = envs.get_torch_distributed_backend()
logger.debug(
"world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s",
world_size,
Expand Down Expand Up @@ -337,6 +339,8 @@ def initialize_model_parallel(
vae_parallel_size: int = 0,
backend: Optional[str] = None,
) -> None:
if backend is None:
backend = envs.get_torch_distributed_backend()
"""
Initialize model parallel groups.

Expand Down
32 changes: 23 additions & 9 deletions xfuser/core/long_ctx_attention/hybrid/attn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,20 @@
from torch import Tensor

import torch.distributed
from yunchang import LongContextAttention
try:
from yunchang.kernels import AttnType
except ImportError:
raise ImportError("Please install yunchang 0.6.0 or later")

from yunchang.comm.all_to_all import SeqAllToAll4D
from yunchang.globals import HAS_SPARSE_SAGE_ATTENTION
if torch.cuda.is_available():
from yunchang import LongContextAttention
try:
from yunchang.kernels import AttnType
except ImportError:
raise ImportError("Please install yunchang 0.6.0 or later")

from yunchang.comm.all_to_all import SeqAllToAll4D
from yunchang.globals import HAS_SPARSE_SAGE_ATTENTION
else:
LongContextAttention = object
AttnType = None
HAS_SPARSE_SAGE_ATTENTION = False


from xfuser.logger import init_logger
Expand All @@ -33,7 +39,7 @@ def __init__(
use_pack_qkv: bool = False,
use_kv_cache: bool = False,
use_sync: bool = False,
attn_type: AttnType = AttnType.FA,
attn_type: AttnType = None,
attn_processor: torch.nn.Module = None,
q_descale=None,
k_descale=None,
Expand All @@ -57,6 +63,10 @@ def __init__(
use_sync=use_sync,
attn_type = attn_type,
)
if attn_type is None:
if torch.cuda.is_available():
from yunchang.kernels import AttnType
attn_type = AttnType.FA
self.use_kv_cache = use_kv_cache
self.q_descale = q_descale
self.k_descale = k_descale
Expand Down Expand Up @@ -227,8 +237,12 @@ def __init__(self,
ring_impl_type: str = "basic",
use_pack_qkv: bool = False,
use_kv_cache: bool = False,
attn_type: AttnType = AttnType.FA,
attn_type: AttnType = None,
attn_processor: torch.nn.Module = None):
if attn_type is None:
if torch.cuda.is_available():
from yunchang.kernels import AttnType
attn_type = AttnType.FA
super().__init__(scatter_idx, gather_idx, ring_impl_type, use_pack_qkv, use_kv_cache, attn_type, attn_processor)
# TODO need to check the attn_type
from xfuser.core.long_ctx_attention.ring import xdit_sana_ring_flash_attn_func
Expand Down
17 changes: 13 additions & 4 deletions xfuser/core/long_ctx_attention/ring/ring_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,26 @@

from xfuser.core.long_ctx_attention import xFuserLongContextAttention
from xfuser.core.cache_manager.cache_manager import get_cache_manager
from yunchang.ring.utils import RingComm, update_out_and_lse
from yunchang.ring.ring_flash_attn import RingFlashAttnFunc
from yunchang.kernels import select_flash_attn_impl, AttnType
if torch.cuda.is_available():
from yunchang.ring.utils import RingComm, update_out_and_lse
from yunchang.ring.ring_flash_attn import RingFlashAttnFunc
from yunchang.kernels import select_flash_attn_impl, AttnType
else:
RingComm = object
RingFlashAttnFunc = object
AttnType = None
select_flash_attn_impl = None

try:
import flash_attn
from flash_attn.flash_attn_interface import _flash_attn_forward
except ImportError:
flash_attn = None
_flash_attn_forward = None
from yunchang.kernels.attention import pytorch_attn_forward
if torch.cuda.is_available():
from yunchang.kernels.attention import pytorch_attn_forward
else:
pytorch_attn_forward = None

def xdit_ring_flash_attn_forward(
process_group,
Expand Down
18 changes: 17 additions & 1 deletion xfuser/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,17 @@ def _is_musa():
return False


def _is_mps():
return torch.backends.mps.is_available()


def get_device(local_rank: int) -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda", local_rank)
elif _is_musa():
return torch.device("musa", local_rank)
elif _is_mps():
return torch.device("mps")
else:
return torch.device("cpu")

Expand All @@ -77,6 +83,8 @@ def get_device_name() -> str:
return "cuda"
elif _is_musa():
return "musa"
elif _is_mps():
return "mps"
else:
return "cpu"

Expand All @@ -90,6 +98,8 @@ def get_device_version():
return torch.version.cuda
elif _is_musa():
return torch.version.musa
elif _is_mps():
return None
else:
raise NotImplementedError(
"No Accelerators(AMD/NV/MTT GPU, AMD MI instinct accelerators) available"
Expand All @@ -101,6 +111,8 @@ def get_torch_distributed_backend() -> str:
return "nccl"
elif _is_musa():
return "mccl"
elif _is_mps():
return "gloo"
else:
raise NotImplementedError(
"No Accelerators(AMD/NV/MTT GPU, AMD MI instinct accelerators) available"
Expand All @@ -110,7 +122,7 @@ def get_torch_distributed_backend() -> str:
variables: Dict[str, Callable[[], Any]] = {
# ================== Other Vars ==================
# used in version checking
"CUDA_VERSION": lambda: version.parse(get_device_version()),
"CUDA_VERSION": lambda: version.parse(get_device_version() or "0.0"),
"TORCH_VERSION": lambda: version.parse(
version.parse(torch.__version__).base_version
),
Expand Down Expand Up @@ -159,6 +171,8 @@ def initialize(self):
}

def check_flash_attn(self):
if not torch.cuda.is_available():
return False
if _is_musa():
logger.info(
"Flash Attention library is not supported on MUSA for the moment."
Expand All @@ -184,6 +198,8 @@ def check_flash_attn(self):
return False

def check_long_ctx_attn(self):
if not torch.cuda.is_available():
return False
try:
from yunchang import (
set_seq_parallel_pg,
Expand Down
13 changes: 7 additions & 6 deletions xfuser/model_executor/cache/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
get_sp_group,
get_sequence_parallel_world_size,
)
from xfuser.envs import get_device

import torch
from torch.nn import Module
Expand All @@ -18,8 +19,8 @@
class CacheContext(Module):
def __init__(self):
super().__init__()
self.register_buffer("default_coef", torch.tensor([1.0, 0.0]).cuda())
self.register_buffer("flux_coef", torch.tensor([498.651651, -283.781631, 55.8554382, -3.82021401, 0.264230861]).cuda())
self.register_buffer("default_coef", torch.tensor([1.0, 0.0]).to(get_device(0)))
self.register_buffer("flux_coef", torch.tensor([498.651651, -283.781631, 55.8554382, -3.82021401, 0.264230861]).to(get_device(0)))

self.register_buffer("original_hidden_states", None, persistent=False)
self.register_buffer("original_encoder_hidden_states", None, persistent=False)
Expand Down Expand Up @@ -90,14 +91,14 @@ def __init__(
self.transformer_blocks = torch.nn.ModuleList(transformer_blocks)
self.single_transformer_blocks = torch.nn.ModuleList(single_transformer_blocks) if single_transformer_blocks else None
self.transformer = transformer
self.register_buffer("cnt", torch.tensor(0).cuda())
self.register_buffer("accumulated_rel_l1_distance", torch.tensor([0.0]).cuda())
self.register_buffer("use_cache", torch.tensor(False, dtype=torch.bool).cuda())
self.register_buffer("cnt", torch.tensor(0).to(get_device(0)))
self.register_buffer("accumulated_rel_l1_distance", torch.tensor([0.0]).to(get_device(0)))
self.register_buffer("use_cache", torch.tensor(False, dtype=torch.bool).to(get_device(0)))

self.cache_context = CacheContext()
self.callback_handler = CallbackHandler(callbacks)

self.rel_l1_thresh = torch.tensor(rel_l1_thresh).cuda()
self.rel_l1_thresh = torch.tensor(rel_l1_thresh).to(get_device(0))
self.return_hidden_states_first = return_hidden_states_first
self.num_steps = num_steps
self.name = name
Expand Down
5 changes: 4 additions & 1 deletion xfuser/model_executor/layers/usp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@

from torch.distributed.tensor.experimental._attention import _templated_ring_attention

from yunchang.globals import PROCESS_GROUP
if torch.cuda.is_available():
from yunchang.globals import PROCESS_GROUP
else:
PROCESS_GROUP = None

from xfuser.core.distributed import (
get_sequence_parallel_world_size,
Expand Down
11 changes: 8 additions & 3 deletions xfuser/model_executor/layers/usp_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@

import torch.distributed._functional_collectives as ft_c

from yunchang.globals import PROCESS_GROUP
from yunchang.ring.ring_flash_attn import ring_flash_attn_forward
from yunchang.ring.ring_pytorch_attn import ring_pytorch_attn_func
if torch.cuda.is_available():
from yunchang.globals import PROCESS_GROUP
from yunchang.ring.ring_flash_attn import ring_flash_attn_forward
from yunchang.ring.ring_pytorch_attn import ring_pytorch_attn_func
else:
PROCESS_GROUP = None
ring_flash_attn_forward = None
ring_pytorch_attn_func = None

from xfuser.core.distributed import (
get_sequence_parallel_world_size,
Expand Down
14 changes: 9 additions & 5 deletions xfuser/model_executor/pipelines/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
get_vae_parallel_group,
get_dit_group,
)
from xfuser.envs import (
get_device,
get_device_name,
)
from xfuser.core.fast_attention import (
get_fast_attn_enable,
initialize_fast_attn_state,
Expand Down Expand Up @@ -105,7 +109,7 @@ def reset_activation_cache(self):

def execute(self, output_type:str):
if self.vae is not None:
device = f"cuda:{get_world_group().local_rank}"
device = get_device(get_world_group().local_rank)
rank = get_world_group().rank
dit_parallel_size = self.dit_parallel_size
dtype = self.dtype
Expand Down Expand Up @@ -327,7 +331,7 @@ def prepare_run(
prompt=prompt,
use_resolution_binning=input_config.use_resolution_binning,
num_inference_steps=steps,
generator=torch.Generator(device="cuda").manual_seed(42),
generator=torch.Generator(device=get_device_name()).manual_seed(42),
output_type=input_config.output_type,
)
get_runtime_state().runtime_config.warmup_steps = warmup_steps
Expand All @@ -345,7 +349,7 @@ def latte_prepare_run(
# use_resolution_binning=input_config.use_resolution_binning,
num_inference_steps=steps,
output_type="latent",
generator=torch.Generator(device="cuda").manual_seed(42),
generator=torch.Generator(device=get_device_name()).manual_seed(42),
)
get_runtime_state().runtime_config.warmup_steps = warmup_steps

Expand Down Expand Up @@ -573,7 +577,7 @@ def gather_latents_for_vae(self, latents:torch.Tensor):
return latents

rank = get_world_group().rank
device = f"cuda:{get_world_group().local_rank}"
device = get_device(get_world_group().local_rank)
dit_parallel_size = get_dit_world_size()

# Gather only from DP last groups to the first VAE worker
Expand Down Expand Up @@ -609,7 +613,7 @@ def gather_broadcast_latents(self, latents:torch.Tensor):

# ---------gather latents from dp last group-----------
rank = get_world_group().rank
device = f"cuda:{get_world_group().local_rank}"
device = get_device(get_world_group().local_rank)

# all gather dp last group rank list
dp_rank_list = [torch.zeros(1, dtype=int, device=device) for _ in range(get_world_group().world_size)]
Expand Down