Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions src/transformers/hf_argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,27 @@
from enum import Enum
from inspect import isclass
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Literal, NewType, Optional, Tuple, Union, get_type_hints
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Literal,
NewType,
Optional,
Tuple,
TypeVar,
Union,
get_type_hints,
)

import yaml


DataClass = NewType("DataClass", Any)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any idea why these NewType()s were used originally instead of a simple alias? Like DataClass = Any, DataclassType = Any.

DataClassType = NewType("DataClassType", Any)
T = TypeVar("T")


# From https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
Expand Down Expand Up @@ -269,7 +283,7 @@ def parse_args_into_dataclasses(
look_for_args_file=True,
args_filename=None,
args_file_flag=None,
) -> Tuple[DataClass, ...]:
) -> Tuple[T, ...]:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me this looks the same as annotating the return type to Tuple[Any, ...], since T isn't used anywhere else.

The usual use of TypeVar is to infer an output type from either an input type (if T is used as annotation for an argument, like identity(x: T) -> T) or from a concretized version of a generic class (like __getitem__() return T for a list[T]).

I have some ideas for suggestions that I can leave in a comment!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for the comment. I think the most ideal workflow is something like below working, but not sure if it's possible 👀


class HfArgumentParser:
    def __init__(self, test: List[Type[T]]):
        self.test = test

    def parse_args_into_dataclasses(self) -> List[T]:
        return self.test
    
parser = HfArgumentParser([RewardConfig, Config2])
args, args2 = parser.parse_args_into_dataclasses()

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first option I listed below should do this exactly! The trick is that HfArgumentParser needs to inherit from Generic[T].

"""
Parse command-line args into instances of the specified dataclass types.

Expand Down