| 
8 | 8 | import pytest  | 
9 | 9 | 
 
  | 
10 | 10 | from ... import InferenceData, from_dict  | 
 | 11 | +from ... import to_zarr, from_zarr  | 
11 | 12 | 
 
  | 
12 | 13 | from ..helpers import (  # pylint: disable=unused-import  | 
13 | 14 |     chains,  | 
@@ -103,3 +104,41 @@ def test_io_method(self, data, eight_schools_params, store, fill_attrs):  | 
103 | 104 |                 assert inference_data2.attrs["test"] == 1  | 
104 | 105 |             else:  | 
105 | 106 |                 assert "test" not in inference_data2.attrs  | 
 | 107 | + | 
 | 108 | +    def test_io_function(self, data, eight_schools_params):  | 
 | 109 | +        # create InferenceData and check it has been properly created  | 
 | 110 | +        inference_data = self.get_inference_data(  # pylint: disable=W0612  | 
 | 111 | +            data,  | 
 | 112 | +            eight_schools_params,  | 
 | 113 | +            fill_attrs=True,  | 
 | 114 | +        )  | 
 | 115 | +        test_dict = {  | 
 | 116 | +            "posterior": ["eta", "theta", "mu", "tau"],  | 
 | 117 | +            "posterior_predictive": ["eta", "theta", "mu", "tau"],  | 
 | 118 | +            "sample_stats": ["eta", "theta", "mu", "tau"],  | 
 | 119 | +            "prior": ["eta", "theta", "mu", "tau"],  | 
 | 120 | +            "prior_predictive": ["eta", "theta", "mu", "tau"],  | 
 | 121 | +            "sample_stats_prior": ["eta", "theta", "mu", "tau"],  | 
 | 122 | +            "observed_data": ["J", "y", "sigma"],  | 
 | 123 | +        }  | 
 | 124 | +        fails = check_multiple_attrs(test_dict, inference_data)  | 
 | 125 | +        assert not fails  | 
 | 126 | + | 
 | 127 | +        assert inference_data.attrs["test"] == 1  | 
 | 128 | + | 
 | 129 | +        # check filename does not exist and use to_zarr method  | 
 | 130 | +        with TemporaryDirectory(prefix="arviz_tests_") as tmp_dir:  | 
 | 131 | +            filepath = os.path.join(tmp_dir, "zarr")  | 
 | 132 | + | 
 | 133 | +            to_zarr(inference_data, store=filepath)  | 
 | 134 | +            # assert file has been saved correctly  | 
 | 135 | +            assert os.path.exists(filepath)  | 
 | 136 | +            assert os.path.getsize(filepath) > 0  | 
 | 137 | + | 
 | 138 | +            inference_data2 = from_zarr(filepath)  | 
 | 139 | + | 
 | 140 | +            # Everything in dict still available in inference_data2 ?  | 
 | 141 | +            fails = check_multiple_attrs(test_dict, inference_data2)  | 
 | 142 | +            assert not fails  | 
 | 143 | + | 
 | 144 | +            assert inference_data2.attrs["test"] == 1  | 
0 commit comments