Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/setupapp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
which python
python -m pip install --upgrade pip --no-cache-dir
python -m pip uninstall -y torch torchvision
python -m pip install -q -r requirements.txt --no-cache-dir
python -m pip install --upgrade -q -r requirements.txt --no-cache-dir
python -m pip list
- name: Run unit tests report coverage
run: |
Expand Down
241 changes: 241 additions & 0 deletions tests/test_integration_classification_2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
import subprocess
import tarfile
import tempfile
import unittest

import numpy as np
import torch
from torch.utils.data import DataLoader

import monai
from monai.metrics import compute_roc_auc
from monai.networks.nets import densenet121
from monai.transforms import (AddChannel, Compose, LoadPNG, RandFlip, RandRotate, RandZoom, Resize, ScaleIntensity,
ToTensor)
from tests.utils import skip_if_quick

TEST_DATA_URL = 'https://www.dropbox.com/s/5wwskxctvcxiuea/MedNIST.tar.gz'


class MedNISTDataset(torch.utils.data.Dataset):

def __init__(self, image_files, labels, transforms):
self.image_files = image_files
self.labels = labels
self.transforms = transforms

def __len__(self):
return len(self.image_files)

def __getitem__(self, index):
return self.transforms(self.image_files[index]), self.labels[index]


def run_training_test(root_dir, train_x, train_y, val_x, val_y, device=torch.device("cuda:0")):

monai.config.print_config()
# define transforms for image and classification
train_transforms = Compose([
LoadPNG(),
AddChannel(),
ScaleIntensity(),
RandRotate(degrees=15, prob=0.5),
RandFlip(spatial_axis=0, prob=0.5),
RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
Resize(spatial_size=(64, 64), mode='constant'),
ToTensor()
])
train_transforms.set_random_state(1234)
val_transforms = Compose([LoadPNG(), AddChannel(), ScaleIntensity(), ToTensor()])

# create train, val data loaders
train_ds = MedNISTDataset(train_x, train_y, train_transforms)
train_loader = DataLoader(train_ds, batch_size=300, shuffle=True, num_workers=10)

val_ds = MedNISTDataset(val_x, val_y, val_transforms)
val_loader = DataLoader(val_ds, batch_size=300, num_workers=10)

model = densenet121(
spatial_dims=2,
in_channels=1,
out_channels=len(np.unique(train_y)),
).to(device)
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), 1e-5)
epoch_num = 4
val_interval = 1

# start training validation
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = list()
metric_values = list()
model_filename = os.path.join(root_dir, 'best_metric_model.pth')
for epoch in range(epoch_num):
print('-' * 10)
print('Epoch {}/{}'.format(epoch + 1, epoch_num))
model.train()
epoch_loss = 0
step = 0
for batch_data in train_loader:
step += 1
inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_function(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
epoch_loss /= step
epoch_loss_values.append(epoch_loss)
print("epoch %d average loss:%0.4f" % (epoch + 1, epoch_loss))

if (epoch + 1) % val_interval == 0:
model.eval()
with torch.no_grad():
y_pred = torch.tensor([], dtype=torch.float32, device=device)
y = torch.tensor([], dtype=torch.long, device=device)
for val_data in val_loader:
val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
y_pred = torch.cat([y_pred, model(val_images)], dim=0)
y = torch.cat([y, val_labels], dim=0)
auc_metric = compute_roc_auc(y_pred, y, to_onehot_y=True, add_softmax=True)
metric_values.append(auc_metric)
acc_value = torch.eq(y_pred.argmax(dim=1), y)
acc_metric = acc_value.sum().item() / len(acc_value)
if auc_metric > best_metric:
best_metric = auc_metric
best_metric_epoch = epoch + 1
torch.save(model.state_dict(), model_filename)
print('saved new best metric model')
print("current epoch %d current AUC: %0.4f current accuracy: %0.4f best AUC: %0.4f at epoch %d" %
(epoch + 1, auc_metric, acc_metric, best_metric, best_metric_epoch))
print('train completed, best_metric: %0.4f at epoch: %d' % (best_metric, best_metric_epoch))
return epoch_loss_values, best_metric, best_metric_epoch


def run_inference_test(root_dir, test_x, test_y, device=torch.device("cuda:0")):
# define transforms for image and classification
val_transforms = Compose([LoadPNG(), AddChannel(), ScaleIntensity(), ToTensor()])
val_ds = MedNISTDataset(test_x, test_y, val_transforms)
val_loader = DataLoader(val_ds, batch_size=300, num_workers=10)

model = densenet121(
spatial_dims=2,
in_channels=1,
out_channels=len(np.unique(test_y)),
).to(device)

model_filename = os.path.join(root_dir, 'best_metric_model.pth')
model.load_state_dict(torch.load(model_filename))
model.eval()
y_true = list()
y_pred = list()
with torch.no_grad():
for test_data in val_loader:
test_images, test_labels = test_data[0].to(device), test_data[1].to(device)
pred = model(test_images).argmax(dim=1)
for i in range(len(pred)):
y_true.append(test_labels[i].item())
y_pred.append(pred[i].item())
tps = [np.sum((np.asarray(y_true) == idx) & (np.asarray(y_pred) == idx)) for idx in np.unique(test_y)]
return tps


class IntegrationClassification2D(unittest.TestCase):

def setUp(self):
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(0)
self.data_dir = tempfile.mkdtemp()

# download
subprocess.call(['wget', '-nv', '-P', self.data_dir, TEST_DATA_URL])
dataset_file = os.path.join(self.data_dir, 'MedNIST.tar.gz')
assert os.path.exists(dataset_file)

# extract tarfile
datafile = tarfile.open(dataset_file)
datafile.extractall(path=self.data_dir)
datafile.close()

# find image files and labels
data_dir = os.path.join(self.data_dir, 'MedNIST')
class_names = sorted(os.listdir(data_dir))
image_files = [[
os.path.join(data_dir, class_name, x) for x in sorted(os.listdir(os.path.join(data_dir, class_name)))
] for class_name in class_names]
image_file_list, image_classes = [], []
for i, class_name in enumerate(class_names):
image_file_list.extend(image_files[i])
image_classes.extend([i] * len(image_files[i]))

# split train, val, test
valid_frac, test_frac = 0.1, 0.1
self.train_x, self.train_y = [], []
self.val_x, self.val_y = [], []
self.test_x, self.test_y = [], []
for i in range(len(image_classes)):
rann = np.random.random()
if rann < valid_frac:
self.val_x.append(image_file_list[i])
self.val_y.append(image_classes[i])
elif rann < test_frac + valid_frac:
self.test_x.append(image_file_list[i])
self.test_y.append(image_classes[i])
else:
self.train_x.append(image_file_list[i])
self.train_y.append(image_classes[i])

np.random.seed(seed=None)
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu:0')

def tearDown(self):
shutil.rmtree(self.data_dir)

@skip_if_quick
def test_training(self):
repeated = []
for i in range(2):
torch.manual_seed(0)

repeated.append([])
losses, best_metric, best_metric_epoch = \
run_training_test(self.data_dir, self.train_x, self.train_y, self.val_x, self.val_y, device=self.device)

# check training properties
np.testing.assert_allclose(
losses, [0.8501208358129878, 0.18469145818121113, 0.08108749352158255, 0.04965383692342005], rtol=1e-3)
repeated[i].extend(losses)
print('best metric', best_metric)
np.testing.assert_allclose(best_metric, 0.9999480167572079, rtol=1e-4)
repeated[i].append(best_metric)
np.testing.assert_allclose(best_metric_epoch, 4)
model_file = os.path.join(self.data_dir, 'best_metric_model.pth')
self.assertTrue(os.path.exists(model_file))

infer_metric = run_inference_test(self.data_dir, self.test_x, self.test_y, device=self.device)

# check inference properties
np.testing.assert_allclose(np.asarray(infer_metric), [1036, 895, 982, 1033, 958, 1047])
repeated[i].extend(infer_metric)

np.testing.assert_allclose(repeated[0], repeated[1])


if __name__ == '__main__':
unittest.main()