Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 137 additions & 0 deletions examples/run_amdgpu_1x8.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
#!/usr/bin/bash
SCRIPT_ROOT=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )/.." &> /dev/null && pwd )

set -x

# single node disable IB device
export NCCL_IB_DISABLE=1
export RCCL_IB_DISABLE=1

# enable DMA buffer
export NCCL_ENABLE_DMABUF_SUPPORT=1
export RCCL_ENABLE_DMABUF_SUPPORT=1

export NCCL_DEBUG=INFO
export RCCL_DEBUG=INFO

# enable sharp in multinodes config

# Select the model type from Pixart-alpha, Pixart-sigma, Sd3, or Flux
# The model is downloaded to a specified location on disk,
# or you can simply use the model's ID on Hugging Face,
# which will then be downloaded to the default cache path on Hugging Face.

export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
MASTER_ADDR=localhost
MASTER_PORT=$(shuf -n 1 -i 10000-65535)
NNODES=1
NODE_RANK=0

OUTPUT_BASEPATH=$SCRIPT_ROOT

# pipeline parallel
PP=${PP:-2}

# tensor parllel
TP=${TP:-1}

# ring deg
CP=${CP:-1}



export PYTHONPATH=$PWD:$PYTHONPATH

export MODEL_TYPE="Pixart-alpha"

CFG_ARGS="--use_cfg_parallel"

if [ "$MODEL_TYPE" = "Pixart-alpha" ]; then
export SCRIPT=pixartalpha_example.py
export MODEL_ID="/mnt/models/SD/PixArt-XL-2-1024-MS"
export INFERENCE_STEP=20
elif [ "$MODEL_TYPE" = "Pixart-sigma" ]; then
export SCRIPT=pixartsigma_example.py
export MODEL_ID="/cfs/dit/PixArt-Sigma-XL-2-2K-MS"
export INFERENCE_STEP=20
elif [ "$MODEL_TYPE" = "Sd3" ]; then
export SCRIPT=sd3_example.py
export MODEL_ID="/mnt/models/SD/stable-diffusion-3-medium-diffusers"
export INFERENCE_STEP=20
elif [ "$MODEL_TYPE" = "Flux" ]; then
export SCRIPT=flux_example.py
export MODEL_ID="/mnt/models/SD/FLUX.1-schnell"
export INFERENCE_STEP=4
# Flux does not apply cfg
export CFG_ARGS=""
elif [ "$MODEL_TYPE" = "HunyuanDiT" ]; then
export SCRIPT=hunyuandit_example.py
export MODEL_ID="/mnt/models/SD/HunyuanDiT-v1.2-Diffusers"
export INFERENCE_STEP=20
else
echo "Invalid MODEL_TYPE: $MODEL_TYPE"
exit 1
fi


mkdir -p ./results

# for HEIGHT in 1024
# do
# for N_GPUS in 8;
# do

HEIGHT=1024
N_GPUS=8

DISTR_ARGS="
--nproc_per_node $N_GPUS \
--nnodes $NNODES \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT
"

TASK_ARGS="--height $HEIGHT \
--width $HEIGHT \
--no_use_resolution_binning \
"

# On 8 gpus, pp=2, ulysses=2, ring=1, cfg_parallel=2 (split batch)
PARALLEL_ARGS="--pipefusion_parallel_degree $PP --ulysses_degree 2 --ring_degree $CP"

# Flux only supports SP, do not set the pipefusion degree
if [ "$MODEL_TYPE" = "Flux" ]; then
PARALLEL_ARGS="--ulysses_degree $N_GPUS"
elif [ "$MODEL_TYPE" = "HunyuanDiT" ]; then
echo "change PP from $PP to 1"
PP=1
PARALLEL_ARGS="--pipefusion_parallel_degree $PP --ulysses_degree 8 --ring_degree $CP"
fi

echo "PARALLEL ARGS : ${PARALLEL_ARGS}"

# By default, num_pipeline_patch = pipefusion_degree, and you can tune this parameter to achieve optimal performance.
# PIPEFUSION_ARGS="--num_pipeline_patch 8 "

# For high-resolution images, we use the latent output type to avoid runing the vae module. Used for measuring speed.
# OUTPUT_ARGS="--output_type latent"

mkdir -p ${OUTPUT_BASEPATH}/log/${MODEL_TYPE}

torchrun $DISTR_ARGS $SCRIPT_ROOT/examples/$SCRIPT \
--model $MODEL_ID \
$PARALLEL_ARGS \
$TASK_ARGS \
$PIPEFUSION_ARGS \
$OUTPUT_ARGS \
--num_inference_steps $INFERENCE_STEP \
--warmup_steps 0 \
--prompt "A small dog" \
$CFG_ARGS \
&> ${OUTPUT_BASEPATH}/log/${MODEL_TYPE}/${NODE_RANK}.log

# done
# done


14 changes: 14 additions & 0 deletions requirements-rocm.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# torch>=2.3.0,

numpy==1.24.4
scipy>=1.8.1
diffusers>=0.31.0
transformers>=4.39.1
sentencepiece>=0.1.99
accelerate==0.33.0
beautifulsoup4>=4.12.3
distvae
ftfy>=6.2.0

# long context attention (pytoch native impl on top of flash attention)
yunchang @ git+https://github.com/yiakwy-xpu-ml-framework-team/xDiT-long-context-attention-fork.git@add_amd_gpu_suppport
120 changes: 106 additions & 14 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import logging
from setuptools import find_packages, setup
import subprocess

from typing import List

def get_cuda_version():
try:
Expand All @@ -13,19 +15,92 @@ def get_cuda_version():
except (subprocess.CalledProcessError, FileNotFoundError):
return "no_cuda"

try:
import torch
from torch.utils.cpp_extension import ROCM_HOME
except:
print("base env does not provide torch distribution")

if __name__ == "__main__":
with open("README.md", "r") as f:
long_description = f.read()
fp = open("xfuser/__version__.py", "r").read()
version = eval(fp.strip().split()[-1])
## Constant
HIP_VERSION_PAT = r'HIP version: (\S+)'
HIP_SDK_ROOT = "/opt/rocm"
# currently only support MI30X (MI308X, MI300XA) datacenter intelligent computing accelerator
ALLOWED_AMDGPU_ARCHS = ["gfx942"]

setup(
name="xfuser",
author="xDiT Team",
author_email="[email protected]",
packages=find_packages(),
install_requires=[
ROOT_DIR = os.path.dirname(__file__)

logger = logging.getLogger(__name__)

## ROCM helper
def _is_hip() -> bool:
SDK_ROOT=f"{HIP_SDK_ROOT}"
def _check_sdk_installed() -> bool:
# return True if this dir points to a directory or symbolic link
return os.path.isdir(SDK_ROOT)

if not _check_sdk_installed():
return False

# we provide torch for the base env, check whether it is valid installation
has_rocm = torch.version.hip is not None

if has_rocm:
result = subprocess.run([f"{SDK_ROOT}/bin/rocminfo", " | grep -o -m1 'gfx.*'"],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True)

if result.returncode != 0:
print("Use AMD pytorch, but no devices found!")
return False

# target_amdgpu_arch = result.stdout
print(f"target AMD gpu arch {result.stdout}")
return has_rocm

def get_hipcc_rocm_version():
assert _is_hip()

result = subprocess.run(['hipcc', '--version'],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True)

# Check if the command was executed successfully
if result.returncode != 0:
print("Error running 'hipcc --version'")
return None

# Extract the version using a regular expression
match = re.search(HIP_VERSION_PAT, result.stdout)
if match:
# Return the version string
return match.group(1)
else:
print("Could not find HIP version in the output")
return None

def get_path(*filepath) -> str:
return os.path.join(ROOT_DIR, *filepath)
def get_requirements() -> List[str]:
"""Get Python package dependencies from requirements.txt."""

def _read_requirements(filename: str) -> List[str]:
with open(get_path(filename)) as f:
requirements = f.read().strip().split("\n")
resolved_requirements = []
for line in requirements:
if line.startswith("-r "):
resolved_requirements += _read_requirements(line.split()[1])
else:
resolved_requirements.append(line)
return resolved_requirements

if _is_hip():
requirements = _read_requirements("requirements-rocm.txt")
extras_require = {}
else:
requirements = [
"torch>=2.4.1",
"accelerate>=0.33.0",
"transformers>=4.39.1",
Expand All @@ -34,7 +109,7 @@ def get_cuda_version():
"distvae",
"yunchang>=0.6.0",
"einops",
],
]
extras_require={
"diffusers": [
"diffusers>=0.31.0", # NOTE: diffusers>=0.32.0.dev is necessary for CogVideoX and Flux
Expand All @@ -59,7 +134,24 @@ def get_cuda_version():
"imageio",
"imageio-ffmpeg"
]
},
}
return requirements, extras_require

if __name__ == "__main__":
with open("README.md", "r") as f:
long_description = f.read()
fp = open("xfuser/__version__.py", "r").read()
version = eval(fp.strip().split()[-1])

requirements, extra_requirements = get_requirements()

setup(
name="xfuser",
author="xDiT Team",
author_email="[email protected]",
packages=find_packages(),
install_requires=requirements,
extras_require=extra_requirements,
url="https://github.com/xdit-project/xDiT.",
description="A Scalable Inference Engine for Diffusion Transformers (DiTs) on Multiple Computing Devices",
long_description=long_description,
Expand Down
3 changes: 3 additions & 0 deletions xfuser/core/long_ctx_attention/ring/ring_flash_attn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import List
import math
import torch
from flash_attn.flash_attn_interface import _flash_attn_forward
from yunchang.ring.utils import RingComm, update_out_and_lse, get_default_args
from yunchang.ring.ring_flash_attn import RingFlashAttnFunc
import torch.nn.functional as F

from xfuser.core.long_ctx_attention import xFuserLongContextAttention
Expand Down
4 changes: 0 additions & 4 deletions xfuser/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,14 @@
"XDIT_LOGGING_LEVEL": lambda: os.getenv("XDIT_LOGGING_LEVEL", "INFO"),
}


def _is_hip():
has_rocm = torch.version.hip is not None
return has_rocm


def _is_cuda():
has_cuda = torch.version.cuda is not None
return has_cuda


def _is_musa():
try:
if hasattr(torch, "musa") and torch.musa.is_available():
Expand Down Expand Up @@ -118,7 +115,6 @@ def get_torch_distributed_backend() -> str:
"No Accelerators(AMD/NV/MTT GPU, AMD MI instinct accelerators) available"
)


variables: Dict[str, Callable[[], Any]] = {
# ================== Other Vars ==================
# used in version checking
Expand Down