From bf05372d98b14b60eaaf4355332ab49869edbbc2 Mon Sep 17 00:00:00 2001 From: "Pankiewicz, Nikodem" Date: Tue, 25 May 2021 11:21:39 +0200 Subject: [PATCH] add support for action space of type: gym.spaces.Tuple --- rld/attributation.py | 117 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 113 insertions(+), 4 deletions(-) diff --git a/rld/attributation.py b/rld/attributation.py index 3546bff..81d5d37 100644 --- a/rld/attributation.py +++ b/rld/attributation.py @@ -167,6 +167,10 @@ def __next__(self) -> TimestepAttributationBatch: probs = _multicategorical_softmax( logits, list(self.model.action_space().nvec) ) + elif isinstance(self.model.action_space(), gym.spaces.Tuple): + probs = _multicategorical_softmax( + logits, n_vecs(self.model.action_space()) + ) else: raise ActionSpaceNotSupported(self.model.action_space()) @@ -183,6 +187,13 @@ def __next__(self) -> TimestepAttributationBatch: prob_of_picked_action = ( torch.stack([probs[s] for s in raw_picked_action]).prod().item() ) + elif isinstance(self.model.action_space(), gym.spaces.Tuple): + if len(raw_picked_action) > 1: + prob_of_picked_action = ( + torch.stack([probs[s] for s in raw_picked_action]).prod().item() + ) + else: + prob_of_picked_action = probs[picked_action].item() else: raise ActionSpaceNotSupported(self.model.action_space()) @@ -224,6 +235,22 @@ def __next__(self) -> TimestepAttributationBatch: self.model.action_space(), action ) targets.append((action, raw_action, prob)) + elif isinstance(self.model.action_space(), gym.spaces.Tuple): + # Extract probs for each sub-action + subs_probs = _extract_tuple_discrete_action_probs( + self.model.action_space(), probs + ) + # Calculate most probably actions + top_probs_with_index = _sort_multi_discrete_action_probs(subs_probs) + + num_actions = _total_tuple_discrete_actions(self.model.action_space()) + num_top = min(int(self.target.value), num_actions) + for action, prob in top_probs_with_index[:num_top]: + raw_action = _action_to_raw_action( + self.model.action_space(), action + ) + targets.append((action, raw_action, prob)) + else: raise ActionSpaceNotSupported(self.model.action_space()) @@ -253,6 +280,17 @@ def __next__(self) -> TimestepAttributationBatch: targets_for_target = torch.tensor( target, device=self.model.output_device(), dtype=torch.long ) + elif isinstance(self.model.action_space(), gym.spaces.Tuple): + num_sub_actions = _tuple_discrete_actions_count( + self.model.action_space() + ) + inputs_for_target = _extend_batch_dim(inputs, num_sub_actions) + if self._model_recurrent: + states_for_target = _extend_batch_dim(state, num_sub_actions) + baselines_for_target = _extend_batch_dim(baselines, num_sub_actions) + targets_for_target = torch.tensor( + target, device=self.model.output_device(), dtype=torch.long + ) else: raise ActionSpaceNotSupported(self.model.action_space()) inputs_all.append(inputs_for_target) @@ -334,7 +372,47 @@ def attribute_trajectory( elif isinstance(model.action_space(), gym.spaces.MultiDiscrete): num_sub_actions = _multi_discrete_actions_count(model.action_space()) grouped_attributation = [ - raw_attributation[i : i + num_sub_actions] + raw_attributation[i: i + num_sub_actions] + for i in range(0, raw_attributation.shape[0], num_sub_actions) + ] + + attributation = Attributation( + picked=MultiDiscreteActionAttributation( + action=batch.actions[0], + prob=batch.action_probs[0], + raw=[ + unpack_array(grouped_attributation[0][j], obs_space) + for j in range(num_sub_actions) + ], + normalized=[ + normalizer.transform( + unpack_array(grouped_attributation[0][j], obs_space) + ) + for j in range(num_sub_actions) + ], + ), + top=[ + MultiDiscreteActionAttributation( + action=batch.actions[i], + prob=batch.action_probs[i], + raw=[ + unpack_array(grouped_attributation[i][j], obs_space) + for j in range(num_sub_actions) + ], + normalized=[ + normalizer.transform( + unpack_array(grouped_attributation[i][j], obs_space) + ) + for j in range(num_sub_actions) + ], + ) + for i in range(1, len(batch.actions)) + ], + ) + elif isinstance(model.action_space(), gym.spaces.Tuple): + num_sub_actions = _tuple_discrete_actions_count(model.action_space()) + grouped_attributation = [ + raw_attributation[i: i + num_sub_actions] for i in range(0, raw_attributation.shape[0], num_sub_actions) ] @@ -393,10 +471,18 @@ def _multi_discrete_actions_count(action_space: gym.spaces.MultiDiscrete) -> int return len(action_space.nvec) +def _tuple_discrete_actions_count(action_space: gym.spaces.MultiDiscrete) -> int: + return len(n_vecs(action_space)) + + def _total_multi_discrete_actions(action_space: gym.spaces.MultiDiscrete) -> int: return reduce(operator.mul, action_space.nvec) +def _total_tuple_discrete_actions(action_space: gym.spaces.MultiDiscrete) -> int: + return reduce(operator.mul, n_vecs(action_space)) + + def _extract_multi_discrete_action_probs( action_space: gym.spaces.MultiDiscrete, probs: torch.Tensor ) -> List[torch.Tensor]: @@ -405,7 +491,14 @@ def _extract_multi_discrete_action_probs( # Accumulate sizes to calculate offsets (e.g. [0, 4, 6, 9]) # TODO Use initial=0 when moved to Python 3.8 offsets = [0] + list(accumulate(subs, operator.add))[:-1] - return [probs[o : o + s] for o, s in zip(offsets, subs)] + return [probs[o: o + s] for o, s in zip(offsets, subs)] + + +def _extract_tuple_discrete_action_probs(action_space: gym.spaces.MultiDiscrete, + probs: torch.Tensor) -> List[torch.Tensor]: + subs = n_vecs(action_space) + offsets = [0] + list(accumulate(subs, operator.add))[:-1] + return [probs[o: o + s] for o, s in zip(offsets, subs)] def _sort_multi_discrete_action_probs( @@ -428,6 +521,10 @@ def _action_to_raw_action(action_space: gym.Space, action: ActionLike) -> Action # TODO Use initial=0 when moved to Python 3.8 offsets = [0] + list(accumulate(subs, operator.add))[:-1] return action + np.array(offsets) + elif isinstance(action_space, gym.spaces.Tuple): + subs = n_vecs(action_space) + offsets = [0] + list(accumulate(subs, operator.add))[:-1] + return action + np.array(offsets) else: raise ActionSpaceNotSupported(action_space) @@ -436,8 +533,20 @@ def _multicategorical_softmax(t: torch.Tensor, sizes: List[int]) -> torch.Tensor probs = torch.empty_like(t) offset = 0 for sub_space_size in sizes: - probs[offset : offset + sub_space_size] = torch.softmax( - t[offset : offset + sub_space_size], dim=-1 + probs[offset: offset + sub_space_size] = torch.softmax( + t[offset: offset + sub_space_size], dim=-1 ) offset += sub_space_size return probs + + +def n_vecs(action_space: gym.spaces.Tuple) -> List[int]: + nvecs = [] + for a_s in action_space.spaces: + if isinstance(a_s, gym.spaces.Discrete): + nvecs.append(a_s.n) + elif isinstance(a_s, gym.spaces.MultiDiscrete): + nvecs.extend(list(a_s.nvec)) + elif isinstance(a_s, gym.spaces.Tuple): + nvecs.extend(n_vecs(a_s)) + return nvecs