Skip to content

Commit aeee830

Browse files
authored
Re-network the DIT, fix some parameters, and simplify the model networking code
Re-network the DIT, fix some parameters, and simplify the model networking code
2 parents 279a720 + c4f8242 commit aeee830

File tree

4 files changed

+280
-36
lines changed

4 files changed

+280
-36
lines changed

ppdiffusers/examples/inference/class_conditional_image_generation-dit.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,74 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import argparse
16+
import datetime
17+
import os
18+
1519
import paddle
1620
from paddlenlp.trainer import set_seed
1721

1822
from ppdiffusers import DDIMScheduler, DiTPipeline
1923

20-
dtype = paddle.float32
24+
25+
def parse_args():
26+
parser = argparse.ArgumentParser(
27+
description=" Use PaddleMIX to accelerate the Diffusion Transformer image generation model."
28+
)
29+
parser.add_argument(
30+
"--benchmark",
31+
type=(lambda x: str(x).lower() in ["true", "1", "yes"]),
32+
default=False,
33+
help="if benchmark is set to True, measure inference performance",
34+
)
35+
parser.add_argument(
36+
"--inference_optimize",
37+
type=(lambda x: str(x).lower() in ["true", "1", "yes"]),
38+
default=False,
39+
help="If inference_optimize is set to True, all optimizations except Triton are enabled.",
40+
)
41+
parser.add_argument(
42+
"--inference_optimize_triton",
43+
type=(lambda x: str(x).lower() in ["true", "1", "yes"]),
44+
default=True,
45+
help="If inference_optimize_triton is set to True, Triton operator optimized inference is enabled.",
46+
)
47+
return parser.parse_args()
48+
49+
50+
args = parse_args()
51+
52+
if args.inference_optimize:
53+
os.environ["INFERENCE_OPTIMIZE"] = "True"
54+
if args.inference_optimize_triton:
55+
os.environ["INFERENCE_OPTIMIZE_TRITON"] = "True"
56+
57+
dtype = paddle.float16
2158
pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", paddle_dtype=dtype)
2259
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
2360
set_seed(42)
2461

2562
words = ["golden retriever"] # class_ids [207]
2663
class_ids = pipe.get_label_ids(words)
64+
image = pipe(class_labels=class_ids, num_inference_steps=25).images[0]
2765

66+
if args.benchmark:
67+
68+
# warmup
69+
for i in range(5):
70+
image = pipe(class_labels=class_ids, num_inference_steps=25).images[0]
71+
72+
repeat_times = 5
73+
74+
paddle.device.synchronize()
75+
starttime = datetime.datetime.now()
76+
for i in range(repeat_times):
77+
image = pipe(class_labels=class_ids, num_inference_steps=25).images[0]
78+
paddle.device.synchronize()
79+
endtime = datetime.datetime.now()
80+
81+
duringtime = endtime - starttime
82+
time_ms = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0
83+
print("The ave end to end time : ", time_ms / repeat_times, "ms")
2884

29-
image = pipe(class_labels=class_ids, num_inference_steps=25).images[0]
3085
image.save("class_conditional_image_generation-dit-result.png")

ppdiffusers/ppdiffusers/models/modeling_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,6 +1050,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
10501050

10511051
return model
10521052

1053+
@classmethod
1054+
def custom_modify_weight(cls, state_dict):
1055+
pass
1056+
10531057
@classmethod
10541058
def _load_pretrained_model(
10551059
cls,
@@ -1130,6 +1134,7 @@ def _find_mismatched_keys(
11301134
error_msgs.append(
11311135
f"Error size mismatch, {key_name} receives a shape {loaded_shape}, but the expected shape is {model_shape}."
11321136
)
1137+
cls.custom_modify_weight(state_dict)
11331138
faster_set_state_dict(model_to_load, state_dict)
11341139

11351140
missing_keys = sorted(list(set(expected_keys) - set(loaded_keys)))
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import math
16+
import os
17+
18+
import paddle
19+
import paddle.nn.functional as F
20+
from paddle import nn
21+
22+
23+
class SimplifiedFacebookDIT(nn.Layer):
24+
def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attention_head_dim: int):
25+
super().__init__()
26+
self.num_layers = num_layers
27+
self.dim = dim
28+
self.heads_num = num_attention_heads
29+
self.head_dim = attention_head_dim
30+
self.timestep_embedder_in_channels = 256
31+
self.timestep_embedder_time_embed_dim = 1152
32+
self.timestep_embedder_time_embed_dim_out = self.timestep_embedder_time_embed_dim
33+
self.LabelEmbedding_num_classes = 1001
34+
self.LabelEmbedding_num_hidden_size = 1152
35+
36+
self.fcs0 = nn.LayerList(
37+
[
38+
nn.Linear(self.timestep_embedder_in_channels, self.timestep_embedder_time_embed_dim)
39+
for i in range(num_layers)
40+
]
41+
)
42+
43+
self.fcs1 = nn.LayerList(
44+
[
45+
nn.Linear(self.timestep_embedder_time_embed_dim, self.timestep_embedder_time_embed_dim_out)
46+
for i in range(num_layers)
47+
]
48+
)
49+
50+
self.fcs2 = nn.LayerList(
51+
[
52+
nn.Linear(self.timestep_embedder_time_embed_dim, 6 * self.timestep_embedder_time_embed_dim)
53+
for i in range(num_layers)
54+
]
55+
)
56+
57+
self.embs = nn.LayerList(
58+
[
59+
nn.Embedding(self.LabelEmbedding_num_classes, self.LabelEmbedding_num_hidden_size)
60+
for i in range(num_layers)
61+
]
62+
)
63+
64+
self.q = nn.LayerList([nn.Linear(dim, dim) for i in range(num_layers)])
65+
self.k = nn.LayerList([nn.Linear(dim, dim) for i in range(num_layers)])
66+
self.v = nn.LayerList([nn.Linear(dim, dim) for i in range(num_layers)])
67+
self.out_proj = nn.LayerList([nn.Linear(dim, dim) for i in range(num_layers)])
68+
self.ffn1 = nn.LayerList([nn.Linear(dim, dim * 4) for i in range(num_layers)])
69+
self.ffn2 = nn.LayerList([nn.Linear(dim * 4, dim) for i in range(num_layers)])
70+
self.norm = nn.LayerNorm(1152, epsilon=1e-06, weight_attr=False, bias_attr=False)
71+
self.norm1 = nn.LayerNorm(1152, epsilon=1e-05, weight_attr=False, bias_attr=False)
72+
73+
def forward(self, hidden_states, timesteps, class_labels):
74+
75+
# below code are copied from PaddleMIX/ppdiffusers/ppdiffusers/models/embeddings.py
76+
num_channels = 256
77+
max_period = 10000
78+
downscale_freq_shift = 1
79+
half_dim = num_channels // 2
80+
exponent = -math.log(max_period) * paddle.arange(start=0, end=half_dim, dtype="float32")
81+
exponent = exponent / (half_dim - downscale_freq_shift)
82+
emb = paddle.exp(exponent)
83+
emb = timesteps[:, None].cast("float32") * emb[None, :]
84+
emb = paddle.concat([paddle.cos(emb), paddle.sin(emb)], axis=-1)
85+
common_emb = emb.cast(hidden_states.dtype)
86+
87+
for i in range(self.num_layers):
88+
emb = self.fcs0[i](common_emb)
89+
emb = F.silu(emb)
90+
emb = self.fcs1[i](emb)
91+
emb = emb + self.embs[i](class_labels)
92+
emb = F.silu(emb)
93+
emb = self.fcs2[i](emb)
94+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, axis=1)
95+
import paddlemix
96+
97+
if os.getenv("INFERENCE_OPTIMIZE_TRITON"):
98+
norm_hidden_states = paddlemix.triton_ops.adaptive_layer_norm(
99+
hidden_states, scale_msa, shift_msa, epsilon=1e-06
100+
)
101+
else:
102+
norm_hidden_states = self.norm(
103+
hidden_states,
104+
)
105+
norm_hidden_states = norm_hidden_states * (1 + scale_msa[:, None]) + shift_msa[:, None]
106+
107+
q = self.q[i](norm_hidden_states).reshape([0, 0, self.heads_num, self.head_dim])
108+
k = self.k[i](norm_hidden_states).reshape([0, 0, self.heads_num, self.head_dim])
109+
v = self.v[i](norm_hidden_states).reshape([0, 0, self.heads_num, self.head_dim])
110+
111+
norm_hidden_states = F.scaled_dot_product_attention_(q, k, v, scale=self.head_dim**-0.5)
112+
norm_hidden_states = norm_hidden_states.reshape(
113+
[norm_hidden_states.shape[0], norm_hidden_states.shape[1], self.dim]
114+
)
115+
norm_hidden_states = self.out_proj[i](norm_hidden_states)
116+
if os.getenv("INFERENCE_OPTIMIZE_TRITON"):
117+
hidden_states, norm_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual(
118+
hidden_states, norm_hidden_states, gate_msa, scale_mlp, shift_mlp, epsilon=1e-05
119+
)
120+
else:
121+
hidden_states = hidden_states + norm_hidden_states * gate_msa.reshape(
122+
[norm_hidden_states.shape[0], 1, self.dim]
123+
)
124+
norm_hidden_states = self.norm1(
125+
hidden_states,
126+
)
127+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
128+
129+
norm_hidden_states = self.ffn1[i](norm_hidden_states)
130+
norm_hidden_states = F.gelu(norm_hidden_states, approximate=True)
131+
norm_hidden_states = self.ffn2[i](norm_hidden_states)
132+
133+
hidden_states = hidden_states + norm_hidden_states * gate_mlp.reshape(
134+
[norm_hidden_states.shape[0], 1, self.dim]
135+
)
136+
137+
return hidden_states

ppdiffusers/ppdiffusers/models/transformer_2d.py

Lines changed: 81 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import os
1415
from dataclasses import dataclass
1516
from typing import Any, Dict, Optional
1617

@@ -33,6 +34,7 @@
3334
from .lora import LoRACompatibleConv, LoRACompatibleLinear
3435
from .modeling_utils import ModelMixin
3536
from .normalization import AdaLayerNormSingle
37+
from .simplified_facebook_dit import SimplifiedFacebookDIT
3638

3739

3840
@dataclass
@@ -114,6 +116,8 @@ def __init__(
114116
self.inner_dim = inner_dim = num_attention_heads * attention_head_dim
115117
self.data_format = data_format
116118

119+
self.inference_optimize = os.getenv("INFERENCE_OPTIMIZE") == "True"
120+
117121
conv_cls = nn.Conv2D if USE_PEFT_BACKEND else LoRACompatibleConv
118122
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
119123

@@ -217,6 +221,17 @@ def __init__(
217221
for d in range(num_layers)
218222
]
219223
)
224+
if self.inference_optimize:
225+
self.simplified_facebookdit = SimplifiedFacebookDIT(
226+
num_layers, inner_dim, num_attention_heads, attention_head_dim
227+
)
228+
self.simplified_facebookdit = paddle.incubate.jit.inference(
229+
self.simplified_facebookdit,
230+
enable_new_ir=True,
231+
cache_static_model=False,
232+
exp_enable_use_cutlass=True,
233+
delete_pass_lists=["add_norm_fuse_pass"],
234+
)
220235

221236
# 4. Define output layers
222237
self.out_channels = in_channels if out_channels is None else out_channels
@@ -392,40 +407,43 @@ def forward(
392407
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
393408
encoder_hidden_states = encoder_hidden_states.reshape([batch_size, -1, hidden_states.shape[-1]])
394409

395-
for block in self.transformer_blocks:
396-
if self.gradient_checkpointing and not hidden_states.stop_gradient and not use_old_recompute():
397-
398-
def create_custom_forward(module, return_dict=None):
399-
def custom_forward(*inputs):
400-
if return_dict is not None:
401-
return module(*inputs, return_dict=return_dict)
402-
else:
403-
return module(*inputs)
404-
405-
return custom_forward
406-
407-
ckpt_kwargs = {} if recompute_use_reentrant() else {"use_reentrant": False}
408-
hidden_states = recompute(
409-
create_custom_forward(block),
410-
hidden_states,
411-
attention_mask,
412-
encoder_hidden_states,
413-
encoder_attention_mask,
414-
timestep,
415-
cross_attention_kwargs,
416-
class_labels,
417-
**ckpt_kwargs,
418-
)
419-
else:
420-
hidden_states = block(
421-
hidden_states,
422-
attention_mask=attention_mask,
423-
encoder_hidden_states=encoder_hidden_states,
424-
encoder_attention_mask=encoder_attention_mask,
425-
timestep=timestep,
426-
cross_attention_kwargs=cross_attention_kwargs,
427-
class_labels=class_labels,
428-
)
410+
if self.inference_optimize:
411+
hidden_states = self.simplified_facebookdit(hidden_states, timestep, class_labels)
412+
else:
413+
for block in self.transformer_blocks:
414+
if self.gradient_checkpointing and not hidden_states.stop_gradient and not use_old_recompute():
415+
416+
def create_custom_forward(module, return_dict=None):
417+
def custom_forward(*inputs):
418+
if return_dict is not None:
419+
return module(*inputs, return_dict=return_dict)
420+
else:
421+
return module(*inputs)
422+
423+
return custom_forward
424+
425+
ckpt_kwargs = {} if recompute_use_reentrant() else {"use_reentrant": False}
426+
hidden_states = recompute(
427+
create_custom_forward(block),
428+
hidden_states,
429+
attention_mask,
430+
encoder_hidden_states,
431+
encoder_attention_mask,
432+
timestep,
433+
cross_attention_kwargs,
434+
class_labels,
435+
**ckpt_kwargs,
436+
)
437+
else:
438+
hidden_states = block(
439+
hidden_states,
440+
attention_mask=attention_mask,
441+
encoder_hidden_states=encoder_hidden_states,
442+
encoder_attention_mask=encoder_attention_mask,
443+
timestep=timestep,
444+
cross_attention_kwargs=cross_attention_kwargs,
445+
class_labels=class_labels,
446+
)
429447

430448
# 3. Output
431449
if self.is_input_continuous:
@@ -489,3 +507,32 @@ def custom_forward(*inputs):
489507
return (output,)
490508

491509
return Transformer2DModelOutput(sample=output)
510+
511+
@classmethod
512+
def custom_modify_weight(cls, state_dict):
513+
if os.getenv("INFERENCE_OPTIMIZE") != "True":
514+
return
515+
for i in range(28):
516+
map_from_my_dit = [
517+
(f"q.{i}.weight", f"{i}.attn1.to_q.weight"),
518+
(f"k.{i}.weight", f"{i}.attn1.to_k.weight"),
519+
(f"v.{i}.weight", f"{i}.attn1.to_v.weight"),
520+
(f"q.{i}.bias", f"{i}.attn1.to_q.bias"),
521+
(f"k.{i}.bias", f"{i}.attn1.to_k.bias"),
522+
(f"v.{i}.bias", f"{i}.attn1.to_v.bias"),
523+
(f"out_proj.{i}.weight", f"{i}.attn1.to_out.0.weight"),
524+
(f"out_proj.{i}.bias", f"{i}.attn1.to_out.0.bias"),
525+
(f"ffn1.{i}.weight", f"{i}.ff.net.0.proj.weight"),
526+
(f"ffn1.{i}.bias", f"{i}.ff.net.0.proj.bias"),
527+
(f"ffn2.{i}.weight", f"{i}.ff.net.2.weight"),
528+
(f"ffn2.{i}.bias", f"{i}.ff.net.2.bias"),
529+
(f"fcs0.{i}.weight", f"{i}.norm1.emb.timestep_embedder.linear_1.weight"),
530+
(f"fcs0.{i}.bias", f"{i}.norm1.emb.timestep_embedder.linear_1.bias"),
531+
(f"fcs1.{i}.weight", f"{i}.norm1.emb.timestep_embedder.linear_2.weight"),
532+
(f"fcs1.{i}.bias", f"{i}.norm1.emb.timestep_embedder.linear_2.bias"),
533+
(f"fcs2.{i}.weight", f"{i}.norm1.linear.weight"),
534+
(f"fcs2.{i}.bias", f"{i}.norm1.linear.bias"),
535+
(f"embs.{i}.weight", f"{i}.norm1.emb.class_embedder.embedding_table.weight"),
536+
]
537+
for to_, from_ in map_from_my_dit:
538+
state_dict["simplified_facebookdit." + to_] = paddle.assign(state_dict["transformer_blocks." + from_])

0 commit comments

Comments
 (0)