Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,8 +841,13 @@ def __init__(
Now don't support capture both decode-only and prefill-only"""
self.full_cuda_graph: bool = True

""" Maximum CUDA Graph capture size """
self.max_capture_size: int = None
""" Record maps mapped from real shape to captured size to reduce runtime overhead """
self.real_shape_to_captured_size: dict[int, int] = None
""" Whether to use shared memory pool for multi capture_size """
self.use_unique_memory_pool: bool = False

# CINN Config ...
if args is not None:
for key, value in args.items():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
import paddle.nn.layer
from paddle.device.cuda import graphs

if paddle.is_compiled_with_cuda():
from paddle.base.core import CUDAGraph

from fastdeploy import envs
from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication import capture_custom_allreduce
Expand Down Expand Up @@ -96,6 +99,9 @@ def __init__(
self.cudagraph_capture_sizes = fd_config.graph_opt_config.cudagraph_capture_sizes
self.warm_up_size = fd_config.graph_opt_config.cudagraph_num_of_warmups
self.real_shape_to_captured_size = fd_config.graph_opt_config.real_shape_to_captured_size
self.unique_memory_pool_id = None
if self.fd_config.graph_opt_config.use_unique_memory_pool:
self.unique_memory_pool_id = CUDAGraph.gen_new_memory_pool_id()

self._create_entry_dict()

Expand Down Expand Up @@ -169,7 +175,11 @@ def __call__(self, **kwargs) -> List[paddle.Tensor] | paddle.Tensor:
input_addresses = [x.data_ptr() for (_, x) in kwargs.items() if isinstance(x, paddle.Tensor)]
entry.input_addresses = input_addresses

new_grpah = graphs.CUDAGraph()
new_grpah = (
graphs.CUDAGraph(pool_id=self.unique_memory_pool_id)
if self.fd_config.graph_opt_config.use_unique_memory_pool
else graphs.CUDAGraph()
)
paddle.device.synchronize()

# Capture
Expand Down
Loading