Implement ONNX model serialisation phase 2#227
Implement ONNX model serialisation phase 2#227RAMitchell wants to merge 22 commits intorapidsai:mainfrom
Conversation
There was a problem hiding this comment.
Pull Request Overview
This PR implements ONNX model serialization (phase 2) by combining model initialization with individual ONNX models into a serialized estimator. Key changes include updating dependencies to include onnxruntime, adding new test functions to verify ONNX predictions, and modifying each model’s to_onnx method (and related functions) to accept an explicit data type parameter and use consistent input/output naming.
Reviewed Changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| pyproject.toml, dependencies.yaml, conda YAML | Updated dependencies to include onnxruntime>=1.21 |
| legateboost/test/test_onnx.py | Added new functions for ONNX predictions and updated test naming |
| legateboost/models/tree.py, nn.py, linear.py, krr.py | Modified to_onnx methods to accept X_dtype and standardized input/output names |
| legateboost/legateboost.py | Introduced _make_onnx_init and updated to_onnx to merge ONNX models |
seberg
left a comment
There was a problem hiding this comment.
The approach looks good to me, commenting since I suspect the classification needs a bit work. This is tricky to test!
Thinking about the predict_function= argument, but it seems good to me. (The predict_raw seems a bit duplicating the "predict".)
I should look closer at some of the ONNX code probably.
| assert onnx_pred.dtype == pred.dtype | ||
| assert pred.shape == onnx_pred.shape | ||
| number_wrong = np.sum( | ||
| np.abs(pred - onnx_pred) > (1e-2 if X.dtype == np.float32 else 1e-5) |
There was a problem hiding this comment.
Predictions are alway similarly sized (i.e. 0-1)? Just curious if it would make sense to allow a relative deviation.
|
Cupynumeric test failures here require https://github.com/nv-legate/legate.internal/pull/2177 - we need to wait for nightlies to become available. |
Implements #225
This PR combines the model initialisation term as well as individual onnx models together into a serialised estimator.
I will likely only implement the predict_raw method here and leave predict_proba for anther PR as it will require e.g. softmax transforms.