Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion libreyolo/models/base/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,8 @@ def _set_device(self, device: str) -> None:
target = torch.device(device_str)
if target != self.model.device:
self.model.device = target
self.model.model.to(target)
if hasattr(self.model.model, "to"):
self.model.model.to(target)

def _process_in_batches(
self,
Expand Down
3 changes: 2 additions & 1 deletion libreyolo/models/l2cs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,4 +343,5 @@ def _set_device(self, device: str) -> None:
target = torch.device(device_str)
if target != self.model.device:
self.model.device = target
self.model.model.to(target)
if hasattr(self.model.model, "to"):
self.model.model.to(target)
56 changes: 56 additions & 0 deletions tests/unit/test_backend_validation_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,59 @@ def __init__(self):

assert backend.size == "n"
assert backend.FAMILY == "rfdetr"


def test_backend_eval_proxy_has_no_to():
from libreyolo.backends.base import _BackendEvalProxy

proxy = _BackendEvalProxy()
assert not hasattr(proxy, "to")
assert hasattr(proxy, "eval")


def test_set_device_does_not_call_to_on_backend_proxy():
"""_set_device must not raise when model.model is a _BackendEvalProxy."""
from types import SimpleNamespace
from unittest.mock import patch

from libreyolo.backends.base import _BackendEvalProxy
from libreyolo.models.base.inference import InferenceRunner

proxy = _BackendEvalProxy()
fake_model = SimpleNamespace(
device=torch.device("cpu"),
model=proxy,
)

runner = object.__new__(InferenceRunner)
runner.model = fake_model

# Calling _set_device with a different device must not raise AttributeError.
with patch("torch.cuda.is_available", return_value=True):
runner._set_device("cuda:0")

# model.device is updated; proxy is left untouched (no .to() call).
assert fake_model.device == torch.device("cuda:0")


def test_l2cs_set_device_does_not_call_to_on_backend_proxy():
"""GazeInferenceRunner._set_device must also tolerate _BackendEvalProxy."""
from types import SimpleNamespace
from unittest.mock import patch

from libreyolo.backends.base import _BackendEvalProxy
from libreyolo.models.l2cs.inference import GazeInferenceRunner

proxy = _BackendEvalProxy()
fake_model = SimpleNamespace(
device=torch.device("cpu"),
model=proxy,
)

runner = object.__new__(GazeInferenceRunner)
runner.model = fake_model

with patch("torch.cuda.is_available", return_value=True):
runner._set_device("cuda:0")

assert fake_model.device == torch.device("cuda:0")
Loading