Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 8 additions & 11 deletions flax/testing/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,10 @@

from flax import io

flags.DEFINE_string(
_BENCHMARK_OUTPUT_DIR = flags.DEFINE_string(
'benchmark_output_dir', default=None, help='Benchmark output directory.'
)


FLAGS = flags.FLAGS

_SCALAR_PLUGIN_NAME = (
summary_lib.scalar_pb('', 0).value[0].metadata.plugin_data.plugin_name
)
Expand Down Expand Up @@ -125,9 +122,10 @@ def __init__(self, *args, **kwargs):
patched_func = functools.partial(self._collect_assert_wrapper, fn=func)
setattr(self, func_name, patched_func)

self._benchmark_output_dir = _BENCHMARK_OUTPUT_DIR.value
# Create target directory if defined.
if FLAGS.benchmark_output_dir and not io.exists(FLAGS.benchmark_output_dir):
io.makedirs(FLAGS.benchmark_output_dir)
if self._benchmark_output_dir and not io.exists(self._benchmark_output_dir):
io.makedirs(self._benchmark_output_dir)

# pylint: disable=invalid-name
def _collect_assert_wrapper(self, *args, fn=None, **kwargs):
Expand Down Expand Up @@ -160,8 +158,8 @@ def get_tmp_model_dir(self):
if defined else uses a temporary directory. This helps to export summary
files to tensorboard as multiple separate runs for each test method.
"""
if FLAGS.benchmark_output_dir:
model_dir = FLAGS.benchmark_output_dir
if self._benchmark_output_dir:
model_dir = self._benchmark_output_dir
else:
model_dir = tempfile.mkdtemp()
model_dir_path = os.path.join(
Expand Down Expand Up @@ -281,8 +279,7 @@ def _report_benchmark_results(self):
logging.info(results_str)

# Maybe save results as a file for pickup by CI / monitoring frameworks.
benchmark_output_dir = FLAGS.benchmark_output_dir
if benchmark_output_dir:
filename = os.path.join(benchmark_output_dir, name + '.json')
if self._benchmark_output_dir:
filename = os.path.join(self._benchmark_output_dir, name + '.json')
with io.GFile(filename, 'w') as fout:
fout.write(results_str)
Loading