diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index e0f870669f7f..ad2e2f8d0e3c 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -167,10 +167,48 @@ class KeyModelPathMapping(Generic[KeyT]): class ModelHandler(Generic[ExampleT, PredictionT, ModelT]): """Has the ability to load and apply an ML model.""" - def __init__(self): - """Environment variables are set using a dict named 'env_vars' before - loading the model. Child classes can accept this dict as a kwarg.""" - self._env_vars = {} + def __init__( + self, + *, + min_batch_size: Optional[int] = None, + max_batch_size: Optional[int] = None, + max_batch_duration_secs: Optional[int] = None, + max_batch_weight: Optional[int] = None, + element_size_fn: Optional[Callable[[Any], int]] = None, + large_model: bool = False, + model_copies: Optional[int] = None, + **kwargs): + """Initializes the ModelHandler. + + Args: + min_batch_size: the minimum batch size to use when batching inputs. + max_batch_size: the maximum batch size to use when batching inputs. + max_batch_duration_secs: the maximum amount of time to buffer a batch + before emitting; used in streaming contexts. + max_batch_weight: the maximum weight of a batch. Requires element_size_fn. + element_size_fn: a function that returns the size (weight) of an element. + large_model: set to true if your model is large enough to run into + memory pressure if you load multiple copies. + model_copies: The exact number of models that you would like loaded + onto your machine. + kwargs: 'env_vars' can be used to set environment variables + before loading the model. + """ + self._env_vars = kwargs.get('env_vars', {}) + self._batching_kwargs: dict[str, Any] = {} + if min_batch_size is not None: + self._batching_kwargs['min_batch_size'] = min_batch_size + if max_batch_size is not None: + self._batching_kwargs['max_batch_size'] = max_batch_size + if max_batch_duration_secs is not None: + self._batching_kwargs['max_batch_duration_secs'] = max_batch_duration_secs + if max_batch_weight is not None: + self._batching_kwargs['max_batch_weight'] = max_batch_weight + if element_size_fn is not None: + self._batching_kwargs['element_size_fn'] = element_size_fn + self._large_model = large_model + self._model_copies = model_copies + self._share_across_processes = large_model or (model_copies is not None) def load_model(self) -> ModelT: """Loads and initializes a model for processing.""" @@ -220,7 +258,7 @@ def batch_elements_kwargs(self) -> Mapping[str, Any]: Returns: kwargs suitable for beam.BatchElements. """ - return {} + return getattr(self, '_batching_kwargs', {}) def validate_inference_args(self, inference_args: Optional[dict[str, Any]]): """ @@ -325,14 +363,14 @@ def share_model_across_processes(self) -> bool: memory. Multi-process support may vary by runner, but this will fallback to loading per process as necessary. See https://beam.apache.org/releases/pydoc/current/apache_beam.utils.multi_process_shared.html""" - return False + return getattr(self, '_share_across_processes', False) def model_copies(self) -> int: """Returns the maximum number of model copies that should be loaded at one time. This only impacts model handlers that are using share_model_across_processes to share their model across processes instead of being loaded per process.""" - return 1 + return getattr(self, '_model_copies', None) or 1 def override_metrics(self, metrics_namespace: str = '') -> bool: """Returns a boolean representing whether or not a model handler will diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index 381bf5456604..55784166ad5d 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -2133,5 +2133,151 @@ def request(self, batch, model, inference_args=None): model_handler.run_inference([1], FakeModel()) +class FakeModelHandlerForSizing(base.ModelHandler[int, int, FakeModel]): + """A ModelHandler used to test element sizing behavior.""" + def __init__( + self, + max_batch_size: int = 10, + max_batch_weight: Optional[int] = None, + element_size_fn=None): + super().__init__( + max_batch_size=max_batch_size, + max_batch_weight=max_batch_weight, + element_size_fn=element_size_fn) + + def load_model(self) -> FakeModel: + return FakeModel() + + def run_inference(self, batch, model, inference_args=None): + return [model.predict(x) for x in batch] + + +class RunInferenceSizeTest(unittest.TestCase): + """Tests for ModelHandler.batch_elements_kwargs with element_size_fn.""" + def test_kwargs_are_passed_correctly(self): + """Adds element_size_fn without clobbering existing kwargs.""" + def size_fn(x): + return 10 + + sized_handler = FakeModelHandlerForSizing( + max_batch_size=20, max_batch_weight=100, element_size_fn=size_fn) + + kwargs = sized_handler.batch_elements_kwargs() + + self.assertEqual(kwargs['max_batch_size'], 20) + self.assertEqual(kwargs['max_batch_weight'], 100) + self.assertIn('element_size_fn', kwargs) + self.assertEqual(kwargs['element_size_fn'](1), 10) + + def test_sizing_with_edge_cases(self): + """Allows extreme values from element_size_fn.""" + zero_size_fn = lambda x: 0 + sized_handler = FakeModelHandlerForSizing( + max_batch_size=1, element_size_fn=zero_size_fn) + kwargs = sized_handler.batch_elements_kwargs() + self.assertEqual(kwargs['element_size_fn'](999), 0) + + large_size_fn = lambda x: 1000000 + sized_handler = FakeModelHandlerForSizing( + max_batch_size=1, element_size_fn=large_size_fn) + kwargs = sized_handler.batch_elements_kwargs() + self.assertEqual(kwargs['element_size_fn'](1), 1000000) + + +class FakeModelHandlerForBatching(base.ModelHandler[int, int, FakeModel]): + """A ModelHandler used to test batching behavior via base class __init__.""" + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def load_model(self) -> FakeModel: + return FakeModel() + + def run_inference(self, batch, model, inference_args=None): + return [model.predict(x) for x in batch] + + +class ModelHandlerBatchingArgsTest(unittest.TestCase): + """Tests for ModelHandler.__init__ batching parameters.""" + def test_batch_elements_kwargs_all_args(self): + """All batching args passed to __init__ are in batch_elements_kwargs.""" + def size_fn(x): + return 10 + + handler = FakeModelHandlerForBatching( + min_batch_size=5, + max_batch_size=20, + max_batch_duration_secs=30, + max_batch_weight=100, + element_size_fn=size_fn) + + kwargs = handler.batch_elements_kwargs() + + self.assertEqual(kwargs['min_batch_size'], 5) + self.assertEqual(kwargs['max_batch_size'], 20) + self.assertEqual(kwargs['max_batch_duration_secs'], 30) + self.assertEqual(kwargs['max_batch_weight'], 100) + self.assertIn('element_size_fn', kwargs) + self.assertEqual(kwargs['element_size_fn'](1), 10) + + def test_batch_elements_kwargs_partial_args(self): + """Only provided batching args are included in kwargs.""" + handler = FakeModelHandlerForBatching(max_batch_size=50) + kwargs = handler.batch_elements_kwargs() + + self.assertEqual(kwargs, {'max_batch_size': 50}) + + def test_batch_elements_kwargs_empty_when_no_args(self): + """No batching kwargs when none are provided.""" + handler = FakeModelHandlerForBatching() + kwargs = handler.batch_elements_kwargs() + + self.assertEqual(kwargs, {}) + + def test_large_model_sets_share_across_processes(self): + """Setting large_model=True enables share_model_across_processes.""" + handler = FakeModelHandlerForBatching(large_model=True) + + self.assertTrue(handler.share_model_across_processes()) + + def test_model_copies_sets_share_across_processes(self): + """Setting model_copies enables share_model_across_processes.""" + handler = FakeModelHandlerForBatching(model_copies=2) + + self.assertTrue(handler.share_model_across_processes()) + self.assertEqual(handler.model_copies(), 2) + + def test_default_share_across_processes_is_false(self): + """Default share_model_across_processes is False.""" + handler = FakeModelHandlerForBatching() + + self.assertFalse(handler.share_model_across_processes()) + + def test_default_model_copies_is_one(self): + """Default model_copies is 1.""" + handler = FakeModelHandlerForBatching() + + self.assertEqual(handler.model_copies(), 1) + + def test_env_vars_from_kwargs(self): + """Environment variables can be passed via kwargs.""" + handler = FakeModelHandlerForBatching(env_vars={'MY_VAR': 'value'}) + + self.assertEqual(handler._env_vars, {'MY_VAR': 'value'}) + + def test_min_batch_size_only(self): + """min_batch_size can be passed alone.""" + handler = FakeModelHandlerForBatching(min_batch_size=10) + kwargs = handler.batch_elements_kwargs() + + self.assertEqual(kwargs, {'min_batch_size': 10}) + + def test_max_batch_duration_secs_only(self): + """max_batch_duration_secs can be passed alone.""" + handler = FakeModelHandlerForBatching(max_batch_duration_secs=60) + kwargs = handler.batch_elements_kwargs() + + self.assertEqual(kwargs, {'max_batch_duration_secs': 60}) + + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/ml/inference/gemini_inference.py b/sdks/python/apache_beam/ml/inference/gemini_inference.py index c840efedd8fd..a79fbe8a555f 100644 --- a/sdks/python/apache_beam/ml/inference/gemini_inference.py +++ b/sdks/python/apache_beam/ml/inference/gemini_inference.py @@ -112,6 +112,8 @@ def __init__( min_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None, max_batch_duration_secs: Optional[int] = None, + max_batch_weight: Optional[int] = None, + element_size_fn: Optional[Callable[[Any], int]] = None, **kwargs): """Implementation of the ModelHandler interface for Google Gemini. **NOTE:** This API and its implementation are under development and @@ -134,15 +136,18 @@ def __init__( project: the GCP project to use for Vertex AI requests. Setting this parameter routes requests to Vertex AI. If this paramter is provided, location must also be provided and api_key should not be set. - location: the GCP project to use for Vertex AI requests. Setting this + location: the GCP project to use for Vertex AI requests. Setting this parameter routes requests to Vertex AI. If this paramter is provided, project must also be provided and api_key should not be set. min_batch_size: optional. the minimum batch size to use when batching inputs. max_batch_size: optional. the maximum batch size to use when batching inputs. - max_batch_duration_secs: optional. the maximum amount of time to buffer + max_batch_duration_secs: optional. the maximum amount of time to buffer a batch before emitting; used in streaming contexts. + max_batch_weight: optional. the maximum total weight of a batch. + element_size_fn: optional. a function that returns the size (weight) + of an element. """ self._batching_kwargs = {} self._env_vars = kwargs.get('env_vars', {}) @@ -152,6 +157,10 @@ def __init__( self._batching_kwargs["max_batch_size"] = max_batch_size if max_batch_duration_secs is not None: self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs + if max_batch_weight is not None: + self._batching_kwargs["max_batch_weight"] = max_batch_weight + if element_size_fn is not None: + self._batching_kwargs['element_size_fn'] = element_size_fn self.model_name = model_name self.request_fn = request_fn @@ -174,6 +183,9 @@ def __init__( retry_filter=_retry_on_appropriate_service_error, **kwargs) + def batch_elements_kwargs(self): + return self._batching_kwargs + def create_client(self) -> genai.Client: """Creates the GenAI client used to send requests. Creates a version for the Vertex AI API or the Gemini Developer API based on the arguments diff --git a/sdks/python/apache_beam/ml/inference/huggingface_inference.py b/sdks/python/apache_beam/ml/inference/huggingface_inference.py index 501a019c378e..2c1f5e2cc908 100644 --- a/sdks/python/apache_beam/ml/inference/huggingface_inference.py +++ b/sdks/python/apache_beam/ml/inference/huggingface_inference.py @@ -227,6 +227,8 @@ def __init__( max_batch_duration_secs: Optional[int] = None, large_model: bool = False, model_copies: Optional[int] = None, + max_batch_weight: Optional[int] = None, + element_size_fn: Optional[Callable[[Any], int]] = None, **kwargs): """ Implementation of the ModelHandler interface for HuggingFace with @@ -262,27 +264,28 @@ def __init__( model_copies: The exact number of models that you would like loaded onto your machine. This can be useful if you exactly know your CPU or GPU capacity and want to maximize resource utilization. + max_batch_weight: the maximum total weight of a batch. + element_size_fn: a function that returns the size (weight) of an element. kwargs: 'env_vars' can be used to set environment variables before loading the model. **Supported Versions:** HuggingFaceModelHandler supports transformers>=4.18.0. """ + super().__init__( + min_batch_size=min_batch_size, + max_batch_size=max_batch_size, + max_batch_duration_secs=max_batch_duration_secs, + max_batch_weight=max_batch_weight, + element_size_fn=element_size_fn, + large_model=large_model, + model_copies=model_copies, + **kwargs) self._model_uri = model_uri self._model_class = model_class self._device = device self._inference_fn = inference_fn self._model_config_args = load_model_args if load_model_args else {} - self._batching_kwargs = {} - self._env_vars = kwargs.get("env_vars", {}) - if min_batch_size is not None: - self._batching_kwargs["min_batch_size"] = min_batch_size - if max_batch_size is not None: - self._batching_kwargs["max_batch_size"] = max_batch_size - if max_batch_duration_secs is not None: - self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs - self._share_across_processes = large_model or (model_copies is not None) - self._model_copies = model_copies or 1 self._framework = framework _validate_constructor_args( @@ -352,15 +355,6 @@ def get_num_bytes( return sum( (el.element_size() for tensor in batch for el in tensor.values())) - def batch_elements_kwargs(self): - return self._batching_kwargs - - def share_model_across_processes(self) -> bool: - return self._share_across_processes - - def model_copies(self) -> int: - return self._model_copies - def get_metrics_namespace(self) -> str: """ Returns: @@ -415,6 +409,8 @@ def __init__( max_batch_duration_secs: Optional[int] = None, large_model: bool = False, model_copies: Optional[int] = None, + max_batch_weight: Optional[int] = None, + element_size_fn: Optional[Callable[[Any], int]] = None, **kwargs): """ Implementation of the ModelHandler interface for HuggingFace with @@ -450,27 +446,28 @@ def __init__( model_copies: The exact number of models that you would like loaded onto your machine. This can be useful if you exactly know your CPU or GPU capacity and want to maximize resource utilization. + max_batch_weight: the maximum total weight of a batch. + element_size_fn: a function that returns the size (weight) of an element. kwargs: 'env_vars' can be used to set environment variables before loading the model. **Supported Versions:** HuggingFaceModelHandler supports transformers>=4.18.0. """ + super().__init__( + min_batch_size=min_batch_size, + max_batch_size=max_batch_size, + max_batch_duration_secs=max_batch_duration_secs, + max_batch_weight=max_batch_weight, + element_size_fn=element_size_fn, + large_model=large_model, + model_copies=model_copies, + **kwargs) self._model_uri = model_uri self._model_class = model_class self._device = device self._inference_fn = inference_fn self._model_config_args = load_model_args if load_model_args else {} - self._batching_kwargs = {} - self._env_vars = kwargs.get("env_vars", {}) - if min_batch_size is not None: - self._batching_kwargs["min_batch_size"] = min_batch_size - if max_batch_size is not None: - self._batching_kwargs["max_batch_size"] = max_batch_size - if max_batch_duration_secs is not None: - self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs - self._share_across_processes = large_model or (model_copies is not None) - self._model_copies = model_copies or 1 self._framework = "" _validate_constructor_args( @@ -547,15 +544,6 @@ def get_num_bytes( return sum( (el.element_size() for tensor in batch for el in tensor.values())) - def batch_elements_kwargs(self): - return self._batching_kwargs - - def share_model_across_processes(self) -> bool: - return self._share_across_processes - - def model_copies(self) -> int: - return self._model_copies - def get_metrics_namespace(self) -> str: """ Returns: @@ -586,6 +574,8 @@ def __init__( max_batch_duration_secs: Optional[int] = None, large_model: bool = False, model_copies: Optional[int] = None, + max_batch_weight: Optional[int] = None, + element_size_fn: Optional[Callable[[Any], int]] = None, **kwargs): """ Implementation of the ModelHandler interface for Hugging Face Pipelines. @@ -629,27 +619,28 @@ def __init__( model_copies: The exact number of models that you would like loaded onto your machine. This can be useful if you exactly know your CPU or GPU capacity and want to maximize resource utilization. + max_batch_weight: the maximum total weight of a batch. + element_size_fn: a function that returns the size (weight) of an element. kwargs: 'env_vars' can be used to set environment variables before loading the model. **Supported Versions:** HuggingFacePipelineModelHandler supports transformers>=4.18.0. """ + super().__init__( + min_batch_size=min_batch_size, + max_batch_size=max_batch_size, + max_batch_duration_secs=max_batch_duration_secs, + max_batch_weight=max_batch_weight, + element_size_fn=element_size_fn, + large_model=large_model, + model_copies=model_copies, + **kwargs) self._task = task self._model = model self._inference_fn = inference_fn self._load_pipeline_args = load_pipeline_args if load_pipeline_args else {} - self._batching_kwargs = {} self._framework = "pt" - self._env_vars = kwargs.get('env_vars', {}) - if min_batch_size is not None: - self._batching_kwargs['min_batch_size'] = min_batch_size - if max_batch_size is not None: - self._batching_kwargs['max_batch_size'] = max_batch_size - if max_batch_duration_secs is not None: - self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs - self._share_across_processes = large_model or (model_copies is not None) - self._model_copies = model_copies or 1 # Check if the device is specified twice. If true then the device parameter # of model handler is overridden. @@ -726,15 +717,6 @@ def get_num_bytes(self, batch: Sequence[str]) -> int: """ return sum(sys.getsizeof(element) for element in batch) - def batch_elements_kwargs(self): - return self._batching_kwargs - - def share_model_across_processes(self) -> bool: - return self._share_across_processes - - def model_copies(self) -> int: - return self._model_copies - def get_metrics_namespace(self) -> str: """ Returns: diff --git a/sdks/python/apache_beam/ml/inference/onnx_inference.py b/sdks/python/apache_beam/ml/inference/onnx_inference.py index 3485866f11c3..4423eed2e407 100644 --- a/sdks/python/apache_beam/ml/inference/onnx_inference.py +++ b/sdks/python/apache_beam/ml/inference/onnx_inference.py @@ -17,7 +17,6 @@ from collections.abc import Callable from collections.abc import Iterable -from collections.abc import Mapping from collections.abc import Sequence from typing import Any from typing import Optional @@ -67,6 +66,8 @@ def __init__( #pylint: disable=dangerous-default-value min_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None, max_batch_duration_secs: Optional[int] = None, + max_batch_weight: Optional[int] = None, + element_size_fn: Optional[Callable[[Any], int]] = None, **kwargs): """ Implementation of the ModelHandler interface for onnx using numpy arrays as input. @@ -91,24 +92,25 @@ def __init__( #pylint: disable=dangerous-default-value max_batch_size: the maximum batch size to use when batching inputs. max_batch_duration_secs: the maximum amount of time to buffer a batch before emitting; used in streaming contexts. + max_batch_weight: the maximum total weight of a batch. + element_size_fn: a function that returns the size (weight) of an element. kwargs: 'env_vars' can be used to set environment variables before loading the model. """ + super().__init__( + min_batch_size=min_batch_size, + max_batch_size=max_batch_size, + max_batch_duration_secs=max_batch_duration_secs, + max_batch_weight=max_batch_weight, + element_size_fn=element_size_fn, + large_model=large_model, + model_copies=model_copies, + **kwargs) self._model_uri = model_uri self._session_options = session_options self._providers = providers self._provider_options = provider_options self._model_inference_fn = inference_fn - self._env_vars = kwargs.get('env_vars', {}) - self._share_across_processes = large_model or (model_copies is not None) - self._model_copies = model_copies or 1 - self._batching_kwargs = {} - if min_batch_size is not None: - self._batching_kwargs["min_batch_size"] = min_batch_size - if max_batch_size is not None: - self._batching_kwargs["max_batch_size"] = max_batch_size - if max_batch_duration_secs is not None: - self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs def load_model(self) -> ort.InferenceSession: """Loads and initializes an onnx inference session for processing.""" @@ -167,12 +169,3 @@ def get_metrics_namespace(self) -> str: A namespace for metrics collected by the RunInference transform. """ return 'BeamML_Onnx' - - def share_model_across_processes(self) -> bool: - return self._share_across_processes - - def model_copies(self) -> int: - return self._model_copies - - def batch_elements_kwargs(self) -> Mapping[str, Any]: - return self._batching_kwargs diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py index affbcd977f5c..63c2a116fcc9 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py @@ -197,6 +197,8 @@ def __init__( large_model: bool = False, model_copies: Optional[int] = None, load_model_args: Optional[dict[str, Any]] = None, + max_batch_weight: Optional[int] = None, + element_size_fn: Optional[Callable[[Any], int]] = None, **kwargs): """Implementation of the ModelHandler interface for PyTorch. @@ -240,12 +242,23 @@ def __init__( GPU capacity and want to maximize resource utilization. load_model_args: a dictionary of parameters passed to the torch.load function to specify custom config for loading models. + max_batch_weight: the maximum total weight of a batch. + element_size_fn: a function that returns the size (weight) of an element. kwargs: 'env_vars' can be used to set environment variables before loading the model. **Supported Versions:** RunInference APIs in Apache Beam have been tested with PyTorch 1.9 and 1.10. """ + super().__init__( + min_batch_size=min_batch_size, + max_batch_size=max_batch_size, + max_batch_duration_secs=max_batch_duration_secs, + max_batch_weight=max_batch_weight, + element_size_fn=element_size_fn, + large_model=large_model, + model_copies=model_copies, + **kwargs) self._state_dict_path = state_dict_path if device == 'GPU': logging.info("Device is set to CUDA") @@ -256,18 +269,8 @@ def __init__( self._model_class = model_class self._model_params = model_params if model_params else {} self._inference_fn = inference_fn - self._batching_kwargs = {} - if min_batch_size is not None: - self._batching_kwargs['min_batch_size'] = min_batch_size - if max_batch_size is not None: - self._batching_kwargs['max_batch_size'] = max_batch_size - if max_batch_duration_secs is not None: - self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs self._torch_script_model_path = torch_script_model_path self._load_model_args = load_model_args if load_model_args else {} - self._env_vars = kwargs.get('env_vars', {}) - self._share_across_processes = large_model or (model_copies is not None) - self._model_copies = model_copies or 1 _validate_constructor_args( state_dict_path=self._state_dict_path, @@ -342,15 +345,6 @@ def get_metrics_namespace(self) -> str: """ return 'BeamML_PyTorch' - def batch_elements_kwargs(self): - return self._batching_kwargs - - def share_model_across_processes(self) -> bool: - return self._share_across_processes - - def model_copies(self) -> int: - return self._model_copies - def default_keyed_tensor_inference_fn( batch: Sequence[dict[str, torch.Tensor]], @@ -435,6 +429,8 @@ def __init__( large_model: bool = False, model_copies: Optional[int] = None, load_model_args: Optional[dict[str, Any]] = None, + max_batch_weight: Optional[int] = None, + element_size_fn: Optional[Callable[[Any], int]] = None, **kwargs): """Implementation of the ModelHandler interface for PyTorch. @@ -483,12 +479,23 @@ def __init__( GPU capacity and want to maximize resource utilization. load_model_args: a dictionary of parameters passed to the torch.load function to specify custom config for loading models. + max_batch_weight: the maximum total weight of a batch. + element_size_fn: a function that returns the size (weight) of an element. kwargs: 'env_vars' can be used to set environment variables before loading the model. **Supported Versions:** RunInference APIs in Apache Beam have been tested on torch>=1.9.0,<1.14.0. """ + super().__init__( + min_batch_size=min_batch_size, + max_batch_size=max_batch_size, + max_batch_duration_secs=max_batch_duration_secs, + max_batch_weight=max_batch_weight, + element_size_fn=element_size_fn, + large_model=large_model, + model_copies=model_copies, + **kwargs) self._state_dict_path = state_dict_path if device == 'GPU': logging.info("Device is set to CUDA") @@ -499,18 +506,8 @@ def __init__( self._model_class = model_class self._model_params = model_params if model_params else {} self._inference_fn = inference_fn - self._batching_kwargs = {} - if min_batch_size is not None: - self._batching_kwargs['min_batch_size'] = min_batch_size - if max_batch_size is not None: - self._batching_kwargs['max_batch_size'] = max_batch_size - if max_batch_duration_secs is not None: - self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs self._torch_script_model_path = torch_script_model_path self._load_model_args = load_model_args if load_model_args else {} - self._env_vars = kwargs.get('env_vars', {}) - self._share_across_processes = large_model or (model_copies is not None) - self._model_copies = model_copies or 1 _validate_constructor_args( state_dict_path=self._state_dict_path, @@ -586,12 +583,3 @@ def get_metrics_namespace(self) -> str: A namespace for metrics collected by the RunInference transform. """ return 'BeamML_PyTorch' - - def batch_elements_kwargs(self): - return self._batching_kwargs - - def share_model_across_processes(self) -> bool: - return self._share_across_processes - - def model_copies(self) -> int: - return self._model_copies diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference.py b/sdks/python/apache_beam/ml/inference/sklearn_inference.py index 84947bec3dfb..e61ef9c194aa 100644 --- a/sdks/python/apache_beam/ml/inference/sklearn_inference.py +++ b/sdks/python/apache_beam/ml/inference/sklearn_inference.py @@ -93,6 +93,8 @@ def __init__( max_batch_duration_secs: Optional[int] = None, large_model: bool = False, model_copies: Optional[int] = None, + max_batch_weight: Optional[int] = None, + element_size_fn: Optional[Callable[[Any], int]] = None, **kwargs): """ Implementation of the ModelHandler interface for scikit-learn using numpy arrays as input. @@ -122,22 +124,23 @@ def __init__( model_copies: The exact number of models that you would like loaded onto your machine. This can be useful if you exactly know your CPU or GPU capacity and want to maximize resource utilization. + max_batch_weight: the maximum total weight of a batch. + element_size_fn: a function that returns the size (weight) of an element. kwargs: 'env_vars' can be used to set environment variables before loading the model. """ + super().__init__( + min_batch_size=min_batch_size, + max_batch_size=max_batch_size, + max_batch_duration_secs=max_batch_duration_secs, + max_batch_weight=max_batch_weight, + element_size_fn=element_size_fn, + large_model=large_model, + model_copies=model_copies, + **kwargs) self._model_uri = model_uri self._model_file_type = model_file_type self._model_inference_fn = inference_fn - self._batching_kwargs = {} - if min_batch_size is not None: - self._batching_kwargs['min_batch_size'] = min_batch_size - if max_batch_size is not None: - self._batching_kwargs['max_batch_size'] = max_batch_size - if max_batch_duration_secs is not None: - self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs - self._env_vars = kwargs.get('env_vars', {}) - self._share_across_processes = large_model or (model_copies is not None) - self._model_copies = model_copies or 1 def load_model(self) -> BaseEstimator: """Loads and initializes a model for processing.""" @@ -187,15 +190,6 @@ def get_metrics_namespace(self) -> str: """ return 'BeamML_Sklearn' - def batch_elements_kwargs(self): - return self._batching_kwargs - - def share_model_across_processes(self) -> bool: - return self._share_across_processes - - def model_copies(self) -> int: - return self._model_copies - PandasInferenceFn = Callable[ [BaseEstimator, Sequence[pandas.DataFrame], Optional[dict[str, Any]]], Any] @@ -228,6 +222,8 @@ def __init__( max_batch_duration_secs: Optional[int] = None, large_model: bool = False, model_copies: Optional[int] = None, + max_batch_weight: Optional[int] = None, + element_size_fn: Optional[Callable[[Any], int]] = None, **kwargs): """Implementation of the ModelHandler interface for scikit-learn that supports pandas dataframes. @@ -260,22 +256,23 @@ def __init__( model_copies: The exact number of models that you would like loaded onto your machine. This can be useful if you exactly know your CPU or GPU capacity and want to maximize resource utilization. + max_batch_weight: the maximum total weight of a batch. + element_size_fn: a function that returns the size (weight) of an element. kwargs: 'env_vars' can be used to set environment variables before loading the model. """ + super().__init__( + min_batch_size=min_batch_size, + max_batch_size=max_batch_size, + max_batch_duration_secs=max_batch_duration_secs, + max_batch_weight=max_batch_weight, + element_size_fn=element_size_fn, + large_model=large_model, + model_copies=model_copies, + **kwargs) self._model_uri = model_uri self._model_file_type = model_file_type self._model_inference_fn = inference_fn - self._batching_kwargs = {} - if min_batch_size is not None: - self._batching_kwargs['min_batch_size'] = min_batch_size - if max_batch_size is not None: - self._batching_kwargs['max_batch_size'] = max_batch_size - if max_batch_duration_secs is not None: - self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs - self._env_vars = kwargs.get('env_vars', {}) - self._share_across_processes = large_model or (model_copies is not None) - self._model_copies = model_copies or 1 def load_model(self) -> BaseEstimator: """Loads and initializes a model for processing.""" @@ -326,12 +323,3 @@ def get_metrics_namespace(self) -> str: A namespace for metrics collected by the RunInference transform. """ return 'BeamML_Sklearn' - - def batch_elements_kwargs(self): - return self._batching_kwargs - - def share_model_across_processes(self) -> bool: - return self._share_across_processes - - def model_copies(self) -> int: - return self._model_copies diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py index 5ce293a06ac0..97b74eb360a7 100644 --- a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py +++ b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py @@ -112,6 +112,8 @@ def __init__( max_batch_duration_secs: Optional[int] = None, large_model: bool = False, model_copies: Optional[int] = None, + max_batch_weight: Optional[int] = None, + element_size_fn: Optional[Callable[[Any], int]] = None, **kwargs): """Implementation of the ModelHandler interface for Tensorflow. @@ -140,28 +142,30 @@ def __init__( model_copies: The exact number of models that you would like loaded onto your machine. This can be useful if you exactly know your CPU or GPU capacity and want to maximize resource utilization. + max_batch_weight: the maximum total weight of a batch. + element_size_fn: a function that returns the size (weight) of an + element. kwargs: 'env_vars' can be used to set environment variables before loading the model. **Supported Versions:** RunInference APIs in Apache Beam have been tested with Tensorflow 2.9, 2.10, 2.11. """ + super().__init__( + min_batch_size=min_batch_size, + max_batch_size=max_batch_size, + max_batch_duration_secs=max_batch_duration_secs, + max_batch_weight=max_batch_weight, + element_size_fn=element_size_fn, + large_model=large_model, + model_copies=model_copies, + **kwargs) self._model_uri = model_uri self._model_type = model_type self._inference_fn = inference_fn self._create_model_fn = create_model_fn - self._env_vars = kwargs.get('env_vars', {}) self._load_model_args = {} if not load_model_args else load_model_args self._custom_weights = custom_weights - self._batching_kwargs = {} - if min_batch_size is not None: - self._batching_kwargs['min_batch_size'] = min_batch_size - if max_batch_size is not None: - self._batching_kwargs['max_batch_size'] = max_batch_size - if max_batch_duration_secs is not None: - self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs - self._share_across_processes = large_model or (model_copies is not None) - self._model_copies = model_copies or 1 def load_model(self) -> tf.Module: """Loads and initializes a Tensorflow model for processing.""" @@ -219,15 +223,6 @@ def get_metrics_namespace(self) -> str: """ return 'BeamML_TF_Numpy' - def batch_elements_kwargs(self): - return self._batching_kwargs - - def share_model_across_processes(self) -> bool: - return self._share_across_processes - - def model_copies(self) -> int: - return self._model_copies - class TFModelHandlerTensor(ModelHandler[tf.Tensor, PredictionResult, tf.Module]): @@ -245,6 +240,8 @@ def __init__( max_batch_duration_secs: Optional[int] = None, large_model: bool = False, model_copies: Optional[int] = None, + max_batch_weight: Optional[int] = None, + element_size_fn: Optional[Callable[[Any], int]] = None, **kwargs): """Implementation of the ModelHandler interface for Tensorflow. @@ -278,28 +275,30 @@ def __init__( model_copies: The exact number of models that you would like loaded onto your machine. This can be useful if you exactly know your CPU or GPU capacity and want to maximize resource utilization. + max_batch_weight: the maximum total weight of a batch. + element_size_fn: a function that returns the size (weight) of an + element. kwargs: 'env_vars' can be used to set environment variables before loading the model. **Supported Versions:** RunInference APIs in Apache Beam have been tested with Tensorflow 2.11. """ + super().__init__( + min_batch_size=min_batch_size, + max_batch_size=max_batch_size, + max_batch_duration_secs=max_batch_duration_secs, + max_batch_weight=max_batch_weight, + element_size_fn=element_size_fn, + large_model=large_model, + model_copies=model_copies, + **kwargs) self._model_uri = model_uri self._model_type = model_type self._inference_fn = inference_fn self._create_model_fn = create_model_fn - self._env_vars = kwargs.get('env_vars', {}) self._load_model_args = {} if not load_model_args else load_model_args self._custom_weights = custom_weights - self._batching_kwargs = {} - if min_batch_size is not None: - self._batching_kwargs['min_batch_size'] = min_batch_size - if max_batch_size is not None: - self._batching_kwargs['max_batch_size'] = max_batch_size - if max_batch_duration_secs is not None: - self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs - self._share_across_processes = large_model or (model_copies is not None) - self._model_copies = model_copies or 1 def load_model(self) -> tf.Module: """Loads and initializes a tensorflow model for processing.""" @@ -356,12 +355,3 @@ def get_metrics_namespace(self) -> str: A namespace for metrics collected by the RunInference transform. """ return 'BeamML_TF_Tensor' - - def batch_elements_kwargs(self): - return self._batching_kwargs - - def share_model_across_processes(self) -> bool: - return self._share_across_processes - - def model_copies(self) -> int: - return self._model_copies diff --git a/sdks/python/apache_beam/ml/inference/tensorrt_inference.py b/sdks/python/apache_beam/ml/inference/tensorrt_inference.py index b575dfa849da..00a61b4934aa 100644 --- a/sdks/python/apache_beam/ml/inference/tensorrt_inference.py +++ b/sdks/python/apache_beam/ml/inference/tensorrt_inference.py @@ -230,6 +230,8 @@ def __init__( large_model: bool = False, model_copies: Optional[int] = None, max_batch_duration_secs: Optional[int] = None, + max_batch_weight: Optional[int] = None, + element_size_fn: Optional[Callable[[Any], int]] = None, **kwargs): """Implementation of the ModelHandler interface for TensorRT. @@ -258,6 +260,8 @@ def __init__( GPU capacity and want to maximize resource utilization. max_batch_duration_secs: the maximum amount of time to buffer a batch before emitting; used in streaming contexts. + max_batch_weight: the maximum total weight of a batch. + element_size_fn: a function that returns the size (weight) of an element. kwargs: Additional arguments like 'engine_path' and 'onnx_path' are currently supported. 'env_vars' can be used to set environment variables before loading the model. @@ -265,25 +269,20 @@ def __init__( See https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/ for details """ - self.min_batch_size = min_batch_size - self.max_batch_size = max_batch_size - self.max_batch_duration_secs = max_batch_duration_secs + super().__init__( + min_batch_size=min_batch_size, + max_batch_size=max_batch_size, + max_batch_duration_secs=max_batch_duration_secs, + max_batch_weight=max_batch_weight, + element_size_fn=element_size_fn, + large_model=large_model, + model_copies=model_copies, + **kwargs) self.inference_fn = inference_fn if 'engine_path' in kwargs: self.engine_path = kwargs.get('engine_path') elif 'onnx_path' in kwargs: self.onnx_path = kwargs.get('onnx_path') - self._env_vars = kwargs.get('env_vars', {}) - self._share_across_processes = large_model or (model_copies is not None) - self._model_copies = model_copies or 1 - - def batch_elements_kwargs(self): - """Sets min_batch_size and max_batch_size of a TensorRT engine.""" - return { - 'min_batch_size': self.min_batch_size, - 'max_batch_size': self.max_batch_size, - 'max_batch_duration_secs': self.max_batch_duration_secs - } def load_model(self) -> TensorRTEngine: """Loads and initializes a TensorRT engine for processing.""" @@ -336,12 +335,6 @@ def get_metrics_namespace(self) -> str: """ return 'BeamML_TensorRT' - def share_model_across_processes(self) -> bool: - return self._share_across_processes - - def model_copies(self) -> int: - return self._model_copies - def validate_inference_args(self, inference_args: Optional[dict[str, Any]]): """ Currently, this model handler does not support inference args. Given that, diff --git a/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py b/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py index cd3d0beb593c..02827f9578f1 100644 --- a/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py +++ b/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py @@ -17,6 +17,7 @@ import json import logging +from collections.abc import Callable from collections.abc import Iterable from collections.abc import Mapping from collections.abc import Sequence @@ -69,6 +70,8 @@ def __init__( min_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None, max_batch_duration_secs: Optional[int] = None, + max_batch_weight: Optional[int] = None, + element_size_fn: Optional[Callable[[Any], int]] = None, **kwargs): """Implementation of the ModelHandler interface for Vertex AI. **NOTE:** This API and its implementation are under development and @@ -107,8 +110,11 @@ def __init__( inputs. max_batch_size: optional. the maximum batch size to use when batching inputs. - max_batch_duration_secs: optional. the maximum amount of time to buffer + max_batch_duration_secs: optional. the maximum amount of time to buffer a batch before emitting; used in streaming contexts. + max_batch_weight: optional. the maximum total weight of a batch. + element_size_fn: optional. a function that returns the size (weight) + of an element. """ self._batching_kwargs = {} self._env_vars = kwargs.get('env_vars', {}) @@ -119,6 +125,10 @@ def __init__( self._batching_kwargs["max_batch_size"] = max_batch_size if max_batch_duration_secs is not None: self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs + if max_batch_weight is not None: + self._batching_kwargs["max_batch_weight"] = max_batch_weight + if element_size_fn is not None: + self._batching_kwargs['element_size_fn'] = element_size_fn if private and network is None: raise ValueError( diff --git a/sdks/python/apache_beam/ml/inference/vllm_inference.py b/sdks/python/apache_beam/ml/inference/vllm_inference.py index bdbee9e51fd5..918b49155606 100644 --- a/sdks/python/apache_beam/ml/inference/vllm_inference.py +++ b/sdks/python/apache_beam/ml/inference/vllm_inference.py @@ -25,6 +25,7 @@ import threading import time import uuid +from collections.abc import Callable from collections.abc import Iterable from collections.abc import Sequence from dataclasses import dataclass @@ -175,7 +176,13 @@ class VLLMCompletionsModelHandler(ModelHandler[str, def __init__( self, model_name: str, - vllm_server_kwargs: Optional[dict[str, str]] = None): + vllm_server_kwargs: Optional[dict[str, str]] = None, + *, + min_batch_size: Optional[int] = None, + max_batch_size: Optional[int] = None, + max_batch_duration_secs: Optional[int] = None, + max_batch_weight: Optional[int] = None, + element_size_fn: Optional[Callable[[Any], int]] = None): """Implementation of the ModelHandler interface for vLLM using text as input. @@ -194,10 +201,24 @@ def __init__( `{'echo': 'true'}` to prepend new messages with the previous message. For a list of possible kwargs, see https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#extra-parameters-for-completions-api + min_batch_size: optional. the minimum batch size to use when batching + inputs. + max_batch_size: optional. the maximum batch size to use when batching + inputs. + max_batch_duration_secs: optional. the maximum amount of time to buffer + a batch before emitting; used in streaming contexts. + max_batch_weight: optional. the maximum total weight of a batch. + element_size_fn: optional. a function that returns the size (weight) of + an element. """ + super().__init__( + min_batch_size=min_batch_size, + max_batch_size=max_batch_size, + max_batch_duration_secs=max_batch_duration_secs, + max_batch_weight=max_batch_weight, + element_size_fn=element_size_fn) self._model_name = model_name self._vllm_server_kwargs: dict[str, str] = vllm_server_kwargs or {} - self._env_vars = {} def load_model(self) -> _VLLMModelServer: return _VLLMModelServer(self._model_name, self._vllm_server_kwargs) @@ -253,7 +274,13 @@ def __init__( self, model_name: str, chat_template_path: Optional[str] = None, - vllm_server_kwargs: Optional[dict[str, str]] = None): + vllm_server_kwargs: Optional[dict[str, str]] = None, + *, + min_batch_size: Optional[int] = None, + max_batch_size: Optional[int] = None, + max_batch_duration_secs: Optional[int] = None, + max_batch_weight: Optional[int] = None, + element_size_fn: Optional[Callable[[Any], int]] = None): """ Implementation of the ModelHandler interface for vLLM using previous messages as input. @@ -277,10 +304,24 @@ def __init__( `{'echo': 'true'}` to prepend new messages with the previous message. For a list of possible kwargs, see https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#extra-parameters-for-chat-api + min_batch_size: optional. the minimum batch size to use when batching + inputs. + max_batch_size: optional. the maximum batch size to use when batching + inputs. + max_batch_duration_secs: optional. the maximum amount of time to buffer + a batch before emitting; used in streaming contexts. + max_batch_weight: optional. the maximum total weight of a batch. + element_size_fn: optional. a function that returns the size (weight) of + an element. """ + super().__init__( + min_batch_size=min_batch_size, + max_batch_size=max_batch_size, + max_batch_duration_secs=max_batch_duration_secs, + max_batch_weight=max_batch_weight, + element_size_fn=element_size_fn) self._model_name = model_name self._vllm_server_kwargs: dict[str, str] = vllm_server_kwargs or {} - self._env_vars = {} self._chat_template_path = chat_template_path self._chat_file = f'template-{uuid.uuid4().hex}.jinja' diff --git a/sdks/python/apache_beam/ml/inference/xgboost_inference.py b/sdks/python/apache_beam/ml/inference/xgboost_inference.py index 10289b076416..9d7413685113 100644 --- a/sdks/python/apache_beam/ml/inference/xgboost_inference.py +++ b/sdks/python/apache_beam/ml/inference/xgboost_inference.py @@ -19,7 +19,6 @@ from abc import ABC from collections.abc import Callable from collections.abc import Iterable -from collections.abc import Mapping from collections.abc import Sequence from typing import Any from typing import Optional @@ -79,6 +78,8 @@ def __init__( min_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None, max_batch_duration_secs: Optional[int] = None, + max_batch_weight: Optional[int] = None, + element_size_fn: Optional[Callable[[Any], int]] = None, **kwargs): """Implementation of the ModelHandler interface for XGBoost. @@ -103,8 +104,11 @@ def __init__( inputs. max_batch_size: optional. the maximum batch size to use when batching inputs. - max_batch_duration_secs: optional. the maximum amount of time to buffer + max_batch_duration_secs: optional. the maximum amount of time to buffer a batch before emitting; used in streaming contexts. + max_batch_weight: optional. the maximum total weight of a batch. + element_size_fn: optional. a function that returns the size (weight) + of an element. kwargs: 'env_vars' can be used to set environment variables before loading the model. @@ -121,17 +125,16 @@ def __init__( and should not be instantiated directly. (See instead XGBoostModelHandlerNumpy, XGBoostModelHandlerPandas, etc.) """ + super().__init__( + min_batch_size=min_batch_size, + max_batch_size=max_batch_size, + max_batch_duration_secs=max_batch_duration_secs, + max_batch_weight=max_batch_weight, + element_size_fn=element_size_fn, + **kwargs) self._model_class = model_class self._model_state = model_state self._inference_fn = inference_fn - self._env_vars = kwargs.get('env_vars', {}) - self._batching_kwargs = {} - if min_batch_size is not None: - self._batching_kwargs["min_batch_size"] = min_batch_size - if max_batch_size is not None: - self._batching_kwargs["max_batch_size"] = max_batch_size - if max_batch_duration_secs is not None: - self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs def load_model(self) -> Union[xgboost.Booster, xgboost.XGBModel]: model = self._model_class() @@ -146,9 +149,6 @@ def load_model(self) -> Union[xgboost.Booster, xgboost.XGBModel]: def get_metrics_namespace(self) -> str: return 'BeamML_XGBoost' - def batch_elements_kwargs(self) -> Mapping[str, Any]: - return self._batching_kwargs - class XGBoostModelHandlerNumpy(XGBoostModelHandler[numpy.ndarray, PredictionResult,