Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 27 additions & 13 deletions mmdeploy/codebase/mmcls/deploy/classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
import mmcv
import numpy as np
import torch
from mmcls.datasets import DATASETS
from mmcls.models.classifiers.base import BaseClassifier
from mmcv.utils import Registry

from mmdeploy.codebase.base import BaseBackendModel
from mmdeploy.utils import (Backend, get_backend, get_codebase_config,
load_config)
get_root_logger, load_config)


def __build_backend_model(cls_name: str, registry: Registry, *args, **kwargs):
Expand Down Expand Up @@ -150,20 +149,35 @@ def get_classes_from_config(model_cfg: Union[str, mmcv.Config]):
Returns:
list[str]: A list of string specifying names of different class.
"""
model_cfg = load_config(model_cfg)[0]
from mmcls.datasets import DATASETS

module_dict = DATASETS.module_dict
model_cfg = load_config(model_cfg)[0]
data_cfg = model_cfg.data

if 'train' in data_cfg:
module = module_dict[data_cfg.train.type]
elif 'val' in data_cfg:
module = module_dict[data_cfg.val.type]
elif 'test' in data_cfg:
module = module_dict[data_cfg.test.type]
else:
raise RuntimeError(f'No dataset config found in: {model_cfg}')

return module.CLASSES
def _get_class_names(dataset_type: str):
dataset = data_cfg.get(dataset_type, None)
if (not dataset) or (dataset.type not in module_dict):
return None

module = module_dict[dataset.type]
if module.CLASSES is not None:
return module.CLASSES
return module.get_classes(dataset.get('classes', None))

class_names = None
for dataset_type in ['val', 'test', 'train']:
class_names = _get_class_names(dataset_type)
if class_names is not None:
break

if class_names is None:
logger = get_root_logger()
logger.warning(f'Use generated class names, because \
it failed to parse CLASSES from config: {data_cfg}')
num_classes = model_cfg.model.head.num_classes
class_names = [str(i) for i in range(num_classes)]
return class_names


def build_classification_model(model_files: Sequence[str],
Expand Down