|
| 1 | +import os |
| 2 | +import pytest |
| 3 | +import numpy as np |
| 4 | + |
| 5 | +from plantcv.plantcv import Spectral_data |
| 6 | +from plantcv.plantcv.hyperspectral import read_data |
| 7 | +from plantcv.plantcv.hyperspectral import write_data |
| 8 | + |
| 9 | +def test_write_data_default(tmpdir): |
| 10 | + """Test for PlantCV.""" |
| 11 | + rng = np.random.default_rng() |
| 12 | + |
| 13 | + # Create a test tmp directory |
| 14 | + cache_dir = tmpdir.mkdir("cache") |
| 15 | + |
| 16 | + lines = 32 |
| 17 | + samples = 32 |
| 18 | + bands = 5 |
| 19 | + |
| 20 | + # Create random array data in the interval [0-65535] and wavelengths in the |
| 21 | + # interval [400-1000) |
| 22 | + rand_array = rng.integers(0, 65535, size=(lines, samples, bands), dtype=np.uint16, endpoint=True) |
| 23 | + rand_wavelengths = np.sort(600.0*rng.random(size=bands) + 400.0) |
| 24 | + # Create dictionary of wavelengths |
| 25 | + wavelength_dict = {} |
| 26 | + for j, wavelength in enumerate(rand_wavelengths): |
| 27 | + wavelength_dict.update({wavelength: float(j)}) |
| 28 | + |
| 29 | + # Create spectral data object |
| 30 | + rand_spectral_array = Spectral_data(array_data=rand_array, |
| 31 | + max_wavelength=rand_wavelengths[-1], |
| 32 | + min_wavelength=rand_wavelengths[0], |
| 33 | + max_value=float(np.amax(rand_array)), |
| 34 | + min_value=float(np.amin(rand_array)), |
| 35 | + d_type=rand_array.dtype, |
| 36 | + wavelength_dict=wavelength_dict, |
| 37 | + samples=samples, |
| 38 | + lines=lines, |
| 39 | + interleave='bil', |
| 40 | + wavelength_units='nm', |
| 41 | + array_type="datacube", |
| 42 | + pseudo_rgb=None, |
| 43 | + filename='random_hyperspectral_test', |
| 44 | + default_bands=None) |
| 45 | + |
| 46 | + |
| 47 | + filename = os.path.join(cache_dir, 'plantcv_hyperspectral_write_data.raw') |
| 48 | + write_data(filename=filename, spectral_data=rand_spectral_array) |
| 49 | + |
| 50 | + # Read written hyperspectral image |
| 51 | + array_data = read_data(filename=filename) |
| 52 | + assert np.shape(array_data.array_data) == (lines, samples, bands) |
0 commit comments