From d1ab097b5d4c77961516b2fa2761350d45059f2f Mon Sep 17 00:00:00 2001 From: jaegukhyun Date: Fri, 21 Jul 2023 17:16:16 +0900 Subject: [PATCH 1/6] Enable configurable confidence threshold for ov export and inference --- .../detection/adapters/openvino/task.py | 10 ++++---- src/otx/algorithms/detection/task.py | 23 ++++++++++--------- src/otx/cli/tools/export.py | 14 ++++++----- .../detection/adapters/mmdet/test_task.py | 21 +++++++++++++++++ 4 files changed, 47 insertions(+), 21 deletions(-) diff --git a/src/otx/algorithms/detection/adapters/openvino/task.py b/src/otx/algorithms/detection/adapters/openvino/task.py index 4ad7abba8fc..c870004aef5 100644 --- a/src/otx/algorithms/detection/adapters/openvino/task.py +++ b/src/otx/algorithms/detection/adapters/openvino/task.py @@ -434,10 +434,12 @@ def load_inferencer( if self.model is None: raise RuntimeError("load_inferencer failed, model is None") _hparams = copy.deepcopy(self.hparams) - self.confidence_threshold = float( - np.frombuffer(self.model.get_data("confidence_threshold"), dtype=np.float32)[0] - ) - _hparams.postprocessing.confidence_threshold = self.confidence_threshold + if _hparams.postprocessing.result_based_confidence_threshold: + self.confidence_threshold = float( + np.frombuffer(self.model.get_data("confidence_threshold"), dtype=np.float32)[0] + ) + _hparams.postprocessing.confidence_threshold = self.confidence_threshold + logger.info(f"Confidence Threshold: {_hparams.postprocessing.confidence_threshold}") _hparams.postprocessing.use_ellipse_shapes = self.config.postprocessing.use_ellipse_shapes async_requests_num = get_default_async_reqs_num() args = [ diff --git a/src/otx/algorithms/detection/task.py b/src/otx/algorithms/detection/task.py index ee94af104c2..b720c4effb4 100644 --- a/src/otx/algorithms/detection/task.py +++ b/src/otx/algorithms/detection/task.py @@ -83,11 +83,13 @@ def __init__(self, task_environment: TaskEnvironment, output_path: Optional[str] ) self._anchors: Dict[str, int] = {} - if hasattr(self._hyperparams, "postprocessing"): + if self._hyperparams.postprocessing.result_based_confidence_threshold: + self.confidence_threshold = 0.0 # Use all predictions to compute best threshold + elif hasattr(self._hyperparams, "postprocessing"): if hasattr(self._hyperparams.postprocessing, "confidence_threshold"): self.confidence_threshold = self._hyperparams.postprocessing.confidence_threshold - else: - self.confidence_threshold = 0.0 + else: + self.confidence_threshold = 0.0 if task_environment.model is not None: self._load_model() @@ -104,7 +106,6 @@ def _load_postprocessing(self, model_data): Args: model_data: The model data. - """ loaded_postprocessing = model_data.get("config", {}).get("postprocessing", {}) hparams = self._hyperparams.postprocessing @@ -112,6 +113,13 @@ def _load_postprocessing(self, model_data): hparams.use_ellipse_shapes = loaded_postprocessing["use_ellipse_shapes"]["value"] else: hparams.use_ellipse_shapes = False + # If confidence threshold is adaptive then up-to-date value should be stored in the model + # and should not be changed during inference. Otherwise user-specified value should be taken. + if hparams.result_based_confidence_threshold: + self.confidence_threshold = model_data.get("confidence_threshold", self.confidence_threshold) + else: + self.confidence_threshold = hparams.confidence_threshold + logger.info(f"Confidence threshold {self.confidence_threshold}") def _load_tiling_parameters(self, model_data): """Load tiling parameters from PyTorch model. @@ -158,8 +166,6 @@ def _load_model_ckpt(self, model: Optional[ModelEntity]) -> Optional[Dict]: buffer = io.BytesIO(model.get_data("weights.pth")) model_data = torch.load(buffer, map_location=torch.device("cpu")) - # set confidence_threshold as well - self.confidence_threshold = model_data.get("confidence_threshold", self.confidence_threshold) if model_data.get("anchors"): self._anchors = model_data["anchors"] self._load_postprocessing(model_data) @@ -287,11 +293,6 @@ def infer( explain_predicted_classes = inference_parameters.explain_predicted_classes self._time_monitor = InferenceProgressCallback(len(dataset), update_progress_callback) - # If confidence threshold is adaptive then up-to-date value should be stored in the model - # and should not be changed during inference. Otherwise user-specified value should be taken. - if not self._hyperparams.postprocessing.result_based_confidence_threshold: - self.confidence_threshold = self._hyperparams.postprocessing.confidence_threshold - logger.info(f"Confidence threshold {self.confidence_threshold}") dataset.purpose = DatasetPurpose.INFERENCE prediction_results, _ = self._infer_model(dataset, inference_parameters) diff --git a/src/otx/cli/tools/export.py b/src/otx/cli/tools/export.py index 7c4cc0439e2..e347d767750 100644 --- a/src/otx/cli/tools/export.py +++ b/src/otx/cli/tools/export.py @@ -18,7 +18,6 @@ # Update environment variables for CLI use import otx.cli # noqa: F401 -from otx.api.configuration.helper import create from otx.api.entities.model import ModelEntity, ModelOptimizationType, ModelPrecision from otx.api.entities.task_environment import TaskEnvironment from otx.api.usecases.adapters.model_adapter import ModelAdapter @@ -27,12 +26,12 @@ from otx.cli.utils.importing import get_impl_class from otx.cli.utils.io import read_binary, read_label_schema, save_model_data from otx.cli.utils.nncf import is_checkpoint_nncf -from otx.cli.utils.parser import get_parser_and_hprams_data +from otx.cli.utils.parser import add_hyper_parameters_sub_parser, get_parser_and_hprams_data def get_args(): """Parses command line arguments.""" - parser, _, _ = get_parser_and_hprams_data() + parser, hyper_parameters, params = get_parser_and_hprams_data() parser.add_argument( "--load-weights", @@ -64,12 +63,15 @@ def get_args(): default="openvino", ) - return parser.parse_args() + add_hyper_parameters_sub_parser(parser, hyper_parameters, modes=("INFERENCE",)) + override_param = [f"params.{param[2:].split('=')[0]}" for param in params if param.startswith("--")] + + return parser.parse_args(), override_param def main(): """Main function that is used for model exporting.""" - args = get_args() + args, override_param = get_args() config_manager = ConfigManager(args, mode="export", workspace_root=args.workspace) # Auto-Configuration for model template config_manager.configure_template() @@ -88,7 +90,7 @@ def main(): task_class = get_impl_class(template.entrypoints.nncf if is_nncf else template.entrypoints.base) # Get hyper parameters schema. - hyper_parameters = create(template.hyper_parameters.data) + hyper_parameters = config_manager.get_hyparams_config(override_param) assert hyper_parameters environment = TaskEnvironment( diff --git a/tests/unit/algorithms/detection/adapters/mmdet/test_task.py b/tests/unit/algorithms/detection/adapters/mmdet/test_task.py index 63ef3b595fa..fd6b27bcb90 100644 --- a/tests/unit/algorithms/detection/adapters/mmdet/test_task.py +++ b/tests/unit/algorithms/detection/adapters/mmdet/test_task.py @@ -178,6 +178,27 @@ def test_build_model(self, mocker) -> None: model = self.det_task.build_model(_mock_recipe_cfg, True) assert isinstance(model, CustomATSS) + @e2e_pytest_unit + def test_load_postprocessing(self): + """Test _load_postprocessing function.""" + mock_model_data = { + "config": {"postprocessing": {"use_ellipse_shapes": {"value": True}}}, + "confidence_threshold": 0.75, + } + self.det_task._load_postprocessing(mock_model_data) + assert self.det_task._hyperparams.postprocessing.use_ellipse_shapes == True + assert self.det_task.confidence_threshold == 0.75 + + mock_model_data = { + "config": {"postprocessing": {"use_ellipse_shapes": {"value": False}}}, + "confidence_threshold": 0.75, + } + self.det_task._hyperparams.postprocessing.result_based_confidence_threshold = False + self.det_task._hyperparams.postprocessing.confidence_threshold = 0.45 + self.det_task._load_postprocessing(mock_model_data) + assert self.det_task._hyperparams.postprocessing.use_ellipse_shapes == False + assert self.det_task.confidence_threshold == 0.45 + @e2e_pytest_unit def test_train(self, mocker) -> None: """Test train function.""" From b990d4940564d460f8daf99bb09679388b94f18d Mon Sep 17 00:00:00 2001 From: jaegukhyun Date: Mon, 24 Jul 2023 11:07:14 +0900 Subject: [PATCH 2/6] Update unit tests --- tests/unit/cli/tools/test_export.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/tests/unit/cli/tools/test_export.py b/tests/unit/cli/tools/test_export.py index 2a760eb1c59..035bfcfbe97 100644 --- a/tests/unit/cli/tools/test_export.py +++ b/tests/unit/cli/tools/test_export.py @@ -12,13 +12,29 @@ def test_get_args(mocker): mocker.patch("sys.argv", ["otx", "--load-weights", "load_weights", "--output", "output"]) mocker.patch.object( - target_package, "get_parser_and_hprams_data", return_value=[argparse.ArgumentParser(), "fake", "fake"] + target_package, + "get_parser_and_hprams_data", + return_value=[ + argparse.ArgumentParser(), + {"result_based_confidence": False, "confidence_threshold": 0.35}, + [ + "params", + "--postprocessing.result_based_confidence", + "false", + "--postprocessing.confidence_threshold", + "0.95", + ], + ], ) - parsed_args = get_args() + parsed_args, override_param = get_args() assert parsed_args.load_weights == "load_weights" assert parsed_args.output == "output" + assert override_param == [ + "params.postprocessing.result_based_confidence", + "params.postprocessing.confidence_threshold", + ] @pytest.fixture @@ -33,7 +49,7 @@ def mock_contains(self, val): mock_args.__contains__ = mock_contains mock_get_args = mocker.patch("otx.cli.tools.export.get_args") - mock_get_args.return_value = mock_args + mock_get_args.return_value = (mock_args, []) return mock_args @@ -61,7 +77,6 @@ def mock_config_manager(mocker): @e2e_pytest_unit def test_main(mocker, mock_args, mock_task, mock_config_manager, tmp_dir): mocker.patch.object(target_package, "is_checkpoint_nncf", return_value=True) - mocker.patch.object(target_package, "create") mocker.patch.object(target_package, "TaskEnvironment") mocker.patch.object(target_package, "read_label_schema") mocker.patch.object(target_package, "read_binary") From 1039ce41b951bc7f60cbe7a4ed01d969f5abaaf4 Mon Sep 17 00:00:00 2001 From: jaegukhyun Date: Mon, 24 Jul 2023 15:00:14 +0900 Subject: [PATCH 3/6] Update task.py --- src/otx/algorithms/detection/task.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/otx/algorithms/detection/task.py b/src/otx/algorithms/detection/task.py index b720c4effb4..5ff417fdec5 100644 --- a/src/otx/algorithms/detection/task.py +++ b/src/otx/algorithms/detection/task.py @@ -83,13 +83,14 @@ def __init__(self, task_environment: TaskEnvironment, output_path: Optional[str] ) self._anchors: Dict[str, int] = {} - if self._hyperparams.postprocessing.result_based_confidence_threshold: - self.confidence_threshold = 0.0 # Use all predictions to compute best threshold - elif hasattr(self._hyperparams, "postprocessing"): - if hasattr(self._hyperparams.postprocessing, "confidence_threshold"): - self.confidence_threshold = self._hyperparams.postprocessing.confidence_threshold - else: - self.confidence_threshold = 0.0 + if ( + not self._hyperparams.postprocessing.result_based_confidence_threshold + and hasattr(self._hyperparams, "postprocessing") + and hasattr(self._hyperparams.postprocessing, "confidence_threshold") + ): + self.confidence_threshold = self._hyperparams.postprocessing.confidence_threshold + else: + self.confidence_threshold = 0.0 if task_environment.model is not None: self._load_model() From 65d7866c1c8fbfa08f6ec444328fb5be09a7e023 Mon Sep 17 00:00:00 2001 From: jaegukhyun Date: Mon, 24 Jul 2023 16:35:44 +0900 Subject: [PATCH 4/6] Update CHANGELOG.md and docs --- CHANGELOG.md | 2 +- .../tutorials/base/how_to_train/detection.rst | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f3a433f0d8..6012a22ee89 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,7 +6,7 @@ All notable changes to this project will be documented in this file. ### New features -- +- Enable configurable confidence threshold for otx eval and export() ### Enhancements diff --git a/docs/source/guide/tutorials/base/how_to_train/detection.rst b/docs/source/guide/tutorials/base/how_to_train/detection.rst index dc36b7e66ab..b91effaa2bc 100644 --- a/docs/source/guide/tutorials/base/how_to_train/detection.rst +++ b/docs/source/guide/tutorials/base/how_to_train/detection.rst @@ -359,6 +359,22 @@ using ``otx eval`` and passing the IR model path to the ``--load-weights`` param Performance(score: 0.5487693710118504, dashboard: (1 metric groups)) +4. ``Optional`` Additionally, we can tune confidence threshold via the command line. +Learn more about template-specific parameters using ``otx export params --help``. + +For example, if there are too many False-Positive predictions (there we have a prediction, but don't have annotated object for it), we can suppress its number by increasing the confidence threshold as it is shown below. + +Please note, by default, the optimal confidence threshold is detected based on validation results to maximize the final F1 metric. To set a custom confidence threshold, please disable ``result_based_confidence_threshold`` option. + +.. code-block:: + + (otx) ...$ otx export --load-weights ../outputs/weights.pth \ + --output ../outputs + params \ + --postprocessing.confidence_threshold 0.5 \ + --postprocessing.result_based_confidence_threshold false + + ************* Optimization ************* From 5e71721b8f0b492dc6d77ea8ae559bed3d5af3f2 Mon Sep 17 00:00:00 2001 From: jaegukhyun Date: Wed, 26 Jul 2023 10:38:31 +0900 Subject: [PATCH 5/6] Reflect reviews --- docs/source/guide/tutorials/base/how_to_train/detection.rst | 2 +- src/otx/algorithms/detection/task.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/guide/tutorials/base/how_to_train/detection.rst b/docs/source/guide/tutorials/base/how_to_train/detection.rst index b91effaa2bc..90b0363d5f7 100644 --- a/docs/source/guide/tutorials/base/how_to_train/detection.rst +++ b/docs/source/guide/tutorials/base/how_to_train/detection.rst @@ -369,7 +369,7 @@ Please note, by default, the optimal confidence threshold is detected based on v .. code-block:: (otx) ...$ otx export --load-weights ../outputs/weights.pth \ - --output ../outputs + --output ../outputs \ params \ --postprocessing.confidence_threshold 0.5 \ --postprocessing.result_based_confidence_threshold false diff --git a/src/otx/algorithms/detection/task.py b/src/otx/algorithms/detection/task.py index 5ff417fdec5..a16a1b4302e 100644 --- a/src/otx/algorithms/detection/task.py +++ b/src/otx/algorithms/detection/task.py @@ -84,8 +84,8 @@ def __init__(self, task_environment: TaskEnvironment, output_path: Optional[str] self._anchors: Dict[str, int] = {} if ( - not self._hyperparams.postprocessing.result_based_confidence_threshold - and hasattr(self._hyperparams, "postprocessing") + hasattr(self._hyperparams, "postprocessing") + and not getattr(self._hyperparams.postprocessing, "result_based_confidence_threshold", False) and hasattr(self._hyperparams.postprocessing, "confidence_threshold") ): self.confidence_threshold = self._hyperparams.postprocessing.confidence_threshold From 56b9e98b002db3b0ccafbc272b12ad908c9e2343 Mon Sep 17 00:00:00 2001 From: jaegukhyun Date: Wed, 26 Jul 2023 17:20:50 +0900 Subject: [PATCH 6/6] Extract getting override params part --- src/otx/cli/tools/demo.py | 3 ++- src/otx/cli/tools/eval.py | 3 ++- src/otx/cli/tools/explain.py | 3 ++- src/otx/cli/tools/export.py | 4 ++-- src/otx/cli/tools/optimize.py | 3 ++- src/otx/cli/tools/train.py | 3 ++- src/otx/cli/utils/parser.py | 5 +++++ 7 files changed, 17 insertions(+), 7 deletions(-) diff --git a/src/otx/cli/tools/demo.py b/src/otx/cli/tools/demo.py index d0e9adae598..a1363c8966b 100644 --- a/src/otx/cli/tools/demo.py +++ b/src/otx/cli/tools/demo.py @@ -35,6 +35,7 @@ from otx.cli.utils.io import read_label_schema, read_model from otx.cli.utils.parser import ( add_hyper_parameters_sub_parser, + get_override_param, get_parser_and_hprams_data, ) @@ -87,7 +88,7 @@ def get_args(): ) add_hyper_parameters_sub_parser(parser, hyper_parameters, modes=("INFERENCE",)) - override_param = [f"params.{param[2:].split('=')[0]}" for param in params if param.startswith("--")] + override_param = get_override_param(params) return parser.parse_args(), override_param diff --git a/src/otx/cli/tools/eval.py b/src/otx/cli/tools/eval.py index 0883514d994..00a533510d0 100644 --- a/src/otx/cli/tools/eval.py +++ b/src/otx/cli/tools/eval.py @@ -30,6 +30,7 @@ from otx.cli.utils.nncf import is_checkpoint_nncf from otx.cli.utils.parser import ( add_hyper_parameters_sub_parser, + get_override_param, get_parser_and_hprams_data, ) from otx.core.data.adapter import get_dataset_adapter @@ -68,7 +69,7 @@ def get_args(): ) add_hyper_parameters_sub_parser(parser, hyper_parameters, modes=("INFERENCE",)) - override_param = [f"params.{param[2:].split('=')[0]}" for param in params if param.startswith("--")] + override_param = get_override_param(params) return parser.parse_args(), override_param diff --git a/src/otx/cli/tools/explain.py b/src/otx/cli/tools/explain.py index 10f8aa79ddc..ec7735acbdf 100644 --- a/src/otx/cli/tools/explain.py +++ b/src/otx/cli/tools/explain.py @@ -33,6 +33,7 @@ from otx.cli.utils.nncf import is_checkpoint_nncf from otx.cli.utils.parser import ( add_hyper_parameters_sub_parser, + get_override_param, get_parser_and_hprams_data, ) @@ -88,7 +89,7 @@ def get_args(): help="Weight of the saliency map when overlaying the input image with saliency map.", ) add_hyper_parameters_sub_parser(parser, hyper_parameters, modes=("INFERENCE",)) - override_param = [f"params.{param[2:].split('=')[0]}" for param in params if param.startswith("--")] + override_param = get_override_param(params) return parser.parse_args(), override_param diff --git a/src/otx/cli/tools/export.py b/src/otx/cli/tools/export.py index e347d767750..8ec6c0f92b5 100644 --- a/src/otx/cli/tools/export.py +++ b/src/otx/cli/tools/export.py @@ -26,7 +26,7 @@ from otx.cli.utils.importing import get_impl_class from otx.cli.utils.io import read_binary, read_label_schema, save_model_data from otx.cli.utils.nncf import is_checkpoint_nncf -from otx.cli.utils.parser import add_hyper_parameters_sub_parser, get_parser_and_hprams_data +from otx.cli.utils.parser import add_hyper_parameters_sub_parser, get_override_param, get_parser_and_hprams_data def get_args(): @@ -64,7 +64,7 @@ def get_args(): ) add_hyper_parameters_sub_parser(parser, hyper_parameters, modes=("INFERENCE",)) - override_param = [f"params.{param[2:].split('=')[0]}" for param in params if param.startswith("--")] + override_param = get_override_param(params) return parser.parse_args(), override_param diff --git a/src/otx/cli/tools/optimize.py b/src/otx/cli/tools/optimize.py index 61a3eba51d5..b29e3fa84ba 100644 --- a/src/otx/cli/tools/optimize.py +++ b/src/otx/cli/tools/optimize.py @@ -32,6 +32,7 @@ from otx.cli.utils.io import read_model, save_model_data from otx.cli.utils.parser import ( add_hyper_parameters_sub_parser, + get_override_param, get_parser_and_hprams_data, ) from otx.core.data.adapter import get_dataset_adapter @@ -70,7 +71,7 @@ def get_args(): ) add_hyper_parameters_sub_parser(parser, hyper_parameters) - override_param = [f"params.{param[2:].split('=')[0]}" for param in params if param.startswith("--")] + override_param = get_override_param(params) return parser.parse_args(), override_param diff --git a/src/otx/cli/tools/train.py b/src/otx/cli/tools/train.py index b81146adeda..5a0fecccf1f 100644 --- a/src/otx/cli/tools/train.py +++ b/src/otx/cli/tools/train.py @@ -38,6 +38,7 @@ from otx.cli.utils.parser import ( MemSizeAction, add_hyper_parameters_sub_parser, + get_override_param, get_parser_and_hprams_data, ) from otx.cli.utils.report import get_otx_report @@ -161,7 +162,7 @@ def get_args(): sub_parser = add_hyper_parameters_sub_parser(parser, hyper_parameters, return_sub_parser=True) # TODO: Temporary solution for cases where there is no template input - override_param = [f"params.{param[2:].split('=')[0]}" for param in params if param.startswith("--")] + override_param = get_override_param(params) if not hyper_parameters and "params" in params: if "params" in params: params = params[params.index("params") :] diff --git a/src/otx/cli/utils/parser.py b/src/otx/cli/utils/parser.py index acc65f09a65..1d58d5b6308 100644 --- a/src/otx/cli/utils/parser.py +++ b/src/otx/cli/utils/parser.py @@ -259,3 +259,8 @@ def get_parser_and_hprams_data(): parser.add_argument("template", nargs="?", default=None, help=template_help_str) return parser, hyper_parameters, params + + +def get_override_param(params): + """Get override param list from params.""" + return [f"params.{param[2:].split('=')[0]}" for param in params if param.startswith("--")]