@@ -301,6 +301,22 @@ def with_postprocess_fn(
301301 inference result in order from first applied to last applied."""
302302 return _PostProcessingModelHandler (self , fn )
303303
304+ def with_element_size_fn (
305+ self , fn : Callable [[Union [ExampleT , tuple [KeyT , ExampleT ]]], int ]
306+ ) -> 'ModelHandler[ExampleT, PredictionT, ModelT]' :
307+ """Returns a new `ModelHandler` that uses `fn` for element sizing.
308+
309+ The provided sizing function is passed through to `beam.BatchElements`
310+ via `batch_elements_kwargs` as `element_size_fn`.
311+
312+ Args:
313+ fn: A callable that returns the size (as an `int`) for each element.
314+
315+ Returns:
316+ A `ModelHandler` wrapping this handler, with `fn` used for batching.
317+ """
318+ return _SizingModelHandler (self , fn )
319+
304320 def with_no_batching (
305321 self
306322 ) -> """ModelHandler[Union[
@@ -1287,6 +1303,85 @@ def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
12871303 return self ._base .get_postprocess_fns () + [self ._postprocess_fn ]
12881304
12891305
1306+ class _SizingModelHandler (Generic [ExampleT , PredictionT , ModelT ],
1307+ ModelHandler [ExampleT , PredictionT , ModelT ]):
1308+ def __init__ (
1309+ self ,
1310+ base : ModelHandler [ExampleT , PredictionT , ModelT ],
1311+ element_size_fn : Callable [[Union [ExampleT , tuple [KeyT , ExampleT ]]], int ]):
1312+ """A ModelHandler that has an element_size_fn associated with it.
1313+
1314+ Args:
1315+ base: An implementation of the underlying model handler.
1316+ element_size_fn: the element sizing function to use for batching.
1317+ """
1318+ self ._base = base
1319+ self ._env_vars = getattr (base , '_env_vars' , {})
1320+ self ._element_size_fn = element_size_fn
1321+
1322+ def set_environment_vars (self ):
1323+ return self ._base .set_environment_vars ()
1324+
1325+ def load_model (self ) -> ModelT :
1326+ return self ._base .load_model ()
1327+
1328+ def run_inference (
1329+ self ,
1330+ batch : Sequence [Union [ExampleT , tuple [KeyT , ExampleT ]]],
1331+ model : ModelT ,
1332+ inference_args : Optional [dict [str , Any ]] = None
1333+ ) -> Union [Iterable [PredictionT ], Iterable [tuple [KeyT , PredictionT ]]]:
1334+ return self ._base .run_inference (batch , model , inference_args )
1335+
1336+ def get_num_bytes (
1337+ self , batch : Sequence [Union [ExampleT , tuple [KeyT , ExampleT ]]]) -> int :
1338+ return self ._base .get_num_bytes (batch )
1339+
1340+ def get_metrics_namespace (self ) -> str :
1341+ return self ._base .get_metrics_namespace ()
1342+
1343+ def get_resource_hints (self ):
1344+ return self ._base .get_resource_hints ()
1345+
1346+ def batch_elements_kwargs (self ):
1347+ kwargs = dict (self ._base .batch_elements_kwargs ())
1348+ kwargs ["element_size_fn" ] = self ._element_size_fn
1349+ return kwargs
1350+
1351+ def validate_inference_args (self , inference_args : Optional [dict [str , Any ]]):
1352+ return self ._base .validate_inference_args (inference_args )
1353+
1354+ def update_model_path (self , model_path : Optional [str ] = None ):
1355+ return self ._base .update_model_path (model_path = model_path )
1356+
1357+ def update_model_paths (
1358+ self ,
1359+ model : ModelT ,
1360+ model_paths : Optional [Union [str , list [KeyModelPathMapping ]]] = None ):
1361+ return self ._base .update_model_paths (model , model_paths )
1362+
1363+ def get_preprocess_fns (self ) -> Iterable [Callable [[Any ], Any ]]:
1364+ return self ._base .get_preprocess_fns ()
1365+
1366+ def should_skip_batching (self ) -> bool :
1367+ return self ._base .should_skip_batching ()
1368+
1369+ def share_model_across_processes (self ) -> bool :
1370+ return self ._base .share_model_across_processes ()
1371+
1372+ def model_copies (self ) -> int :
1373+ return self ._base .model_copies ()
1374+
1375+ def override_metrics (self , metrics_namespace : str = '' ) -> bool :
1376+ return self ._base .override_metrics (metrics_namespace = metrics_namespace )
1377+
1378+ def should_garbage_collect_on_timeout (self ) -> bool :
1379+ return self ._base .should_garbage_collect_on_timeout ()
1380+
1381+ def get_postprocess_fns (self ) -> Iterable [Callable [[Any ], Any ]]:
1382+ return self ._base .get_postprocess_fns ()
1383+
1384+
12901385class RunInference (beam .PTransform [beam .PCollection [Union [ExampleT ,
12911386 Iterable [ExampleT ]]],
12921387 beam .PCollection [PredictionT ]]):
0 commit comments