Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
290 changes: 231 additions & 59 deletions geoagent/core/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

from __future__ import annotations

from dataclasses import dataclass, field
from dataclasses import asdict, dataclass, field
from functools import lru_cache
from typing import Any, Iterable, Sequence
import importlib


@dataclass
Expand All @@ -20,53 +22,196 @@ class GeoToolMeta:
requires_packages: tuple[str, ...] = ()
extra: dict[str, Any] = field(default_factory=dict)

@property
def needs_confirmation(self) -> bool:
"""
Determine whether user confirmation is required

Returns:
True if confirmation should be required before execution
"""
return self.requires_confirmation or self.destructive or self.long_running

def to_dict(self) -> dict[Any, Any]:
"""
Convert metadata into serial dictionary

Returns:
Dictionary representation of the metadata
"""
data = asdict(self)

data["available_in"] = list(self.available_in)
data["requires_confirmation"] = list(self.requires_packages)

Comment on lines +42 to +46

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Fix the serializer key for requires_packages.

Line 45 currently overwrites requires_confirmation with the package list, so get_all_tools_config() emits the wrong schema and loses the actual confirmation flag for every tool.

🐛 Proposed fix
         data = asdict(self)

         data["available_in"] = list(self.available_in)
-        data["requires_confirmation"] = list(self.requires_packages)
+        data["requires_packages"] = list(self.requires_packages)

         return data
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@geoagent/core/registry.py` around lines 42 - 46, In the asdict serialization
block, line 45 incorrectly assigns the list of requires_packages to the
requires_confirmation key, causing the confirmation flag to be overwritten.
Change the key on line 45 from requires_confirmation to requires_packages so
that data["requires_packages"] = list(self.requires_packages) correctly
serializes the packages field without overwriting the confirmation flag.

return data


class GeoToolRegistry:
"""Maps tool names to :class:`GeoToolMeta`."""

def __init__(self) -> None:
self._by_name: dict[str, GeoToolMeta] = {}
self._tools: dict[str, GeoToolMeta] = {}

def register_tool(self, tool_obj: Any, meta: GeoToolMeta) -> None:
"""Register metadata for a Strands decorated tool."""
name = getattr(tool_obj, "tool_name", None) or meta.name
self._by_name[str(name)] = meta
def __contains__(self, tool_name: str) -> bool:
return tool_name in self._tools

def __len__(self) -> int:
return len(self._tools)

def register(self, meta: GeoToolMeta) -> None:
"""Register metadata by explicit name."""
self._by_name[meta.name] = meta
"""
Register metadata by explicit name.

Args:
meta:
Metadata describing the tool.

Raises:
ValueError:
If a tool with the same name has already been registered.
"""
if meta.name in self._tools:
raise ValueError(f"'{meta.name}' already registered.")
self._tools[meta.name] = meta

def register_tool(self, tool_obj: Any, meta: GeoToolMeta) -> None:
"""
Register metadata for a Strands decorated tool.

Args:
tool_obj:
Tool implementation instance

meta:
Metadata associated with the tool
"""
tool_name = getattr(tool_obj, "tool_name", None)

if tool_name is None:
tool_name = meta.name

self.register(
GeoToolMeta(
name=str(tool_name),
description=meta.description,
category=meta.category,
requires_confirmation=meta.requires_confirmation,
destructive=meta.destructive,
long_running=meta.long_running,
available_in=meta.available_in,
requires_packages=meta.requires_packages,
extra=dict(meta.extra),
)
)

def unregister(self, tool_name: str) -> bool:
"""
Remove a tool from the registry

Args:
tool_name:
Name of the tool to remove

Returns:
True if the tool existed and was removed.
otherwise False.
"""
return self._tools.pop(tool_name, None) is not None

def get(self, tool_name: str) -> GeoToolMeta | None:
"""Return metadata for a registered tool name."""
return self._by_name.get(tool_name)
"""
Return metadata for a registered tool name.

Args:
tool_name:
Tool identifier.

Returns:
Tool metadata if found, otherwise None.
"""
return self._tools.get(tool_name)

def require(self, tool_name: str) -> GeoToolMeta:
"""
Return tool metadata

Args:
tool_name:
Tool identifier.

Returns:
KeyError:
If the tool is not registered.
"""
try:
return self._tools[tool_name]
except Exception as exc:
raise KeyError(f"Unknown tool '{tool_name}'") from exc

def exists(self, tool_name: str) -> bool:
return tool_name in self._tools

def list_names(self) -> list[str]:
"""Return registered tool names."""
return sorted(self._by_name.keys())
return sorted(self._tools)

def list_tools(self) -> list[GeoToolMeta]:
return list(self._tools.values())

def get_all_tools_config(self) -> list[dict[str, Any]]:
"""Return tool metadata as serializable config records.

Mirrors the type of inspection users expect from Strands-facing
registries while preserving GeoAgent-specific metadata.
"""
out: list[dict[str, Any]] = []
for name in self.list_names():
meta = self._by_name[name]
out.append(
{
"name": meta.name,
"description": meta.description,
"category": meta.category,
"requires_confirmation": meta.requires_confirmation,
"destructive": meta.destructive,
"long_running": meta.long_running,
"available_in": list(meta.available_in),
"requires_packages": list(meta.requires_packages),
"extra": dict(meta.extra),
}
)
return out
return [
meta.to_dict()
for meta in sorted(self._tools.values(), key=lambda m: m.name)
]

def get_tools_for_context(
self, tool_objects: Sequence[Any], *, context: str
) -> list[Any]:
"""
Filter tools that are available in a specific execution context.

Args:
tool_objects:
Collection of tool implementation.

context:
Execution context name such as ``full`` or ``fast``

Returns:
Filtered list of tool objects that may bbe used in the requested
"""
if context != "fast":
return list(tool_objects)
Comment on lines +189 to +190

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Reject unknown contexts instead of failing open.

Lines 189-190 treat anything except "fast" as "full". A typo like "Fast" or any future restricted context will silently expose every tool instead of preserving the filtered set.

🛡️ Proposed fix
-        if context != "fast":
+        if context == "full":
             return list(tool_objects)
+        if context != "fast":
+            raise ValueError(f"Unknown tool context '{context}'")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@geoagent/core/registry.py` around lines 189 - 190, The condition at lines
189-190 in geoagent/core/registry.py treats any context value that isn't exactly
"fast" as valid and returns all tool_objects, which is a "fail open" pattern.
This means typos like "Fast" or unknown context values will silently expose
every tool instead of being rejected. Replace the current `if context != "fast"`
logic with explicit validation that only accepts known valid context values
(like "fast" and "full") and raises an error for any unrecognized context,
ensuring that unknown or mistyped contexts fail explicitly rather than
defaulting to exposing all tools.


result: list[Any] = []

for tool in tool_objects:
name = getattr(tool, "tool_name", None)

if not name:
continue

meta = self.get(name)

if meta and context in meta.available_in:
result.append(tool)
continue

if name in FAST_TOOL_FALLBACK:
result.append(tool)
return result

def tool_requiring_confirmation(self) -> list[GeoToolMeta]:
"""
Return tool requiring user confirmation
"""
return [meta for meta in self._tools.values() if meta.needs_confirmation]

def needs_user_confirmation(self, meta: GeoToolMeta) -> bool:
"""Return True if the tool should go through the confirm callback."""
Expand Down Expand Up @@ -104,36 +249,63 @@ def needs_user_confirmation(self, meta: GeoToolMeta) -> bool:
)


@lru_cache(maxsize=128)
def package_available(package: str) -> bool:
"""
Return True if every named package is importable.

Args:
package:
Package name.

Returns:
True if the package is importable, otherwise False.
"""

try:
importlib.import_module(package)
return True
except ImportError:
return False


def packages_available(packages: Iterable[str]) -> bool:
"""
Verify that all required packages are available.

Args:
packages:
Collection of package names.

Returns:
True only if every package can be imported.
"""
return all(package_available(package) for package in packages)


def collect_tools_for_context(
tool_objects: Sequence[Any],
*,
fast: bool,
registry: GeoToolRegistry,
tool_objects: Sequence[Any], *, fast: bool, registry: GeoToolRegistry
) -> list[Any]:
"""Filter tools for fast mode using metadata or :data:`FAST_TOOL_FALLBACK`."""

if not fast:
return list(tool_objects)

out: list[Any] = []
for t in tool_objects:
name = getattr(t, "tool_name", None)
if name is None:
continue
meta = registry.get(str(name))
if meta is not None and "fast" in meta.available_in:
out.append(t)
continue
if str(name) in FAST_TOOL_FALLBACK:
out.append(t)
return out


def packages_available(requires: Iterable[str]) -> bool:
"""Return True if every named package is importable."""
for pkg in requires:
try:
__import__(pkg)
except ImportError:
return False
return True
"""
Convenience wrapper for context-based tool filtering.

This function preserves backward compatibility with older code that
uses a boolean ``fast`` flag instead of explicitly specifying a
context name.

Args:
tool_objects:
Collection of tool implementations.

fast:
Whether fast execution mode is enabled.

registry:
Registry used to evaluate tool availability.

Returns:
Filtered list of tool objects.
"""
context = "fast" if fast else "full"

return registry.get_tools_for_context(tool_objects, context=context)
Loading