diff --git a/fastdeploy/cache_manager/cache_messager.py b/fastdeploy/cache_manager/cache_messager.py index bd02a4c4f96..b57aa8931f7 100644 --- a/fastdeploy/cache_manager/cache_messager.py +++ b/fastdeploy/cache_manager/cache_messager.py @@ -132,7 +132,10 @@ def __init__( self.gpu_cache_kvs = gpu_cache_kvs self.rank = rank self.nranks = nranks - address = (pod_ip, engine_worker_queue_port) + if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM: + address = (pod_ip, engine_worker_queue_port) + else: + address = f"/dev/shm/fd_task_queue_{engine_worker_queue_port}.sock" self.engine_worker_queue = EngineWorkerQueue( address=address, is_server=False, @@ -423,7 +426,10 @@ def __init__( self.gpu_cache_kvs = gpu_cache_kvs self.rank = rank self.nranks = nranks - address = (pod_ip, engine_worker_queue_port) + if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM: + address = (pod_ip, engine_worker_queue_port) + else: + address = f"/dev/shm/fd_task_queue_{engine_worker_queue_port}.sock" self.engine_worker_queue = EngineWorkerQueue( address=address, is_server=False, diff --git a/fastdeploy/engine/async_llm.py b/fastdeploy/engine/async_llm.py index 11f809b995d..936f834777f 100644 --- a/fastdeploy/engine/async_llm.py +++ b/fastdeploy/engine/async_llm.py @@ -923,10 +923,13 @@ def launch_components(self): 1, self.cfg.parallel_config.data_parallel_size // self.cfg.nnode, ): - address = ( - self.cfg.master_ip, - int(self.cfg.parallel_config.engine_worker_queue_port[i]), - ) + if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM: + address = ( + self.cfg.master_ip, + int(self.cfg.parallel_config.engine_worker_queue_port[i]), + ) + else: + address = f"/dev/shm/fd_task_queue_{self.cfg.parallel_config.engine_worker_queue_port[i]}.sock" llm_logger.info(f"dp start queue service {address}") self.dp_engine_worker_queue_server.append( EngineWorkerQueue( diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index abbab33f67e..b88a4d0f054 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -270,10 +270,16 @@ def start_worker_queue_service(self, start_queue): """ start queue service for engine worker communication """ - address = ( - self.cfg.master_ip, - int(self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]), - ) + + if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM: + address = ( + self.cfg.master_ip, + int( + self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id] + ), + ) + else: + address = f"/dev/shm/fd_task_queue_{self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]}.sock" if start_queue and (self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0"): self.llm_logger.info(f"Starting engine worker queue server service at {address}") @@ -284,15 +290,18 @@ def start_worker_queue_service(self, start_queue): local_data_parallel_size=self.cfg.parallel_config.data_parallel_size, ) # Dynamically updates the port value if an anonymous port is used - self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id] = str( - self.engine_worker_queue_server.get_server_port() - ) - address = ( - self.cfg.master_ip, - int( - self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id] - ), - ) + if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM: + self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id] = ( + str(self.engine_worker_queue_server.get_server_port()) + ) + address = ( + self.cfg.master_ip, + int( + self.cfg.parallel_config.engine_worker_queue_port[ + self.cfg.parallel_config.local_data_parallel_id + ] + ), + ) if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed": self.cache_task_queue = EngineCacheQueue( diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 5a4b7b39a1b..405bb0cc2be 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -707,7 +707,9 @@ def launch_components(self): for i in range(self.cfg.parallel_config.data_parallel_size): request_queues_for_dp_ipc.append(multiprocessing.Queue()) self.engine.scheduler.start( - self.cfg.node_rank * self.cfg.worker_num_per_node, request_queues_for_dp_ipc, result_queue_for_dp_ipc + self.cfg.node_rank * self.cfg.worker_num_per_node % self.cfg.worker_num_per_node, + request_queues_for_dp_ipc, + result_queue_for_dp_ipc, ) if not envs.FD_ENABLE_MULTI_API_SERVER: @@ -719,10 +721,14 @@ def launch_components(self): 1, self.cfg.parallel_config.data_parallel_size // self.cfg.nnode, ): - address = ( - self.cfg.master_ip, - int(self.cfg.parallel_config.engine_worker_queue_port[i]), - ) + if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM: + address = ( + self.cfg.master_ip, + int(self.cfg.parallel_config.engine_worker_queue_port[i]), + ) + else: + address = f"/dev/shm/fd_task_queue_{self.cfg.parallel_config.engine_worker_queue_port[i]}.sock" + llm_logger.info(f"dp start queue service {address}") self.dp_engine_worker_queue_server.append( EngineWorkerQueue( diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 96a36aa4813..743a35fc512 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -161,6 +161,7 @@ def __init__( self.extend_block_tables = [] # dp self.dp_rank = dp_rank + self.llm_engine_recv_req_timestamp = time.time() @classmethod def from_dict(cls, d: dict): diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 34013d52e23..9601a6b3273 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -155,6 +155,8 @@ "ENCODE_FEATURE_BOS_SK": lambda: os.getenv("ENCODE_FEATURE_BOS_SK"), # Enable offline perf test mode for PD disaggregation "FD_OFFLINE_PERF_TEST_FOR_PD": lambda: int(os.getenv("FD_OFFLINE_PERF_TEST_FOR_PD", "0")), + "FD_ENABLE_E2W_TENSOR_CONVERT": lambda: int(os.getenv("FD_ENABLE_E2W_TENSOR_CONVERT", "0")), + "FD_ENGINE_TASK_QUEUE_WITH_SHM": lambda: int(os.getenv("FD_ENGINE_TASK_QUEUE_WITH_SHM", "0")), # ep+tp strategy: "all_reduce" or "all_to_all" # all_reduce: qkv_linear + attn + out_linear + allreduce # all_to_all: allgather + qkv_linear + attn + all2all + out_linear diff --git a/fastdeploy/inter_communicator/engine_worker_queue.py b/fastdeploy/inter_communicator/engine_worker_queue.py index 7544db6fdc5..ffa9155bab7 100644 --- a/fastdeploy/inter_communicator/engine_worker_queue.py +++ b/fastdeploy/inter_communicator/engine_worker_queue.py @@ -485,24 +485,38 @@ def _connect_with_retry(self, max_retries: int = 5, interval: int = 3) -> None: @staticmethod def to_tensor(tasks): """ - Convert NumPy arrays in multimodal inputs to PaddlePaddle tensors. + Convert NumPy arrays in multimodal inputs to Paddle tensors. Args: - tasks: List of tasks containing multimodal inputs. + tasks (tuple): ([request], bsz) """ + if (not envs.FD_ENABLE_MAX_PREFILL) and (not envs.FD_ENABLE_E2W_TENSOR_CONVERT): + return try: - if envs.FD_ENABLE_MAX_PREFILL: - llm_logger.debug(f"Convert image to tensor, type: {type(tasks)}") - batch_tasks, _ = tasks - for task in batch_tasks: - if not hasattr(task, "multimodal_inputs"): + batch_tasks, _ = tasks + for task in batch_tasks: + multimodal_inputs = getattr(task, "multimodal_inputs", None) + if not multimodal_inputs: + continue + # tensor keys + tensor_keys = [ + "images", + "patch_idx", + "token_type_ids", + "position_ids", + "attention_mask_offset", + ] + + llm_logger.debug(f"Converting multimodal inputs to tensor...{tensor_keys}") + + for key in tensor_keys: + value = multimodal_inputs.get(key) + if value is None: continue - images = task.multimodal_inputs["images"] - if isinstance(images, np.ndarray): - llm_logger.debug(f"Convert image to tensor, shape: {images.shape}") - task.multimodal_inputs["images"] = paddle.to_tensor(images) + if not isinstance(value, paddle.Tensor): + multimodal_inputs[key] = paddle.to_tensor(value) except Exception as e: - llm_logger.warning(f"Failed to convert to tensor: {e}") + llm_logger.warning(f"Tensor conversion failed: {type(e).__name__}: {e}") @staticmethod def to_numpy(tasks): diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py index e4247e0a599..978345a89c2 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py @@ -108,10 +108,22 @@ def init_ep(self, layer: nn.Layer) -> None: # For non-mixed ep phase = config.model_config.moe_phase.phase - if phase == "prefill": - self.ep_prefill_runner = self.EPPrefillRunner(**common_args) + if current_platform.is_cuda(): + if phase == "prefill": + self.ep_prefill_runner = self.EPPrefillRunner( + **common_args, + use_internode_ll_two_stage=layer.fd_config.parallel_config.use_internode_ll_two_stage, + ) + else: + self.ep_decoder_runner = self.EPDecoderRunner( + **common_args, + use_internode_ll_two_stage=layer.fd_config.parallel_config.use_internode_ll_two_stage, + ) else: - self.ep_decoder_runner = self.EPDecoderRunner(**common_args) + if phase == "prefill": + self.ep_prefill_runner = self.EPPrefillRunner(**common_args) + else: + self.ep_decoder_runner = self.EPDecoderRunner(**common_args) def process_loaded_weights(self, layer, weights) -> None: """ diff --git a/fastdeploy/splitwise/splitwise_connector.py b/fastdeploy/splitwise/splitwise_connector.py index a56ecefecd2..e5ad4ad8adc 100644 --- a/fastdeploy/splitwise/splitwise_connector.py +++ b/fastdeploy/splitwise/splitwise_connector.py @@ -330,8 +330,13 @@ def create_connection(self, port): Parameters: port (int): Port number. """ + if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM: + address = ("0.0.0.0", int(port)) + else: + address = f"/dev/shm/fd_task_queue_{port}.sock" + self.connect_innode_instances[port] = EngineWorkerQueue( - address=("0.0.0.0", int(port)), + address=address, num_client=self.cfg.parallel_config.tensor_parallel_size, client_id=0, ) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 5178e2b8ccc..53583962c65 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -569,6 +569,10 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = if request.sampling_params is not None and request.sampling_params.prompt_logprobs is not None: self.prompt_logprobs_reqs[request.request_id] = request has_prefill_task = True + if ( + self.fd_config.scheduler_config.splitwise_role == "decode" + ): # In PD, we continue to decode after P generate first token + self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0 elif request.task_type.value == RequestType.DECODE.value: # decode task logger.debug(f"Handle decode request {request} at idx {idx}") encoder_block_num = len(request.block_tables) diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index eff28d904d9..44a8a2fbcaa 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -556,10 +556,13 @@ def init_device(self) -> None: def start_task_queue_service(self): # Initialize task queue - task_address = ( - self.parallel_config.pod_ip, - self.parallel_config.engine_worker_queue_port, - ) + if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM: + task_address = ( + self.parallel_config.pod_ip, + self.parallel_config.engine_worker_queue_port, + ) + else: + task_address = f"/dev/shm/fd_task_queue_{self.parallel_config.engine_worker_queue_port}.sock" logger.info(f"connect task queue address {task_address}") self.task_queue = TaskQueue( address=task_address,