From 6a15062fab3b6b0c7113a16fcc10a7d4c780ab7f Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 4 Jul 2024 09:18:17 -0400 Subject: [PATCH 01/10] Dockerfile: use custom cache manager in Triton. Co-authored-by: Chih-Chieh-Yang Signed-off-by: Thomas Parnell --- Dockerfile | 10 ++++++++- triton_patch/custom_cache_manager.py | 32 ++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) create mode 100644 triton_patch/custom_cache_manager.py diff --git a/Dockerfile b/Dockerfile index f571e8be421e..c821f2965ec5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -166,6 +166,13 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamba \ --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir + +# workaround for https://github.com/vllm-project/vllm/issues/6103 +# until fixed in Triton upstream +RUN --mount=type=bind,source=triton_patch,target=/context \ + export TRITON_PATH=$(python3 -c "import triton; print(triton.__file__.strip(\"__init__.py\"))") \ + && cp /context/custom_cache_manager.py ${TRITON_PATH}/runtime/custom_cache_manager.py \ + #################### vLLM installation IMAGE #################### @@ -197,7 +204,8 @@ FROM vllm-base AS vllm-openai RUN --mount=type=cache,target=/root/.cache/pip \ pip install accelerate hf_transfer 'modelscope!=1.15.0' -ENV VLLM_USAGE_SOURCE production-docker-image +ENV VLLM_USAGE_SOURCE production-docker-image \ + TRITON_CACHE_MANAGER "triton.runtime.custom_cache_manager:CustomCacheManager" ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] #################### OPENAI API SERVER #################### diff --git a/triton_patch/custom_cache_manager.py b/triton_patch/custom_cache_manager.py new file mode 100644 index 000000000000..c83ed5b6e865 --- /dev/null +++ b/triton_patch/custom_cache_manager.py @@ -0,0 +1,32 @@ +import os + +from triton.runtime.cache import (FileCacheManager, default_cache_dir, + default_dump_dir, default_override_dir) + + +class CustomCacheManager(FileCacheManager): + + def __init__(self, key, override=False, dump=False): + self.key = key + self.lock_path = None + if dump: + self.cache_dir = default_dump_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + elif override: + self.cache_dir = default_override_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) + else: + # create cache directory if it doesn't exist + self.cache_dir = os.getenv("TRITON_CACHE_DIR", + "").strip() or default_cache_dir() + if self.cache_dir: + self.cache_dir = f"{self.cache_dir}_{os.getpid()}" + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + else: + raise RuntimeError("Could not create or locate cache dir") + + print(f"Triton cache dir: {self.cache_dir=}") From 0d54387272482294708ad0e09dc0df71144dec11 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 4 Jul 2024 09:24:36 -0400 Subject: [PATCH 02/10] Minor error in Dockerfile Signed-off-by: Thomas Parnell --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index c821f2965ec5..357389dec3d4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -171,7 +171,7 @@ RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamb # until fixed in Triton upstream RUN --mount=type=bind,source=triton_patch,target=/context \ export TRITON_PATH=$(python3 -c "import triton; print(triton.__file__.strip(\"__init__.py\"))") \ - && cp /context/custom_cache_manager.py ${TRITON_PATH}/runtime/custom_cache_manager.py \ + && cp /context/custom_cache_manager.py ${TRITON_PATH}/runtime/custom_cache_manager.py #################### vLLM installation IMAGE #################### From b8031659ecd65c49d1784351d15e97faf28b9a59 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 4 Jul 2024 10:47:46 -0400 Subject: [PATCH 03/10] Include custom cache manager as part of vllm Signed-off-by: Thomas Parnell --- Dockerfile | 10 +------ .../layers/fused_moe/fused_moe.py | 7 +++++ vllm/triton_utils/custom_cache_manager.py | 30 +++++++++++++++++++ 3 files changed, 38 insertions(+), 9 deletions(-) create mode 100644 vllm/triton_utils/custom_cache_manager.py diff --git a/Dockerfile b/Dockerfile index 357389dec3d4..f571e8be421e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -166,13 +166,6 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamba \ --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir - -# workaround for https://github.com/vllm-project/vllm/issues/6103 -# until fixed in Triton upstream -RUN --mount=type=bind,source=triton_patch,target=/context \ - export TRITON_PATH=$(python3 -c "import triton; print(triton.__file__.strip(\"__init__.py\"))") \ - && cp /context/custom_cache_manager.py ${TRITON_PATH}/runtime/custom_cache_manager.py - #################### vLLM installation IMAGE #################### @@ -204,8 +197,7 @@ FROM vllm-base AS vllm-openai RUN --mount=type=cache,target=/root/.cache/pip \ pip install accelerate hf_transfer 'modelscope!=1.15.0' -ENV VLLM_USAGE_SOURCE production-docker-image \ - TRITON_CACHE_MANAGER "triton.runtime.custom_cache_manager:CustomCacheManager" +ENV VLLM_USAGE_SOURCE production-docker-image ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] #################### OPENAI API SERVER #################### diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 99a5c7d78a67..7d73590fe5c8 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -271,6 +271,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, **config, ) +def maybe_set_triton_cache_manager(module: str) -> None: + cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None) + if cache_manger != module: + os.environ["TRITON_CACHE_MANAGER"] = module def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: device_name = torch.cuda.get_device_name().replace(" ", "_") @@ -428,6 +432,9 @@ def fused_experts(hidden_states: torch.Tensor, CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE M = min(num_tokens, CHUNK_SIZE) + # workaround for https://github.com/vllm-project/vllm/issues/6103 + maybe_set_triton_cache_manager("vllm.triton_utils.custom_cache_manager:CustomCacheManager") + if override_config: config = override_config else: diff --git a/vllm/triton_utils/custom_cache_manager.py b/vllm/triton_utils/custom_cache_manager.py new file mode 100644 index 000000000000..5836249d79bb --- /dev/null +++ b/vllm/triton_utils/custom_cache_manager.py @@ -0,0 +1,30 @@ +import os + +from triton.runtime.cache import (FileCacheManager, default_cache_dir, + default_dump_dir, default_override_dir) + + +class CustomCacheManager(FileCacheManager): + + def __init__(self, key, override=False, dump=False): + self.key = key + self.lock_path = None + if dump: + self.cache_dir = default_dump_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + elif override: + self.cache_dir = default_override_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) + else: + # create cache directory if it doesn't exist + self.cache_dir = os.getenv("TRITON_CACHE_DIR", + "").strip() or default_cache_dir() + if self.cache_dir: + self.cache_dir = f"{self.cache_dir}_{os.getpid()}" + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + else: + raise RuntimeError("Could not create or locate cache dir") From d3ef0d8d3e0d9330b26d60cab912011dbbbe7b21 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 4 Jul 2024 10:48:10 -0400 Subject: [PATCH 04/10] Add __init__.py Signed-off-by: Thomas Parnell --- vllm/triton_utils/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 vllm/triton_utils/__init__.py diff --git a/vllm/triton_utils/__init__.py b/vllm/triton_utils/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 From eb5c8926804459e93912e73e48d512786280b5e5 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 4 Jul 2024 10:50:46 -0400 Subject: [PATCH 05/10] remove triton_patch dir Signed-off-by: Thomas Parnell --- triton_patch/custom_cache_manager.py | 32 ---------------------------- 1 file changed, 32 deletions(-) delete mode 100644 triton_patch/custom_cache_manager.py diff --git a/triton_patch/custom_cache_manager.py b/triton_patch/custom_cache_manager.py deleted file mode 100644 index c83ed5b6e865..000000000000 --- a/triton_patch/custom_cache_manager.py +++ /dev/null @@ -1,32 +0,0 @@ -import os - -from triton.runtime.cache import (FileCacheManager, default_cache_dir, - default_dump_dir, default_override_dir) - - -class CustomCacheManager(FileCacheManager): - - def __init__(self, key, override=False, dump=False): - self.key = key - self.lock_path = None - if dump: - self.cache_dir = default_dump_dir() - self.cache_dir = os.path.join(self.cache_dir, self.key) - self.lock_path = os.path.join(self.cache_dir, "lock") - os.makedirs(self.cache_dir, exist_ok=True) - elif override: - self.cache_dir = default_override_dir() - self.cache_dir = os.path.join(self.cache_dir, self.key) - else: - # create cache directory if it doesn't exist - self.cache_dir = os.getenv("TRITON_CACHE_DIR", - "").strip() or default_cache_dir() - if self.cache_dir: - self.cache_dir = f"{self.cache_dir}_{os.getpid()}" - self.cache_dir = os.path.join(self.cache_dir, self.key) - self.lock_path = os.path.join(self.cache_dir, "lock") - os.makedirs(self.cache_dir, exist_ok=True) - else: - raise RuntimeError("Could not create or locate cache dir") - - print(f"Triton cache dir: {self.cache_dir=}") From 81eef8ab3a78b77d1a7eaa853489e24b8a56f08c Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 4 Jul 2024 10:54:20 -0400 Subject: [PATCH 06/10] Format Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/fused_moe/fused_moe.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 7d73590fe5c8..fe8ba59f699d 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -271,11 +271,13 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, **config, ) + def maybe_set_triton_cache_manager(module: str) -> None: cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None) if cache_manger != module: os.environ["TRITON_CACHE_MANAGER"] = module + def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: device_name = torch.cuda.get_device_name().replace(" ", "_") dtype_selector = "" if not dtype else f",dtype={dtype}" @@ -433,7 +435,8 @@ def fused_experts(hidden_states: torch.Tensor, M = min(num_tokens, CHUNK_SIZE) # workaround for https://github.com/vllm-project/vllm/issues/6103 - maybe_set_triton_cache_manager("vllm.triton_utils.custom_cache_manager:CustomCacheManager") + maybe_set_triton_cache_manager( + "vllm.triton_utils.custom_cache_manager:CustomCacheManager") if override_config: config = override_config From b0406452233ce029ce251b9819c3c2746773377f Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 4 Jul 2024 15:52:22 -0400 Subject: [PATCH 07/10] Address review comments Signed-off-by: Thomas Parnell --- vllm/executor/multiproc_gpu_executor.py | 4 ++++ vllm/model_executor/layers/fused_moe/fused_moe.py | 10 ---------- vllm/triton_utils/__init__.py | 6 ++++++ vllm/triton_utils/custom_cache_manager.py | 12 ++++++++++++ 4 files changed, 22 insertions(+), 10 deletions(-) diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index dcde27973f8e..cddfeab32778 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -9,6 +9,7 @@ ResultHandler, WorkerMonitor) from vllm.logger import init_logger from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.triton_utils import maybe_set_triton_cache_manager from vllm.utils import (cuda_device_count_stateless, error_on_invalid_device_count_status, get_distributed_init_method, get_open_port, @@ -42,6 +43,9 @@ def _init_executor(self) -> None: if "OMP_NUM_THREADS" not in os.environ: os.environ["OMP_NUM_THREADS"] = "1" + # workaround for https://github.com/vllm-project/vllm/issues/6103 + maybe_set_triton_cache_manager() + assert world_size <= cuda_device_count_stateless(), ( "please set tensor_parallel_size to less than max local gpu count") diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index fe8ba59f699d..99a5c7d78a67 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -272,12 +272,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, ) -def maybe_set_triton_cache_manager(module: str) -> None: - cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None) - if cache_manger != module: - os.environ["TRITON_CACHE_MANAGER"] = module - - def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: device_name = torch.cuda.get_device_name().replace(" ", "_") dtype_selector = "" if not dtype else f",dtype={dtype}" @@ -434,10 +428,6 @@ def fused_experts(hidden_states: torch.Tensor, CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE M = min(num_tokens, CHUNK_SIZE) - # workaround for https://github.com/vllm-project/vllm/issues/6103 - maybe_set_triton_cache_manager( - "vllm.triton_utils.custom_cache_manager:CustomCacheManager") - if override_config: config = override_config else: diff --git a/vllm/triton_utils/__init__.py b/vllm/triton_utils/__init__.py index e69de29bb2d1..09843e5d1f30 100644 --- a/vllm/triton_utils/__init__.py +++ b/vllm/triton_utils/__init__.py @@ -0,0 +1,6 @@ +from vllm.triton_utils.custom_cache_manager import ( + maybe_set_triton_cache_manager) + +__all__ = [ + "maybe_set_triton_cache_manager", +] diff --git a/vllm/triton_utils/custom_cache_manager.py b/vllm/triton_utils/custom_cache_manager.py index 5836249d79bb..f1407bbe45b4 100644 --- a/vllm/triton_utils/custom_cache_manager.py +++ b/vllm/triton_utils/custom_cache_manager.py @@ -3,6 +3,18 @@ from triton.runtime.cache import (FileCacheManager, default_cache_dir, default_dump_dir, default_override_dir) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def maybe_set_triton_cache_manager() -> None: + cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None) + if cache_manger is None: + manager = "vllm.triton_utils.custom_cache_manager:CustomCacheManager" + logger.info("Setting Triton cache manager to: %s", manager) + os.environ["TRITON_CACHE_MANAGER"] = manager + class CustomCacheManager(FileCacheManager): From 4dd93673b01c062efbecc431ff73a522e2fb37b8 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 4 Jul 2024 16:00:49 -0400 Subject: [PATCH 08/10] Only change cache manager for tp>1 Signed-off-by: Thomas Parnell --- vllm/executor/multiproc_gpu_executor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index cddfeab32778..a0e248b2e199 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -44,7 +44,8 @@ def _init_executor(self) -> None: os.environ["OMP_NUM_THREADS"] = "1" # workaround for https://github.com/vllm-project/vllm/issues/6103 - maybe_set_triton_cache_manager() + if world_size > 1: + maybe_set_triton_cache_manager() assert world_size <= cuda_device_count_stateless(), ( "please set tensor_parallel_size to less than max local gpu count") From 889d6dd5aecfb36a04c2cc697a579fac72061e54 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 5 Jul 2024 14:56:22 -0400 Subject: [PATCH 09/10] Add some docstrings Signed-off-by: Thomas Parnell --- vllm/triton_utils/custom_cache_manager.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/vllm/triton_utils/custom_cache_manager.py b/vllm/triton_utils/custom_cache_manager.py index f1407bbe45b4..b0cbaaedf51d 100644 --- a/vllm/triton_utils/custom_cache_manager.py +++ b/vllm/triton_utils/custom_cache_manager.py @@ -9,6 +9,8 @@ def maybe_set_triton_cache_manager() -> None: + """Set environment variable to tell Triton to use a + custom cache manager""" cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None) if cache_manger is None: manager = "vllm.triton_utils.custom_cache_manager:CustomCacheManager" @@ -17,6 +19,11 @@ def maybe_set_triton_cache_manager() -> None: class CustomCacheManager(FileCacheManager): + """Re-implements Triton's cache manager, ensuring that a + unique cache directory is created for each process. This is + needed to avoid collisions when running with tp>1 and + using multi-processing as the distributed backend. + """ def __init__(self, key, override=False, dump=False): self.key = key From 3307522289fdfefe323b6c00d0db696651989a2f Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 11 Jul 2024 04:51:14 -0400 Subject: [PATCH 10/10] Update docstring Signed-off-by: Thomas Parnell --- vllm/triton_utils/custom_cache_manager.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/triton_utils/custom_cache_manager.py b/vllm/triton_utils/custom_cache_manager.py index b0cbaaedf51d..17039d7ba24c 100644 --- a/vllm/triton_utils/custom_cache_manager.py +++ b/vllm/triton_utils/custom_cache_manager.py @@ -23,6 +23,10 @@ class CustomCacheManager(FileCacheManager): unique cache directory is created for each process. This is needed to avoid collisions when running with tp>1 and using multi-processing as the distributed backend. + + Note this issue was fixed by triton-lang/triton/pull/4295, + but the fix is not yet included in triton==v3.0.0. However, + it should be included in the subsequent version. """ def __init__(self, key, override=False, dump=False):