diff --git a/pytensor/link/pytorch/dispatch/scalar.py b/pytensor/link/pytorch/dispatch/scalar.py index 6a1c6b235e..a64dcf23ba 100644 --- a/pytensor/link/pytorch/dispatch/scalar.py +++ b/pytensor/link/pytorch/dispatch/scalar.py @@ -5,6 +5,7 @@ from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.scalar.basic import ( Cast, + Clip, Invert, ScalarOp, ) @@ -71,6 +72,14 @@ def pytorch_funcify_Softplus(op, node, **kwargs): return torch.nn.Softplus() +@pytorch_funcify.register(Clip) +def pytorch_funcify_Clip(op, node, **kwargs): + def clip(x, min_val, max_val): + return torch.where(x < min_val, min_val, torch.where(x > max_val, max_val, x)) + + return clip + + @pytorch_funcify.register(ScalarLoop) def pytorch_funicify_ScalarLoop(op, node, **kwargs): update = pytorch_funcify(op.fgraph, **kwargs) diff --git a/tests/link/pytorch/test_elemwise.py b/tests/link/pytorch/test_elemwise.py index dacf6f8699..1db6c67e35 100644 --- a/tests/link/pytorch/test_elemwise.py +++ b/tests/link/pytorch/test_elemwise.py @@ -151,6 +151,19 @@ def test_cast(): assert res.dtype == np.int32 +@pytest.mark.parametrize( + "x_val, min_val, max_val", + [ + (np.array([5.0], dtype=config.floatX), 0.0, 10.0), + (np.array([-5.0], dtype=config.floatX), 0.0, 10.0), + ], +) +def test_clip(x_val, min_val, max_val): + x = pt.tensor("x", shape=x_val.shape, dtype=config.floatX) + out = pt.clip(x, min_val, max_val) + compare_pytorch_and_py([x], [out], [x_val]) + + def test_vmap_elemwise(): from pytensor.link.pytorch.dispatch.basic import pytorch_funcify