Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion external/vcm/vcm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
parse_timestep_str_from_path,
parse_datetime_from_str,
)
from .calc import mass_integrate
from .calc import mass_integrate, r2_score, local_time
from .calc.thermo import (
net_heating,
net_precipitation,
Expand Down
3 changes: 2 additions & 1 deletion external/vcm/vcm/calc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .advect import storage_and_advection
from .calc import apparent_heating, apparent_source, mass_integrate
from .calc import apparent_heating, apparent_source, mass_integrate, local_time
from .metrics import r2_score
from .q_terms import compute_Q_terms

Expand All @@ -10,4 +10,5 @@
"mass_integrate",
"r2_score",
"compute_Q_terms",
"local_time",
]
4 changes: 1 addition & 3 deletions external/vcm/vcm/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
This module is for functions that select subsets of the data
"""
import numpy as np
import warnings

from vcm.cubedsphere.constants import (
COORD_X_CENTER,
Expand All @@ -20,8 +19,7 @@ def mask_to_surface_type(ds, surface_type):
Returns:
input dataset masked to the surface_type specified
"""
if surface_type in ["none", "None", None]:
warnings.warn("surface_type provided as None: no mask applied.")
if surface_type in ["none", "None", None, "global"]:
return ds
elif surface_type not in ["sea", "land", "seaice"]:
raise ValueError("Must mask to surface_type in ['sea', 'land', 'seaice'].")
Expand Down
6 changes: 6 additions & 0 deletions fv3net/diagnostics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .create_report import create_report
from .data import (
merge_comparison_datasets,
get_latlon_grid_coords_set,
EXAMPLE_CLIMATE_LATLON_COORDS,
)
64 changes: 60 additions & 4 deletions fv3net/diagnostics/data_funcs.py → fv3net/diagnostics/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@


def merge_comparison_datasets(
data_vars, datasets, dataset_labels, grid, additional_dataset=None
data_vars,
datasets,
dataset_labels,
grid,
concat_dim_name="dataset",
additional_dataset=None,
):
""" Makes a comparison dataset out of multiple datasets that all have a common
data variable. They are concatenated with a new dim "dataset" that can be used
Expand All @@ -31,17 +36,22 @@ def merge_comparison_datasets(
is the coords for the "dataset" dimension
grid: xr dataset with lat/lon grid vars
additional_data: xr data array, any additional data (e.g. slmsk) to merge along
with data arrays and grid
with data arrays and grid.

Returns:
Dataset with new dataset dimension to denote the target vs predicted
quantities. It is unstacked into the original x,y, time dimensions.
"""

src_dim_index = pd.Index(dataset_labels, name="dataset")
src_dim_index = pd.Index(dataset_labels, name=concat_dim_name)
datasets = [drop_nondim_coords(ds) for ds in datasets]
# if one of the datasets is missing data variable(s) that are in the others,
# fill it with an empty data array
_add_missing_data_vars(data_vars, datasets)
datasets_to_merge = [
xr.concat([ds[data_vars].squeeze(drop=True) for ds in datasets], src_dim_index),
xr.concat(
[ds[data_vars].squeeze(drop=True) for ds in datasets], dim=src_dim_index
),
grid,
]
if additional_dataset is not None:
Expand Down Expand Up @@ -126,3 +136,49 @@ def net_heating_from_dataset(ds: xr.Dataset, suffix: str = None) -> xr.DataArray
ds["PRATEsfc" + suffix],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hard code here. I expect this name will change in future versions.

)
return vcm.net_heating(*fluxes)


def _add_empty_dataarray(ds, template_dataarray):
""" Adds an empty data array with the dimensions of the example
data array to the dataset. This is useful when concatenating mulitple
datasets where one does not have a data variable.
ex. concating prediction/target/highres datasets for
plotting comparisons, where the high res data does not have 3D variables.

Args:
ds (xarray dataset): dataset that will have additional empty data array added
example_dataarray (data array with the desired dimensions)

Returns:
original xarray dataset with an empty array assigned to the
template name dataarray.

"""
da_fill = np.empty(template_dataarray.shape)
da_fill[:] = np.nan
return ds.assign({template_dataarray.name: (template_dataarray.dims, da_fill)})


def _add_missing_data_vars(data_vars, datasets):
""" Checks if any dataset in a list to be concated is missing a data variable,
and returns of kwargs to be provided to _add_empty_dataarray

Args:
data_vars (list[str]): full list of data vars for final concated ds
datasets ([type]): datasets to check again

Returns:
List of dicts {"ds": dataset that needs empty datarray added,
"example_dataarray": example of data array with dims}
This can be passed as kwargs to _add_empty_dataarray
"""
for data_var in data_vars:
array_var = None
for ds in datasets:
if data_var in list(ds.data_vars):
array_var = ds[data_var]
if array_var is None:
raise ValueError(f"None of the datasets contain data array for {data_var}.")
for i in range(len(datasets)):
if data_var not in list(datasets[i].data_vars):
datasets[i] = _add_empty_dataarray(datasets[i], array_var)
77 changes: 69 additions & 8 deletions fv3net/diagnostics/sklearn_model_performance/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,45 @@
from vcm.cloud.fsspec import get_fs, get_protocol
from vcm.cloud.gsutil import copy
from vcm.cubedsphere.constants import INIT_TIME_DIM
from fv3net.diagnostics.sklearn_model_performance.data_funcs_sklearn import (

from ..create_report import create_report
from ..data import merge_comparison_datasets
from .data import (
predict_on_test_data,
load_high_res_diag_dataset,
add_column_heating_moistening,
)
from fv3net.diagnostics.sklearn_model_performance.plotting_sklearn import make_all_plots
from fv3net.diagnostics.create_report import create_report
from .diagnostics import plot_diagnostics
from .create_metrics import create_metrics_dataset
from .plot_metrics import plot_metrics

DATA_VARS = [
"dQ1",
"dQ2",
"sphum",
"T",
"tsea",
"net_precipitation",
"net_heating",
"net_precipitation_physics",
"net_heating_physics",
"net_precipitation_ml",
"net_heating_ml",
"delp",
]
DATASET_NAME_PREDICTION = "prediction"
DATASET_NAME_FV3_TARGET = "C48_target"
DATASET_NAME_SHIELD_HIRES = "coarsened_high_res"

DPI_FIGURES = {
"LTS": 100,
"dQ2_pressure_profiles": 100,
"R2_pressure_profiles": 100,
"diurnal_cycle": 90,
"map_plot_3col": 120,
"map_plot_single": 100,
}

TEMP_OUTPUT_DIR = "temp_sklearn_prediction_report_output"

if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -79,16 +109,47 @@
args.model_type,
args.downsample_time_factor,
)

fs_input = get_fs(args.test_data_path)
fs_output = get_fs(args.output_path)

add_column_heating_moistening(ds_test)
add_column_heating_moistening(ds_pred)
init_times = list(set(ds_test[INIT_TIME_DIM].values))
ds_hires = load_high_res_diag_dataset(args.high_res_data_path, init_times)

grid_path = os.path.join(os.path.dirname(args.test_data_path), "grid_spec.zarr")
fs_input = get_fs(args.test_data_path)

grid = xr.open_zarr(fs_input.get_mapper(grid_path))
report_sections = make_all_plots(ds_pred, ds_test, ds_hires, grid, output_dir)
create_report(report_sections, "ml_model_predict_diagnostics", output_dir)
slmsk = ds_test["slmsk"].isel({INIT_TIME_DIM: 0})

ds = merge_comparison_datasets(
data_vars=DATA_VARS,
datasets=[ds_pred, ds_test, ds_hires],
dataset_labels=[
DATASET_NAME_PREDICTION,
DATASET_NAME_FV3_TARGET,
DATASET_NAME_SHIELD_HIRES,
],
grid=grid,
additional_dataset=slmsk,
)
# separate datasets will now have common grid/sfc_type variables and
# an identifying dataset coordinate
ds_pred = ds.sel(dataset=DATASET_NAME_PREDICTION)
ds_test = ds.sel(dataset=DATASET_NAME_FV3_TARGET)
ds_hires = ds.sel(dataset=DATASET_NAME_SHIELD_HIRES)

ds_metrics = create_metrics_dataset(ds_pred, ds_test, ds_hires)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! this structure is pretty clear.

ds_metrics.to_netcdf(os.path.join(output_dir, "metrics.nc"))
metrics_plot_sections = plot_metrics(ds_metrics, output_dir, DPI_FIGURES)

diag_report_sections = plot_diagnostics(
ds_pred, ds_test, ds_hires, output_dir=output_dir, dpi_figures=DPI_FIGURES
)

combined_report_sections = {**metrics_plot_sections, **diag_report_sections}
create_report(combined_report_sections, "ml_offline_diagnostics", output_dir)

fs_output = get_fs(args.output_path)
if proto == "gs":
copy(output_dir, args.output_path)
Expand Down
94 changes: 94 additions & 0 deletions fv3net/diagnostics/sklearn_model_performance/create_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import numpy as np
import xarray as xr

import vcm
from vcm.calc import r2_score
from vcm.cubedsphere.regridz import regrid_to_common_pressure
from vcm.cubedsphere.constants import (
INIT_TIME_DIM,
COORD_X_CENTER,
COORD_Y_CENTER,
PRESSURE_GRID,
GRID_VARS,
)

STACK_DIMS = ["tile", INIT_TIME_DIM, COORD_X_CENTER, COORD_Y_CENTER]
SAMPLE_DIM = "sample"


def create_metrics_dataset(ds_pred, ds_fv3, ds_shield):

ds_metrics = _r2_global_values(ds_pred, ds_fv3, ds_shield)
for grid_var in GRID_VARS:
ds_metrics[grid_var] = ds_pred[grid_var]

for sfc_type in ["global", "sea", "land"]:
for var in ["dQ1", "dQ2"]:
ds_metrics[
f"r2_{var}_pressure_levels_{sfc_type}"
] = _r2_pressure_level_metrics(
vcm.mask_to_surface_type(ds_fv3, sfc_type)[var],
vcm.mask_to_surface_type(ds_pred, sfc_type)[var],
vcm.mask_to_surface_type(ds_fv3, sfc_type)["delp"],
)
# add a coordinate for target datasets so that the plot_metrics functions
# can use it for labels
ds_metrics = ds_metrics.assign_coords(
{
"target_dataset_names": [
ds_target.dataset.values.item() for ds_target in [ds_fv3, ds_shield]
]
}
)
for var in ["net_precipitation", "net_heating"]:
for ds_target in [ds_fv3, ds_shield]:
target_label = ds_target.dataset.values.item()
ds_metrics[
f"rmse_{var}_vs_{target_label}"
] = _root_mean_squared_error_metrics(ds_target[var], ds_pred[var])

return ds_metrics


def _root_mean_squared_error_metrics(da_target, da_pred):
rmse = np.sqrt((da_target - da_pred) ** 2).mean(INIT_TIME_DIM)
return rmse


def _r2_pressure_level_metrics(da_target, da_pred, delp):
pressure = np.array(PRESSURE_GRID) / 100
target = regrid_to_common_pressure(da_target, delp).stack(sample=STACK_DIMS)
prediction = regrid_to_common_pressure(da_pred, delp).stack(sample=STACK_DIMS)
da = xr.DataArray(
r2_score(target, prediction, "sample"),
dims=["pressure"],
coords={"pressure": pressure},
)
return da


def _r2_global_values(ds_pred, ds_fv3, ds_shield):
""" Calculate global R^2 for net precipitation and heating against
target FV3 dataset and coarsened high res dataset

Args:
ds ([type]): [description]

Returns:
[type]: [description]
"""
r2_summary = xr.Dataset()
for var in ["net_heating", "net_precipitation"]:
for sfc_type in ["global", "sea", "land"]:
for ds_target in [ds_fv3, ds_shield]:
target_label = ds_target.dataset.values.item()
r2_summary[f"R2_{sfc_type}_{var}_vs_{target_label}"] = r2_score(
vcm.mask_to_surface_type(ds_target, sfc_type)[var].stack(
sample=STACK_DIMS
),
vcm.mask_to_surface_type(ds_pred, sfc_type)[var].stack(
sample=STACK_DIMS
),
"sample",
).values.item()
return r2_summary
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import xarray as xr

import fv3net
from ..data_funcs import net_heating_from_dataset
from ..data import net_heating_from_dataset
from fv3net.pipelines.create_training_data import (
SUFFIX_COARSE_TRAIN_DIAG,
VAR_Q_HEATING_ML,
Expand All @@ -16,6 +16,7 @@
INIT_TIME_DIM,
COORD_X_CENTER,
COORD_Y_CENTER,
COORD_Z_CENTER,
TILE_COORDS,
)
from vcm.regrid import regrid_to_shared_coords
Expand Down Expand Up @@ -145,7 +146,7 @@ def lower_tropospheric_stability(ds):
[70000],
pressure,
regrid_dim_name="p700mb",
replace_dim_name="pfull",
replace_dim_name=COORD_Z_CENTER,
)
.squeeze()
.drop("p700mb")
Expand Down
Loading