diff --git a/rld/attributation.py b/rld/attributation.py index 3546bff..f3a42ff 100644 --- a/rld/attributation.py +++ b/rld/attributation.py @@ -23,7 +23,15 @@ remove_channel_dim_from_image_space, Attributation, ) -from rld.typing import BaselineBuilder, AttributationLike, ActionLike, HiddenState +from rld.typing import ( + BaselineBuilder, + AttributationLike, + ActionLike, + HiddenStateTensor, + AttributationLikeStrict, + ObsTensorStrict, + MultiDiscreteAsTuple, +) class AttributationNormalizationMode(IntEnum): @@ -55,7 +63,15 @@ def transform(self, attr: AttributationLike) -> AttributationLike: if self.obs_image_channel_dim is not None: attr = np.sum(attr, axis=self.obs_image_channel_dim) obs_space = remove_channel_dim_from_image_space(obs_space) - attr = flatten(self.obs_space, attr) + attr = unflatten(obs_space, self._transform(flatten(self.obs_space, attr))) + return attr + + def transform_only( + self, hs_attr: AttributationLikeStrict + ) -> AttributationLikeStrict: + return self._transform(hs_attr) + + def _transform(self, attr: AttributationLikeStrict) -> AttributationLikeStrict: if self.mode == AttributationNormalizationMode.ALL: scaling_factor = self._calculate_safe_scaling_factor(np.abs(attr)) elif self.mode == AttributationNormalizationMode.POSITIVE: @@ -70,7 +86,7 @@ def transform(self, attr: AttributationLike) -> AttributationLike: else: raise EnumValueNotFound(self.mode, AttributationNormalizationMode) attr_norm = self._scale(attr, scaling_factor) - return unflatten(obs_space, attr_norm) + return attr_norm def _calculate_safe_scaling_factor(self, attr: AttributationLike) -> float: sorted_vals = np.sort(attr.flatten()) @@ -103,9 +119,10 @@ class AttributationTarget(IntEnum): @dataclass class TimestepAttributationBatch: - inputs: torch.Tensor - state: Optional[HiddenState] - baselines: torch.Tensor + input: ObsTensorStrict + hs: Optional[HiddenStateTensor] + input_baselines: ObsTensorStrict + hs_baselines: Optional[HiddenStateTensor] targets: torch.Tensor actions: List[ActionLike] action_probs: List[float] @@ -127,64 +144,71 @@ def __init__( self.target = target self._it = iter(self.trajectory) self._model_recurrent = isinstance(self.model, RecurrentModel) - self._state: Optional[ - HiddenState + self._model_obs_space = model.obs_space() + self._model_action_space = model.action_space() + self._hs: Optional[ + HiddenStateTensor ] = self.model.initial_state() if self._model_recurrent else None + if isinstance(self._model_action_space, gym.spaces.Tuple): + _validate_multi_discrete_tuple_action_space(self._model_action_space) + @torch.no_grad() def __next__(self) -> TimestepAttributationBatch: - try: - timestep = next(self._it) - except StopIteration: - raise StopIteration + timestep = next(self._it) - # inputs = self.model.flatten_obs(timestep.obs) - inputs = pack_array(timestep.obs, self.model.obs_space()) + inputs = pack_array(timestep.obs, self._model_obs_space) if self.baseline is not None: + # TODO Use inputs before packing baselines = self.baseline(inputs) else: baselines = np.zeros_like(inputs) inputs = torch.tensor(inputs, device=self.model.input_device()).unsqueeze(dim=0) - baselines = torch.tensor( - baselines, device=self.model.output_device() - ).unsqueeze(dim=0) + baselines = torch.tensor(baselines, device=self.model.input_device()).unsqueeze( + dim=0 + ) if self._model_recurrent: - out = self.model(inputs, self._state) + timestep = replace(timestep, hs=self._hs) + out = self.model(inputs, self._hs) state = self.model.last_output_state().to(device=self.model.input_device()) - self._state = state + self._hs = state logits = out.squeeze(dim=0).to(device=self.model.input_device()) else: + state = None logits = ( self.model(inputs).squeeze(dim=0).to(device=self.model.input_device()) ) - if isinstance(self.model.action_space(), gym.spaces.Discrete): + if isinstance(self._model_action_space, gym.spaces.Discrete): probs = torch.softmax(logits, dim=-1) - elif isinstance(self.model.action_space(), gym.spaces.MultiDiscrete): - probs = _multicategorical_softmax( - logits, list(self.model.action_space().nvec) - ) + elif isinstance( + self._model_action_space, (gym.spaces.MultiDiscrete, MultiDiscreteAsTuple) + ): + subs = _multi_discrete_action_sizes(self._model_action_space) + probs = _multicategorical_softmax(logits, subs) else: - raise ActionSpaceNotSupported(self.model.action_space()) + raise ActionSpaceNotSupported(self._model_action_space) targets = [] picked_action = timestep.action raw_picked_action = _action_to_raw_action( - self.model.action_space(), timestep.action + self._model_action_space, timestep.action ) - if isinstance(self.model.action_space(), gym.spaces.Discrete): + if isinstance(self._model_action_space, gym.spaces.Discrete): prob_of_picked_action = probs[picked_action].item() - elif isinstance(self.model.action_space(), gym.spaces.MultiDiscrete): + elif isinstance( + self._model_action_space, (gym.spaces.MultiDiscrete, MultiDiscreteAsTuple) + ): prob_of_picked_action = ( torch.stack([probs[s] for s in raw_picked_action]).prod().item() ) else: - raise ActionSpaceNotSupported(self.model.action_space()) + raise ActionSpaceNotSupported(self._model_action_space) # Always add the picked action to the targets list targets.append((picked_action, raw_picked_action, prob_of_picked_action)) @@ -195,7 +219,7 @@ def __next__(self) -> TimestepAttributationBatch: AttributationTarget.TOP5, AttributationTarget.ALL, ): - if isinstance(self.model.action_space(), gym.spaces.Discrete): + if isinstance(self._model_action_space, gym.spaces.Discrete): # Get total available actions num_actions = logits.numel() num_top = min(int(self.target.value), num_actions) @@ -204,28 +228,29 @@ def __next__(self) -> TimestepAttributationBatch: top_probs_with_index.indices, top_probs_with_index.values ): action = index.item() - raw_action = _action_to_raw_action( - self.model.action_space(), action - ) + raw_action = _action_to_raw_action(self._model_action_space, action) prob = prob.item() targets.append((action, raw_action, prob)) - elif isinstance(self.model.action_space(), gym.spaces.MultiDiscrete): + elif isinstance( + self._model_action_space, + (gym.spaces.MultiDiscrete, MultiDiscreteAsTuple), + ): # Extract probs for each sub-action subs_probs = _extract_multi_discrete_action_probs( - self.model.action_space(), probs + self._model_action_space, probs ) # Calculate most probably actions top_probs_with_index = _sort_multi_discrete_action_probs(subs_probs) - num_actions = _total_multi_discrete_actions(self.model.action_space()) + num_actions = _total_multi_discrete_actions(self._model_action_space) num_top = min(int(self.target.value), num_actions) for action, prob in top_probs_with_index[:num_top]: - raw_action = _action_to_raw_action( - self.model.action_space(), action - ) + raw_action = _action_to_raw_action(self._model_action_space, action) targets.append((action, raw_action, prob)) else: - raise ActionSpaceNotSupported(self.model.action_space()) + raise ActionSpaceNotSupported(self._model_action_space) + else: + raise EnumValueNotFound(self.target, AttributationTarget) inputs_all = [] if self._model_recurrent: @@ -234,7 +259,7 @@ def __next__(self) -> TimestepAttributationBatch: targets_all = [] for _, target, _ in targets: - if isinstance(self.model.action_space(), gym.spaces.Discrete): + if isinstance(self._model_action_space, gym.spaces.Discrete): inputs_for_target = inputs if self._model_recurrent: states_for_target = state @@ -242,9 +267,12 @@ def __next__(self) -> TimestepAttributationBatch: targets_for_target = torch.tensor( target, device=self.model.output_device(), dtype=torch.long ).unsqueeze(dim=0) - elif isinstance(self.model.action_space(), gym.spaces.MultiDiscrete): + elif isinstance( + self._model_action_space, + (gym.spaces.MultiDiscrete, MultiDiscreteAsTuple), + ): num_sub_actions = _multi_discrete_actions_count( - self.model.action_space() + self._model_action_space ) inputs_for_target = _extend_batch_dim(inputs, num_sub_actions) if self._model_recurrent: @@ -254,7 +282,7 @@ def __next__(self) -> TimestepAttributationBatch: target, device=self.model.output_device(), dtype=torch.long ) else: - raise ActionSpaceNotSupported(self.model.action_space()) + raise ActionSpaceNotSupported(self._model_action_space) inputs_all.append(inputs_for_target) if self._model_recurrent: states_all.append(states_for_target) @@ -264,17 +292,20 @@ def __next__(self) -> TimestepAttributationBatch: inputs_batch = torch.cat(inputs_all, dim=0) if self._model_recurrent: states_batch = torch.cat(states_all, dim=0) + state_baselines_batch = torch.zeros_like(states_batch) else: states_batch = None + state_baselines_batch = None baselines_batch = torch.cat(baselines_all, dim=0) targets_batch = torch.cat(targets_all, dim=0) actions = [action for action, _, _ in targets] action_probs = [prob for _, _, prob in targets] return TimestepAttributationBatch( - inputs=inputs_batch, - state=states_batch, - baselines=baselines_batch, + input=inputs_batch, + hs=states_batch, + input_baselines=baselines_batch, + hs_baselines=state_baselines_batch, targets=targets_batch, actions=actions, action_probs=action_probs, @@ -292,24 +323,27 @@ def attribute_trajectory( normalizer: AttributationNormalizer, ) -> Trajectory: obs_space = model.obs_space() + action_space = model.action_space() is_recurrent = isinstance(model, RecurrentModel) algo = IntegratedGradients(model) timesteps = [] for batch in trajectory_it: if not is_recurrent: raw_attributation = algo.attribute( - batch.inputs, baselines=batch.baselines, target=batch.targets + batch.input, baselines=batch.input_baselines, target=batch.targets ) + raw_hs_attr = None else: - raw_attributation = algo.attribute( - (batch.inputs, batch.state), - baselines=(batch.baselines, batch.state), + raw_attributation, raw_hs_attr = algo.attribute( + (batch.input, batch.hs), + baselines=(batch.input_baselines, batch.hs_baselines), target=batch.targets, ) - raw_attributation = raw_attributation[0] raw_attributation = raw_attributation.detach().cpu().numpy() + if raw_hs_attr is not None: + raw_hs_attr = raw_hs_attr.detach().cpu().numpy() - if isinstance(model.action_space(), gym.spaces.Discrete): + if isinstance(action_space, gym.spaces.Discrete): attributation = Attributation( picked=DiscreteActionAttributation( action=batch.actions[0], @@ -318,6 +352,10 @@ def attribute_trajectory( normalized=normalizer.transform( unpack_array(raw_attributation[0], obs_space) ), + hs_raw=raw_hs_attr[0] if is_recurrent else None, + hs_normalized=normalizer.transform_only(raw_hs_attr[0]) + if is_recurrent + else None, ), top=[ DiscreteActionAttributation( @@ -327,16 +365,27 @@ def attribute_trajectory( normalized=normalizer.transform( unpack_array(raw_attributation[i], obs_space) ), + hs_raw=raw_hs_attr[i] if is_recurrent else None, + hs_normalized=normalizer.transform_only(raw_hs_attr[i]) + if is_recurrent + else None, ) for i in range(1, len(batch.actions)) ], ) - elif isinstance(model.action_space(), gym.spaces.MultiDiscrete): - num_sub_actions = _multi_discrete_actions_count(model.action_space()) + elif isinstance(action_space, (gym.spaces.MultiDiscrete, MultiDiscreteAsTuple)): + num_sub_actions = _multi_discrete_actions_count(action_space) grouped_attributation = [ raw_attributation[i : i + num_sub_actions] for i in range(0, raw_attributation.shape[0], num_sub_actions) ] + if is_recurrent: + grouped_hs_attributation = [ + raw_hs_attr[i : i + num_sub_actions] + for i in range(0, raw_hs_attr.shape[0], num_sub_actions) + ] + else: + grouped_hs_attributation = None attributation = Attributation( picked=MultiDiscreteActionAttributation( @@ -352,6 +401,17 @@ def attribute_trajectory( ) for j in range(num_sub_actions) ], + hs_raw=[ + grouped_hs_attributation[0][j] for j in range(num_sub_actions) + ] + if is_recurrent + else None, + hs_normalized=[ + normalizer.transform_only(grouped_hs_attributation[0][j]) + for j in range(num_sub_actions) + ] + if is_recurrent + else None, ), top=[ MultiDiscreteActionAttributation( @@ -367,12 +427,24 @@ def attribute_trajectory( ) for j in range(num_sub_actions) ], + hs_raw=[ + grouped_hs_attributation[i][j] + for j in range(num_sub_actions) + ] + if is_recurrent + else None, + hs_normalized=[ + normalizer.transform_only(grouped_hs_attributation[i][j]) + for j in range(num_sub_actions) + ] + if is_recurrent + else None, ) for i in range(1, len(batch.actions)) ], ) else: - raise ActionSpaceNotSupported(model.action_space()) + raise ActionSpaceNotSupported(action_space) timesteps.append(replace(batch.timestep, attributations=attributation)) @@ -389,19 +461,36 @@ def _extend_batch_dim(t: torch.Tensor, new_batch_dim: int) -> torch.Tensor: return t.repeat(repeat_shape) -def _multi_discrete_actions_count(action_space: gym.spaces.MultiDiscrete) -> int: - return len(action_space.nvec) +def _multi_discrete_action_sizes( + space: Union[gym.spaces.MultiDiscrete, MultiDiscreteAsTuple] +) -> List[int]: + if isinstance(space, gym.spaces.MultiDiscrete): + return list(space.nvec) + else: + return [sub_space.n for sub_space in space.spaces] -def _total_multi_discrete_actions(action_space: gym.spaces.MultiDiscrete) -> int: - return reduce(operator.mul, action_space.nvec) +def _multi_discrete_actions_count( + action_space: Union[gym.spaces.MultiDiscrete, MultiDiscreteAsTuple] +) -> int: + if isinstance(action_space, gym.spaces.MultiDiscrete): + return len(action_space.nvec) + else: + return len(action_space.spaces) + + +def _total_multi_discrete_actions( + action_space: Union[gym.spaces.MultiDiscrete, MultiDiscreteAsTuple] +) -> int: + discrete_actions = _multi_discrete_action_sizes(action_space) + return reduce(operator.mul, discrete_actions) def _extract_multi_discrete_action_probs( action_space: gym.spaces.MultiDiscrete, probs: torch.Tensor ) -> List[torch.Tensor]: # Extract sizes of each dimension (e.g. [4, 2, 3]) - subs = action_space.nvec + subs = _multi_discrete_action_sizes(action_space) # Accumulate sizes to calculate offsets (e.g. [0, 4, 6, 9]) # TODO Use initial=0 when moved to Python 3.8 offsets = [0] + list(accumulate(subs, operator.add))[:-1] @@ -421,9 +510,9 @@ def _sort_multi_discrete_action_probs( def _action_to_raw_action(action_space: gym.Space, action: ActionLike) -> ActionLike: if isinstance(action_space, gym.spaces.Discrete): return action - elif isinstance(action_space, gym.spaces.MultiDiscrete): + elif isinstance(action_space, (gym.spaces.MultiDiscrete, MultiDiscreteAsTuple)): # Extract sizes of each dimension (e.g. [4, 2, 3]) - subs = action_space.nvec + subs = _multi_discrete_action_sizes(action_space) # Accumulate sizes to calculate offsets (e.g. [0, 4, 6, 9]) # TODO Use initial=0 when moved to Python 3.8 offsets = [0] + list(accumulate(subs, operator.add))[:-1] @@ -441,3 +530,11 @@ def _multicategorical_softmax(t: torch.Tensor, sizes: List[int]) -> torch.Tensor ) offset += sub_space_size return probs + + +def _validate_multi_discrete_tuple_action_space(space: gym.spaces.Tuple): + for space in space.spaces: + if not isinstance(space, gym.spaces.Discrete): + raise ValueError( + "Only a flat `Tuple` spaces with `Discrete` subspaces are supported." + ) diff --git a/rld/model.py b/rld/model.py index c50ca9d..7b6e852 100644 --- a/rld/model.py +++ b/rld/model.py @@ -17,7 +17,7 @@ ObsLikeStrict, ObsTensorLike, ObsTensorStrict, - HiddenState, + HiddenStateTensor, ) @@ -40,19 +40,19 @@ def _forward_unimplemented(self, *input: Any) -> None: class RecurrentModel(Model, ABC): - def initial_state(self) -> HiddenState: + def initial_state(self) -> HiddenStateTensor: """ [B x 2 x 1 x CELL_SIZE] """ raise NotImplementedError - def last_output_state(self) -> HiddenState: + def last_output_state(self) -> HiddenStateTensor: raise NotImplementedError - def reshape_to_torch(self, state: HiddenState) -> HiddenState: + def reshape_to_torch(self, state: HiddenStateTensor) -> HiddenStateTensor: return state.permute((1, 2, 0, 3)) - def reshape_to_store(self, state: HiddenState) -> HiddenState: + def reshape_to_store(self, state: HiddenStateTensor) -> HiddenStateTensor: return state.permute((2, 0, 1, 3)) @@ -102,9 +102,9 @@ def __init__(self, model: RayModel, lstm_cell_size: int): super().__init__(model) self.lstm_cell_size = lstm_cell_size - self._last_state: Optional[HiddenState] = None + self._last_state: Optional[HiddenStateTensor] = None - def forward(self, obs_flat: ObsTensorStrict, state: HiddenState): + def forward(self, obs_flat: ObsTensorStrict, state: HiddenStateTensor): if isinstance(self.obs_space(), Box): # We need to unpack e.g. image-like observation, # as RLlib doesn't flatten them into 1D vectors @@ -127,14 +127,14 @@ def forward(self, obs_flat: ObsTensorStrict, state: HiddenState): return logits - def initial_state(self) -> HiddenState: + def initial_state(self) -> HiddenStateTensor: initial_state = [ torch.tensor(s, device=self.input_device()) for s in self.model.get_initial_state() ] return torch.stack(initial_state).unsqueeze(dim=0).unsqueeze(dim=2) - def last_output_state(self) -> HiddenState: + def last_output_state(self) -> HiddenStateTensor: if self._last_state is None: raise RuntimeError( "Trying to get last output hidden state without calling " @@ -142,11 +142,11 @@ def last_output_state(self) -> HiddenState: ) return self._last_state - def reshape_to_torch(self, state: HiddenState) -> HiddenState: + def reshape_to_torch(self, state: HiddenStateTensor) -> HiddenStateTensor: permuted = super().reshape_to_torch(state) return permuted.squeeze(dim=1) - def reshape_to_store(self, state: HiddenState) -> HiddenState: + def reshape_to_store(self, state: HiddenStateTensor) -> HiddenStateTensor: return super().reshape_to_store(state.unsqueeze(dim=1)) diff --git a/rld/rollout.py b/rld/rollout.py index 7f22480..38be96f 100644 --- a/rld/rollout.py +++ b/rld/rollout.py @@ -28,6 +28,8 @@ ActionLike, InfoLike, AttributationLike, + HiddenState, + AttributationLikeStrict, ) @@ -48,6 +50,10 @@ class ActionAttributation(ABC): prob: float raw: Union[AttributationLike, Sequence[AttributationLike]] normalized: Union[AttributationLike, Sequence[AttributationLike]] + hs_raw: Optional[Union[AttributationLikeStrict, Sequence[AttributationLikeStrict]]] + hs_normalized: Optional[ + Union[AttributationLikeStrict, Sequence[AttributationLikeStrict]] + ] def is_complied(self, obs_space: gym.Space) -> bool: raise NotImplementedError @@ -57,6 +63,8 @@ def is_complied(self, obs_space: gym.Space) -> bool: class DiscreteActionAttributation(ActionAttributation): raw: AttributationLike normalized: AttributationLike + hs_raw: Optional[AttributationLikeStrict] = None + hs_normalized: Optional[AttributationLikeStrict] = None def is_complied(self, obs_space: gym.Space) -> bool: obs_space = remove_value_constraints_from_space(obs_space) @@ -67,6 +75,8 @@ def is_complied(self, obs_space: gym.Space) -> bool: class MultiDiscreteActionAttributation(ActionAttributation): raw: Sequence[AttributationLike] normalized: Sequence[AttributationLike] + hs_raw: Optional[Sequence[AttributationLikeStrict]] = None + hs_normalized: Optional[Sequence[AttributationLikeStrict]] = None def is_complied(self, obs_space: gym.Space) -> bool: obs_space = remove_value_constraints_from_space(obs_space) @@ -84,31 +94,6 @@ def sub_action(self, index: int) -> AttributationLike: return self.raw[index] -# @dataclass -# class TupleActionAttributation(ActionAttributation): -# raw: Sequence[AttributationLike] -# normalized: Sequence[AttributationLike] -# -# def is_complied(self, obs_space: gym.Space) -> bool: -# obs_space = remove_value_constraints_from_space(obs_space) -# return all( -# [obs_space.contains(self.space(i)) for i in range(self.num_spaces())] -# ) -# -# def map( -# self, fn: Callable[[AttributationLike], AttributationLike] -# ) -> ActionAttributation: -# return TupleActionAttributation( -# [fn(self.space(i)) for i in range(self.num_spaces())] -# ) -# -# def num_spaces(self) -> int: -# return len(self.data) -# -# def space(self, index: int) -> AttributationLike: -# return self.data[index] - - @dataclass class Timestep: obs: ObsLike @@ -117,6 +102,7 @@ class Timestep: done: DoneLike info: InfoLike attributations: Optional[Attributation] = None + hs: Optional[HiddenState] = None @dataclass diff --git a/rld/tests/resources/envs.py b/rld/tests/resources/envs.py index 2385bae..5e00c70 100644 --- a/rld/tests/resources/envs.py +++ b/rld/tests/resources/envs.py @@ -80,13 +80,13 @@ class DictObsTupleActionEnv(BaseEnv): ALL_ENVS = [ BoxObsDiscreteActionEnv, BoxObsMultiDiscreteActionEnv, - # BoxObsTupleActionEnv, + BoxObsTupleActionEnv, ImageObsDiscreteActionEnv, ImageObsMultiDiscreteActionEnv, - # ImageObsTupleActionEnv, + ImageObsTupleActionEnv, DictObsDiscreteActionEnv, DictObsMultiDiscreteActionEnv, - # DictObsTupleActionEnv, + DictObsTupleActionEnv, ] diff --git a/rld/tests/resources/models.py b/rld/tests/resources/models.py index dc9961c..11dc486 100644 --- a/rld/tests/resources/models.py +++ b/rld/tests/resources/models.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import Optional +from typing import Optional, Union, List import torch import torch.nn as nn @@ -13,8 +13,9 @@ MULTI_DISCRETE_ACTION_SPACE, IMAGE_OBS_SPACE, DICT_OBS_SPACE, + TUPLE_ACTION_SPACE, ) -from rld.typing import HiddenState, ObsTensorStrict, ObsTensorLike +from rld.typing import HiddenStateTensor, ObsTensorStrict, ObsTensorLike class ObsMixin: @@ -29,7 +30,12 @@ def preprocess_obs(self, obs: torch.Tensor) -> torch.Tensor: class ActionMixin: - def init_head(self) -> nn.Module: + def init_head(self) -> Union[nn.Module, List[nn.Module]]: + raise NotImplementedError + + def call_head( + self, hidden: torch.Tensor + ) -> Union[torch.Tensor, List[torch.Tensor]]: raise NotImplementedError @@ -46,7 +52,11 @@ def forward(self, obs_flat: ObsTensorStrict) -> torch.Tensor: obs = unpack_tensor(obs_flat, self.obs_space()) obs = self.preprocess_obs(obs) x = F.relu(self.call_hidden(obs)) - x = self.head(x) + x = self.call_head(x) + + if isinstance(x, list): + x = torch.cat(x, dim=-1) + return x def input_device(self) -> torch.device: @@ -65,9 +75,11 @@ def __init__(self): ) self.head = self.init_head() - self._last_state: Optional[HiddenState] = None + self._last_state: Optional[HiddenStateTensor] = None - def forward(self, obs_flat: ObsTensorStrict, state: HiddenState) -> torch.Tensor: + def forward( + self, obs_flat: ObsTensorStrict, state: HiddenStateTensor + ) -> torch.Tensor: obs = unpack_tensor(obs_flat, self.obs_space()) obs = self.preprocess_obs(obs) @@ -79,7 +91,10 @@ def forward(self, obs_flat: ObsTensorStrict, state: HiddenState) -> torch.Tensor state = self.reshape_to_store(torch.stack(state)) x = x.squeeze(dim=1) - x = self.head(x) + x = self.call_head(x) + + if isinstance(x, list): + x = torch.cat(x, dim=-1) self._last_state = state return x @@ -87,14 +102,14 @@ def forward(self, obs_flat: ObsTensorStrict, state: HiddenState) -> torch.Tensor def input_device(self) -> torch.device: return torch.device("cpu") - def initial_state(self) -> HiddenState: + def initial_state(self) -> HiddenStateTensor: return torch.zeros( (1, 2, 1, self.NUM_HIDDEN_NEURONS), dtype=torch.float32, device=self.input_device(), ) - def last_output_state(self) -> HiddenState: + def last_output_state(self) -> HiddenStateTensor: if self._last_state is None: raise RuntimeError( "Trying to get last output hidden state without calling " @@ -162,6 +177,9 @@ def action_space(self) -> Space: def init_head(self: BaseModel) -> nn.Module: return nn.Linear(self.NUM_HIDDEN_NEURONS, self.action_space().n) + def call_head(self: BaseModel, hidden: torch.Tensor) -> torch.Tensor: + return self.head(hidden) + class MultiDiscreteActionMixin(ActionMixin): def action_space(self) -> Space: @@ -170,6 +188,23 @@ def action_space(self) -> Space: def init_head(self: BaseModel) -> nn.Module: return nn.Linear(self.NUM_HIDDEN_NEURONS, sum(self.action_space().nvec)) + def call_head(self: BaseModel, hidden: torch.Tensor) -> torch.Tensor: + return self.head(hidden) + + +class TupleActionMixin(ActionMixin): + def action_space(self) -> Space: + return TUPLE_ACTION_SPACE + + def init_head(self: BaseModel) -> List[nn.Module]: + return [ + nn.Linear(self.NUM_HIDDEN_NEURONS, sub_space.n) + for sub_space in self.action_space().spaces + ] + + def call_head(self: BaseModel, hidden: torch.Tensor) -> List[torch.Tensor]: + return [head(hidden) for head in self.head] + class BoxObsDiscreteActionModel(BoxObsMixin, DiscreteActionMixin, BaseModel): pass @@ -179,6 +214,10 @@ class BoxObsMultiDiscreteActionModel(BoxObsMixin, MultiDiscreteActionMixin, Base pass +class BoxObsTupleActionModel(BoxObsMixin, TupleActionMixin, BaseModel): + pass + + class ImageObsDiscreteActionModel(ImageObsMixin, DiscreteActionMixin, BaseModel): pass @@ -189,6 +228,10 @@ class ImageObsMultiDiscreteActionModel( pass +class ImageObsTupleActionModel(BoxObsMixin, TupleActionMixin, BaseModel): + pass + + class DictObsDiscreteActionModel(DictObxMixin, DiscreteActionMixin, BaseModel): pass @@ -199,6 +242,10 @@ class DictObsMultiDiscreteActionModel( pass +class DictObsTupleActionModel(DictObxMixin, TupleActionMixin, BaseModel): + pass + + class BoxObsDiscreteActionRecurrentModel( BoxObsMixin, DiscreteActionMixin, BaseRecurrentModel ): @@ -211,6 +258,12 @@ class BoxObsMultiDiscreteActionRecurrentModel( pass +class BoxObsTupleActionRecurrentModel( + BoxObsMixin, TupleActionMixin, BaseRecurrentModel +): + pass + + class ImageObsDiscreteActionRecurrentModel( ImageObsMixin, DiscreteActionMixin, BaseRecurrentModel ): @@ -223,6 +276,12 @@ class ImageObsMultiDiscreteActionRecurrentModel( pass +class ImageObsTupleActionRecurrentModel( + ImageObsMixin, TupleActionMixin, BaseRecurrentModel +): + pass + + class DictObsDiscreteActionRecurrentModel( DictObxMixin, DiscreteActionMixin, BaseRecurrentModel ): @@ -235,18 +294,30 @@ class DictObsMultiDiscreteActionRecurrentModel( pass +class DictObsTupleActionRecurrentModel( + DictObxMixin, TupleActionMixin, BaseRecurrentModel +): + pass + + ALL_MODELS = [ BoxObsDiscreteActionModel, BoxObsMultiDiscreteActionModel, + BoxObsTupleActionModel, ImageObsDiscreteActionModel, ImageObsMultiDiscreteActionModel, + ImageObsTupleActionModel, DictObsDiscreteActionModel, DictObsMultiDiscreteActionModel, + DictObsTupleActionModel, # Recurrent models BoxObsDiscreteActionRecurrentModel, BoxObsMultiDiscreteActionRecurrentModel, + BoxObsTupleActionRecurrentModel, ImageObsDiscreteActionRecurrentModel, ImageObsMultiDiscreteActionRecurrentModel, + ImageObsTupleActionRecurrentModel, DictObsDiscreteActionRecurrentModel, DictObsMultiDiscreteActionRecurrentModel, + DictObsTupleActionRecurrentModel, ] diff --git a/rld/tests/resources/spaces.py b/rld/tests/resources/spaces.py index 6ebc938..62416e4 100644 --- a/rld/tests/resources/spaces.py +++ b/rld/tests/resources/spaces.py @@ -1,12 +1,8 @@ -import gym +from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple -BOX_OBS_SPACE = gym.spaces.Box(-1, 1, shape=(6,)) -IMAGE_OBS_SPACE = gym.spaces.Box(0, 1, shape=(84, 84, 4)) -DICT_OBS_SPACE = gym.spaces.Dict( - a=gym.spaces.Box(-1, 1, (4, 6)), b=gym.spaces.Box(-1, 1, (2,)) -) -DISCRETE_ACTION_SPACE = gym.spaces.Discrete(4) -MULTI_DISCRETE_ACTION_SPACE = gym.spaces.MultiDiscrete([4, 3, 2]) -TUPLE_ACTION_SPACE = gym.spaces.Tuple( - (gym.spaces.MultiDiscrete([4, 2, 3]), gym.spaces.Discrete(2)) -) +BOX_OBS_SPACE = Box(-1, 1, shape=(6,)) +IMAGE_OBS_SPACE = Box(0, 1, shape=(84, 84, 4)) +DICT_OBS_SPACE = Dict(a=Box(-1, 1, (4, 6)), b=Box(-1, 1, (2,))) +DISCRETE_ACTION_SPACE = Discrete(4) +MULTI_DISCRETE_ACTION_SPACE = MultiDiscrete([4, 3, 2]) +TUPLE_ACTION_SPACE = Tuple((Discrete(4), Discrete(3), Discrete(2))) diff --git a/rld/typing.py b/rld/typing.py index 2e4882d..a081440 100644 --- a/rld/typing.py +++ b/rld/typing.py @@ -1,5 +1,6 @@ from typing import Union, Sequence, Callable, Dict, List +import gym import numpy as np import torch @@ -23,10 +24,13 @@ InfoLike = Dict[str, InfoValueLike] InfoBatchLike = Sequence[InfoLike] -HiddenState = torch.Tensor +HiddenState = np.ndarray +HiddenStateTensor = torch.Tensor AttributationLike = Union[np.ndarray, Dict[str, np.ndarray]] -AttributationLikeStrict = torch.Tensor +AttributationLikeStrict = np.ndarray AttributationBatchLike = Sequence[AttributationLike] BaselineBuilder = Callable[[ObsLike], np.ndarray] + +MultiDiscreteAsTuple = gym.spaces.Tuple