Skip to content

Commit 8a065a1

Browse files
Merge pull request #118 from askui/feat-add-locate-all
Feat: Add locate all function
2 parents e2c9167 + 1d1c2ee commit 8a065a1

File tree

16 files changed

+129
-55
lines changed

16 files changed

+129
-55
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -367,12 +367,12 @@ class MyGetAndLocateModel(GetModel, LocateModel):
367367
locator: str | Locator,
368368
image: ImageSource,
369369
model_choice: ModelComposition | str,
370-
) -> Point:
370+
) -> PointList:
371371
# Implement custom locate logic, e.g.:
372372
# - Use a different object detection model
373373
# - Implement custom element finding
374374
# - Call external vision services
375-
return (100, 100) # Example coordinates
375+
return [(100, 100)] # Example coordinates
376376

377377

378378
# Create model registry

src/askui/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
OnMessageCb,
2626
OnMessageCbParam,
2727
Point,
28+
PointList,
2829
TextBlockParam,
2930
TextCitationParam,
3031
ToolResultBlockParam,
@@ -82,6 +83,7 @@
8283
"OnMessageCbParam",
8384
"PcKey",
8485
"Point",
86+
"PointList",
8587
"ResponseSchema",
8688
"ResponseSchemaBase",
8789
"Retry",

src/askui/agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def _click(
182182
def _mouse_move(
183183
self, locator: str | Locator, model: ModelComposition | str | None = None
184184
) -> None:
185-
point = self._locate(locator=locator, model=model)
185+
point = self._locate(locator=locator, model=model)[0]
186186
self.tools.os.mouse_move(point[0], point[1])
187187

188188
@telemetry.record_call(exclude={"locator"})

src/askui/agent_base.py

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
ModelName,
3030
ModelRegistry,
3131
Point,
32+
PointList,
3233
TotalModelChoice,
3334
)
3435
from .models.types.response_schemas import ResponseSchema
@@ -352,13 +353,14 @@ class LinkedListNode(ResponseSchemaBase):
352353
self._reporter.add_message("Agent", message_content)
353354
return response
354355

356+
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
355357
def _locate(
356358
self,
357359
locator: str | Locator,
358360
screenshot: Optional[Img] = None,
359361
model: ModelComposition | str | None = None,
360-
) -> Point:
361-
def locate_with_screenshot() -> Point:
362+
) -> PointList:
363+
def locate_with_screenshot() -> PointList:
362364
_screenshot = load_image_source(
363365
self._agent_os.screenshot() if screenshot is None else screenshot
364366
)
@@ -368,10 +370,10 @@ def locate_with_screenshot() -> Point:
368370
model_choice=model or self._model_choice["locate"],
369371
)
370372

371-
point = self._retry.attempt(locate_with_screenshot)
372-
self._reporter.add_message("ModelRouter", f"locate: ({point[0]}, {point[1]})")
373-
logger.debug("ModelRouter locate: (%d, %d)", point[0], point[1])
374-
return point
373+
points = self._retry.attempt(locate_with_screenshot)
374+
self._reporter.add_message("ModelRouter", f"locate {len(points)} elements")
375+
logger.debug("ModelRouter locate: %d elements", len(points))
376+
return points
375377

376378
@telemetry.record_call(exclude={"locator", "screenshot"})
377379
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
@@ -382,7 +384,7 @@ def locate(
382384
model: ModelComposition | str | None = None,
383385
) -> Point:
384386
"""
385-
Locates the UI element identified by the provided locator.
387+
Locates the first matching UI element identified by the provided locator.
386388
387389
Args:
388390
locator (str | Locator): The identifier or description of the element to
@@ -405,8 +407,53 @@ def locate(
405407
print(f"Element found at coordinates: {point}")
406408
```
407409
"""
408-
self._reporter.add_message("User", f"locate {locator}")
409-
logger.debug("VisionAgent received instruction to locate %s", locator)
410+
self._reporter.add_message("User", f"locate first matching element {locator}")
411+
logger.debug(
412+
"VisionAgent received instruction to locate first matching element %s",
413+
locator,
414+
)
415+
return self._locate(locator, screenshot, model)[0]
416+
417+
@telemetry.record_call(exclude={"locator", "screenshot"})
418+
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
419+
def locate_all(
420+
self,
421+
locator: str | Locator,
422+
screenshot: Optional[Img] = None,
423+
model: ModelComposition | str | None = None,
424+
) -> PointList:
425+
"""
426+
Locates all matching UI elements identified by the provided locator.
427+
428+
Note: Some LocateModels can only locate a single element. In this case, the
429+
returned list will have a length of 1.
430+
431+
Args:
432+
locator (str | Locator): The identifier or description of the element to
433+
locate.
434+
screenshot (Img | None, optional): The screenshot to use for locating the
435+
element. Can be a path to an image file, a PIL Image object or a data
436+
URL. If `None`, takes a screenshot of the currently selected display.
437+
model (ModelComposition | str | None, optional): The composition or name
438+
of the model(s) to be used for locating the element using the `locator`.
439+
440+
Returns:
441+
PointList: The coordinates of the elements as a list of tuples (x, y).
442+
443+
Example:
444+
```python
445+
from askui import VisionAgent
446+
447+
with VisionAgent() as agent:
448+
points = agent.locate_all("Submit button")
449+
print(f"Found {len(points)} elements at coordinates: {points}")
450+
```
451+
"""
452+
self._reporter.add_message("User", f"locate all matching UI elements {locator}")
453+
logger.debug(
454+
"VisionAgent received instruction to locate all matching UI elements %s",
455+
locator,
456+
)
410457
return self._locate(locator, screenshot, model)
411458

412459
@telemetry.record_call()

src/askui/android_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def tap(
198198
msg += f" on {target}"
199199
self._reporter.add_message("User", msg)
200200
logger.debug("VisionAgent received instruction to click on %s", target)
201-
point = self._locate(locator=target, model=model)
201+
point = self._locate(locator=target, model=model)[0]
202202
self.os.tap(point[0], point[1])
203203

204204
@telemetry.record_call(exclude={"text"})

src/askui/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
ModelName,
1010
ModelRegistry,
1111
Point,
12+
PointList,
1213
)
1314
from .openrouter.model import OpenRouterModel
1415
from .openrouter.settings import ChatCompletionsCreateSettings, OpenRouterSettings
@@ -53,6 +54,7 @@
5354
"OpenRouterModel",
5455
"OpenRouterSettings",
5556
"Point",
57+
"PointList",
5658
"TextBlockParam",
5759
"TextCitationParam",
5860
"ToolResultBlockParam",

src/askui/models/anthropic/messages_api.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
LocateModel,
2828
ModelComposition,
2929
ModelName,
30-
Point,
30+
PointList,
3131
)
3232
from askui.models.shared.agent_message_param import (
3333
Base64ImageSourceParam,
@@ -198,7 +198,7 @@ def locate(
198198
locator: str | Locator,
199199
image: ImageSource,
200200
model_choice: ModelComposition | str,
201-
) -> Point:
201+
) -> PointList:
202202
if not isinstance(model_choice, str):
203203
error_msg = "Model composition is not supported for Claude"
204204
raise NotImplementedError(error_msg)
@@ -219,12 +219,14 @@ def locate(
219219
),
220220
model_choice=model_choice,
221221
)
222-
return scale_coordinates(
223-
extract_click_coordinates(content),
224-
image.root.size,
225-
self._settings.resolution,
226-
inverse=True,
227-
)
222+
return [
223+
scale_coordinates(
224+
extract_click_coordinates(content),
225+
image.root.size,
226+
self._settings.resolution,
227+
inverse=True,
228+
)
229+
]
228230
except (
229231
_UnexpectedResponseError,
230232
ValueError,

src/askui/models/askui/inference_api.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from askui.locators.serializers import AskUiLocatorSerializer, AskUiSerializedLocator
2121
from askui.logger import logger
2222
from askui.models.exceptions import ElementNotFoundError
23-
from askui.models.models import GetModel, LocateModel, ModelComposition, Point
23+
from askui.models.models import GetModel, LocateModel, ModelComposition, PointList
2424
from askui.models.shared.agent_message_param import MessageParam
2525
from askui.models.shared.messages_api import MessagesApi
2626
from askui.models.shared.settings import MessageSettings
@@ -162,7 +162,7 @@ def locate(
162162
locator: str | Locator,
163163
image: ImageSource,
164164
model_choice: ModelComposition | str,
165-
) -> Point:
165+
) -> PointList:
166166
serialized_locator = (
167167
self._locator_serializer.serialize(locator=locator)
168168
if isinstance(locator, Locator)
@@ -171,7 +171,7 @@ def locate(
171171
logger.debug(f"serialized_locator:\n{json_lib.dumps(serialized_locator)}")
172172
json: dict[str, Any] = {
173173
"image": image.to_data_url(),
174-
"instruction": f"Click on {serialized_locator['instruction']}",
174+
"instruction": f"get element {serialized_locator['instruction']}",
175175
}
176176
if "customElements" in serialized_locator:
177177
json["customElements"] = serialized_locator["customElements"]
@@ -182,17 +182,20 @@ def locate(
182182
)
183183
response = self._post(path="/inference", json=json)
184184
content = response.json()
185-
assert content["type"] == "COMMANDS", (
185+
assert content["type"] == "DETECTED_ELEMENTS", (
186186
f"Received unknown content type {content['type']}"
187187
)
188-
actions = [
189-
el for el in content["data"]["actions"] if el["inputEvent"] == "MOUSE_MOVE"
190-
]
191-
if len(actions) == 0:
188+
detected_elements = content["data"]["detected_elements"]
189+
if len(detected_elements) == 0:
192190
raise ElementNotFoundError(locator, serialized_locator)
193191

194-
position = actions[0]["position"]
195-
return int(position["x"]), int(position["y"])
192+
return [
193+
(
194+
int((element["bndbox"]["xmax"] + element["bndbox"]["xmin"]) / 2),
195+
int((element["bndbox"]["ymax"] + element["bndbox"]["ymin"]) / 2),
196+
)
197+
for element in detected_elements
198+
]
196199

197200
@override
198201
def get(

src/askui/models/askui/model_router.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
ElementNotFoundError,
99
ModelNotFoundError,
1010
)
11-
from askui.models.models import LocateModel, ModelComposition, ModelName, Point
11+
from askui.models.models import LocateModel, ModelComposition, ModelName, PointList
1212
from askui.utils.image_utils import ImageSource
1313

1414

@@ -18,7 +18,7 @@ def __init__(self, inference_api: AskUiInferenceApi):
1818

1919
def _locate_with_askui_ocr(
2020
self, screenshot: ImageSource, locator: str | Text
21-
) -> Point:
21+
) -> PointList:
2222
locator = Text(locator) if isinstance(locator, str) else locator
2323
return self._inference_api.locate(
2424
locator, screenshot, model_choice=ModelName.ASKUI__OCR
@@ -30,7 +30,7 @@ def locate(
3030
locator: str | Locator,
3131
image: ImageSource,
3232
model_choice: ModelComposition | str,
33-
) -> Point:
33+
) -> PointList:
3434
if (
3535
isinstance(model_choice, ModelComposition)
3636
or model_choice == ModelName.ASKUI

src/askui/models/huggingface/spaces_api.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from askui.exceptions import AutomationError
1111
from askui.locators.locators import Locator
1212
from askui.locators.serializers import VlmLocatorSerializer
13-
from askui.models.models import LocateModel, ModelComposition, ModelName, Point
13+
from askui.models.models import LocateModel, ModelComposition, ModelName, PointList
1414
from askui.utils.image_utils import ImageSource
1515

1616

@@ -65,7 +65,7 @@ def locate(
6565
locator: str | Locator,
6666
image: ImageSource,
6767
model_choice: ModelComposition | str,
68-
) -> Point:
68+
) -> PointList:
6969
"""Predict element location using Hugging Face Spaces."""
7070
if not isinstance(model_choice, str):
7171
error_msg = "Model composition is not supported for Hugging Face Spaces"
@@ -76,9 +76,9 @@ def locate(
7676
if isinstance(locator, Locator)
7777
else locator
7878
)
79-
return self._spaces[model_choice](
80-
image.root, serialized_locator, model_choice
81-
)
79+
return [
80+
self._spaces[model_choice](image.root, serialized_locator, model_choice)
81+
]
8282
except (ValueError, json.JSONDecodeError, httpx.HTTPError) as e:
8383
error_msg = f"Hugging Face Spaces Exception: {e}"
8484
raise AutomationError(error_msg) from e

0 commit comments

Comments
 (0)