Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions mmdeploy/codebase/mmedit/deploy/super_resolution_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import List, Optional, Sequence, Union

import mmcv
Expand Down Expand Up @@ -88,6 +89,9 @@ def forward(self,
def forward_test(self,
lq: torch.Tensor,
gt: Optional[torch.Tensor] = None,
meta=None,
save_image=False,
save_path=None,
*args,
**kwargs):
"""Run inference for restorer to generate evaluation result.
Expand All @@ -96,6 +100,8 @@ def forward_test(self,
lq (torch.Tensor): The input low-quality image of the model.
gt (torch.Tensor): The ground truth of input image. Defaults to
`None`.
save_image (bool): Whether to save image. Default: False.
save_path (str): Path to save image. Default: None.
*args: Other arguments.
**kwargs: Other key-pair arguments.

Expand All @@ -104,6 +110,17 @@ def forward_test(self,
"""
outputs = self.forward_dummy(lq)
result = self.test_post_process(outputs, lq, gt)

# Align to mmediting BasicRestorer
if save_image and save_path:
outputs = [torch.from_numpy(i) for i in outputs]

lq_path = meta[0]['lq_path']
folder_name = osp.splitext(osp.basename(lq_path))[0]
save_path = osp.join(save_path, f'{folder_name}.png')

mmcv.imwrite(tensor2img(outputs), save_path)

return result

def forward_dummy(self, lq: torch.Tensor, *args, **kwargs):
Expand Down