@@ -1551,7 +1551,7 @@ def test_fp16_full_eval(self):
15511551 a = torch .ones (1000 , bs ) + 0.001
15521552 b = torch .ones (1000 , bs ) - 0.001
15531553
1554- # 1. with mem metrics enabled
1554+ # 1. with fp16_full_eval disabled
15551555 trainer = get_regression_trainer (a = a , b = b , eval_len = eval_len , skip_memory_metrics = False )
15561556 metrics = trainer .evaluate ()
15571557 del trainer
@@ -1572,7 +1572,7 @@ def test_fp16_full_eval(self):
15721572 # perfect world: fp32_eval == close to zero
15731573 self .assertLess (fp32_eval , 5_000 )
15741574
1575- # 2. with mem metrics disabled
1575+ # 2. with fp16_full_eval enabled
15761576 trainer = get_regression_trainer (a = a , b = b , eval_len = eval_len , fp16_full_eval = True , skip_memory_metrics = False )
15771577 metrics = trainer .evaluate ()
15781578 fp16_init = metrics ["init_mem_gpu_alloc_delta" ]
@@ -1611,7 +1611,7 @@ def test_bf16_full_eval(self):
16111611 a = torch .ones (1000 , bs ) + 0.001
16121612 b = torch .ones (1000 , bs ) - 0.001
16131613
1614- # 1. with mem metrics enabled
1614+ # 1. with bf16_full_eval disabled
16151615 trainer = get_regression_trainer (a = a , b = b , eval_len = eval_len , skip_memory_metrics = False )
16161616 metrics = trainer .evaluate ()
16171617 del trainer
@@ -1632,7 +1632,7 @@ def test_bf16_full_eval(self):
16321632 # perfect world: fp32_eval == close to zero
16331633 self .assertLess (fp32_eval , 5_000 )
16341634
1635- # 2. with mem metrics disabled
1635+ # 2. with bf16_full_eval enabled
16361636 trainer = get_regression_trainer (a = a , b = b , eval_len = eval_len , bf16_full_eval = True , skip_memory_metrics = False )
16371637 metrics = trainer .evaluate ()
16381638 bf16_init = metrics ["init_mem_gpu_alloc_delta" ]
0 commit comments