|
4 | 4 | import mmcv |
5 | 5 | import numpy as np |
6 | 6 | import torch |
7 | | -from mmcls.datasets import DATASETS |
8 | 7 | from mmcls.models.classifiers.base import BaseClassifier |
9 | 8 | from mmcv.utils import Registry |
10 | 9 |
|
11 | 10 | from mmdeploy.codebase.base import BaseBackendModel |
12 | 11 | from mmdeploy.utils import (Backend, get_backend, get_codebase_config, |
13 | | - load_config) |
| 12 | + get_root_logger, load_config) |
14 | 13 |
|
15 | 14 |
|
16 | 15 | def __build_backend_model(cls_name: str, registry: Registry, *args, **kwargs): |
@@ -150,20 +149,35 @@ def get_classes_from_config(model_cfg: Union[str, mmcv.Config]): |
150 | 149 | Returns: |
151 | 150 | list[str]: A list of string specifying names of different class. |
152 | 151 | """ |
153 | | - model_cfg = load_config(model_cfg)[0] |
| 152 | + from mmcls.datasets import DATASETS |
| 153 | + |
154 | 154 | module_dict = DATASETS.module_dict |
| 155 | + model_cfg = load_config(model_cfg)[0] |
155 | 156 | data_cfg = model_cfg.data |
156 | 157 |
|
157 | | - if 'train' in data_cfg: |
158 | | - module = module_dict[data_cfg.train.type] |
159 | | - elif 'val' in data_cfg: |
160 | | - module = module_dict[data_cfg.val.type] |
161 | | - elif 'test' in data_cfg: |
162 | | - module = module_dict[data_cfg.test.type] |
163 | | - else: |
164 | | - raise RuntimeError(f'No dataset config found in: {model_cfg}') |
165 | | - |
166 | | - return module.CLASSES |
| 158 | + def _get_class_names(dataset_type: str): |
| 159 | + dataset = data_cfg.get(dataset_type, None) |
| 160 | + if (not dataset) or (dataset.type not in module_dict): |
| 161 | + return None |
| 162 | + |
| 163 | + module = module_dict[dataset.type] |
| 164 | + if module.CLASSES is not None: |
| 165 | + return module.CLASSES |
| 166 | + return module.get_classes(dataset.get('classes', None)) |
| 167 | + |
| 168 | + class_names = None |
| 169 | + for dataset_type in ['val', 'test', 'train']: |
| 170 | + class_names = _get_class_names(dataset_type) |
| 171 | + if class_names is not None: |
| 172 | + break |
| 173 | + |
| 174 | + if class_names is None: |
| 175 | + logger = get_root_logger() |
| 176 | + logger.warning(f'Use generated class names, because \ |
| 177 | + it failed to parse CLASSES from config: {data_cfg}') |
| 178 | + num_classes = model_cfg.model.head.num_classes |
| 179 | + class_names = [str(i) for i in range(num_classes)] |
| 180 | + return class_names |
167 | 181 |
|
168 | 182 |
|
169 | 183 | def build_classification_model(model_files: Sequence[str], |
|
0 commit comments