Skip to content

Conversation

@ChenTaoyu-SJTU
Copy link
Contributor

@ChenTaoyu-SJTU ChenTaoyu-SJTU commented Sep 18, 2025

Refer to the issue: #567

This is the first PR for support Ascend NPU platform, we implement the flux model in a single NPU run. Here is the main change illustration:

  1. Add NPU env info in xfuser/envs.py
  2. use NPU specific method in xfuser/core/distributed/parallel_state.py and xfuser/model_executor/pipelines/pipeline_flux.py

The additional environment needed:

torch==2.7.1
torch_npu==2.7.1.dev20250724

This PR confirm the usage of tp and dp in npu xDiT. The following file and command can be verify the correction:

  1. /root/Workplace/xDiT_example/launch.sh
set -xe

torchrun --nproc_per_node=4 --start_method=spawn /root/Workplace/xDiT_example/sd3.py \
--model /root/.cache/modelscope/hub/models/stabilityai/stable-diffusion-3-medium-diffusers \
--height 1024 --width 1024 --no_use_resolution_binning --guidance_scale 3.5 \
--num_inference_steps 50 \
--warmup_steps 1 \
--prompt "brown dog laying on the ground with a metal bowl in front of him." "A small cat." "A good man" \
--tensor_parallel_degree 2 \
--data_parallel_degree 2
  1. /root/Workplace/xDiT_example/sd3.py :
import time
import os
import torch
import torch.distributed
from transformers import T5EncoderModel
from xfuser import xFuserStableDiffusion3Pipeline, xFuserArgs
from xfuser.config import FlexibleArgumentParser
from xfuser.core.distributed import (
    get_world_group,
    is_dp_last_group,
    get_data_parallel_rank,
    get_runtime_state,
)
from xfuser.core.distributed.parallel_state import get_data_parallel_world_size

def main():
    
    parser = FlexibleArgumentParser(description="xFuser Arguments")
    args = xFuserArgs.add_cli_args(parser).parse_args()
    engine_args = xFuserArgs.from_cli_args(args)
    engine_config, input_config = engine_args.create_config()
    local_rank = get_world_group().local_rank
    text_encoder_3 = T5EncoderModel.from_pretrained(engine_config.model_config.model, subfolder="text_encoder_3", torch_dtype=torch.float16)
    if args.use_fp8_t5_encoder:
        from optimum.quanto import freeze, qfloat8, quantize
        print(f"rank {local_rank} quantizing text encoder 2")
        quantize(text_encoder_3, weights=qfloat8)
        freeze(text_encoder_3)

    pipe = xFuserStableDiffusion3Pipeline.from_pretrained(
        pretrained_model_name_or_path=engine_config.model_config.model,
        engine_config=engine_config,
        torch_dtype=torch.float16,
        text_encoder_3=text_encoder_3,
    ).to(f"npu:{local_rank}")

    parameter_peak_memory = torch.npu.max_memory_allocated(device=f"npu:{local_rank}")

    pipe.prepare_run(input_config)
    
    torch.npu.reset_peak_memory_stats()
    start_time = time.time()
    output = pipe(
        height=input_config.height,
        width=input_config.width,
        prompt=input_config.prompt,
        num_inference_steps=input_config.num_inference_steps,
        output_type=input_config.output_type,
        guidance_scale=input_config.guidance_scale,
        generator=torch.Generator(device="npu").manual_seed(input_config.seed),
    )
    end_time = time.time()
    elapsed_time = end_time - start_time
    peak_memory = torch.npu.max_memory_allocated(device=f"npu:{local_rank}")

    parallel_info = (
        f"dp{engine_args.data_parallel_degree}_cfg{engine_config.parallel_config.cfg_degree}_"
        f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_"
        f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}"
    )
    if input_config.output_type == "pil":
        dp_group_index = get_data_parallel_rank()
        num_dp_groups = get_data_parallel_world_size()
        dp_batch_size = (input_config.batch_size + num_dp_groups - 1) // num_dp_groups
        if pipe.is_dp_last_group():
            if not os.path.exists("results"):
                os.mkdir("results")
            for i, image in enumerate(output.images):
                image_rank = dp_group_index * dp_batch_size + i
                image.save(
                    f"./results/stable_diffusion_3_result_{parallel_info}_{image_rank}.png"
                )
                print(
                    f"image {i} saved to ./results/stable_diffusion_3_result_{parallel_info}_{image_rank}.png"
                )

    if get_world_group().rank == get_world_group().world_size - 1:
        print(
            f"epoch time: {elapsed_time:.2f} sec, parameter memory: {parameter_peak_memory/1e9:.2f} GB, peak memory: {peak_memory/1e9:.2f} GB"
        )

    get_runtime_state().destroy_distributed_env()


if __name__ == "__main__":
    main()

And run the command:

chmod +x /root/Workplace/xDiT_example/launch.sh
/root/Workplace/xDiT_example/launch.sh

@ChenTaoyu-SJTU ChenTaoyu-SJTU changed the title Add NPU support for one model in one node Add NPU support for one model in single card Sep 19, 2025
Copy link
Collaborator

@feifeibear feifeibear left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ChenTaoyu-SJTU
Copy link
Contributor Author

@feifeibear The TP and DP parallel is ready for npu now. I also tried pipefusion, but it doesn't work for npu currently, since its async p2p communication behavior is not same as cuda. We plan to support the USP and pipefusion after this PR merged first, what do you think?

command and Result:
image

@feifeibear feifeibear merged commit e559fe8 into xdit-project:main Oct 13, 2025
This was referenced Oct 17, 2025
@GuangyuZhu04
Copy link

Do you have any Ascend docker environment available that can currently run this code?

@jcaraban jcaraban mentioned this pull request Nov 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants