diff --git a/test/test_decorator.py b/test/test_decorator.py index d70811d2..9521eee7 100644 --- a/test/test_decorator.py +++ b/test/test_decorator.py @@ -1,13 +1,11 @@ -"""Tests simple parsing decorators. -""" +"""Tests simple parsing decorators.""" import collections import dataclasses import functools -import inspect import sys -import typing from typing import Callable - +import typing +import inspect import pytest import simple_parsing as sp @@ -43,7 +41,7 @@ def pop(self, index, default_value=None): return default_value -def partial(fn: Callable, *args, **kwargs) -> Callable: +def change_defaults(fn: Callable, *args, **kwargs) -> Callable: """Partial via changing the signature defaults.""" @functools.wraps(fn) @@ -69,22 +67,38 @@ def _wrapper(*other_args, **other_kwargs): return _wrapper +def _xfail_in_py311(*param): + return pytest.param( + *param, + marks=pytest.mark.xfail( + sys.version_info >= (3, 11), + reason="TODO: test doesn't work in Python 3.11", + raises=ValueError, # "non-default argument follows default argument" + strict=True, + ), + ) + + @pytest.mark.parametrize( "args, expected, fn", [ - ("", 1, partial(_fn_with_positional_only, 1)), - ("2", 2, partial(_fn_with_positional_only, 1)), - ("2", 2, _fn_with_positional_only), - ("", 1, partial(_fn_with_keyword_only, x=1)), - ("--x=2", 2, partial(_fn_with_keyword_only, x=1)), - ("--x=2", 2, _fn_with_keyword_only), - ("", 3, partial(_fn_with_all_argument_types, 1, b=1, c=1)), - ("2", 4, partial(_fn_with_all_argument_types, b=1, c=1)), - ("2 --b=2", 5, partial(_fn_with_all_argument_types, c=1)), - ("2 --b=2 --c=2", 6, _fn_with_all_argument_types), - ("--c=2", 4, partial(_fn_with_all_argument_types, 1, b=1)), - ("--b=2", 4, partial(_fn_with_all_argument_types, 1, c=1)), - ("--b=2 --c=2", 5, partial(_fn_with_all_argument_types, 1)), + ("", 1, change_defaults(_fn_with_positional_only, 1)), + ("2", 2, change_defaults(_fn_with_positional_only, 1)), + ("2", 2, change_defaults(_fn_with_positional_only)), + ("", 1, change_defaults(_fn_with_keyword_only, x=1)), + ("--x=2", 2, change_defaults(_fn_with_keyword_only, x=1)), + ("--x=2", 2, change_defaults(_fn_with_keyword_only)), + ("", 3, change_defaults(_fn_with_all_argument_types, 1, b=1, c=1)), + ("2", 4, change_defaults(_fn_with_all_argument_types, b=1, c=1)), + ("2 --b=2", 5, change_defaults(_fn_with_all_argument_types, c=1)), + ("2 --b=2 --c=2", 6, change_defaults(_fn_with_all_argument_types)), + ("--c=2", 4, change_defaults(_fn_with_all_argument_types, 1, b=1)), + _xfail_in_py311( + "--b=2", 4, functools.partial(change_defaults, _fn_with_all_argument_types, 1, c=1) + ), + _xfail_in_py311( + "--b=2 --c=2", 5, functools.partial(change_defaults, _fn_with_all_argument_types, 1) + ), ], ) def test_simple_arguments( @@ -92,6 +106,10 @@ def test_simple_arguments( expected: int, fn: Callable, ): + if isinstance(fn, functools.partial): + # In Python 3.11.4, the inspect module had a backward-incompatible change. We need to have + # this additional level of indirection. + fn = fn() decorated = sp.decorators.main(fn, args=args) assert decorated() == expected @@ -100,25 +118,16 @@ def _fn_with_nested_dataclass(x: int, /, *, data: AddThreeNumbers) -> int: return x + data() -def _xfail_in_py311(*param): - return pytest.param( - *param, - marks=pytest.mark.xfail( - sys.version_info >= (3, 11), - reason="TODO: test doesn't work in Python 3.11", - strict=True, - ), - ) - - @pytest.mark.parametrize( - "args, expected, fn", + ("args", "expected", "fn"), [ - _xfail_in_py311("", 1, partial(_fn_with_nested_dataclass, 1, data=AddThreeNumbers())), - ("--a=1", 2, partial(_fn_with_nested_dataclass, 1)), - ("--a=1 --b=1", 3, partial(_fn_with_nested_dataclass, 1)), - ("--a=1 --b=1 --c=1", 4, partial(_fn_with_nested_dataclass, 1)), - ("2 --a=1 --b=1 --c=1", 5, partial(_fn_with_nested_dataclass)), + _xfail_in_py311( + "", 1, change_defaults(_fn_with_nested_dataclass, 1, data=AddThreeNumbers()) + ), + ("--a=1", 2, change_defaults(_fn_with_nested_dataclass, 1)), + ("--a=1 --b=1", 3, change_defaults(_fn_with_nested_dataclass, 1)), + ("--a=1 --b=1 --c=1", 4, change_defaults(_fn_with_nested_dataclass, 1)), + ("2 --a=1 --b=1 --c=1", 5, change_defaults(_fn_with_nested_dataclass)), ], ) def test_nested_dataclass(