From 9b3f291201a0846fd4611a5c54214c73a53b43e4 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Thu, 18 Dec 2025 11:37:59 -0600 Subject: [PATCH 1/2] Added PyTorch clip dispatch --- pytensor/link/pytorch/dispatch/scalar.py | 9 +++++++++ tests/link/pytorch/test_elemwise.py | 17 +++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/pytensor/link/pytorch/dispatch/scalar.py b/pytensor/link/pytorch/dispatch/scalar.py index 6a1c6b235e..a64dcf23ba 100644 --- a/pytensor/link/pytorch/dispatch/scalar.py +++ b/pytensor/link/pytorch/dispatch/scalar.py @@ -5,6 +5,7 @@ from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.scalar.basic import ( Cast, + Clip, Invert, ScalarOp, ) @@ -71,6 +72,14 @@ def pytorch_funcify_Softplus(op, node, **kwargs): return torch.nn.Softplus() +@pytorch_funcify.register(Clip) +def pytorch_funcify_Clip(op, node, **kwargs): + def clip(x, min_val, max_val): + return torch.where(x < min_val, min_val, torch.where(x > max_val, max_val, x)) + + return clip + + @pytorch_funcify.register(ScalarLoop) def pytorch_funicify_ScalarLoop(op, node, **kwargs): update = pytorch_funcify(op.fgraph, **kwargs) diff --git a/tests/link/pytorch/test_elemwise.py b/tests/link/pytorch/test_elemwise.py index dacf6f8699..08845b6880 100644 --- a/tests/link/pytorch/test_elemwise.py +++ b/tests/link/pytorch/test_elemwise.py @@ -151,6 +151,23 @@ def test_cast(): assert res.dtype == np.int32 +@pytest.mark.parametrize( + "x_val, min_val, max_val", + [ + (np.array([5.0], dtype=config.floatX), 0.0, 10.0), + (np.array([-5.0], dtype=config.floatX), 0.0, 10.0), + (np.array([15.0], dtype=config.floatX), 0.0, 10.0), + (np.array([5.0], dtype=config.floatX), 10.0, 0.0), + (np.array([-5.0, 5.0, 15.0], dtype=config.floatX), 0.0, 10.0), + (np.array([[-5.0, 5.0], [15.0, 7.0]], dtype=config.floatX), 0.0, 10.0), + ], +) +def test_clip(x_val, min_val, max_val): + x = pt.tensor("x", shape=x_val.shape, dtype=config.floatX) + out = pt.clip(x, min_val, max_val) + compare_pytorch_and_py([x], [out], [x_val]) + + def test_vmap_elemwise(): from pytensor.link.pytorch.dispatch.basic import pytorch_funcify From 7f40247f14b5b82f68138b8e5d0d91040046534b Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Thu, 18 Dec 2025 15:29:03 -0600 Subject: [PATCH 2/2] Removed most test cases --- tests/link/pytorch/test_elemwise.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/link/pytorch/test_elemwise.py b/tests/link/pytorch/test_elemwise.py index 08845b6880..1db6c67e35 100644 --- a/tests/link/pytorch/test_elemwise.py +++ b/tests/link/pytorch/test_elemwise.py @@ -156,10 +156,6 @@ def test_cast(): [ (np.array([5.0], dtype=config.floatX), 0.0, 10.0), (np.array([-5.0], dtype=config.floatX), 0.0, 10.0), - (np.array([15.0], dtype=config.floatX), 0.0, 10.0), - (np.array([5.0], dtype=config.floatX), 10.0, 0.0), - (np.array([-5.0, 5.0, 15.0], dtype=config.floatX), 0.0, 10.0), - (np.array([[-5.0, 5.0], [15.0, 7.0]], dtype=config.floatX), 0.0, 10.0), ], ) def test_clip(x_val, min_val, max_val):