Skip to content

Commit b0b069e

Browse files
authored
fix mmcls get classes (open-mmlab#215)
* fix mmcls get classes * resolve comment * resolve comment
1 parent 4510661 commit b0b069e

File tree

1 file changed

+27
-13
lines changed

1 file changed

+27
-13
lines changed

mmdeploy/codebase/mmcls/deploy/classification_model.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,12 @@
44
import mmcv
55
import numpy as np
66
import torch
7-
from mmcls.datasets import DATASETS
87
from mmcls.models.classifiers.base import BaseClassifier
98
from mmcv.utils import Registry
109

1110
from mmdeploy.codebase.base import BaseBackendModel
1211
from mmdeploy.utils import (Backend, get_backend, get_codebase_config,
13-
load_config)
12+
get_root_logger, load_config)
1413

1514

1615
def __build_backend_model(cls_name: str, registry: Registry, *args, **kwargs):
@@ -150,20 +149,35 @@ def get_classes_from_config(model_cfg: Union[str, mmcv.Config]):
150149
Returns:
151150
list[str]: A list of string specifying names of different class.
152151
"""
153-
model_cfg = load_config(model_cfg)[0]
152+
from mmcls.datasets import DATASETS
153+
154154
module_dict = DATASETS.module_dict
155+
model_cfg = load_config(model_cfg)[0]
155156
data_cfg = model_cfg.data
156157

157-
if 'train' in data_cfg:
158-
module = module_dict[data_cfg.train.type]
159-
elif 'val' in data_cfg:
160-
module = module_dict[data_cfg.val.type]
161-
elif 'test' in data_cfg:
162-
module = module_dict[data_cfg.test.type]
163-
else:
164-
raise RuntimeError(f'No dataset config found in: {model_cfg}')
165-
166-
return module.CLASSES
158+
def _get_class_names(dataset_type: str):
159+
dataset = data_cfg.get(dataset_type, None)
160+
if (not dataset) or (dataset.type not in module_dict):
161+
return None
162+
163+
module = module_dict[dataset.type]
164+
if module.CLASSES is not None:
165+
return module.CLASSES
166+
return module.get_classes(dataset.get('classes', None))
167+
168+
class_names = None
169+
for dataset_type in ['val', 'test', 'train']:
170+
class_names = _get_class_names(dataset_type)
171+
if class_names is not None:
172+
break
173+
174+
if class_names is None:
175+
logger = get_root_logger()
176+
logger.warning(f'Use generated class names, because \
177+
it failed to parse CLASSES from config: {data_cfg}')
178+
num_classes = model_cfg.model.head.num_classes
179+
class_names = [str(i) for i in range(num_classes)]
180+
return class_names
167181

168182

169183
def build_classification_model(model_files: Sequence[str],

0 commit comments

Comments
 (0)