Skip to content

Commit 9026ad0

Browse files
committed
[RunInference] Add content-aware dynamic batching via element_size_fn (Issue #37414)
1 parent e12436a commit 9026ad0

File tree

2 files changed

+170
-0
lines changed

2 files changed

+170
-0
lines changed

sdks/python/apache_beam/ml/inference/base.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,22 @@ def with_postprocess_fn(
301301
inference result in order from first applied to last applied."""
302302
return _PostProcessingModelHandler(self, fn)
303303

304+
def with_element_size_fn(
305+
self, fn: Callable[[Union[ExampleT, tuple[KeyT, ExampleT]]], int]
306+
) -> 'ModelHandler[ExampleT, PredictionT, ModelT]':
307+
"""Returns a new `ModelHandler` that uses `fn` for element sizing.
308+
309+
The provided sizing function is passed through to `beam.BatchElements`
310+
via `batch_elements_kwargs` as `element_size_fn`.
311+
312+
Args:
313+
fn: A callable that returns the size (as an `int`) for each element.
314+
315+
Returns:
316+
A `ModelHandler` wrapping this handler, with `fn` used for batching.
317+
"""
318+
return _SizingModelHandler(self, fn)
319+
304320
def with_no_batching(
305321
self
306322
) -> """ModelHandler[Union[
@@ -1287,6 +1303,85 @@ def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
12871303
return self._base.get_postprocess_fns() + [self._postprocess_fn]
12881304

12891305

1306+
class _SizingModelHandler(Generic[ExampleT, PredictionT, ModelT],
1307+
ModelHandler[ExampleT, PredictionT, ModelT]):
1308+
def __init__(
1309+
self,
1310+
base: ModelHandler[ExampleT, PredictionT, ModelT],
1311+
element_size_fn: Callable[[Union[ExampleT, tuple[KeyT, ExampleT]]], int]):
1312+
"""A ModelHandler that has an element_size_fn associated with it.
1313+
1314+
Args:
1315+
base: An implementation of the underlying model handler.
1316+
element_size_fn: the element sizing function to use for batching.
1317+
"""
1318+
self._base = base
1319+
self._env_vars = getattr(base, '_env_vars', {})
1320+
self._element_size_fn = element_size_fn
1321+
1322+
def set_environment_vars(self):
1323+
return self._base.set_environment_vars()
1324+
1325+
def load_model(self) -> ModelT:
1326+
return self._base.load_model()
1327+
1328+
def run_inference(
1329+
self,
1330+
batch: Sequence[Union[ExampleT, tuple[KeyT, ExampleT]]],
1331+
model: ModelT,
1332+
inference_args: Optional[dict[str, Any]] = None
1333+
) -> Union[Iterable[PredictionT], Iterable[tuple[KeyT, PredictionT]]]:
1334+
return self._base.run_inference(batch, model, inference_args)
1335+
1336+
def get_num_bytes(
1337+
self, batch: Sequence[Union[ExampleT, tuple[KeyT, ExampleT]]]) -> int:
1338+
return self._base.get_num_bytes(batch)
1339+
1340+
def get_metrics_namespace(self) -> str:
1341+
return self._base.get_metrics_namespace()
1342+
1343+
def get_resource_hints(self):
1344+
return self._base.get_resource_hints()
1345+
1346+
def batch_elements_kwargs(self):
1347+
kwargs = dict(self._base.batch_elements_kwargs())
1348+
kwargs["element_size_fn"] = self._element_size_fn
1349+
return kwargs
1350+
1351+
def validate_inference_args(self, inference_args: Optional[dict[str, Any]]):
1352+
return self._base.validate_inference_args(inference_args)
1353+
1354+
def update_model_path(self, model_path: Optional[str] = None):
1355+
return self._base.update_model_path(model_path=model_path)
1356+
1357+
def update_model_paths(
1358+
self,
1359+
model: ModelT,
1360+
model_paths: Optional[Union[str, list[KeyModelPathMapping]]] = None):
1361+
return self._base.update_model_paths(model, model_paths)
1362+
1363+
def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
1364+
return self._base.get_preprocess_fns()
1365+
1366+
def should_skip_batching(self) -> bool:
1367+
return self._base.should_skip_batching()
1368+
1369+
def share_model_across_processes(self) -> bool:
1370+
return self._base.share_model_across_processes()
1371+
1372+
def model_copies(self) -> int:
1373+
return self._base.model_copies()
1374+
1375+
def override_metrics(self, metrics_namespace: str = '') -> bool:
1376+
return self._base.override_metrics(metrics_namespace=metrics_namespace)
1377+
1378+
def should_garbage_collect_on_timeout(self) -> bool:
1379+
return self._base.should_garbage_collect_on_timeout()
1380+
1381+
def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
1382+
return self._base.get_postprocess_fns()
1383+
1384+
12901385
class RunInference(beam.PTransform[beam.PCollection[Union[ExampleT,
12911386
Iterable[ExampleT]]],
12921387
beam.PCollection[PredictionT]]):

sdks/python/apache_beam/ml/inference/base_test.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2133,5 +2133,80 @@ def request(self, batch, model, inference_args=None):
21332133
model_handler.run_inference([1], FakeModel())
21342134

21352135

2136+
class FakeModelHandlerForSizing(base.ModelHandler[int, int, FakeModel]):
2137+
"""A ModelHandler used to test element sizing behavior."""
2138+
def __init__(self, max_batch_size: int = 10):
2139+
self._max_batch_size = max_batch_size
2140+
2141+
def load_model(self) -> FakeModel:
2142+
return FakeModel()
2143+
2144+
def run_inference(self, batch, model, inference_args=None):
2145+
return [model.predict(x) for x in batch]
2146+
2147+
def batch_elements_kwargs(self):
2148+
return {'max_batch_size': self._max_batch_size}
2149+
2150+
2151+
class RunInferenceSizeTest(unittest.TestCase):
2152+
"""Tests for ModelHandler.with_element_size_fn."""
2153+
def test_kwargs_are_passed_correctly(self):
2154+
"""Adds element_size_fn without clobbering existing kwargs."""
2155+
def size_fn(x):
2156+
return 10
2157+
2158+
base_handler = FakeModelHandlerForSizing(max_batch_size=20)
2159+
sized_handler = base_handler.with_element_size_fn(size_fn)
2160+
2161+
kwargs = sized_handler.batch_elements_kwargs()
2162+
2163+
self.assertEqual(kwargs['max_batch_size'], 20)
2164+
self.assertIn('element_size_fn', kwargs)
2165+
self.assertEqual(kwargs['element_size_fn'](1), 10)
2166+
2167+
def test_element_size_fn_wrapper_delegates_correctly(self):
2168+
"""_SizingModelHandler delegates methods to the base handler."""
2169+
base_handler = FakeModelHandlerForSizing()
2170+
size_fn = lambda x: x * 2
2171+
sized_handler = base_handler.with_element_size_fn(size_fn)
2172+
2173+
model = sized_handler.load_model()
2174+
self.assertIsInstance(model, FakeModel)
2175+
2176+
result = list(sized_handler.run_inference([1, 2], model))
2177+
expected = [2, 3] # FakeModel.predict(x) = x + 1
2178+
self.assertEqual(result, expected)
2179+
2180+
self.assertEqual(sized_handler.get_metrics_namespace(), 'RunInference')
2181+
2182+
def test_multiple_wrappers_can_be_chained(self):
2183+
"""Sizing can be chained with other ModelHandler wrappers."""
2184+
base_handler = FakeModelHandlerForSizing()
2185+
preprocess_fn = lambda x: x * 10
2186+
size_fn = lambda x: 5
2187+
2188+
chained_handler = (
2189+
base_handler.with_preprocess_fn(preprocess_fn).with_element_size_fn(
2190+
size_fn))
2191+
2192+
kwargs = chained_handler.batch_elements_kwargs()
2193+
self.assertIn('element_size_fn', kwargs)
2194+
self.assertEqual(kwargs['element_size_fn'](1), 5)
2195+
2196+
def test_sizing_with_edge_cases(self):
2197+
"""Allows extreme values from element_size_fn."""
2198+
base_handler = FakeModelHandlerForSizing(max_batch_size=1)
2199+
2200+
zero_size_fn = lambda x: 0
2201+
sized_handler = base_handler.with_element_size_fn(zero_size_fn)
2202+
kwargs = sized_handler.batch_elements_kwargs()
2203+
self.assertEqual(kwargs['element_size_fn'](999), 0)
2204+
2205+
large_size_fn = lambda x: 1000000
2206+
sized_handler = base_handler.with_element_size_fn(large_size_fn)
2207+
kwargs = sized_handler.batch_elements_kwargs()
2208+
self.assertEqual(kwargs['element_size_fn'](1), 1000000)
2209+
2210+
21362211
if __name__ == '__main__':
21372212
unittest.main()

0 commit comments

Comments
 (0)