diff --git a/flax/testing/benchmark.py b/flax/testing/benchmark.py index ea55db31b..a45759877 100644 --- a/flax/testing/benchmark.py +++ b/flax/testing/benchmark.py @@ -42,13 +42,10 @@ from flax import io -flags.DEFINE_string( +_BENCHMARK_OUTPUT_DIR = flags.DEFINE_string( 'benchmark_output_dir', default=None, help='Benchmark output directory.' ) - -FLAGS = flags.FLAGS - _SCALAR_PLUGIN_NAME = ( summary_lib.scalar_pb('', 0).value[0].metadata.plugin_data.plugin_name ) @@ -125,9 +122,10 @@ def __init__(self, *args, **kwargs): patched_func = functools.partial(self._collect_assert_wrapper, fn=func) setattr(self, func_name, patched_func) + self._benchmark_output_dir = _BENCHMARK_OUTPUT_DIR.value # Create target directory if defined. - if FLAGS.benchmark_output_dir and not io.exists(FLAGS.benchmark_output_dir): - io.makedirs(FLAGS.benchmark_output_dir) + if self._benchmark_output_dir and not io.exists(self._benchmark_output_dir): + io.makedirs(self._benchmark_output_dir) # pylint: disable=invalid-name def _collect_assert_wrapper(self, *args, fn=None, **kwargs): @@ -160,8 +158,8 @@ def get_tmp_model_dir(self): if defined else uses a temporary directory. This helps to export summary files to tensorboard as multiple separate runs for each test method. """ - if FLAGS.benchmark_output_dir: - model_dir = FLAGS.benchmark_output_dir + if self._benchmark_output_dir: + model_dir = self._benchmark_output_dir else: model_dir = tempfile.mkdtemp() model_dir_path = os.path.join( @@ -281,8 +279,7 @@ def _report_benchmark_results(self): logging.info(results_str) # Maybe save results as a file for pickup by CI / monitoring frameworks. - benchmark_output_dir = FLAGS.benchmark_output_dir - if benchmark_output_dir: - filename = os.path.join(benchmark_output_dir, name + '.json') + if self._benchmark_output_dir: + filename = os.path.join(self._benchmark_output_dir, name + '.json') with io.GFile(filename, 'w') as fout: fout.write(results_str)