Skip to content

Commit abfc8b4

Browse files
committed
Benchmarking infrastructure
1 parent 0709da3 commit abfc8b4

File tree

2 files changed

+87
-20
lines changed

2 files changed

+87
-20
lines changed

src/main/java/org/apache/sysds/api/DMLScript.java

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import java.util.Date;
3636
import java.util.Map;
3737
import java.util.Scanner;
38+
import java.util.function.BiConsumer;
3839
import java.util.function.Function;
3940

4041
import org.apache.commons.cli.AlreadySelectedException;
@@ -422,6 +423,7 @@ public static void loadConfiguration(String fnameOptConfig) throws IOException {
422423
}
423424

424425
public static Function<DMLProgram, Boolean> hopInterceptor = null;
426+
public static BiConsumer<Long, Long> runtimeMetricsInterceptor = null;
425427
/**
426428
* The running body of DMLScript execution. This method should be called after execution properties have been correctly set,
427429
* and customized parameters have been put into _argVals
@@ -460,11 +462,10 @@ private static void execute(String dmlScriptStr, String fnameOptConfig, Map<Stri
460462

461463
//init working directories (before usage by following compilation steps)
462464
initHadoopExecution( ConfigurationManager.getDMLConfig() );
463-
465+
466+
long startMillis1 = System.currentTimeMillis();
464467
//Step 5: rewrite HOP DAGs (incl IPA and memory estimates)
465-
long startMillis = System.currentTimeMillis();
466468
dmlt.rewriteHopsDAG(prog);
467-
System.out.println("Rewrite procedure took: " + (System.currentTimeMillis() - startMillis) + "ms");
468469

469470
if (hopInterceptor != null && !hopInterceptor.apply(prog))
470471
return;
@@ -496,7 +497,13 @@ private static void execute(String dmlScriptStr, String fnameOptConfig, Map<Stri
496497
ExecutionContext ec = null;
497498
try {
498499
ec = ExecutionContextFactory.createContext(rtprog);
500+
long startMillis2 = System.currentTimeMillis();
499501
ScriptExecutorUtils.executeRuntimeProgram(rtprog, ec, ConfigurationManager.getDMLConfig(), STATISTICS ? STATISTICS_COUNT : 0, null);
502+
503+
if (runtimeMetricsInterceptor != null) {
504+
long endMillis = System.currentTimeMillis();
505+
runtimeMetricsInterceptor.accept(endMillis - startMillis1, endMillis - startMillis2);
506+
}
500507
}
501508
finally {
502509
//cleanup scratch_space and all working dirs

src/test/java/org/apache/sysds/test/AutomatedTestBase.java

Lines changed: 77 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
import java.net.InetSocketAddress;
3232
import java.net.ServerSocket;
3333
import java.nio.charset.Charset;
34+
import java.nio.file.Files;
35+
import java.nio.file.Paths;
3436
import java.util.ArrayList;
3537
import java.util.Arrays;
3638
import java.util.HashMap;
@@ -39,6 +41,7 @@
3941
import java.util.Set;
4042
import java.util.concurrent.TimeoutException;
4143

44+
import org.apache.commons.collections.ArrayStack;
4245
import org.apache.commons.io.FileUtils;
4346
import org.apache.commons.io.IOUtils;
4447
import org.apache.commons.lang3.ArrayUtils;
@@ -91,6 +94,7 @@
9194
import org.junit.After;
9295
import org.junit.Assert;
9396
import org.junit.Before;
97+
import scala.Tuple4;
9498

9599
/**
96100
* <p>
@@ -106,9 +110,49 @@
106110
*
107111
*/
108112
public abstract class AutomatedTestBase {
113+
protected static final boolean BENCHMARK = true;
114+
protected static final int BENCHMARK_WARMUP_RUNS = 1;
115+
protected static final int BENCHMARK_REPETITIONS = 1;
116+
protected static final boolean ALLOW_GENERATED_REWRITES = true;
117+
protected static final String BASE_DATA_DIR = "/Users/janniklindemann/Dev/MScThesis/NGramAnalysis/";
118+
private static String currentTestName = "";
119+
private static int currentTestRun = -1;
120+
private static boolean benchmark_run = false;
121+
109122

110123
static {
111124
RewriterRuntimeUtils.setupIfNecessary();
125+
126+
if (BENCHMARK) {
127+
final List<Tuple4<String, Integer, Long, Long>> runTimes = new ArrayList<>();
128+
129+
DMLScript.runtimeMetricsInterceptor = (runTime, executionTime) -> {
130+
if (benchmark_run)
131+
runTimes.add(new Tuple4<>(currentTestName, currentTestRun, runTime, executionTime));
132+
};
133+
134+
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
135+
StringBuilder csvBuilder = new StringBuilder();
136+
csvBuilder.append("TestName,TestRun,RunTimeMS,ExecTimeMS\n");
137+
138+
for (Tuple4<String, Integer, Long, Long> entry : runTimes) {
139+
csvBuilder.append(entry._1());
140+
csvBuilder.append(',');
141+
csvBuilder.append(entry._2());
142+
csvBuilder.append(',');
143+
csvBuilder.append(entry._3());
144+
csvBuilder.append(',');
145+
csvBuilder.append(entry._4());
146+
csvBuilder.append('\n');
147+
}
148+
149+
try {
150+
Files.writeString(Paths.get(BASE_DATA_DIR + "runtimes.csv"), csvBuilder.toString());
151+
} catch (IOException e) {
152+
e.printStackTrace();
153+
}
154+
}));
155+
}
112156
}
113157

114158
private static final Log LOG = LogFactory.getLog(AutomatedTestBase.class.getName());
@@ -196,7 +240,6 @@ public String getCodgenConfig() {
196240
protected static ExecMode rtplatform = ExecMode.HYBRID;
197241

198242
protected static final boolean DEBUG = false;
199-
protected static final boolean ALLOW_GENERATED_REWRITES = true;
200243

201244
public static boolean VERBOSE_STATS = false;
202245

@@ -1397,23 +1440,40 @@ protected ByteArrayOutputStream runTest(boolean newWay, boolean exceptionExpecte
13971440
String errMessage, int maxSparkInst) {
13981441
try{
13991442
final List<ByteArrayOutputStream> out = new ArrayList<>();
1400-
Thread t = new Thread(
1401-
() -> out.add(runTestWithTimeout(newWay, exceptionExpected, expectedException, errMessage, maxSparkInst)),
1402-
"TestRunner_main");
1403-
Thread.UncaughtExceptionHandler h = new Thread.UncaughtExceptionHandler() {
1404-
@Override
1405-
public void uncaughtException(Thread th, Throwable ex) {
1406-
fail("Thread Failed test with message: " +ex.getMessage());
1407-
}
1408-
};
1409-
t.setUncaughtExceptionHandler(h);
1410-
t.start();
14111443

1412-
t.join(TEST_TIMEOUT * 1000);
1413-
if(t.isAlive())
1414-
throw new TimeoutException("Test failed to finish in time");
1415-
if(out.size() <= 0) // hack in case the test failed return empty string.
1416-
fail("test failed");
1444+
if (currentTestName == null || !currentTestName.equals(this.getClass().getSimpleName())) {
1445+
currentTestRun = 1;
1446+
currentTestName = this.getClass().getSimpleName();
1447+
} else {
1448+
currentTestRun++;
1449+
}
1450+
1451+
int totalReps = BENCHMARK_WARMUP_RUNS + BENCHMARK_REPETITIONS;
1452+
1453+
for (int i = 0; i < totalReps; i++) {
1454+
out.clear();
1455+
Statistics.reset();
1456+
1457+
benchmark_run = BENCHMARK && i >= BENCHMARK_WARMUP_RUNS;
1458+
1459+
Thread t = new Thread(
1460+
() -> out.add(runTestWithTimeout(newWay, exceptionExpected, expectedException, errMessage, maxSparkInst)),
1461+
"TestRunner_main");
1462+
Thread.UncaughtExceptionHandler h = new Thread.UncaughtExceptionHandler() {
1463+
@Override
1464+
public void uncaughtException(Thread th, Throwable ex) {
1465+
fail("Thread Failed test with message: " + ex.getMessage());
1466+
}
1467+
};
1468+
t.setUncaughtExceptionHandler(h);
1469+
t.start();
1470+
1471+
t.join(TEST_TIMEOUT * 1000);
1472+
if (t.isAlive())
1473+
throw new TimeoutException("Test failed to finish in time");
1474+
if (out.size() <= 0) // hack in case the test failed return empty string.
1475+
fail("test failed");
1476+
}
14171477

14181478
return out.get(0);
14191479
}

0 commit comments

Comments
 (0)