|
42 | 42 |
|
43 | 43 | from flax import io |
44 | 44 |
|
45 | | -flags.DEFINE_string( |
| 45 | +_BENCHMARK_OUTPUT_DIR = flags.DEFINE_string( |
46 | 46 | 'benchmark_output_dir', default=None, help='Benchmark output directory.' |
47 | 47 | ) |
48 | 48 |
|
49 | | - |
50 | | -FLAGS = flags.FLAGS |
51 | | - |
52 | 49 | _SCALAR_PLUGIN_NAME = ( |
53 | 50 | summary_lib.scalar_pb('', 0).value[0].metadata.plugin_data.plugin_name |
54 | 51 | ) |
@@ -125,9 +122,10 @@ def __init__(self, *args, **kwargs): |
125 | 122 | patched_func = functools.partial(self._collect_assert_wrapper, fn=func) |
126 | 123 | setattr(self, func_name, patched_func) |
127 | 124 |
|
| 125 | + self._benchmark_output_dir = _BENCHMARK_OUTPUT_DIR.value |
128 | 126 | # Create target directory if defined. |
129 | | - if FLAGS.benchmark_output_dir and not io.exists(FLAGS.benchmark_output_dir): |
130 | | - io.makedirs(FLAGS.benchmark_output_dir) |
| 127 | + if self._benchmark_output_dir and not io.exists(self._benchmark_output_dir): |
| 128 | + io.makedirs(self._benchmark_output_dir) |
131 | 129 |
|
132 | 130 | # pylint: disable=invalid-name |
133 | 131 | def _collect_assert_wrapper(self, *args, fn=None, **kwargs): |
@@ -160,8 +158,8 @@ def get_tmp_model_dir(self): |
160 | 158 | if defined else uses a temporary directory. This helps to export summary |
161 | 159 | files to tensorboard as multiple separate runs for each test method. |
162 | 160 | """ |
163 | | - if FLAGS.benchmark_output_dir: |
164 | | - model_dir = FLAGS.benchmark_output_dir |
| 161 | + if self._benchmark_output_dir: |
| 162 | + model_dir = self._benchmark_output_dir |
165 | 163 | else: |
166 | 164 | model_dir = tempfile.mkdtemp() |
167 | 165 | model_dir_path = os.path.join( |
@@ -281,8 +279,7 @@ def _report_benchmark_results(self): |
281 | 279 | logging.info(results_str) |
282 | 280 |
|
283 | 281 | # Maybe save results as a file for pickup by CI / monitoring frameworks. |
284 | | - benchmark_output_dir = FLAGS.benchmark_output_dir |
285 | | - if benchmark_output_dir: |
286 | | - filename = os.path.join(benchmark_output_dir, name + '.json') |
| 282 | + if self._benchmark_output_dir: |
| 283 | + filename = os.path.join(self._benchmark_output_dir, name + '.json') |
287 | 284 | with io.GFile(filename, 'w') as fout: |
288 | 285 | fout.write(results_str) |
0 commit comments