Skip to content

Commit 4b756da

Browse files
committed
feat(CLI): Creating a new CLI to view checkpoint's info.
1 parent e00cf32 commit 4b756da

File tree

4 files changed

+181
-0
lines changed

4 files changed

+181
-0
lines changed

everyvoice/base_cli/checkpoint.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
"""
2+
CLI command to inspect EveryVoice's checkpoints.
3+
"""
4+
import json
5+
import sys
6+
import warnings
7+
from enum import Enum
8+
from json import JSONEncoder
9+
from pathlib import Path
10+
from typing import Any, Dict
11+
12+
import typer
13+
import yaml
14+
from pydantic import BaseModel
15+
from typing_extensions import Annotated
16+
17+
from everyvoice.model.feature_prediction.FastSpeech2_lightning.fs2.model import (
18+
FastSpeech2,
19+
)
20+
from everyvoice.model.vocoder.HiFiGAN_iSTFT_lightning.hfgl.model import HiFiGAN
21+
22+
app = typer.Typer(
23+
pretty_exceptions_show_locals=False,
24+
help="Extract checkpoint's hyperparameters.",
25+
)
26+
27+
28+
class ExportType(str, Enum):
29+
"""
30+
Available export format for the configuration.
31+
"""
32+
33+
JSON = "json"
34+
YAML = "yaml"
35+
36+
37+
class CheckpointEncoder(JSONEncoder):
38+
"""
39+
Helper JSON Encoder for missing `torch.Tensor` & `pydantic.BaseModel`.
40+
"""
41+
42+
def default(self, obj: Any):
43+
"""
44+
Extends json to handle `torch.Tensor` and `pydantic.BaseModel`.
45+
"""
46+
import torch
47+
48+
if isinstance(obj, torch.Tensor):
49+
return list(obj.shape)
50+
elif isinstance(obj, BaseModel):
51+
return json.loads(obj.json())
52+
return super().default(obj)
53+
54+
55+
def load_checkpoint(model_path: Path) -> Dict[str, Any]:
56+
"""
57+
Loads a checkpoint and performs minor clean up of the checkpoint.
58+
Removes the `optimizer_states`'s `state` and `param_groups`'s `params`.
59+
Removes `state_dict` from the checkpoint.
60+
"""
61+
import torch
62+
63+
checkpoint = torch.load(str(model_path), map_location=torch.device("cpu"))
64+
65+
# Some clean up of useless stuff.
66+
if "optimizer_states" in checkpoint:
67+
for optimizer in checkpoint["optimizer_states"]:
68+
# Delete the optimizer history values.
69+
if "state" in optimizer:
70+
del optimizer["state"]
71+
# These are simply values [0, len(checkpoint["optimizer_states"][0]["state"])].
72+
for param_group in optimizer["param_groups"]:
73+
if "params" in param_group:
74+
del param_group["params"]
75+
76+
if "state_dict" in checkpoint:
77+
del checkpoint["state_dict"]
78+
79+
if "loops" in checkpoint:
80+
del checkpoint["loops"]
81+
82+
return checkpoint
83+
84+
85+
@app.command()
86+
def inspect(
87+
model_path: Path = typer.Argument(
88+
...,
89+
exists=True,
90+
dir_okay=False,
91+
file_okay=True,
92+
help="The path to your model checkpoint file.",
93+
),
94+
export_type: ExportType = ExportType.YAML,
95+
show_config: Annotated[
96+
bool,
97+
typer.Option(
98+
"--show-config/--no-show-config", # noqa
99+
"-c/-C", # noqa
100+
help="Show the configuration used during training in either json or yaml format", # noqa
101+
),
102+
] = True,
103+
show_architecture: Annotated[
104+
bool,
105+
typer.Option(
106+
"--show-architecture/--no-show-architecture", # noqa
107+
"-a/-A", # noqa
108+
help="Show the model's architecture", # noqa
109+
),
110+
] = True,
111+
show_weights: Annotated[
112+
bool,
113+
typer.Option(
114+
"--show-weights/--no-show-weights", # noqa
115+
"-w/-W", # noqa
116+
help="Show the number of weights per layer", # noqa
117+
),
118+
] = True,
119+
):
120+
"""
121+
Given an EveryVoice checkpoint, show information about the configuration
122+
used during training, the model's architecture and the number of weights
123+
per layer and total weight count.
124+
"""
125+
checkpoint = load_checkpoint(model_path)
126+
127+
if show_config:
128+
print("Configs:")
129+
if export_type is ExportType.JSON:
130+
json.dump(
131+
checkpoint,
132+
sys.stdout,
133+
ensure_ascii=False,
134+
indent=2,
135+
cls=CheckpointEncoder,
136+
)
137+
elif export_type is ExportType.YAML:
138+
output = json.loads(json.dumps(checkpoint, cls=CheckpointEncoder))
139+
yaml.dump(output, stream=sys.stdout)
140+
else:
141+
raise NotImplementedError(f"Unsupported export type {export_type}!")
142+
143+
if show_architecture:
144+
with warnings.catch_warnings():
145+
warnings.simplefilter("ignore")
146+
try:
147+
model = HiFiGAN.load_from_checkpoint(model_path)
148+
# NOTE if ANY exception is raise, that means the model couldn't be
149+
# loaded and we want to try another config type. This is to "ask
150+
# forgiveness, not permission".
151+
except Exception:
152+
try:
153+
model = FastSpeech2.load_from_checkpoint(model_path)
154+
except Exception:
155+
raise NotImplementedError(
156+
"Your checkpoint contains a model type that is not yet supported!"
157+
)
158+
print("\n\nModel Architecture:\n", model, sep="")
159+
160+
if show_weights:
161+
from torchinfo import summary
162+
163+
statistics = summary(model, None, verbose=0)
164+
print("\nModel's Weights:\n", statistics)
165+
# According to Aidan (1, 80, 50) should be a valid input size but it looks
166+
# like the model is expecting a Dict which isn't supported by torchsummary.
167+
# print(summary(model, (1, 80, 50)))

everyvoice/cli.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from everyvoice.model.aligner.wav2vec2aligner.aligner.cli import (
1111
align_single as ctc_segment,
1212
)
13+
from everyvoice.model.e2e.config import EveryVoiceConfig
14+
from everyvoice.model.feature_prediction.config import FeaturePredictionConfig
15+
from everyvoice.base_cli.checkpoint import inspect as inspect_checkpoint
1316
from everyvoice.model.feature_prediction.FastSpeech2_lightning.fs2.cli import (
1417
preprocess as preprocess_fs2,
1518
)
@@ -201,6 +204,10 @@ def new_project():
201204
short_help="Synthesize using your pre-trained EveryVoice models",
202205
)
203206

207+
app.command(
208+
name="inspect-checkpoint",
209+
short_help="Extract structural information from a checkpoint",
210+
)(inspect_checkpoint)
204211

205212
class TestSuites(str, Enum):
206213
all = "all"

everyvoice/tests/test_cli.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def setUp(self) -> None:
2020
"train",
2121
"synthesize",
2222
"preprocess",
23+
"inspect-checkpoint",
2324
]
2425

2526
def test_commands_present(self):
@@ -46,6 +47,11 @@ def test_update_schema(self):
4647
)
4748
)
4849

50+
def test_inspect_checkpoint(self):
51+
result = self.runner.invoke(app, ["inspect-checkpoint", "--help"])
52+
self.assertIn("inspect-checkpoint [OPTIONS] MODEL_PATH",
53+
result.stdout)
54+
4955

5056
if __name__ == "__main__":
5157
main()

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,5 @@ simple-term-menu==1.5.2
2222
setuptools==59.5.0 # https://github.com/pytorch/pytorch/issues/69894
2323
tabulate==0.8.10
2424
tensorboard>=2.14.1
25+
torchinfo==1.8.0
2526
typer[all]>=0.9.0

0 commit comments

Comments
 (0)