1+ import json
2+ import os
13import unittest
4+ from datetime import datetime
25from types import SimpleNamespace
36
47from sglang .srt .utils import kill_child_process
1417 popen_launch_server ,
1518)
1619
20+ MODEL_SCORE_THRESHOLDS = {
21+ "meta-llama/Llama-3.1-8B-Instruct" : 0.8316 ,
22+ "mistralai/Mistral-7B-Instruct-v0.3" : 0.5861 ,
23+ "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct" : 0.8672 ,
24+ "google/gemma-2-27b-it" : 0.9227 ,
25+ "meta-llama/Llama-3.1-70B-Instruct" : 0.9623 ,
26+ "mistralai/Mixtral-8x7B-Instruct-v0.1" : 0.6415 ,
27+ "Qwen/Qwen2-57B-A14B-Instruct" : 0.8791 ,
28+ "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8" : 0.8672 ,
29+ "neuralmagic/Mistral-7B-Instruct-v0.3-FP8" : 0.5544 ,
30+ "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8" : 0.8356 ,
31+ "neuralmagic/gemma-2-2b-it-FP8" : 0.6059 ,
32+ "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8" : 0.9504 ,
33+ "neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8" : 0.6138 ,
34+ "neuralmagic/Qwen2-72B-Instruct-FP8" : 0.9504 ,
35+ "neuralmagic/Qwen2-57B-A14B-Instruct-FP8" : 0.8197 ,
36+ "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4" : 0.8395 ,
37+ "hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4" : 0.8435 ,
38+ }
39+
1740
1841def parse_models (model_string ):
1942 return [model .strip () for model in model_string .split ("," ) if model .strip ()]
@@ -23,10 +46,8 @@ def launch_server(base_url, model, is_fp8, is_tp2):
2346 other_args = ["--log-level-http" , "warning" , "--trust-remote-code" ]
2447 if is_fp8 :
2548 if "Llama-3" in model or "gemma-2" in model :
26- # compressed-tensors
2749 other_args .extend (["--kv-cache-dtype" , "fp8_e5m2" ])
2850 elif "Qwen2-72B-Instruct-FP8" in model :
29- # bug
3051 other_args .extend (["--quantization" , "fp8" ])
3152 else :
3253 other_args .extend (["--quantization" , "fp8" , "--kv-cache-dtype" , "fp8_e5m2" ])
@@ -48,6 +69,49 @@ def launch_server(base_url, model, is_fp8, is_tp2):
4869 return process
4970
5071
72+ def write_results_to_json (model , metrics , mode = "a" ):
73+ result = {
74+ "timestamp" : datetime .now ().isoformat (),
75+ "model" : model ,
76+ "metrics" : metrics ,
77+ "score" : metrics ["score" ],
78+ }
79+
80+ existing_results = []
81+ if mode == "a" and os .path .exists ("results.json" ):
82+ try :
83+ with open ("results.json" , "r" ) as f :
84+ existing_results = json .load (f )
85+ except json .JSONDecodeError :
86+ existing_results = []
87+
88+ if isinstance (existing_results , list ):
89+ existing_results .append (result )
90+ else :
91+ existing_results = [result ]
92+
93+ with open ("results.json" , "w" ) as f :
94+ json .dump (existing_results , f , indent = 2 )
95+
96+
97+ def check_model_scores (results ):
98+ failed_models = []
99+ for model , score in results :
100+ threshold = MODEL_SCORE_THRESHOLDS .get (model )
101+ if threshold is None :
102+ print (f"Warning: No threshold defined for model { model } " )
103+ continue
104+
105+ if score < threshold :
106+ failed_models .append (
107+ f"\n Score Check Failed: { model } \n "
108+ f"Model { model } score ({ score :.4f} ) is below threshold ({ threshold :.4f} )"
109+ )
110+
111+ if failed_models :
112+ raise AssertionError ("\n " .join (failed_models ))
113+
114+
51115class TestEvalAccuracyLarge (unittest .TestCase ):
52116 @classmethod
53117 def setUpClass (cls ):
@@ -68,6 +132,9 @@ def tearDown(self):
68132 kill_child_process (self .process .pid , include_self = True )
69133
70134 def test_mgsm_en_all_models (self ):
135+ is_first = True
136+ all_results = []
137+
71138 for model_group , is_fp8 , is_tp2 in self .model_groups :
72139 for model in model_group :
73140 with self .subTest (model = model ):
@@ -85,11 +152,24 @@ def test_mgsm_en_all_models(self):
85152 print (
86153 f"{ '=' * 42 } \n { model } - metrics={ metrics } score={ metrics ['score' ]} \n { '=' * 42 } \n "
87154 )
88- # loosely threshold
89- assert metrics ["score" ] > 0.5 , f"score={ metrics ['score' ]} <= 0.5"
155+
156+ write_results_to_json (model , metrics , "w" if is_first else "a" )
157+ is_first = False
158+
159+ all_results .append ((model , metrics ["score" ]))
90160
91161 self .tearDown ()
92162
163+ try :
164+ with open ("results.json" , "r" ) as f :
165+ print ("\n Final Results from results.json:" )
166+ print (json .dumps (json .load (f ), indent = 2 ))
167+ except Exception as e :
168+ print (f"Error reading results.json: { e } " )
169+
170+ # Check all scores after collecting all results
171+ check_model_scores (all_results )
172+
93173
94174if __name__ == "__main__" :
95175 unittest .main ()
0 commit comments