diff --git a/scripts/performance/argument_parser.py b/scripts/performance/argument_parser.py index 4cdee11b14fc..f7ea0adcc11f 100644 --- a/scripts/performance/argument_parser.py +++ b/scripts/performance/argument_parser.py @@ -301,6 +301,14 @@ def bool_arg(arg): required=False, default=None, ) + parser.add_argument( + "-fsdp_db", + "--use_fsdp_double_buffer", + help="Enable FSDP double buffer. Disabled by default", + type=bool_arg, + required=False, + default=None, + ) parser.add_argument( "-ubr", "--use_user_buffer_registration", diff --git a/scripts/performance/helpers.py b/scripts/performance/helpers.py index b09e75215ca8..80303a5574dc 100644 --- a/scripts/performance/helpers.py +++ b/scripts/performance/helpers.py @@ -313,6 +313,12 @@ def set_perf_optimization_configs( if use_fsdp_double_buffer: assert use_mcore_fsdp == True, "use_fsdp_double_buffer requires use_mcore_fsdp to be True" + if use_user_buffer_registration: + assert use_mcore_fsdp == True, "use_user_buffer_registration requires use_mcore_fsdp to be True" + assert ( + use_fsdp_double_buffer is not False + ), "use_fsdp_double_buffer cannot be False when use_user_buffer_registration is True" + if use_mcore_fsdp and enable_cuda_graphs: logging.warning("Currently, cuda graphs are not supported with FSDP. Disabling cuda graphs.") enable_cuda_graphs = False @@ -360,6 +366,7 @@ def set_primary_perf_configs( etp_size: Optional[int] = None, enable_cuda_graphs: bool = False, use_mcore_fsdp: bool = False, + use_fsdp_double_buffer: bool = False, use_user_buffer_registration: bool = False, use_sharp: bool = False, recompute_layers: int = 0,