-
Notifications
You must be signed in to change notification settings - Fork 29
Expand file tree
/
Copy pathconfig.py
More file actions
498 lines (415 loc) · 20.7 KB
/
config.py
File metadata and controls
498 lines (415 loc) · 20.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
import dataclass_wizard
import xarray as xr
from dataclass_wizard import JSONWizard
from deepdiff import DeepDiff
from packaging.version import Version
class InvalidConfigException(Exception):
pass
def validate_config(config_inputs):
"""
Validate that, in the config:
- either `variables` or `derived_variables` are present in the config
- if both `variables` and `derived_variables` are present, that they don't
add the same variables to the dataset
Parameters
----------
config_inputs: Dict[str, InputDataset]
Returns
-------
"""
for input_dataset_name, input_dataset in config_inputs.items():
if not input_dataset.variables and not input_dataset.derived_variables:
raise InvalidConfigException(
f"Input dataset '{input_dataset_name}' is missing the keys `variables` and/or"
" `derived_variables`. Make sure that you update the config so that the input"
f" dataset '{input_dataset_name}' contains at least either a `variables` or"
" `derived_variables` section."
)
elif input_dataset.variables and input_dataset.derived_variables:
# Check so that there are no overlapping variables
if isinstance(input_dataset.variables, list):
variable_vars = input_dataset.variables
elif isinstance(input_dataset.variables, dict):
variable_vars = input_dataset.variables.keys()
else:
raise TypeError(
f"Expected an instance of list or dict, but got {type(input_dataset.variables)}."
)
derived_variable_vars = input_dataset.derived_variables.keys()
common_vars = list(set(variable_vars) & set(derived_variable_vars))
if len(common_vars) > 0:
raise InvalidConfigException(
"Both `variables` and `derived_variables` include the following variables name(s):"
f" '{', '.join(common_vars)}'. This is not allowed. Make sure that there"
" are no overlapping variable names between `variables` and `derived_variables`,"
f" either by renaming or removing '{', '.join(common_vars)}' from one of them."
)
@dataclass
class Range:
"""
Defines a range for a variable to be used for selection, i.e.
`xarray.Dataset.sel({var_name}: slice({start}, {end}, {step}))`, the variable
name is the key in the dictionary and the slice object is created from the
`start`, `end`, and `step` attributes.
Attributes
----------
start: str
The start of the range, e.g. "1990-09-03T00:00", 0, or 0.0.
end: str
The end of the range, e.g. "1990-09-04T00:00", 1, or 1.0.
step: str
The step size for the range, e.g. "PT3H", 1, or 1.0. If not given
then the entire range will be selected.
"""
start: Union[str, int, float]
end: Union[str, int, float]
step: Optional[Union[str, int, float]] = None
@dataclass
class ValueSelection:
"""
Defines a selection on the coordinate values of a variable, the
`values` attribute can either be a list of values to select or a
`Range` object to select a range of values. This is used to create
a slice object for the selection. Optionally, the `units` attribute can be
used to specify the units of the values which will used to ensure that
the `units` attribute of the variable has the same value.
Attributes:
values: The values to select.
units: The units of the values.
"""
values: Union[List[Union[float, int]], Range]
units: Optional[str] = None
@dataclass
class DerivedVariable:
"""
Defines a derived variables, where the function (for calculating the variable) and
the kwargs (arguments to function) are specified. kwargs can contain both arguments
which should extract/select data from the input dataset, in which case they should
have the "ds_input." prefix to distinguish them from other argument that should not
be extracted from the dataset (e.g. a string to indicate if the sine or cosine
component should be extracted).
Optionally, attributes to the derived variable can be specified in `attrs`, e.g.
{"attrs": "units": "W*m**-2, "long_name": "top-of-the-atmosphere radiation"}.
In case a function does not return an `xr.DataArray` with the required attributes
(`units` and `long_name`) set, these have to be specified in `attrs`.
Attributes:
kwargs: Variables required for calculating the derived variable.
function: Function used to calculate the derived variable.
attrs: Attributes (e.g. `units` and `long_name`) to set for the derived variable.
"""
kwargs: Dict[str, str]
function: str
attrs: Optional[Dict[str, str]] = field(default_factory=dict)
@dataclass
class DimMapping:
"""
Defines the process for mapping dimensions and variables from an input
dataset to a single new dimension (as in dimension in the
output dataset of the dataset generation).
There are three methods implemented for mapping:
- "rename":
Renames a dimension in the dataset to a new name.
E.g. adding a dim-mapping as `{"time": {"method": "rename", "dim": "analysis_time"}}`
will rename the "analysis_time" dimension in the input dataset to "time" dimension in the output.
- "stack_variables_by_var_name":
Stacks all variables along a new dimension that is mapped to the output dimensions name given.
E.g. adding a dim-mapping as
`{"state_feature": {"method": "stack_variables_by_var_name", "name_format": "{var_name}{altitude}m", dims: [altitude]}}`
will stack all variables in the input dataset along the "state_feature" dimension in the output
and the coordinate values will be given as f"{var_name}{altitude}m" where `var_name` is the name
of the variable and `altitude` is the value of the "altitude" coordinate.
If any dimensions are specified in the `dims` attribute, then the these dimensions will
also be stacked into this new dimension, and the `name_format` attribute can be used to
use the coordinate values from the stacked dimensions in the new coordinate values.
- "stack":
Stacks the provided coordinates and maps the result to the output dimension.
E.g. `{"grid_index": {"method": "stack", "dims": ["x", "y"]}}` will stack the "x" and "y"
dimensions in the input dataset into a new "grid_index" dimension in the output.
Attributes:
method: The method used for mapping.
dims: The dimensions to be mapped.
name_format: The format for naming the mapped dimensions.
Attributes
----------
method: str
The method used for mapping. The options are:
- "rename": Renames a dimension in the dataset to a new name.
- "stack_variables_by_var_name": Stacks all variables along a new dimension that is mapped to the output dimensions name given.
- "stack": Stacks the provided coordinates and maps the result to the output dimension.
dims: List[str]
The dimensions to be mapped when using the "stack" or "stack_variables_by_var_name" methods.
dim: str
The dimension to be renamed when using the "rename" method.
name_format: str
The format for naming the mapped dimensions when using the "stack_variables_by_var_name" method.
"""
method: str
dims: Optional[List[str]] = None
dim: Optional[str] = None
name_format: Optional[str] = field(default=None)
coord_ranges: Optional[Dict[str, Range]] = field(default_factory=dict)
@dataclass
class InputDataset:
"""
Definition of a single input dataset which will be mapped to one the
variables that have been defined as output variables in the produced dataset
(i.e. the input variables for model architecture being targeted by the dataset).
The definition for a input dataset includes setting
1) the path to the dataset,
2) the expected dimensions of the dataset,
3) the variables to select from the dataset (and optionally subsection
along the coordinates for each variable) or the variables to derive
from the dataset, and finally
4) the method by which the dimensions and variables of the dataset are
mapped to one of the output variables (this includes stacking of all
the selected variables into a new single variable along a new coordinate,
and may include renaming and stacking dimensions existing dimensions).
Attributes
----------
path: str
Path to the dataset, e.g. the path to a zarr dataset or netCDF file.
This can be anything that can be passed to `xarray.open_dataset`
dims: List[str]
List of the expected dimensions of the dataset. E.g. `["time", "x", "y"]`.
These will be checked to ensure consistency of the dataset being read.
dim_mapping: Dict[str, DimMapping]
Mapping of the variables and dimensions in the input dataset to the dimensions of the
output variable (`target_output_variable`). The key is the name of the output dimension to map to
and the ´DimMapping´ describes how to map the dimensions and variables of the input dataset
to this input dimension for the output variable.
target_output_variable: str
The name of the output variable (i.e. the name of a variable that that is expected by
the architecture to exist in the training dataset). If multiple datasets map to the same variable,
then the data from all datasets will be concatenated along the dimension that isn't shared
(e.g. two datasets that coincide in space and time will only differ in the feature dimension,
so the two will be combined by concatenating along the feature dimension).
If a single shared coordinate cannot be found then an exception will be raised.
variables: Union[List[str], Dict[str, Dict[str, ValueSelection]]]
List of the variables to select from the dataset. E.g. `["temperature", "precipitation"]`
or a dictionary where the keys are the variable names and the values are dictionaries
defining the selection for each variable. E.g. `{"temperature": levels: {"values": [1000, 950, 900]}}`
would select the "temperature" variable and only the levels 1000, 950, and 900.
derived_variables: Dict[str, DerivedVariable]
Dictionary of variables to derive from the dataset, where the keys are the names variables will be given and
the values are `DerivedVariable` definitions that specify how to derive a variable.
"""
path: str
dims: List[str]
dim_mapping: Dict[str, DimMapping]
target_output_variable: str
variables: Optional[Union[List[str], Dict[str, Dict[str, ValueSelection]]]] = None
derived_variables: Optional[Dict[str, DerivedVariable]] = None
attributes: Optional[Dict[str, Any]] = field(default_factory=dict)
coord_ranges: Optional[Dict[str, Range]] = None
@dataclass
class Statistics:
"""
Define the statistics to compute for the output dataset, this includes defining
the the statistics to compute and the dimensions to compute the statistics over.
The statistics will be computed for each variable in the output dataset seperately.
Attributes
----------
ops: List[str]
The statistics to compute, e.g. ["mean", "std", "min", "max"].
dims: List[str]
The dimensions to compute the statistics over, e.g. ["time", "grid_index"].
"""
ops: List[str]
dims: List[str]
@dataclass
class Split:
"""
Define the `start` and `end` coordinate value (e.g. time) for a split of the dataset and optionally
the statistics to compute for the split.
Attributes
----------
start: str
The start of the split, e.g. "1990-09-03T00:00".
end: str
The end of the split, e.g. "1990-09-04T00:00".
compute_statistics: StatisticsInput
The statistics to compute for the split.
"""
start: str
end: str
compute_statistics: Optional[Statistics] = None
@dataclass
class Splitting:
"""
dim: str
The dimension to split the dataset along, e.g. "time", this must be provided if splits are defined.
splits: Dict[str, Split]
Defines the splits of the dataset, the keys are the names of the splits and the values
are the `Split` objects defining the start and end of the split. Optionally, the
`compute_statistics` attribute can be used to define the statistics to compute for the split.
"""
dim: str
splits: Dict[str, Split]
@dataclass
class ConvexHullCropping:
"""
Define the method applied for cropping the spatial domain before writing
the transformed output dataset. This is typically used when you want to
create a dataset to provide data in a boundary around a limited-area
domain.
The cropping is done by creating a convex hull around the spatial
coordinates of an *interior* dataset (this will typically be the
"limited-area" domain when doing Limited Area Modelling) and then including
all points that are within a margin of the convex hull boundary. In addition
to including the points inside the defined margin around the convex hull,
you can also include the points inside the convex hull of the interior
dataset by setting the `include_interior` attribute to `True`.
Attributes
----------
margin_width_degrees: float
The width (in degrees) of the margin applied to the convex hull
boundary of the interior dataset used to define the cropping domain.
interior_dataset_config_path: str
The path to the configuration file for the dataset defining the interior domain
include_interior_points: bool
Whether to include the points inside the convex hull of the interior dataset
"""
margin_width_degrees: float
interior_dataset_config_path: str
include_interior_points: bool = False
@dataclass
class Output:
"""
Definition of the output dataset that will be created by the dataset generation, you should
adapt this to the architecture of the model that you are going to using the dataset with. This
includes defining what input variables the architecture expects (and the dimensions of each),
the expected value range for each coordinate, and the chunking information for each dimension.
Attributes
----------
variables: Dict[str, List[str]]
Defines the variables of the produced output, i.e. the input variables for the model
architecture. The keys are the variable names to create and the values are lists of
the dimensions. E.g. `{"static": ["grid_index", "feature"], "state": ["time",
"grid_index", "state_feature"]}` would define that the architecture expects a variable
named "static" with dimensions "grid_index" and "feature" and a variable named "state" with
dimensions "time", "grid_index", and "state_feature".
coord_ranges: Dict[str, Range]
Defines the expected value range for each coordinate. The keys are the
name of the coordinate and the values are the range, e.g.
`{"time": {"start": "1990-09-03T00:00", "end": "1990-09-04T00:00", "step": "PT3H"}}`
would define that the "time" coordinate should have values between
"1990-09-03T00:00" and "1990-09-04T00:00" with a step size of 3 hours.
These range definitions are both used to ensure that the input dataset
has the expected range and to select the correct values from the input
dataset. If not given then the entire range will be selected.
chunking: Dict[str, int]
Defines the chunking information for each dimension. The keys are the
names of the dimensions and the values are the chunk size for that dimension.
If chunking is not specified for a dimension, then the entire dimension
will be a single chunk.
splitting: Splitting
Defines the splits of the dataset (e.g. train, test, validation), the dimension to split
the dataset along, and optionally the statistics to compute for each split.
domain_cropping: ConvexHullCropping
Defines the method applied for cropping the spatial domain before writing
the transformed output dataset. This is typically used when you want to
create a dataset to provide data in a boundary around a limited-area
domain.
"""
variables: Dict[str, List[str]]
coord_ranges: Dict[str, Range] = field(default_factory=dict)
chunking: Dict[str, int] = field(default_factory=dict)
splitting: Optional[Splitting] = None
domain_cropping: Optional[ConvexHullCropping] = None
@dataclass
class Config(dataclass_wizard.JSONWizard, dataclass_wizard.YAMLWizard):
"""Configuration for the model.
Attributes:
schema_version: Version of the config file schema.
dataset_version: Version of the dataset itself.
architecture: Information about the model architecture this dataset is intended for.
inputs: Input datasets for the model.
Attributes
----------
output: Output
Information about the structure of the output from mllam-data-prep, you should set this
to matchthe model architecture this dataset is intended for. This
covers defining what input variables the architecture expects (and the dimensions of each),
the expected value range for each coordinate, and the chunking information for each dimension.
inputs: Dict[str, InputDataset]
Input datasets for the model. The keys are the names of the datasets and the values are
the input dataset configurations.
extra: Dict[str, Any]
Extra information to include in the config file. This will be ignored by the
`mllam_data_prep` library, but can be used to include additional information
that is useful for the user.
schema_version: str
Version string for the config file schema.
dataset_version: str
Version string for the dataset itself.
"""
output: Output
inputs: Dict[str, InputDataset]
schema_version: str
dataset_version: str
extra: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
validate_config(self.inputs)
class _(JSONWizard.Meta):
raise_on_unknown_json_key = True
class UnsupportedMllamDataPrepVersion(Exception):
pass
def find_config_differences(
config: Config, ds_existing: xr.Dataset
) -> Union[None, dict]:
"""
Compare the provided config against the one the provided dataset is created
from (which is stored in the `creation_config` attribute), and return the
differences.
Parameters
----------
config : Config
The configuration object to compare against
ds_existing : xr.Dataset
The existing dataset to compare against
Returns
-------
Union[None, dict]
If the configurations are the same, returns None. If they are different, returns
a dictionary of the differences.
Raises
------
UnsupportedMllamDataPrepVersion
If the existing dataset was created with an older version of mllam-data-prep
that does not have the `creation_config` attribute
"""
required_mdp_version = Version("v0.6.0")
config_mdp_version = Version(ds_existing.attrs["mdp_version"])
if config_mdp_version < required_mdp_version:
raise UnsupportedMllamDataPrepVersion(
"The existing dataset was created with an older version of mllam-data-prep "
f"({config_mdp_version}), and does not have the creation_config attribute "
f"(added in v{required_mdp_version}). Please delete the existing dataset "
"or set overwrite='always' to overwrite it."
)
else:
existing_config_yaml = ds_existing.attrs.get("creation_config", None)
if existing_config_yaml is None:
raise ValueError(
"The provided dataset does not have a creation_config attribute"
)
existing_config = Config.from_yaml(existing_config_yaml)
if existing_config != config:
differences = DeepDiff(
existing_config.to_dict(), config.to_dict(), ignore_order=True
).to_dict()
return differences
return None
if __name__ == "__main__":
import argparse
argparser = argparse.ArgumentParser()
argparser.add_argument(
"-f", help="Path to the yaml file to load.", default="example.danra.yaml"
)
args = argparser.parse_args()
assert args.f.endswith(".yaml"), "Config file must have a .yaml extension."
config = Config.from_yaml_file(args.f)
import rich
rich.print(config)