Skip to content

Commit a205bc5

Browse files
SamuelLarkinroedoejet
authored andcommitted
feat: Making the configs' properties relative paths.
1 parent 6e9dde1 commit a205bc5

18 files changed

+533
-81
lines changed

everyvoice/config/preprocessing_config.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
from typing import List, Optional, Union
55

66
from loguru import logger
7-
from pydantic import Field, FilePath, field_validator, model_validator
7+
from pydantic import Field, FilePath, ValidationInfo, field_validator, model_validator
88

9-
from everyvoice.config.shared_types import ConfigModel
9+
from everyvoice.config.shared_types import ConfigModel, PartialLoadConfig, init_context
1010
from everyvoice.config.utils import (
1111
PossiblyRelativePath,
1212
PossiblySerializedCallable,
@@ -50,7 +50,7 @@ class PitchCalculationMethod(Enum):
5050
cwt = "cwt"
5151

5252

53-
class Dataset(ConfigModel):
53+
class Dataset(PartialLoadConfig):
5454
label: str = "YourDataSet"
5555
data_dir: PossiblyRelativePath = Path("/please/create/a/path/to/your/dataset/data")
5656
textgrid_dir: Union[PossiblyRelativePath, None] = None
@@ -60,8 +60,17 @@ class Dataset(ConfigModel):
6060
filelist_loader: PossiblySerializedCallable = generic_dict_loader
6161
sox_effects: list = [["channels", "1"]]
6262

63+
@field_validator(
64+
"data_dir",
65+
"textgrid_dir",
66+
"filelist",
67+
)
68+
@classmethod
69+
def relative_to_absolute(cls, value: Path, info: ValidationInfo) -> Path:
70+
return PartialLoadConfig.path_relative_to_absolute(value, info)
71+
6372

64-
class PreprocessingConfig(ConfigModel):
73+
class PreprocessingConfig(PartialLoadConfig):
6574
dataset: str = "YourDataSet"
6675
pitch_type: Union[
6776
PitchCalculationMethod, str
@@ -76,9 +85,16 @@ class PreprocessingConfig(ConfigModel):
7685
path_to_audio_config_file: Optional[FilePath] = None
7786
source_data: List[Dataset] = Field(default_factory=lambda: [Dataset()])
7887

79-
@model_validator(mode="before")
80-
def load_partials(self):
81-
return load_partials(self, ["audio"])
88+
@model_validator(mode="before") # type: ignore
89+
def load_partials(self, info: ValidationInfo):
90+
config_path = (
91+
info.context.get("config_path", None) if info.context is not None else None
92+
)
93+
return load_partials(
94+
self, # type: ignore
95+
("audio",),
96+
config_path=config_path,
97+
)
8298

8399
@field_validator("save_dir", mode="after")
84100
def create_dir(cls, value: Path):
@@ -93,4 +109,6 @@ def create_dir(cls, value: Path):
93109
def load_config_from_path(path: Path) -> "PreprocessingConfig":
94110
"""Load a config from a path"""
95111
config = load_config_from_json_or_yaml_path(path)
96-
return PreprocessingConfig(**config)
112+
with init_context({"config_path": path}):
113+
config = PreprocessingConfig(**config)
114+
return config

everyvoice/config/shared_types.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,35 @@
11
from collections.abc import Mapping, Sequence
2+
from contextlib import contextmanager
3+
from contextvars import ContextVar
24
from functools import cached_property
35
from pathlib import Path
4-
from typing import Tuple, Union
6+
from typing import Any, Dict, Iterator, Tuple, Union
57

68
from loguru import logger
7-
from pydantic import BaseModel, ConfigDict, DirectoryPath, Field, validator
9+
from pydantic import (
10+
BaseModel,
11+
ConfigDict,
12+
DirectoryPath,
13+
Field,
14+
ValidationInfo,
15+
field_validator,
16+
validator,
17+
)
818

919
from everyvoice.config.utils import PossiblyRelativePath, PossiblySerializedCallable
1020
from everyvoice.utils import generic_dict_loader, get_current_time, rel_path_to_abs_path
1121

22+
_init_context_var = ContextVar("_init_context_var", default=None)
23+
24+
25+
@contextmanager
26+
def init_context(value: Dict[str, Any]) -> Iterator[None]:
27+
token = _init_context_var.set(value) # type: ignore
28+
try:
29+
yield
30+
finally:
31+
_init_context_var.reset(token)
32+
1233

1334
class ConfigModel(BaseModel):
1435
model_config = ConfigDict(
@@ -48,7 +69,26 @@ def combine_configs(orig_dict: Union[dict, Sequence], new_dict: dict):
4869
return orig_dict
4970

5071

51-
class LoggerConfig(ConfigModel):
72+
class PartialLoadConfig(ConfigModel):
73+
"""Models that have partial models which requires a context to properly load."""
74+
75+
# [Using validation context with BaseModel initialization](https://docs.pydantic.dev/2.3/usage/validators/#using-validation-context-with-basemodel-initialization)
76+
def __init__(__pydantic_self__, **data: Any) -> None:
77+
__pydantic_self__.__pydantic_validator__.validate_python(
78+
data,
79+
self_instance=__pydantic_self__,
80+
context=_init_context_var.get(),
81+
)
82+
83+
@classmethod
84+
def path_relative_to_absolute(cls, value: Path, info: ValidationInfo) -> Path:
85+
if info.context and value is not None and not value.is_absolute():
86+
config_path = info.context.get("config_path", Path("."))
87+
value = (config_path / value).resolve()
88+
return value
89+
90+
91+
class LoggerConfig(PartialLoadConfig):
5292
"""The logger configures all the information needed for where to store your experiment's logs and checkpoints.
5393
The structure of your logs will then be:
5494
<name> / <version> / <sub_dir>
@@ -67,6 +107,11 @@ class LoggerConfig(ConfigModel):
67107
version: str = "base"
68108
"""The version of your experiment"""
69109

110+
@field_validator("save_dir")
111+
@classmethod
112+
def relative_to_absolute(cls, value: Path, info: ValidationInfo) -> Path:
113+
return PartialLoadConfig.path_relative_to_absolute(value, info)
114+
70115
@cached_property
71116
def sub_dir(self) -> str:
72117
return self.sub_dir_callable()
@@ -83,7 +128,7 @@ def convert_path(cls, v, values):
83128
return path
84129

85130

86-
class BaseTrainingConfig(ConfigModel):
131+
class BaseTrainingConfig(PartialLoadConfig):
87132
batch_size: int = 16
88133
save_top_k_ckpts: int = 5
89134
ckpt_steps: Union[int, None] = None
@@ -102,6 +147,11 @@ class BaseTrainingConfig(ConfigModel):
102147
val_data_workers: int = 0
103148
train_data_workers: int = 4
104149

150+
@field_validator("training_filelist", "validation_filelist")
151+
@classmethod
152+
def relative_to_absolute(cls, value: Path, info: ValidationInfo) -> Path:
153+
return PartialLoadConfig.path_relative_to_absolute(value, info)
154+
105155

106156
class BaseOptimizer(ConfigModel):
107157
learning_rate: float = 1e-4

everyvoice/config/utils.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from importlib import import_module
22
from pathlib import Path
3-
from typing import Any, Callable, Dict, List, Union
3+
from typing import Any, Callable, Dict, Optional, Sequence, Union
44

55
from loguru import logger
66
from pydantic import PlainSerializer, WithJsonSchema
@@ -10,10 +10,15 @@
1010
from everyvoice.utils import load_config_from_json_or_yaml_path, rel_path_to_abs_path
1111

1212

13-
def load_partials(pre_validated_model_dict: Dict[Any, Any], partial_keys: List[str]):
14-
"""Loads all partials based on a list of partial keys. For this to work, your model
15-
must have a {key}_config_file: Optional[FilePath] = None field defined, and you must
16-
have a model_validator(mode="before") that runs this function.
13+
def load_partials(
14+
pre_validated_model_dict: Dict[Any, Any],
15+
partial_keys: Sequence[str],
16+
config_path: Optional[Path] = None,
17+
):
18+
"""Loads all partials based on a list of partial keys. For this to work,
19+
your model must have a {key}_config_file: Optional[FilePath] = None field
20+
defined, and you must have a model_validator(mode="before") that runs this
21+
function.
1722
"""
1823
# If there's nothing there, just return the dict
1924
if not pre_validated_model_dict:
@@ -25,9 +30,10 @@ def load_partials(pre_validated_model_dict: Dict[Any, Any], partial_keys: List[s
2530
key_for_path_to_partial in pre_validated_model_dict
2631
and pre_validated_model_dict[key_for_path_to_partial]
2732
):
28-
subconfig_path = rel_path_to_abs_path(
29-
pre_validated_model_dict[key_for_path_to_partial]
30-
)
33+
subconfig_path = Path(pre_validated_model_dict[key_for_path_to_partial])
34+
if not subconfig_path.is_absolute() and config_path is not None:
35+
subconfig_path = (config_path.parent / subconfig_path).resolve()
36+
pre_validated_model_dict[key_for_path_to_partial] = subconfig_path
3137
# anything defined in the key will override the path
3238
# so audio would override any values in path_to_audio_config_file
3339
if key in pre_validated_model_dict:

everyvoice/model/e2e/config/__init__.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
from pathlib import Path
22
from typing import Optional, Union
33

4-
from pydantic import Field, FilePath, model_validator
4+
from pydantic import Field, FilePath, ValidationInfo, model_validator
55

6-
from everyvoice.config.shared_types import BaseTrainingConfig, ConfigModel
6+
from everyvoice.config.shared_types import (
7+
BaseTrainingConfig,
8+
PartialLoadConfig,
9+
init_context,
10+
)
711
from everyvoice.config.utils import PossiblyRelativePath, load_partials
812
from everyvoice.model.aligner.config import AlignerConfig
913
from everyvoice.model.feature_prediction.config import FeaturePredictionConfig
@@ -16,7 +20,7 @@ class E2ETrainingConfig(BaseTrainingConfig):
1620
vocoder_checkpoint: Union[None, PossiblyRelativePath] = None
1721

1822

19-
class EveryVoiceConfig(ConfigModel):
23+
class EveryVoiceConfig(PartialLoadConfig):
2024
aligner: AlignerConfig = Field(default_factory=AlignerConfig)
2125
path_to_aligner_config_file: Optional[FilePath] = None
2226

@@ -31,10 +35,15 @@ class EveryVoiceConfig(ConfigModel):
3135
training: E2ETrainingConfig = Field(default_factory=E2ETrainingConfig)
3236
path_to_training_config_file: Optional[FilePath] = None
3337

34-
@model_validator(mode="before")
35-
def load_partials(self):
38+
@model_validator(mode="before") # type: ignore
39+
def load_partials(self, info: ValidationInfo):
40+
config_path = (
41+
info.context.get("config_path", None) if info.context is not None else None
42+
)
3643
return load_partials(
37-
self, ["aligner", "feature_prediction", "vocoder", "training"]
44+
self, # type: ignore
45+
("aligner", "feature_prediction", "vocoder", "training"),
46+
config_path=config_path,
3847
)
3948

4049
@staticmethod
@@ -43,4 +52,6 @@ def load_config_from_path(
4352
) -> "EveryVoiceConfig":
4453
"""Load a config from a path"""
4554
config = load_config_from_json_or_yaml_path(path)
46-
return EveryVoiceConfig(**config)
55+
with init_context({"config_path": path}):
56+
config = EveryVoiceConfig(**config)
57+
return config

everyvoice/run_tests.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from loguru import logger
1111

1212
from everyvoice.tests.test_cli import CLITest
13-
from everyvoice.tests.test_configs import ConfigTest
13+
from everyvoice.tests.test_configs import ConfigTest, LoadConfigTest
1414
from everyvoice.tests.test_dataloader import DataLoaderTest
1515
from everyvoice.tests.test_model import ModelTest
1616
from everyvoice.tests.test_preprocessing import (
@@ -25,7 +25,9 @@
2525

2626
LOADER = TestLoader()
2727

28-
CONFIG_TESTS = [LOADER.loadTestsFromTestCase(test) for test in [ConfigTest]]
28+
CONFIG_TESTS = [
29+
LOADER.loadTestsFromTestCase(test) for test in [ConfigTest, LoadConfigTest]
30+
]
2931

3032
DATALOADER_TESTS = [LOADER.loadTestsFromTestCase(test) for test in [DataLoaderTest]]
3133

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
model: {conv_dim: 512, lstm_dim: 512}
2+
path_to_preprocessing_config_file: everyvoice-shared-data.yaml
3+
path_to_text_config_file: everyvoice-shared-text.yaml
4+
training:
5+
batch_size: 16
6+
binned_sampler: true
7+
ckpt_epochs: 1
8+
extraction_method: dijkstra
9+
filelist_loader: everyvoice.utils.generic_dict_loader
10+
logger: {name: AlignerExperiment, save_dir: ../logs_and_checkpoints, sub_dir_callable: everyvoice.utils.get_current_time,
11+
version: base}
12+
max_epochs: 1000
13+
max_steps: 100000
14+
optimizer:
15+
betas: [0.9, 0.98]
16+
eps: 1.0e-08
17+
learning_rate: 0.0001
18+
name: adamw
19+
weight_decay: 0.01
20+
plot_steps: 1000
21+
save_top_k_ckpts: 5
22+
train_data_workers: 4
23+
training_filelist: ../preprocessed/training_filelist.psv
24+
val_data_workers: 0
25+
validation_filelist: ../preprocessed/validation_filelist.psv
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
audio: {alignment_bit_depth: 16, alignment_sampling_rate: 22050, f_max: 8000, f_min: 0,
2+
fft_hop_frames: 256, fft_window_frames: 1024, input_sampling_rate: 22050, max_audio_length: 11.0,
3+
max_wav_value: 32767.0, min_audio_length: 0.4, n_fft: 1024, n_mels: 80, norm_db: -3.0,
4+
output_sampling_rate: 22050, sil_duration: 0.1, sil_threshold: 1.0, spec_type: mel-librosa,
5+
target_bit_depth: 16, vocoder_segment_size: 8192}
6+
dataset: relative
7+
dataset_split_seed: 1234
8+
energy_phone_averaging: true
9+
pitch_phone_averaging: true
10+
pitch_type: pyworld
11+
save_dir: ../preprocessed
12+
source_data:
13+
- data_dir: ../../lj/wavs
14+
filelist: ../r-filelist.psv
15+
filelist_loader: everyvoice.utils.generic_dict_loader
16+
label: dataset_0
17+
sox_effects:
18+
- [channel, '1']
19+
train_split: 0.9
20+
value_separator: --

0 commit comments

Comments
 (0)