File tree Expand file tree Collapse file tree
ppdiffusers/examples/class_conditional_image_generation/DiT/diffusion Expand file tree Collapse file tree Original file line number Diff line number Diff line change 1616from .dit import DiT
1717from .dit_llama import DiT_Llama
1818from .respace import SpacedDiffusion , space_timesteps
19- from .trainer import LatentDiffusionTrainer , LatentDiffusionAutoTrainer
19+ from .trainer import LatentDiffusionTrainer
20+ try :
21+ from paddlenlp .trainer .auto_trainer import AutoTrainer
22+ from .trainer_auto import LatentDiffusionAutoTrainer
23+ except :
24+ print (f'please install paddlepaddle-gpu>=3.0.0b2 if using auto trainer' )
25+
2026from .trainer_args import (
2127 DataArguments ,
2228 ModelArguments ,
Original file line number Diff line number Diff line change 2323from paddle .distributed import fleet
2424from paddle .io import get_worker_info
2525from paddlenlp .trainer import Trainer
26- from paddlenlp .trainer .auto_trainer import AutoTrainer
2726from paddlenlp .trainer .integrations import (
2827 INTEGRATION_TO_CALLBACK ,
2928 TrainerCallback ,
@@ -295,20 +294,6 @@ def __impl__():
295294
296295 return __impl__
297296
298- class LatentDiffusionAutoTrainer (AutoTrainer ):
299- def __init__ (self , * args , ** kwargs ):
300- super ().__init__ (* args , ** kwargs )
301-
302- def _get_meshes_for_loader (self ):
303- def _get_mesh (pp_idx = 0 ):
304- return fleet .auto .get_mesh ().get_mesh_with_dim ("pp" )[pp_idx ]
305-
306- return _get_mesh (0 ) # label_id is not label
307-
308- def _wrap_for_dist_loader (self , train_dataloader ):
309- dist_loader = super ()._wrap_for_dist_loader (train_dataloader )
310- dist_loader ._input_keys = ["latents" , "label_id" ]
311- return dist_loader
312297
313298class LatentDiffusionTrainer (Trainer ):
314299 def __init__ (self , ** kwargs ):
You can’t perform that action at this time.
0 commit comments