Skip to content

Commit 539af47

Browse files
committed
revert test changes
1 parent ae4fc4a commit 539af47

1 file changed

Lines changed: 1 addition & 28 deletions

File tree

feature_transformer.py

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,7 @@ def FeatureTransformerSliceFunctionEmulate(
517517

518518
def test():
519519
BATCH_SIZE = 16
520-
INPUT_SIZE = 768
520+
INPUT_SIZE = 10
521521
MAX_ACTIVE_FEATURES = 32
522522
STRIDE = 128
523523
MAX_ERROR = 1e-4
@@ -574,7 +574,6 @@ def bench():
574574
MAX_ACTIVE_FEATURES = 64
575575

576576
layer = DoubleFeatureTransformerSlice(INPUT_SIZE, STRIDE).cuda()
577-
layer = torch.compile(layer, fullgraph=True, mode="max-autotune")
578577
indices0 = torch.cat(
579578
[
580579
torch.sort(
@@ -606,38 +605,12 @@ def bench():
606605
BATCH_SIZE, MAX_ACTIVE_FEATURES, dtype=torch.float32
607606
).cuda()
608607

609-
# Warmup
610-
output0, output1 = layer(indices0, values0, indices1, values1)
611-
output0 = torch.clamp(output0, 0.0, 1.0)
612-
output0 = output0.clone()
613-
output1 = output1.clone()
614-
output1 = torch.clamp(output1, 0.0, 1.0)
615-
g = ((output0 - output1) ** 2).mean()
616-
g.backward()
617-
torch.cuda.synchronize()
618-
619-
for _ in range(ITERS):
620-
torch.compiler.cudagraph_mark_step_begin()
621-
output0, output1 = layer(indices0, values0, indices1, values1)
622-
output0 = output0.clone()
623-
output1 = output1.clone()
624-
output0 = torch.clamp(output0, 0.0, 1.0)
625-
output1 = torch.clamp(output1, 0.0, 1.0)
626-
627-
g = ((output0 - output1) ** 2).mean()
628-
g.backward()
629-
630-
torch.cuda.synchronize()
631-
632608
output0, output1 = layer(indices0, values0, indices1, values1)
633609

634610
start = time.time()
635611

636612
for _ in range(ITERS):
637-
torch.compiler.cudagraph_mark_step_begin()
638613
output0, output1 = layer(indices0, values0, indices1, values1)
639-
output0 = output0.clone()
640-
output1 = output1.clone()
641614
output0 = torch.clamp(output0, 0.0, 1.0)
642615
output1 = torch.clamp(output1, 0.0, 1.0)
643616

0 commit comments

Comments
 (0)