11from dataclasses import dataclass
2- from typing import Dict , List
2+ from typing import ClassVar , Dict , List
33
44from ..features import ClassLabel , Features , Value
55from .base import TaskTemplate
66
77
8+ class FeaturesWithLazyClassLabel :
9+ def __init__ (self , features , label_column = "labels" ):
10+ assert label_column in features , f"Key '{ label_column } ' missing in features { features } "
11+ self ._features = features
12+ self ._label_column = label_column
13+
14+ def __get__ (self , obj , objtype = None ):
15+ if obj is None :
16+ return self ._features
17+
18+ assert hasattr (obj , self ._label_column ), f"Object has no attribute '{ self ._label_column } '"
19+ features = self ._features .copy ()
20+ features ["labels" ] = ClassLabel (names = getattr (obj , self ._label_column ))
21+ return features
22+
23+
824@dataclass (frozen = True )
925class TextClassification (TaskTemplate ):
10- task = "text-classification"
11- input_schema = Features ({"text" : Value ("string" )})
12- # TODO(lewtun): Since we update this in __post_init__ do we need to set a default? We'll need it for __init__ so
13- # investigate if there's a more elegant approach.
14- label_schema = Features ({"labels" : ClassLabel })
26+ task : ClassVar [str ] = "text-classification"
27+ input_schema : ClassVar [Features ] = Features ({"text" : Value ("string" )})
28+ # TODO(lewtun): Find a more elegant approach without descriptors.
29+ label_schema : ClassVar [Features ] = FeaturesWithLazyClassLabel (Features ({"labels" : ClassLabel }))
1530 labels : List [str ]
1631 text_column : str = "text"
1732 label_column : str = "labels"
1833
1934 def __post_init__ (self ):
20- assert sorted ( set ( self .labels )) == sorted ( self .labels ), "Labels must be unique"
35+ assert len ( self .labels ) == len ( set ( self .labels ) ), "Labels must be unique"
2136 # Cast labels to tuple to allow hashing
22- object .__setattr__ (self , "labels" , tuple (sorted (self .labels )))
23- object .__setattr__ (self , "text_column" , self .text_column )
24- object .__setattr__ (self , "label_column" , self .label_column )
25- self .label_schema ["labels" ] = ClassLabel (names = self .labels )
26- object .__setattr__ (self , "label2id" , {label : idx for idx , label in enumerate (self .labels )})
27- object .__setattr__ (self , "id2label" , {idx : label for label , idx in self .label2id .items ()})
37+ self .__dict__ ["labels" ] = tuple (sorted (self .labels ))
2838
2939 @property
3040 def column_mapping (self ) -> Dict [str , str ]:
@@ -33,10 +43,10 @@ def column_mapping(self) -> Dict[str, str]:
3343 self .label_column : "labels" ,
3444 }
3545
36- @classmethod
37- def from_dict ( cls , template_dict : dict ) -> "TextClassification" :
38- return cls (
39- text_column = template_dict [ "text_column" ],
40- label_column = template_dict [ "label_column" ],
41- labels = template_dict [ "labels" ],
42- )
46+ @property
47+ def label2id ( self ) :
48+ return { label : idx for idx , label in enumerate ( self . labels )}
49+
50+ @ property
51+ def id2label ( self ):
52+ return { idx : label for idx , label in enumerate ( self . labels )}
0 commit comments