Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions examples/lifelong_learning/RFNet/accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from basemodel import val_args
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The import of the relative path should be adjusted.

from utils.metrics import Evaluator
from tqdm import tqdm
from dataloaders import make_data_loader
from sedna.common.class_factory import ClassType, ClassFactory
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note the order of import


__all__ = ('accuracy')

@ClassFactory.register(ClassType.GENERAL)
def accuracy(y_true, y_pred, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Common keyword. Use alias while register.

args = val_args()
_, _, test_loader, num_class = make_data_loader(args, test_data=y_true)
evaluator = Evaluator(num_class)

tbar = tqdm(test_loader, desc='\r')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

useless

for i, (sample, img_path) in enumerate(tbar):
if args.depth:
image, depth, target = sample['image'], sample['depth'], sample['label']
else:
image, target = sample['image'], sample['label']
if args.cuda:
image, target = image.cuda(), target.cuda()
if args.depth:
depth = depth.cuda()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check whether the device supports GPU.


target[target > evaluator.num_class-1] = 255
target = target.cpu().numpy()
# Add batch sample into evaluator
evaluator.add_batch(target, y_pred[i])

# Test during the training
# Acc = evaluator.Pixel_Accuracy()
CPA = evaluator.Pixel_Accuracy_Class()
mIoU = evaluator.Mean_Intersection_over_Union()
FWIoU = evaluator.Frequency_Weighted_Intersection_over_Union()

print("CPA:{}, mIoU:{}, fwIoU: {}".format(CPA, mIoU, FWIoU))
return CPA
Loading