1- from typing import Tuple , Callable , Sequence , Dict , Any
1+ from typing import Tuple , Callable , Sequence , cast
2+ from typing import Dict , Union , Optional , Hashable , Any
23
34from ..model import Model
45from ..config import registry
5- from ..types import Ints2d , DTypes
6+ from ..types import Ints1d , Ints2d , DTypes
7+ from ..util import is_xp_array , to_numpy
68
79
8- InT = Sequence [Any ]
10+ InT = Union [ Sequence [Hashable ], Ints1d , Ints2d ]
911OutT = Ints2d
1012
13+ InT_v1 = Sequence [Any ]
14+ OutT_v1 = Ints2d
15+
1116
1217@registry .layers ("remap_ids.v1" )
1318def remap_ids (
1419 mapping_table : Dict [Any , int ] = {}, default : int = 0 , dtype : DTypes = "i"
15- ) -> Model [InT , OutT ]:
20+ ) -> Model [InT_v1 , OutT_v1 ]:
1621 """Remap string or integer inputs using a mapping table, usually as a
1722 preprocess before embeddings. The mapping table can be passed in on input,
1823 or updated after the layer has been created. The mapping table is stored in
@@ -26,7 +31,7 @@ def remap_ids(
2631
2732
2833def forward (
29- model : Model [InT , OutT ], inputs : InT , is_train : bool
34+ model : Model [InT_v1 , OutT_v1 ], inputs : InT_v1 , is_train : bool
3035) -> Tuple [OutT , Callable ]:
3136 table = model .attrs ["mapping_table" ]
3237 default = model .attrs ["default" ]
@@ -35,7 +40,60 @@ def forward(
3540 arr = model .ops .asarray2i (values , dtype = dtype )
3641 output = model .ops .reshape2i (arr , - 1 , 1 )
3742
38- def backprop (dY : OutT ) -> InT :
43+ def backprop (dY : OutT_v1 ) -> InT :
3944 return []
4045
4146 return output , backprop
47+
48+
49+ @registry .layers ("remap_ids.v2" )
50+ def remap_ids_v2 (
51+ mapping_table : Optional [Union [Dict [int , int ], Dict [str , int ]]] = None ,
52+ default : int = 0 ,
53+ * ,
54+ column : Optional [int ] = None
55+ ) -> Model [InT , OutT ]:
56+ """Remap string or integer inputs using a mapping table,
57+ usually as a preprocessing step before embeddings.
58+ The mapping table can be passed in on input,
59+ or updated after the layer has been created.
60+ The mapping table is stored in the "mapping_table" attribute.
61+ Two dimensional arrays can be provided as input in which case
62+ the 'column' chooses which column to process. This is useful
63+ to work together with FeatureExtractor in spaCy.
64+ """
65+ return Model (
66+ "remap_ids" ,
67+ forward_v2 ,
68+ attrs = {"mapping_table" : mapping_table , "default" : default , "column" : column },
69+ )
70+
71+
72+ def forward_v2 (
73+ model : Model [InT , OutT ], inputs : InT , is_train : bool
74+ ) -> Tuple [OutT , Callable ]:
75+ table = model .attrs ["mapping_table" ]
76+ if table is None :
77+ raise ValueError ("'mapping table' not set" )
78+ default = model .attrs ["default" ]
79+ column = model .attrs ["column" ]
80+ if is_xp_array (inputs ):
81+ xp_input = True
82+ if column is not None :
83+ idx = to_numpy (cast (Ints2d , inputs )[:, column ])
84+ else :
85+ idx = to_numpy (inputs )
86+ else :
87+ xp_input = False
88+ idx = inputs
89+ values = [table .get (x , default ) for x in idx ]
90+ arr = model .ops .asarray2i (values , dtype = "i" )
91+ output = model .ops .reshape2i (arr , - 1 , 1 )
92+
93+ def backprop (dY : OutT ) -> InT :
94+ if xp_input :
95+ return model .ops .xp .empty (dY .shape ) # type: ignore
96+ else :
97+ return []
98+
99+ return output , backprop
0 commit comments