Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,20 @@ __device__ int64_t topp_sampling_kernel(const int64_t *candidate_ids,
const int tid = threadIdx.x;

float sum_scores = 0.0f;
float rand_top_p = curand_uniform(dev_curand_states + tid) * topp;
for (int i = 0; i < candidate_len; i++) {
sum_scores += candidate_scores[i];
}
float tgt_topp = sum_scores < topp ? sum_scores : topp;

sum_scores = 0.0f;
float rand_top_p = curand_uniform(dev_curand_states + tid) * tgt_topp;
for (int i = 0; i < candidate_len; i++) {
sum_scores += candidate_scores[i];
if (rand_top_p <= sum_scores) {
return candidate_ids[i];
return candidate_ids[i];
}
}
return candidate_ids[0];
return candidate_ids[0];
}

__global__ void setup_kernel(curandState_t *state, const uint64_t seed,
Expand Down
3 changes: 3 additions & 0 deletions custom_ops/gpu_ops/speculate_decoding/top_p_candidates.cu
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,9 @@ __global__ void KeMatrixTopPBeamTopKFt(
break;
}
}
if (top_p_value == 1.0 && actual_candidates_lens[token_id] == 0){
actual_candidates_lens[token_id] = max_cadidate_len;
}
}
}

Expand Down
7 changes: 7 additions & 0 deletions fastdeploy/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@
"FD_FOR_TORCH_MODEL_FORMAT": lambda: bool(int(os.getenv("FD_FOR_TORCH_MODEL_FORMAT", "0"))),
# force disable default chunked prefill
"FD_DISABLE_CHUNKED_PREFILL": lambda: bool(int(os.getenv("FD_DISABLE_CHUNKED_PREFILL", "0"))),
# For separate setting of sampling parameters for speculative decoding
"FD_SPECULATE_SAMPLING_TOP_P": lambda: (
None if "FD_SPECULATE_SAMPLING_TOP_P" not in os.environ else float(os.environ["FD_SPECULATE_SAMPLING_TOP_P"])
),
"FD_SPECULATE_SAMPLING_TOP_K": lambda: (
None if "FD_SPECULATE_SAMPLING_TOP_K" not in os.environ else float(os.environ["FD_SPECULATE_SAMPLING_TOP_K"])
),
"FD_ENABLE_INTERNAL_ADAPTER": lambda: int(os.getenv("FD_ENABLE_INTERNAL_ADAPTER", "0")),
# LLMEngine recieve requests port, used when FD_ENABLE_INTERNAL_ADAPTER=1
"FD_ZMQ_RECV_REQUEST_SERVER_PORT": lambda: os.getenv("FD_ZMQ_RECV_REQUEST_SERVER_PORT", "8200"),
Expand Down
12 changes: 10 additions & 2 deletions fastdeploy/spec_decode/mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,16 @@ def _init_model_inputs(self):
)
# self.model_inputs["caches"] = self.cache_kvs
# Inherit generation hyperparameters from the main model for consistency
self.model_inputs["top_p"] = self.target_model_inputs["top_p"]
self.model_inputs["top_k"] = self.target_model_inputs["top_k"]
self.model_inputs["top_p"] = (
self.target_model_inputs["top_p"]
if envs.FD_SPECULATE_SAMPLING_TOP_P is None
else paddle.full_like(self.target_model_inputs["top_p"], envs.FD_SPECULATE_SAMPLING_TOP_P)
)
self.model_inputs["top_k"] = (
self.target_model_inputs["top_k"]
if envs.FD_SPECULATE_SAMPLING_TOP_K is None
else paddle.full_like(self.target_model_inputs["top_k"], envs.FD_SPECULATE_SAMPLING_TOP_K)
)
self.model_inputs["temperature"] = self.target_model_inputs["temperature"]
self.model_inputs["eos_token_id"] = self.target_model_inputs["eos_token_id"]
self.model_inputs["penalty_score"] = self.target_model_inputs["penalty_score"]
Expand Down
Loading