|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
| 14 | +import os |
14 | 15 | from dataclasses import dataclass |
15 | 16 | from typing import Any, Dict, Optional |
16 | 17 |
|
|
33 | 34 | from .lora import LoRACompatibleConv, LoRACompatibleLinear |
34 | 35 | from .modeling_utils import ModelMixin |
35 | 36 | from .normalization import AdaLayerNormSingle |
| 37 | +from .simplified_facebook_dit import SimplifiedFacebookDIT |
36 | 38 |
|
37 | 39 |
|
38 | 40 | @dataclass |
@@ -114,6 +116,8 @@ def __init__( |
114 | 116 | self.inner_dim = inner_dim = num_attention_heads * attention_head_dim |
115 | 117 | self.data_format = data_format |
116 | 118 |
|
| 119 | + self.inference_optimize = os.getenv("INFERENCE_OPTIMIZE") == "True" |
| 120 | + |
117 | 121 | conv_cls = nn.Conv2D if USE_PEFT_BACKEND else LoRACompatibleConv |
118 | 122 | linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear |
119 | 123 |
|
@@ -217,6 +221,17 @@ def __init__( |
217 | 221 | for d in range(num_layers) |
218 | 222 | ] |
219 | 223 | ) |
| 224 | + if self.inference_optimize: |
| 225 | + self.simplified_facebookdit = SimplifiedFacebookDIT( |
| 226 | + num_layers, inner_dim, num_attention_heads, attention_head_dim |
| 227 | + ) |
| 228 | + self.simplified_facebookdit = paddle.incubate.jit.inference( |
| 229 | + self.simplified_facebookdit, |
| 230 | + enable_new_ir=True, |
| 231 | + cache_static_model=False, |
| 232 | + exp_enable_use_cutlass=True, |
| 233 | + delete_pass_lists=["add_norm_fuse_pass"], |
| 234 | + ) |
220 | 235 |
|
221 | 236 | # 4. Define output layers |
222 | 237 | self.out_channels = in_channels if out_channels is None else out_channels |
@@ -392,40 +407,43 @@ def forward( |
392 | 407 | encoder_hidden_states = self.caption_projection(encoder_hidden_states) |
393 | 408 | encoder_hidden_states = encoder_hidden_states.reshape([batch_size, -1, hidden_states.shape[-1]]) |
394 | 409 |
|
395 | | - for block in self.transformer_blocks: |
396 | | - if self.gradient_checkpointing and not hidden_states.stop_gradient and not use_old_recompute(): |
397 | | - |
398 | | - def create_custom_forward(module, return_dict=None): |
399 | | - def custom_forward(*inputs): |
400 | | - if return_dict is not None: |
401 | | - return module(*inputs, return_dict=return_dict) |
402 | | - else: |
403 | | - return module(*inputs) |
404 | | - |
405 | | - return custom_forward |
406 | | - |
407 | | - ckpt_kwargs = {} if recompute_use_reentrant() else {"use_reentrant": False} |
408 | | - hidden_states = recompute( |
409 | | - create_custom_forward(block), |
410 | | - hidden_states, |
411 | | - attention_mask, |
412 | | - encoder_hidden_states, |
413 | | - encoder_attention_mask, |
414 | | - timestep, |
415 | | - cross_attention_kwargs, |
416 | | - class_labels, |
417 | | - **ckpt_kwargs, |
418 | | - ) |
419 | | - else: |
420 | | - hidden_states = block( |
421 | | - hidden_states, |
422 | | - attention_mask=attention_mask, |
423 | | - encoder_hidden_states=encoder_hidden_states, |
424 | | - encoder_attention_mask=encoder_attention_mask, |
425 | | - timestep=timestep, |
426 | | - cross_attention_kwargs=cross_attention_kwargs, |
427 | | - class_labels=class_labels, |
428 | | - ) |
| 410 | + if self.inference_optimize: |
| 411 | + hidden_states = self.simplified_facebookdit(hidden_states, timestep, class_labels) |
| 412 | + else: |
| 413 | + for block in self.transformer_blocks: |
| 414 | + if self.gradient_checkpointing and not hidden_states.stop_gradient and not use_old_recompute(): |
| 415 | + |
| 416 | + def create_custom_forward(module, return_dict=None): |
| 417 | + def custom_forward(*inputs): |
| 418 | + if return_dict is not None: |
| 419 | + return module(*inputs, return_dict=return_dict) |
| 420 | + else: |
| 421 | + return module(*inputs) |
| 422 | + |
| 423 | + return custom_forward |
| 424 | + |
| 425 | + ckpt_kwargs = {} if recompute_use_reentrant() else {"use_reentrant": False} |
| 426 | + hidden_states = recompute( |
| 427 | + create_custom_forward(block), |
| 428 | + hidden_states, |
| 429 | + attention_mask, |
| 430 | + encoder_hidden_states, |
| 431 | + encoder_attention_mask, |
| 432 | + timestep, |
| 433 | + cross_attention_kwargs, |
| 434 | + class_labels, |
| 435 | + **ckpt_kwargs, |
| 436 | + ) |
| 437 | + else: |
| 438 | + hidden_states = block( |
| 439 | + hidden_states, |
| 440 | + attention_mask=attention_mask, |
| 441 | + encoder_hidden_states=encoder_hidden_states, |
| 442 | + encoder_attention_mask=encoder_attention_mask, |
| 443 | + timestep=timestep, |
| 444 | + cross_attention_kwargs=cross_attention_kwargs, |
| 445 | + class_labels=class_labels, |
| 446 | + ) |
429 | 447 |
|
430 | 448 | # 3. Output |
431 | 449 | if self.is_input_continuous: |
@@ -489,3 +507,32 @@ def custom_forward(*inputs): |
489 | 507 | return (output,) |
490 | 508 |
|
491 | 509 | return Transformer2DModelOutput(sample=output) |
| 510 | + |
| 511 | + @classmethod |
| 512 | + def custom_modify_weight(cls, state_dict): |
| 513 | + if os.getenv("INFERENCE_OPTIMIZE") != "True": |
| 514 | + return |
| 515 | + for i in range(28): |
| 516 | + map_from_my_dit = [ |
| 517 | + (f"q.{i}.weight", f"{i}.attn1.to_q.weight"), |
| 518 | + (f"k.{i}.weight", f"{i}.attn1.to_k.weight"), |
| 519 | + (f"v.{i}.weight", f"{i}.attn1.to_v.weight"), |
| 520 | + (f"q.{i}.bias", f"{i}.attn1.to_q.bias"), |
| 521 | + (f"k.{i}.bias", f"{i}.attn1.to_k.bias"), |
| 522 | + (f"v.{i}.bias", f"{i}.attn1.to_v.bias"), |
| 523 | + (f"out_proj.{i}.weight", f"{i}.attn1.to_out.0.weight"), |
| 524 | + (f"out_proj.{i}.bias", f"{i}.attn1.to_out.0.bias"), |
| 525 | + (f"ffn1.{i}.weight", f"{i}.ff.net.0.proj.weight"), |
| 526 | + (f"ffn1.{i}.bias", f"{i}.ff.net.0.proj.bias"), |
| 527 | + (f"ffn2.{i}.weight", f"{i}.ff.net.2.weight"), |
| 528 | + (f"ffn2.{i}.bias", f"{i}.ff.net.2.bias"), |
| 529 | + (f"fcs0.{i}.weight", f"{i}.norm1.emb.timestep_embedder.linear_1.weight"), |
| 530 | + (f"fcs0.{i}.bias", f"{i}.norm1.emb.timestep_embedder.linear_1.bias"), |
| 531 | + (f"fcs1.{i}.weight", f"{i}.norm1.emb.timestep_embedder.linear_2.weight"), |
| 532 | + (f"fcs1.{i}.bias", f"{i}.norm1.emb.timestep_embedder.linear_2.bias"), |
| 533 | + (f"fcs2.{i}.weight", f"{i}.norm1.linear.weight"), |
| 534 | + (f"fcs2.{i}.bias", f"{i}.norm1.linear.bias"), |
| 535 | + (f"embs.{i}.weight", f"{i}.norm1.emb.class_embedder.embedding_table.weight"), |
| 536 | + ] |
| 537 | + for to_, from_ in map_from_my_dit: |
| 538 | + state_dict["simplified_facebookdit." + to_] = paddle.assign(state_dict["transformer_blocks." + from_]) |
0 commit comments