Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 45 additions & 36 deletions test/test_decorator.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
"""Tests simple parsing decorators.
"""
"""Tests simple parsing decorators."""
import collections
import dataclasses
import functools
import inspect
import sys
import typing
from typing import Callable

import typing
import inspect
import pytest

import simple_parsing as sp
Expand Down Expand Up @@ -43,7 +41,7 @@ def pop(self, index, default_value=None):
return default_value


def partial(fn: Callable, *args, **kwargs) -> Callable:
def change_defaults(fn: Callable, *args, **kwargs) -> Callable:
"""Partial via changing the signature defaults."""

@functools.wraps(fn)
Expand All @@ -69,29 +67,49 @@ def _wrapper(*other_args, **other_kwargs):
return _wrapper


def _xfail_in_py311(*param):
return pytest.param(
*param,
marks=pytest.mark.xfail(
sys.version_info >= (3, 11),
reason="TODO: test doesn't work in Python 3.11",
raises=ValueError, # "non-default argument follows default argument"
strict=True,
),
)


@pytest.mark.parametrize(
"args, expected, fn",
[
("", 1, partial(_fn_with_positional_only, 1)),
("2", 2, partial(_fn_with_positional_only, 1)),
("2", 2, _fn_with_positional_only),
("", 1, partial(_fn_with_keyword_only, x=1)),
("--x=2", 2, partial(_fn_with_keyword_only, x=1)),
("--x=2", 2, _fn_with_keyword_only),
("", 3, partial(_fn_with_all_argument_types, 1, b=1, c=1)),
("2", 4, partial(_fn_with_all_argument_types, b=1, c=1)),
("2 --b=2", 5, partial(_fn_with_all_argument_types, c=1)),
("2 --b=2 --c=2", 6, _fn_with_all_argument_types),
("--c=2", 4, partial(_fn_with_all_argument_types, 1, b=1)),
("--b=2", 4, partial(_fn_with_all_argument_types, 1, c=1)),
("--b=2 --c=2", 5, partial(_fn_with_all_argument_types, 1)),
("", 1, change_defaults(_fn_with_positional_only, 1)),
("2", 2, change_defaults(_fn_with_positional_only, 1)),
("2", 2, change_defaults(_fn_with_positional_only)),
("", 1, change_defaults(_fn_with_keyword_only, x=1)),
("--x=2", 2, change_defaults(_fn_with_keyword_only, x=1)),
("--x=2", 2, change_defaults(_fn_with_keyword_only)),
("", 3, change_defaults(_fn_with_all_argument_types, 1, b=1, c=1)),
("2", 4, change_defaults(_fn_with_all_argument_types, b=1, c=1)),
("2 --b=2", 5, change_defaults(_fn_with_all_argument_types, c=1)),
("2 --b=2 --c=2", 6, change_defaults(_fn_with_all_argument_types)),
("--c=2", 4, change_defaults(_fn_with_all_argument_types, 1, b=1)),
_xfail_in_py311(
"--b=2", 4, functools.partial(change_defaults, _fn_with_all_argument_types, 1, c=1)
),
_xfail_in_py311(
"--b=2 --c=2", 5, functools.partial(change_defaults, _fn_with_all_argument_types, 1)
),
],
)
def test_simple_arguments(
args: str,
expected: int,
fn: Callable,
):
if isinstance(fn, functools.partial):
# In Python 3.11.4, the inspect module had a backward-incompatible change. We need to have
# this additional level of indirection.
fn = fn()
decorated = sp.decorators.main(fn, args=args)
assert decorated() == expected

Expand All @@ -100,25 +118,16 @@ def _fn_with_nested_dataclass(x: int, /, *, data: AddThreeNumbers) -> int:
return x + data()


def _xfail_in_py311(*param):
return pytest.param(
*param,
marks=pytest.mark.xfail(
sys.version_info >= (3, 11),
reason="TODO: test doesn't work in Python 3.11",
strict=True,
),
)


@pytest.mark.parametrize(
"args, expected, fn",
("args", "expected", "fn"),
[
_xfail_in_py311("", 1, partial(_fn_with_nested_dataclass, 1, data=AddThreeNumbers())),
("--a=1", 2, partial(_fn_with_nested_dataclass, 1)),
("--a=1 --b=1", 3, partial(_fn_with_nested_dataclass, 1)),
("--a=1 --b=1 --c=1", 4, partial(_fn_with_nested_dataclass, 1)),
("2 --a=1 --b=1 --c=1", 5, partial(_fn_with_nested_dataclass)),
_xfail_in_py311(
"", 1, change_defaults(_fn_with_nested_dataclass, 1, data=AddThreeNumbers())
),
("--a=1", 2, change_defaults(_fn_with_nested_dataclass, 1)),
("--a=1 --b=1", 3, change_defaults(_fn_with_nested_dataclass, 1)),
("--a=1 --b=1 --c=1", 4, change_defaults(_fn_with_nested_dataclass, 1)),
("2 --a=1 --b=1 --c=1", 5, change_defaults(_fn_with_nested_dataclass)),
],
)
def test_nested_dataclass(
Expand Down