Skip to content

Commit a9fe772

Browse files
author
Nikolas Schmitz
committed
Added function get_img_at_mpp to class TifffileWSIReader; changed resizing function to Image.resize, cucim.skimage.transform.resize
1 parent 88002e8 commit a9fe772

File tree

1 file changed

+130
-30
lines changed

1 file changed

+130
-30
lines changed

monai/data/wsi_reader.py

Lines changed: 130 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
import numpy as np
2121
import torch
22-
import cv2
2322

2423
from monai.config import DtypeLike, NdarrayOrTensor, PathLike
2524
from monai.data.image_reader import ImageReader, _stack_images
@@ -778,9 +777,14 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05
778777
779778
"""
780779

780+
cucim_resize, _ = optional_import("cucim.skimage.transform", name="resize")
781+
cp, _ = optional_import("cupy")
782+
781783
user_mpp_x, user_mpp_y = mpp
782784
mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(wsi.resolutions['level_count'])]
783-
closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) # Should not throw ValueError, instead just return the closest value;
785+
closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5)
786+
# -> Should not throw ValueError, instead just return the closest value; how to select tolerances?
787+
784788
mpp_closest_lvl = mpp_list[closest_lvl]
785789
closest_lvl_dim = wsi.resolutions['level_dimensions'][closest_lvl]
786790

@@ -797,13 +801,12 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05
797801
within_tolerance_x = (user_mpp_x >= lower_bound_x) & (user_mpp_x <= upper_bound_x)
798802
within_tolerance_y = (user_mpp_y >= lower_bound_y) & (user_mpp_y <= upper_bound_y)
799803
within_tolerance = within_tolerance_x & within_tolerance_y
800-
804+
801805
if within_tolerance:
802806
# Take closest_level and continue with returning img at level
803807
print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.')
804-
closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim))[:, :, :3]
808+
closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers)
805809

806-
return closest_lvl_wsi
807810
else:
808811
# If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp
809812
closest_level_is_bigger_x = mpp_closest_lvl_x < user_mpp_x
@@ -814,15 +817,16 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05
814817
ds_factor_x = mpp_closest_lvl_x / user_mpp_x
815818
ds_factor_y = mpp_closest_lvl_y / user_mpp_y
816819

817-
closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers))[:, :, :3]
820+
closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers)
821+
wsi_arr = cp.array(closest_lvl_wsi)
818822

819-
target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x))
820-
target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y))
821-
822-
closest_lvl_wsi = cv2.resize(closest_lvl_wsi, dsize=(target_res_x, target_res_y), interpolation=cv2.INTER_LINEAR)
823+
target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x))
824+
target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y))
823825

826+
# closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), Image.BILINEAR)
827+
closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=0)
824828
print(f'Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}')
825-
return closest_lvl_wsi
829+
826830
else:
827831
# Else: increase resolution (ie, decrement level) and then downsample
828832
closest_lvl = closest_lvl - 1
@@ -833,15 +837,18 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05
833837
ds_factor_y = mpp_closest_lvl_y / user_mpp_y
834838

835839
closest_lvl_dim = wsi.resolutions['level_dimensions'][closest_lvl]
836-
closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers))[:, :, :3]
840+
closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers)
841+
wsi_arr = cp.array(closest_lvl_wsi)
837842

838-
target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x))
839-
target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y))
840-
841-
closest_lvl_wsi = cv2.resize(closest_lvl_wsi, dsize=(target_res_x, target_res_y), interpolation=cv2.INTER_LINEAR)
843+
target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x))
844+
target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y))
842845

846+
# closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), Image.BILINEAR)
847+
closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=0)
843848
print(f'Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}')
844-
return closest_lvl_wsi
849+
850+
wsi_arr = cp.asnumpy(closest_lvl_wsi)
851+
return wsi_arr
845852

846853
def get_power(self, wsi, level: int) -> float:
847854
"""
@@ -1055,9 +1062,12 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05
10551062
10561063
"""
10571064

1065+
pil_image, _ = optional_import("PIL", name="Image")
10581066
user_mpp_x, user_mpp_y = mpp
10591067
mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(wsi.level_count)]
1060-
closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) # Should not throw ValueError, instead just return the closest value;
1068+
closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5)
1069+
# -> Should not throw ValueError, instead just return the closest value; how to select tolerances?
1070+
10611071
mpp_closest_lvl = mpp_list[closest_lvl]
10621072
closest_lvl_dim = wsi.level_dimensions[closest_lvl]
10631073

@@ -1078,9 +1088,8 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05
10781088
if within_tolerance:
10791089
# Take closest_level and continue with returning img at level
10801090
print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.')
1081-
closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim))[:, :, :3]
1091+
closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim)
10821092

1083-
return closest_lvl_wsi
10841093
else:
10851094
# If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp
10861095
closest_level_is_bigger_x = mpp_closest_lvl_x < user_mpp_x
@@ -1091,15 +1100,14 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05
10911100
ds_factor_x = mpp_closest_lvl_x / user_mpp_x
10921101
ds_factor_y = mpp_closest_lvl_y / user_mpp_y
10931102

1094-
closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim))[:, :, :3]
1103+
closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim)
10951104

10961105
target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x))
10971106
target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y))
10981107

1099-
closest_lvl_wsi = cv2.resize(closest_lvl_wsi, dsize=(target_res_x, target_res_y), interpolation=cv2.INTER_LINEAR)
1100-
1108+
closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR)
11011109
print(f'Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}')
1102-
return closest_lvl_wsi
1110+
11031111
else:
11041112
# Else: increase resolution (ie, decrement level) and then downsample
11051113
closest_lvl = closest_lvl - 1
@@ -1110,15 +1118,16 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05
11101118
ds_factor_y = mpp_closest_lvl_y / user_mpp_y
11111119

11121120
closest_lvl_dim = wsi.level_dimensions[closest_lvl]
1113-
closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim))[:, :, :3]
1121+
closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim)
11141122

11151123
target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x))
11161124
target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y))
11171125

1118-
closest_lvl_wsi = cv2.resize(closest_lvl_wsi, dsize=(target_res_x, target_res_y), interpolation=cv2.INTER_LINEAR)
1119-
1126+
closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR)
11201127
print(f'Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}')
1121-
return closest_lvl_wsi
1128+
1129+
wsi_arr = np.array(closest_lvl_wsi)
1130+
return wsi_arr
11221131

11231132
def get_power(self, wsi, level: int) -> float:
11241133
"""
@@ -1276,8 +1285,10 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]:
12761285
and wsi.pages[level].tags["YResolution"].value
12771286
):
12781287
unit = wsi.pages[level].tags.get("ResolutionUnit")
1279-
if unit is not None:
1280-
unit = str(unit.value)[8:]
1288+
if unit is not None: # Needs to be extended
1289+
# unit = str(unit.value)[8:]
1290+
unit = str(unit.value.name).lower() # TODO: Merge both methods
1291+
12811292
else:
12821293
warnings.warn("The resolution unit is missing. `micrometer` will be used as default.")
12831294
unit = "micrometer"
@@ -1290,6 +1301,95 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]:
12901301

12911302
raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.")
12921303

1304+
def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array:
1305+
"""
1306+
Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution.
1307+
The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user.
1308+
If the user-provided mpp is larger than the mpp of the closest level—indicating that the closest level has a higher resolution than requested—the image is downscaled to a resolution that matches the user-provided mpp.
1309+
Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, the next lower level (which has a higher resolution) is chosen.
1310+
The image from this level is then down-scaled to achieve a resolution at the user-provided mpp value.
1311+
1312+
Args:
1313+
wsi: whole slide image object from WSIReader
1314+
mpp: the resolution in micron per pixel at which the representation of the whole slide image should be extracted.
1315+
atol: the acceptable absolute tolerance for resolution in micro per pixel.
1316+
rtol: the acceptable relative tolerance for resolution in micro per pixel.
1317+
1318+
"""
1319+
1320+
pil_image, _ = optional_import("PIL", name="Image")
1321+
user_mpp_x, user_mpp_y = mpp
1322+
mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(len(wsi.pages))] # QuPath show 4 levels in the pyramid, but len(wsi.pages) is 1?
1323+
closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5)
1324+
# -> Should not throw ValueError, instead just return the closest value; how to select tolerances?
1325+
1326+
mpp_closest_lvl = mpp_list[closest_lvl]
1327+
1328+
lvl_dims = [self.get_size(wsi, lvl) for lvl in range(len(wsi.pages))] # Returns size in (height, width)
1329+
closest_lvl_dim = lvl_dims[closest_lvl]
1330+
closest_lvl_dim = (closest_lvl_dim[1], closest_lvl_dim[0])
1331+
1332+
print(f'Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}')
1333+
mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl
1334+
1335+
# Define tolerance intervals for x and y of closest level
1336+
lower_bound_x = mpp_closest_lvl_x * (1 - rtol) - atol
1337+
upper_bound_x = mpp_closest_lvl_x * (1 + rtol) + atol
1338+
lower_bound_y = mpp_closest_lvl_y * (1 - rtol) - atol
1339+
upper_bound_y = mpp_closest_lvl_y * (1 + rtol) + atol
1340+
1341+
# Check if user-provided mpp_x and mpp_y fall within the tolerance intervals for closest level
1342+
within_tolerance_x = (user_mpp_x >= lower_bound_x) & (user_mpp_x <= upper_bound_x)
1343+
within_tolerance_y = (user_mpp_y >= lower_bound_y) & (user_mpp_y <= upper_bound_y)
1344+
within_tolerance = within_tolerance_x & within_tolerance_y
1345+
1346+
if within_tolerance:
1347+
# Take closest_level and continue with returning img at level
1348+
print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.')
1349+
closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim)
1350+
1351+
else:
1352+
# If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp
1353+
closest_level_is_bigger_x = mpp_closest_lvl_x < user_mpp_x
1354+
closest_level_is_bigger_y = mpp_closest_lvl_y < user_mpp_y
1355+
closest_level_is_bigger = closest_level_is_bigger_x & closest_level_is_bigger_y
1356+
1357+
if closest_level_is_bigger:
1358+
ds_factor_x = mpp_closest_lvl_x / user_mpp_x
1359+
ds_factor_y = mpp_closest_lvl_y / user_mpp_y
1360+
1361+
# closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers)
1362+
closest_lvl_wsi = pil_image.fromarray(wsi.pages[closest_lvl].asarray()) # Might be suboptimal
1363+
1364+
target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x))
1365+
target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y))
1366+
1367+
closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR)
1368+
print(f'Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}')
1369+
1370+
else:
1371+
# Else: increase resolution (ie, decrement level) and then downsample
1372+
closest_lvl = closest_lvl - 1
1373+
mpp_closest_lvl = mpp_list[closest_lvl] # Update MPP
1374+
mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl
1375+
1376+
ds_factor_x = mpp_closest_lvl_x / user_mpp_x
1377+
ds_factor_y = mpp_closest_lvl_y / user_mpp_y
1378+
1379+
closest_lvl_dim = lvl_dims[closest_lvl]
1380+
closest_lvl_dim = (closest_lvl_dim[1], closest_lvl_dim[0])
1381+
# closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers)
1382+
closest_lvl_wsi = pil_image.fromarray(wsi.pages[closest_lvl].asarray()) # Might be suboptimal
1383+
1384+
target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x))
1385+
target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y))
1386+
1387+
closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR)
1388+
print(f'Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}')
1389+
1390+
wsi_arr = np.array(closest_lvl_wsi)
1391+
return wsi_arr
1392+
12931393
def get_power(self, wsi, level: int) -> float:
12941394
"""
12951395
Returns the objective power of the whole slide image at a given level.

0 commit comments

Comments
 (0)