Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 45 additions & 7 deletions sdks/python/apache_beam/ml/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,48 @@ class KeyModelPathMapping(Generic[KeyT]):

class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
"""Has the ability to load and apply an ML model."""
def __init__(self):
"""Environment variables are set using a dict named 'env_vars' before
loading the model. Child classes can accept this dict as a kwarg."""
self._env_vars = {}
def __init__(
self,
*,
min_batch_size: Optional[int] = None,
max_batch_size: Optional[int] = None,
max_batch_duration_secs: Optional[int] = None,
max_batch_weight: Optional[int] = None,
element_size_fn: Optional[Callable[[Any], int]] = None,
large_model: bool = False,
model_copies: Optional[int] = None,
**kwargs):
"""Initializes the ModelHandler.

Args:
min_batch_size: the minimum batch size to use when batching inputs.
max_batch_size: the maximum batch size to use when batching inputs.
max_batch_duration_secs: the maximum amount of time to buffer a batch
before emitting; used in streaming contexts.
max_batch_weight: the maximum weight of a batch. Requires element_size_fn.
element_size_fn: a function that returns the size (weight) of an element.
large_model: set to true if your model is large enough to run into
memory pressure if you load multiple copies.
model_copies: The exact number of models that you would like loaded
onto your machine.
kwargs: 'env_vars' can be used to set environment variables
before loading the model.
"""
self._env_vars = kwargs.get('env_vars', {})
self._batching_kwargs: dict[str, Any] = {}
if min_batch_size is not None:
self._batching_kwargs['min_batch_size'] = min_batch_size
if max_batch_size is not None:
self._batching_kwargs['max_batch_size'] = max_batch_size
if max_batch_duration_secs is not None:
self._batching_kwargs['max_batch_duration_secs'] = max_batch_duration_secs
if max_batch_weight is not None:
self._batching_kwargs['max_batch_weight'] = max_batch_weight
if element_size_fn is not None:
self._batching_kwargs['element_size_fn'] = element_size_fn
Comment on lines +198 to +208
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This series of if statements to populate _batching_kwargs can be made more concise and easier to maintain, especially if more parameters are added in the future.

    batching_params = {
        'min_batch_size': min_batch_size,
        'max_batch_size': max_batch_size,
        'max_batch_duration_secs': max_batch_duration_secs,
        'max_batch_weight': max_batch_weight,
        'element_size_fn': element_size_fn,
    }
    self._batching_kwargs: dict[str, Any] = {
        k: v for k, v in batching_params.items() if v is not None
    }

self._large_model = large_model
self._model_copies = model_copies
self._share_across_processes = large_model or (model_copies is not None)

def load_model(self) -> ModelT:
"""Loads and initializes a model for processing."""
Expand Down Expand Up @@ -220,7 +258,7 @@ def batch_elements_kwargs(self) -> Mapping[str, Any]:
Returns:
kwargs suitable for beam.BatchElements.
"""
return {}
return getattr(self, '_batching_kwargs', {})

def validate_inference_args(self, inference_args: Optional[dict[str, Any]]):
"""
Expand Down Expand Up @@ -325,14 +363,14 @@ def share_model_across_processes(self) -> bool:
memory. Multi-process support may vary by runner, but this will fallback to
loading per process as necessary. See
https://beam.apache.org/releases/pydoc/current/apache_beam.utils.multi_process_shared.html"""
return False
return getattr(self, '_share_across_processes', False)

def model_copies(self) -> int:
"""Returns the maximum number of model copies that should be loaded at one
time. This only impacts model handlers that are using
share_model_across_processes to share their model across processes instead
of being loaded per process."""
return 1
return getattr(self, '_model_copies', None) or 1

def override_metrics(self, metrics_namespace: str = '') -> bool:
"""Returns a boolean representing whether or not a model handler will
Expand Down
146 changes: 146 additions & 0 deletions sdks/python/apache_beam/ml/inference/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2133,5 +2133,151 @@ def request(self, batch, model, inference_args=None):
model_handler.run_inference([1], FakeModel())


class FakeModelHandlerForSizing(base.ModelHandler[int, int, FakeModel]):
"""A ModelHandler used to test element sizing behavior."""
def __init__(
self,
max_batch_size: int = 10,
max_batch_weight: Optional[int] = None,
element_size_fn=None):
super().__init__(
max_batch_size=max_batch_size,
max_batch_weight=max_batch_weight,
element_size_fn=element_size_fn)

def load_model(self) -> FakeModel:
return FakeModel()

def run_inference(self, batch, model, inference_args=None):
return [model.predict(x) for x in batch]


class RunInferenceSizeTest(unittest.TestCase):
"""Tests for ModelHandler.batch_elements_kwargs with element_size_fn."""
def test_kwargs_are_passed_correctly(self):
"""Adds element_size_fn without clobbering existing kwargs."""
def size_fn(x):
return 10

sized_handler = FakeModelHandlerForSizing(
max_batch_size=20, max_batch_weight=100, element_size_fn=size_fn)

kwargs = sized_handler.batch_elements_kwargs()

self.assertEqual(kwargs['max_batch_size'], 20)
self.assertEqual(kwargs['max_batch_weight'], 100)
self.assertIn('element_size_fn', kwargs)
self.assertEqual(kwargs['element_size_fn'](1), 10)

def test_sizing_with_edge_cases(self):
"""Allows extreme values from element_size_fn."""
zero_size_fn = lambda x: 0
sized_handler = FakeModelHandlerForSizing(
max_batch_size=1, element_size_fn=zero_size_fn)
kwargs = sized_handler.batch_elements_kwargs()
self.assertEqual(kwargs['element_size_fn'](999), 0)

large_size_fn = lambda x: 1000000
sized_handler = FakeModelHandlerForSizing(
max_batch_size=1, element_size_fn=large_size_fn)
kwargs = sized_handler.batch_elements_kwargs()
self.assertEqual(kwargs['element_size_fn'](1), 1000000)


class FakeModelHandlerForBatching(base.ModelHandler[int, int, FakeModel]):
"""A ModelHandler used to test batching behavior via base class __init__."""
def __init__(self, **kwargs):
super().__init__(**kwargs)

def load_model(self) -> FakeModel:
return FakeModel()

def run_inference(self, batch, model, inference_args=None):
return [model.predict(x) for x in batch]


class ModelHandlerBatchingArgsTest(unittest.TestCase):
"""Tests for ModelHandler.__init__ batching parameters."""
def test_batch_elements_kwargs_all_args(self):
"""All batching args passed to __init__ are in batch_elements_kwargs."""
def size_fn(x):
return 10

handler = FakeModelHandlerForBatching(
min_batch_size=5,
max_batch_size=20,
max_batch_duration_secs=30,
max_batch_weight=100,
element_size_fn=size_fn)

kwargs = handler.batch_elements_kwargs()

self.assertEqual(kwargs['min_batch_size'], 5)
self.assertEqual(kwargs['max_batch_size'], 20)
self.assertEqual(kwargs['max_batch_duration_secs'], 30)
self.assertEqual(kwargs['max_batch_weight'], 100)
self.assertIn('element_size_fn', kwargs)
self.assertEqual(kwargs['element_size_fn'](1), 10)

def test_batch_elements_kwargs_partial_args(self):
"""Only provided batching args are included in kwargs."""
handler = FakeModelHandlerForBatching(max_batch_size=50)
kwargs = handler.batch_elements_kwargs()

self.assertEqual(kwargs, {'max_batch_size': 50})

def test_batch_elements_kwargs_empty_when_no_args(self):
"""No batching kwargs when none are provided."""
handler = FakeModelHandlerForBatching()
kwargs = handler.batch_elements_kwargs()

self.assertEqual(kwargs, {})

def test_large_model_sets_share_across_processes(self):
"""Setting large_model=True enables share_model_across_processes."""
handler = FakeModelHandlerForBatching(large_model=True)

self.assertTrue(handler.share_model_across_processes())

def test_model_copies_sets_share_across_processes(self):
"""Setting model_copies enables share_model_across_processes."""
handler = FakeModelHandlerForBatching(model_copies=2)

self.assertTrue(handler.share_model_across_processes())
self.assertEqual(handler.model_copies(), 2)

def test_default_share_across_processes_is_false(self):
"""Default share_model_across_processes is False."""
handler = FakeModelHandlerForBatching()

self.assertFalse(handler.share_model_across_processes())

def test_default_model_copies_is_one(self):
"""Default model_copies is 1."""
handler = FakeModelHandlerForBatching()

self.assertEqual(handler.model_copies(), 1)

def test_env_vars_from_kwargs(self):
"""Environment variables can be passed via kwargs."""
handler = FakeModelHandlerForBatching(env_vars={'MY_VAR': 'value'})

self.assertEqual(handler._env_vars, {'MY_VAR': 'value'})

def test_min_batch_size_only(self):
"""min_batch_size can be passed alone."""
handler = FakeModelHandlerForBatching(min_batch_size=10)
kwargs = handler.batch_elements_kwargs()

self.assertEqual(kwargs, {'min_batch_size': 10})

def test_max_batch_duration_secs_only(self):
"""max_batch_duration_secs can be passed alone."""
handler = FakeModelHandlerForBatching(max_batch_duration_secs=60)
kwargs = handler.batch_elements_kwargs()

self.assertEqual(kwargs, {'max_batch_duration_secs': 60})


if __name__ == '__main__':
unittest.main()
16 changes: 14 additions & 2 deletions sdks/python/apache_beam/ml/inference/gemini_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ def __init__(
min_batch_size: Optional[int] = None,
max_batch_size: Optional[int] = None,
max_batch_duration_secs: Optional[int] = None,
max_batch_weight: Optional[int] = None,
element_size_fn: Optional[Callable[[Any], int]] = None,
**kwargs):
"""Implementation of the ModelHandler interface for Google Gemini.
**NOTE:** This API and its implementation are under development and
Expand All @@ -134,15 +136,18 @@ def __init__(
project: the GCP project to use for Vertex AI requests. Setting this
parameter routes requests to Vertex AI. If this paramter is provided,
location must also be provided and api_key should not be set.
location: the GCP project to use for Vertex AI requests. Setting this
location: the GCP project to use for Vertex AI requests. Setting this
parameter routes requests to Vertex AI. If this paramter is provided,
project must also be provided and api_key should not be set.
min_batch_size: optional. the minimum batch size to use when batching
inputs.
max_batch_size: optional. the maximum batch size to use when batching
inputs.
max_batch_duration_secs: optional. the maximum amount of time to buffer
max_batch_duration_secs: optional. the maximum amount of time to buffer
a batch before emitting; used in streaming contexts.
max_batch_weight: optional. the maximum total weight of a batch.
element_size_fn: optional. a function that returns the size (weight)
of an element.
"""
self._batching_kwargs = {}
self._env_vars = kwargs.get('env_vars', {})
Expand All @@ -152,6 +157,10 @@ def __init__(
self._batching_kwargs["max_batch_size"] = max_batch_size
if max_batch_duration_secs is not None:
self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs
if max_batch_weight is not None:
self._batching_kwargs["max_batch_weight"] = max_batch_weight
if element_size_fn is not None:
self._batching_kwargs['element_size_fn'] = element_size_fn

self.model_name = model_name
self.request_fn = request_fn
Expand All @@ -174,6 +183,9 @@ def __init__(
retry_filter=_retry_on_appropriate_service_error,
**kwargs)

def batch_elements_kwargs(self):
return self._batching_kwargs

def create_client(self) -> genai.Client:
"""Creates the GenAI client used to send requests. Creates a version for
the Vertex AI API or the Gemini Developer API based on the arguments
Expand Down
Loading
Loading