Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -540,10 +540,10 @@ std::vector<paddle::Tensor> count_tokens_per_expert_func(
const paddle::Tensor& topk_ids,
int64_t num_experts,
bool compute_padded_cumsum = false);
void GetPositionIdsAndMaskEncoderBatch(const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& position_ids);
void GetPositionIds(const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& position_ids);

std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
const paddle::Tensor& kv_nope,
Expand Down Expand Up @@ -1639,9 +1639,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("is_zp_float"));
#endif

m.def("get_position_ids_and_mask_encoder_batch",
&GetPositionIdsAndMaskEncoderBatch,
"get_position_ids_and_mask_encoder_batch function");
m.def("get_position_ids", &GetPositionIds, "get_position_ids function");

/**
* cutlass_scaled_mm.cu
Expand Down
67 changes: 67 additions & 0 deletions custom_ops/gpu_ops/get_position_ids.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "helper.h"
#include "paddle/extension.h"

__global__ void GetPositionIdsKernel(const int* __restrict__ seq_lens_encoder,
const int* __restrict__ seq_lens_decoder,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 seq_lens_encoder 参数传入 GetPositionIdsKernel 但在函数体中从未使用(内核中所有偏移量均由 seq_lens_this_timeseq_lens_decoder 计算)。

建议删除该参数,或在注释中说明保留原因(如 ABI/API 兼容性)。若保留,建议加 (void)seq_lens_encoder; 以避免编译器警告。

// 当前:参数声明但从未读取
__global__ void GetPositionIdsKernel(const int* __restrict__ seq_lens_encoder,  // 未使用
                                     const int* __restrict__ seq_lens_decoder,
                                     ...

const int* __restrict__ seq_lens_this_time,
int* __restrict__ position_ids,
const int bsz) {
int current_bid = threadIdx.x;
if (current_bid >= bsz) return;

// Caculate the offset of current batch in the position_ids buffer
int buffer_offset = 0;
for (int i = 0; i < current_bid; i++) {
buffer_offset += seq_lens_this_time[i];
}

// Caculate the token offset in the current batch
int token_offset = seq_lens_decoder[current_bid];
int token_num_this_batch = seq_lens_this_time[current_bid];
if (token_num_this_batch == 0) return;

// Write position ids for current batch
#pragma unroll
for (int i = 0; i < token_num_this_batch; i++) {
position_ids[buffer_offset + i] = token_offset + i;
}
}

void GetPositionIds(const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& position_ids) {
const int bsz = seq_lens_this_time.shape()[0];

GetPositionIdsKernel<<<1, bsz, 0, position_ids.stream()>>>(
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
seq_lens_this_time.data<int>(),
const_cast<int*>(position_ids.data<int>()),
bsz);
}

PD_BUILD_STATIC_OP(get_position_ids)
.Inputs({
"seq_lens_encoder",
"seq_lens_decoder",
"seq_lens_this_time",
"position_ids",
})
.Outputs({"position_ids_out"})
.SetInplaceMap({{"position_ids", "position_ids_out"}})
.SetKernelFn(PD_KERNEL(GetPositionIds));
79 changes: 0 additions & 79 deletions custom_ops/gpu_ops/get_position_ids_and_mask_encoder_batch.cu

This file was deleted.

6 changes: 3 additions & 3 deletions custom_ops/setup_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def find_end_files(directory, end_str):
"gpu_ops/speculate_decoding/speculate_step.cu",
"gpu_ops/speculate_decoding/speculate_step_system_cache.cu",
"gpu_ops/speculate_decoding/speculate_update_v3.cu",
"gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
"gpu_ops/get_position_ids.cu",
"gpu_ops/fused_rotary_position_encoding.cu",
"gpu_ops/step_reschedule.cu",
]
Expand Down Expand Up @@ -326,7 +326,7 @@ def find_end_files(directory, end_str):
"gpu_ops/sample_kernels/rejection_top_p_sampling.cu",
"gpu_ops/sample_kernels/top_k_renorm_probs.cu",
"gpu_ops/sample_kernels/min_p_sampling_from_probs.cu",
"gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
"gpu_ops/get_position_ids.cu",
"gpu_ops/fused_rotary_position_encoding.cu",
"gpu_ops/noaux_tc.cu",
"gpu_ops/noaux_tc_redundant.cu",
Expand Down Expand Up @@ -687,7 +687,7 @@ def find_end_files(directory, end_str):
"gpu_ops/fused_rotary_position_encoding.cu",
"gpu_ops/text_image_gather_scatter.cu",
"gpu_ops/text_image_index_out.cu",
"gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
"gpu_ops/get_position_ids.cu",
"gpu_ops/limit_thinking_content_length.cu",
"gpu_ops/update_attn_mask_offsets.cu",
"gpu_ops/append_attn/mla_cache_kernel.cu",
Expand Down
107 changes: 50 additions & 57 deletions fastdeploy/model_executor/layers/moe/routing_indices_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,8 @@ class RoutedExpertsCapturer:
Does NOT manage request lifecycle — that is handled by RoutingCacheManager on the Engine side.
"""

def __init__(self, fd_config: FDConfig, block_table, total_block_num):
def __init__(self, fd_config: FDConfig, total_block_num: int):
self.fd_config = fd_config
self.block_table = block_table
self.max_num_seqs = fd_config.scheduler_config.max_num_seqs

# Read routing params from centralized config
Expand All @@ -125,20 +124,23 @@ def __init__(self, fd_config: FDConfig, block_table, total_block_num):
logger.info(f"[R3] RoutedExpertsCapturer config: {rrc}")

self._init_routing_cache(dtype=self.routing_dtype, total_block_num=total_block_num)
self.pending_update_positions = None

def _init_routing_cache(self, dtype: str, total_block_num: int):
"""Initialize GPU transient buffer and prepare lazy SharedMemory attach."""
"""Initialize GPU transient buffer, staging buffers, and CPU pinned buffers."""
max_num_kv_tokens = total_block_num * self.fd_config.cache_config.block_size

# Small GPU transient buffer: only current step's token routing
# TODO(Chengyanfu): Use max_num_batched_tokens to replace get_max_chunk_tokens()
max_num_batched_tokens = self.fd_config.get_max_chunk_tokens()
self.gpu_routing_buffer = paddle.full(
shape=[max_num_batched_tokens, self.num_moe_layers, self.moe_top_k],
fill_value=-1,
dtype=dtype,
)
shape = [max_num_batched_tokens, self.num_moe_layers, self.moe_top_k]

self.gpu_routing_buffer = paddle.full(shape=shape, fill_value=-1, dtype=dtype)
self.routing_staging_buf = paddle.full(shape=shape, fill_value=-1, dtype=dtype)
self.slot_mapping_staging_buf = paddle.zeros([max_num_batched_tokens], dtype=paddle.int64)

self.cpu_routing_buf = paddle.zeros(shape, dtype=dtype).pin_memory()
self.cpu_slot_mapping_buf = paddle.zeros([max_num_batched_tokens], dtype=paddle.int64).pin_memory()
self._pending_save = None # {"num_tokens": int}

# Lazy attach to SharedMemory routing_host_buffer (created by Engine after profiling)
self.routing_host_view = None
Expand Down Expand Up @@ -173,67 +175,58 @@ def _try_attach_routing_host_view(self):
"Routing capture will be skipped."
)

def save_captured_routing(self, num_tokens: int, slot_mapping: np.ndarray):
def prepare_pending_save(self, num_tokens: int, slot_mapping_gpu: paddle.Tensor):
"""
Enqueue D2D + async D2H for routing data and slot_mapping.
Must be called before post_process_event.record().
All ops are enqueued on the current CUDA stream; CPU returns immediately.

1. D2D (non-blocking): gpu_routing_buffer → routing_staging_buf
2. D2D (non-blocking): slot_mapping_gpu → slot_mapping_staging_buf
3. async D2H: routing_staging_buf → cpu_routing_buf
4. async D2H: slot_mapping_staging_buf → cpu_slot_mapping_buf
"""
if num_tokens > 0:
# D2D: GPU → staging
self.routing_staging_buf.copy_(self.gpu_routing_buffer, False)
self.slot_mapping_staging_buf.copy_(slot_mapping_gpu, False)
# async D2H: staging → CPU pinned
self.cpu_routing_buf.copy_(self.routing_staging_buf, False)
self.cpu_slot_mapping_buf.copy_(self.slot_mapping_staging_buf, False)
self._pending_save = {"num_tokens": num_tokens}
else:
self._pending_save = None

def flush_pending_save(self):
"""
After forward, scatter GPU buffer routing data to routing_host_buffer.
Called in step gap (post_process), not during forward. CUDAGraph compatible.
Pure CPU operation. Called after post_process_event.synchronize(),
which guarantees all D2D and D2H transfers have completed.
Scatter from CPU pinned buffers to SharedMemory.
"""
assert slot_mapping.shape[0] == num_tokens
if num_tokens == 0:
pending = self._pending_save
if pending is None:

This comment was marked as outdated.

return

# Lazy attach to SharedMemory (Engine creates it after profiling completes)
if self.routing_host_view is None and not self._routing_host_view_attach_attempted:
self._try_attach_routing_host_view()
self._pending_save = None

if self.routing_host_view is None:
return
if not self._routing_host_view_attach_attempted:
self._try_attach_routing_host_view()
if self.routing_host_view is None:
return

# D2H copy: GPU → CPU numpy, then scatter to SharedMemory
data = self.gpu_routing_buffer[:num_tokens].cpu().numpy()
self.routing_host_view.scatter(slot_mapping, data)
num_tokens = pending["num_tokens"]
data = self.cpu_routing_buf[:num_tokens].numpy()
slot_np = self.cpu_slot_mapping_buf[:num_tokens].numpy()

def compute_slot_mapping_flat(self, positions) -> np.ndarray:
"""
Compute flat slot_mapping for all tokens in the step.
Returns a 1D numpy array of slot indices.
"""
all_slots = []
block_size = self.fd_config.cache_config.block_size
for batch_id, position in enumerate(positions):
if len(position) == 0:
continue
block_table_indices = position // block_size
token_block_ids = self.block_table[batch_id, block_table_indices]
block_offset = position % block_size
token_cache_ids = np.array(token_block_ids) * block_size + block_offset
all_slots.append(token_cache_ids)
if all_slots:
return np.concatenate(all_slots)
return np.array([], dtype=np.int64)

def get_token_positions(self, seq_lens_decoder, seq_lens_this_time):
"""Get token position of each sequence in a batch."""
starts = seq_lens_decoder.numpy()
increase_num = seq_lens_this_time.numpy()

positions = []
for i in range(seq_lens_this_time.shape[0]):
if increase_num[i] == 0:
positions.append([])
continue
repeated_base = np.repeat(starts[i], increase_num[i])
positions.append(repeated_base + np.arange(0, increase_num[i]))

return positions
self.routing_host_view.scatter(slot_np, data)

def get_gpu_routing_buffer(self) -> paddle.Tensor:
return self.gpu_routing_buffer

def clear(self):
"""Clear GPU buffer and pending positions. Used during RL round cleanup."""
"""Clear GPU buffer and pending save state. Used during RL round cleanup."""
self.gpu_routing_buffer.fill_(-1)
self.pending_update_positions = None
self._pending_save = None


# Backward compatibility alias
Expand Down
28 changes: 6 additions & 22 deletions fastdeploy/model_executor/pre_and_post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,18 +339,10 @@ def post_process_normal(

# Routing replay
if routing_replay_manager is not None:
# Trigger lazy SharedMemory attach if not yet attempted
routing_replay_manager._try_attach_routing_host_view()
# GPU transient buffer → SharedMemory routing_host_buffer
slot_mapping_flat = routing_replay_manager.compute_slot_mapping_flat(
positions=routing_replay_manager.pending_update_positions
)
num_tokens = len(slot_mapping_flat)
slot_mapping_gpu = share_inputs["slot_mapping_buffer"]
num_tokens = int(share_inputs["ids_remove_padding"].shape[0])
if routing_replay_manager.tp_rank == 0:
routing_replay_manager.save_captured_routing(
num_tokens=num_tokens,
slot_mapping=slot_mapping_flat,
)
routing_replay_manager.prepare_pending_save(num_tokens, slot_mapping_gpu)

# 2. Update the input buffer of the model
with paddle.framework._no_check_dy2st_diff():
Expand Down Expand Up @@ -521,18 +513,10 @@ def post_process_speculate(

# Routing replay
if routing_replay_manager is not None:
# Trigger lazy SharedMemory attach if not yet attempted
routing_replay_manager._try_attach_routing_host_view()
# GPU transient buffer → SharedMemory routing_host_buffer
slot_mapping_flat = routing_replay_manager.compute_slot_mapping_flat(
positions=routing_replay_manager.pending_update_positions
)
num_tokens = len(slot_mapping_flat)
slot_mapping_gpu = share_inputs["slot_mapping_buffer"]
num_tokens = int(share_inputs["ids_remove_padding"].shape[0])
if routing_replay_manager.tp_rank == 0:
routing_replay_manager.save_captured_routing(
num_tokens=num_tokens,
slot_mapping=slot_mapping_flat,
)
routing_replay_manager.prepare_pending_save(num_tokens, slot_mapping_gpu)

# Unified state update: merges speculate_update + speculate_set_value_by_flags_and_idx
# into a single kernel launch. Handles EOS detection, max_dec_len truncation, step_idx
Expand Down
Loading
Loading