Skip to content

Commit 46f7c67

Browse files
committed
Fix helper function tests
1 parent b0648fd commit 46f7c67

1 file changed

Lines changed: 44 additions & 45 deletions

File tree

tests/torchtune/training/checkpointing/test_checkpointer_utils.py

Lines changed: 44 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import os
78
from copy import deepcopy
89
from pathlib import Path
910

@@ -278,88 +279,86 @@ def test_output_dir_ckpt_dir_few_levels_down(self):
278279
class TestGetAllCheckpointsInDir:
279280
"""Series of tests for the ``get_all_checkpoints_in_dir`` function."""
280281

281-
def test_get_all_ckpts_simple(self, tmp_dir):
282-
ckpt_dir_0 = tmp_dir / "epoch_0"
282+
def test_get_all_ckpts_simple(self, tmpdir):
283+
tmpdir = Path(tmpdir)
284+
ckpt_dir_0 = tmpdir / "epoch_0"
283285
ckpt_dir_0.mkdir(parents=True, exist_ok=True)
284286

285-
ckpt_dir_1 = tmp_dir / "epoch_1"
286-
ckpt_dir_1.mkdir(parents=True, exist_ok=True)
287+
ckpt_dir_1 = tmpdir / "epoch_1"
288+
ckpt_dir_1.mkdir()
287289

288-
all_ckpts = get_all_checkpoints_in_dir(tmp_dir)
290+
all_ckpts = get_all_checkpoints_in_dir(tmpdir)
289291
assert len(all_ckpts) == 2
290-
assert all_ckpts == [ckpt_dir_0, ckpt_dir_1]
292+
assert ckpt_dir_0 in all_ckpts
293+
assert ckpt_dir_1 in all_ckpts
291294

292-
def test_get_all_ckpts_with_pattern_that_matches_some(self, tmp_dir):
295+
def test_get_all_ckpts_with_pattern_that_matches_some(self, tmpdir):
293296
"""Test that we only return the checkpoints that match the pattern."""
294-
ckpt_dir_0 = tmp_dir / "epoch_0"
297+
tmpdir = Path(tmpdir)
298+
ckpt_dir_0 = tmpdir / "epoch_0"
295299
ckpt_dir_0.mkdir(parents=True, exist_ok=True)
296300

297-
ckpt_dir_1 = tmp_dir / "step_1"
298-
ckpt_dir_1.mkdir(parents=True, exist_ok=True)
301+
ckpt_dir_1 = tmpdir / "step_1"
302+
ckpt_dir_1.mkdir()
299303

300-
all_ckpts = get_all_checkpoints_in_dir(tmp_dir)
304+
all_ckpts = get_all_checkpoints_in_dir(tmpdir)
301305
assert len(all_ckpts) == 1
302306
assert all_ckpts == [ckpt_dir_0]
303307

304-
def test_get_all_ckpts_override_pattern(self, tmp_dir):
308+
def test_get_all_ckpts_override_pattern(self, tmpdir):
305309
"""Test that we can override the default pattern and it works."""
306-
ckpt_dir_0 = tmp_dir / "epoch_0"
310+
tmpdir = Path(tmpdir)
311+
ckpt_dir_0 = tmpdir / "epoch_0"
307312
ckpt_dir_0.mkdir(parents=True, exist_ok=True)
308313

309-
ckpt_dir_1 = tmp_dir / "step_1"
310-
ckpt_dir_1.mkdir(parents=True, exist_ok=True)
314+
ckpt_dir_1 = tmpdir / "step_1"
315+
ckpt_dir_1.mkdir()
311316

312-
all_ckpts = get_all_checkpoints_in_dir(tmp_dir, pattern="step_*")
317+
all_ckpts = get_all_checkpoints_in_dir(tmpdir, pattern="step_*")
313318
assert len(all_ckpts) == 1
314319
assert all_ckpts == [ckpt_dir_1]
315320

316-
def test_get_all_ckpts_only_return_dirs(self, tmp_dir):
321+
def test_get_all_ckpts_only_return_dirs(self, tmpdir):
317322
"""Test that even if a file matches the pattern, we only return directories."""
318-
ckpt_dir_0 = tmp_dir / "epoch_0"
323+
tmpdir = Path(tmpdir)
324+
ckpt_dir_0 = tmpdir / "epoch_0"
319325
ckpt_dir_0.mkdir(parents=True, exist_ok=True)
320326

321-
file = tmp_dir / "epoch_1"
322-
ckpt_dir_1.touch()
327+
file = tmpdir / "epoch_1"
328+
file.touch()
323329

324-
all_ckpts = get_all_checkpoints_in_dir(tmp_dir)
330+
all_ckpts = get_all_checkpoints_in_dir(tmpdir)
325331
assert len(all_ckpts) == 1
326332
assert all_ckpts == [ckpt_dir_0]
327333

328-
def test_get_all_ckpts_non_unique(self, tmp_dir):
329-
"""Test that we return all checkpoints, even if they have the same name."""
330-
ckpt_dir_0 = tmp_dir / "epoch_0"
331-
ckpt_dir_0.mkdir(parents=True, exist_ok=True)
332-
333-
ckpt_dir_1 = tmp_dir / "epoch_0"
334-
ckpt_dir_1.mkdir(parents=True, exist_ok=True)
335-
336-
all_ckpts = get_all_checkpoints_in_dir(tmp_dir)
337-
assert len(all_ckpts) == 2
338-
assert all_ckpts == [ckpt_dir_0, ckpt_dir_1]
339-
340334

341335
class TestPruneSurplusCheckpoints:
342336
"""Series of tests for the ``prune_surplus_checkpoints`` function."""
343337

344-
def test_prune_surplus_checkpoints_simple(self, tmp_dir):
345-
ckpt_dir_0 = tmp_dir / "epoch_0"
338+
def test_prune_surplus_checkpoints_simple(self, tmpdir):
339+
tmpdir = Path(tmpdir)
340+
ckpt_dir_0 = tmpdir / "epoch_0"
346341
ckpt_dir_0.mkdir(parents=True, exist_ok=True)
347342

348-
ckpt_dir_1 = tmp_dir / "epoch_1"
349-
ckpt_dir_1.mkdir(parents=True, exist_ok=True)
343+
ckpt_dir_1 = tmpdir / "epoch_1"
344+
ckpt_dir_1.mkdir()
350345

351-
prune_surplus_checkpoints(tmp_dir, 1)
352-
remaining_ckpts = os.listdir(tmp_dir)
346+
prune_surplus_checkpoints([ckpt_dir_0, ckpt_dir_1], 1)
347+
remaining_ckpts = os.listdir(tmpdir)
353348
assert len(remaining_ckpts) == 1
354349
assert remaining_ckpts == ["epoch_1"]
355350

356-
def test_prune_surplus_checkpoints_keep_last_invalid(self, tmp_dir):
351+
def test_prune_surplus_checkpoints_keep_last_invalid(self, tmpdir):
357352
"""Test that we raise an error if keep_last_n_checkpoints is not >= 1"""
358-
ckpt_dir_0 = tmp_dir / "epoch_0"
353+
tmpdir = Path(tmpdir)
354+
ckpt_dir_0 = tmpdir / "epoch_0"
359355
ckpt_dir_0.mkdir(parents=True, exist_ok=True)
360356

361-
ckpt_dir_1 = tmp_dir / "epoch_1"
362-
ckpt_dir_1.mkdir(parents=True, exist_ok=True)
357+
ckpt_dir_1 = tmpdir / "epoch_1"
358+
ckpt_dir_1.mkdir()
363359

364-
with pytest.raises(ValueError, match="keep_last_n_checkpoints must be >= 1"):
365-
prune_surplus_checkpoints(tmp_dir, 0)
360+
with pytest.raises(
361+
ValueError,
362+
match="keep_last_n_checkpoints must be greater than or equal to 1",
363+
):
364+
prune_surplus_checkpoints([ckpt_dir_0, ckpt_dir_1], 0)

0 commit comments

Comments
 (0)