Skip to content

Commit cbb4f80

Browse files
khoFlax Authors
authored andcommitted
Internal change
PiperOrigin-RevId: 780870062
1 parent 171b28e commit cbb4f80

File tree

1 file changed

+8
-11
lines changed

1 file changed

+8
-11
lines changed

flax/testing/benchmark.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,10 @@
4242

4343
from flax import io
4444

45-
flags.DEFINE_string(
45+
_BENCHMARK_OUTPUT_DIR = flags.DEFINE_string(
4646
'benchmark_output_dir', default=None, help='Benchmark output directory.'
4747
)
4848

49-
50-
FLAGS = flags.FLAGS
51-
5249
_SCALAR_PLUGIN_NAME = (
5350
summary_lib.scalar_pb('', 0).value[0].metadata.plugin_data.plugin_name
5451
)
@@ -125,9 +122,10 @@ def __init__(self, *args, **kwargs):
125122
patched_func = functools.partial(self._collect_assert_wrapper, fn=func)
126123
setattr(self, func_name, patched_func)
127124

125+
self._benchmark_output_dir = _BENCHMARK_OUTPUT_DIR.value
128126
# Create target directory if defined.
129-
if FLAGS.benchmark_output_dir and not io.exists(FLAGS.benchmark_output_dir):
130-
io.makedirs(FLAGS.benchmark_output_dir)
127+
if self._benchmark_output_dir and not io.exists(self._benchmark_output_dir):
128+
io.makedirs(self._benchmark_output_dir)
131129

132130
# pylint: disable=invalid-name
133131
def _collect_assert_wrapper(self, *args, fn=None, **kwargs):
@@ -160,8 +158,8 @@ def get_tmp_model_dir(self):
160158
if defined else uses a temporary directory. This helps to export summary
161159
files to tensorboard as multiple separate runs for each test method.
162160
"""
163-
if FLAGS.benchmark_output_dir:
164-
model_dir = FLAGS.benchmark_output_dir
161+
if self._benchmark_output_dir:
162+
model_dir = self._benchmark_output_dir
165163
else:
166164
model_dir = tempfile.mkdtemp()
167165
model_dir_path = os.path.join(
@@ -281,8 +279,7 @@ def _report_benchmark_results(self):
281279
logging.info(results_str)
282280

283281
# Maybe save results as a file for pickup by CI / monitoring frameworks.
284-
benchmark_output_dir = FLAGS.benchmark_output_dir
285-
if benchmark_output_dir:
286-
filename = os.path.join(benchmark_output_dir, name + '.json')
282+
if self._benchmark_output_dir:
283+
filename = os.path.join(self._benchmark_output_dir, name + '.json')
287284
with io.GFile(filename, 'w') as fout:
288285
fout.write(results_str)

0 commit comments

Comments
 (0)