Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions csrc/gpu/get_output.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@

#include "paddle/extension.h"

#define MAX_BSZ 512
#define SPECULATE_MAX_BSZ 256
#define MAX_BSZ 256
#define MAX_DRAFT_TOKENS 6

template <int SIZE>
Expand Down Expand Up @@ -70,8 +69,8 @@ void GetOutput(const paddle::Tensor& x,
static struct MsgData<SIZE> msg_rcv;
GetOutputFunc<SIZE>(msg_rcv, x, rank_id, wait_flag);
} else {
constexpr int SIZE = SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS +
SPECULATE_MAX_BSZ +
constexpr int SIZE = MAX_BSZ * MAX_DRAFT_TOKENS +
MAX_BSZ +
2; // stop_flag, bsz, accept_num*bsz, tokens...
static struct MsgData<SIZE> specu_msg_rcv;
GetOutputFunc<SIZE>(specu_msg_rcv, x, rank_id, wait_flag);
Expand Down
13 changes: 6 additions & 7 deletions csrc/gpu/save_with_output_msg.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@

#include "paddle/extension.h"

#define MAX_BSZ 512
#define SPECULATE_MAX_BSZ 256
#define MAX_BSZ 256
#define MAX_DRAFT_TOKENS 6

template <int SIZE>
Expand Down Expand Up @@ -63,15 +62,15 @@ void SaveOutMsgFunc(MsgData<SIZE>& msg_sed, // NOLINT
msg_sed.mtype = 1;
msg_sed.mtext[0] = not_need_stop_data[0] ? 1 : -1;
msg_sed.mtext[1] = bsz;
for (int i = 2; i < SPECULATE_MAX_BSZ + 2; i++) {
for (int i = 2; i < MAX_BSZ + 2; i++) {
if (i - 2 >= bsz) {
msg_sed.mtext[i] = 0;
} else {
msg_sed.mtext[i] = (int)accept_num_data[i - 2];
}
}
for (int i = SPECULATE_MAX_BSZ + 2; i < SIZE; i++) {
int token_id = i - SPECULATE_MAX_BSZ - 2;
for (int i = MAX_BSZ + 2; i < SIZE; i++) {
int token_id = i - MAX_BSZ - 2;
int bid = token_id / MAX_DRAFT_TOKENS;
int local_token_id = token_id % MAX_DRAFT_TOKENS;
if (token_id / MAX_DRAFT_TOKENS >= bsz) {
Expand All @@ -97,8 +96,8 @@ void SaveOutMsg(const paddle::Tensor& x,
static struct MsgData<SIZE> msg_sed;
SaveOutMsgFunc<SIZE>(msg_sed, x, not_need_stop, accept_num, rank_id);
} else {
constexpr int SIZE = SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS +
SPECULATE_MAX_BSZ +
constexpr int SIZE = MAX_BSZ * MAX_DRAFT_TOKENS +
MAX_BSZ +
2; // stop_flag, bsz, accept_num*bsz, tokens...
static struct MsgData<SIZE> specu_msg_sed;
SaveOutMsgFunc<SIZE>(specu_msg_sed, x, not_need_stop, accept_num, rank_id);
Expand Down
6 changes: 3 additions & 3 deletions llm/predict/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
PretrainedTokenizer,
)
from paddlenlp.trl import llm_utils
from paddlenlp.utils.env import MAX_BSZ, MAX_DRAFT_TOKENS, SPECULATE_MAX_BSZ
from paddlenlp.utils.env import MAX_BSZ, MAX_DRAFT_TOKENS
from paddlenlp.utils.import_utils import is_paddlenlp_ops_available
from paddlenlp.utils.log import logger

Expand Down Expand Up @@ -1039,7 +1039,7 @@ def predict(self, input_texts: list[str], return_tokens=False):
output_tensor_shape = [MAX_BSZ + 2, 1]
else:
read_res_func = llm_utils.speculate_read_res
output_tensor_shape = [SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2, 1]
output_tensor_shape = [MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ + 2, 1]

read_res_process = mp.Process(
target=read_res_func, args=[self.model_name_or_path, tensor_queue, result_queue, done_event]
Expand Down Expand Up @@ -1186,7 +1186,7 @@ def predict(self, input_texts: list[str], return_tokens=False):
output_tensor_shape = [MAX_BSZ + 2, 1]
else:
read_res_func = llm_utils.speculate_read_res
output_tensor_shape = [SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2, 1]
output_tensor_shape = [MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ + 2, 1]

read_res_process = mp.Process(
target=read_res_func, args=[self.model_name_or_path, tensor_queue, result_queue, done_event]
Expand Down
189 changes: 141 additions & 48 deletions paddlenlp/experimental/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,24 @@ def prepare_input_ids_for_generation(bos_token_id, encoder_output=None):
seq_len = encoder_output.shape[1]
return paddle.ones([batch_size, seq_len], dtype="int64") * bos_token_id

def get_output_padding_offset(self, seq_lens_this_time, seq_lens_encoder, seq_lens_decoder):
"""
In the senerio of speculate decoding, the length of output token after rebuild_padding is no longer bsz.
So we need to calculate the output_padding_offset after rebuild_padding.
"""
from paddlenlp_ops import (
speculate_get_output_padding_offset,
speculate_get_seq_lens_output,
)

seq_lens_output = speculate_get_seq_lens_output(seq_lens_this_time, seq_lens_encoder, seq_lens_decoder)
out_token_num = paddle.sum(seq_lens_output)
output_cum_offsets_tmp = paddle.cumsum(self.max_seq_len - seq_lens_output)
output_padding_offset, output_cum_offsets = speculate_get_output_padding_offset(
output_cum_offsets_tmp, out_token_num, seq_lens_output, self.max_seq_len
)
return output_padding_offset, output_cum_offsets

@paddle.no_grad()
def generate(
self,
Expand Down Expand Up @@ -665,66 +683,141 @@ def _post_process_(
):
step_idx = model_kwargs["step_idx"]
logits = paddle.cast(outputs, paddle.float32)
from paddlenlp_ops import set_preids_token_penalty_multi_scores

set_preids_token_penalty_multi_scores(
model_kwargs["pre_ids"],
model_kwargs["input_ids"],
model_kwargs["seq_lens_encoder"],
model_kwargs["seq_lens_decoder"],
step_idx,
model_kwargs["stop_flags"],
logits,
penalty_score,
frequency_score,
presence_score,
temperature,
model_kwargs["bad_tokens"],
step_idx,
model_kwargs["min_dec_len"],
eos_token_id,
)
# TODO(Wanglongzhi2001): token_penalty of speculative decoding
if not is_speculative_decoding:
from paddlenlp_ops import set_preids_token_penalty_multi_scores

set_preids_token_penalty_multi_scores(
model_kwargs["pre_ids"],
model_kwargs["input_ids"],
model_kwargs["seq_lens_encoder"],
model_kwargs["seq_lens_decoder"],
step_idx,
model_kwargs["stop_flags"],
logits,
penalty_score,
frequency_score,
presence_score,
temperature,
model_kwargs["bad_tokens"],
step_idx,
model_kwargs["min_dec_len"],
eos_token_id,
)

# sample
probs = F.softmax(logits)

# compute next_tokens
if use_faster_top_p_sampling():
from paddlenlp_ops import top_p_sampling_reject
from paddlenlp_ops import save_output

next_tokens = top_p_sampling_reject(probs, top_p, 0)
# whether speculative decoding
if not is_speculative_decoding:

# compute next_tokens
if use_faster_top_p_sampling():
from paddlenlp_ops import top_p_sampling_reject

next_tokens = top_p_sampling_reject(probs, top_p, 0)
else:
_, next_tokens = paddle.tensor.top_p_sampling(probs, top_p)

if self.config.tensor_parallel_degree > 1:
paddle.distributed.broadcast(next_tokens, 0)

from paddlenlp_ops import update_inputs_v2

update_inputs_v2(
model_kwargs["stop_flags"],
model_kwargs["step_idx"],
model_kwargs["not_need_stop"],
model_kwargs["seq_lens_this_time"],
model_kwargs["seq_lens_encoder"],
model_kwargs["seq_lens_decoder"],
model_kwargs["max_dec_len"],
model_kwargs["input_ids"],
model_kwargs["stop_nums"],
next_tokens,
model_kwargs["is_block_step"],
eos_token_id,
model_kwargs["next_tokens"],
)

save_output(
next_tokens,
model_kwargs["not_need_stop"],
model_kwargs.get("accept_num", None), # only initialized in speculative decoding
self.config.tensor_parallel_rank,
)
return next_tokens
else:
_, next_tokens = paddle.tensor.top_p_sampling(probs, top_p)
from paddlenlp_ops import (
speculate_set_value_by_flags_and_idx,
speculate_verify_and_update,
top_p_candidates,
)

if self.config.tensor_parallel_degree > 1:
paddle.distributed.broadcast(next_tokens, 0)
verify_scores, verify_tokens, actual_candidate_len = top_p_candidates(
probs, top_p, model_kwargs["output_padding_offset"], self.max_candidate_len, self.max_seq_len
) # [token_num, max_candidate_len]

from paddlenlp_ops import update_inputs_v2
# Speculate Verify And Update
speculate_verify_and_update(
model_kwargs["accept_tokens"],
model_kwargs["accept_num"],
model_kwargs["step_idx"],
model_kwargs["seq_lens_encoder"],
model_kwargs["seq_lens_decoder"],
model_kwargs["stop_flags"],
model_kwargs["not_need_stop"],
model_kwargs[
"draft_tokens"
], # Both input and output, need to write the last 1 token accepted to position 0.
model_kwargs["seq_lens_this_time"],
verify_tokens,
verify_scores,
model_kwargs["max_dec_len"],
eos_token_id,
model_kwargs["is_block_step"],
model_kwargs["output_cum_offsets"],
actual_candidate_len,
model_kwargs["actual_draft_token_num"],
top_p,
self.max_seq_len,
self.verify_window,
True, # enable_topp
)

update_inputs_v2(
model_kwargs["stop_flags"],
model_kwargs["step_idx"],
model_kwargs["not_need_stop"],
model_kwargs["seq_lens_this_time"],
model_kwargs["seq_lens_encoder"],
model_kwargs["seq_lens_decoder"],
model_kwargs["max_dec_len"],
model_kwargs["input_ids"],
model_kwargs["stop_nums"],
next_tokens,
model_kwargs["is_block_step"],
eos_token_id,
model_kwargs["next_tokens"],
)
from paddlenlp_ops import save_output
save_output(
model_kwargs["accept_tokens"],
model_kwargs["not_need_stop"],
model_kwargs["accept_num"],
self.config.tensor_parallel_rank,
)

save_output(
next_tokens,
model_kwargs["not_need_stop"],
model_kwargs.get("accept_tokens", None), # only initialized in speculative decoding
self.config.tensor_parallel_rank,
# If seq_lens_decoder is 0 (means stop), accept_num should be set to 0
model_kwargs["accept_num"][model_kwargs["seq_lens_decoder"] == 0] = 0

# Update pre_ids through accept tokens
speculate_set_value_by_flags_and_idx(
model_kwargs["pre_ids"],
model_kwargs["accept_tokens"],
model_kwargs["accept_num"],
model_kwargs["stop_flags"],
model_kwargs["seq_lens_this_time"],
model_kwargs["seq_lens_encoder"],
model_kwargs["seq_lens_decoder"],
model_kwargs["step_idx"],
)

is_speculative_decoding = model_kwargs.get("draft_tokens", None) is not None
if is_speculative_decoding:
# Prepare output padding offset
output_padding_offset, output_cum_offsets = self.get_output_padding_offset(
model_kwargs["seq_lens_this_time"], model_kwargs["seq_lens_encoder"], model_kwargs["seq_lens_decoder"]
)
return next_tokens
model_kwargs["output_padding_offset"] = output_padding_offset
model_kwargs["output_cum_offsets"] = output_cum_offsets

# encoder
outputs = _forward_(**model_kwargs) # [bs, 1, dim_embed]
Expand Down
Loading