Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions examples/verify_aim24.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,6 @@ def generate_requests(dataset: Dataset, field_name: str, data_format: str, trial
return requests







def main():
model_name = "Qwen/Qwen3-0.6B"
llm = sgl.Engine(model_path=model_name,
Expand Down
2 changes: 1 addition & 1 deletion examples/verify_algo.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env bash
set -e
# export CUDA_VISIBLE_DEVICES=0
export CUDA_VISIBLE_DEVICES=1

sparse_algos=(
"block_sparse_attention"
Expand Down
25 changes: 0 additions & 25 deletions examples/verify_algo_int8.sh

This file was deleted.

18 changes: 16 additions & 2 deletions examples/verify_algo_fp8.sh → examples/verify_algo_quant.sh
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env bash
set -e
# export CUDA_VISIBLE_DEVICES=0
export CUDA_VISIBLE_DEVICES=0

sparse_algos=(
"block_sparse_attention"
Expand All @@ -11,6 +11,20 @@ mkdir -p "${RESULTS_DIR}"
TIMESTAMP=$(date +%Y%m%d_%H%M%S)

for algo in "${sparse_algos[@]}"; do
OUTFILE="${RESULTS_DIR}/${algo}_int8_${TIMESTAMP}.log"
echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --kv-cache-dtype int8"
echo ">>> Saving results to ${OUTFILE}"
{ time python verify_algo.py \
--trials 8 \
--topk-val 30 \
--vortex-module-name "${algo}" \
--model-name Qwen/Qwen3-1.7B \
--kv-cache-dtype int8 \
--mem 0.7 ; } \
2>&1 | tee "${OUTFILE}"
done

for algo in "${sparse_algos[@]}"; do
OUTFILE="${RESULTS_DIR}/${algo}_fp8_${TIMESTAMP}.log"
echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --kv-cache-dtype fp8_e4m3"
echo ">>> Saving results to ${OUTFILE}"
Expand All @@ -22,4 +36,4 @@ TIMESTAMP=$(date +%Y%m%d_%H%M%S)
--kv-cache-dtype fp8_e4m3 \
--mem 0.7 ; } \
2>&1 | tee "${OUTFILE}"
done
done
4 changes: 4 additions & 0 deletions vortex_torch/cache/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ class Context(ContextBase):

# Quantization: quant_type (0=none, 1=int8, 2=e4m3, 3=e5m2),
# kv_scale (per-tensor fp8 scale), kv_scale_ptr (per-token int8 scale tensor)
# fp8_type: 0=none, 1=e4m3, 2=e5m2 (encoding for Triton kernels)
"quant_type",
"kv_scale",
"kv_scale_ptr",
"fp8_type",
)


Expand All @@ -49,6 +51,8 @@ def __init__(self) -> None:
object.__setattr__(self, name, 1.0) # identity scale for bf16
elif name == "kv_scale_ptr":
object.__setattr__(self, name, None) # per-token scale tensor (int8 only)
elif name == "fp8_type":
object.__setattr__(self, name, 0) # 0 = none (bf16 default)
else:
object.__setattr__(self, name, UNSET)

Expand Down
1 change: 1 addition & 0 deletions vortex_torch/flow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ def run_indexer_virtual(self, group_size: int, page_size: int, head_dim: int):
ctx.page_size = page_size
ctx.max_num_pages = 0
ctx.max_num_pages_per_request = 0
ctx.topk_type = "naive"

device = "cuda"
dtype = torch.bfloat16
Expand Down