Re-network the DIT, fix some parameters, and simplify the model networking code#632
Conversation
|
Thanks for your contribution! |
| if qkv is not None: | ||
| state_dict[qkv_key_b] = paddle.concat([qkv, state_dict.pop(key)], axis=-1) | ||
|
|
||
| for key in list(state_dict.keys()): |
There was a problem hiding this comment.
518行以下改成
map_from_my_dit = {}
for i in range(28):
map_from_my_dit[f'tmp_ZKKFacebookDIT.qkv.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_qkv.weight'
map_from_my_dit[f'tmp_ZKKFacebookDIT.qkv.{i}.bias'] = f'transformer_blocks.{i}.attn1.to_qkv.bias'
map_from_my_dit[f'tmp_ZKKFacebookDIT.out_proj.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_out.0.weight'
map_from_my_dit[f'tmp_ZKKFacebookDIT.out_proj.{i}.bias'] = f'transformer_blocks.{i}.attn1.to_out.0.bias'
map_from_my_dit[f'tmp_ZKKFacebookDIT.ffn1.{i}.weight'] = f'transformer_blocks.{i}.ff.net.0.proj.weight'
map_from_my_dit[f'tmp_ZKKFacebookDIT.ffn1.{i}.bias'] = f'transformer_blocks.{i}.ff.net.0.proj.bias'
map_from_my_dit[f'tmp_ZKKFacebookDIT.ffn2.{i}.weight'] = f'transformer_blocks.{i}.ff.net.2.weight'
map_from_my_dit[f'tmp_ZKKFacebookDIT.ffn2.{i}.bias'] = f'transformer_blocks.{i}.ff.net.2.bias'
map_from_my_dit[f'tmp_ZKKFacebookDIT.fcs0.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_1.weight'
map_from_my_dit[f'tmp_ZKKFacebookDIT.fcs0.{i}.bias'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_1.bias'
map_from_my_dit[f'tmp_ZKKFacebookDIT.fcs1.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_2.weight'
map_from_my_dit[f'tmp_ZKKFacebookDIT.fcs1.{i}.bias'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_2.bias'
map_from_my_dit[f'tmp_ZKKFacebookDIT.fcs2.{i}.weight'] = f'transformer_blocks.{i}.norm1.linear.weight'
map_from_my_dit[f'tmp_ZKKFacebookDIT.fcs2.{i}.bias'] = f'transformer_blocks.{i}.norm1.linear.bias'
map_from_my_dit[f'tmp_ZKKFacebookDIT.embs.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.class_embedder.embedding_table.weight'
for key in map_from_my_dit.keys():
state_dict[key] = paddle.assign(state_dict[map_from_my_dit[key]])There was a problem hiding this comment.
已更改!
感谢提供修改意见,辛苦!
| def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attention_head_dim: int): | ||
| super().__init__() | ||
| self.num_layers = num_layers | ||
| self.dtype = "float16" |
There was a problem hiding this comment.
self.dtype = "float16"改成可配置的。
There was a problem hiding this comment.
已更改!
感谢提供修改意见,辛苦!
| error_msgs.append( | ||
| f"Error size mismatch, {key_name} receives a shape {loaded_shape}, but the expected shape is {model_shape}." | ||
| ) | ||
| if os.getenv('Inference_Optimize'): |
There was a problem hiding this comment.
这里去掉,改在transformer_2d.py里面判断吧
There was a problem hiding this comment.
已更改!
感谢提供修改意见,辛苦!
| recompute_use_reentrant, | ||
| use_old_recompute, | ||
| ) | ||
| from .simplified_facebook_dit import Simplified_FacebookDIT |
There was a problem hiding this comment.
Simplified_FacebookDIT改成SimplifiedFacebookDIT
There was a problem hiding this comment.
已更改!
感谢提供修改意见,辛苦!
| ] | ||
| ) | ||
| if self.Inference_Optimize: | ||
| self.simplified_facebookDIT = SimplifiedFacebookDIT(num_layers, inner_dim, num_attention_heads, attention_head_dim) |
There was a problem hiding this comment.
这里del self.transformer_blocks吧
There was a problem hiding this comment.
修改该项会引发相关报错,因为该方法还需要在其他位置调用,暂时不做更改!
感谢提供修改意见,辛苦!
| self.inner_dim = inner_dim = num_attention_heads * attention_head_dim | ||
| self.data_format = data_format | ||
|
|
||
| self.Inference_Optimize = bool(os.getenv('Inference_Optimize')) |
There was a problem hiding this comment.
self.Inference_Optimize = os.getenv('Inference_Optimize') == "True"
There was a problem hiding this comment.
已更改!
感谢提供修改意见,辛苦!
| return | ||
| map_from_my_dit = {} | ||
| for i in range(28): | ||
| map_from_my_dit[f'simplified_facebookDIT.q.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_q.weight' |
There was a problem hiding this comment.
尽量减少代码的拷贝,例如公共的命名前缀应该抽出来,避免后续修改
There was a problem hiding this comment.
尽量减少代码的拷贝,例如公共的命名前缀应该抽出来,避免后续修改
已更改,折叠了部分命名代码!
感谢提供修改意见,辛苦!
| from ppdiffusers import DDIMScheduler, DiTPipeline | ||
|
|
||
| dtype = paddle.float32 | ||
| os.environ["Inference_Optimize"] = "False" |
There was a problem hiding this comment.
已更改!
感谢提供修改意见,辛苦!
|
|
||
| # warmup | ||
| for i in range(5): | ||
| image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] |
There was a problem hiding this comment.
这里只是为了测benchmark,实际用户并不需要warmpup。看下是否增加benchmark开关。
There was a problem hiding this comment.
已更改,添加benchmark & inference_optimize 的相关开关!
感谢提供修改意见,辛苦!
|
|
||
|
|
||
| import datetime | ||
| import time |
There was a problem hiding this comment.
已更改!
感谢提供修改意见,辛苦!
|
|
||
| image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] | ||
| for i in range(repeat_times): | ||
| image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] |
There was a problem hiding this comment.
同上,benchmark才需要,用户使用不需要
There was a problem hiding this comment.
已更改!
感谢提供修改意见,辛苦!
| enable_new_ir=True, | ||
| cache_static_model=False, | ||
| exp_enable_use_cutlass=True, | ||
| delete_pass_lists=["add_norm_fuse_pass"], |
There was a problem hiding this comment.
已使用pre-commit调整!
感谢提供修改意见,辛苦!
| self.inner_dim = inner_dim = num_attention_heads * attention_head_dim | ||
| self.data_format = data_format | ||
|
|
||
| self.Inference_Optimize = os.getenv('Inference_Optimize') == "True" |
There was a problem hiding this comment.
self.inference_optimize ,遵守命名规范
There was a problem hiding this comment.
已更改!
感谢提供修改意见,辛苦!
| import paddle.nn.functional as F | ||
| import math | ||
|
|
||
| class SimplifiedFacebookDIT(nn.Layer): |
There was a problem hiding this comment.
必须一定要简化这个模块吗?
手工优化需要
There was a problem hiding this comment.
手工优化需要对原动态图模型组网 做高性能精简重组,这一模块还将transformer循环中的冗余计算部分提出,减少了部分计算量。
感谢提供修改意见,辛苦!
| self.proj_out = linear_cls(inner_dim, in_channels) | ||
| else: | ||
| self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0, data_format=data_format) | ||
| self.proj_out = conv_cls( |
There was a problem hiding this comment.
格式修改请忽略
采用pre-commit统一修改格式!
感谢提供修改意见,辛苦!
| self.in_channels = in_channels | ||
|
|
||
| self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, epsilon=1e-6, data_format=data_format) | ||
| self.norm = nn.GroupNorm( |
| self.proj_in = linear_cls(in_channels, inner_dim) | ||
| else: | ||
| self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0, data_format=data_format) | ||
| self.proj_in = conv_cls( |
Latest optimization: Re-network DIT, simplify the original model dynamic graph into a high-performance model network,
paddle.incubate.jit.inferenceto do dynamic and static conversion, and removes redundant parts in the loop;Currently facebook-DIT takes: 219.936 ms