Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 160 additions & 63 deletions rld/attributation.py

Large diffs are not rendered by default.

22 changes: 11 additions & 11 deletions rld/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
ObsLikeStrict,
ObsTensorLike,
ObsTensorStrict,
HiddenState,
HiddenStateTensor,
)


Expand All @@ -40,19 +40,19 @@ def _forward_unimplemented(self, *input: Any) -> None:


class RecurrentModel(Model, ABC):
def initial_state(self) -> HiddenState:
def initial_state(self) -> HiddenStateTensor:
"""
[B x 2 x 1 x CELL_SIZE]
"""
raise NotImplementedError

def last_output_state(self) -> HiddenState:
def last_output_state(self) -> HiddenStateTensor:
raise NotImplementedError

def reshape_to_torch(self, state: HiddenState) -> HiddenState:
def reshape_to_torch(self, state: HiddenStateTensor) -> HiddenStateTensor:
return state.permute((1, 2, 0, 3))

def reshape_to_store(self, state: HiddenState) -> HiddenState:
def reshape_to_store(self, state: HiddenStateTensor) -> HiddenStateTensor:
return state.permute((2, 0, 1, 3))


Expand Down Expand Up @@ -102,9 +102,9 @@ def __init__(self, model: RayModel, lstm_cell_size: int):
super().__init__(model)
self.lstm_cell_size = lstm_cell_size

self._last_state: Optional[HiddenState] = None
self._last_state: Optional[HiddenStateTensor] = None

def forward(self, obs_flat: ObsTensorStrict, state: HiddenState):
def forward(self, obs_flat: ObsTensorStrict, state: HiddenStateTensor):
if isinstance(self.obs_space(), Box):
# We need to unpack e.g. image-like observation,
# as RLlib doesn't flatten them into 1D vectors
Expand All @@ -127,26 +127,26 @@ def forward(self, obs_flat: ObsTensorStrict, state: HiddenState):

return logits

def initial_state(self) -> HiddenState:
def initial_state(self) -> HiddenStateTensor:
initial_state = [
torch.tensor(s, device=self.input_device())
for s in self.model.get_initial_state()
]
return torch.stack(initial_state).unsqueeze(dim=0).unsqueeze(dim=2)

def last_output_state(self) -> HiddenState:
def last_output_state(self) -> HiddenStateTensor:
if self._last_state is None:
raise RuntimeError(
"Trying to get last output hidden state without calling "
"forward() first."
)
return self._last_state

def reshape_to_torch(self, state: HiddenState) -> HiddenState:
def reshape_to_torch(self, state: HiddenStateTensor) -> HiddenStateTensor:
permuted = super().reshape_to_torch(state)
return permuted.squeeze(dim=1)

def reshape_to_store(self, state: HiddenState) -> HiddenState:
def reshape_to_store(self, state: HiddenStateTensor) -> HiddenStateTensor:
return super().reshape_to_store(state.unsqueeze(dim=1))


Expand Down
36 changes: 11 additions & 25 deletions rld/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
ActionLike,
InfoLike,
AttributationLike,
HiddenState,
AttributationLikeStrict,
)


Expand All @@ -48,6 +50,10 @@ class ActionAttributation(ABC):
prob: float
raw: Union[AttributationLike, Sequence[AttributationLike]]
normalized: Union[AttributationLike, Sequence[AttributationLike]]
hs_raw: Optional[Union[AttributationLikeStrict, Sequence[AttributationLikeStrict]]]
hs_normalized: Optional[
Union[AttributationLikeStrict, Sequence[AttributationLikeStrict]]
]

def is_complied(self, obs_space: gym.Space) -> bool:
raise NotImplementedError
Expand All @@ -57,6 +63,8 @@ def is_complied(self, obs_space: gym.Space) -> bool:
class DiscreteActionAttributation(ActionAttributation):
raw: AttributationLike
normalized: AttributationLike
hs_raw: Optional[AttributationLikeStrict] = None
hs_normalized: Optional[AttributationLikeStrict] = None

def is_complied(self, obs_space: gym.Space) -> bool:
obs_space = remove_value_constraints_from_space(obs_space)
Expand All @@ -67,6 +75,8 @@ def is_complied(self, obs_space: gym.Space) -> bool:
class MultiDiscreteActionAttributation(ActionAttributation):
raw: Sequence[AttributationLike]
normalized: Sequence[AttributationLike]
hs_raw: Optional[Sequence[AttributationLikeStrict]] = None
hs_normalized: Optional[Sequence[AttributationLikeStrict]] = None

def is_complied(self, obs_space: gym.Space) -> bool:
obs_space = remove_value_constraints_from_space(obs_space)
Expand All @@ -84,31 +94,6 @@ def sub_action(self, index: int) -> AttributationLike:
return self.raw[index]


# @dataclass
# class TupleActionAttributation(ActionAttributation):
# raw: Sequence[AttributationLike]
# normalized: Sequence[AttributationLike]
#
# def is_complied(self, obs_space: gym.Space) -> bool:
# obs_space = remove_value_constraints_from_space(obs_space)
# return all(
# [obs_space.contains(self.space(i)) for i in range(self.num_spaces())]
# )
#
# def map(
# self, fn: Callable[[AttributationLike], AttributationLike]
# ) -> ActionAttributation:
# return TupleActionAttributation(
# [fn(self.space(i)) for i in range(self.num_spaces())]
# )
#
# def num_spaces(self) -> int:
# return len(self.data)
#
# def space(self, index: int) -> AttributationLike:
# return self.data[index]


@dataclass
class Timestep:
obs: ObsLike
Expand All @@ -117,6 +102,7 @@ class Timestep:
done: DoneLike
info: InfoLike
attributations: Optional[Attributation] = None
hs: Optional[HiddenState] = None


@dataclass
Expand Down
6 changes: 3 additions & 3 deletions rld/tests/resources/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,13 @@ class DictObsTupleActionEnv(BaseEnv):
ALL_ENVS = [
BoxObsDiscreteActionEnv,
BoxObsMultiDiscreteActionEnv,
# BoxObsTupleActionEnv,
BoxObsTupleActionEnv,
ImageObsDiscreteActionEnv,
ImageObsMultiDiscreteActionEnv,
# ImageObsTupleActionEnv,
ImageObsTupleActionEnv,
DictObsDiscreteActionEnv,
DictObsMultiDiscreteActionEnv,
# DictObsTupleActionEnv,
DictObsTupleActionEnv,
]


Expand Down
89 changes: 80 additions & 9 deletions rld/tests/resources/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC
from typing import Optional
from typing import Optional, Union, List

import torch
import torch.nn as nn
Expand All @@ -13,8 +13,9 @@
MULTI_DISCRETE_ACTION_SPACE,
IMAGE_OBS_SPACE,
DICT_OBS_SPACE,
TUPLE_ACTION_SPACE,
)
from rld.typing import HiddenState, ObsTensorStrict, ObsTensorLike
from rld.typing import HiddenStateTensor, ObsTensorStrict, ObsTensorLike


class ObsMixin:
Expand All @@ -29,7 +30,12 @@ def preprocess_obs(self, obs: torch.Tensor) -> torch.Tensor:


class ActionMixin:
def init_head(self) -> nn.Module:
def init_head(self) -> Union[nn.Module, List[nn.Module]]:
raise NotImplementedError

def call_head(
self, hidden: torch.Tensor
) -> Union[torch.Tensor, List[torch.Tensor]]:
raise NotImplementedError


Expand All @@ -46,7 +52,11 @@ def forward(self, obs_flat: ObsTensorStrict) -> torch.Tensor:
obs = unpack_tensor(obs_flat, self.obs_space())
obs = self.preprocess_obs(obs)
x = F.relu(self.call_hidden(obs))
x = self.head(x)
x = self.call_head(x)

if isinstance(x, list):
x = torch.cat(x, dim=-1)

return x

def input_device(self) -> torch.device:
Expand All @@ -65,9 +75,11 @@ def __init__(self):
)
self.head = self.init_head()

self._last_state: Optional[HiddenState] = None
self._last_state: Optional[HiddenStateTensor] = None

def forward(self, obs_flat: ObsTensorStrict, state: HiddenState) -> torch.Tensor:
def forward(
self, obs_flat: ObsTensorStrict, state: HiddenStateTensor
) -> torch.Tensor:
obs = unpack_tensor(obs_flat, self.obs_space())
obs = self.preprocess_obs(obs)

Expand All @@ -79,22 +91,25 @@ def forward(self, obs_flat: ObsTensorStrict, state: HiddenState) -> torch.Tensor
state = self.reshape_to_store(torch.stack(state))

x = x.squeeze(dim=1)
x = self.head(x)
x = self.call_head(x)

if isinstance(x, list):
x = torch.cat(x, dim=-1)

self._last_state = state
return x

def input_device(self) -> torch.device:
return torch.device("cpu")

def initial_state(self) -> HiddenState:
def initial_state(self) -> HiddenStateTensor:
return torch.zeros(
(1, 2, 1, self.NUM_HIDDEN_NEURONS),
dtype=torch.float32,
device=self.input_device(),
)

def last_output_state(self) -> HiddenState:
def last_output_state(self) -> HiddenStateTensor:
if self._last_state is None:
raise RuntimeError(
"Trying to get last output hidden state without calling "
Expand Down Expand Up @@ -162,6 +177,9 @@ def action_space(self) -> Space:
def init_head(self: BaseModel) -> nn.Module:
return nn.Linear(self.NUM_HIDDEN_NEURONS, self.action_space().n)

def call_head(self: BaseModel, hidden: torch.Tensor) -> torch.Tensor:
return self.head(hidden)


class MultiDiscreteActionMixin(ActionMixin):
def action_space(self) -> Space:
Expand All @@ -170,6 +188,23 @@ def action_space(self) -> Space:
def init_head(self: BaseModel) -> nn.Module:
return nn.Linear(self.NUM_HIDDEN_NEURONS, sum(self.action_space().nvec))

def call_head(self: BaseModel, hidden: torch.Tensor) -> torch.Tensor:
return self.head(hidden)


class TupleActionMixin(ActionMixin):
def action_space(self) -> Space:
return TUPLE_ACTION_SPACE

def init_head(self: BaseModel) -> List[nn.Module]:
return [
nn.Linear(self.NUM_HIDDEN_NEURONS, sub_space.n)
for sub_space in self.action_space().spaces
]

def call_head(self: BaseModel, hidden: torch.Tensor) -> List[torch.Tensor]:
return [head(hidden) for head in self.head]


class BoxObsDiscreteActionModel(BoxObsMixin, DiscreteActionMixin, BaseModel):
pass
Expand All @@ -179,6 +214,10 @@ class BoxObsMultiDiscreteActionModel(BoxObsMixin, MultiDiscreteActionMixin, Base
pass


class BoxObsTupleActionModel(BoxObsMixin, TupleActionMixin, BaseModel):
pass


class ImageObsDiscreteActionModel(ImageObsMixin, DiscreteActionMixin, BaseModel):
pass

Expand All @@ -189,6 +228,10 @@ class ImageObsMultiDiscreteActionModel(
pass


class ImageObsTupleActionModel(BoxObsMixin, TupleActionMixin, BaseModel):
pass


class DictObsDiscreteActionModel(DictObxMixin, DiscreteActionMixin, BaseModel):
pass

Expand All @@ -199,6 +242,10 @@ class DictObsMultiDiscreteActionModel(
pass


class DictObsTupleActionModel(DictObxMixin, TupleActionMixin, BaseModel):
pass


class BoxObsDiscreteActionRecurrentModel(
BoxObsMixin, DiscreteActionMixin, BaseRecurrentModel
):
Expand All @@ -211,6 +258,12 @@ class BoxObsMultiDiscreteActionRecurrentModel(
pass


class BoxObsTupleActionRecurrentModel(
BoxObsMixin, TupleActionMixin, BaseRecurrentModel
):
pass


class ImageObsDiscreteActionRecurrentModel(
ImageObsMixin, DiscreteActionMixin, BaseRecurrentModel
):
Expand All @@ -223,6 +276,12 @@ class ImageObsMultiDiscreteActionRecurrentModel(
pass


class ImageObsTupleActionRecurrentModel(
ImageObsMixin, TupleActionMixin, BaseRecurrentModel
):
pass


class DictObsDiscreteActionRecurrentModel(
DictObxMixin, DiscreteActionMixin, BaseRecurrentModel
):
Expand All @@ -235,18 +294,30 @@ class DictObsMultiDiscreteActionRecurrentModel(
pass


class DictObsTupleActionRecurrentModel(
DictObxMixin, TupleActionMixin, BaseRecurrentModel
):
pass


ALL_MODELS = [
BoxObsDiscreteActionModel,
BoxObsMultiDiscreteActionModel,
BoxObsTupleActionModel,
ImageObsDiscreteActionModel,
ImageObsMultiDiscreteActionModel,
ImageObsTupleActionModel,
DictObsDiscreteActionModel,
DictObsMultiDiscreteActionModel,
DictObsTupleActionModel,
# Recurrent models
BoxObsDiscreteActionRecurrentModel,
BoxObsMultiDiscreteActionRecurrentModel,
BoxObsTupleActionRecurrentModel,
ImageObsDiscreteActionRecurrentModel,
ImageObsMultiDiscreteActionRecurrentModel,
ImageObsTupleActionRecurrentModel,
DictObsDiscreteActionRecurrentModel,
DictObsMultiDiscreteActionRecurrentModel,
DictObsTupleActionRecurrentModel,
]
Loading