@@ -517,7 +517,7 @@ def FeatureTransformerSliceFunctionEmulate(
517517
518518 def test ():
519519 BATCH_SIZE = 16
520- INPUT_SIZE = 768
520+ INPUT_SIZE = 10
521521 MAX_ACTIVE_FEATURES = 32
522522 STRIDE = 128
523523 MAX_ERROR = 1e-4
@@ -574,7 +574,6 @@ def bench():
574574 MAX_ACTIVE_FEATURES = 64
575575
576576 layer = DoubleFeatureTransformerSlice (INPUT_SIZE , STRIDE ).cuda ()
577- layer = torch .compile (layer , fullgraph = True , mode = "max-autotune" )
578577 indices0 = torch .cat (
579578 [
580579 torch .sort (
@@ -606,38 +605,12 @@ def bench():
606605 BATCH_SIZE , MAX_ACTIVE_FEATURES , dtype = torch .float32
607606 ).cuda ()
608607
609- # Warmup
610- output0 , output1 = layer (indices0 , values0 , indices1 , values1 )
611- output0 = torch .clamp (output0 , 0.0 , 1.0 )
612- output0 = output0 .clone ()
613- output1 = output1 .clone ()
614- output1 = torch .clamp (output1 , 0.0 , 1.0 )
615- g = ((output0 - output1 ) ** 2 ).mean ()
616- g .backward ()
617- torch .cuda .synchronize ()
618-
619- for _ in range (ITERS ):
620- torch .compiler .cudagraph_mark_step_begin ()
621- output0 , output1 = layer (indices0 , values0 , indices1 , values1 )
622- output0 = output0 .clone ()
623- output1 = output1 .clone ()
624- output0 = torch .clamp (output0 , 0.0 , 1.0 )
625- output1 = torch .clamp (output1 , 0.0 , 1.0 )
626-
627- g = ((output0 - output1 ) ** 2 ).mean ()
628- g .backward ()
629-
630- torch .cuda .synchronize ()
631-
632608 output0 , output1 = layer (indices0 , values0 , indices1 , values1 )
633609
634610 start = time .time ()
635611
636612 for _ in range (ITERS ):
637- torch .compiler .cudagraph_mark_step_begin ()
638613 output0 , output1 = layer (indices0 , values0 , indices1 , values1 )
639- output0 = output0 .clone ()
640- output1 = output1 .clone ()
641614 output0 = torch .clamp (output0 , 0.0 , 1.0 )
642615 output1 = torch .clamp (output1 , 0.0 , 1.0 )
643616
0 commit comments