Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 10 additions & 36 deletions csrc/gpu/speculate_decoding_kernels/ngram_match.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ void find_candidate_pred_tokens(const int64_t *input_ids,
int32_t *seq_lens_this_time,
int32_t *seq_lens_encoder,
int32_t *seq_lens_decoder,
int64_t *max_dec_len,
int64_t input_ids_stride,
int64_t pre_ids_stride,
int64_t draft_tokens_stride,
Expand All @@ -55,8 +56,8 @@ void find_candidate_pred_tokens(const int64_t *input_ids,
}
}
for (int batch_idx = 0; batch_idx < real_batch_size; batch_idx++) {
max_draft_tokens = draft_token_num[batch_idx];
// int local_draft_tokens = max_draft_tokens;
max_draft_tokens = std::min(static_cast<int64_t>(
draft_token_num[batch_idx]), max_dec_len[batch_idx] - step_idx[batch_idx] - 1);
if (seq_lens_encoder[batch_idx] > 0) {
continue;
} else if (seq_lens_decoder[batch_idx] == 0) {
Expand Down Expand Up @@ -90,14 +91,7 @@ void find_candidate_pred_tokens(const int64_t *input_ids,
continue;
}
const int64_t *ngram = cur_pre_ids + (cur_step_idx + 1 - ngram_size);
#ifdef _DEBUG
if (batch_idx == 0) {
for (int mm = 0; mm < ngram_size; mm++) {
printf("idx %d: %lld\n", mm, ngram[mm]);
}
}
printf("cur_input_ids_len %d\n", cur_input_ids_len);
#endif

// Iterate through sliding windows of size ngram_size
bool match_input = false;
for (int64_t i = 0; i <= cur_input_ids_len - ngram_size; ++i) {
Expand All @@ -114,13 +108,7 @@ void find_candidate_pred_tokens(const int64_t *input_ids,
int64_t end_idx = std::min(start_idx + max_draft_tokens, cur_input_ids_len);
if (start_idx >= end_idx)
continue;
#ifdef _DEBUG
printf("batch_idx:%d. ngram_size:%d. idx:%lld. \n", batch_idx, ngram_size, i);
printf("start:%d. end:%d.\n", start_idx, end_idx);
#endif
// Ensure we don't go beyond the length of input_ids and avoid self-match
// if (end_idx <= cur_input_ids_len && start_idx < cur_input_ids_len - ngram_size) {
// Return a pointer to the next num_pred_tokens

int64_t cur_draft_token_num = end_idx - start_idx;

seq_lens_this_time[batch_idx] = cur_draft_token_num + 1;
Expand All @@ -133,15 +121,10 @@ void find_candidate_pred_tokens(const int64_t *input_ids,
}
}
if (!match_input) {
#ifdef _DEBUG
printf("match_input is false so match output\n");
#endif
for (int64_t i = 0; i <= cur_step_idx - ngram_size; ++i) {
// Check if the current window matches the ngram
bool match = true;
#ifdef _DEBUG
printf("search %d token in pre_ids\n", i);
#endif

for (int j = 0; j < ngram_size; j++) {
if (ngram[j] != cur_pre_ids[i + j]) {
match = false;
Expand All @@ -150,26 +133,14 @@ void find_candidate_pred_tokens(const int64_t *input_ids,
}

if (match) {
#ifdef _DEBUG
printf("%d token in pre_ids matched\n", i);
#endif
int64_t start_idx = i + ngram_size;
int64_t end_idx = std::min(start_idx + max_draft_tokens, cur_step_idx);
int64_t cur_draft_token_num = end_idx - start_idx;
if (start_idx >= end_idx)
continue;

#ifdef _DEBUG
printf("cur_step_idx %d, start_idx %lld, end_idx %lld, cur_draft_token_num is %lld\n",
cur_step_idx,
start_idx,
end_idx,
cur_draft_token_num);
#endif

seq_lens_this_time[batch_idx] = cur_draft_token_num + 1;
memcpy(cur_draft_tokens + 1, cur_pre_ids + start_idx, sizeof(int64_t) * cur_draft_token_num);
// To break the current batch_idx for-loop
ngram_size = 0;
break;
}
Expand All @@ -188,6 +159,7 @@ void NgramMatch(const paddle::Tensor &input_ids,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &max_dec_len,
const int real_batch_size,
const int max_ngram_size,
const int max_draft_tokens) {
Expand All @@ -210,6 +182,7 @@ void NgramMatch(const paddle::Tensor &input_ids,
const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()),
const_cast<int32_t *>(seq_lens_encoder.data<int32_t>()),
const_cast<int32_t *>(seq_lens_decoder.data<int32_t>()),
const_cast<int64_t *>(max_dec_len.data<int64_t>()),
input_ids_stride,
pre_ids_stride,
draft_tokens_stride,
Expand All @@ -227,7 +200,8 @@ PD_BUILD_OP(ngram_match)
"draft_tokens",
"seq_lens_this_time",
"seq_lens_encoder",
"seq_lens_decoder"})
"seq_lens_decoder",
"max_dec_len"})
.Attrs({"real_batch_size: int", "max_ngram_size: int", "max_draft_tokens: int"})
.Outputs({"draft_tokens_out", "seq_lens_this_time_out"})
.SetKernelFn(PD_KERNEL(NgramMatch))
Expand Down
42 changes: 42 additions & 0 deletions csrc/gpu/speculate_decoding_kernels/speculate_clear_accept_nums.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "helper.h"

__global__ void speculate_clear_accept_nums_kernel(
int* accept_num,
const int* seq_lens_decoder,
const int max_bsz
) {
const int bid = threadIdx.x;
if (bid >= max_bsz) return;
accept_num[bid] = seq_lens_decoder[bid] == 0 ? 0 : accept_num[bid];

}

void SpeculateClearAcceptNums(const paddle::Tensor& accept_num,
const paddle::Tensor& seq_lens_decoder
) {
// printf("enter clear \n");
const int max_bsz = seq_lens_decoder.shape()[0];
speculate_clear_accept_nums_kernel<<<1, 1024, 0, accept_num.stream()>>>(const_cast<int*>(accept_num.data<int>()),
seq_lens_decoder.data<int>(), max_bsz);
}

PD_BUILD_OP(speculate_clear_accept_nums)
.Inputs({"accept_num",
"seq_lens_decoder"})
.Outputs({"seq_lens_decoder_out"})
.SetInplaceMap({{"seq_lens_decoder", "seq_lens_decoder_out"}})
.SetKernelFn(PD_KERNEL(SpeculateClearAcceptNums));
140 changes: 140 additions & 0 deletions csrc/gpu/speculate_decoding_kernels/speculate_update.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "helper.h"

template <int THREADBLOCK_SIZE>
__global__ void speculate_update(int *seq_lens_encoder,
int *seq_lens_decoder,
bool *not_need_stop,
int64_t *draft_tokens,
int *actual_draft_token_nums,
const int64_t *accept_tokens,
const int *accept_num,
const bool *stop_flags,
const int *seq_lens_this_time,
const bool *is_block_step,
const int real_bsz,
const int max_draft_tokens) {
const int bid = threadIdx.x;
const int accept_num_now = accept_num[bid];
int stop_flag_now_int = 0;
if (!(is_block_step[bid] || bid >= real_bsz)) {
if (stop_flags[bid]) {
stop_flag_now_int = 1;
}
if (seq_lens_encoder[bid] == 0) {
seq_lens_decoder[bid] += accept_num_now;
}

if (seq_lens_this_time[bid] > 1 &&
seq_lens_encoder[bid] ==
0) { // 对于append模式,需要根据接收与否确定是否要降低下次draft
// token的数量
auto current_actual_draft_token_num = actual_draft_token_nums[bid];
if (accept_num_now - 1 == current_actual_draft_token_num) {
if (current_actual_draft_token_num + 2 <=
max_draft_tokens - 1) {
actual_draft_token_nums[bid] =
current_actual_draft_token_num + 2;
} else if (current_actual_draft_token_num + 1 <=
max_draft_tokens - 1) {
actual_draft_token_nums[bid] =
current_actual_draft_token_num + 1;
} else {
actual_draft_token_nums[bid] = max_draft_tokens - 1;
}
} else {
actual_draft_token_nums[bid] =
actual_draft_token_nums[bid] - 1 >= 1
? actual_draft_token_nums[bid] - 1
: 1;
}
}

if (seq_lens_encoder[bid] != 0) {
seq_lens_decoder[bid] += seq_lens_encoder[bid];
seq_lens_encoder[bid] = 0;
}
if (!stop_flags[bid]) {
draft_tokens[bid * max_draft_tokens] =
accept_tokens[bid * max_draft_tokens + accept_num_now - 1];
}
if (stop_flag_now_int) {
seq_lens_decoder[bid] = 0;
}
}
__syncthreads();
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;

int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int);

if (threadIdx.x == 0) {
not_need_stop[0] = stop_sum < real_bsz;
}
}

void SpeculateUpdate(const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &not_need_stop,
const paddle::Tensor &draft_tokens,
const paddle::Tensor &actual_draft_token_nums,
const paddle::Tensor &accept_tokens,
const paddle::Tensor &accept_num,
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &is_block_step) {
int real_bsz = seq_lens_this_time.shape()[0];
auto max_draft_tokens = draft_tokens.shape()[1];

constexpr int BlockSize = 512;

speculate_update<BlockSize><<<1, BlockSize, 0, accept_tokens.stream()>>>(
const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<bool *>(not_need_stop.data<bool>()),
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
const_cast<int *>(actual_draft_token_nums.data<int>()),
accept_tokens.data<int64_t>(),
accept_num.data<int>(),
stop_flags.data<bool>(),
seq_lens_this_time.data<int>(),
is_block_step.data<bool>(),
real_bsz,
max_draft_tokens);
}

PD_BUILD_OP(speculate_update)
.Inputs({"seq_lens_encoder",
"seq_lens_decoder",
"not_need_stop",
"draft_tokens",
"actual_draft_token_nums",
"accept_tokens",
"accept_num",
"stop_flags",
"seq_lens_this_time",
"is_block_step"})
.Outputs({"seq_lens_encoder_out",
"seq_lens_decoder_out",
"not_need_stop_out",
"draft_tokens_out",
"actual_draft_token_nums_out"})
.SetInplaceMap({{"seq_lens_encoder", "seq_lens_encoder_out"},
{"seq_lens_decoder", "seq_lens_decoder_out"},
{"not_need_stop", "not_need_stop_out"},
{"draft_tokens", "draft_tokens_out"},
{"actual_draft_token_nums", "actual_draft_token_nums_out"}})
.SetKernelFn(PD_KERNEL(SpeculateUpdate));
Loading