Skip to content
Merged
23 changes: 18 additions & 5 deletions paddlemix/triton_ops/triton_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,13 @@ def paddle_fused_adaLN(x, mha_out, gate, hidd, scale, shift, weight, bias, epsil
seq_size = x.shape[1]
N_npo2 = triton.next_power_of_2(N)

# baseline.
if os.getenv("INFERENCE_OPTIMIZE_TRITON") is None:
resi_out_paddle = mha_out * gate_msa.unsqueeze(axis=1) + x
norm_hidden_states = paddle.nn.functional.layer_norm(resi_out_paddle, [N], weight, bias, epsilon)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
return resi_out_paddle, norm_hidden_states

op_name = "triton_fused_adaLN_scale_residual"
op_name += get_dtype_str(x.dtype)
op_name += f"_{N_npo2}_{weight_attr}_{bias_attr}"
Expand All @@ -865,9 +872,9 @@ def paddle_fused_adaLN(x, mha_out, gate, hidd, scale, shift, weight, bias, epsil
shift_mlp,
resi_out,
adaLN_out,
M,
-1,
N,
seq_size,
-1,
epsilon,
N_npo2=N_npo2,
weight_attr=weight_attr,
Expand Down Expand Up @@ -1072,7 +1079,13 @@ def modulate(x, shift, scale):
M = x.shape[0] * x.shape[1]
N = x.shape[2]
seq_size = x.shape[1]
BLOCK_SIZE = min(1024, triton.next_power_of_2(N))
BLOCK_SIZE = triton.next_power_of_2(N)

# baseline.
if os.getenv("INFERENCE_OPTIMIZE_TRITON") is None:
norm_hidden_states = paddle.nn.functional.layer_norm(x, [N], weight, bias, epsilon)
norm_hidden_states = norm_hidden_states * (1 + scale[:, None]) + shift[:, None]
return norm_hidden_states

op_name = "triton_adaptive_layer_norm"
op_name += get_dtype_str(x.dtype)
Expand All @@ -1096,9 +1109,9 @@ def modulate(x, shift, scale):
y,
y,
y,
M,
-1,
N,
seq_size,
-1,
epsilon,
BLOCK_SIZE=BLOCK_SIZE,
weight_attr=weight_attr,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,29 @@ image = pipe(class_labels=class_ids, num_inference_steps=25, generator=generator
image.save("result_DiT_golden_retriever.png")
```

### 2.3 Paddle Inference 高性能推理

- Paddle Inference提供DIT模型高性能推理实现,推理性能提升80%+
推理步骤如下:
```shell
# 安装develop版本的paddle
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu123/
# 安装 triton
python -m pip install triton

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个会安装上torch和一系列依赖吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个会安装上torch和一系列依赖吧
已经修改为更全面的介绍文档,给出paddle适配triton的方法。 辛苦!

```
一键推理指令:
```shell
python ppdiffusers/examples/inference/class_conditional_image_generation-dit.py --inference_optimize 1
```

- 在 NVIDIA A100-SXM4-40GB 上测试的性能如下:

| Paddle Inference| TensorRT-LLM | Paddle动态图 |
| --------------- | ------------ | ------------ |
| 219 ms | 242 ms | 1200 ms |




## 引用
```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,24 @@ def parse_args():
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
set_seed(42)

if args.inference_optimize:
# optimize the transformer using paddle.incubate.jit.inference
pipe.transformer = paddle.incubate.jit.inference(
pipe.transformer,
enable_new_ir=True,
save_model_dir="./tmp/dit",
cache_static_model=True,
exp_enable_use_cutlass=True,
delete_pass_lists=["add_norm_fuse_pass"],
)
pipe.vae.decode = paddle.incubate.jit.inference(
pipe.vae.decode,
enable_new_ir=True,
save_model_dir="./tmp/dit/vae",
cache_static_model=True,
exp_enable_use_cutlass=True,
)

words = ["golden retriever"] # class_ids [207]
class_ids = pipe.get_label_ids(words)
image = pipe(class_labels=class_ids, num_inference_steps=25).images[0]
Expand All @@ -71,15 +89,15 @@ def parse_args():

repeat_times = 5

paddle.device.synchronize()
starttime = datetime.datetime.now()
for i in range(repeat_times):
paddle.device.synchronize()
starttime = datetime.datetime.now()
image = pipe(class_labels=class_ids, num_inference_steps=25).images[0]
paddle.device.synchronize()
endtime = datetime.datetime.now()
paddle.device.synchronize()
endtime = datetime.datetime.now()

duringtime = endtime - starttime
time_ms = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0
print("The ave end to end time : ", time_ms / repeat_times, "ms")
duringtime = endtime - starttime
time_ms = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0
print("The this end to end time : ", time_ms, "ms")

image.save("class_conditional_image_generation-dit-result.png")
42 changes: 19 additions & 23 deletions ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import math
import os

import paddle
import paddle.nn.functional as F
Expand Down Expand Up @@ -84,6 +83,10 @@ def forward(self, hidden_states, timesteps, class_labels):
emb = paddle.concat([paddle.cos(emb), paddle.sin(emb)], axis=-1)
common_emb = emb.cast(hidden_states.dtype)

last_ffn_output = None
last_hidden_states = None
last_gate_mlp = None

for i in range(self.num_layers):
emb = self.fcs0[i](common_emb)
emb = F.silu(emb)
Expand All @@ -94,44 +97,37 @@ def forward(self, hidden_states, timesteps, class_labels):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, axis=1)
import paddlemix

if os.getenv("INFERENCE_OPTIMIZE_TRITON"):
if last_ffn_output is None:
norm_hidden_states = paddlemix.triton_ops.adaptive_layer_norm(
hidden_states, scale_msa, shift_msa, epsilon=1e-06
)
else:
norm_hidden_states = self.norm(
hidden_states,
hidden_states, norm_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual(
last_hidden_states, last_ffn_output, last_gate_mlp, scale_msa, shift_msa, epsilon=1e-06
)
norm_hidden_states = norm_hidden_states * (1 + scale_msa[:, None]) + shift_msa[:, None]

q = self.q[i](norm_hidden_states).reshape([0, 0, self.heads_num, self.head_dim])
k = self.k[i](norm_hidden_states).reshape([0, 0, self.heads_num, self.head_dim])
v = self.v[i](norm_hidden_states).reshape([0, 0, self.heads_num, self.head_dim])

norm_hidden_states = F.scaled_dot_product_attention_(q, k, v, scale=self.head_dim**-0.5)
norm_hidden_states = norm_hidden_states.reshape(
[norm_hidden_states.shape[0], norm_hidden_states.shape[1], self.dim]
)
norm_hidden_states = norm_hidden_states.reshape([0, 0, self.dim])
norm_hidden_states = self.out_proj[i](norm_hidden_states)
if os.getenv("INFERENCE_OPTIMIZE_TRITON"):
hidden_states, norm_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual(
hidden_states, norm_hidden_states, gate_msa, scale_mlp, shift_mlp, epsilon=1e-05
)
else:
hidden_states = hidden_states + norm_hidden_states * gate_msa.reshape(
[norm_hidden_states.shape[0], 1, self.dim]
)
norm_hidden_states = self.norm1(
hidden_states,
)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]

hidden_states, norm_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual(
hidden_states, norm_hidden_states, gate_msa, scale_mlp, shift_mlp, epsilon=1e-05
)

norm_hidden_states = self.ffn1[i](norm_hidden_states)
norm_hidden_states = F.gelu(norm_hidden_states, approximate=True)
norm_hidden_states = self.ffn2[i](norm_hidden_states)

hidden_states = hidden_states + norm_hidden_states * gate_mlp.reshape(
[norm_hidden_states.shape[0], 1, self.dim]
)
last_ffn_output = norm_hidden_states
last_hidden_states = hidden_states
last_gate_mlp = gate_mlp

hidden_states = hidden_states + norm_hidden_states * gate_mlp.reshape(
[norm_hidden_states.shape[0], 1, self.dim]
)

return hidden_states
10 changes: 2 additions & 8 deletions ppdiffusers/ppdiffusers/models/transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,13 +225,6 @@ def __init__(
self.simplified_facebookdit = SimplifiedFacebookDIT(
num_layers, inner_dim, num_attention_heads, attention_head_dim
)
self.simplified_facebookdit = paddle.incubate.jit.inference(
self.simplified_facebookdit,
enable_new_ir=True,
cache_static_model=False,
exp_enable_use_cutlass=True,
delete_pass_lists=["add_norm_fuse_pass"],
)

# 4. Define output layers
self.out_channels = in_channels if out_channels is None else out_channels
Expand Down Expand Up @@ -498,7 +491,8 @@ def custom_forward(*inputs):
hidden_states = hidden_states.reshape(
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
)
hidden_states = paddle.einsum("nhwpqc->nchpwq", hidden_states)
# hidden_states = paddle.einsum("nhwpqc->nchpwq", hidden_states)
hidden_states = hidden_states.transpose([0, 5, 1, 3, 2, 4])
output = hidden_states.reshape(
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
)
Expand Down
2 changes: 1 addition & 1 deletion ppdiffusers/ppdiffusers/patches/paddle_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def scaled_dot_product_attention_(
pre_cache_length=0,
).transpose([0, 2, 1, 3])
elif attention_op == "flash":
with requires_grad_and_without_random(query, key, value):
with requires_grad_and_without_random(query, key, value, stop_gradient=False):
output = paddle.nn.functional.scaled_dot_product_attention(
query,
key,
Expand Down
17 changes: 13 additions & 4 deletions ppdiffusers/ppdiffusers/pipelines/dit/pipeline_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,12 @@ def __call__(
]
)
# predict noise model_output
noise_pred = self.transformer(
latent_model_input, timestep=timesteps, class_labels=class_labels_input
).sample
noise_pred_out = self.transformer(latent_model_input, timestep=timesteps, class_labels=class_labels_input)
if paddle.incubate.jit.is_inference_mode(self.transformer):
# self.transformer run in paddle inference.
noise_pred = noise_pred_out
else:
noise_pred = noise_pred_out.sample

# perform guidance
if guidance_scale > 1:
Expand Down Expand Up @@ -222,7 +225,13 @@ def __call__(
latents = latent_model_input

latents = 1 / self.vae.config.scaling_factor * latents
samples = self.vae.decode(latents).sample

samples_out = self.vae.decode(latents)
if paddle.incubate.jit.is_inference_mode(self.vae.decode):
# self.vae.decode run in paddle inference.
samples = samples_out
else:
samples = samples_out.sample

samples = (samples / 2 + 0.5).clip(0, 1)

Expand Down