Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ All notable changes to this project will be documented in this file.

### New features

-
- Enable configurable confidence threshold for otx eval and export(<https://github.com/openvinotoolkit/training_extensions/pull/2388>)

### Enhancements

Expand Down
16 changes: 16 additions & 0 deletions docs/source/guide/tutorials/base/how_to_train/detection.rst
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,22 @@ using ``otx eval`` and passing the IR model path to the ``--load-weights`` param
Performance(score: 0.5487693710118504, dashboard: (1 metric groups))


4. ``Optional`` Additionally, we can tune confidence threshold via the command line.
Learn more about template-specific parameters using ``otx export params --help``.

For example, if there are too many False-Positive predictions (there we have a prediction, but don't have annotated object for it), we can suppress its number by increasing the confidence threshold as it is shown below.

Please note, by default, the optimal confidence threshold is detected based on validation results to maximize the final F1 metric. To set a custom confidence threshold, please disable ``result_based_confidence_threshold`` option.

.. code-block::

(otx) ...$ otx export --load-weights ../outputs/weights.pth \
--output ../outputs
params \
--postprocessing.confidence_threshold 0.5 \
--postprocessing.result_based_confidence_threshold false


*************
Optimization
*************
Expand Down
10 changes: 6 additions & 4 deletions src/otx/algorithms/detection/adapters/openvino/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,10 +434,12 @@ def load_inferencer(
if self.model is None:
raise RuntimeError("load_inferencer failed, model is None")
_hparams = copy.deepcopy(self.hparams)
self.confidence_threshold = float(
np.frombuffer(self.model.get_data("confidence_threshold"), dtype=np.float32)[0]
)
_hparams.postprocessing.confidence_threshold = self.confidence_threshold
if _hparams.postprocessing.result_based_confidence_threshold:
self.confidence_threshold = float(
np.frombuffer(self.model.get_data("confidence_threshold"), dtype=np.float32)[0]
)
_hparams.postprocessing.confidence_threshold = self.confidence_threshold
logger.info(f"Confidence Threshold: {_hparams.postprocessing.confidence_threshold}")
_hparams.postprocessing.use_ellipse_shapes = self.config.postprocessing.use_ellipse_shapes
async_requests_num = get_default_async_reqs_num()
args = [
Expand Down
24 changes: 13 additions & 11 deletions src/otx/algorithms/detection/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,12 @@ def __init__(self, task_environment: TaskEnvironment, output_path: Optional[str]
)
self._anchors: Dict[str, int] = {}

if hasattr(self._hyperparams, "postprocessing"):
if hasattr(self._hyperparams.postprocessing, "confidence_threshold"):
self.confidence_threshold = self._hyperparams.postprocessing.confidence_threshold
if (
not self._hyperparams.postprocessing.result_based_confidence_threshold
and hasattr(self._hyperparams, "postprocessing")
and hasattr(self._hyperparams.postprocessing, "confidence_threshold")
):
self.confidence_threshold = self._hyperparams.postprocessing.confidence_threshold
else:
self.confidence_threshold = 0.0

Expand All @@ -104,14 +107,20 @@ def _load_postprocessing(self, model_data):

Args:
model_data: The model data.

"""
loaded_postprocessing = model_data.get("config", {}).get("postprocessing", {})
hparams = self._hyperparams.postprocessing
if "use_ellipse_shapes" in loaded_postprocessing:
hparams.use_ellipse_shapes = loaded_postprocessing["use_ellipse_shapes"]["value"]
else:
hparams.use_ellipse_shapes = False
# If confidence threshold is adaptive then up-to-date value should be stored in the model
# and should not be changed during inference. Otherwise user-specified value should be taken.
if hparams.result_based_confidence_threshold:
self.confidence_threshold = model_data.get("confidence_threshold", self.confidence_threshold)
else:
self.confidence_threshold = hparams.confidence_threshold
logger.info(f"Confidence threshold {self.confidence_threshold}")

def _load_tiling_parameters(self, model_data):
"""Load tiling parameters from PyTorch model.
Expand Down Expand Up @@ -158,8 +167,6 @@ def _load_model_ckpt(self, model: Optional[ModelEntity]) -> Optional[Dict]:
buffer = io.BytesIO(model.get_data("weights.pth"))
model_data = torch.load(buffer, map_location=torch.device("cpu"))

# set confidence_threshold as well
self.confidence_threshold = model_data.get("confidence_threshold", self.confidence_threshold)
if model_data.get("anchors"):
self._anchors = model_data["anchors"]
self._load_postprocessing(model_data)
Expand Down Expand Up @@ -287,11 +294,6 @@ def infer(
explain_predicted_classes = inference_parameters.explain_predicted_classes

self._time_monitor = InferenceProgressCallback(len(dataset), update_progress_callback)
# If confidence threshold is adaptive then up-to-date value should be stored in the model
# and should not be changed during inference. Otherwise user-specified value should be taken.
if not self._hyperparams.postprocessing.result_based_confidence_threshold:
self.confidence_threshold = self._hyperparams.postprocessing.confidence_threshold
logger.info(f"Confidence threshold {self.confidence_threshold}")

dataset.purpose = DatasetPurpose.INFERENCE
prediction_results, _ = self._infer_model(dataset, inference_parameters)
Expand Down
14 changes: 8 additions & 6 deletions src/otx/cli/tools/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

# Update environment variables for CLI use
import otx.cli # noqa: F401
from otx.api.configuration.helper import create
from otx.api.entities.model import ModelEntity, ModelOptimizationType, ModelPrecision
from otx.api.entities.task_environment import TaskEnvironment
from otx.api.usecases.adapters.model_adapter import ModelAdapter
Expand All @@ -27,12 +26,12 @@
from otx.cli.utils.importing import get_impl_class
from otx.cli.utils.io import read_binary, read_label_schema, save_model_data
from otx.cli.utils.nncf import is_checkpoint_nncf
from otx.cli.utils.parser import get_parser_and_hprams_data
from otx.cli.utils.parser import add_hyper_parameters_sub_parser, get_parser_and_hprams_data


def get_args():
"""Parses command line arguments."""
parser, _, _ = get_parser_and_hprams_data()
parser, hyper_parameters, params = get_parser_and_hprams_data()

parser.add_argument(
"--load-weights",
Expand Down Expand Up @@ -64,12 +63,15 @@ def get_args():
default="openvino",
)

return parser.parse_args()
add_hyper_parameters_sub_parser(parser, hyper_parameters, modes=("INFERENCE",))
override_param = [f"params.{param[2:].split('=')[0]}" for param in params if param.startswith("--")]

return parser.parse_args(), override_param


def main():
"""Main function that is used for model exporting."""
args = get_args()
args, override_param = get_args()
config_manager = ConfigManager(args, mode="export", workspace_root=args.workspace)
# Auto-Configuration for model template
config_manager.configure_template()
Expand All @@ -88,7 +90,7 @@ def main():
task_class = get_impl_class(template.entrypoints.nncf if is_nncf else template.entrypoints.base)

# Get hyper parameters schema.
hyper_parameters = create(template.hyper_parameters.data)
hyper_parameters = config_manager.get_hyparams_config(override_param)
assert hyper_parameters

environment = TaskEnvironment(
Expand Down
21 changes: 21 additions & 0 deletions tests/unit/algorithms/detection/adapters/mmdet/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,27 @@ def test_build_model(self, mocker) -> None:
model = self.det_task.build_model(_mock_recipe_cfg, True)
assert isinstance(model, CustomATSS)

@e2e_pytest_unit
def test_load_postprocessing(self):
"""Test _load_postprocessing function."""
mock_model_data = {
"config": {"postprocessing": {"use_ellipse_shapes": {"value": True}}},
"confidence_threshold": 0.75,
}
self.det_task._load_postprocessing(mock_model_data)
assert self.det_task._hyperparams.postprocessing.use_ellipse_shapes == True
assert self.det_task.confidence_threshold == 0.75

mock_model_data = {
"config": {"postprocessing": {"use_ellipse_shapes": {"value": False}}},
"confidence_threshold": 0.75,
}
self.det_task._hyperparams.postprocessing.result_based_confidence_threshold = False
self.det_task._hyperparams.postprocessing.confidence_threshold = 0.45
self.det_task._load_postprocessing(mock_model_data)
assert self.det_task._hyperparams.postprocessing.use_ellipse_shapes == False
assert self.det_task.confidence_threshold == 0.45

@e2e_pytest_unit
def test_train(self, mocker) -> None:
"""Test train function."""
Expand Down
23 changes: 19 additions & 4 deletions tests/unit/cli/tools/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,29 @@
def test_get_args(mocker):
mocker.patch("sys.argv", ["otx", "--load-weights", "load_weights", "--output", "output"])
mocker.patch.object(
target_package, "get_parser_and_hprams_data", return_value=[argparse.ArgumentParser(), "fake", "fake"]
target_package,
"get_parser_and_hprams_data",
return_value=[
argparse.ArgumentParser(),
{"result_based_confidence": False, "confidence_threshold": 0.35},
[
"params",
"--postprocessing.result_based_confidence",
"false",
"--postprocessing.confidence_threshold",
"0.95",
],
],
)

parsed_args = get_args()
parsed_args, override_param = get_args()

assert parsed_args.load_weights == "load_weights"
assert parsed_args.output == "output"
assert override_param == [
"params.postprocessing.result_based_confidence",
"params.postprocessing.confidence_threshold",
]


@pytest.fixture
Expand All @@ -33,7 +49,7 @@ def mock_contains(self, val):

mock_args.__contains__ = mock_contains
mock_get_args = mocker.patch("otx.cli.tools.export.get_args")
mock_get_args.return_value = mock_args
mock_get_args.return_value = (mock_args, [])

return mock_args

Expand Down Expand Up @@ -61,7 +77,6 @@ def mock_config_manager(mocker):
@e2e_pytest_unit
def test_main(mocker, mock_args, mock_task, mock_config_manager, tmp_dir):
mocker.patch.object(target_package, "is_checkpoint_nncf", return_value=True)
mocker.patch.object(target_package, "create")
mocker.patch.object(target_package, "TaskEnvironment")
mocker.patch.object(target_package, "read_label_schema")
mocker.patch.object(target_package, "read_binary")
Expand Down