diff --git a/torchrec/distributed/benchmark/benchmark_base.py b/torchrec/distributed/benchmark/base.py similarity index 100% rename from torchrec/distributed/benchmark/benchmark_base.py rename to torchrec/distributed/benchmark/base.py diff --git a/torchrec/distributed/benchmark/benchmark_inference.py b/torchrec/distributed/benchmark/benchmark_inference.py index ff7bcf6c4..5ae91d66f 100644 --- a/torchrec/distributed/benchmark/benchmark_inference.py +++ b/torchrec/distributed/benchmark/benchmark_inference.py @@ -18,7 +18,7 @@ import torch -from torchrec.distributed.benchmark.benchmark_base import ( +from torchrec.distributed.benchmark.base import ( BenchmarkResult, CompileMode, DLRM_NUM_EMBEDDINGS_PER_FEATURE, diff --git a/torchrec/distributed/benchmark/benchmark_set_sharding_context_post_a2a.py b/torchrec/distributed/benchmark/benchmark_set_sharding_context_post_a2a.py index 4f9060791..a269fb807 100644 --- a/torchrec/distributed/benchmark/benchmark_set_sharding_context_post_a2a.py +++ b/torchrec/distributed/benchmark/benchmark_set_sharding_context_post_a2a.py @@ -13,7 +13,7 @@ import click import torch -from torchrec.distributed.benchmark.benchmark_base import benchmark_func +from torchrec.distributed.benchmark.base import benchmark_func from torchrec.distributed.embedding import EmbeddingCollectionContext from torchrec.distributed.embedding_sharding import _set_sharding_context_post_a2a from torchrec.distributed.sharding.sequence_sharding import SequenceShardingContext diff --git a/torchrec/distributed/benchmark/benchmark_split_table_batched_embeddings.py b/torchrec/distributed/benchmark/benchmark_split_table_batched_embeddings.py index df4d6d6e7..08cac5db1 100644 --- a/torchrec/distributed/benchmark/benchmark_split_table_batched_embeddings.py +++ b/torchrec/distributed/benchmark/benchmark_split_table_batched_embeddings.py @@ -26,7 +26,7 @@ ComputeDevice, SplitTableBatchedEmbeddingBagsCodegen, ) -from torchrec.distributed.benchmark.benchmark_base import benchmark_func +from torchrec.distributed.benchmark.base import benchmark_func from torchrec.distributed.test_utils.test_model import ModelInput from torchrec.modules.embedding_configs import EmbeddingBagConfig from torchrec.sparse.jagged_tensor import KeyedJaggedTensor diff --git a/torchrec/distributed/benchmark/benchmark_train.py b/torchrec/distributed/benchmark/benchmark_train.py index c35943791..f76ad9aed 100644 --- a/torchrec/distributed/benchmark/benchmark_train.py +++ b/torchrec/distributed/benchmark/benchmark_train.py @@ -17,7 +17,7 @@ from typing import List, Optional, Tuple import torch -from torchrec.distributed.benchmark.benchmark_base import ( +from torchrec.distributed.benchmark.base import ( BenchmarkResult, CompileMode, init_argparse_and_args, diff --git a/torchrec/distributed/benchmark/benchmark_train_pipeline.py b/torchrec/distributed/benchmark/benchmark_train_pipeline.py index 81d018950..4b28a70b6 100644 --- a/torchrec/distributed/benchmark/benchmark_train_pipeline.py +++ b/torchrec/distributed/benchmark/benchmark_train_pipeline.py @@ -16,7 +16,7 @@ OSS (external): python -m torchrec.distributed.benchmark.benchmark_train_pipeline --world_size=4 --pipeline=sparse --batch_size=10 -Adding New Model Support: +To support a new model in pipeline benchmark: See benchmark_pipeline_utils.py for step-by-step instructions. """ @@ -26,7 +26,7 @@ import torch from fbgemm_gpu.split_embedding_configs import EmbOptimType from torch import nn -from torchrec.distributed.benchmark.benchmark_base import ( +from torchrec.distributed.benchmark.base import ( benchmark_func, BenchmarkResult, cmd_conf, @@ -37,7 +37,6 @@ BaseModelConfig, create_model_config, generate_data, - generate_pipeline, generate_planner, generate_sharded_model_and_optimizer, ) @@ -49,9 +48,10 @@ MultiProcessContext, run_multi_process_func, ) +from torchrec.distributed.test_utils.table_config import EmbeddingTablesConfig from torchrec.distributed.test_utils.test_input import ModelInput from torchrec.distributed.test_utils.test_model import TestOverArchLarge -from torchrec.distributed.test_utils.test_tables import EmbeddingTablesConfig +from torchrec.distributed.test_utils.train_pipeline import PipelineConfig from torchrec.distributed.train_pipeline import TrainPipeline from torchrec.distributed.types import ShardingType from torchrec.modules.embedding_configs import EmbeddingBagConfig @@ -116,33 +116,6 @@ class RunOptions: export_stacks: bool = False -@dataclass -class PipelineConfig: - """ - Configuration for training pipelines. - - This class defines the parameters for configuring the training pipeline. - - Args: - pipeline (str): The type of training pipeline to use. Options include: - - "base": Basic training pipeline - - "sparse": Pipeline optimized for sparse operations - - "fused": Pipeline with fused sparse distribution - - "semi": Semi-synchronous training pipeline - - "prefetch": Pipeline with prefetching for sparse distribution - Default is "base". - emb_lookup_stream (str): The stream to use for embedding lookups. - Only used by certain pipeline types (e.g., "fused"). - Default is "data_dist". - apply_jit (bool): Whether to apply JIT (Just-In-Time) compilation to the model. - Default is False. - """ - - pipeline: str = "base" - emb_lookup_stream: str = "data_dist" - apply_jit: bool = False - - @dataclass class ModelSelectionConfig: model_name: str = "test_sparse_nn" @@ -279,13 +252,10 @@ def _func_to_benchmark( except StopIteration: break - pipeline = generate_pipeline( - pipeline_type=pipeline_config.pipeline, - emb_lookup_stream=pipeline_config.emb_lookup_stream, + pipeline = pipeline_config.generate_pipeline( model=sharded_model, opt=optimizer, device=ctx.device, - apply_jit=pipeline_config.apply_jit, ) pipeline.progress(iter(bench_inputs)) diff --git a/torchrec/distributed/benchmark/benchmark_utils.py b/torchrec/distributed/benchmark/benchmark_utils.py index 80b155d0a..dca96323e 100644 --- a/torchrec/distributed/benchmark/benchmark_utils.py +++ b/torchrec/distributed/benchmark/benchmark_utils.py @@ -10,7 +10,7 @@ """ Utilities for benchmarking training pipelines with different model configurations. -Adding New Model Support: +To support a new model in pipeline benchmark: 1. Create config class inheriting from BaseModelConfig with generate_model() method 2. Add the model to model_configs dict in create_model_config() 3. Add model-specific params to ModelSelectionConfig and create_model_config's arguments in benchmark_train_pipeline.py @@ -39,15 +39,6 @@ TestTowerCollectionSparseNN, TestTowerSparseNN, ) -from torchrec.distributed.train_pipeline import ( - TrainPipelineBase, - TrainPipelineFusedSparseDist, - TrainPipelineSparseDist, -) -from torchrec.distributed.train_pipeline.train_pipelines import ( - PrefetchTrainPipelineSparseDist, - TrainPipelineSemiSync, -) from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType from torchrec.models.deepfm import SimpleDeepFMNNWrapper from torchrec.models.dlrm import DLRMWrapper @@ -249,80 +240,6 @@ def create_model_config(model_name: str, **kwargs) -> BaseModelConfig: return model_class(**filtered_kwargs) -def generate_pipeline( - pipeline_type: str, - emb_lookup_stream: str, - model: nn.Module, - opt: torch.optim.Optimizer, - device: torch.device, - apply_jit: bool = False, -) -> Union[TrainPipelineBase, TrainPipelineSparseDist]: - """ - Generate a training pipeline instance based on the configuration. - - This function creates and returns the appropriate training pipeline object - based on the pipeline type specified. Different pipeline types are optimized - for different training scenarios. - - Args: - pipeline_type (str): The type of training pipeline to use. Options include: - - "base": Basic training pipeline - - "sparse": Pipeline optimized for sparse operations - - "fused": Pipeline with fused sparse distribution - - "semi": Semi-synchronous training pipeline - - "prefetch": Pipeline with prefetching for sparse distribution - emb_lookup_stream (str): The stream to use for embedding lookups. - Only used by certain pipeline types (e.g., "fused"). - model (nn.Module): The model to be trained. - opt (torch.optim.Optimizer): The optimizer to use for training. - device (torch.device): The device to run the training on. - apply_jit (bool): Whether to apply JIT (Just-In-Time) compilation to the model. - Default is False. - - Returns: - Union[TrainPipelineBase, TrainPipelineSparseDist]: An instance of the - appropriate training pipeline class based on the configuration. - - Raises: - RuntimeError: If an unknown pipeline type is specified. - """ - - _pipeline_cls: Dict[ - str, Type[Union[TrainPipelineBase, TrainPipelineSparseDist]] - ] = { - "base": TrainPipelineBase, - "sparse": TrainPipelineSparseDist, - "fused": TrainPipelineFusedSparseDist, - "semi": TrainPipelineSemiSync, - "prefetch": PrefetchTrainPipelineSparseDist, - } - - if pipeline_type == "semi": - return TrainPipelineSemiSync( - model=model, - optimizer=opt, - device=device, - start_batch=0, - apply_jit=apply_jit, - ) - elif pipeline_type == "fused": - return TrainPipelineFusedSparseDist( - model=model, - optimizer=opt, - device=device, - emb_lookup_stream=emb_lookup_stream, - apply_jit=apply_jit, - ) - elif pipeline_type == "base": - assert apply_jit is False, "JIT is not supported for base pipeline" - - return TrainPipelineBase(model=model, optimizer=opt, device=device) - else: - Pipeline = _pipeline_cls[pipeline_type] - # pyre-ignore[28] - return Pipeline(model=model, optimizer=opt, device=device, apply_jit=apply_jit) - - def generate_data( tables: List[EmbeddingBagConfig], weighted_tables: List[EmbeddingBagConfig], diff --git a/torchrec/distributed/benchmark/embedding_collection_wrappers.py b/torchrec/distributed/benchmark/embedding_collection_wrappers.py index 78d967c52..3131b9216 100644 --- a/torchrec/distributed/benchmark/embedding_collection_wrappers.py +++ b/torchrec/distributed/benchmark/embedding_collection_wrappers.py @@ -57,12 +57,7 @@ from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor # Import the shared types and utilities from benchmark_utils -from .benchmark_base import ( - benchmark, - BenchmarkResult, - CompileMode, - multi_process_benchmark, -) +from .base import benchmark, BenchmarkResult, CompileMode, multi_process_benchmark logger: logging.Logger = logging.getLogger() diff --git a/torchrec/distributed/test_utils/test_tables.py b/torchrec/distributed/test_utils/table_config.py similarity index 98% rename from torchrec/distributed/test_utils/test_tables.py rename to torchrec/distributed/test_utils/table_config.py index e308947aa..2954ed085 100644 --- a/torchrec/distributed/test_utils/test_tables.py +++ b/torchrec/distributed/test_utils/table_config.py @@ -8,7 +8,7 @@ # pyre-strict from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List from torchrec.modules.embedding_configs import EmbeddingBagConfig diff --git a/torchrec/distributed/test_utils/train_pipeline.py b/torchrec/distributed/test_utils/train_pipeline.py new file mode 100644 index 000000000..49293fb2d --- /dev/null +++ b/torchrec/distributed/test_utils/train_pipeline.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from dataclasses import dataclass +from typing import Dict, Type, Union + +import torch +from torch import nn + +from torchrec.distributed.train_pipeline import ( + TrainPipelineBase, + TrainPipelineFusedSparseDist, + TrainPipelineSparseDist, +) +from torchrec.distributed.train_pipeline.train_pipelines import ( + PrefetchTrainPipelineSparseDist, + TrainPipelineSemiSync, +) + + +@dataclass +class PipelineConfig: + """ + Configuration for training pipelines. + + This class defines the parameters for configuring the training pipeline. + + Args: + pipeline (str): The type of training pipeline to use. Options include: + - "base": Basic training pipeline + - "sparse": Pipeline optimized for sparse operations + - "fused": Pipeline with fused sparse distribution + - "semi": Semi-synchronous training pipeline + - "prefetch": Pipeline with prefetching for sparse distribution + Default is "base". + emb_lookup_stream (str): The stream to use for embedding lookups. + Only used by certain pipeline types (e.g., "fused"). + Default is "data_dist". + apply_jit (bool): Whether to apply JIT (Just-In-Time) compilation to the model. + Default is False. + """ + + pipeline: str = "base" + emb_lookup_stream: str = "data_dist" + apply_jit: bool = False + + def generate_pipeline( + self, + model: nn.Module, + opt: torch.optim.Optimizer, + device: torch.device, + ) -> Union[TrainPipelineBase, TrainPipelineSparseDist]: + """ + Generate a training pipeline instance based on the configuration. + + This function creates and returns the appropriate training pipeline object + based on the pipeline type specified. Different pipeline types are optimized + for different training scenarios. + + Args: + pipeline_type (str): The type of training pipeline to use. Options include: + - "base": Basic training pipeline + - "sparse": Pipeline optimized for sparse operations + - "fused": Pipeline with fused sparse distribution + - "semi": Semi-synchronous training pipeline + - "prefetch": Pipeline with prefetching for sparse distribution + emb_lookup_stream (str): The stream to use for embedding lookups. + Only used by certain pipeline types (e.g., "fused"). + model (nn.Module): The model to be trained. + opt (torch.optim.Optimizer): The optimizer to use for training. + device (torch.device): The device to run the training on. + apply_jit (bool): Whether to apply JIT (Just-In-Time) compilation to the model. + Default is False. + + Returns: + Union[TrainPipelineBase, TrainPipelineSparseDist]: An instance of the + appropriate training pipeline class based on the configuration. + + Raises: + RuntimeError: If an unknown pipeline type is specified. + """ + + _pipeline_cls: Dict[ + str, Type[Union[TrainPipelineBase, TrainPipelineSparseDist]] + ] = { + "base": TrainPipelineBase, + "sparse": TrainPipelineSparseDist, + "fused": TrainPipelineFusedSparseDist, + "semi": TrainPipelineSemiSync, + "prefetch": PrefetchTrainPipelineSparseDist, + } + + if self.pipeline == "semi": + return TrainPipelineSemiSync( + model=model, + optimizer=opt, + device=device, + start_batch=0, + apply_jit=self.apply_jit, + ) + elif self.pipeline == "fused": + return TrainPipelineFusedSparseDist( + model=model, + optimizer=opt, + device=device, + emb_lookup_stream=self.emb_lookup_stream, + apply_jit=self.apply_jit, + ) + elif self.pipeline == "base": + assert self.apply_jit is False, "JIT is not supported for base pipeline" + + return TrainPipelineBase(model=model, optimizer=opt, device=device) + else: + Pipeline = _pipeline_cls[self.pipeline] + # pyre-ignore[28] + return Pipeline( + model=model, optimizer=opt, device=device, apply_jit=self.apply_jit + ) diff --git a/torchrec/sparse/tests/jagged_tensor_benchmark.py b/torchrec/sparse/tests/jagged_tensor_benchmark.py index 5d1d35068..3ced3e291 100644 --- a/torchrec/sparse/tests/jagged_tensor_benchmark.py +++ b/torchrec/sparse/tests/jagged_tensor_benchmark.py @@ -15,7 +15,7 @@ import click import torch -from torchrec.distributed.benchmark.benchmark_base import ( +from torchrec.distributed.benchmark.base import ( benchmark, BenchmarkResult, CPUMemoryStats, diff --git a/torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py b/torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py index 8d55f9a87..2bfb58d6e 100644 --- a/torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py +++ b/torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py @@ -20,7 +20,7 @@ # Otherwise will get error # NotImplementedError: fbgemm::permute_1D_sparse_data: We could not find the abstract impl for this operator. from fbgemm_gpu import sparse_ops # noqa: F401, E402 -from torchrec.distributed.benchmark.benchmark_base import ( +from torchrec.distributed.benchmark.base import ( BenchmarkResult, CPUMemoryStats, GPUMemoryStats,