Skip to content

Conversation

@Isotr0py
Copy link
Member

@Isotr0py Isotr0py commented Sep 1, 2025

Purpose

  • After Update PyTorch to 2.8.0 #20358, we upgraded Triton to 3.4.0, which have fixed the issue that broke pre-Ampere FP16 calculation, so we can fully enable V1 on Volta and Turing now.

Test Plan

python examples/offline_inference/basic/basic.py

Test Result

Have confirmed on T4 machine:

(EngineCore_0 pid=1333) INFO 09-01 05:34:08 [core.py:75] Initializing a V1 LLM engine (v0.10.1rc2.dev405+g8c742a66d) with config: model='facebook/opt-125m', speculative_config=None, tokenizer='facebook/opt-125m', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=2048, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=facebook/opt-125m, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=True, pooler_config=None, compilation_config={"level":3,"debug_dump_path":"","cache_dir":"","backend":"","custom_ops":[],"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output","vllm.mamba_mixer2","vllm.mamba_mixer","vllm.short_conv","vllm.linear_attention"],"use_inductor":true,"compile_sizes":[],"inductor_compile_config":{"enable_auto_functionalized_v2":false},"inductor_passes":{},"cudagraph_mode":1,"use_cudagraph":true,"cudagraph_num_of_warmups":1,"cudagraph_capture_sizes":[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"cudagraph_copy_inputs":false,"full_cuda_graph":false,"pass_config":{},"max_capture_size":512,"local_cache_dir":null}
(EngineCore_0 pid=1333) ERROR 09-01 05:34:11 [fa_utils.py:57] Cannot use FA version 2 is not supported due to FA2 is only supported on devices with compute capability >= 8
[W901 05:34:22.565738035 socket.cpp:200] [c10d] The hostname of the client socket cannot be retrieved. err=-3
[W901 05:34:32.576528242 socket.cpp:200] [c10d] The hostname of the client socket cannot be retrieved. err=-3
[W901 05:34:32.577286626 ProcessGroupNCCL.cpp:981] Warning: TORCH_NCCL_AVOID_RECORD_STREAMS is the default now, this environment variable is thus deprecated. (function operator())
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
(EngineCore_0 pid=1333) INFO 09-01 05:34:32 [parallel_state.py:1134] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
(EngineCore_0 pid=1333) WARNING 09-01 05:34:32 [topk_topp_sampler.py:69] FlashInfer is not available. Falling back to the PyTorch-native implementation of top-p & top-k sampling. For the best performance, please install FlashInfer.
(EngineCore_0 pid=1333) INFO 09-01 05:34:32 [gpu_model_runner.py:1926] Starting to load model facebook/opt-125m...
(EngineCore_0 pid=1333) INFO 09-01 05:34:32 [gpu_model_runner.py:1958] Loading model from scratch...
(EngineCore_0 pid=1333) INFO 09-01 05:34:32 [cuda.py:334] Using FlexAttention backend on V1 engine.
(EngineCore_0 pid=1333) INFO 09-01 05:34:33 [weight_utils.py:304] Using model weights format ['*.safetensors', '*.bin', '*.pt']
pytorch_model.bin: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 251M/251M [00:01<00:00, 196MB/s]
(EngineCore_0 pid=1333) INFO 09-01 05:34:34 [weight_utils.py:325] Time spent downloading weights for facebook/opt-125m: 1.631956 seconds
Loading pt checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  3.39it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  3.39it/s]
(EngineCore_0 pid=1333) 
(EngineCore_0 pid=1333) INFO 09-01 05:34:35 [default_loader.py:267] Loading weights took 0.30 seconds
(EngineCore_0 pid=1333) INFO 09-01 05:34:35 [gpu_model_runner.py:1980] Model loading took 0.2389 GiB and 2.342473 seconds
(EngineCore_0 pid=1333) INFO 09-01 05:34:38 [backends.py:538] Using cache directory: /root/.cache/vllm/torch_compile_cache/d4f9f3287c/rank_0_0/backbone for vLLM's torch.compile
(EngineCore_0 pid=1333) INFO 09-01 05:34:38 [backends.py:549] Dynamo bytecode transform time: 2.61 s
(EngineCore_0 pid=1333) INFO 09-01 05:34:41 [backends.py:194] Cache the graph for dynamic shape for later use
(EngineCore_0 pid=1333) [rank0]:W0901 05:34:42.399000 1333 torch/_inductor/utils.py:1436] [0/0] Not enough SMs to use max_autotune_gemm mode
(EngineCore_0 pid=1333) INFO 09-01 05:34:47 [backends.py:215] Compiling a graph for dynamic shape takes 9.14 s
(EngineCore_0 pid=1333) INFO 09-01 05:34:48 [monitor.py:34] torch.compile takes 11.75 s in total
(EngineCore_0 pid=1333) INFO 09-01 05:34:49 [gpu_worker.py:276] Available KV cache memory: 12.54 GiB
(EngineCore_0 pid=1333) INFO 09-01 05:34:49 [kv_cache_utils.py:849] GPU KV cache size: 365,200 tokens
(EngineCore_0 pid=1333) INFO 09-01 05:34:49 [kv_cache_utils.py:853] Maximum concurrency for 2,048 tokens per request: 178.32x
Capturing CUDA graphs (mixed prefill-decode, PIECEWISE): 100%|████████████████████████████████████████████████████████████████████████████████| 67/67 [00:01<00:00, 46.10it/s]
(EngineCore_0 pid=1333) INFO 09-01 05:34:51 [gpu_model_runner.py:2681] Graph capturing finished in 2 secs, took 0.21 GiB
(EngineCore_0 pid=1333) INFO 09-01 05:34:51 [core.py:217] init engine (profile, create kv cache, warmup model) took 16.03 seconds
INFO 09-01 05:34:52 [llm.py:283] Supported_tasks: ['generate']
Adding requests: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 781.03it/s]
Processed prompts:  25%|████████████▊                                      | 1/4 [00:16<00:49, 16.44s/it, est. speed input: 0.36 toks/s, output: 0.97 toks/s]Processed prompts: 100%|███████████████████████████████████████████████████| 4/4 [00:16<00:00,  4.12s/it, est. speed input: 1.58 toks/s, output: 3.89 toks/s]

Generated Outputs:
------------------------------------------------------------
Prompt:    'Hello, my name is'
Output:    ' Joel, I am a 4yo, I am very naughty, and I like'
------------------------------------------------------------
Prompt:    'The president of the United States is'
Output:    ' reportedly holding back from issuing a statement about the Ukraine crisis after he called the International'
------------------------------------------------------------
Prompt:    'The capital of France is'
Output:    ' the capital of the French colony of Ireland.\nLaws of France are not'
------------------------------------------------------------
Prompt:    'The future of AI is'
Output:    " in the hands of the vast majority of people - you've probably seen it already"
------------------------------------------------------------

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request removes a workaround that disabled V1 engine support for pre-Ampere GPUs (Volta, Turing) using FP16 precision. The change is justified by an upgrade to Triton v3.4.0, which fixes the underlying bug. By deleting this conditional check in vllm/engine/arg_utils.py, the PR correctly re-enables V1 FP16 inference on these GPU architectures, broadening hardware compatibility. The change is clear, concise, and supported by test results on a T4 machine. I find the change to be correct and have no further suggestions.

Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) September 1, 2025 05:47
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 1, 2025
@DarkLight1337 DarkLight1337 merged commit d7fbc6d into vllm-project:main Sep 1, 2025
45 of 47 checks passed
@Isotr0py Isotr0py deleted the v1-turing branch September 1, 2025 09:52
@Jyothirmaikottu
Copy link

Jyothirmaikottu commented Sep 8, 2025

Hi Team, I'm facing this error when I try to perform "vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --tensor-parallel-size 2 --max-num-batched-tokens 16384"

I have built wheels from source to support vllm on Nvidia T4 gpu with arm64 arch type.
my env:
cuda=12.8
torch=2.7.0
pytorch-triton=3.4.0

I have used the source code from main branch to build vllm wheels. When I try basic inference test it works, but when I try to run the server I get the following error:

  • Basic inference:
docker run --rm \
    -v /fsx/vllm-dlc/vllm:/vllm \
    --entrypoint /bin/bash \
    -e "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
    -e "VLLM_WORKER_MULTIPROC_METHOD=spawn" \
    -e "VLLM_USE_V1=0" \
    -v /fsx/.cache/huggingface:/root/.cache/huggingface \
    --gpus=all \
    vllm-arm64-image\
    -c "python3 /vllm/examples/offline_inference/basic/generate.py \
        --model ${MODEL_NAME} \
        --dtype half \
        --tensor-parallel-size 1 \
        --max-model-len 2048"
        

Output:

--------------------------------------------------
Prompt: 'Hello, my name is'
Generated text: " ... and I'm trying to solve this problem: For each positive integer \\( n"
--------------------------------------------------
Prompt: 'The president of the United States is'
Generated text: ' the only one with the ability to... The president of the United States is the'
--------------------------------------------------
Prompt: 'The capital of France is'
Generated text: ' the capital of Germany. The capital of Germany is the capital of Japan. The'
--------------------------------------------------
Prompt: 'The future of AI is'
Generated text: ' a big topic. I want to understand the future of AI in terms of...'
--------------------------------------------------
[rank0]:[W908 16:44:12.547704272 ProcessGroupNCCL.cpp:1476] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())

vllm serve command:
Error:

(VllmWorkerProcess pid=421) ERROR 09-08 17:17:18 [multiproc_worker_utils.py:232] torch._inductor.exc.InductorError: AttributeError: type object 'CompiledKernel' has no attribute 'launch_enter_hook'
(VllmWorkerProcess pid=421) ERROR 09-08 17:17:18 [multiproc_worker_utils.py:232]
(VllmWorkerProcess pid=421) ERROR 09-08 17:17:18 [multiproc_worker_utils.py:232] Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants