1+ from typing import Type
2+
13import pytest
24import torch
35
4- from vllm .model_executor .layers .activation import FastGELU , NewGELU , SiluAndMul
6+ from vllm .model_executor .layers .activation import (FastGELU , GeluAndMul ,
7+ NewGELU , SiluAndMul )
58from allclose_default import get_default_atol , get_default_rtol
69
710DTYPES = [torch .half , torch .bfloat16 , torch .float ]
1316]
1417
1518
19+ @pytest .mark .parametrize ("activation" , [SiluAndMul , GeluAndMul ])
1620@pytest .mark .parametrize ("num_tokens" , NUM_TOKENS )
1721@pytest .mark .parametrize ("d" , D )
1822@pytest .mark .parametrize ("dtype" , DTYPES )
1923@pytest .mark .parametrize ("seed" , SEEDS )
2024@pytest .mark .parametrize ("device" , CUDA_DEVICES )
2125@torch .inference_mode ()
22- def test_silu_and_mul (
26+ def test_act_and_mul (
27+ activation : Type [torch .nn .Module ],
2328 num_tokens : int ,
2429 d : int ,
2530 dtype : torch .dtype ,
@@ -31,48 +36,23 @@ def test_silu_and_mul(
3136 torch .cuda .manual_seed (seed )
3237 torch .set_default_device (device )
3338 x = torch .randn (num_tokens , 2 * d , dtype = dtype )
34- layer = SiluAndMul ()
39+ layer = activation ()
3540 out = layer (x )
3641 ref_out = layer ._forward (x )
37- assert torch .allclose (out ,
38- ref_out ,
39- atol = get_default_atol (out ),
40- rtol = get_default_rtol (out ))
42+ # The SiLU and GELU implementations are equivalent to the native PyTorch
43+ # implementations, so we can do exact comparison.
44+ assert torch .allclose (out , ref_out , atol = 0.0 , rtol = 0.0 )
4145
4246
47+ @pytest .mark .parametrize ("activation" , [FastGELU , NewGELU ])
4348@pytest .mark .parametrize ("num_tokens" , NUM_TOKENS )
4449@pytest .mark .parametrize ("d" , D )
4550@pytest .mark .parametrize ("dtype" , DTYPES )
4651@pytest .mark .parametrize ("seed" , SEEDS )
4752@pytest .mark .parametrize ("device" , CUDA_DEVICES )
4853@torch .inference_mode ()
49- def test_gelu_new (
50- num_tokens : int ,
51- d : int ,
52- dtype : torch .dtype ,
53- seed : int ,
54- device : str ,
55- ) -> None :
56- torch .random .manual_seed (seed )
57- if torch .cuda .is_available ():
58- torch .cuda .manual_seed (seed )
59- torch .set_default_device (device )
60- x = torch .randn (num_tokens , d , dtype = dtype )
61- layer = NewGELU ()
62- out = layer (x )
63- ref_out = layer ._forward (x )
64- assert torch .allclose (out ,
65- ref_out ,
66- atol = get_default_atol (out ),
67- rtol = get_default_rtol (out ))
68-
69-
70- @pytest .mark .parametrize ("num_tokens" , NUM_TOKENS )
71- @pytest .mark .parametrize ("d" , D )
72- @pytest .mark .parametrize ("dtype" , DTYPES )
73- @pytest .mark .parametrize ("seed" , SEEDS )
74- @pytest .mark .parametrize ("device" , CUDA_DEVICES )
75- def test_gelu_fast (
54+ def test_activation (
55+ activation : Type [torch .nn .Module ],
7656 num_tokens : int ,
7757 d : int ,
7858 dtype : torch .dtype ,
@@ -84,7 +64,7 @@ def test_gelu_fast(
8464 torch .cuda .manual_seed (seed )
8565 torch .set_default_device (device )
8666 x = torch .randn (num_tokens , d , dtype = dtype )
87- layer = FastGELU ()
67+ layer = activation ()
8868 out = layer (x )
8969 ref_out = layer ._forward (x )
9070 assert torch .allclose (out ,
0 commit comments