Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 38 additions & 21 deletions csrc/gpu/step.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,19 @@ __global__ void free_and_dispatch_block(bool *stop_flags,
}
} else if (seq_lens_decoder[tid] != 0 && block_table_now[seq_lens_decoder[tid] / block_size] == -1) {
// 统计需要分配block的位置和总数
#ifdef DEBUG_STEP
printf("step seq_id:%d, ##### pin 1 #####\n", tid);
#endif
const int ori_need_block_len = atomicAdd(need_block_len, 1);
need_block_list[ori_need_block_len] = tid;
#ifdef DEBUG_STEP
printf("seq_id: %d need block\n", tid);
#endif
}
}
#ifdef DEBUG_STEP
printf("step seq_id:%d, ##### pin 2 #####\n", tid);
#endif
__syncthreads();
if (tid == 0) {
printf("need_block_len: %d, free_list_len: %d\n", need_block_len[0], free_list_len[0]);
Expand Down Expand Up @@ -102,7 +108,9 @@ __global__ void free_and_dispatch_block(bool *stop_flags,
}
__syncthreads();
}

#ifdef DEBUG_STEP
printf("step seq_id:%d, ##### pin 3 #####\n", tid);
#endif
// 为需要block的位置分配block,每个位置分配一个block
if (tid < need_block_len[0]) {
const int need_block_id = need_block_list[tid];
Expand All @@ -116,37 +124,46 @@ __global__ void free_and_dispatch_block(bool *stop_flags,
need_block_list[tid] = -1;
}
__syncthreads();

#ifdef DEBUG_STEP
printf("step seq_id:%d, ##### pin 4 #####\n", tid);
#endif
// 计算可以复原的query id
if (tid == 0) {
int ori_free_list_len = free_list_len[0];
int ori_step_len = step_len[0];
#ifdef DEBUG_STEP
printf("ori_step_len %d\n", ori_step_len);
int ori_step_block_id = step_block_list[ori_step_len - 1];
int tmp_used_len = used_list_len[ori_step_block_id];
// 比之前调度时多分配一个block,防止马上恢复刚调度的query(比如回收的seq_id在need_block_list中)
int used_len = tmp_used_len < max_decoder_block_num ? tmp_used_len + 1 : tmp_used_len;
while (ori_step_len > 0 && ori_free_list_len >= used_len) {
#endif
if (ori_step_len > 0) {
int ori_step_block_id = step_block_list[ori_step_len - 1];
int tmp_used_len = used_list_len[ori_step_block_id];
// 比之前调度时多分配一个block,防止马上恢复刚调度的query(比如回收的seq_id在need_block_list中)
int used_len = tmp_used_len < max_decoder_block_num ? tmp_used_len + 1 : tmp_used_len;
while (ori_step_len > 0 && ori_free_list_len >= used_len) {
#ifdef DEBUG_STEP
printf("recover seq_id: %d, free_list_len: %d, used_list_len: %d\n",
ori_step_block_id, ori_free_list_len, used_len);
printf("recover seq_id: %d, free_list_len: %d, used_list_len: %d\n",
ori_step_block_id, ori_free_list_len, used_len);
#endif
recover_block_list[recover_len[0]] = ori_step_block_id;
is_block_step[ori_step_block_id] = false;
used_list_len[ori_step_block_id] = used_len;
ori_free_list_len -= used_len;
step_block_list[ori_step_len - 1] = -1;
step_len[0] -= 1;
recover_len[0] += 1;
ori_step_len = step_len[0];
if (ori_step_len > 0) {
ori_step_block_id = step_block_list[ori_step_len - 1];
tmp_used_len = used_list_len[ori_step_block_id];
used_len = tmp_used_len < max_decoder_block_num ? tmp_used_len + 1 : tmp_used_len;
recover_block_list[recover_len[0]] = ori_step_block_id;
is_block_step[ori_step_block_id] = false;
used_list_len[ori_step_block_id] = used_len;
ori_free_list_len -= used_len;
step_block_list[ori_step_len - 1] = -1;
step_len[0] -= 1;
recover_len[0] += 1;
ori_step_len = step_len[0];
if (ori_step_len > 0) {
ori_step_block_id = step_block_list[ori_step_len - 1];
tmp_used_len = used_list_len[ori_step_block_id];
used_len = tmp_used_len < max_decoder_block_num ? tmp_used_len + 1 : tmp_used_len;
}
}
}
need_block_len[0] = 0;
}
#ifdef DEBUG_STEP
printf("step seq_id:%d, ##### pin 5 #####\n", tid);
#endif
}

// 根据上一步计算出的可以复原的query_id进行状态恢复
Expand Down