-
Notifications
You must be signed in to change notification settings - Fork 223
Re-network the DIT, fix some parameters, and simplify the model networking code #632
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
59f23a0
5fee64b
f653a66
3b29d9d
a88caea
54eeec2
28a62c0
884e29a
15d08b6
7d49c49
b03aa8e
cb86d17
dc0c45c
42f61bc
000dd80
9bb9cde
d3de838
400ab19
bfe8c41
8896057
e9aa47d
c8916f7
a87f81b
0a09bf2
10e8c1f
10953b5
922d7d0
c4f8242
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,12 +12,19 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import os | ||
| os.environ["CUDA_VISIBLE_DEVICES"] = "2" | ||
| import paddle | ||
| from paddlenlp.trainer import set_seed | ||
|
|
||
| from ppdiffusers import DDIMScheduler, DiTPipeline | ||
|
|
||
| dtype = paddle.float32 | ||
| Inference_Optimize = True | ||
| if Inference_Optimize: | ||
| os.environ["Inference_Optimize"] = "True" | ||
| else: | ||
| pass | ||
|
|
||
| dtype = paddle.float16 | ||
| pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", paddle_dtype=dtype) | ||
| pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) | ||
| set_seed(42) | ||
|
|
@@ -27,4 +34,35 @@ | |
|
|
||
|
|
||
| image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] | ||
| image.save("class_conditional_image_generation-dit-result.png") | ||
| # image.save("class_conditional_image_generation-dit-result.png") | ||
| image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] | ||
| image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] | ||
|
|
||
|
|
||
| import datetime | ||
| import time | ||
|
|
||
| warm_up_times = 5 | ||
| repeat_times = 10 | ||
| sum_time = 0. | ||
|
|
||
| for i in range(repeat_times): | ||
| paddle.device.synchronize() | ||
| starttime = datetime.datetime.now() | ||
| image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] | ||
|
||
| paddle.device.synchronize() | ||
| endtime = datetime.datetime.now() | ||
| duringtime = endtime - starttime | ||
| time_ms = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0 | ||
| evet = "every_time: " + str(time_ms) + "ms\n\n" | ||
| with open("/cwb/wenbin/PaddleMIX/ppdiffusers/examples/inference/Aibin/time_729.txt", "a") as time_file: | ||
| time_file.write(evet) | ||
| sum_time+=time_ms | ||
| print("The ave end to end time : ", sum_time / repeat_times, "ms") | ||
| msg = "average_time: " + str(sum_time / repeat_times) + "ms\n\n" | ||
| print(msg) | ||
| with open("/cwb/wenbin/PaddleMIX/ppdiffusers/examples/inference/Aibin/time_729.txt", "a") as time_file: | ||
| time_file.write(msg) | ||
|
|
||
| image.save("class_conditional_image_generation-dit-29.png") | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1050,6 +1050,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
|
|
||
| return model | ||
|
|
||
| @classmethod | ||
| def custom_modify_weight(cls, state_dict): | ||
| pass | ||
|
|
||
| @classmethod | ||
| def _load_pretrained_model( | ||
| cls, | ||
|
|
@@ -1130,6 +1134,8 @@ def _find_mismatched_keys( | |
| error_msgs.append( | ||
| f"Error size mismatch, {key_name} receives a shape {loaded_shape}, but the expected shape is {model_shape}." | ||
| ) | ||
| if os.getenv('Inference_Optimize'): | ||
|
||
| cls.custom_modify_weight(state_dict) | ||
| faster_set_state_dict(model_to_load, state_dict) | ||
|
|
||
| missing_keys = sorted(list(set(expected_keys) - set(loaded_keys))) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| from paddle import nn | ||
| import paddle | ||
| import paddle.nn.functional as F | ||
| import math | ||
|
|
||
| class Simplified_FacebookDIT(nn.Layer): | ||
| def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attention_head_dim: int): | ||
| super().__init__() | ||
| self.num_layers = num_layers | ||
| self.dtype = "float16" | ||
|
||
| self.dim = dim | ||
| self.num_attention_heads = num_attention_heads | ||
| self.attention_head_dim = attention_head_dim | ||
| self.timestep_embedder_in_channels = 256 | ||
| self.timestep_embedder_time_embed_dim = 1152 | ||
| self.timestep_embedder_time_embed_dim_out = self.timestep_embedder_time_embed_dim | ||
| self.LabelEmbedding_num_classes = 1001 | ||
| self.LabelEmbedding_num_hidden_size = 1152 | ||
|
|
||
| self.fcs0 = nn.LayerList([nn.Linear(self.timestep_embedder_in_channels, | ||
| self.timestep_embedder_time_embed_dim) for i in range(self.num_layers)]) | ||
|
|
||
| self.fcs1 = nn.LayerList([nn.Linear(self.timestep_embedder_time_embed_dim, | ||
| self.timestep_embedder_time_embed_dim_out) for i in range(self.num_layers)]) | ||
|
|
||
| self.fcs2 = nn.LayerList([nn.Linear(self.timestep_embedder_time_embed_dim, | ||
| 6 * self.timestep_embedder_time_embed_dim) for i in range(self.num_layers)]) | ||
|
|
||
| self.embs = nn.LayerList([nn.Embedding(self.LabelEmbedding_num_classes, | ||
| self.LabelEmbedding_num_hidden_size) for i in range(self.num_layers)]) | ||
|
|
||
|
|
||
| self.qkv = nn.LayerList([nn.Linear(dim, dim * 3) for i in range(self.num_layers)]) | ||
| self.out_proj = nn.LayerList([nn.Linear(dim, dim) for i in range(self.num_layers)]) | ||
| self.ffn1 = nn.LayerList([nn.Linear(dim, dim*4) for i in range(self.num_layers)]) | ||
| self.ffn2 = nn.LayerList([nn.Linear(dim*4, dim) for i in range(self.num_layers)]) | ||
|
|
||
| @paddle.incubate.jit.inference(enable_new_ir=True, | ||
| cache_static_model=True, | ||
| exp_enable_use_cutlass=True, | ||
| delete_pass_lists=["add_norm_fuse_pass"], | ||
| ) | ||
| def forward(self, hidden_states, timesteps, class_labels): | ||
|
|
||
| # below code are copied from PaddleMIX/ppdiffusers/ppdiffusers/models/embeddings.py | ||
| num_channels = 256 | ||
| max_period = 10000 | ||
| downscale_freq_shift = 1 | ||
| half_dim = num_channels // 2 | ||
| exponent = -math.log(max_period) * paddle.arange(start=0, end=half_dim, dtype="float32") | ||
| exponent = exponent / (half_dim - downscale_freq_shift) | ||
| emb = paddle.exp(exponent) | ||
| emb = timesteps[:, None].cast("float32") * emb[None, :] | ||
| emb = paddle.concat([paddle.cos(emb), paddle.sin(emb)], axis=-1) | ||
| common_emb = emb.cast(self.dtype) | ||
|
|
||
| for i in range(self.num_layers): | ||
| emb = self.fcs0[i](common_emb) | ||
| emb = F.silu(emb) | ||
| emb = self.fcs1[i](emb) | ||
| emb = emb + self.embs[i](class_labels) | ||
| emb = F.silu(emb) | ||
| emb = self.fcs2[i](emb) | ||
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, axis=1) | ||
| import paddlemix | ||
| norm_hidden_states =paddlemix.triton_ops.adaptive_layer_norm(hidden_states, scale_msa, shift_msa) | ||
| q,k,v = self.qkv[i](norm_hidden_states).chunk(3, axis=-1) | ||
| b,s,h = q.shape | ||
| q = q.reshape([b,s,self.num_attention_heads,self.attention_head_dim]) | ||
| k = k.reshape([b,s,self.num_attention_heads,self.attention_head_dim]) | ||
| v = v.reshape([b,s,self.num_attention_heads,self.attention_head_dim]) | ||
|
|
||
| norm_hidden_states = F.scaled_dot_product_attention_(q, k, v, scale=self.attention_head_dim**-0.5) | ||
| norm_hidden_states = norm_hidden_states.reshape([b,s,self.dim]) | ||
| norm_hidden_states = self.out_proj[i](norm_hidden_states) | ||
|
|
||
| # hidden_states = hidden_states + norm_hidden_states * gate_msa.reshape([b,1,self.dim]) | ||
| # norm_hidden_states =paddlemix.triton_ops.adaptive_layer_norm(hidden_states, scale_mlp, shift_mlp) | ||
| hidden_states,norm_hidden_states =paddlemix.triton_ops.fused_adaLN_scale_residual(hidden_states, norm_hidden_states, gate_msa, scale_mlp, shift_mlp) | ||
|
|
||
| norm_hidden_states = self.ffn1[i](norm_hidden_states) | ||
| norm_hidden_states = F.gelu(norm_hidden_states, approximate=True) | ||
| norm_hidden_states = self.ffn2[i](norm_hidden_states) | ||
|
|
||
| hidden_states = hidden_states + norm_hidden_states * gate_mlp.reshape([b,1,self.dim]) | ||
|
|
||
| return hidden_states | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,11 +28,15 @@ | |
| recompute_use_reentrant, | ||
| use_old_recompute, | ||
| ) | ||
| from .simplified_facebook_dit import Simplified_FacebookDIT | ||
|
||
|
|
||
| from .attention import BasicTransformerBlock | ||
| from .embeddings import CaptionProjection, PatchEmbed | ||
| from .lora import LoRACompatibleConv, LoRACompatibleLinear | ||
| from .modeling_utils import ModelMixin | ||
| from .normalization import AdaLayerNormSingle | ||
| import os | ||
|
|
||
|
|
||
|
|
||
| @dataclass | ||
|
|
@@ -114,6 +118,8 @@ def __init__( | |
| self.inner_dim = inner_dim = num_attention_heads * attention_head_dim | ||
| self.data_format = data_format | ||
|
|
||
| self.Inference_Optimize = bool(os.getenv('Inference_Optimize')) | ||
|
||
|
|
||
| conv_cls = nn.Conv2D if USE_PEFT_BACKEND else LoRACompatibleConv | ||
| linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear | ||
|
|
||
|
|
@@ -213,6 +219,8 @@ def __init__( | |
| for d in range(num_layers) | ||
| ] | ||
| ) | ||
| if self.Inference_Optimize: | ||
| self.Simplified_FacebookDIT = Simplified_FacebookDIT(num_layers, inner_dim, num_attention_heads, attention_head_dim) | ||
|
|
||
| # 4. Define output layers | ||
| self.out_channels = in_channels if out_channels is None else out_channels | ||
|
|
@@ -250,6 +258,7 @@ def __init__( | |
| self.caption_projection = CaptionProjection(in_features=caption_channels, hidden_size=inner_dim) | ||
|
|
||
| self.gradient_checkpointing = False | ||
|
|
||
|
|
||
| def _set_gradient_checkpointing(self, module, value=False): | ||
| if hasattr(module, "gradient_checkpointing"): | ||
|
|
@@ -384,41 +393,44 @@ def forward( | |
| batch_size = hidden_states.shape[0] | ||
| encoder_hidden_states = self.caption_projection(encoder_hidden_states) | ||
| encoder_hidden_states = encoder_hidden_states.reshape([batch_size, -1, hidden_states.shape[-1]]) | ||
|
|
||
| for block in self.transformer_blocks: | ||
| if self.gradient_checkpointing and not hidden_states.stop_gradient and not use_old_recompute(): | ||
|
|
||
| def create_custom_forward(module, return_dict=None): | ||
| def custom_forward(*inputs): | ||
| if return_dict is not None: | ||
| return module(*inputs, return_dict=return_dict) | ||
| else: | ||
| return module(*inputs) | ||
|
|
||
| return custom_forward | ||
|
|
||
| ckpt_kwargs = {} if recompute_use_reentrant() else {"use_reentrant": False} | ||
| hidden_states = recompute( | ||
| create_custom_forward(block), | ||
| hidden_states, | ||
| attention_mask, | ||
| encoder_hidden_states, | ||
| encoder_attention_mask, | ||
| timestep, | ||
| cross_attention_kwargs, | ||
| class_labels, | ||
| **ckpt_kwargs, | ||
| ) | ||
| else: | ||
| hidden_states = block( | ||
| hidden_states, | ||
| attention_mask=attention_mask, | ||
| encoder_hidden_states=encoder_hidden_states, | ||
| encoder_attention_mask=encoder_attention_mask, | ||
| timestep=timestep, | ||
| cross_attention_kwargs=cross_attention_kwargs, | ||
| class_labels=class_labels, | ||
| ) | ||
|
|
||
| if self.Inference_Optimize: | ||
| hidden_states =self.Simplified_FacebookDIT(hidden_states, timestep, class_labels) | ||
| else: | ||
| for block in self.transformer_blocks: | ||
| if self.gradient_checkpointing and not hidden_states.stop_gradient and not use_old_recompute(): | ||
|
|
||
| def create_custom_forward(module, return_dict=None): | ||
| def custom_forward(*inputs): | ||
| if return_dict is not None: | ||
| return module(*inputs, return_dict=return_dict) | ||
| else: | ||
| return module(*inputs) | ||
|
|
||
| return custom_forward | ||
|
|
||
| ckpt_kwargs = {} if recompute_use_reentrant() else {"use_reentrant": False} | ||
| hidden_states = recompute( | ||
| create_custom_forward(block), | ||
| hidden_states, | ||
| attention_mask, | ||
| encoder_hidden_states, | ||
| encoder_attention_mask, | ||
| timestep, | ||
| cross_attention_kwargs, | ||
| class_labels, | ||
| **ckpt_kwargs, | ||
| ) | ||
| else: | ||
| hidden_states = block( | ||
| hidden_states, | ||
| attention_mask=attention_mask, | ||
| encoder_hidden_states=encoder_hidden_states, | ||
| encoder_attention_mask=encoder_attention_mask, | ||
| timestep=timestep, | ||
| cross_attention_kwargs=cross_attention_kwargs, | ||
| class_labels=class_labels, | ||
| ) | ||
|
|
||
| # 3. Output | ||
| if self.is_input_continuous: | ||
|
|
@@ -482,3 +494,51 @@ def custom_forward(*inputs): | |
| return (output,) | ||
|
|
||
| return Transformer2DModelOutput(sample=output) | ||
|
|
||
| @classmethod | ||
| def custom_modify_weight(cls, state_dict): | ||
| for key in list(state_dict.keys()): | ||
| if 'attn1.to_q.weight' in key or 'attn1.to_k.weight' in key or 'attn1.to_v.weight' in key: | ||
| part = key.split('.')[-2] | ||
| layer_id = key.split('.')[1] | ||
| qkv_key_w = f'transformer_blocks.{layer_id}.attn1.to_qkv.weight' | ||
| if part == 'to_q' and qkv_key_w not in state_dict: | ||
| state_dict[qkv_key_w] = state_dict.pop(key) | ||
| elif part in ('to_k', 'to_v'): | ||
| qkv = state_dict.get(qkv_key_w) | ||
| if qkv is not None: | ||
| state_dict[qkv_key_w] = paddle.concat([qkv, state_dict.pop(key)], axis=-1) | ||
| if 'attn1.to_q.bias' in key or 'attn1.to_k.bias' in key or 'attn1.to_v.bias' in key: | ||
| part = key.split('.')[-2] | ||
| layer_id = key.split('.')[1] | ||
| qkv_key_b = f'transformer_blocks.{layer_id}.attn1.to_qkv.bias' | ||
| if part == 'to_q' and qkv_key_b not in state_dict: | ||
| state_dict[qkv_key_b] = state_dict.pop(key) | ||
| elif part in ('to_k', 'to_v'): | ||
| qkv = state_dict.get(qkv_key_b) | ||
| if qkv is not None: | ||
| state_dict[qkv_key_b] = paddle.concat([qkv, state_dict.pop(key)], axis=-1) | ||
|
|
||
| map_from_my_dit = {} | ||
| for i in range(28): | ||
| map_from_my_dit[f'Simplified_FacebookDIT.qkv.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_qkv.weight' | ||
| map_from_my_dit[f'Simplified_FacebookDIT.qkv.{i}.bias'] = f'transformer_blocks.{i}.attn1.to_qkv.bias' | ||
| map_from_my_dit[f'Simplified_FacebookDIT.out_proj.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_out.0.weight' | ||
| map_from_my_dit[f'Simplified_FacebookDIT.out_proj.{i}.bias'] = f'transformer_blocks.{i}.attn1.to_out.0.bias' | ||
| map_from_my_dit[f'Simplified_FacebookDIT.ffn1.{i}.weight'] = f'transformer_blocks.{i}.ff.net.0.proj.weight' | ||
| map_from_my_dit[f'Simplified_FacebookDIT.ffn1.{i}.bias'] = f'transformer_blocks.{i}.ff.net.0.proj.bias' | ||
| map_from_my_dit[f'Simplified_FacebookDIT.ffn2.{i}.weight'] = f'transformer_blocks.{i}.ff.net.2.weight' | ||
| map_from_my_dit[f'Simplified_FacebookDIT.ffn2.{i}.bias'] = f'transformer_blocks.{i}.ff.net.2.bias' | ||
|
|
||
| map_from_my_dit[f'Simplified_FacebookDIT.fcs0.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_1.weight' | ||
| map_from_my_dit[f'Simplified_FacebookDIT.fcs0.{i}.bias'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_1.bias' | ||
| map_from_my_dit[f'Simplified_FacebookDIT.fcs1.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_2.weight' | ||
| map_from_my_dit[f'Simplified_FacebookDIT.fcs1.{i}.bias'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_2.bias' | ||
| map_from_my_dit[f'Simplified_FacebookDIT.fcs2.{i}.weight'] = f'transformer_blocks.{i}.norm1.linear.weight' | ||
| map_from_my_dit[f'Simplified_FacebookDIT.fcs2.{i}.bias'] = f'transformer_blocks.{i}.norm1.linear.bias' | ||
|
|
||
| map_from_my_dit[f'Simplified_FacebookDIT.embs.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.class_embedder.embedding_table.weight' | ||
|
|
||
| for key in map_from_my_dit.keys(): | ||
| state_dict[key] = paddle.assign(state_dict[map_from_my_dit[key]]) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import移动到前面
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已更改!
感谢提供修改意见,辛苦!