From 2b49e3376eaaa7fe4c9b0649a7704cb8f920be6a Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Fri, 13 Feb 2026 15:17:20 -0800 Subject: [PATCH 01/20] Add mypy, ruff and pyright configuration --- pyproject.toml | 20 +++++++ uv.lock | 142 ++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 161 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index fea893c6..eafba3f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,4 +26,24 @@ dev = [ "pytest>=8.0", "pytest-asyncio>=0.24", "rich>=14.2.0", + "mypy>=1.11", + "ruff>=0.8", ] + +[tool.mypy] +python_version = "3.13" +strict = true +plugins = ["pydantic.mypy"] + +[tool.pyright] +pythonVersion = "3.13" +typeCheckingMode = "standard" +reportMissingTypeStubs = false +reportUnusedCallResult = false + +[tool.ruff] +target-version = "py313" +src = ["src"] + +[tool.ruff.lint] +select = ["E", "F", "I", "UP", "B", "SIM"] diff --git a/uv.lock b/uv.lock index 469f021b..35dd8533 100644 --- a/uv.lock +++ b/uv.lock @@ -376,6 +376,66 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/41/45/1a4ed80516f02155c51f51e8cedb3c1902296743db0bbc66608a0db2814f/jsonschema_specifications-2025.9.1-py3-none-any.whl", hash = "sha256:98802fee3a11ee76ecaca44429fda8a41bff98b00a0f2838151b113f210cc6fe", size = 18437, upload-time = "2025-09-08T01:34:57.871Z" }, ] +[[package]] +name = "librt" +version = "0.8.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8a/3f/4ca7dd7819bf8ff303aca39c3c60e5320e46e766ab7f7dd627d3b9c11bdf/librt-0.8.0.tar.gz", hash = "sha256:cb74cdcbc0103fc988e04e5c58b0b31e8e5dd2babb9182b6f9490488eb36324b", size = 177306, upload-time = "2026-02-12T14:53:54.743Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fb/53/f3bc0c4921adb0d4a5afa0656f2c0fbe20e18e3e0295e12985b9a5dc3f55/librt-0.8.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:17269dd2745dbe8e42475acb28e419ad92dfa38214224b1b01020b8cac70b645", size = 66511, upload-time = "2026-02-12T14:52:30.34Z" }, + { url = "https://files.pythonhosted.org/packages/89/4b/4c96357432007c25a1b5e363045373a6c39481e49f6ba05234bb59a839c1/librt-0.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f4617cef654fca552f00ce5ffdf4f4b68770f18950e4246ce94629b789b92467", size = 68628, upload-time = "2026-02-12T14:52:31.491Z" }, + { url = "https://files.pythonhosted.org/packages/47/16/52d75374d1012e8fc709216b5eaa25f471370e2a2331b8be00f18670a6c7/librt-0.8.0-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:5cb11061a736a9db45e3c1293cfcb1e3caf205912dfa085734ba750f2197ff9a", size = 198941, upload-time = "2026-02-12T14:52:32.489Z" }, + { url = "https://files.pythonhosted.org/packages/fc/11/d5dd89e5a2228567b1228d8602d896736247424484db086eea6b8010bcba/librt-0.8.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b4bb00bd71b448f16749909b08a0ff16f58b079e2261c2e1000f2bbb2a4f0a45", size = 210009, upload-time = "2026-02-12T14:52:33.634Z" }, + { url = "https://files.pythonhosted.org/packages/49/d8/fc1a92a77c3020ee08ce2dc48aed4b42ab7c30fb43ce488d388673b0f164/librt-0.8.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:95a719a049f0eefaf1952673223cf00d442952273cbd20cf2ed7ec423a0ef58d", size = 224461, upload-time = "2026-02-12T14:52:34.868Z" }, + { url = "https://files.pythonhosted.org/packages/7f/98/eb923e8b028cece924c246104aa800cf72e02d023a8ad4ca87135b05a2fe/librt-0.8.0-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:bd32add59b58fba3439d48d6f36ac695830388e3da3e92e4fc26d2d02670d19c", size = 217538, upload-time = "2026-02-12T14:52:36.078Z" }, + { url = "https://files.pythonhosted.org/packages/fd/67/24e80ab170674a1d8ee9f9a83081dca4635519dbd0473b8321deecddb5be/librt-0.8.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4f764b2424cb04524ff7a486b9c391e93f93dc1bd8305b2136d25e582e99aa2f", size = 225110, upload-time = "2026-02-12T14:52:37.301Z" }, + { url = "https://files.pythonhosted.org/packages/d8/c7/6fbdcbd1a6e5243c7989c21d68ab967c153b391351174b4729e359d9977f/librt-0.8.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:f04ca50e847abc486fa8f4107250566441e693779a5374ba211e96e238f298b9", size = 217758, upload-time = "2026-02-12T14:52:38.89Z" }, + { url = "https://files.pythonhosted.org/packages/4b/bd/4d6b36669db086e3d747434430073e14def032dd58ad97959bf7e2d06c67/librt-0.8.0-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:9ab3a3475a55b89b87ffd7e6665838e8458e0b596c22e0177e0f961434ec474a", size = 218384, upload-time = "2026-02-12T14:52:40.637Z" }, + { url = "https://files.pythonhosted.org/packages/50/2d/afe966beb0a8f179b132f3e95c8dd90738a23e9ebdba10f89a3f192f9366/librt-0.8.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3e36a8da17134ffc29373775d88c04832f9ecfab1880470661813e6c7991ef79", size = 241187, upload-time = "2026-02-12T14:52:43.55Z" }, + { url = "https://files.pythonhosted.org/packages/02/d0/6172ea4af2b538462785ab1a68e52d5c99cfb9866a7caf00fdf388299734/librt-0.8.0-cp312-cp312-win32.whl", hash = "sha256:4eb5e06ebcc668677ed6389164f52f13f71737fc8be471101fa8b4ce77baeb0c", size = 54914, upload-time = "2026-02-12T14:52:44.676Z" }, + { url = "https://files.pythonhosted.org/packages/d4/cb/ceb6ed6175612a4337ad49fb01ef594712b934b4bc88ce8a63554832eb44/librt-0.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:0a33335eb59921e77c9acc05d0e654e4e32e45b014a4d61517897c11591094f8", size = 62020, upload-time = "2026-02-12T14:52:45.676Z" }, + { url = "https://files.pythonhosted.org/packages/f1/7e/61701acbc67da74ce06ddc7ba9483e81c70f44236b2d00f6a4bfee1aacbf/librt-0.8.0-cp312-cp312-win_arm64.whl", hash = "sha256:24a01c13a2a9bdad20997a4443ebe6e329df063d1978bbe2ebbf637878a46d1e", size = 52443, upload-time = "2026-02-12T14:52:47.218Z" }, + { url = "https://files.pythonhosted.org/packages/6d/32/3edb0bcb4113a9c8bdcd1750663a54565d255027657a5df9d90f13ee07fa/librt-0.8.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:7f820210e21e3a8bf8fde2ae3c3d10106d4de9ead28cbfdf6d0f0f41f5b12fa1", size = 66522, upload-time = "2026-02-12T14:52:48.219Z" }, + { url = "https://files.pythonhosted.org/packages/30/ab/e8c3d05e281f5d405ebdcc5bc8ab36df23e1a4b40ac9da8c3eb9928b72b9/librt-0.8.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:4831c44b8919e75ca0dfb52052897c1ef59fdae19d3589893fbd068f1e41afbf", size = 68658, upload-time = "2026-02-12T14:52:50.351Z" }, + { url = "https://files.pythonhosted.org/packages/7c/d3/74a206c47b7748bbc8c43942de3ed67de4c231156e148b4f9250869593df/librt-0.8.0-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:88c6e75540f1f10f5e0fc5e87b4b6c290f0e90d1db8c6734f670840494764af8", size = 199287, upload-time = "2026-02-12T14:52:51.938Z" }, + { url = "https://files.pythonhosted.org/packages/fa/29/ef98a9131cf12cb95771d24e4c411fda96c89dc78b09c2de4704877ebee4/librt-0.8.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9646178cd794704d722306c2c920c221abbf080fede3ba539d5afdec16c46dad", size = 210293, upload-time = "2026-02-12T14:52:53.128Z" }, + { url = "https://files.pythonhosted.org/packages/5b/3e/89b4968cb08c53d4c2d8b02517081dfe4b9e07a959ec143d333d76899f6c/librt-0.8.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6e1af31a710e17891d9adf0dbd9a5fcd94901a3922a96499abdbf7ce658f4e01", size = 224801, upload-time = "2026-02-12T14:52:54.367Z" }, + { url = "https://files.pythonhosted.org/packages/6d/28/f38526d501f9513f8b48d78e6be4a241e15dd4b000056dc8b3f06ee9ce5d/librt-0.8.0-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:507e94f4bec00b2f590fbe55f48cd518a208e2474a3b90a60aa8f29136ddbada", size = 218090, upload-time = "2026-02-12T14:52:55.758Z" }, + { url = "https://files.pythonhosted.org/packages/02/ec/64e29887c5009c24dc9c397116c680caffc50286f62bd99c39e3875a2854/librt-0.8.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f1178e0de0c271231a660fbef9be6acdfa1d596803464706862bef6644cc1cae", size = 225483, upload-time = "2026-02-12T14:52:57.375Z" }, + { url = "https://files.pythonhosted.org/packages/ee/16/7850bdbc9f1a32d3feff2708d90c56fc0490b13f1012e438532781aa598c/librt-0.8.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:71fc517efc14f75c2f74b1f0a5d5eb4a8e06aa135c34d18eaf3522f4a53cd62d", size = 218226, upload-time = "2026-02-12T14:52:58.534Z" }, + { url = "https://files.pythonhosted.org/packages/1c/4a/166bffc992d65ddefa7c47052010a87c059b44a458ebaf8f5eba384b0533/librt-0.8.0-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:0583aef7e9a720dd40f26a2ad5a1bf2ccbb90059dac2b32ac516df232c701db3", size = 218755, upload-time = "2026-02-12T14:52:59.701Z" }, + { url = "https://files.pythonhosted.org/packages/da/5d/9aeee038bcc72a9cfaaee934463fe9280a73c5440d36bd3175069d2cb97b/librt-0.8.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5d0f76fc73480d42285c609c0ea74d79856c160fa828ff9aceab574ea4ecfd7b", size = 241617, upload-time = "2026-02-12T14:53:00.966Z" }, + { url = "https://files.pythonhosted.org/packages/64/ff/2bec6b0296b9d0402aa6ec8540aa19ebcb875d669c37800cb43d10d9c3a3/librt-0.8.0-cp313-cp313-win32.whl", hash = "sha256:e79dbc8f57de360f0ed987dc7de7be814b4803ef0e8fc6d3ff86e16798c99935", size = 54966, upload-time = "2026-02-12T14:53:02.042Z" }, + { url = "https://files.pythonhosted.org/packages/08/8d/bf44633b0182996b2c7ea69a03a5c529683fa1f6b8e45c03fe874ff40d56/librt-0.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:25b3e667cbfc9000c4740b282df599ebd91dbdcc1aa6785050e4c1d6be5329ab", size = 62000, upload-time = "2026-02-12T14:53:03.822Z" }, + { url = "https://files.pythonhosted.org/packages/5c/fd/c6472b8e0eac0925001f75e366cf5500bcb975357a65ef1f6b5749389d3a/librt-0.8.0-cp313-cp313-win_arm64.whl", hash = "sha256:e9a3a38eb4134ad33122a6d575e6324831f930a771d951a15ce232e0237412c2", size = 52496, upload-time = "2026-02-12T14:53:04.889Z" }, + { url = "https://files.pythonhosted.org/packages/e0/13/79ebfe30cd273d7c0ce37a5f14dc489c5fb8b722a008983db2cfd57270bb/librt-0.8.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:421765e8c6b18e64d21c8ead315708a56fc24f44075059702e421d164575fdda", size = 66078, upload-time = "2026-02-12T14:53:06.085Z" }, + { url = "https://files.pythonhosted.org/packages/4b/8f/d11eca40b62a8d5e759239a80636386ef88adecb10d1a050b38cc0da9f9e/librt-0.8.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:48f84830a8f8ad7918afd743fd7c4eb558728bceab7b0e38fd5a5cf78206a556", size = 68309, upload-time = "2026-02-12T14:53:07.121Z" }, + { url = "https://files.pythonhosted.org/packages/9c/b4/f12ee70a3596db40ff3c88ec9eaa4e323f3b92f77505b4d900746706ec6a/librt-0.8.0-cp314-cp314-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:9f09d4884f882baa39a7e36bbf3eae124c4ca2a223efb91e567381d1c55c6b06", size = 196804, upload-time = "2026-02-12T14:53:08.164Z" }, + { url = "https://files.pythonhosted.org/packages/8b/7e/70dbbdc0271fd626abe1671ad117bcd61a9a88cdc6a10ccfbfc703db1873/librt-0.8.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:693697133c3b32aa9b27f040e3691be210e9ac4d905061859a9ed519b1d5a376", size = 206915, upload-time = "2026-02-12T14:53:09.333Z" }, + { url = "https://files.pythonhosted.org/packages/79/13/6b9e05a635d4327608d06b3c1702166e3b3e78315846373446cf90d7b0bf/librt-0.8.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c5512aae4648152abaf4d48b59890503fcbe86e85abc12fb9b096fe948bdd816", size = 221200, upload-time = "2026-02-12T14:53:10.68Z" }, + { url = "https://files.pythonhosted.org/packages/35/6c/e19a3ac53e9414de43a73d7507d2d766cd22d8ca763d29a4e072d628db42/librt-0.8.0-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:995d24caa6bbb34bcdd4a41df98ac6d1af637cfa8975cb0790e47d6623e70e3e", size = 214640, upload-time = "2026-02-12T14:53:12.342Z" }, + { url = "https://files.pythonhosted.org/packages/30/f0/23a78464788619e8c70f090cfd099cce4973eed142c4dccb99fc322283fd/librt-0.8.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:b9aef96d7593584e31ef6ac1eb9775355b0099fee7651fae3a15bc8657b67b52", size = 221980, upload-time = "2026-02-12T14:53:13.603Z" }, + { url = "https://files.pythonhosted.org/packages/03/32/38e21420c5d7aa8a8bd2c7a7d5252ab174a5a8aaec8b5551968979b747bf/librt-0.8.0-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:4f6e975377fbc4c9567cb33ea9ab826031b6c7ec0515bfae66a4fb110d40d6da", size = 215146, upload-time = "2026-02-12T14:53:14.8Z" }, + { url = "https://files.pythonhosted.org/packages/bb/00/bd9ecf38b1824c25240b3ad982fb62c80f0a969e6679091ba2b3afb2b510/librt-0.8.0-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:daae5e955764be8fd70a93e9e5133c75297f8bce1e802e1d3683b98f77e1c5ab", size = 215203, upload-time = "2026-02-12T14:53:16.087Z" }, + { url = "https://files.pythonhosted.org/packages/b9/60/7559bcc5279d37810b98d4a52616febd7b8eef04391714fd6bdf629598b1/librt-0.8.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:7bd68cebf3131bb920d5984f75fe302d758db33264e44b45ad139385662d7bc3", size = 237937, upload-time = "2026-02-12T14:53:17.236Z" }, + { url = "https://files.pythonhosted.org/packages/41/cc/be3e7da88f1abbe2642672af1dc00a0bccece11ca60241b1883f3018d8d5/librt-0.8.0-cp314-cp314-win32.whl", hash = "sha256:1e6811cac1dcb27ca4c74e0ca4a5917a8e06db0d8408d30daee3a41724bfde7a", size = 50685, upload-time = "2026-02-12T14:53:18.888Z" }, + { url = "https://files.pythonhosted.org/packages/38/27/e381d0df182a8f61ef1f6025d8b138b3318cc9d18ad4d5f47c3bf7492523/librt-0.8.0-cp314-cp314-win_amd64.whl", hash = "sha256:178707cda89d910c3b28bf5aa5f69d3d4734e0f6ae102f753ad79edef83a83c7", size = 57872, upload-time = "2026-02-12T14:53:19.942Z" }, + { url = "https://files.pythonhosted.org/packages/c5/0c/ca9dfdf00554a44dea7d555001248269a4bab569e1590a91391feb863fa4/librt-0.8.0-cp314-cp314-win_arm64.whl", hash = "sha256:3e8b77b5f54d0937b26512774916041756c9eb3e66f1031971e626eea49d0bf4", size = 48056, upload-time = "2026-02-12T14:53:21.473Z" }, + { url = "https://files.pythonhosted.org/packages/f2/ed/6cc9c4ad24f90c8e782193c7b4a857408fd49540800613d1356c63567d7b/librt-0.8.0-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:789911e8fa40a2e82f41120c936b1965f3213c67f5a483fc5a41f5839a05dcbb", size = 68307, upload-time = "2026-02-12T14:53:22.498Z" }, + { url = "https://files.pythonhosted.org/packages/84/d8/0e94292c6b3e00b6eeea39dd44d5703d1ec29b6dafce7eea19dc8f1aedbd/librt-0.8.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:2b37437e7e4ef5e15a297b36ba9e577f73e29564131d86dd75875705e97402b5", size = 70999, upload-time = "2026-02-12T14:53:23.603Z" }, + { url = "https://files.pythonhosted.org/packages/0e/f4/6be1afcbdeedbdbbf54a7c9d73ad43e1bf36897cebf3978308cd64922e02/librt-0.8.0-cp314-cp314t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:671a6152edf3b924d98a5ed5e6982ec9cb30894085482acadce0975f031d4c5c", size = 220782, upload-time = "2026-02-12T14:53:25.133Z" }, + { url = "https://files.pythonhosted.org/packages/f0/8d/f306e8caa93cfaf5c6c9e0d940908d75dc6af4fd856baa5535c922ee02b1/librt-0.8.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8992ca186a1678107b0af3d0c9303d8c7305981b9914989b9788319ed4d89546", size = 235420, upload-time = "2026-02-12T14:53:27.047Z" }, + { url = "https://files.pythonhosted.org/packages/d6/f2/65d86bd462e9c351326564ca805e8457442149f348496e25ccd94583ffa2/librt-0.8.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:001e5330093d887b8b9165823eca6c5c4db183fe4edea4fdc0680bbac5f46944", size = 246452, upload-time = "2026-02-12T14:53:28.341Z" }, + { url = "https://files.pythonhosted.org/packages/03/94/39c88b503b4cb3fcbdeb3caa29672b6b44ebee8dcc8a54d49839ac280f3f/librt-0.8.0-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:d920789eca7ef71df7f31fd547ec0d3002e04d77f30ba6881e08a630e7b2c30e", size = 238891, upload-time = "2026-02-12T14:53:29.625Z" }, + { url = "https://files.pythonhosted.org/packages/e3/c6/6c0d68190893d01b71b9569b07a1c811e280c0065a791249921c83dc0290/librt-0.8.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:82fb4602d1b3e303a58bfe6165992b5a78d823ec646445356c332cd5f5bbaa61", size = 250249, upload-time = "2026-02-12T14:53:30.93Z" }, + { url = "https://files.pythonhosted.org/packages/52/7a/f715ed9e039035d0ea637579c3c0155ab3709a7046bc408c0fb05d337121/librt-0.8.0-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:4d3e38797eb482485b486898f89415a6ab163bc291476bd95712e42cf4383c05", size = 240642, upload-time = "2026-02-12T14:53:32.174Z" }, + { url = "https://files.pythonhosted.org/packages/c2/3c/609000a333debf5992efe087edc6467c1fdbdddca5b610355569bbea9589/librt-0.8.0-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:a905091a13e0884701226860836d0386b88c72ce5c2fdfba6618e14c72be9f25", size = 239621, upload-time = "2026-02-12T14:53:33.39Z" }, + { url = "https://files.pythonhosted.org/packages/b9/df/87b0673d5c395a8f34f38569c116c93142d4dc7e04af2510620772d6bd4f/librt-0.8.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:375eda7acfce1f15f5ed56cfc960669eefa1ec8732e3e9087c3c4c3f2066759c", size = 262986, upload-time = "2026-02-12T14:53:34.617Z" }, + { url = "https://files.pythonhosted.org/packages/09/7f/6bbbe9dcda649684773aaea78b87fff4d7e59550fbc2877faa83612087a3/librt-0.8.0-cp314-cp314t-win32.whl", hash = "sha256:2ccdd20d9a72c562ffb73098ac411de351b53a6fbb3390903b2d33078ef90447", size = 51328, upload-time = "2026-02-12T14:53:36.15Z" }, + { url = "https://files.pythonhosted.org/packages/bb/f3/e1981ab6fa9b41be0396648b5850267888a752d025313a9e929c4856208e/librt-0.8.0-cp314-cp314t-win_amd64.whl", hash = "sha256:25e82d920d4d62ad741592fcf8d0f3bda0e3fc388a184cb7d2f566c681c5f7b9", size = 58719, upload-time = "2026-02-12T14:53:37.183Z" }, + { url = "https://files.pythonhosted.org/packages/94/d1/433b3c06e78f23486fe4fdd19bc134657eb30997d2054b0dbf52bbf3382e/librt-0.8.0-cp314-cp314t-win_arm64.whl", hash = "sha256:92249938ab744a5890580d3cb2b22042f0dce71cdaa7c1369823df62bedf7cbc", size = 48753, upload-time = "2026-02-12T14:53:38.539Z" }, +] + [[package]] name = "markdown-it-py" version = "4.0.0" @@ -422,6 +482,48 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" }, ] +[[package]] +name = "mypy" +version = "1.19.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "librt", marker = "platform_python_implementation != 'PyPy'" }, + { name = "mypy-extensions" }, + { name = "pathspec" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f5/db/4efed9504bc01309ab9c2da7e352cc223569f05478012b5d9ece38fd44d2/mypy-1.19.1.tar.gz", hash = "sha256:19d88bb05303fe63f71dd2c6270daca27cb9401c4ca8255fe50d1d920e0eb9ba", size = 3582404, upload-time = "2025-12-15T05:03:48.42Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/06/8a/19bfae96f6615aa8a0604915512e0289b1fad33d5909bf7244f02935d33a/mypy-1.19.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:a8174a03289288c1f6c46d55cef02379b478bfbc8e358e02047487cad44c6ca1", size = 13206053, upload-time = "2025-12-15T05:03:46.622Z" }, + { url = "https://files.pythonhosted.org/packages/a5/34/3e63879ab041602154ba2a9f99817bb0c85c4df19a23a1443c8986e4d565/mypy-1.19.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ffcebe56eb09ff0c0885e750036a095e23793ba6c2e894e7e63f6d89ad51f22e", size = 12219134, upload-time = "2025-12-15T05:03:24.367Z" }, + { url = "https://files.pythonhosted.org/packages/89/cc/2db6f0e95366b630364e09845672dbee0cbf0bbe753a204b29a944967cd9/mypy-1.19.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b64d987153888790bcdb03a6473d321820597ab8dd9243b27a92153c4fa50fd2", size = 12731616, upload-time = "2025-12-15T05:02:44.725Z" }, + { url = "https://files.pythonhosted.org/packages/00/be/dd56c1fd4807bc1eba1cf18b2a850d0de7bacb55e158755eb79f77c41f8e/mypy-1.19.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c35d298c2c4bba75feb2195655dfea8124d855dfd7343bf8b8c055421eaf0cf8", size = 13620847, upload-time = "2025-12-15T05:03:39.633Z" }, + { url = "https://files.pythonhosted.org/packages/6d/42/332951aae42b79329f743bf1da088cd75d8d4d9acc18fbcbd84f26c1af4e/mypy-1.19.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:34c81968774648ab5ac09c29a375fdede03ba253f8f8287847bd480782f73a6a", size = 13834976, upload-time = "2025-12-15T05:03:08.786Z" }, + { url = "https://files.pythonhosted.org/packages/6f/63/e7493e5f90e1e085c562bb06e2eb32cae27c5057b9653348d38b47daaecc/mypy-1.19.1-cp312-cp312-win_amd64.whl", hash = "sha256:b10e7c2cd7870ba4ad9b2d8a6102eb5ffc1f16ca35e3de6bfa390c1113029d13", size = 10118104, upload-time = "2025-12-15T05:03:10.834Z" }, + { url = "https://files.pythonhosted.org/packages/de/9f/a6abae693f7a0c697dbb435aac52e958dc8da44e92e08ba88d2e42326176/mypy-1.19.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e3157c7594ff2ef1634ee058aafc56a82db665c9438fd41b390f3bde1ab12250", size = 13201927, upload-time = "2025-12-15T05:02:29.138Z" }, + { url = "https://files.pythonhosted.org/packages/9a/a4/45c35ccf6e1c65afc23a069f50e2c66f46bd3798cbe0d680c12d12935caa/mypy-1.19.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:bdb12f69bcc02700c2b47e070238f42cb87f18c0bc1fc4cdb4fb2bc5fd7a3b8b", size = 12206730, upload-time = "2025-12-15T05:03:01.325Z" }, + { url = "https://files.pythonhosted.org/packages/05/bb/cdcf89678e26b187650512620eec8368fded4cfd99cfcb431e4cdfd19dec/mypy-1.19.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f859fb09d9583a985be9a493d5cfc5515b56b08f7447759a0c5deaf68d80506e", size = 12724581, upload-time = "2025-12-15T05:03:20.087Z" }, + { url = "https://files.pythonhosted.org/packages/d1/32/dd260d52babf67bad8e6770f8e1102021877ce0edea106e72df5626bb0ec/mypy-1.19.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c9a6538e0415310aad77cb94004ca6482330fece18036b5f360b62c45814c4ef", size = 13616252, upload-time = "2025-12-15T05:02:49.036Z" }, + { url = "https://files.pythonhosted.org/packages/71/d0/5e60a9d2e3bd48432ae2b454b7ef2b62a960ab51292b1eda2a95edd78198/mypy-1.19.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:da4869fc5e7f62a88f3fe0b5c919d1d9f7ea3cef92d3689de2823fd27e40aa75", size = 13840848, upload-time = "2025-12-15T05:02:55.95Z" }, + { url = "https://files.pythonhosted.org/packages/98/76/d32051fa65ecf6cc8c6610956473abdc9b4c43301107476ac03559507843/mypy-1.19.1-cp313-cp313-win_amd64.whl", hash = "sha256:016f2246209095e8eda7538944daa1d60e1e8134d98983b9fc1e92c1fc0cb8dd", size = 10135510, upload-time = "2025-12-15T05:02:58.438Z" }, + { url = "https://files.pythonhosted.org/packages/de/eb/b83e75f4c820c4247a58580ef86fcd35165028f191e7e1ba57128c52782d/mypy-1.19.1-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:06e6170bd5836770e8104c8fdd58e5e725cfeb309f0a6c681a811f557e97eac1", size = 13199744, upload-time = "2025-12-15T05:03:30.823Z" }, + { url = "https://files.pythonhosted.org/packages/94/28/52785ab7bfa165f87fcbb61547a93f98bb20e7f82f90f165a1f69bce7b3d/mypy-1.19.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:804bd67b8054a85447c8954215a906d6eff9cabeabe493fb6334b24f4bfff718", size = 12215815, upload-time = "2025-12-15T05:02:42.323Z" }, + { url = "https://files.pythonhosted.org/packages/0a/c6/bdd60774a0dbfb05122e3e925f2e9e846c009e479dcec4821dad881f5b52/mypy-1.19.1-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:21761006a7f497cb0d4de3d8ef4ca70532256688b0523eee02baf9eec895e27b", size = 12740047, upload-time = "2025-12-15T05:03:33.168Z" }, + { url = "https://files.pythonhosted.org/packages/32/2a/66ba933fe6c76bd40d1fe916a83f04fed253152f451a877520b3c4a5e41e/mypy-1.19.1-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:28902ee51f12e0f19e1e16fbe2f8f06b6637f482c459dd393efddd0ec7f82045", size = 13601998, upload-time = "2025-12-15T05:03:13.056Z" }, + { url = "https://files.pythonhosted.org/packages/e3/da/5055c63e377c5c2418760411fd6a63ee2b96cf95397259038756c042574f/mypy-1.19.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:481daf36a4c443332e2ae9c137dfee878fcea781a2e3f895d54bd3002a900957", size = 13807476, upload-time = "2025-12-15T05:03:17.977Z" }, + { url = "https://files.pythonhosted.org/packages/cd/09/4ebd873390a063176f06b0dbf1f7783dd87bd120eae7727fa4ae4179b685/mypy-1.19.1-cp314-cp314-win_amd64.whl", hash = "sha256:8bb5c6f6d043655e055be9b542aa5f3bdd30e4f3589163e85f93f3640060509f", size = 10281872, upload-time = "2025-12-15T05:03:05.549Z" }, + { url = "https://files.pythonhosted.org/packages/8d/f4/4ce9a05ce5ded1de3ec1c1d96cf9f9504a04e54ce0ed55cfa38619a32b8d/mypy-1.19.1-py3-none-any.whl", hash = "sha256:f1235f5ea01b7db5468d53ece6aaddf1ad0b88d9e7462b86ef96fe04995d7247", size = 2471239, upload-time = "2025-12-15T05:03:07.248Z" }, +] + +[[package]] +name = "mypy-extensions" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a2/6e/371856a3fb9d31ca8dac321cda606860fa4548858c0cc45d9d1d4ca2628b/mypy_extensions-1.1.0.tar.gz", hash = "sha256:52e68efc3284861e772bbcd66823fde5ae21fd2fdb51c62a211403730b916558", size = 6343, upload-time = "2025-04-22T14:54:24.164Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505", size = 4963, upload-time = "2025-04-22T14:54:22.983Z" }, +] + [[package]] name = "openai" version = "2.14.0" @@ -450,6 +552,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/b9/c538f279a4e237a006a2c98387d081e9eb060d203d8ed34467cc0f0b9b53/packaging-26.0-py3-none-any.whl", hash = "sha256:b36f1fef9334a5588b4166f8bcd26a14e521f2b55e6b9de3aaa80d3ff7a37529", size = 74366, upload-time = "2026-01-21T20:50:37.788Z" }, ] +[[package]] +name = "pathspec" +version = "1.0.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/fa/36/e27608899f9b8d4dff0617b2d9ab17ca5608956ca44461ac14ac48b44015/pathspec-1.0.4.tar.gz", hash = "sha256:0210e2ae8a21a9137c0d470578cb0e595af87edaa6ebf12ff176f14a02e0e645", size = 131200, upload-time = "2026-01-27T03:59:46.938Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ef/3c/2c197d226f9ea224a9ab8d197933f9da0ae0aac5b6e0f884e2b8d9c8e9f7/pathspec-1.0.4-py3-none-any.whl", hash = "sha256:fb6ae2fd4e7c921a165808a552060e722767cfa526f99ca5156ed2ce45a5c723", size = 55206, upload-time = "2026-01-27T03:59:45.137Z" }, +] + [[package]] name = "pluggy" version = "1.6.0" @@ -762,6 +873,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d0/02/fa464cdfbe6b26e0600b62c528b72d8608f5cc49f96b8d6e38c95d60c676/rpds_py-0.30.0-cp314-cp314t-win_amd64.whl", hash = "sha256:27f4b0e92de5bfbc6f86e43959e6edd1425c33b5e69aab0984a72047f2bcf1e3", size = 226532, upload-time = "2025-11-30T20:24:14.634Z" }, ] +[[package]] +name = "ruff" +version = "0.15.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/04/dc/4e6ac71b511b141cf626357a3946679abeba4cf67bc7cc5a17920f31e10d/ruff-0.15.1.tar.gz", hash = "sha256:c590fe13fb57c97141ae975c03a1aedb3d3156030cabd740d6ff0b0d601e203f", size = 4540855, upload-time = "2026-02-12T23:09:09.998Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/23/bf/e6e4324238c17f9d9120a9d60aa99a7daaa21204c07fcd84e2ef03bb5fd1/ruff-0.15.1-py3-none-linux_armv6l.whl", hash = "sha256:b101ed7cf4615bda6ffe65bdb59f964e9f4a0d3f85cbf0e54f0ab76d7b90228a", size = 10367819, upload-time = "2026-02-12T23:09:03.598Z" }, + { url = "https://files.pythonhosted.org/packages/b3/ea/c8f89d32e7912269d38c58f3649e453ac32c528f93bb7f4219258be2e7ed/ruff-0.15.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:939c995e9277e63ea632cc8d3fae17aa758526f49a9a850d2e7e758bfef46602", size = 10798618, upload-time = "2026-02-12T23:09:22.928Z" }, + { url = "https://files.pythonhosted.org/packages/5e/0f/1d0d88bc862624247d82c20c10d4c0f6bb2f346559d8af281674cf327f15/ruff-0.15.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:1d83466455fdefe60b8d9c8df81d3c1bbb2115cede53549d3b522ce2bc703899", size = 10148518, upload-time = "2026-02-12T23:08:58.339Z" }, + { url = "https://files.pythonhosted.org/packages/f5/c8/291c49cefaa4a9248e986256df2ade7add79388fe179e0691be06fae6f37/ruff-0.15.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a9457e3c3291024866222b96108ab2d8265b477e5b1534c7ddb1810904858d16", size = 10518811, upload-time = "2026-02-12T23:09:31.865Z" }, + { url = "https://files.pythonhosted.org/packages/c3/1a/f5707440e5ae43ffa5365cac8bbb91e9665f4a883f560893829cf16a606b/ruff-0.15.1-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:92c92b003e9d4f7fbd33b1867bb15a1b785b1735069108dfc23821ba045b29bc", size = 10196169, upload-time = "2026-02-12T23:09:17.306Z" }, + { url = "https://files.pythonhosted.org/packages/2a/ff/26ddc8c4da04c8fd3ee65a89c9fb99eaa5c30394269d424461467be2271f/ruff-0.15.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1fe5c41ab43e3a06778844c586251eb5a510f67125427625f9eb2b9526535779", size = 10990491, upload-time = "2026-02-12T23:09:25.503Z" }, + { url = "https://files.pythonhosted.org/packages/fc/00/50920cb385b89413f7cdb4bb9bc8fc59c1b0f30028d8bccc294189a54955/ruff-0.15.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:66a6dd6df4d80dc382c6484f8ce1bcceb55c32e9f27a8b94c32f6c7331bf14fb", size = 11843280, upload-time = "2026-02-12T23:09:19.88Z" }, + { url = "https://files.pythonhosted.org/packages/5d/6d/2f5cad8380caf5632a15460c323ae326f1e1a2b5b90a6ee7519017a017ca/ruff-0.15.1-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6a4a42cbb8af0bda9bcd7606b064d7c0bc311a88d141d02f78920be6acb5aa83", size = 11274336, upload-time = "2026-02-12T23:09:14.907Z" }, + { url = "https://files.pythonhosted.org/packages/a3/1d/5f56cae1d6c40b8a318513599b35ea4b075d7dc1cd1d04449578c29d1d75/ruff-0.15.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4ab064052c31dddada35079901592dfba2e05f5b1e43af3954aafcbc1096a5b2", size = 11137288, upload-time = "2026-02-12T23:09:07.475Z" }, + { url = "https://files.pythonhosted.org/packages/cd/20/6f8d7d8f768c93b0382b33b9306b3b999918816da46537d5a61635514635/ruff-0.15.1-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:5631c940fe9fe91f817a4c2ea4e81f47bee3ca4aa646134a24374f3c19ad9454", size = 11070681, upload-time = "2026-02-12T23:08:55.43Z" }, + { url = "https://files.pythonhosted.org/packages/9a/67/d640ac76069f64cdea59dba02af2e00b1fa30e2103c7f8d049c0cff4cafd/ruff-0.15.1-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:68138a4ba184b4691ccdc39f7795c66b3c68160c586519e7e8444cf5a53e1b4c", size = 10486401, upload-time = "2026-02-12T23:09:27.927Z" }, + { url = "https://files.pythonhosted.org/packages/65/3d/e1429f64a3ff89297497916b88c32a5cc88eeca7e9c787072d0e7f1d3e1e/ruff-0.15.1-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:518f9af03bfc33c03bdb4cb63fabc935341bb7f54af500f92ac309ecfbba6330", size = 10197452, upload-time = "2026-02-12T23:09:12.147Z" }, + { url = "https://files.pythonhosted.org/packages/78/83/e2c3bade17dad63bf1e1c2ffaf11490603b760be149e1419b07049b36ef2/ruff-0.15.1-py3-none-musllinux_1_2_i686.whl", hash = "sha256:da79f4d6a826caaea95de0237a67e33b81e6ec2e25fc7e1993a4015dffca7c61", size = 10693900, upload-time = "2026-02-12T23:09:34.418Z" }, + { url = "https://files.pythonhosted.org/packages/a1/27/fdc0e11a813e6338e0706e8b39bb7a1d61ea5b36873b351acee7e524a72a/ruff-0.15.1-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:3dd86dccb83cd7d4dcfac303ffc277e6048600dfc22e38158afa208e8bf94a1f", size = 11227302, upload-time = "2026-02-12T23:09:36.536Z" }, + { url = "https://files.pythonhosted.org/packages/f6/58/ac864a75067dcbd3b95be5ab4eb2b601d7fbc3d3d736a27e391a4f92a5c1/ruff-0.15.1-py3-none-win32.whl", hash = "sha256:660975d9cb49b5d5278b12b03bb9951d554543a90b74ed5d366b20e2c57c2098", size = 10462555, upload-time = "2026-02-12T23:09:29.899Z" }, + { url = "https://files.pythonhosted.org/packages/e0/5e/d4ccc8a27ecdb78116feac4935dfc39d1304536f4296168f91ed3ec00cd2/ruff-0.15.1-py3-none-win_amd64.whl", hash = "sha256:c820fef9dd5d4172a6570e5721704a96c6679b80cf7be41659ed439653f62336", size = 11599956, upload-time = "2026-02-12T23:09:01.157Z" }, + { url = "https://files.pythonhosted.org/packages/2a/07/5bda6a85b220c64c65686bc85bd0bbb23b29c62b3a9f9433fa55f17cda93/ruff-0.15.1-py3-none-win_arm64.whl", hash = "sha256:5ff7d5f0f88567850f45081fac8f4ec212be8d0b963e385c3f7d0d2eb4899416", size = 10874604, upload-time = "2026-02-12T23:09:05.515Z" }, +] + [[package]] name = "sniffio" version = "1.3.1" @@ -861,7 +997,7 @@ wheels = [ [[package]] name = "vercel-ai-sdk" -version = "0.0.1.dev3" +version = "0.0.1.dev4" source = { editable = "." } dependencies = [ { name = "anthropic" }, @@ -874,10 +1010,12 @@ dependencies = [ [package.dev-dependencies] dev = [ + { name = "mypy" }, { name = "pytest" }, { name = "pytest-asyncio" }, { name = "python-dotenv" }, { name = "rich" }, + { name = "ruff" }, ] [package.metadata] @@ -892,10 +1030,12 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ + { name = "mypy", specifier = ">=1.11" }, { name = "pytest", specifier = ">=8.0" }, { name = "pytest-asyncio", specifier = ">=0.24" }, { name = "python-dotenv", specifier = ">=1.2.1" }, { name = "rich", specifier = ">=14.2.0" }, + { name = "ruff", specifier = ">=0.8" }, ] [[package]] From 6887aa6910d05a63c2e708567cad0ea8841137a3 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Thu, 19 Feb 2026 10:32:42 -0800 Subject: [PATCH 02/20] Refactor core/tools.py to enable typed tool outputs --- src/vercel_ai_sdk/core/tools.py | 131 +++++++++++++++++--------------- 1 file changed, 68 insertions(+), 63 deletions(-) diff --git a/src/vercel_ai_sdk/core/tools.py b/src/vercel_ai_sdk/core/tools.py index ce1a474f..48cde37b 100644 --- a/src/vercel_ai_sdk/core/tools.py +++ b/src/vercel_ai_sdk/core/tools.py @@ -1,16 +1,21 @@ from __future__ import annotations import inspect -from collections.abc import Awaitable -from typing import Any, Callable, get_args, get_origin, get_type_hints, overload +from collections.abc import Awaitable, Callable +from typing import ( + Any, + get_args, + get_origin, + get_type_hints, +) import pydantic # Module-level tool registry - populated at decoration time -_tool_registry: dict[str, "Tool"] = {} +_tool_registry: dict[str, Tool] = {} -def get_tool(name: str) -> "Tool | None": +def get_tool(name: str) -> Tool | None: """Look up a tool by name from the global registry.""" return _tool_registry.get(name) @@ -41,76 +46,76 @@ def _is_optional(param_type: type) -> bool: class ToolSchema(pydantic.BaseModel): - """What the LLM sees: name, description, and JSON Schema for parameters. - - This is the serializable, function-free description of a tool. - Use this when you need tool metadata without the callable (e.g. passing - tool schemas across a serialization boundary to an LLM activity). - """ + """What the LLM sees: name, description, and JSON Schema for parameters.""" name: str description: str - tool_schema: dict[str, Any] + param_schema: dict[str, Any] + return_type: Any -class Tool(ToolSchema): - """A ToolSchema plus the async callable that implements it.""" +class Tool[**P, R]: + def __init__(self, fn: Callable[P, Awaitable[R]], schema: ToolSchema) -> None: + self._fn = fn + self.schema = schema - fn: Callable[..., Awaitable[Any]] = pydantic.Field(exclude=True) + async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: + return await self._fn(*args, **kwargs) - model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) + @property + def name(self) -> str: + return self.schema.name + @property + def description(self) -> str: + return self.schema.description -@overload -def tool(fn: Callable[..., Awaitable[Any]]) -> Tool: ... + @property + def param_schema(self) -> dict[str, Any]: + return self.schema.param_schema -@overload -def tool(fn: None = None) -> Callable[[Callable[..., Awaitable[Any]]], Tool]: ... +def tool[**P, R](fn: Callable[P, Awaitable[R]]) -> Tool[P, R]: + """Decorator to define a tool from an async function.""" + # 1. build tool schema by parsing the function + sig = inspect.signature(fn) + hints = get_type_hints(fn) if hasattr(fn, "__annotations__") else {} -def tool( - fn: Callable[..., Awaitable[Any]] | None = None, -) -> Tool | Callable[[Callable[..., Awaitable[Any]]], Tool]: - """Decorator to define a tool from an async function.""" + properties = {} + required = [] + + for param_name, param in sig.parameters.items(): + param_type = hints.get(param_name, str) + + # Skip Runtime-typed parameters - they're injected, not from LLM + if _is_runtime_type(param_type): + continue + + properties[param_name] = _get_param_schema(param_type) + + if param.default is inspect.Parameter.empty and not _is_optional(param_type): + required.append(param_name) + + parameters = { + "type": "object", + "properties": properties, + } + + if required: + parameters["required"] = required + + # 2. instantiate the tool + + schema = ToolSchema( + name=fn.__name__, + description=inspect.getdoc(fn) or "", + param_schema=parameters, + return_type=hints.get("return", None), + ) + + t = Tool(fn=fn, schema=schema) - def make_tool(f: Callable[..., Awaitable[Any]]) -> Tool: - sig = inspect.signature(f) - hints = get_type_hints(f) if hasattr(f, "__annotations__") else {} - - properties = {} - required = [] - - for param_name, param in sig.parameters.items(): - param_type = hints.get(param_name, str) - - # Skip Runtime-typed parameters - they're injected, not from LLM - if _is_runtime_type(param_type): - continue - properties[param_name] = _get_param_schema(param_type) - - if param.default is inspect.Parameter.empty and not _is_optional( - param_type - ): - required.append(param_name) - - parameters = { - "type": "object", - "properties": properties, - } - if required: - parameters["required"] = required - - t = Tool( - name=f.__name__, - description=inspect.getdoc(f) or "", - tool_schema=parameters, - fn=f, - ) - # Register in global registry - _tool_registry[t.name] = t - return t - - if fn is not None: - return make_tool(fn) - return make_tool + # Register in global registry + _tool_registry[t.name] = t + return t From 074d371ae4934aea209ccb59ce08000ac9f57532 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Thu, 19 Feb 2026 11:25:18 -0800 Subject: [PATCH 03/20] Refactor runtime.py to match the updated tool API --- src/vercel_ai_sdk/core/runtime.py | 11 +---- src/vercel_ai_sdk/core/tools.py | 75 ++++++++++++++----------------- 2 files changed, 35 insertions(+), 51 deletions(-) diff --git a/src/vercel_ai_sdk/core/runtime.py b/src/vercel_ai_sdk/core/runtime.py index b055b9f7..10203eae 100644 --- a/src/vercel_ai_sdk/core/runtime.py +++ b/src/vercel_ai_sdk/core/runtime.py @@ -225,15 +225,8 @@ async def execute_tool( if tool is None: raise ValueError(f"Tool not found in registry: {tool_call.tool_name}") - kwargs: dict[str, Any] = ( - json.loads(tool_call.tool_args) if tool_call.tool_args else {} - ) - - # Inject runtime if the tool has a Runtime-typed parameter - if rt and (runtime_param := _find_runtime_param(tool.fn)): - kwargs[runtime_param] = rt - - result = await tool.fn(**kwargs) + # TODO catch validation error and json error + result = await tool.validate_and_call(tool_call.tool_args, rt) tool_call.set_result(result) # Record for checkpoint diff --git a/src/vercel_ai_sdk/core/tools.py b/src/vercel_ai_sdk/core/tools.py index 48cde37b..fbdedb58 100644 --- a/src/vercel_ai_sdk/core/tools.py +++ b/src/vercel_ai_sdk/core/tools.py @@ -1,16 +1,15 @@ from __future__ import annotations import inspect +import json from collections.abc import Awaitable, Callable -from typing import ( - Any, - get_args, - get_origin, - get_type_hints, -) +from typing import TYPE_CHECKING, Any, get_type_hints import pydantic +if TYPE_CHECKING: + from . import runtime as runtime_ + # Module-level tool registry - populated at decoration time _tool_registry: dict[str, Tool] = {} @@ -28,23 +27,6 @@ def _is_runtime_type(hint: Any) -> bool: return hint is Runtime -def _get_param_schema(param_type: type) -> dict[str, Any]: - """Get JSON schema for a Python type using Pydantic's TypeAdapter.""" - schema = pydantic.TypeAdapter(param_type).json_schema() - if "$defs" in schema and len(schema.get("$defs", {})) == 0: - del schema["$defs"] - return schema - - -def _is_optional(param_type: type) -> bool: - """Check if a type is Optional (Union with None).""" - origin = get_origin(param_type) - if origin is not None: - args = get_args(param_type) - return type(None) in args - return False - - class ToolSchema(pydantic.BaseModel): """What the LLM sees: name, description, and JSON Schema for parameters.""" @@ -55,13 +37,31 @@ class ToolSchema(pydantic.BaseModel): class Tool[**P, R]: - def __init__(self, fn: Callable[P, Awaitable[R]], schema: ToolSchema) -> None: + def __init__( + self, + fn: Callable[P, Awaitable[R]], + validator: type[pydantic.BaseModel], + schema: ToolSchema, + ) -> None: self._fn = fn + self._validator = validator self.schema = schema async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: return await self._fn(*args, **kwargs) + async def validate_and_call( + self, json_str: str, runtime: runtime_.Runtime | None + ) -> R: + kwargs = json.loads(json_str) + + if runtime and (rt_param := runtime_._find_runtime_param(self._fn)): + kwargs[rt_param] = runtime + + # validate llm-generated inputs + self._validator.model_validate(kwargs) + return await self(**kwargs) # type: ignore[arg-type] + @property def name(self) -> str: return self.schema.name @@ -82,40 +82,31 @@ def tool[**P, R](fn: Callable[P, Awaitable[R]]) -> Tool[P, R]: sig = inspect.signature(fn) hints = get_type_hints(fn) if hasattr(fn, "__annotations__") else {} - properties = {} - required = [] + fields = {} for param_name, param in sig.parameters.items(): param_type = hints.get(param_name, str) - # Skip Runtime-typed parameters - they're injected, not from LLM if _is_runtime_type(param_type): continue + if param.default is inspect.Parameter.empty: + fields[param_name] = (param_type, ...) + else: + fields[param_name] = (param_type, param.default) - properties[param_name] = _get_param_schema(param_type) - - if param.default is inspect.Parameter.empty and not _is_optional(param_type): - required.append(param_name) - - parameters = { - "type": "object", - "properties": properties, - } - - if required: - parameters["required"] = required + validator = pydantic.create_model(f"{fn.__name__}_Args", **fields) # 2. instantiate the tool schema = ToolSchema( name=fn.__name__, description=inspect.getdoc(fn) or "", - param_schema=parameters, + param_schema=validator.model_json_schema(), return_type=hints.get("return", None), ) - t = Tool(fn=fn, schema=schema) + t = Tool(fn=fn, validator=validator, schema=schema) - # Register in global registry + # 3. register in global registry _tool_registry[t.name] = t return t From 88d50b304bb47df0d21667ec7e1eae255da3d42b Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Thu, 19 Feb 2026 12:36:38 -0800 Subject: [PATCH 04/20] Refactor remaining callsites that interact with tools in some way --- src/vercel_ai_sdk/agent/agent.py | 1 - src/vercel_ai_sdk/agent/tools.py | 1 - src/vercel_ai_sdk/anthropic/__init__.py | 8 ++++---- src/vercel_ai_sdk/core/llm.py | 4 ++-- src/vercel_ai_sdk/core/runtime.py | 4 ++-- src/vercel_ai_sdk/core/tools.py | 17 ++++++++++------- src/vercel_ai_sdk/mcp/client.py | 18 +++++++++++++----- src/vercel_ai_sdk/openai/__init__.py | 8 ++++---- tests/conftest.py | 4 ++-- tests/core/test_tools.py | 24 ++++++++++++------------ tests/mcp/test_client.py | 6 +++--- 11 files changed, 52 insertions(+), 43 deletions(-) diff --git a/src/vercel_ai_sdk/agent/agent.py b/src/vercel_ai_sdk/agent/agent.py index 4d7ed2ab..400aa88d 100644 --- a/src/vercel_ai_sdk/agent/agent.py +++ b/src/vercel_ai_sdk/agent/agent.py @@ -3,7 +3,6 @@ import asyncio import dataclasses import traceback - import pydantic import vercel_ai_sdk as ai diff --git a/src/vercel_ai_sdk/agent/tools.py b/src/vercel_ai_sdk/agent/tools.py index 7b3d905b..422fd22d 100644 --- a/src/vercel_ai_sdk/agent/tools.py +++ b/src/vercel_ai_sdk/agent/tools.py @@ -1,7 +1,6 @@ from __future__ import annotations import contextvars - import vercel_ai_sdk as ai from . import proto diff --git a/src/vercel_ai_sdk/anthropic/__init__.py b/src/vercel_ai_sdk/anthropic/__init__.py index 1d6281c3..d4f89390 100644 --- a/src/vercel_ai_sdk/anthropic/__init__.py +++ b/src/vercel_ai_sdk/anthropic/__init__.py @@ -10,13 +10,13 @@ from .. import core -def _tools_to_anthropic(tools: Sequence[core.tools.ToolSchema]) -> list[dict[str, Any]]: +def _tools_to_anthropic(tools: Sequence[core.tools.Tool]) -> list[dict[str, Any]]: """Convert internal Tool objects to Anthropic tool schema format.""" return [ { "name": tool.name, "description": tool.description, - "input_schema": tool.tool_schema, + "input_schema": tool.param_schema, } for tool in tools ] @@ -117,7 +117,7 @@ def __init__( async def stream_events( self, messages: list[core.messages.Message], - tools: Sequence[core.tools.ToolSchema] | None = None, + tools: Sequence[core.tools.Tool] | None = None, ) -> AsyncGenerator[core.llm.StreamEvent, None]: """Yield raw stream events from Anthropic API.""" system_prompt, anthropic_messages = _messages_to_anthropic(messages) @@ -206,7 +206,7 @@ async def stream_events( async def stream( self, messages: list[core.messages.Message], - tools: Sequence[core.tools.ToolSchema] | None = None, + tools: Sequence[core.tools.Tool] | None = None, ) -> AsyncGenerator[core.messages.Message, None]: """Stream Messages (uses StreamProcessor internally).""" handler = core.llm.StreamHandler() diff --git a/src/vercel_ai_sdk/core/llm.py b/src/vercel_ai_sdk/core/llm.py index 8a402fee..7b857164 100644 --- a/src/vercel_ai_sdk/core/llm.py +++ b/src/vercel_ai_sdk/core/llm.py @@ -215,7 +215,7 @@ class LanguageModel(abc.ABC): async def stream( self, messages: list[messages_.Message], - tools: Sequence[tools_.ToolSchema] | None = None, + tools: Sequence[tools_.Tool] | None = None, ) -> AsyncGenerator[messages_.Message, None]: raise NotImplementedError yield @@ -223,7 +223,7 @@ async def stream( async def buffer( self, messages: list[messages_.Message], - tools: Sequence[tools_.ToolSchema] | None = None, + tools: Sequence[tools_.Tool] | None = None, ) -> messages_.Message: """Drain the stream and return the final message.""" final = None diff --git a/src/vercel_ai_sdk/core/runtime.py b/src/vercel_ai_sdk/core/runtime.py index 10203eae..b0ace4d9 100644 --- a/src/vercel_ai_sdk/core/runtime.py +++ b/src/vercel_ai_sdk/core/runtime.py @@ -187,7 +187,7 @@ def _find_runtime_param(fn: Callable[..., Any]) -> str | None: async def stream_step( llm: llm_.LanguageModel, messages: list[messages_.Message], - tools: Sequence[tools_.ToolSchema] | None = None, + tools: Sequence[tools_.Tool] | None = None, label: str | None = None, ) -> AsyncGenerator[messages_.Message, None]: """Single LLM call that streams to Runtime.""" @@ -243,7 +243,7 @@ async def execute_tool( async def stream_loop( llm: llm_.LanguageModel, messages: list[messages_.Message], - tools: Sequence[tools_.ToolSchema], + tools: Sequence[tools_.Tool], label: str | None = None, ) -> streams_.StreamResult: """Agent loop: stream LLM, execute tools, repeat until done.""" diff --git a/src/vercel_ai_sdk/core/tools.py b/src/vercel_ai_sdk/core/tools.py index fbdedb58..b4bc7bc6 100644 --- a/src/vercel_ai_sdk/core/tools.py +++ b/src/vercel_ai_sdk/core/tools.py @@ -40,8 +40,8 @@ class Tool[**P, R]: def __init__( self, fn: Callable[P, Awaitable[R]], - validator: type[pydantic.BaseModel], schema: ToolSchema, + validator: type[pydantic.BaseModel] | None = None, ) -> None: self._fn = fn self._validator = validator @@ -53,14 +53,17 @@ async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: async def validate_and_call( self, json_str: str, runtime: runtime_.Runtime | None ) -> R: - kwargs = json.loads(json_str) + from .runtime import _find_runtime_param - if runtime and (rt_param := runtime_._find_runtime_param(self._fn)): + kwargs = json.loads(json_str) if json_str else {} + + if runtime and (rt_param := _find_runtime_param(self._fn)): kwargs[rt_param] = runtime - # validate llm-generated inputs - self._validator.model_validate(kwargs) - return await self(**kwargs) # type: ignore[arg-type] + # validate llm-generated inputs (skipped for MCP tools) + if self._validator is not None: + self._validator.model_validate(kwargs) + return await self._fn(**kwargs) # type: ignore[arg-type] @property def name(self) -> str: @@ -105,7 +108,7 @@ def tool[**P, R](fn: Callable[P, Awaitable[R]]) -> Tool[P, R]: return_type=hints.get("return", None), ) - t = Tool(fn=fn, validator=validator, schema=schema) + t = Tool(fn=fn, schema=schema, validator=validator) # 3. register in global registry _tool_registry[t.name] = t diff --git a/src/vercel_ai_sdk/mcp/client.py b/src/vercel_ai_sdk/mcp/client.py index 97524fe4..133af068 100644 --- a/src/vercel_ai_sdk/mcp/client.py +++ b/src/vercel_ai_sdk/mcp/client.py @@ -5,8 +5,10 @@ import contextvars import dataclasses import json -from typing import Any, Callable +from collections.abc import Callable +from typing import Any +import httpx import mcp.client.session import mcp.client.stdio import mcp.client.streamable_http @@ -207,8 +209,9 @@ async def get_http_tools( connection_key = f"http:{url}" def transport_factory(): - return mcp.client.streamable_http.streamablehttp_client( - url=url, headers=headers + http_client = httpx.AsyncClient(headers=headers) if headers else None + return mcp.client.streamable_http.streamable_http_client( + url=url, http_client=http_client ) client = await _get_or_create_connection(connection_key, transport_factory) @@ -231,11 +234,16 @@ def _mcp_tool_to_native( if tool_prefix: name = f"{tool_prefix}_{name}" - t = core.tools.Tool( + schema = core.tools.ToolSchema( name=name, description=mcp_tool.description or "", - tool_schema=mcp_tool.inputSchema, + param_schema=mcp_tool.inputSchema, + return_type=Any, + ) + + t = core.tools.Tool( fn=_make_tool_fn(connection_key, mcp_tool.name, transport_factory), + schema=schema, ) # Register so execute_tool() can find it by name core.tools._tool_registry[name] = t diff --git a/src/vercel_ai_sdk/openai/__init__.py b/src/vercel_ai_sdk/openai/__init__.py index 7975b3b6..f0b66bf6 100644 --- a/src/vercel_ai_sdk/openai/__init__.py +++ b/src/vercel_ai_sdk/openai/__init__.py @@ -9,7 +9,7 @@ from .. import core -def _tools_to_openai(tools: Sequence[core.tools.ToolSchema]) -> list[dict[str, Any]]: +def _tools_to_openai(tools: Sequence[core.tools.Tool]) -> list[dict[str, Any]]: """Convert internal Tool objects to OpenAI tool schema format.""" return [ { @@ -17,7 +17,7 @@ def _tools_to_openai(tools: Sequence[core.tools.ToolSchema]) -> list[dict[str, A "function": { "name": tool.name, "description": tool.description, - "parameters": tool.tool_schema, + "parameters": tool.param_schema, }, } for tool in tools @@ -132,7 +132,7 @@ def __init__( async def stream_events( self, messages: list[core.messages.Message], - tools: Sequence[core.tools.ToolSchema] | None = None, + tools: Sequence[core.tools.Tool] | None = None, ) -> AsyncGenerator[core.llm.StreamEvent, None]: """Yield raw stream events from OpenAI API.""" openai_messages = _messages_to_openai(messages) @@ -243,7 +243,7 @@ async def stream_events( async def stream( self, messages: list[core.messages.Message], - tools: Sequence[core.tools.ToolSchema] | None = None, + tools: Sequence[core.tools.Tool] | None = None, ) -> AsyncGenerator[core.messages.Message, None]: """Stream Messages (uses StreamHandler internally).""" handler = core.llm.StreamHandler() diff --git a/tests/conftest.py b/tests/conftest.py index e382fcd9..754a633e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Sequence import vercel_ai_sdk as ai from vercel_ai_sdk.core import messages @@ -17,7 +17,7 @@ def __init__(self, responses: list[list[messages.Message]]) -> None: async def stream( self, messages: list[messages.Message], - tools: list[ai.Tool] | None = None, + tools: Sequence[ai.Tool] | None = None, ) -> AsyncGenerator[messages.Message, None]: if self._call_index >= len(self._responses): raise RuntimeError("MockLLM: no more responses configured") diff --git a/tests/core/test_tools.py b/tests/core/test_tools.py index 3923802b..b465a612 100644 --- a/tests/core/test_tools.py +++ b/tests/core/test_tools.py @@ -20,10 +20,10 @@ async def greet(name: str, count: int) -> str: assert greet.name == "greet" assert greet.description == "Say hello." - props = greet.tool_schema["properties"] - assert props["name"] == {"type": "string"} - assert props["count"] == {"type": "integer"} - assert set(greet.tool_schema["required"]) == {"name", "count"} + props = greet.param_schema["properties"] + assert props["name"]["type"] == "string" + assert props["count"]["type"] == "integer" + assert set(greet.param_schema["required"]) == {"name", "count"} def test_optional_param_not_required(): @@ -32,10 +32,10 @@ async def search(query: str, limit: Optional[int] = None) -> str: """Search.""" return query - assert "query" in search.tool_schema.get("required", []) - assert "limit" not in search.tool_schema.get("required", []) + assert "query" in search.param_schema.get("required", []) + assert "limit" not in search.param_schema.get("required", []) # limit should still appear in properties - assert "limit" in search.tool_schema["properties"] + assert "limit" in search.param_schema["properties"] def test_default_value_not_required(): @@ -54,7 +54,7 @@ async def send(recipients: list[str], urgent: bool = False) -> str: """Send message.""" return "sent" - props = send.tool_schema["properties"] + props = send.param_schema["properties"] assert props["recipients"]["type"] == "array" assert props["recipients"]["items"]["type"] == "string" @@ -68,10 +68,10 @@ async def needs_runtime(query: str, rt: Runtime) -> str: """Tool that needs runtime.""" return query - props = needs_runtime.tool_schema["properties"] + props = needs_runtime.param_schema["properties"] assert "rt" not in props assert "query" in props - assert set(needs_runtime.tool_schema.get("required", [])) == {"query"} + assert set(needs_runtime.param_schema.get("required", [])) == {"query"} # -- Registry ------------------------------------------------------------- @@ -100,7 +100,7 @@ async def add(a: int, b: int) -> int: """Add two numbers.""" return a + b - result = await add.fn(a=1, b=2) + result = await add(a=1, b=2) assert result == 3 @@ -108,4 +108,4 @@ async def add(a: int, b: int) -> int: def search_required(tool: ai.Tool) -> list[str]: - return tool.tool_schema.get("required", []) + return tool.param_schema.get("required", []) diff --git a/tests/mcp/test_client.py b/tests/mcp/test_client.py index 7853fb97..c9d3893b 100644 --- a/tests/mcp/test_client.py +++ b/tests/mcp/test_client.py @@ -55,11 +55,11 @@ def test_mcp_tool_to_native_with_prefix(): def test_mcp_tool_to_native_schema_preserved(): - """The inputSchema from the MCP tool is passed through as tool_schema.""" + """The inputSchema from the MCP tool is passed through as param_schema.""" mcp_tool = _fake_mcp_tool() native = _mcp_tool_to_native(mcp_tool, "test:key", _noop_transport_factory, None) - assert native.tool_schema == mcp_tool.inputSchema + assert native.param_schema == mcp_tool.inputSchema assert native.description == "Echo input" @@ -80,7 +80,7 @@ async def fake_fn(**kwargs): mcp_tool = _fake_mcp_tool(name="mcp_e2e_echo") native = _mcp_tool_to_native(mcp_tool, "test:key", _noop_transport_factory, None) # Replace the real fn (which would try to connect) with our fake - native.fn = fake_fn + native._fn = fake_fn _tool_registry[native.name] = native async def graph(llm: ai.LanguageModel): From bb5823eb68c425377fa0e6f36272d56275ead40c Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Thu, 19 Feb 2026 13:00:19 -0800 Subject: [PATCH 05/20] Add a ToolLike protocol to decouple tool schema serialization boundary --- examples/temporal-durable/workflow.py | 4 ++-- src/vercel_ai_sdk/__init__.py | 3 ++- src/vercel_ai_sdk/anthropic/__init__.py | 6 +++--- src/vercel_ai_sdk/core/llm.py | 4 ++-- src/vercel_ai_sdk/core/runtime.py | 4 ++-- src/vercel_ai_sdk/core/tools.py | 14 +++++++++++++- src/vercel_ai_sdk/openai/__init__.py | 6 +++--- tests/conftest.py | 2 +- 8 files changed, 28 insertions(+), 15 deletions(-) diff --git a/examples/temporal-durable/workflow.py b/examples/temporal-durable/workflow.py index 1f747286..1f6d9c9a 100644 --- a/examples/temporal-durable/workflow.py +++ b/examples/temporal-durable/workflow.py @@ -28,12 +28,12 @@ def __init__( async def stream( self, messages: list[ai.Message], - tools: Sequence[ai.ToolSchema] | None = None, + tools: Sequence[ai.ToolLike] | None = None, ) -> AsyncGenerator[ai.Message, None]: result = await self.call_fn( activities.LLMCallParams( messages=[m.model_dump() for m in messages], - tool_schemas=[t.model_dump() for t in (tools or [])], + tool_schemas=[t.schema.model_dump() for t in (tools or [])], ) ) yield ai.Message.model_validate(result.message) diff --git a/src/vercel_ai_sdk/__init__.py b/src/vercel_ai_sdk/__init__.py index fc25bbce..7dff172e 100644 --- a/src/vercel_ai_sdk/__init__.py +++ b/src/vercel_ai_sdk/__init__.py @@ -12,7 +12,7 @@ HookPart, make_messages, ) -from .core.tools import ToolSchema, Tool, tool +from .core.tools import ToolLike, ToolSchema, Tool, tool from .core.llm import LanguageModel from .core.streams import StreamResult, stream from .core.runtime import ( @@ -37,6 +37,7 @@ "ToolPart", "ToolDelta", "ReasoningPart", + "ToolLike", "ToolSchema", "Tool", "LanguageModel", diff --git a/src/vercel_ai_sdk/anthropic/__init__.py b/src/vercel_ai_sdk/anthropic/__init__.py index d4f89390..3d3f7e27 100644 --- a/src/vercel_ai_sdk/anthropic/__init__.py +++ b/src/vercel_ai_sdk/anthropic/__init__.py @@ -10,7 +10,7 @@ from .. import core -def _tools_to_anthropic(tools: Sequence[core.tools.Tool]) -> list[dict[str, Any]]: +def _tools_to_anthropic(tools: Sequence[core.tools.ToolLike]) -> list[dict[str, Any]]: """Convert internal Tool objects to Anthropic tool schema format.""" return [ { @@ -117,7 +117,7 @@ def __init__( async def stream_events( self, messages: list[core.messages.Message], - tools: Sequence[core.tools.Tool] | None = None, + tools: Sequence[core.tools.ToolLike] | None = None, ) -> AsyncGenerator[core.llm.StreamEvent, None]: """Yield raw stream events from Anthropic API.""" system_prompt, anthropic_messages = _messages_to_anthropic(messages) @@ -206,7 +206,7 @@ async def stream_events( async def stream( self, messages: list[core.messages.Message], - tools: Sequence[core.tools.Tool] | None = None, + tools: Sequence[core.tools.ToolLike] | None = None, ) -> AsyncGenerator[core.messages.Message, None]: """Stream Messages (uses StreamProcessor internally).""" handler = core.llm.StreamHandler() diff --git a/src/vercel_ai_sdk/core/llm.py b/src/vercel_ai_sdk/core/llm.py index 7b857164..ee8bcaa2 100644 --- a/src/vercel_ai_sdk/core/llm.py +++ b/src/vercel_ai_sdk/core/llm.py @@ -215,7 +215,7 @@ class LanguageModel(abc.ABC): async def stream( self, messages: list[messages_.Message], - tools: Sequence[tools_.Tool] | None = None, + tools: Sequence[tools_.ToolLike] | None = None, ) -> AsyncGenerator[messages_.Message, None]: raise NotImplementedError yield @@ -223,7 +223,7 @@ async def stream( async def buffer( self, messages: list[messages_.Message], - tools: Sequence[tools_.Tool] | None = None, + tools: Sequence[tools_.ToolLike] | None = None, ) -> messages_.Message: """Drain the stream and return the final message.""" final = None diff --git a/src/vercel_ai_sdk/core/runtime.py b/src/vercel_ai_sdk/core/runtime.py index b0ace4d9..67d6da2f 100644 --- a/src/vercel_ai_sdk/core/runtime.py +++ b/src/vercel_ai_sdk/core/runtime.py @@ -187,7 +187,7 @@ def _find_runtime_param(fn: Callable[..., Any]) -> str | None: async def stream_step( llm: llm_.LanguageModel, messages: list[messages_.Message], - tools: Sequence[tools_.Tool] | None = None, + tools: Sequence[tools_.ToolLike] | None = None, label: str | None = None, ) -> AsyncGenerator[messages_.Message, None]: """Single LLM call that streams to Runtime.""" @@ -243,7 +243,7 @@ async def execute_tool( async def stream_loop( llm: llm_.LanguageModel, messages: list[messages_.Message], - tools: Sequence[tools_.Tool], + tools: Sequence[tools_.ToolLike], label: str | None = None, ) -> streams_.StreamResult: """Agent loop: stream LLM, execute tools, repeat until done.""" diff --git a/src/vercel_ai_sdk/core/tools.py b/src/vercel_ai_sdk/core/tools.py index b4bc7bc6..0c47abb9 100644 --- a/src/vercel_ai_sdk/core/tools.py +++ b/src/vercel_ai_sdk/core/tools.py @@ -3,7 +3,7 @@ import inspect import json from collections.abc import Awaitable, Callable -from typing import TYPE_CHECKING, Any, get_type_hints +from typing import TYPE_CHECKING, Any, Protocol, get_type_hints, runtime_checkable import pydantic @@ -27,6 +27,18 @@ def _is_runtime_type(hint: Any) -> bool: return hint is Runtime +@runtime_checkable +class ToolLike(Protocol): + """Anything the LLM layer can use as a tool definition.""" + + @property + def name(self) -> str: ... + @property + def description(self) -> str: ... + @property + def param_schema(self) -> dict[str, Any]: ... + + class ToolSchema(pydantic.BaseModel): """What the LLM sees: name, description, and JSON Schema for parameters.""" diff --git a/src/vercel_ai_sdk/openai/__init__.py b/src/vercel_ai_sdk/openai/__init__.py index f0b66bf6..72de5a25 100644 --- a/src/vercel_ai_sdk/openai/__init__.py +++ b/src/vercel_ai_sdk/openai/__init__.py @@ -9,7 +9,7 @@ from .. import core -def _tools_to_openai(tools: Sequence[core.tools.Tool]) -> list[dict[str, Any]]: +def _tools_to_openai(tools: Sequence[core.tools.ToolLike]) -> list[dict[str, Any]]: """Convert internal Tool objects to OpenAI tool schema format.""" return [ { @@ -132,7 +132,7 @@ def __init__( async def stream_events( self, messages: list[core.messages.Message], - tools: Sequence[core.tools.Tool] | None = None, + tools: Sequence[core.tools.ToolLike] | None = None, ) -> AsyncGenerator[core.llm.StreamEvent, None]: """Yield raw stream events from OpenAI API.""" openai_messages = _messages_to_openai(messages) @@ -243,7 +243,7 @@ async def stream_events( async def stream( self, messages: list[core.messages.Message], - tools: Sequence[core.tools.Tool] | None = None, + tools: Sequence[core.tools.ToolLike] | None = None, ) -> AsyncGenerator[core.messages.Message, None]: """Stream Messages (uses StreamHandler internally).""" handler = core.llm.StreamHandler() diff --git a/tests/conftest.py b/tests/conftest.py index 754a633e..6445cd8c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,7 +17,7 @@ def __init__(self, responses: list[list[messages.Message]]) -> None: async def stream( self, messages: list[messages.Message], - tools: Sequence[ai.Tool] | None = None, + tools: Sequence[ai.ToolLike] | None = None, ) -> AsyncGenerator[messages.Message, None]: if self._call_index >= len(self._responses): raise RuntimeError("MockLLM: no more responses configured") From 5570c3d409ca1e5308646a91770de1ab1c25e8b0 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Thu, 19 Feb 2026 14:50:58 -0800 Subject: [PATCH 06/20] Add set_error to clean up tool error logic --- examples/multiagent-textual/server.py | 4 +- examples/samples/hooks.py | 2 +- src/vercel_ai_sdk/agent/agent.py | 36 ++++++-------- src/vercel_ai_sdk/ai_sdk_ui/adapter.py | 25 ++++++---- src/vercel_ai_sdk/anthropic/__init__.py | 19 +++---- src/vercel_ai_sdk/core/checkpoint.py | 8 ++- src/vercel_ai_sdk/core/messages.py | 23 ++++++--- src/vercel_ai_sdk/core/runtime.py | 66 +++++++++++++++---------- src/vercel_ai_sdk/openai/__init__.py | 4 +- tests/core/test_messages.py | 12 ++++- 10 files changed, 115 insertions(+), 84 deletions(-) diff --git a/examples/multiagent-textual/server.py b/examples/multiagent-textual/server.py index 5e0f071d..e12cd06e 100644 --- a/examples/multiagent-textual/server.py +++ b/examples/multiagent-textual/server.py @@ -81,7 +81,7 @@ async def mothership_branch(llm: ai.LanguageModel, query: str): if approval.granted: await ai.execute_tool(tc, message=result.last_message) else: - tc.set_result(f"Denied: {approval.reason}") + tc.set_error(f"Denied: {approval.reason}") else: await ai.execute_tool(tc, message=result.last_message) @@ -113,7 +113,7 @@ async def data_center_branch(llm: ai.LanguageModel, query: str): if approval.granted: await ai.execute_tool(tc, message=result.last_message) else: - tc.set_result(f"Access denied: {approval.reason}") + tc.set_error(f"Access denied: {approval.reason}") else: await ai.execute_tool(tc, message=result.last_message) diff --git a/examples/samples/hooks.py b/examples/samples/hooks.py index 0fc89353..0ff7031c 100644 --- a/examples/samples/hooks.py +++ b/examples/samples/hooks.py @@ -43,7 +43,7 @@ async def graph(llm: ai.LanguageModel, query: str): if approval.granted: await ai.execute_tool(tc, message=result.last_message) else: - tc.set_result({"error": f"Rejected: {approval.reason}"}) + tc.set_error(f"Rejected: {approval.reason}") else: await ai.execute_tool(tc, message=result.last_message) diff --git a/src/vercel_ai_sdk/agent/agent.py b/src/vercel_ai_sdk/agent/agent.py index 400aa88d..f230ce28 100644 --- a/src/vercel_ai_sdk/agent/agent.py +++ b/src/vercel_ai_sdk/agent/agent.py @@ -2,7 +2,6 @@ import asyncio import dataclasses -import traceback import pydantic import vercel_ai_sdk as ai @@ -40,28 +39,21 @@ class Agent: async def _execute_tool( self, tc: ai.ToolPart, message: ai.Message | None = None ) -> None: - """Execute a single tool call with approval check and error handling.""" - try: - # TODO this should be tucked away into the framework - # and done using Pydantic - approval = await ToolApproval.create( - f"approve_{tc.tool_call_id}", - metadata={"tool_name": tc.tool_name, "tool_args": tc.tool_args}, - ) + """Execute a single tool call with approval check. - if approval.granted: - await ai.execute_tool(tc, message=message) - else: - tc.set_result({"error": "Tool call was denied by the user."}) - return - - except Exception as exc: - tc.set_result( - { - "error": f"{type(exc).__name__}: {exc}", - "traceback": traceback.format_exc(), - } - ) + Tool execution errors are handled inside ``ai.execute_tool``. + """ + # TODO this should be tucked away into the framework + # and done using Pydantic + approval = await ToolApproval.create( + f"approve_{tc.tool_call_id}", + metadata={"tool_name": tc.tool_name, "tool_args": tc.tool_args}, + ) + + if approval.granted: + await ai.execute_tool(tc, message=message) + else: + tc.set_error("Tool call was denied by the user.") async def _loop( self, diff --git a/src/vercel_ai_sdk/ai_sdk_ui/adapter.py b/src/vercel_ai_sdk/ai_sdk_ui/adapter.py index 417b6fe3..4b8d2305 100644 --- a/src/vercel_ai_sdk/ai_sdk_ui/adapter.py +++ b/src/vercel_ai_sdk/ai_sdk_ui/adapter.py @@ -179,7 +179,7 @@ async def to_ui_message_stream( for part in state.close_open_blocks(): yield part - # Scan tool parts for new pending/result states + # Scan tool parts for new pending/completed states has_new_pending_tools = False has_new_tool_results = False @@ -191,7 +191,7 @@ async def to_ui_message_stream( ): has_new_pending_tools = True elif ( - part.status == "result" + part.status in ("result", "error") and part.tool_call_id not in state.emitted_tool_results ): has_new_tool_results = True @@ -232,16 +232,18 @@ async def to_ui_message_stream( input=args, ) - # Pass 2: Tool results (same step as tool input per AI SDK protocol) + # Pass 2: Tool outputs (same step as tool input per AI SDK protocol) # Tool input and output are part of the same "step" (one LLM turn) if has_new_tool_results: for part in msg.parts: match part: case core.messages.ToolPart( - status="result", tool_call_id=tc_id, result=result, - ) if tc_id not in state.emitted_tool_results: + ) if ( + part.status in ("result", "error") + and tc_id not in state.emitted_tool_results + ): state.emitted_tool_results.add(tc_id) state.pending_tool_calls.discard(tc_id) yield protocol.ToolOutputAvailablePart( @@ -284,16 +286,19 @@ async def to_sse_stream( # Tool conversion helpers # ============================================================================ -_TOOL_RESULT_STATES: frozenset[str] = frozenset( - {"output-available", "output-error", "output-denied"} -) +_TOOL_RESULT_STATES: frozenset[str] = frozenset({"output-available"}) +_TOOL_ERROR_STATES: frozenset[str] = frozenset({"output-error", "output-denied"}) def _map_tool_status( state: ui_message.UIToolInvocationState, -) -> Literal["pending", "result"]: +) -> Literal["pending", "result", "error"]: """Map AI SDK v6 tool invocation state to internal status.""" - return "result" if state in _TOOL_RESULT_STATES else "pending" + if state in _TOOL_ERROR_STATES: + return "error" + if state in _TOOL_RESULT_STATES: + return "result" + return "pending" def _normalize_tool_args(tool_input: str | dict[str, Any] | None) -> str: diff --git a/src/vercel_ai_sdk/anthropic/__init__.py b/src/vercel_ai_sdk/anthropic/__init__.py index 3d3f7e27..714228c1 100644 --- a/src/vercel_ai_sdk/anthropic/__init__.py +++ b/src/vercel_ai_sdk/anthropic/__init__.py @@ -71,15 +71,16 @@ def _messages_to_anthropic( "input": tool_input, } ) - # If tool has a result, collect it for a separate user message - if part.status == "result" and part.result is not None: - tool_results.append( - { - "type": "tool_result", - "tool_use_id": part.tool_call_id, - "content": str(part.result), - } - ) + # If tool has completed (success or error), collect for user message + if part.status in ("result", "error") and part.result is not None: + entry: dict[str, Any] = { + "type": "tool_result", + "tool_use_id": part.tool_call_id, + "content": str(part.result), + } + if part.status == "error": + entry["is_error"] = True + tool_results.append(entry) if content: result.append({"role": "assistant", "content": content}) diff --git a/src/vercel_ai_sdk/core/checkpoint.py b/src/vercel_ai_sdk/core/checkpoint.py index 85d6abd4..7d99cfd3 100644 --- a/src/vercel_ai_sdk/core/checkpoint.py +++ b/src/vercel_ai_sdk/core/checkpoint.py @@ -26,6 +26,7 @@ class ToolEvent: tool_call_id: str result: Any + status: str = "result" # "result" | "error" @dataclasses.dataclass @@ -46,7 +47,12 @@ def serialize(self) -> dict[str, Any]: return { "steps": [{"index": s.index, "messages": s.messages} for s in self.steps], "tools": [ - {"tool_call_id": t.tool_call_id, "result": t.result} for t in self.tools + { + "tool_call_id": t.tool_call_id, + "result": t.result, + **({"status": t.status} if t.status != "result" else {}), + } + for t in self.tools ], "hooks": [ {"label": h.label, "resolution": h.resolution} for h in self.hooks diff --git a/src/vercel_ai_sdk/core/messages.py b/src/vercel_ai_sdk/core/messages.py index 09a2684a..10a03a73 100644 --- a/src/vercel_ai_sdk/core/messages.py +++ b/src/vercel_ai_sdk/core/messages.py @@ -5,7 +5,6 @@ import pydantic - # Streaming state for parts PartState = Literal["streaming", "done"] @@ -22,7 +21,7 @@ class ToolPart(pydantic.BaseModel): tool_call_id: str tool_name: str tool_args: str - status: Literal["pending", "result"] = "pending" # Execution status + status: Literal["pending", "result", "error"] = "pending" # Execution status result: Any = None type: Literal["tool"] = "tool" # Streaming state (for args streaming) @@ -34,6 +33,11 @@ def set_result(self, result: Any) -> None: self.status = "result" self.result = result + def set_error(self, message: str) -> None: + """Set a tool error and mark as failed.""" + self.status = "error" + self.result = message + class ReasoningPart(pydantic.BaseModel): text: str @@ -85,9 +89,11 @@ class Message(pydantic.BaseModel): def is_done(self) -> bool: """Message is done when all parts are done (or have no streaming state).""" for part in self.parts: - if isinstance(part, (TextPart, ReasoningPart, ToolPart)): - if part.state == "streaming": - return False + if ( + isinstance(part, (TextPart, ReasoningPart, ToolPart)) + and part.state == "streaming" + ): + return False return True @property @@ -149,9 +155,10 @@ def get_tool_part(self, tool_call_id: str) -> ToolPart | None: def get_hook_part(self, hook_id: str | None = None) -> HookPart | None: """Find a HookPart by hook_id, or return the first HookPart if no id given.""" for part in self.parts: - if isinstance(part, HookPart): - if hook_id is None or part.hook_id == hook_id: - return part + if isinstance(part, HookPart) and ( + hook_id is None or part.hook_id == hook_id + ): + return part return None diff --git a/src/vercel_ai_sdk/core/runtime.py b/src/vercel_ai_sdk/core/runtime.py index 67d6da2f..d09a28f8 100644 --- a/src/vercel_ai_sdk/core/runtime.py +++ b/src/vercel_ai_sdk/core/runtime.py @@ -1,20 +1,21 @@ from __future__ import annotations import asyncio -import json import contextvars import dataclasses +import json from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine, Sequence from typing import Any, get_type_hints +import pydantic + from .. import mcp -from . import messages as messages_ -from . import tools as tools_ +from . import checkpoint as checkpoint_ +from . import hooks as hooks_ from . import llm as llm_ +from . import messages as messages_ from . import streams as streams_ -from . import hooks as hooks_ -from . import checkpoint as checkpoint_ - +from . import tools as tools_ # ── Queue item types ────────────────────────────────────────────── @@ -64,8 +65,8 @@ def __init__( # Replay cursors self._step_index: int = 0 - self._tool_replay: dict[str, Any] = { - t.tool_call_id: t.result for t in self._checkpoint.tools + self._tool_replay: dict[str, checkpoint_.ToolEvent] = { + t.tool_call_id: t for t in self._checkpoint.tools } self._hook_replay: dict[str, dict[str, Any]] = { h.label: h.resolution for h in self._checkpoint.hooks @@ -128,14 +129,17 @@ def record_step(self, result: streams_.StreamResult) -> None: # ── Replay / record: tools ──────────────────────────────────── - def try_replay_tool(self, tool_call_id: str) -> Any | None: - if tool_call_id in self._tool_replay: - return self._tool_replay[tool_call_id] - return None + def try_replay_tool(self, tool_call_id: str) -> checkpoint_.ToolEvent | None: + """Return the cached ToolEvent if available, else None.""" + return self._tool_replay.get(tool_call_id) - def record_tool(self, tool_call_id: str, result: Any) -> None: + def record_tool( + self, tool_call_id: str, result: Any, *, status: str = "result" + ) -> None: self._tool_log.append( - checkpoint_.ToolEvent(tool_call_id=tool_call_id, result=result) + checkpoint_.ToolEvent( + tool_call_id=tool_call_id, result=result, status=status + ) ) # ── Replay / record: hooks ──────────────────────────────────── @@ -189,7 +193,7 @@ async def stream_step( messages: list[messages_.Message], tools: Sequence[tools_.ToolLike] | None = None, label: str | None = None, -) -> AsyncGenerator[messages_.Message, None]: +) -> AsyncGenerator[messages_.Message]: """Single LLM call that streams to Runtime.""" async for msg in llm.stream(messages=messages, tools=tools): msg.label = label @@ -206,7 +210,7 @@ async def execute_tool( Looks up the tool by name from the global registry, executes it, and updates the ToolPart (and parent Message) with the result. Emits the updated message to the Runtime queue so the UI sees - the transition from status="pending" to status="result". + the transition from status="pending" to status="result" (or "error"). If a checkpoint exists with a cached result for this tool_call_id, returns the cached result without re-executing. @@ -217,23 +221,31 @@ async def execute_tool( if rt: cached = rt.try_replay_tool(tool_call.tool_call_id) if cached is not None: - tool_call.set_result(cached) - return cached + if cached.status == "error": + tool_call.set_error(cached.result) + else: + tool_call.set_result(cached.result) + return cached.result # Fresh execution tool = tools_.get_tool(tool_call.tool_name) if tool is None: raise ValueError(f"Tool not found in registry: {tool_call.tool_name}") - # TODO catch validation error and json error - result = await tool.validate_and_call(tool_call.tool_args, rt) - tool_call.set_result(result) + try: + result = await tool.validate_and_call(tool_call.tool_args, rt) + tool_call.set_result(result) + except (json.JSONDecodeError, pydantic.ValidationError) as exc: + # LLM produced malformed JSON or args that don't match the schema. + # Report back as a tool error so the model can retry. + result = f"{type(exc).__name__}: {exc}" + tool_call.set_error(result) # Record for checkpoint if rt: - rt.record_tool(tool_call.tool_call_id, result) + rt.record_tool(tool_call.tool_call_id, result, status=tool_call.status) - # Emit updated message so UI sees status="result" + # Emit updated message so UI sees status change if rt and message: await rt.put_message(message.model_copy(deep=True)) @@ -288,7 +300,7 @@ class RunResult: """ def __init__(self) -> None: - self._messages: AsyncGenerator[messages_.Message, None] | None = None + self._messages: AsyncGenerator[messages_.Message] | None = None self._runtime: Runtime | None = None @property @@ -310,7 +322,7 @@ def pending_hooks(self) -> dict[str, HookInfo]: for label, sus in self._runtime._pending_hooks.items() } - async def __aiter__(self) -> AsyncGenerator[messages_.Message, None]: + async def __aiter__(self) -> AsyncGenerator[messages_.Message]: if self._messages is not None: async for msg in self._messages: yield msg @@ -347,7 +359,7 @@ def run( """ result = RunResult() - async def _generate() -> AsyncGenerator[messages_.Message, None]: + async def _generate() -> AsyncGenerator[messages_.Message]: runtime = Runtime(checkpoint=checkpoint) result._runtime = runtime token_runtime = _runtime.set(runtime) @@ -375,7 +387,7 @@ async def _generate() -> AsyncGenerator[messages_.Message, None]: step_item = await asyncio.wait_for( runtime._step_queue.get(), timeout=0.1 ) - except asyncio.TimeoutError: + except TimeoutError: continue if isinstance(step_item, Runtime._Sentinel): diff --git a/src/vercel_ai_sdk/openai/__init__.py b/src/vercel_ai_sdk/openai/__init__.py index 72de5a25..32451bf2 100644 --- a/src/vercel_ai_sdk/openai/__init__.py +++ b/src/vercel_ai_sdk/openai/__init__.py @@ -62,8 +62,8 @@ def _messages_to_openai(messages: list[core.messages.Message]) -> list[dict[str, }, } ) - # If tool has a result, collect it for separate tool messages - if part.status == "result" and part.result is not None: + # If tool has completed (success or error), collect for tool messages + if part.status in ("result", "error") and part.result is not None: tool_results.append( { "role": "tool", diff --git a/tests/core/test_messages.py b/tests/core/test_messages.py index 612e801f..608675bb 100644 --- a/tests/core/test_messages.py +++ b/tests/core/test_messages.py @@ -1,4 +1,4 @@ -"""Message model: properties, ToolPart.set_result, make_messages.""" +"""Message model: properties, ToolPart.set_result/set_error, make_messages.""" from vercel_ai_sdk.core.messages import ( HookPart, @@ -178,7 +178,7 @@ def test_get_hook_part_missing(): assert m.get_hook_part("h-nope") is None -# -- ToolPart.set_result --------------------------------------------------- +# -- ToolPart.set_result / set_error --------------------------------------- def test_set_result(): @@ -189,6 +189,14 @@ def test_set_result(): assert tp.result == {"answer": 42} +def test_set_error(): + tp = ToolPart(tool_call_id="tc1", tool_name="t", tool_args="{}") + assert tp.status == "pending" + tp.set_error("Something went wrong") + assert tp.status == "error" + assert tp.result == "Something went wrong" + + # -- make_messages --------------------------------------------------------- From 3d8165ab0bd45bf1094d59ae7c3b69bf3505c0db Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Thu, 19 Feb 2026 15:26:44 -0800 Subject: [PATCH 07/20] Convert Checkpoint to Pydantic --- README.md | 6 +-- examples/fastapi-vite/backend/routes/chat.py | 4 +- src/vercel_ai_sdk/core/checkpoint.py | 51 +++++--------------- src/vercel_ai_sdk/core/runtime.py | 2 +- tests/core/test_checkpoint.py | 12 ++--- 5 files changed, 23 insertions(+), 52 deletions(-) diff --git a/README.md b/README.md index 6b5e543f..ce9659ec 100644 --- a/README.md +++ b/README.md @@ -187,10 +187,10 @@ checkpoint = ai.get_checkpoint() ```python # After a run completes or suspends checkpoint = result.checkpoint -data = checkpoint.serialize() # dict, JSON-safe +data = checkpoint.model_dump() # dict, JSON-safe # Later: restore and resume -checkpoint = ai.Checkpoint.deserialize(data) +checkpoint = ai.Checkpoint.model_validate(data) result = ai.run(my_agent, llm, query, checkpoint=checkpoint) ``` @@ -277,7 +277,7 @@ return StreamingResponse(stream_response(), headers=UI_MESSAGE_STREAM_HEADERS) | `RunResult` | Return type of `run()`. Async-iterable for messages, then `.checkpoint` and `.pending_hooks` | | `HookInfo` | Pending hook info: `label`, `hook_type`, `metadata` | | `Hook` | Generic hook base with `.create()`, `.resolve()`, `.cancel()` class methods | -| `Checkpoint` | Serializable snapshot of completed work: `steps[]`, `tools[]`, `hooks[]`. Has `.serialize()` / `.deserialize()` | +| `Checkpoint` | Pydantic model — serializable snapshot of completed work: `steps[]`, `tools[]`, `hooks[]`. Use `.model_dump()` / `.model_validate()` | | `LanguageModel` | Abstract base class for LLM providers | ## Examples diff --git a/examples/fastapi-vite/backend/routes/chat.py b/examples/fastapi-vite/backend/routes/chat.py index afd0d109..0ed362b4 100644 --- a/examples/fastapi-vite/backend/routes/chat.py +++ b/examples/fastapi-vite/backend/routes/chat.py @@ -37,7 +37,7 @@ async def chat(request: ChatRequest): # run — the frontend carries the full message history — so we only # load a checkpoint when one was saved from a previous incomplete run. saved = await file_storage.get(checkpoint_key) - checkpoint = ai.Checkpoint.deserialize(saved) if saved else None + checkpoint = ai.Checkpoint.model_validate(saved) if saved else None result = ai.run(agent.graph, llm, messages, agent.TOOLS, checkpoint=checkpoint) @@ -49,7 +49,7 @@ async def stream_response(): # so the next request starts fresh. If hooks are pending, save # the checkpoint so the next request can resume from here. if result.pending_hooks: - await file_storage.put(checkpoint_key, result.checkpoint.serialize()) + await file_storage.put(checkpoint_key, result.checkpoint.model_dump()) else: await file_storage.delete(checkpoint_key) diff --git a/src/vercel_ai_sdk/core/checkpoint.py b/src/vercel_ai_sdk/core/checkpoint.py index 7d99cfd3..b49629d3 100644 --- a/src/vercel_ai_sdk/core/checkpoint.py +++ b/src/vercel_ai_sdk/core/checkpoint.py @@ -1,27 +1,24 @@ from __future__ import annotations -import dataclasses from typing import Any +import pydantic + from . import messages as messages_ from . import streams as streams_ -@dataclasses.dataclass -class StepEvent: +class StepEvent(pydantic.BaseModel): """A completed @stream step.""" index: int - messages: list[dict[str, Any]] # Message.model_dump() for each + messages: list[messages_.Message] def to_stream_result(self) -> streams_.StreamResult: - return streams_.StreamResult( - messages=[messages_.Message.model_validate(m) for m in self.messages] - ) + return streams_.StreamResult(messages=list(self.messages)) -@dataclasses.dataclass -class ToolEvent: +class ToolEvent(pydantic.BaseModel): """A completed tool execution.""" tool_call_id: str @@ -29,40 +26,14 @@ class ToolEvent: status: str = "result" # "result" | "error" -@dataclasses.dataclass -class HookEvent: +class HookEvent(pydantic.BaseModel): """A resolved hook.""" label: str resolution: dict[str, Any] -@dataclasses.dataclass -class Checkpoint: - steps: list[StepEvent] = dataclasses.field(default_factory=list) - tools: list[ToolEvent] = dataclasses.field(default_factory=list) - hooks: list[HookEvent] = dataclasses.field(default_factory=list) - - def serialize(self) -> dict[str, Any]: - return { - "steps": [{"index": s.index, "messages": s.messages} for s in self.steps], - "tools": [ - { - "tool_call_id": t.tool_call_id, - "result": t.result, - **({"status": t.status} if t.status != "result" else {}), - } - for t in self.tools - ], - "hooks": [ - {"label": h.label, "resolution": h.resolution} for h in self.hooks - ], - } - - @classmethod - def deserialize(cls, data: dict[str, Any]) -> Checkpoint: - return cls( - steps=[StepEvent(**s) for s in data.get("steps", [])], - tools=[ToolEvent(**t) for t in data.get("tools", [])], - hooks=[HookEvent(**h) for h in data.get("hooks", [])], - ) +class Checkpoint(pydantic.BaseModel): + steps: list[StepEvent] = [] + tools: list[ToolEvent] = [] + hooks: list[HookEvent] = [] diff --git a/src/vercel_ai_sdk/core/runtime.py b/src/vercel_ai_sdk/core/runtime.py index d09a28f8..634f0737 100644 --- a/src/vercel_ai_sdk/core/runtime.py +++ b/src/vercel_ai_sdk/core/runtime.py @@ -122,7 +122,7 @@ def try_replay_step(self) -> streams_.StreamResult | None: def record_step(self, result: streams_.StreamResult) -> None: event = checkpoint_.StepEvent( index=self._step_index, - messages=[m.model_dump() for m in result.messages], + messages=list(result.messages), ) self._step_log.append(event) self._step_index += 1 diff --git a/tests/core/test_checkpoint.py b/tests/core/test_checkpoint.py index 0d5e1184..eefc995e 100644 --- a/tests/core/test_checkpoint.py +++ b/tests/core/test_checkpoint.py @@ -163,18 +163,18 @@ def test_checkpoint_serialization_roundtrip(): StepEvent( index=0, messages=[ - { - "id": "m1", - "role": "assistant", - "parts": [{"type": "text", "text": "hi"}], - } + ai.Message( + id="m1", + role="assistant", + parts=[ai.TextPart(text="hi")], + ) ], ) ], tools=[ToolEvent(tool_call_id="tc-1", result=42)], hooks=[HookEvent(label="h1", resolution={"granted": True})], ) - cp2 = Checkpoint.deserialize(cp.serialize()) + cp2 = Checkpoint.model_validate(cp.model_dump()) assert cp2.steps[0].index == 0 assert cp2.tools[0].result == 42 assert cp2.hooks[0].label == "h1" From 5fbbc2fb7424a8a7ec74fd9503fd195f59696329 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Thu, 19 Feb 2026 15:39:34 -0800 Subject: [PATCH 08/20] Fix @ai.stream type erasure --- src/vercel_ai_sdk/core/streams.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/vercel_ai_sdk/core/streams.py b/src/vercel_ai_sdk/core/streams.py index 4f8354fa..c108018d 100644 --- a/src/vercel_ai_sdk/core/streams.py +++ b/src/vercel_ai_sdk/core/streams.py @@ -3,7 +3,7 @@ import asyncio import dataclasses import functools -from collections.abc import AsyncGenerator, Callable +from collections.abc import AsyncGenerator, Awaitable, Callable from typing import Any from . import messages as messages_ @@ -31,13 +31,13 @@ def text(self) -> str: return "" -Stream = Callable[[], AsyncGenerator[messages_.Message, None]] +Stream = Callable[[], AsyncGenerator[messages_.Message]] # maybe it should have a name and an id inferred from LLM outputs -def stream( - fn: Callable[..., AsyncGenerator[messages_.Message, None]], -) -> Callable[..., Any]: +def stream[**P]( + fn: Callable[P, AsyncGenerator[messages_.Message]], +) -> Callable[P, Awaitable[StreamResult]]: """ Decorator to put an async generator into the Runtime queue. @@ -64,7 +64,7 @@ async def wrapped(*args: Any, **kwargs: Any) -> StreamResult: # Fresh execution: submit to queue and wait future: asyncio.Future[StreamResult] = asyncio.Future() - async def stream_fn() -> AsyncGenerator[messages_.Message, None]: + async def stream_fn() -> AsyncGenerator[messages_.Message]: async for msg in fn(*args, **kwargs): yield msg From 8a95ad676f0c11584798a6a44147e40dc5163630 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Thu, 19 Feb 2026 15:53:18 -0800 Subject: [PATCH 09/20] Fix mypy errors in tools.py --- src/vercel_ai_sdk/core/tools.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/vercel_ai_sdk/core/tools.py b/src/vercel_ai_sdk/core/tools.py index 0c47abb9..c1a9a49b 100644 --- a/src/vercel_ai_sdk/core/tools.py +++ b/src/vercel_ai_sdk/core/tools.py @@ -11,10 +11,10 @@ from . import runtime as runtime_ # Module-level tool registry - populated at decoration time -_tool_registry: dict[str, Tool] = {} +_tool_registry: dict[str, Tool[..., Any]] = {} -def get_tool(name: str) -> Tool | None: +def get_tool(name: str) -> Tool[..., Any] | None: """Look up a tool by name from the global registry.""" return _tool_registry.get(name) @@ -75,7 +75,7 @@ async def validate_and_call( # validate llm-generated inputs (skipped for MCP tools) if self._validator is not None: self._validator.model_validate(kwargs) - return await self._fn(**kwargs) # type: ignore[arg-type] + return await self._fn(**kwargs) # type: ignore[call-arg] @property def name(self) -> str: @@ -97,7 +97,7 @@ def tool[**P, R](fn: Callable[P, Awaitable[R]]) -> Tool[P, R]: sig = inspect.signature(fn) hints = get_type_hints(fn) if hasattr(fn, "__annotations__") else {} - fields = {} + fields: dict[str, Any] = {} for param_name, param in sig.parameters.items(): param_type = hints.get(param_name, str) From 70f87d95606a283a51ced87776321eab3e945638 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Thu, 19 Feb 2026 16:01:18 -0800 Subject: [PATCH 10/20] Fix mypy errors in mcp/client.py --- src/vercel_ai_sdk/mcp/client.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/vercel_ai_sdk/mcp/client.py b/src/vercel_ai_sdk/mcp/client.py index 133af068..0582e2c4 100644 --- a/src/vercel_ai_sdk/mcp/client.py +++ b/src/vercel_ai_sdk/mcp/client.py @@ -98,8 +98,10 @@ async def call_tool(**kwargs: Any) -> Any: client.call_tool(tool_name, kwargs), timeout=30.0, ) - except asyncio.TimeoutError: - raise RuntimeError(f"MCP tool call timed out after 30 seconds: {tool_name}") + except TimeoutError as e: + raise RuntimeError( + f"MCP tool call timed out after 30 seconds: {tool_name}" + ) from e # Handle error responses if result.isError: @@ -137,7 +139,7 @@ async def get_stdio_tools( env: dict[str, str] | None = None, cwd: str | None = None, tool_prefix: str | None = None, -) -> list[core.tools.Tool]: +) -> list[core.tools.Tool[..., Any]]: """ Get tools from an MCP server running as a subprocess. @@ -161,7 +163,7 @@ async def get_stdio_tools( """ connection_key = f"stdio:{command}:{':'.join(args)}" - def transport_factory(): + def transport_factory() -> contextlib.AbstractAsyncContextManager[Any]: return mcp.client.stdio.stdio_client( mcp.client.stdio.StdioServerParameters( command=command, @@ -185,7 +187,7 @@ async def get_http_tools( *, headers: dict[str, str] | None = None, tool_prefix: str | None = None, -) -> list[core.tools.Tool]: +) -> list[core.tools.Tool[..., Any]]: """ Get tools from an MCP server over HTTP (Streamable HTTP transport). @@ -208,7 +210,7 @@ async def get_http_tools( """ connection_key = f"http:{url}" - def transport_factory(): + def transport_factory() -> contextlib.AbstractAsyncContextManager[Any]: http_client = httpx.AsyncClient(headers=headers) if headers else None return mcp.client.streamable_http.streamable_http_client( url=url, http_client=http_client @@ -228,7 +230,7 @@ def _mcp_tool_to_native( connection_key: str, transport_factory: Callable[[], contextlib.AbstractAsyncContextManager[Any]], tool_prefix: str | None, -) -> core.tools.Tool: +) -> core.tools.Tool[..., Any]: """Convert an MCP tool to a native Tool.""" name = mcp_tool.name if tool_prefix: From cf2e838c4723b7a97ff808f29d455736015adb84 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Thu, 19 Feb 2026 16:13:18 -0800 Subject: [PATCH 11/20] Fix mypy errors in the AI SDK UI adapter --- src/vercel_ai_sdk/ai_sdk_ui/adapter.py | 32 +++++++++++------------ src/vercel_ai_sdk/ai_sdk_ui/ui_message.py | 4 +-- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/vercel_ai_sdk/ai_sdk_ui/adapter.py b/src/vercel_ai_sdk/ai_sdk_ui/adapter.py index 4b8d2305..cd2f8800 100644 --- a/src/vercel_ai_sdk/ai_sdk_ui/adapter.py +++ b/src/vercel_ai_sdk/ai_sdk_ui/adapter.py @@ -13,7 +13,6 @@ from .. import core from . import protocol, ui_message - # ============================================================================ # Serialization utilities # ============================================================================ @@ -128,7 +127,7 @@ def begin_message( async def to_ui_message_stream( messages: AsyncIterable[core.messages.Message], -) -> AsyncGenerator[protocol.UIMessageStreamPart, None]: +) -> AsyncGenerator[protocol.UIMessageStreamPart]: """ Convert a proto_sdk message stream into AI SDK UI message stream parts. @@ -183,16 +182,16 @@ async def to_ui_message_stream( has_new_pending_tools = False has_new_tool_results = False - for part in msg.parts: - if isinstance(part, core.messages.ToolPart): + for msg_part in msg.parts: + if isinstance(msg_part, core.messages.ToolPart): if ( - part.status == "pending" - and part.tool_call_id not in state.pending_tool_calls + msg_part.status == "pending" + and msg_part.tool_call_id not in state.pending_tool_calls ): has_new_pending_tools = True elif ( - part.status in ("result", "error") - and part.tool_call_id not in state.emitted_tool_results + msg_part.status in ("result", "error") + and msg_part.tool_call_id not in state.emitted_tool_results ): has_new_tool_results = True @@ -201,8 +200,8 @@ async def to_ui_message_stream( # 2. Then handle tool results (which may need their own step) # Pass 1: Text and pending tool inputs - for part in msg.parts: - match part: + for msg_part in msg.parts: + match msg_part: case core.messages.TextPart(text=text) if ( text and not had_active_text @@ -235,13 +234,14 @@ async def to_ui_message_stream( # Pass 2: Tool outputs (same step as tool input per AI SDK protocol) # Tool input and output are part of the same "step" (one LLM turn) if has_new_tool_results: - for part in msg.parts: - match part: + for msg_part in msg.parts: + match msg_part: case core.messages.ToolPart( tool_call_id=tc_id, result=result, + status=status, ) if ( - part.status in ("result", "error") + status in ("result", "error") and tc_id not in state.emitted_tool_results ): state.emitted_tool_results.add(tc_id) @@ -261,7 +261,7 @@ async def to_ui_message_stream( async def filter_by_label( messages: AsyncIterable[core.messages.Message], label: str | None = None, -) -> AsyncGenerator[core.messages.Message, None]: +) -> AsyncGenerator[core.messages.Message]: """Filter a message stream to a single agent label. If label is provided, only messages with that label pass through. @@ -276,7 +276,7 @@ async def filter_by_label( async def to_sse_stream( messages: AsyncIterable[core.messages.Message], -) -> AsyncGenerator[str, None]: +) -> AsyncGenerator[str]: """Convert a proto_sdk message stream directly into SSE-formatted strings.""" async for part in to_ui_message_stream(messages): yield format_sse(part) @@ -379,7 +379,7 @@ def to_messages( ): pass # Skip unsupported/boundary parts - # Validate user/system messages have content - OpenAI requires it for these roles. + # Validate user/system messages have content - OpenAI requires it there. # Assistant messages can have empty content if they have tool calls. if ui_msg.role in ("user", "system") and not internal_parts: raise ValueError( diff --git a/src/vercel_ai_sdk/ai_sdk_ui/ui_message.py b/src/vercel_ai_sdk/ai_sdk_ui/ui_message.py index 96e9e89e..a920b0f4 100644 --- a/src/vercel_ai_sdk/ai_sdk_ui/ui_message.py +++ b/src/vercel_ai_sdk/ai_sdk_ui/ui_message.py @@ -10,7 +10,7 @@ from __future__ import annotations import uuid -from typing import Any, Literal +from typing import Any, Literal, cast import pydantic @@ -172,7 +172,7 @@ def _parse_ui_part(part_data: dict[str, Any]) -> UIMessagePart | None: part_type = part_data.get("type", "") if model_cls := _STATIC_UI_PART_TYPES.get(part_type): - return model_cls.model_validate(part_data) + return cast(UIMessagePart, model_cls.model_validate(part_data)) match part_type: case str() as t if t.startswith("tool-"): From 871bbd0898a3d842ec9ffb605a101e3c37da6f0f Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Thu, 19 Feb 2026 16:32:09 -0800 Subject: [PATCH 12/20] Fix typing errors and smells in hooks.py --- src/vercel_ai_sdk/core/hooks.py | 57 ++++++++++++++++----------------- tests/core/test_hooks.py | 7 ++-- 2 files changed, 31 insertions(+), 33 deletions(-) diff --git a/src/vercel_ai_sdk/core/hooks.py b/src/vercel_ai_sdk/core/hooks.py index cf8cb743..17a1c1f3 100644 --- a/src/vercel_ai_sdk/core/hooks.py +++ b/src/vercel_ai_sdk/core/hooks.py @@ -1,21 +1,19 @@ from __future__ import annotations import asyncio -from typing import Any, ClassVar, Generic, TypeVar +from typing import TYPE_CHECKING, Any, ClassVar import pydantic from . import messages as messages_ -T = TypeVar("T", bound=pydantic.BaseModel) +if TYPE_CHECKING: + from . import runtime as runtime_ # --------------------------------------------------------------------------- # Module-level hook registries # -# These allow Hook.resolve() and Hook.cancel() to work from anywhere — -# no ContextVar lookup, no RunResult handle, no copy_context(). -# # _live_hooks: # Populated by Hook.create() when a hook suspends inside a running graph. # Maps hook label -> (future, metadata dict, Runtime). @@ -32,8 +30,9 @@ # immediately without suspending. Entries are removed on consumption. # --------------------------------------------------------------------------- -_live_hooks: dict[str, tuple[asyncio.Future[Any], dict[str, Any], Any]] = {} -# label -> (future, metadata, Runtime) +_live_hooks: dict[ + str, tuple[asyncio.Future[Any], dict[str, Any], runtime_.Runtime] +] = {} _pending_resolutions: dict[str, dict[str, Any]] = {} # label -> validated resolution dict @@ -46,17 +45,17 @@ def _cleanup_run(labels: set[str]) -> None: _pending_resolutions.pop(label, None) -class Hook(Generic[T]): +class Hook[T: pydantic.BaseModel]: """ Hook: a suspension point that requires external input to continue. - Usage in graph code (identical in all modes): + Usage in graph code: approval = await ToolApproval.create("approve_delete", metadata={...}) if approval.granted: ... - Resolution from outside the graph (any task, any context): + Resolution from outside the graph: ToolApproval.resolve("approve_delete", {"granted": True, ...}) @@ -92,9 +91,9 @@ async def create(cls, label: str, metadata: dict[str, Any] | None = None) -> T: a single run. metadata: Optional metadata surfaced in the pending HookPart message. """ - from . import runtime as runtime_ + from . import runtime as rt_mod - rt = runtime_._runtime.get(None) + rt = rt_mod._runtime.get(None) if rt is None: raise ValueError("No Runtime context - must be called within ai.run()") @@ -112,7 +111,7 @@ async def create(cls, label: str, metadata: dict[str, Any] | None = None) -> T: # Submit to step queue — run() decides what to do future: asyncio.Future[dict[str, Any]] = asyncio.Future() - suspension = runtime_.HookSuspension( + suspension = rt_mod.HookSuspension( label=label, hook_type=cls._hook_type, metadata=metadata or {}, @@ -169,9 +168,6 @@ def resolve(cls, label: str, data: T | dict[str, Any]) -> None: replays the graph and Hook.create() executes, it finds the pre-registered resolution and returns without suspending. - Can be called from any task, any context — no ContextVar or - RunResult handle needed. - Args: label: The hook label to resolve. data: Resolution payload (dict or pydantic model). Validated @@ -201,7 +197,7 @@ def resolve(cls, label: str, data: T | dict[str, Any]) -> None: _pending_resolutions[label] = resolution @classmethod - def cancel(cls, label: str, reason: str | None = None) -> None: + async def cancel(cls, label: str, reason: str | None = None) -> None: """Cancel a pending hook. Only works for live hooks (long-running mode). Raises if the @@ -213,21 +209,22 @@ def cancel(cls, label: str, reason: str | None = None) -> None: future, hook_metadata, rt = _live_hooks.pop(label) future.cancel(reason) - msg = messages_.Message( - role="assistant", - parts=[ - messages_.HookPart( - hook_id=label, - hook_type=cls._hook_type, - status="cancelled", - metadata=hook_metadata, - ) - ], + await rt.put_message( + messages_.Message( + role="assistant", + parts=[ + messages_.HookPart( + hook_id=label, + hook_type=cls._hook_type, + status="cancelled", + metadata=hook_metadata, + ) + ], + ) ) - asyncio.create_task(rt.put_message(msg)) -def hook(cls: type[T]) -> type[Hook[T]]: +def hook[T: pydantic.BaseModel](cls: type[T]) -> type[Hook[T]]: """ Decorator to create a Hook type from a pydantic model. @@ -243,4 +240,4 @@ def hook(cls: type[T]) -> type[Hook[T]]: }, ) - return hook_impl # type: ignore[return-value] + return hook_impl diff --git a/tests/core/test_hooks.py b/tests/core/test_hooks.py index 5e580523..058f3e90 100644 --- a/tests/core/test_hooks.py +++ b/tests/core/test_hooks.py @@ -74,7 +74,7 @@ async def graph(llm: ai.LanguageModel): async for msg in run_result: if any(isinstance(p, ai.HookPart) and p.status == "pending" for p in msg.parts): - Confirmation.cancel("cancel_me", reason="denied") + await Confirmation.cancel("cancel_me", reason="denied") assert was_cancelled @@ -82,9 +82,10 @@ async def graph(llm: ai.LanguageModel): # -- Hook.cancel() on non-existent label raises ---------------------------- -def test_cancel_nonexistent_raises(): +@pytest.mark.asyncio +async def test_cancel_nonexistent_raises(): with pytest.raises(ValueError, match="No pending hook"): - Confirmation.cancel("does_not_exist_xyz") + await Confirmation.cancel("does_not_exist_xyz") # -- Pre-registration (serverless re-entry) -------------------------------- From 998b0c82e845a6045d5df4804652769d0ad105ca Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Thu, 19 Feb 2026 16:36:31 -0800 Subject: [PATCH 13/20] Fix remaining type errors in runtime and openai adapter --- src/vercel_ai_sdk/core/runtime.py | 3 ++- src/vercel_ai_sdk/openai/__init__.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/vercel_ai_sdk/core/runtime.py b/src/vercel_ai_sdk/core/runtime.py index 634f0737..69ae5fdf 100644 --- a/src/vercel_ai_sdk/core/runtime.py +++ b/src/vercel_ai_sdk/core/runtime.py @@ -268,7 +268,8 @@ async def stream_loop( return result last_msg = result.last_message - local_messages.append(last_msg) + if last_msg is not None: + local_messages.append(last_msg) await asyncio.gather( *(execute_tool(tc, message=last_msg) for tc in result.tool_calls) diff --git a/src/vercel_ai_sdk/openai/__init__.py b/src/vercel_ai_sdk/openai/__init__.py index 32451bf2..6e095661 100644 --- a/src/vercel_ai_sdk/openai/__init__.py +++ b/src/vercel_ai_sdk/openai/__init__.py @@ -162,7 +162,7 @@ async def stream_events( # Track active blocks for Start/End events text_started = False reasoning_started = False - tool_calls: dict[int, dict] = {} # index -> {id, name, started} + tool_calls: dict[int, dict[str, Any]] = {} # index -> {id, name, started} async for chunk in stream: if not chunk.choices: From bf00dbe3f84a65aae9417e37af40c84c2ca9966b Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Thu, 19 Feb 2026 16:51:46 -0800 Subject: [PATCH 14/20] Fix typing errors in the agent module --- src/vercel_ai_sdk/agent/agent.py | 16 ++++++++++------ src/vercel_ai_sdk/agent/tools.py | 4 +++- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/vercel_ai_sdk/agent/agent.py b/src/vercel_ai_sdk/agent/agent.py index f230ce28..8e3b4b2a 100644 --- a/src/vercel_ai_sdk/agent/agent.py +++ b/src/vercel_ai_sdk/agent/agent.py @@ -2,11 +2,14 @@ import asyncio import dataclasses +from typing import Any + import pydantic import vercel_ai_sdk as ai -from .tools import _filesystem, BUILTIN_TOOLS + from . import proto +from .tools import BUILTIN_TOOLS, _filesystem @ai.hook @@ -34,7 +37,7 @@ class Agent: model: ai.LanguageModel filesystem: proto.Filesystem system: str = "" - tools: list[ai.Tool] = dataclasses.field(default_factory=list) + tools: list[ai.Tool[..., Any]] = dataclasses.field(default_factory=list) async def _execute_tool( self, tc: ai.ToolPart, message: ai.Message | None = None @@ -43,9 +46,9 @@ async def _execute_tool( Tool execution errors are handled inside ``ai.execute_tool``. """ - # TODO this should be tucked away into the framework - # and done using Pydantic - approval = await ToolApproval.create( + # TODO: mypy doesn't support class decorators that change the class type — + # @ai.hook returns type[Hook[T]] but mypy still sees the original BaseModel. + approval = await ToolApproval.create( # type: ignore[attr-defined] f"approve_{tc.tool_call_id}", metadata={"tool_name": tc.tool_name, "tool_args": tc.tool_args}, ) @@ -58,7 +61,7 @@ async def _execute_tool( async def _loop( self, messages: list[ai.Message], - tools: list[ai.Tool], + tools: list[ai.Tool[..., Any]], label: str | None = None, ) -> ai.StreamResult: local_messages = list(messages) @@ -72,6 +75,7 @@ async def _loop( return result last_msg = result.last_message + assert last_msg is not None, "tool_calls present but no last_message" local_messages.append(last_msg) await asyncio.gather( diff --git a/src/vercel_ai_sdk/agent/tools.py b/src/vercel_ai_sdk/agent/tools.py index 422fd22d..300e0dcf 100644 --- a/src/vercel_ai_sdk/agent/tools.py +++ b/src/vercel_ai_sdk/agent/tools.py @@ -1,6 +1,8 @@ from __future__ import annotations import contextvars +from typing import Any + import vercel_ai_sdk as ai from . import proto @@ -84,4 +86,4 @@ async def bash(command: str, timeout: int | None = None) -> str: return await _fs().bash(command, timeout=timeout) -BUILTIN_TOOLS: list[ai.Tool] = [read, write, edit, ls, glob, grep, bash] +BUILTIN_TOOLS: list[ai.Tool[..., Any]] = [read, write, edit, ls, glob, grep, bash] From 871faedf85c6abf4d540c8319635f3ee9003e24f Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Thu, 19 Feb 2026 17:12:56 -0800 Subject: [PATCH 15/20] Add type annotations and type ignores to examples --- examples/fastapi-vite/backend/agent.py | 8 +++-- examples/fastapi-vite/backend/main.py | 2 +- examples/fastapi-vite/backend/routes/chat.py | 6 ++-- examples/fastapi-vite/backend/storage.py | 2 +- examples/multiagent-textual/server.py | 37 +++++++++++++------- examples/samples/agent.py | 7 ++-- examples/samples/custom_loop.py | 11 +++--- examples/samples/hooks.py | 17 ++++++--- examples/samples/mcp.py | 8 +++-- examples/samples/multiagent.py | 4 +-- examples/samples/simple.py | 4 +-- examples/samples/streaming_tool.py | 4 +-- 12 files changed, 71 insertions(+), 39 deletions(-) diff --git a/examples/fastapi-vite/backend/agent.py b/examples/fastapi-vite/backend/agent.py index 22562595..d5e98867 100644 --- a/examples/fastapi-vite/backend/agent.py +++ b/examples/fastapi-vite/backend/agent.py @@ -2,6 +2,8 @@ import os +from typing import Any + import vercel_ai_sdk as ai @@ -20,14 +22,14 @@ def get_llm() -> ai.LanguageModel: ) -TOOLS: list[ai.Tool] = [talk_to_mothership] +TOOLS: list[ai.Tool[..., Any]] = [talk_to_mothership] async def graph( llm: ai.LanguageModel, messages: list[ai.Message], - tools: list[ai.Tool], -): + tools: list[ai.Tool[..., Any]], +) -> ai.StreamResult: """ Agent graph: stream LLM, execute tools, repeat until done. diff --git a/examples/fastapi-vite/backend/main.py b/examples/fastapi-vite/backend/main.py index a38fd2ab..ca4a4a0b 100644 --- a/examples/fastapi-vite/backend/main.py +++ b/examples/fastapi-vite/backend/main.py @@ -22,6 +22,6 @@ @app.get("/api/health") -async def health(): +async def health() -> dict[str, str]: """Health check endpoint.""" return {"status": "ok"} diff --git a/examples/fastapi-vite/backend/routes/chat.py b/examples/fastapi-vite/backend/routes/chat.py index 0ed362b4..26326b20 100644 --- a/examples/fastapi-vite/backend/routes/chat.py +++ b/examples/fastapi-vite/backend/routes/chat.py @@ -2,6 +2,8 @@ from __future__ import annotations +from collections.abc import AsyncGenerator + import fastapi import fastapi.responses import pydantic @@ -24,7 +26,7 @@ class ChatRequest(pydantic.BaseModel): @router.post("/chat") -async def chat(request: ChatRequest): +async def chat(request: ChatRequest) -> fastapi.responses.StreamingResponse: """Handle chat requests and stream responses.""" messages = ai.ai_sdk_ui.to_messages(request.messages) session_id = request.session_id or "default" @@ -41,7 +43,7 @@ async def chat(request: ChatRequest): result = ai.run(agent.graph, llm, messages, agent.TOOLS, checkpoint=checkpoint) - async def stream_response(): + async def stream_response() -> AsyncGenerator[str]: async for chunk in ai.ai_sdk_ui.to_sse_stream(result): yield chunk diff --git a/examples/fastapi-vite/backend/storage.py b/examples/fastapi-vite/backend/storage.py index 2bd4c59d..de2cfc0e 100644 --- a/examples/fastapi-vite/backend/storage.py +++ b/examples/fastapi-vite/backend/storage.py @@ -44,7 +44,7 @@ async def get(self, key: str) -> dict[str, Any] | None: path = self._path(key) if not path.exists(): return None - return json.loads(path.read_text()) + return json.loads(path.read_text()) # type: ignore[no-any-return] async def put(self, key: str, value: dict[str, Any]) -> None: path = self._path(key) diff --git a/examples/multiagent-textual/server.py b/examples/multiagent-textual/server.py index e12cd06e..e1d7d719 100644 --- a/examples/multiagent-textual/server.py +++ b/examples/multiagent-textual/server.py @@ -15,6 +15,8 @@ import os import warnings +from typing import Any + import fastapi import pydantic @@ -58,7 +60,7 @@ class Approval(pydantic.BaseModel): # --------------------------------------------------------------------------- -async def mothership_branch(llm: ai.LanguageModel, query: str): +async def mothership_branch(llm: ai.LanguageModel, query: str) -> ai.StreamResult: """Agent that contacts the mothership, gated by an approval hook.""" messages = ai.make_messages( system="You are assistant 1. Use contact_mothership when asked about the future.", @@ -74,7 +76,10 @@ async def mothership_branch(llm: ai.LanguageModel, query: str): for tc in result.tool_calls: if tc.tool_name == "contact_mothership": - approval = await Approval.create( + # TODO: mypy doesn't support class decorators that change the + # class type — @ai.hook returns type[Hook[T]] but mypy still + # sees the original BaseModel. + approval = await Approval.create( # type: ignore[attr-defined] f"mothership_{tc.tool_call_id}", metadata={"branch": "mothership", "tool": tc.tool_name}, ) @@ -85,12 +90,13 @@ async def mothership_branch(llm: ai.LanguageModel, query: str): else: await ai.execute_tool(tc, message=result.last_message) - messages.append(result.last_message) + if result.last_message is not None: + messages.append(result.last_message) return result -async def data_center_branch(llm: ai.LanguageModel, query: str): +async def data_center_branch(llm: ai.LanguageModel, query: str) -> ai.StreamResult: """Agent that contacts data centers, gated by an approval hook.""" messages = ai.make_messages( system="You are assistant 2. Use contact_data_centers when asked about the future.", @@ -106,7 +112,10 @@ async def data_center_branch(llm: ai.LanguageModel, query: str): for tc in result.tool_calls: if tc.tool_name == "contact_data_centers": - approval = await Approval.create( + # TODO: mypy doesn't support class decorators that change the + # class type — @ai.hook returns type[Hook[T]] but mypy still + # sees the original BaseModel. + approval = await Approval.create( # type: ignore[attr-defined] f"data_centers_{tc.tool_call_id}", metadata={"branch": "data_centers", "tool": tc.tool_name}, ) @@ -117,7 +126,8 @@ async def data_center_branch(llm: ai.LanguageModel, query: str): else: await ai.execute_tool(tc, message=result.last_message) - messages.append(result.last_message) + if result.last_message is not None: + messages.append(result.last_message) return result @@ -127,7 +137,7 @@ async def data_center_branch(llm: ai.LanguageModel, query: str): # --------------------------------------------------------------------------- -async def multiagent(llm: ai.LanguageModel, query: str): +async def multiagent(llm: ai.LanguageModel, query: str) -> ai.StreamResult: """Run two gated agents in parallel, then summarise their results.""" r1, r2 = await asyncio.gather( mothership_branch(llm, query), @@ -154,7 +164,7 @@ async def multiagent(llm: ai.LanguageModel, query: str): # --------------------------------------------------------------------------- -def _normalise_message(data: dict) -> dict: +def _normalise_message(data: dict[str, Any]) -> dict[str, Any]: """Ensure ToolPart.result is always a dict for safe deserialisation.""" for part in data.get("parts", []): if part.get("type") == "tool" and isinstance(part.get("result"), str): @@ -168,7 +178,7 @@ def _normalise_message(data: dict) -> dict: @app.websocket("/ws") -async def ws_endpoint(websocket: fastapi.WebSocket): +async def ws_endpoint(websocket: fastapi.WebSocket) -> None: await websocket.accept() print("Client connected") @@ -181,13 +191,16 @@ async def ws_endpoint(websocket: fastapi.WebSocket): result = ai.run(multiagent, llm, "When will the robots take over?") # Background task: read hook resolutions from the client. - async def read_resolutions(): + async def read_resolutions() -> None: try: while True: raw = await websocket.receive_text() data = json.loads(raw) print(f" Resolution received: {data['hook_id']}") - Approval.resolve( + # TODO: mypy doesn't support class decorators that change the + # class type — @ai.hook returns type[Hook[T]] but mypy still + # sees the original BaseModel. + Approval.resolve( # type: ignore[attr-defined] data["hook_id"], {"granted": data["granted"], "reason": data["reason"]}, ) @@ -220,5 +233,5 @@ async def read_resolutions(): @app.get("/api/health") -async def health(): +async def health() -> dict[str, str]: return {"status": "ok"} diff --git a/examples/samples/agent.py b/examples/samples/agent.py index ebca2c6a..215105e8 100644 --- a/examples/samples/agent.py +++ b/examples/samples/agent.py @@ -7,7 +7,7 @@ import vercel_ai_sdk.agent as agent -async def main(): +async def main() -> None: llm = ai.openai.OpenAIModel( model="anthropic/claude-sonnet-4-20250514", base_url="https://ai-gateway.vercel.sh/v1", @@ -25,7 +25,10 @@ async def main(): async for msg in coding_agent.run(messages): # Auto-approve all tool calls if (hook := msg.get_hook_part()) and hook.status == "pending": - agent.ToolApproval.resolve(hook.hook_id, {"granted": True}) + # TODO: mypy doesn't support class decorators that change the + # class type — @ai.hook returns type[Hook[T]] but mypy still + # sees the original BaseModel. + agent.ToolApproval.resolve(hook.hook_id, {"granted": True}) # type: ignore[attr-defined] if msg.text_delta: print(msg.text_delta, end="", flush=True) diff --git a/examples/samples/custom_loop.py b/examples/samples/custom_loop.py index 3c10355a..c5165770 100644 --- a/examples/samples/custom_loop.py +++ b/examples/samples/custom_loop.py @@ -4,6 +4,8 @@ import os from collections.abc import AsyncGenerator +from typing import Any + import vercel_ai_sdk as ai @@ -25,7 +27,7 @@ async def get_population(city: str) -> int: async def custom_stream_step( llm: ai.LanguageModel, messages: list[ai.Message], - tools: list[ai.Tool], + tools: list[ai.Tool[..., Any]], label: str | None = None, ) -> AsyncGenerator[ai.Message, None]: """Wraps llm.stream to inject a label on every message.""" @@ -34,7 +36,7 @@ async def custom_stream_step( yield msg -async def agent(llm: ai.LanguageModel, user_query: str): +async def agent(llm: ai.LanguageModel, user_query: str) -> ai.StreamResult: """Custom agent loop with manual tool execution. Uses @ai.stream for custom streaming, stream_step-style while loop, @@ -52,7 +54,8 @@ async def agent(llm: ai.LanguageModel, user_query: str): if not result.tool_calls: return result - messages.append(result.last_message) + if result.last_message is not None: + messages.append(result.last_message) await asyncio.gather( *( ai.execute_tool(tc, message=result.last_message) @@ -61,7 +64,7 @@ async def agent(llm: ai.LanguageModel, user_query: str): ) -async def main(): +async def main() -> None: llm = ai.anthropic.AnthropicModel( model="anthropic/claude-sonnet-4", base_url="https://ai-gateway.vercel.sh", diff --git a/examples/samples/hooks.py b/examples/samples/hooks.py index 0ff7031c..a1c4c14f 100644 --- a/examples/samples/hooks.py +++ b/examples/samples/hooks.py @@ -20,7 +20,7 @@ class CommunicationApproval(pydantic.BaseModel): reason: str -async def graph(llm: ai.LanguageModel, query: str): +async def graph(llm: ai.LanguageModel, query: str) -> ai.StreamResult: messages = ai.make_messages( system="Use the contact_mothership tool when asked about the future.", user=query, @@ -36,7 +36,10 @@ async def graph(llm: ai.LanguageModel, query: str): for tc in result.tool_calls: if tc.tool_name == "contact_mothership": # Blocks until resolved (long-running) or cancelled (serverless) - approval = await CommunicationApproval.create( + # TODO: mypy doesn't support class decorators that change the + # class type — @ai.hook returns type[Hook[T]] but mypy still + # sees the original BaseModel. + approval = await CommunicationApproval.create( # type: ignore[attr-defined] f"approve_{tc.tool_call_id}", metadata={"tool": tc.tool_name}, ) @@ -47,12 +50,13 @@ async def graph(llm: ai.LanguageModel, query: str): else: await ai.execute_tool(tc, message=result.last_message) - messages.append(result.last_message) + if result.last_message is not None: + messages.append(result.last_message) return result -async def main(): +async def main() -> None: llm = ai.openai.OpenAIModel( model="anthropic/claude-sonnet-4-20250514", base_url="https://ai-gateway.vercel.sh/v1", @@ -63,7 +67,10 @@ async def main(): # Hook parts arrive as pending, waiting for resolution if (hook := msg.get_hook_part()) and hook.status == "pending": answer = input(f"Approve {hook.hook_id}? [y/n] ") - CommunicationApproval.resolve( + # TODO: mypy doesn't support class decorators that change the + # class type — @ai.hook returns type[Hook[T]] but mypy still + # sees the original BaseModel. + CommunicationApproval.resolve( # type: ignore[attr-defined] hook.hook_id, { "granted": answer.strip().lower() in ("y", "yes"), diff --git a/examples/samples/mcp.py b/examples/samples/mcp.py index 410aa768..9d6971fa 100644 --- a/examples/samples/mcp.py +++ b/examples/samples/mcp.py @@ -5,13 +5,15 @@ import rich +from typing import Any + import vercel_ai_sdk as ai -async def context7_agent(llm: ai.LanguageModel, user_query: str): +async def context7_agent(llm: ai.LanguageModel, user_query: str) -> ai.StreamResult: """Agent with Context7 MCP tools for up-to-date library documentation.""" - context7_tools: list[ai.Tool] = await ai.mcp.get_http_tools( + context7_tools: list[ai.Tool[..., Any]] = await ai.mcp.get_http_tools( "https://mcp.context7.com/mcp", headers={"CONTEXT7_API_KEY": os.environ.get("CONTEXT7_API_KEY", "")}, tool_prefix="context7", @@ -28,7 +30,7 @@ async def context7_agent(llm: ai.LanguageModel, user_query: str): ) -async def main(): +async def main() -> None: llm = ai.openai.OpenAIModel( model="openai/gpt-4.1", base_url="https://ai-gateway.vercel.sh/v1", diff --git a/examples/samples/multiagent.py b/examples/samples/multiagent.py index feceae2d..098b56c7 100644 --- a/examples/samples/multiagent.py +++ b/examples/samples/multiagent.py @@ -16,7 +16,7 @@ async def multiply_by_two(number: int) -> int: return number * 2 -async def multiagent(llm: ai.LanguageModel, user_query: str): +async def multiagent(llm: ai.LanguageModel, user_query: str) -> ai.StreamResult: """Run two agents in parallel, then combine their results.""" result1, result2 = await asyncio.gather( @@ -53,7 +53,7 @@ async def multiagent(llm: ai.LanguageModel, user_query: str): ) -async def main(): +async def main() -> None: llm = ai.anthropic.AnthropicModel( model="anthropic/claude-haiku-4.5", base_url="https://ai-gateway.vercel.sh", diff --git a/examples/samples/simple.py b/examples/samples/simple.py index f06c1284..69ba4196 100644 --- a/examples/samples/simple.py +++ b/examples/samples/simple.py @@ -9,7 +9,7 @@ async def talk_to_mothership(question: str) -> str: return "Soon." -async def agent(llm: ai.LanguageModel, user_query: str): +async def agent(llm: ai.LanguageModel, user_query: str) -> ai.StreamResult: return await ai.stream_loop( llm, messages=ai.make_messages( @@ -20,7 +20,7 @@ async def agent(llm: ai.LanguageModel, user_query: str): ) -async def main(): +async def main() -> None: llm = ai.openai.OpenAIModel( model="anthropic/claude-sonnet-4", base_url="https://ai-gateway.vercel.sh/v1", diff --git a/examples/samples/streaming_tool.py b/examples/samples/streaming_tool.py index d67196a3..01eafd0b 100644 --- a/examples/samples/streaming_tool.py +++ b/examples/samples/streaming_tool.py @@ -22,7 +22,7 @@ async def talk_to_mothership(question: str, runtime: ai.Runtime) -> str: return "The mothership says: Soon." -async def agent(llm: ai.LanguageModel, user_query: str): +async def agent(llm: ai.LanguageModel, user_query: str) -> ai.StreamResult: return await ai.stream_loop( llm, messages=ai.make_messages( @@ -33,7 +33,7 @@ async def agent(llm: ai.LanguageModel, user_query: str): ) -async def main(): +async def main() -> None: llm = ai.openai.OpenAIModel( model="anthropic/claude-sonnet-4", base_url="https://ai-gateway.vercel.sh/v1", From c79427d31515c7208855c5d375a4022e6f4e33ca Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Thu, 19 Feb 2026 17:51:48 -0800 Subject: [PATCH 16/20] Update the ToolLike usage in the durable example --- examples/temporal-durable/workflow.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/examples/temporal-durable/workflow.py b/examples/temporal-durable/workflow.py index 1f6d9c9a..5baf999b 100644 --- a/examples/temporal-durable/workflow.py +++ b/examples/temporal-durable/workflow.py @@ -33,7 +33,14 @@ async def stream( result = await self.call_fn( activities.LLMCallParams( messages=[m.model_dump() for m in messages], - tool_schemas=[t.schema.model_dump() for t in (tools or [])], + tool_schemas=[ + { + "name": t.name, + "description": t.description, + "param_schema": t.param_schema, + } + for t in (tools or []) + ], ) ) yield ai.Message.model_validate(result.message) From 1fa78c0c91f680a352f1f6cfc9d8080cbe7a9f8f Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Thu, 19 Feb 2026 18:04:43 -0800 Subject: [PATCH 17/20] Add missing type annotations to tests --- tests/ai_sdk_ui/test_adapter.py | 16 ++++++------- tests/core/test_checkpoint.py | 35 +++++++++++++++-------------- tests/core/test_hooks.py | 25 +++++++++++---------- tests/core/test_llm.py | 16 ++++++------- tests/core/test_messages.py | 40 ++++++++++++++++----------------- tests/core/test_runtime.py | 32 +++++++++++++------------- tests/core/test_streams.py | 12 +++++----- tests/core/test_tools.py | 16 ++++++------- tests/mcp/test_client.py | 14 ++++++------ 9 files changed, 104 insertions(+), 102 deletions(-) diff --git a/tests/ai_sdk_ui/test_adapter.py b/tests/ai_sdk_ui/test_adapter.py index 4b7971c3..4770dabe 100644 --- a/tests/ai_sdk_ui/test_adapter.py +++ b/tests/ai_sdk_ui/test_adapter.py @@ -16,7 +16,7 @@ async def get_event_types(msgs: list[messages.Message]) -> list[str]: """Stream messages through adapter and return event type sequence.""" - async def stream(): + async def stream() -> AsyncGenerator[messages.Message, None]: for m in msgs: yield m @@ -29,7 +29,7 @@ async def stream(): @pytest.mark.asyncio -async def test_text_streaming(): +async def test_text_streaming() -> None: """Text: start -> start-step -> text-start/delta/end -> finish-step -> finish""" msgs = [ messages.Message( @@ -66,7 +66,7 @@ async def test_text_streaming(): @pytest.mark.asyncio -async def test_tool_roundtrip(): +async def test_tool_roundtrip() -> None: """Server-side tool: input-available -> output-available -> text response. Reference: process-ui-message-stream.test.ts "server-side tool roundtrip" @@ -127,7 +127,7 @@ async def test_tool_roundtrip(): @pytest.mark.asyncio -async def test_text_then_tool_then_text(): +async def test_text_then_tool_then_text() -> None: """Full mothership scenario: text -> tool -> result -> final text. Input: "when will the robots take over?" @@ -250,7 +250,7 @@ async def mock_agent( @pytest.mark.asyncio -async def test_runtime_tool_roundtrip(): +async def test_runtime_tool_roundtrip() -> None: """ Integration test: run a mock agent loop through ai.run() and verify that tool-input-available and tool-output-available events are emitted. @@ -340,7 +340,7 @@ async def _async_iter( # ----------------------------------------------------------------------------- -def test_ui_to_internal_two_turn_with_tool(): +def test_ui_to_internal_two_turn_with_tool() -> None: """Test converting a realistic two-turn conversation with tool call. This test uses the exact payload structure from a real AI SDK frontend @@ -443,7 +443,7 @@ def test_ui_to_internal_two_turn_with_tool(): assert internal[2].text == "this is a test run. can you remember the first turn?" -def test_ui_tool_part_with_dict_input(): +def test_ui_tool_part_with_dict_input() -> None: """Test that tool parts with dict input (not JSON string) are handled.""" raw_message = { "id": "msg-1", @@ -468,7 +468,7 @@ def test_ui_tool_part_with_dict_input(): assert tool_part.status == "pending" # input-available maps to pending -def test_ui_skips_unsupported_parts(): +def test_ui_skips_unsupported_parts() -> None: """Test that unsupported part types are skipped gracefully.""" raw_message = { "id": "msg-1", diff --git a/tests/core/test_checkpoint.py b/tests/core/test_checkpoint.py index eefc995e..2cfe9967 100644 --- a/tests/core/test_checkpoint.py +++ b/tests/core/test_checkpoint.py @@ -1,6 +1,7 @@ """Checkpoint replay, hook cancellation/resolution, serialization.""" import asyncio +from typing import Any import pydantic import pytest @@ -20,8 +21,8 @@ class Approval(pydantic.BaseModel): @pytest.mark.asyncio -async def test_step_replay_skips_llm(): - async def graph(llm: ai.LanguageModel): +async def test_step_replay_skips_llm() -> None: + async def graph(llm: ai.LanguageModel) -> ai.StreamResult: return await ai.stream_step( llm, messages=ai.make_messages(system="test", user="hello") ) @@ -39,7 +40,7 @@ async def graph(llm: ai.LanguageModel): @pytest.mark.asyncio -async def test_tool_replay_skips_execution(): +async def test_tool_replay_skips_execution() -> None: execution_count = 0 @ai.tool @@ -49,7 +50,7 @@ async def counting_tool(x: int) -> int: execution_count += 1 return x + 1 - async def graph(llm: ai.LanguageModel): + async def graph(llm: ai.LanguageModel) -> ai.StreamResult: result = await ai.stream_step(llm, ai.make_messages(system="t", user="go")) if result.tool_calls: await asyncio.gather( @@ -76,8 +77,8 @@ async def graph(llm: ai.LanguageModel): @pytest.mark.asyncio -async def test_hook_cancellation_pending(): - async def graph(llm: ai.LanguageModel): +async def test_hook_cancellation_pending() -> None: + async def graph(llm: ai.LanguageModel) -> Any: await ai.stream_step(llm, ai.make_messages(system="t", user="go")) return await Approval.create("my_approval", metadata={"tool": "test"}) @@ -89,8 +90,8 @@ async def graph(llm: ai.LanguageModel): @pytest.mark.asyncio -async def test_hook_resolution_on_reentry(): - async def graph(llm: ai.LanguageModel): +async def test_hook_resolution_on_reentry() -> None: + async def graph(llm: ai.LanguageModel) -> Any: await ai.stream_step(llm, ai.make_messages(system="t", user="go")) return await Approval.create("my_approval") @@ -107,14 +108,14 @@ async def graph(llm: ai.LanguageModel): @pytest.mark.asyncio -async def test_parallel_hooks_all_collected(): - async def graph(llm: ai.LanguageModel): +async def test_parallel_hooks_all_collected() -> None: + async def graph(llm: ai.LanguageModel) -> None: await ai.stream_step(llm, ai.make_messages(system="t", user="go")) - async def a(): + async def a() -> Any: return await Approval.create("hook_a") - async def b(): + async def b() -> Any: return await Approval.create("hook_b") async with asyncio.TaskGroup() as tg: @@ -127,14 +128,14 @@ async def b(): @pytest.mark.asyncio -async def test_parallel_hooks_resolve_on_reentry(): - async def graph(llm: ai.LanguageModel): +async def test_parallel_hooks_resolve_on_reentry() -> None: + async def graph(llm: ai.LanguageModel) -> Any: await ai.stream_step(llm, ai.make_messages(system="t", user="go")) - async def a(): + async def a() -> Any: return await Approval.create("hook_a") - async def b(): + async def b() -> Any: return await Approval.create("hook_b") async with asyncio.TaskGroup() as tg: @@ -157,7 +158,7 @@ async def b(): # -- Serialization --------------------------------------------------------- -def test_checkpoint_serialization_roundtrip(): +def test_checkpoint_serialization_roundtrip() -> None: cp = Checkpoint( steps=[ StepEvent( diff --git a/tests/core/test_hooks.py b/tests/core/test_hooks.py index 058f3e90..b479ab74 100644 --- a/tests/core/test_hooks.py +++ b/tests/core/test_hooks.py @@ -1,6 +1,7 @@ """Hooks: live resolution, cancellation, pre-registration, schema validation.""" import asyncio +from typing import Any import pydantic import pytest @@ -21,11 +22,11 @@ class Confirmation(pydantic.BaseModel): @pytest.mark.asyncio -async def test_resolve_live_future(): +async def test_resolve_live_future() -> None: """In long-running mode, Hook.resolve() unblocks the awaiting coroutine.""" resolved_value = None - async def graph(llm: ai.LanguageModel): + async def graph(llm: ai.LanguageModel) -> None: nonlocal resolved_value await ai.stream_step(llm, ai.make_messages(user="go")) result = await Confirmation.create("confirm_1") @@ -57,11 +58,11 @@ async def graph(llm: ai.LanguageModel): @pytest.mark.asyncio -async def test_cancel_live_hook(): +async def test_cancel_live_hook() -> None: """Hook.cancel() cancels the future, causing CancelledError in graph.""" was_cancelled = False - async def graph(llm: ai.LanguageModel): + async def graph(llm: ai.LanguageModel) -> None: nonlocal was_cancelled await ai.stream_step(llm, ai.make_messages(user="go")) try: @@ -83,7 +84,7 @@ async def graph(llm: ai.LanguageModel): @pytest.mark.asyncio -async def test_cancel_nonexistent_raises(): +async def test_cancel_nonexistent_raises() -> None: with pytest.raises(ValueError, match="No pending hook"): await Confirmation.cancel("does_not_exist_xyz") @@ -92,10 +93,10 @@ async def test_cancel_nonexistent_raises(): @pytest.mark.asyncio -async def test_pre_registered_resolution_consumed(): +async def test_pre_registered_resolution_consumed() -> None: """Pre-registered resolution is consumed by Hook.create() without suspending.""" - async def graph(llm: ai.LanguageModel): + async def graph(llm: ai.LanguageModel) -> Any: await ai.stream_step(llm, ai.make_messages(user="go")) result = await Confirmation.create("pre_reg_1") return result @@ -116,7 +117,7 @@ async def graph(llm: ai.LanguageModel): # -- Schema validation on resolve ----------------------------------------- -def test_resolve_validates_schema(): +def test_resolve_validates_schema() -> None: """resolve() with invalid data raises from pydantic validation.""" # 'approved' is required bool, passing string should raise with pytest.raises(pydantic.ValidationError): @@ -127,10 +128,10 @@ def test_resolve_validates_schema(): @pytest.mark.asyncio -async def test_resolved_hook_emits_message(): +async def test_resolved_hook_emits_message() -> None: """After resolution, a 'resolved' HookPart message is emitted.""" - async def graph(llm: ai.LanguageModel): + async def graph(llm: ai.LanguageModel) -> None: await ai.stream_step(llm, ai.make_messages(user="go")) await Confirmation.create("emit_test") @@ -156,8 +157,8 @@ async def graph(llm: ai.LanguageModel): @pytest.mark.asyncio -async def test_hook_metadata_in_pending(): - async def graph(llm: ai.LanguageModel): +async def test_hook_metadata_in_pending() -> None: + async def graph(llm: ai.LanguageModel) -> None: await ai.stream_step(llm, ai.make_messages(user="go")) await Confirmation.create("meta_test", metadata={"tool": "rm -rf", "path": "/"}) diff --git a/tests/core/test_llm.py b/tests/core/test_llm.py index ba33a55b..66055362 100644 --- a/tests/core/test_llm.py +++ b/tests/core/test_llm.py @@ -21,7 +21,7 @@ # -- Text streaming -------------------------------------------------------- -def test_text_lifecycle(): +def test_text_lifecycle() -> None: h = StreamHandler(message_id="m1") m = h.handle_event(TextStart(block_id="b1")) assert len(m.parts) == 1 @@ -46,7 +46,7 @@ def test_text_lifecycle(): # -- Reasoning streaming --------------------------------------------------- -def test_reasoning_lifecycle(): +def test_reasoning_lifecycle() -> None: h = StreamHandler(message_id="m1") h.handle_event(ReasoningStart(block_id="r1")) m = h.handle_event(ReasoningDelta(block_id="r1", delta="thinking")) @@ -62,7 +62,7 @@ def test_reasoning_lifecycle(): # -- Tool streaming -------------------------------------------------------- -def test_tool_lifecycle(): +def test_tool_lifecycle() -> None: h = StreamHandler(message_id="m1") h.handle_event(ToolStart(tool_call_id="tc1", tool_name="get_weather")) m = h.handle_event(ToolArgsDelta(tool_call_id="tc1", delta='{"ci')) @@ -83,7 +83,7 @@ def test_tool_lifecycle(): # -- Multi-part messages --------------------------------------------------- -def test_reasoning_then_text_then_tool(): +def test_reasoning_then_text_then_tool() -> None: """Full message: reasoning block, text block, tool call.""" h = StreamHandler(message_id="m1") h.handle_event(ReasoningStart(block_id="r1")) @@ -105,7 +105,7 @@ def test_reasoning_then_text_then_tool(): assert all(p.state == "done" for p in m.parts) -def test_multiple_tool_calls(): +def test_multiple_tool_calls() -> None: """Parallel tool calls in one message.""" h = StreamHandler(message_id="m1") h.handle_event(ToolStart(tool_call_id="tc1", tool_name="read_file")) @@ -128,7 +128,7 @@ def test_multiple_tool_calls(): # -- MessageDone ----------------------------------------------------------- -def test_message_done_finalizes_all(): +def test_message_done_finalizes_all() -> None: h = StreamHandler(message_id="m1") h.handle_event(TextStart(block_id="t1")) h.handle_event(TextDelta(block_id="t1", delta="hello")) @@ -141,13 +141,13 @@ def test_message_done_finalizes_all(): # -- Message properties propagate ------------------------------------------ -def test_message_id_propagates(): +def test_message_id_propagates() -> None: h = StreamHandler(message_id="custom-id") m = h.handle_event(TextStart(block_id="b1")) assert m.id == "custom-id" -def test_deltas_only_on_active_blocks(): +def test_deltas_only_on_active_blocks() -> None: """Delta should be None on inactive blocks, present only on active.""" h = StreamHandler(message_id="m1") h.handle_event(TextStart(block_id="t1")) diff --git a/tests/core/test_messages.py b/tests/core/test_messages.py index 608675bb..91d958cd 100644 --- a/tests/core/test_messages.py +++ b/tests/core/test_messages.py @@ -14,7 +14,7 @@ # -- is_done --------------------------------------------------------------- -def test_is_done_all_done(): +def test_is_done_all_done() -> None: m = Message( id="m1", role="assistant", @@ -26,7 +26,7 @@ def test_is_done_all_done(): assert m.is_done is True -def test_is_done_streaming(): +def test_is_done_streaming() -> None: m = Message( id="m1", role="assistant", @@ -35,7 +35,7 @@ def test_is_done_streaming(): assert m.is_done is False -def test_is_done_no_state(): +def test_is_done_no_state() -> None: """Parts without state (restored from storage) count as done.""" m = Message(id="m1", role="assistant", parts=[TextPart(text="hi")]) assert m.is_done is True @@ -44,7 +44,7 @@ def test_is_done_no_state(): # -- text / reasoning properties ------------------------------------------- -def test_text_returns_first_text_part(): +def test_text_returns_first_text_part() -> None: m = Message( id="m1", role="assistant", @@ -56,7 +56,7 @@ def test_text_returns_first_text_part(): assert m.text == "first" -def test_text_empty_when_no_text_parts(): +def test_text_empty_when_no_text_parts() -> None: m = Message( id="m1", role="assistant", @@ -65,7 +65,7 @@ def test_text_empty_when_no_text_parts(): assert m.text == "" -def test_reasoning_returns_first(): +def test_reasoning_returns_first() -> None: m = Message( id="m1", role="assistant", @@ -77,7 +77,7 @@ def test_reasoning_returns_first(): # -- deltas ---------------------------------------------------------------- -def test_text_delta(): +def test_text_delta() -> None: m = Message( id="m1", role="assistant", @@ -86,12 +86,12 @@ def test_text_delta(): assert m.text_delta == "b" -def test_text_delta_empty_when_no_delta(): +def test_text_delta_empty_when_no_delta() -> None: m = Message(id="m1", role="assistant", parts=[TextPart(text="done", state="done")]) assert m.text_delta == "" -def test_reasoning_delta(): +def test_reasoning_delta() -> None: m = Message( id="m1", role="assistant", @@ -100,7 +100,7 @@ def test_reasoning_delta(): assert m.reasoning_delta == "b" -def test_tool_deltas(): +def test_tool_deltas() -> None: m = Message( id="m1", role="assistant", @@ -123,7 +123,7 @@ def test_tool_deltas(): # -- tool_calls / get_tool_part ------------------------------------------- -def test_tool_calls(): +def test_tool_calls() -> None: m = Message( id="m1", role="assistant", @@ -137,7 +137,7 @@ def test_tool_calls(): assert m.tool_calls[0].tool_call_id == "tc1" -def test_get_tool_part_found(): +def test_get_tool_part_found() -> None: m = Message( id="m1", role="assistant", @@ -147,7 +147,7 @@ def test_get_tool_part_found(): assert m.get_tool_part("tc1").tool_name == "t" -def test_get_tool_part_missing(): +def test_get_tool_part_missing() -> None: m = Message(id="m1", role="assistant", parts=[TextPart(text="no tools")]) assert m.get_tool_part("tc-nope") is None @@ -155,7 +155,7 @@ def test_get_tool_part_missing(): # -- get_hook_part --------------------------------------------------------- -def test_get_hook_part_found(): +def test_get_hook_part_found() -> None: """get_hook_part returns the HookPart when present.""" hook = HookPart(hook_id="h1", hook_type="Approval", status="pending") m = Message(id="m1", role="assistant", parts=[hook]) @@ -163,7 +163,7 @@ def test_get_hook_part_found(): assert m.get_hook_part("h1") is hook -def test_get_hook_part_by_id(): +def test_get_hook_part_by_id() -> None: """get_hook_part with a specific hook_id skips non-matching hooks.""" h1 = HookPart(hook_id="h1", hook_type="Approval", status="pending") h2 = HookPart(hook_id="h2", hook_type="Approval", status="resolved") @@ -171,7 +171,7 @@ def test_get_hook_part_by_id(): assert m.get_hook_part("h2") is h2 -def test_get_hook_part_missing(): +def test_get_hook_part_missing() -> None: """get_hook_part returns None when no HookPart exists.""" m = Message(id="m1", role="assistant", parts=[TextPart(text="no hooks")]) assert m.get_hook_part() is None @@ -181,7 +181,7 @@ def test_get_hook_part_missing(): # -- ToolPart.set_result / set_error --------------------------------------- -def test_set_result(): +def test_set_result() -> None: tp = ToolPart(tool_call_id="tc1", tool_name="t", tool_args="{}") assert tp.status == "pending" tp.set_result({"answer": 42}) @@ -189,7 +189,7 @@ def test_set_result(): assert tp.result == {"answer": 42} -def test_set_error(): +def test_set_error() -> None: tp = ToolPart(tool_call_id="tc1", tool_name="t", tool_args="{}") assert tp.status == "pending" tp.set_error("Something went wrong") @@ -200,7 +200,7 @@ def test_set_error(): # -- make_messages --------------------------------------------------------- -def test_make_messages_system_and_user(): +def test_make_messages_system_and_user() -> None: msgs = make_messages(system="You are helpful.", user="Hi") assert len(msgs) == 2 assert msgs[0].role == "system" @@ -209,7 +209,7 @@ def test_make_messages_system_and_user(): assert msgs[1].text == "Hi" -def test_make_messages_user_only(): +def test_make_messages_user_only() -> None: msgs = make_messages(user="Hi") assert len(msgs) == 1 assert msgs[0].role == "user" diff --git a/tests/core/test_runtime.py b/tests/core/test_runtime.py index 5933eb40..64ec86c4 100644 --- a/tests/core/test_runtime.py +++ b/tests/core/test_runtime.py @@ -30,10 +30,10 @@ async def concat(a: str, b: str) -> str: @pytest.mark.asyncio -async def test_stream_loop_text_only(): +async def test_stream_loop_text_only() -> None: """stream_loop with no tool calls returns after one LLM call.""" - async def graph(llm: ai.LanguageModel): + async def graph(llm: ai.LanguageModel) -> ai.StreamResult: return await ai.stream_loop( llm, messages=ai.make_messages(user="Hi"), @@ -51,10 +51,10 @@ async def graph(llm: ai.LanguageModel): @pytest.mark.asyncio -async def test_stream_loop_tool_then_text(): +async def test_stream_loop_tool_then_text() -> None: """stream_loop calls tool, feeds result back, gets final text.""" - async def graph(llm: ai.LanguageModel): + async def graph(llm: ai.LanguageModel) -> ai.StreamResult: return await ai.stream_loop( llm, messages=ai.make_messages(user="Double 5"), @@ -80,10 +80,10 @@ async def graph(llm: ai.LanguageModel): @pytest.mark.asyncio -async def test_stream_loop_parallel_tools(): +async def test_stream_loop_parallel_tools() -> None: """LLM returns two tool calls in one message; both execute.""" - async def graph(llm: ai.LanguageModel): + async def graph(llm: ai.LanguageModel) -> ai.StreamResult: return await ai.stream_loop( llm, messages=ai.make_messages(user="Double 3 and 7"), @@ -129,10 +129,10 @@ async def graph(llm: ai.LanguageModel): @pytest.mark.asyncio -async def test_stream_loop_multi_turn(): +async def test_stream_loop_multi_turn() -> None: """LLM calls a tool, then calls another tool, then returns text.""" - async def graph(llm: ai.LanguageModel): + async def graph(llm: ai.LanguageModel) -> ai.StreamResult: return await ai.stream_loop( llm, messages=ai.make_messages(user="Concat then double"), @@ -155,13 +155,13 @@ async def graph(llm: ai.LanguageModel): @pytest.mark.asyncio -async def test_execute_tool_missing_raises(): +async def test_execute_tool_missing_raises() -> None: """execute_tool with unknown tool name raises ValueError (wrapped in ExceptionGroup by TaskGroup).""" tc = messages.ToolPart( tool_call_id="tc-1", tool_name="nonexistent_tool_zzz", tool_args="{}" ) - async def graph(llm: ai.LanguageModel): + async def graph(llm: ai.LanguageModel) -> None: await ai.execute_tool(tc) result = ai.run(graph, MockLLM([])) @@ -174,7 +174,7 @@ async def graph(llm: ai.LanguageModel): @pytest.mark.asyncio -async def test_execute_tool_injects_runtime(): +async def test_execute_tool_injects_runtime() -> None: """Tools with a Runtime parameter get the active runtime injected.""" received_rt = None @@ -185,7 +185,7 @@ async def introspect(query: str, rt: Runtime) -> str: received_rt = rt return "ok" - async def graph(llm: ai.LanguageModel): + async def graph(llm: ai.LanguageModel) -> None: result = await ai.stream_step(llm, ai.make_messages(user="go")) if result.tool_calls: await asyncio.gather( @@ -206,10 +206,10 @@ async def graph(llm: ai.LanguageModel): @pytest.mark.asyncio -async def test_execute_tool_updates_message(): +async def test_execute_tool_updates_message() -> None: """After execute_tool, the ToolPart in the message has status=result.""" - async def graph(llm: ai.LanguageModel): + async def graph(llm: ai.LanguageModel) -> None: result = await ai.stream_step(llm, ai.make_messages(user="go")) if result.tool_calls: msg = result.last_message @@ -228,10 +228,10 @@ async def graph(llm: ai.LanguageModel): @pytest.mark.asyncio -async def test_stream_loop_checkpoint_records_tools(): +async def test_stream_loop_checkpoint_records_tools() -> None: """stream_loop's tool executions are recorded in the checkpoint.""" - async def graph(llm: ai.LanguageModel): + async def graph(llm: ai.LanguageModel) -> ai.StreamResult: return await ai.stream_loop( llm, messages=ai.make_messages(user="Double 4"), diff --git a/tests/core/test_streams.py b/tests/core/test_streams.py index 9180bf9e..1fa21a6d 100644 --- a/tests/core/test_streams.py +++ b/tests/core/test_streams.py @@ -15,14 +15,14 @@ # -- StreamResult properties ----------------------------------------------- -def test_stream_result_empty(): +def test_stream_result_empty() -> None: r = StreamResult() assert r.last_message is None assert r.tool_calls == [] assert r.text == "" -def test_stream_result_last_message(): +def test_stream_result_last_message() -> None: m1 = text_msg("first", id="m1") m2 = text_msg("second", id="m2") r = StreamResult(messages=[m1, m2]) @@ -30,7 +30,7 @@ def test_stream_result_last_message(): assert r.text == "second" -def test_stream_result_tool_calls(): +def test_stream_result_tool_calls() -> None: m = messages.Message( id="m1", role="assistant", @@ -51,7 +51,7 @@ def test_stream_result_tool_calls(): @pytest.mark.asyncio -async def test_stream_outside_run_raises(): +async def test_stream_outside_run_raises() -> None: """@stream-decorated fn called without ai.run() should raise.""" with pytest.raises(ValueError, match="No Runtime context"): await ai.stream_step( @@ -64,10 +64,10 @@ async def test_stream_outside_run_raises(): @pytest.mark.asyncio -async def test_stream_step_replays_from_checkpoint(): +async def test_stream_step_replays_from_checkpoint() -> None: """stream_step inside ai.run with a checkpoint replays without calling LLM.""" - async def graph(llm: ai.LanguageModel): + async def graph(llm: ai.LanguageModel) -> ai.StreamResult: return await ai.stream_step(llm, ai.make_messages(user="hello")) # First run diff --git a/tests/core/test_tools.py b/tests/core/test_tools.py index b465a612..0aa02820 100644 --- a/tests/core/test_tools.py +++ b/tests/core/test_tools.py @@ -12,7 +12,7 @@ # -- Schema extraction from type hints ------------------------------------ -def test_simple_types_produce_correct_schema(): +def test_simple_types_produce_correct_schema() -> None: @ai.tool async def greet(name: str, count: int) -> str: """Say hello.""" @@ -26,7 +26,7 @@ async def greet(name: str, count: int) -> str: assert set(greet.param_schema["required"]) == {"name", "count"} -def test_optional_param_not_required(): +def test_optional_param_not_required() -> None: @ai.tool async def search(query: str, limit: Optional[int] = None) -> str: """Search.""" @@ -38,7 +38,7 @@ async def search(query: str, limit: Optional[int] = None) -> str: assert "limit" in search.param_schema["properties"] -def test_default_value_not_required(): +def test_default_value_not_required() -> None: @ai.tool async def fetch(url: str, timeout: int = 30) -> str: """Fetch URL.""" @@ -48,7 +48,7 @@ async def fetch(url: str, timeout: int = 30) -> str: assert "timeout" not in search_required(fetch) -def test_complex_type_schema(): +def test_complex_type_schema() -> None: @ai.tool async def send(recipients: list[str], urgent: bool = False) -> str: """Send message.""" @@ -62,7 +62,7 @@ async def send(recipients: list[str], urgent: bool = False) -> str: # -- Runtime parameter skipping ------------------------------------------- -def test_runtime_param_excluded_from_schema(): +def test_runtime_param_excluded_from_schema() -> None: @ai.tool async def needs_runtime(query: str, rt: Runtime) -> str: """Tool that needs runtime.""" @@ -77,7 +77,7 @@ async def needs_runtime(query: str, rt: Runtime) -> str: # -- Registry ------------------------------------------------------------- -def test_tool_registered_on_decoration(): +def test_tool_registered_on_decoration() -> None: @ai.tool async def unique_tool_abc() -> str: """Unique.""" @@ -86,7 +86,7 @@ async def unique_tool_abc() -> str: assert get_tool("unique_tool_abc") is unique_tool_abc -def test_get_tool_returns_none_for_missing(): +def test_get_tool_returns_none_for_missing() -> None: assert get_tool("nonexistent_tool_xyz") is None @@ -94,7 +94,7 @@ def test_get_tool_returns_none_for_missing(): @pytest.mark.asyncio -async def test_tool_fn_is_callable(): +async def test_tool_fn_is_callable() -> None: @ai.tool async def add(a: int, b: int) -> int: """Add two numbers.""" diff --git a/tests/mcp/test_client.py b/tests/mcp/test_client.py index c9d3893b..3e530e1f 100644 --- a/tests/mcp/test_client.py +++ b/tests/mcp/test_client.py @@ -27,7 +27,7 @@ def _fake_mcp_tool( ) -def _noop_transport_factory(): +def _noop_transport_factory() -> None: """Dummy transport factory — never actually called in these tests.""" raise NotImplementedError("should not be called") @@ -35,7 +35,7 @@ def _noop_transport_factory(): # -- _mcp_tool_to_native registers in global registry ---------------------- -def test_mcp_tool_to_native_registers_in_global_registry(): +def test_mcp_tool_to_native_registers_in_global_registry() -> None: """Converting an MCP tool to native registers it in _tool_registry.""" mcp_tool = _fake_mcp_tool(name="mcp_reg_test") native = _mcp_tool_to_native(mcp_tool, "test:key", _noop_transport_factory, None) @@ -45,7 +45,7 @@ def test_mcp_tool_to_native_registers_in_global_registry(): assert _tool_registry["mcp_reg_test"] is native -def test_mcp_tool_to_native_with_prefix(): +def test_mcp_tool_to_native_with_prefix() -> None: """Tool prefix is prepended to the name and both name forms are correct.""" mcp_tool = _fake_mcp_tool(name="echo") native = _mcp_tool_to_native(mcp_tool, "test:key", _noop_transport_factory, "ctx7") @@ -54,7 +54,7 @@ def test_mcp_tool_to_native_with_prefix(): assert get_tool("ctx7_echo") is native -def test_mcp_tool_to_native_schema_preserved(): +def test_mcp_tool_to_native_schema_preserved() -> None: """The inputSchema from the MCP tool is passed through as param_schema.""" mcp_tool = _fake_mcp_tool() native = _mcp_tool_to_native(mcp_tool, "test:key", _noop_transport_factory, None) @@ -67,11 +67,11 @@ def test_mcp_tool_to_native_schema_preserved(): @pytest.mark.asyncio -async def test_mcp_tool_executes_through_stream_loop(): +async def test_mcp_tool_executes_through_stream_loop() -> None: """An MCP-style tool registered via _mcp_tool_to_native can be called by the agent loop.""" call_log: list[dict] = [] - async def fake_fn(**kwargs): + async def fake_fn(**kwargs: str) -> str: call_log.append(kwargs) return f"echoed: {kwargs.get('text', '')}" @@ -83,7 +83,7 @@ async def fake_fn(**kwargs): native._fn = fake_fn _tool_registry[native.name] = native - async def graph(llm: ai.LanguageModel): + async def graph(llm: ai.LanguageModel) -> ai.StreamResult: return await ai.stream_loop( llm, messages=ai.make_messages(user="echo hello"), From 9368ab11950ddbea1c483df2c1629d4ee616bf32 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Thu, 19 Feb 2026 18:11:56 -0800 Subject: [PATCH 18/20] Add type ignores in hook tests --- tests/core/test_checkpoint.py | 20 ++++++++++---------- tests/core/test_hooks.py | 24 ++++++++++++------------ 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/tests/core/test_checkpoint.py b/tests/core/test_checkpoint.py index 2cfe9967..c3c209f1 100644 --- a/tests/core/test_checkpoint.py +++ b/tests/core/test_checkpoint.py @@ -80,27 +80,27 @@ async def graph(llm: ai.LanguageModel) -> ai.StreamResult: async def test_hook_cancellation_pending() -> None: async def graph(llm: ai.LanguageModel) -> Any: await ai.stream_step(llm, ai.make_messages(system="t", user="go")) - return await Approval.create("my_approval", metadata={"tool": "test"}) + return await Approval.create("my_approval", metadata={"tool": "test"}) # type: ignore[attr-defined] result = ai.run(graph, MockLLM([[text_msg("OK")]]), cancel_on_hooks=True) msgs = [msg async for msg in result] assert "my_approval" in result.pending_hooks hook_msgs = [m for m in msgs if any(isinstance(p, ai.HookPart) for p in m.parts)] - assert hook_msgs[0].parts[0].status == "pending" + assert hook_msgs[0].parts[0].status == "pending" # type: ignore[union-attr] @pytest.mark.asyncio async def test_hook_resolution_on_reentry() -> None: async def graph(llm: ai.LanguageModel) -> Any: await ai.stream_step(llm, ai.make_messages(system="t", user="go")) - return await Approval.create("my_approval") + return await Approval.create("my_approval") # type: ignore[attr-defined] resp = [text_msg("OK")] result1 = ai.run(graph, MockLLM([resp]), cancel_on_hooks=True) [msg async for msg in result1] cp = result1.checkpoint - Approval.resolve("my_approval", {"granted": True}) + Approval.resolve("my_approval", {"granted": True}) # type: ignore[attr-defined] result2 = ai.run(graph, MockLLM([]), checkpoint=cp) [msg async for msg in result2] assert len(result2.pending_hooks) == 0 @@ -113,10 +113,10 @@ async def graph(llm: ai.LanguageModel) -> None: await ai.stream_step(llm, ai.make_messages(system="t", user="go")) async def a() -> Any: - return await Approval.create("hook_a") + return await Approval.create("hook_a") # type: ignore[attr-defined] async def b() -> Any: - return await Approval.create("hook_b") + return await Approval.create("hook_b") # type: ignore[attr-defined] async with asyncio.TaskGroup() as tg: tg.create_task(a()) @@ -133,10 +133,10 @@ async def graph(llm: ai.LanguageModel) -> Any: await ai.stream_step(llm, ai.make_messages(system="t", user="go")) async def a() -> Any: - return await Approval.create("hook_a") + return await Approval.create("hook_a") # type: ignore[attr-defined] async def b() -> Any: - return await Approval.create("hook_b") + return await Approval.create("hook_b") # type: ignore[attr-defined] async with asyncio.TaskGroup() as tg: ta = tg.create_task(a()) @@ -148,8 +148,8 @@ async def b() -> Any: [msg async for msg in result1] cp = result1.checkpoint - Approval.resolve("hook_a", {"granted": True}) - Approval.resolve("hook_b", {"granted": False}) + Approval.resolve("hook_a", {"granted": True}) # type: ignore[attr-defined] + Approval.resolve("hook_b", {"granted": False}) # type: ignore[attr-defined] result2 = ai.run(graph, MockLLM([]), checkpoint=cp) [msg async for msg in result2] assert len(result2.pending_hooks) == 0 diff --git a/tests/core/test_hooks.py b/tests/core/test_hooks.py index b479ab74..06c6bd40 100644 --- a/tests/core/test_hooks.py +++ b/tests/core/test_hooks.py @@ -29,7 +29,7 @@ async def test_resolve_live_future() -> None: async def graph(llm: ai.LanguageModel) -> None: nonlocal resolved_value await ai.stream_step(llm, ai.make_messages(user="go")) - result = await Confirmation.create("confirm_1") + result = await Confirmation.create("confirm_1") # type: ignore[attr-defined] resolved_value = result llm = MockLLM([[text_msg("OK")]]) @@ -41,7 +41,7 @@ async def graph(llm: ai.LanguageModel) -> None: collected.append(msg) # When we see the pending hook message, resolve it if any(isinstance(p, ai.HookPart) and p.status == "pending" for p in msg.parts): - Confirmation.resolve( + Confirmation.resolve( # type: ignore[attr-defined] "confirm_1", {"approved": True, "reason": "looks good"} ) @@ -66,7 +66,7 @@ async def graph(llm: ai.LanguageModel) -> None: nonlocal was_cancelled await ai.stream_step(llm, ai.make_messages(user="go")) try: - await Confirmation.create("cancel_me") + await Confirmation.create("cancel_me") # type: ignore[attr-defined] except asyncio.CancelledError: was_cancelled = True @@ -75,7 +75,7 @@ async def graph(llm: ai.LanguageModel) -> None: async for msg in run_result: if any(isinstance(p, ai.HookPart) and p.status == "pending" for p in msg.parts): - await Confirmation.cancel("cancel_me", reason="denied") + await Confirmation.cancel("cancel_me", reason="denied") # type: ignore[attr-defined] assert was_cancelled @@ -86,7 +86,7 @@ async def graph(llm: ai.LanguageModel) -> None: @pytest.mark.asyncio async def test_cancel_nonexistent_raises() -> None: with pytest.raises(ValueError, match="No pending hook"): - await Confirmation.cancel("does_not_exist_xyz") + await Confirmation.cancel("does_not_exist_xyz") # type: ignore[attr-defined] # -- Pre-registration (serverless re-entry) -------------------------------- @@ -98,11 +98,11 @@ async def test_pre_registered_resolution_consumed() -> None: async def graph(llm: ai.LanguageModel) -> Any: await ai.stream_step(llm, ai.make_messages(user="go")) - result = await Confirmation.create("pre_reg_1") + result = await Confirmation.create("pre_reg_1") # type: ignore[attr-defined] return result # Pre-register BEFORE run - Confirmation.resolve("pre_reg_1", {"approved": True}) + Confirmation.resolve("pre_reg_1", {"approved": True}) # type: ignore[attr-defined] llm = MockLLM([[text_msg("OK")]]) run_result = ai.run(graph, llm) @@ -121,7 +121,7 @@ def test_resolve_validates_schema() -> None: """resolve() with invalid data raises from pydantic validation.""" # 'approved' is required bool, passing string should raise with pytest.raises(pydantic.ValidationError): - Confirmation.resolve("schema_test", {"approved": "not_a_bool"}) + Confirmation.resolve("schema_test", {"approved": "not_a_bool"}) # type: ignore[attr-defined] # -- Resolved hook emits message ------------------------------------------- @@ -133,7 +133,7 @@ async def test_resolved_hook_emits_message() -> None: async def graph(llm: ai.LanguageModel) -> None: await ai.stream_step(llm, ai.make_messages(user="go")) - await Confirmation.create("emit_test") + await Confirmation.create("emit_test") # type: ignore[attr-defined] llm = MockLLM([[text_msg("OK")]]) run_result = ai.run(graph, llm) @@ -142,7 +142,7 @@ async def graph(llm: ai.LanguageModel) -> None: async for msg in run_result: msgs.append(msg) if any(isinstance(p, ai.HookPart) and p.status == "pending" for p in msg.parts): - Confirmation.resolve("emit_test", {"approved": False}) + Confirmation.resolve("emit_test", {"approved": False}) # type: ignore[attr-defined] hook_msgs = [ m @@ -150,7 +150,7 @@ async def graph(llm: ai.LanguageModel) -> None: if any(isinstance(p, ai.HookPart) and p.status == "resolved" for p in m.parts) ] assert len(hook_msgs) == 1 - assert hook_msgs[0].parts[0].resolution == {"approved": False, "reason": ""} + assert hook_msgs[0].parts[0].resolution == {"approved": False, "reason": ""} # type: ignore[union-attr] # -- Hook metadata surfaces in pending message ----------------------------- @@ -160,7 +160,7 @@ async def graph(llm: ai.LanguageModel) -> None: async def test_hook_metadata_in_pending() -> None: async def graph(llm: ai.LanguageModel) -> None: await ai.stream_step(llm, ai.make_messages(user="go")) - await Confirmation.create("meta_test", metadata={"tool": "rm -rf", "path": "/"}) + await Confirmation.create("meta_test", metadata={"tool": "rm -rf", "path": "/"}) # type: ignore[attr-defined] run_result = ai.run(graph, MockLLM([[text_msg("OK")]]), cancel_on_hooks=True) msgs = [m async for m in run_result] From b7a2046ce3d009ca006a33dc0aef746a662e674f Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Thu, 19 Feb 2026 18:24:34 -0800 Subject: [PATCH 19/20] Correctly narrow types and add missing annotations in tests --- tests/conftest.py | 2 +- tests/core/test_llm.py | 77 ++++++++++++++++++++++++------------- tests/core/test_messages.py | 11 ++++-- tests/core/test_runtime.py | 1 + tests/core/test_streams.py | 4 +- tests/core/test_tools.py | 6 ++- tests/mcp/test_client.py | 6 ++- 7 files changed, 71 insertions(+), 36 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 6445cd8c..14e2a025 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -45,7 +45,7 @@ def tool_msg( name: str = "test_tool", args: str = "{}", status: str = "pending", - result: dict | None = None, + result: dict[str, object] | None = None, ) -> messages.Message: return messages.Message( id=id, diff --git a/tests/core/test_llm.py b/tests/core/test_llm.py index 66055362..3bae1c81 100644 --- a/tests/core/test_llm.py +++ b/tests/core/test_llm.py @@ -25,22 +25,29 @@ def test_text_lifecycle() -> None: h = StreamHandler(message_id="m1") m = h.handle_event(TextStart(block_id="b1")) assert len(m.parts) == 1 - assert isinstance(m.parts[0], TextPart) - assert m.parts[0].state == "streaming" - assert m.parts[0].text == "" + part = m.parts[0] + assert isinstance(part, TextPart) + assert part.state == "streaming" + assert part.text == "" m = h.handle_event(TextDelta(block_id="b1", delta="Hello")) - assert m.parts[0].text == "Hello" - assert m.parts[0].delta == "Hello" - assert m.parts[0].state == "streaming" + part = m.parts[0] + assert isinstance(part, TextPart) + assert part.text == "Hello" + assert part.delta == "Hello" + assert part.state == "streaming" m = h.handle_event(TextDelta(block_id="b1", delta=" world")) - assert m.parts[0].text == "Hello world" - assert m.parts[0].delta == " world" + part = m.parts[0] + assert isinstance(part, TextPart) + assert part.text == "Hello world" + assert part.delta == " world" m = h.handle_event(TextEnd(block_id="b1")) - assert m.parts[0].state == "done" - assert m.parts[0].delta is None + part = m.parts[0] + assert isinstance(part, TextPart) + assert part.state == "done" + assert part.delta is None # -- Reasoning streaming --------------------------------------------------- @@ -50,13 +57,16 @@ def test_reasoning_lifecycle() -> None: h = StreamHandler(message_id="m1") h.handle_event(ReasoningStart(block_id="r1")) m = h.handle_event(ReasoningDelta(block_id="r1", delta="thinking")) - assert isinstance(m.parts[0], ReasoningPart) - assert m.parts[0].text == "thinking" - assert m.parts[0].state == "streaming" + part = m.parts[0] + assert isinstance(part, ReasoningPart) + assert part.text == "thinking" + assert part.state == "streaming" m = h.handle_event(ReasoningEnd(block_id="r1", signature="sig123")) - assert m.parts[0].state == "done" - assert m.parts[0].signature == "sig123" + part = m.parts[0] + assert isinstance(part, ReasoningPart) + assert part.state == "done" + assert part.signature == "sig123" # -- Tool streaming -------------------------------------------------------- @@ -66,18 +76,23 @@ def test_tool_lifecycle() -> None: h = StreamHandler(message_id="m1") h.handle_event(ToolStart(tool_call_id="tc1", tool_name="get_weather")) m = h.handle_event(ToolArgsDelta(tool_call_id="tc1", delta='{"ci')) - assert isinstance(m.parts[0], ToolPart) - assert m.parts[0].tool_name == "get_weather" - assert m.parts[0].tool_args == '{"ci' - assert m.parts[0].state == "streaming" - assert m.parts[0].args_delta == '{"ci' + part = m.parts[0] + assert isinstance(part, ToolPart) + assert part.tool_name == "get_weather" + assert part.tool_args == '{"ci' + assert part.state == "streaming" + assert part.args_delta == '{"ci' m = h.handle_event(ToolArgsDelta(tool_call_id="tc1", delta='ty":"London"}')) - assert m.parts[0].tool_args == '{"city":"London"}' + part = m.parts[0] + assert isinstance(part, ToolPart) + assert part.tool_args == '{"city":"London"}' m = h.handle_event(ToolEnd(tool_call_id="tc1")) - assert m.parts[0].state == "done" - assert m.parts[0].args_delta is None + part = m.parts[0] + assert isinstance(part, ToolPart) + assert part.state == "done" + assert part.args_delta is None # -- Multi-part messages --------------------------------------------------- @@ -102,7 +117,11 @@ def test_reasoning_then_text_then_tool() -> None: assert isinstance(m.parts[0], ReasoningPart) assert isinstance(m.parts[1], TextPart) assert isinstance(m.parts[2], ToolPart) - assert all(p.state == "done" for p in m.parts) + assert all( + p.state == "done" + for p in m.parts + if isinstance(p, (TextPart, ToolPart, ReasoningPart)) + ) def test_multiple_tool_calls() -> None: @@ -122,7 +141,11 @@ def test_multiple_tool_calls() -> None: h.handle_event(ToolArgsDelta(tool_call_id="tc2", delta='{"dir":"."}')) h.handle_event(ToolEnd(tool_call_id="tc1")) m = h.handle_event(ToolEnd(tool_call_id="tc2")) - assert all(p.state == "done" for p in m.parts) + assert all( + p.state == "done" + for p in m.parts + if isinstance(p, (TextPart, ToolPart, ReasoningPart)) + ) # -- MessageDone ----------------------------------------------------------- @@ -134,7 +157,9 @@ def test_message_done_finalizes_all() -> None: h.handle_event(TextDelta(block_id="t1", delta="hello")) # Don't send TextEnd -- MessageDone should finalize everything m = h.handle_event(MessageDone(finish_reason="end_turn")) - assert m.parts[0].state == "done" + part = m.parts[0] + assert isinstance(part, TextPart) + assert part.state == "done" assert m.is_done diff --git a/tests/core/test_messages.py b/tests/core/test_messages.py index 91d958cd..7d4a991a 100644 --- a/tests/core/test_messages.py +++ b/tests/core/test_messages.py @@ -143,8 +143,9 @@ def test_get_tool_part_found() -> None: role="assistant", parts=[ToolPart(tool_call_id="tc1", tool_name="t", tool_args="{}")], ) - assert m.get_tool_part("tc1") is not None - assert m.get_tool_part("tc1").tool_name == "t" + tp = m.get_tool_part("tc1") + assert tp is not None + assert tp.tool_name == "t" def test_get_tool_part_missing() -> None: @@ -185,7 +186,9 @@ def test_set_result() -> None: tp = ToolPart(tool_call_id="tc1", tool_name="t", tool_args="{}") assert tp.status == "pending" tp.set_result({"answer": 42}) - assert tp.status == "result" + # mypy narrows status to Literal["pending"] from the constructor default and + # can't track that set_result() mutates it to "result" + assert tp.status == "result" # type: ignore[comparison-overlap] assert tp.result == {"answer": 42} @@ -193,7 +196,7 @@ def test_set_error() -> None: tp = ToolPart(tool_call_id="tc1", tool_name="t", tool_args="{}") assert tp.status == "pending" tp.set_error("Something went wrong") - assert tp.status == "error" + assert tp.status == "error" # type: ignore[comparison-overlap] assert tp.result == "Something went wrong" diff --git a/tests/core/test_runtime.py b/tests/core/test_runtime.py index 64ec86c4..b5bc2905 100644 --- a/tests/core/test_runtime.py +++ b/tests/core/test_runtime.py @@ -216,6 +216,7 @@ async def graph(llm: ai.LanguageModel) -> None: for tc in result.tool_calls: await ai.execute_tool(tc, message=msg) # Verify the tool part was mutated + assert msg is not None assert msg.tool_calls[0].status == "result" assert msg.tool_calls[0].result == 10 diff --git a/tests/core/test_streams.py b/tests/core/test_streams.py index 1fa21a6d..31803244 100644 --- a/tests/core/test_streams.py +++ b/tests/core/test_streams.py @@ -26,7 +26,9 @@ def test_stream_result_last_message() -> None: m1 = text_msg("first", id="m1") m2 = text_msg("second", id="m2") r = StreamResult(messages=[m1, m2]) - assert r.last_message.id == "m2" + last = r.last_message + assert last is not None + assert last.id == "m2" assert r.text == "second" diff --git a/tests/core/test_tools.py b/tests/core/test_tools.py index 0aa02820..563fa5f0 100644 --- a/tests/core/test_tools.py +++ b/tests/core/test_tools.py @@ -107,5 +107,7 @@ async def add(a: int, b: int) -> int: # -- Helpers --------------------------------------------------------------- -def search_required(tool: ai.Tool) -> list[str]: - return tool.param_schema.get("required", []) +def search_required(tool: ai.Tool[..., object]) -> list[str]: + result = tool.param_schema.get("required", []) + assert isinstance(result, list) + return result diff --git a/tests/mcp/test_client.py b/tests/mcp/test_client.py index 3e530e1f..3cf88123 100644 --- a/tests/mcp/test_client.py +++ b/tests/mcp/test_client.py @@ -1,6 +1,8 @@ """MCP client: tool registration in global registry, end-to-end execution.""" import asyncio +import contextlib +from typing import Any import mcp.types import pytest @@ -27,7 +29,7 @@ def _fake_mcp_tool( ) -def _noop_transport_factory() -> None: +def _noop_transport_factory() -> contextlib.AbstractAsyncContextManager[Any]: """Dummy transport factory — never actually called in these tests.""" raise NotImplementedError("should not be called") @@ -69,7 +71,7 @@ def test_mcp_tool_to_native_schema_preserved() -> None: @pytest.mark.asyncio async def test_mcp_tool_executes_through_stream_loop() -> None: """An MCP-style tool registered via _mcp_tool_to_native can be called by the agent loop.""" - call_log: list[dict] = [] + call_log: list[dict[str, str]] = [] async def fake_fn(**kwargs: str) -> str: call_log.append(kwargs) From 58eda59c368c83f64fc4c06defad28f5c4075024 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Thu, 19 Feb 2026 18:29:42 -0800 Subject: [PATCH 20/20] Add type to the DataPart in the AI SDK UI protocol --- src/vercel_ai_sdk/ai_sdk_ui/adapter.py | 5 +++++ src/vercel_ai_sdk/ai_sdk_ui/protocol.py | 13 +++++++++++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/vercel_ai_sdk/ai_sdk_ui/adapter.py b/src/vercel_ai_sdk/ai_sdk_ui/adapter.py index cd2f8800..ed080f1a 100644 --- a/src/vercel_ai_sdk/ai_sdk_ui/adapter.py +++ b/src/vercel_ai_sdk/ai_sdk_ui/adapter.py @@ -32,6 +32,11 @@ def _generate_id(prefix: str = "id") -> str: def serialize_part(part: protocol.UIMessageStreamPart) -> str: """Serialize a stream part to JSON with camelCase keys.""" d = dataclasses.asdict(part) + if isinstance(part, protocol.DataPart): + # DataPart's wire type is computed (``data-{data_type}``); replace + # the raw ``data_type`` field with the protocol ``type`` key. + d["type"] = part.type + del d["data_type"] camel_dict = {_to_camel_case(k): v for k, v in d.items() if v is not None} return json.dumps(camel_dict) diff --git a/src/vercel_ai_sdk/ai_sdk_ui/protocol.py b/src/vercel_ai_sdk/ai_sdk_ui/protocol.py index 12d717b4..e7ff7af7 100644 --- a/src/vercel_ai_sdk/ai_sdk_ui/protocol.py +++ b/src/vercel_ai_sdk/ai_sdk_ui/protocol.py @@ -129,9 +129,13 @@ class FilePart: @dataclasses.dataclass class DataPart: - """Custom data parts allow streaming of arbitrary structured data with type-specific handling. + """ + Custom data parts allow streaming of arbitrary structured data with type-specific + handling. - The type will be formatted as 'data-{data_type}' in the output. + The wire type is ``data-{data_type}`` (e.g. ``data-custom``), exposed + via the ``type`` property so that ``DataPart`` is uniform with every + other ``UIMessageStreamPart`` variant. """ data_type: str @@ -139,6 +143,11 @@ class DataPart: id: str | None = None transient: bool | None = None + @property + def type(self) -> str: + """Wire type for the AI SDK SSE protocol.""" + return f"data-{self.data_type}" + @dataclasses.dataclass class ToolInputStartPart: