|
4 | 4 | # This source code is licensed under the BSD-style license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
| 7 | +import os |
7 | 8 | from copy import deepcopy |
8 | 9 | from pathlib import Path |
9 | 10 |
|
@@ -278,88 +279,86 @@ def test_output_dir_ckpt_dir_few_levels_down(self): |
278 | 279 | class TestGetAllCheckpointsInDir: |
279 | 280 | """Series of tests for the ``get_all_checkpoints_in_dir`` function.""" |
280 | 281 |
|
281 | | - def test_get_all_ckpts_simple(self, tmp_dir): |
282 | | - ckpt_dir_0 = tmp_dir / "epoch_0" |
| 282 | + def test_get_all_ckpts_simple(self, tmpdir): |
| 283 | + tmpdir = Path(tmpdir) |
| 284 | + ckpt_dir_0 = tmpdir / "epoch_0" |
283 | 285 | ckpt_dir_0.mkdir(parents=True, exist_ok=True) |
284 | 286 |
|
285 | | - ckpt_dir_1 = tmp_dir / "epoch_1" |
286 | | - ckpt_dir_1.mkdir(parents=True, exist_ok=True) |
| 287 | + ckpt_dir_1 = tmpdir / "epoch_1" |
| 288 | + ckpt_dir_1.mkdir() |
287 | 289 |
|
288 | | - all_ckpts = get_all_checkpoints_in_dir(tmp_dir) |
| 290 | + all_ckpts = get_all_checkpoints_in_dir(tmpdir) |
289 | 291 | assert len(all_ckpts) == 2 |
290 | | - assert all_ckpts == [ckpt_dir_0, ckpt_dir_1] |
| 292 | + assert ckpt_dir_0 in all_ckpts |
| 293 | + assert ckpt_dir_1 in all_ckpts |
291 | 294 |
|
292 | | - def test_get_all_ckpts_with_pattern_that_matches_some(self, tmp_dir): |
| 295 | + def test_get_all_ckpts_with_pattern_that_matches_some(self, tmpdir): |
293 | 296 | """Test that we only return the checkpoints that match the pattern.""" |
294 | | - ckpt_dir_0 = tmp_dir / "epoch_0" |
| 297 | + tmpdir = Path(tmpdir) |
| 298 | + ckpt_dir_0 = tmpdir / "epoch_0" |
295 | 299 | ckpt_dir_0.mkdir(parents=True, exist_ok=True) |
296 | 300 |
|
297 | | - ckpt_dir_1 = tmp_dir / "step_1" |
298 | | - ckpt_dir_1.mkdir(parents=True, exist_ok=True) |
| 301 | + ckpt_dir_1 = tmpdir / "step_1" |
| 302 | + ckpt_dir_1.mkdir() |
299 | 303 |
|
300 | | - all_ckpts = get_all_checkpoints_in_dir(tmp_dir) |
| 304 | + all_ckpts = get_all_checkpoints_in_dir(tmpdir) |
301 | 305 | assert len(all_ckpts) == 1 |
302 | 306 | assert all_ckpts == [ckpt_dir_0] |
303 | 307 |
|
304 | | - def test_get_all_ckpts_override_pattern(self, tmp_dir): |
| 308 | + def test_get_all_ckpts_override_pattern(self, tmpdir): |
305 | 309 | """Test that we can override the default pattern and it works.""" |
306 | | - ckpt_dir_0 = tmp_dir / "epoch_0" |
| 310 | + tmpdir = Path(tmpdir) |
| 311 | + ckpt_dir_0 = tmpdir / "epoch_0" |
307 | 312 | ckpt_dir_0.mkdir(parents=True, exist_ok=True) |
308 | 313 |
|
309 | | - ckpt_dir_1 = tmp_dir / "step_1" |
310 | | - ckpt_dir_1.mkdir(parents=True, exist_ok=True) |
| 314 | + ckpt_dir_1 = tmpdir / "step_1" |
| 315 | + ckpt_dir_1.mkdir() |
311 | 316 |
|
312 | | - all_ckpts = get_all_checkpoints_in_dir(tmp_dir, pattern="step_*") |
| 317 | + all_ckpts = get_all_checkpoints_in_dir(tmpdir, pattern="step_*") |
313 | 318 | assert len(all_ckpts) == 1 |
314 | 319 | assert all_ckpts == [ckpt_dir_1] |
315 | 320 |
|
316 | | - def test_get_all_ckpts_only_return_dirs(self, tmp_dir): |
| 321 | + def test_get_all_ckpts_only_return_dirs(self, tmpdir): |
317 | 322 | """Test that even if a file matches the pattern, we only return directories.""" |
318 | | - ckpt_dir_0 = tmp_dir / "epoch_0" |
| 323 | + tmpdir = Path(tmpdir) |
| 324 | + ckpt_dir_0 = tmpdir / "epoch_0" |
319 | 325 | ckpt_dir_0.mkdir(parents=True, exist_ok=True) |
320 | 326 |
|
321 | | - file = tmp_dir / "epoch_1" |
322 | | - ckpt_dir_1.touch() |
| 327 | + file = tmpdir / "epoch_1" |
| 328 | + file.touch() |
323 | 329 |
|
324 | | - all_ckpts = get_all_checkpoints_in_dir(tmp_dir) |
| 330 | + all_ckpts = get_all_checkpoints_in_dir(tmpdir) |
325 | 331 | assert len(all_ckpts) == 1 |
326 | 332 | assert all_ckpts == [ckpt_dir_0] |
327 | 333 |
|
328 | | - def test_get_all_ckpts_non_unique(self, tmp_dir): |
329 | | - """Test that we return all checkpoints, even if they have the same name.""" |
330 | | - ckpt_dir_0 = tmp_dir / "epoch_0" |
331 | | - ckpt_dir_0.mkdir(parents=True, exist_ok=True) |
332 | | - |
333 | | - ckpt_dir_1 = tmp_dir / "epoch_0" |
334 | | - ckpt_dir_1.mkdir(parents=True, exist_ok=True) |
335 | | - |
336 | | - all_ckpts = get_all_checkpoints_in_dir(tmp_dir) |
337 | | - assert len(all_ckpts) == 2 |
338 | | - assert all_ckpts == [ckpt_dir_0, ckpt_dir_1] |
339 | | - |
340 | 334 |
|
341 | 335 | class TestPruneSurplusCheckpoints: |
342 | 336 | """Series of tests for the ``prune_surplus_checkpoints`` function.""" |
343 | 337 |
|
344 | | - def test_prune_surplus_checkpoints_simple(self, tmp_dir): |
345 | | - ckpt_dir_0 = tmp_dir / "epoch_0" |
| 338 | + def test_prune_surplus_checkpoints_simple(self, tmpdir): |
| 339 | + tmpdir = Path(tmpdir) |
| 340 | + ckpt_dir_0 = tmpdir / "epoch_0" |
346 | 341 | ckpt_dir_0.mkdir(parents=True, exist_ok=True) |
347 | 342 |
|
348 | | - ckpt_dir_1 = tmp_dir / "epoch_1" |
349 | | - ckpt_dir_1.mkdir(parents=True, exist_ok=True) |
| 343 | + ckpt_dir_1 = tmpdir / "epoch_1" |
| 344 | + ckpt_dir_1.mkdir() |
350 | 345 |
|
351 | | - prune_surplus_checkpoints(tmp_dir, 1) |
352 | | - remaining_ckpts = os.listdir(tmp_dir) |
| 346 | + prune_surplus_checkpoints([ckpt_dir_0, ckpt_dir_1], 1) |
| 347 | + remaining_ckpts = os.listdir(tmpdir) |
353 | 348 | assert len(remaining_ckpts) == 1 |
354 | 349 | assert remaining_ckpts == ["epoch_1"] |
355 | 350 |
|
356 | | - def test_prune_surplus_checkpoints_keep_last_invalid(self, tmp_dir): |
| 351 | + def test_prune_surplus_checkpoints_keep_last_invalid(self, tmpdir): |
357 | 352 | """Test that we raise an error if keep_last_n_checkpoints is not >= 1""" |
358 | | - ckpt_dir_0 = tmp_dir / "epoch_0" |
| 353 | + tmpdir = Path(tmpdir) |
| 354 | + ckpt_dir_0 = tmpdir / "epoch_0" |
359 | 355 | ckpt_dir_0.mkdir(parents=True, exist_ok=True) |
360 | 356 |
|
361 | | - ckpt_dir_1 = tmp_dir / "epoch_1" |
362 | | - ckpt_dir_1.mkdir(parents=True, exist_ok=True) |
| 357 | + ckpt_dir_1 = tmpdir / "epoch_1" |
| 358 | + ckpt_dir_1.mkdir() |
363 | 359 |
|
364 | | - with pytest.raises(ValueError, match="keep_last_n_checkpoints must be >= 1"): |
365 | | - prune_surplus_checkpoints(tmp_dir, 0) |
| 360 | + with pytest.raises( |
| 361 | + ValueError, |
| 362 | + match="keep_last_n_checkpoints must be greater than or equal to 1", |
| 363 | + ): |
| 364 | + prune_surplus_checkpoints([ckpt_dir_0, ckpt_dir_1], 0) |
0 commit comments