diff --git a/tests/test_utils.py b/tests/test_utils.py index 7099f66356..9907ac8b01 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -68,10 +68,7 @@ def test_config(request: pytest.FixtureRequest): def generate_request_pair(idx: int, block_per_request, num_gpu_blocks, tokens_per_block, dp_size): """Generate a request pair with token_ids, block_ids, and dp_id""" start_idx = (idx * block_per_request) % num_gpu_blocks - if start_idx + block_per_request >= num_gpu_blocks: - start_idx = ( - (start_idx + block_per_request) % num_gpu_blocks - ) + assert start_idx + block_per_request <= num_gpu_blocks block_ids = torch.arange( start_idx, start_idx + block_per_request,