Skip to content

Commit 2182063

Browse files
alpha0422mmarcinkiewicz
authored andcommitted
[Flux] Add cuda_graph_scope and cache images ids for full iteration cuda graph.
Signed-off-by: Wil Kong <[email protected]>
1 parent 7d5c2fa commit 2182063

File tree

1 file changed

+3
-0
lines changed
  • nemo/collections/diffusion/models/flux

1 file changed

+3
-0
lines changed

nemo/collections/diffusion/models/flux/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import os
1717
from contextlib import nullcontext
1818
from dataclasses import dataclass, field
19+
from functools import lru_cache
1920
from pathlib import Path
2021
from typing import Callable, Optional
2122

@@ -97,6 +98,7 @@ class FluxConfig(TransformerConfig, io.IOMixin):
9798
use_cpu_initialization: bool = True
9899
gradient_accumulation_fusion: bool = False
99100
enable_cuda_graph: bool = False
101+
cuda_graph_scope: Optional[str] = None # full, full_iteration
100102
use_te_rng_tracker: bool = False
101103
cuda_graph_warmup_steps: int = 2
102104

@@ -731,6 +733,7 @@ def _unpack_latents(self, latents, height, width):
731733

732734
return latents
733735

736+
@lru_cache
734737
def _prepare_latent_image_ids(
735738
self, batch_size: int, height: int, width: int, device: torch.device, dtype: torch.dtype
736739
):

0 commit comments

Comments
 (0)