Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 113 additions & 4 deletions rld/attributation.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,10 @@ def __next__(self) -> TimestepAttributationBatch:
probs = _multicategorical_softmax(
logits, list(self.model.action_space().nvec)
)
elif isinstance(self.model.action_space(), gym.spaces.Tuple):
probs = _multicategorical_softmax(
logits, n_vecs(self.model.action_space())
)
else:
raise ActionSpaceNotSupported(self.model.action_space())

Expand All @@ -183,6 +187,13 @@ def __next__(self) -> TimestepAttributationBatch:
prob_of_picked_action = (
torch.stack([probs[s] for s in raw_picked_action]).prod().item()
)
elif isinstance(self.model.action_space(), gym.spaces.Tuple):
if len(raw_picked_action) > 1:
prob_of_picked_action = (
torch.stack([probs[s] for s in raw_picked_action]).prod().item()
)
else:
prob_of_picked_action = probs[picked_action].item()
else:
raise ActionSpaceNotSupported(self.model.action_space())

Expand Down Expand Up @@ -224,6 +235,22 @@ def __next__(self) -> TimestepAttributationBatch:
self.model.action_space(), action
)
targets.append((action, raw_action, prob))
elif isinstance(self.model.action_space(), gym.spaces.Tuple):
# Extract probs for each sub-action
subs_probs = _extract_tuple_discrete_action_probs(
self.model.action_space(), probs
)
# Calculate most probably actions
top_probs_with_index = _sort_multi_discrete_action_probs(subs_probs)

num_actions = _total_tuple_discrete_actions(self.model.action_space())
num_top = min(int(self.target.value), num_actions)
for action, prob in top_probs_with_index[:num_top]:
raw_action = _action_to_raw_action(
self.model.action_space(), action
)
targets.append((action, raw_action, prob))

else:
raise ActionSpaceNotSupported(self.model.action_space())

Expand Down Expand Up @@ -253,6 +280,17 @@ def __next__(self) -> TimestepAttributationBatch:
targets_for_target = torch.tensor(
target, device=self.model.output_device(), dtype=torch.long
)
elif isinstance(self.model.action_space(), gym.spaces.Tuple):
num_sub_actions = _tuple_discrete_actions_count(
self.model.action_space()
)
inputs_for_target = _extend_batch_dim(inputs, num_sub_actions)
if self._model_recurrent:
states_for_target = _extend_batch_dim(state, num_sub_actions)
baselines_for_target = _extend_batch_dim(baselines, num_sub_actions)
targets_for_target = torch.tensor(
target, device=self.model.output_device(), dtype=torch.long
)
else:
raise ActionSpaceNotSupported(self.model.action_space())
inputs_all.append(inputs_for_target)
Expand Down Expand Up @@ -334,7 +372,47 @@ def attribute_trajectory(
elif isinstance(model.action_space(), gym.spaces.MultiDiscrete):
num_sub_actions = _multi_discrete_actions_count(model.action_space())
grouped_attributation = [
raw_attributation[i : i + num_sub_actions]
raw_attributation[i: i + num_sub_actions]
for i in range(0, raw_attributation.shape[0], num_sub_actions)
]

attributation = Attributation(
picked=MultiDiscreteActionAttributation(
action=batch.actions[0],
prob=batch.action_probs[0],
raw=[
unpack_array(grouped_attributation[0][j], obs_space)
for j in range(num_sub_actions)
],
normalized=[
normalizer.transform(
unpack_array(grouped_attributation[0][j], obs_space)
)
for j in range(num_sub_actions)
],
),
top=[
MultiDiscreteActionAttributation(
action=batch.actions[i],
prob=batch.action_probs[i],
raw=[
unpack_array(grouped_attributation[i][j], obs_space)
for j in range(num_sub_actions)
],
normalized=[
normalizer.transform(
unpack_array(grouped_attributation[i][j], obs_space)
)
for j in range(num_sub_actions)
],
)
for i in range(1, len(batch.actions))
],
)
elif isinstance(model.action_space(), gym.spaces.Tuple):
num_sub_actions = _tuple_discrete_actions_count(model.action_space())
grouped_attributation = [
raw_attributation[i: i + num_sub_actions]
for i in range(0, raw_attributation.shape[0], num_sub_actions)
]

Expand Down Expand Up @@ -393,10 +471,18 @@ def _multi_discrete_actions_count(action_space: gym.spaces.MultiDiscrete) -> int
return len(action_space.nvec)


def _tuple_discrete_actions_count(action_space: gym.spaces.MultiDiscrete) -> int:
return len(n_vecs(action_space))


def _total_multi_discrete_actions(action_space: gym.spaces.MultiDiscrete) -> int:
return reduce(operator.mul, action_space.nvec)


def _total_tuple_discrete_actions(action_space: gym.spaces.MultiDiscrete) -> int:
return reduce(operator.mul, n_vecs(action_space))


def _extract_multi_discrete_action_probs(
action_space: gym.spaces.MultiDiscrete, probs: torch.Tensor
) -> List[torch.Tensor]:
Expand All @@ -405,7 +491,14 @@ def _extract_multi_discrete_action_probs(
# Accumulate sizes to calculate offsets (e.g. [0, 4, 6, 9])
# TODO Use initial=0 when moved to Python 3.8
offsets = [0] + list(accumulate(subs, operator.add))[:-1]
return [probs[o : o + s] for o, s in zip(offsets, subs)]
return [probs[o: o + s] for o, s in zip(offsets, subs)]


def _extract_tuple_discrete_action_probs(action_space: gym.spaces.MultiDiscrete,
probs: torch.Tensor) -> List[torch.Tensor]:
subs = n_vecs(action_space)
offsets = [0] + list(accumulate(subs, operator.add))[:-1]
return [probs[o: o + s] for o, s in zip(offsets, subs)]


def _sort_multi_discrete_action_probs(
Expand All @@ -428,6 +521,10 @@ def _action_to_raw_action(action_space: gym.Space, action: ActionLike) -> Action
# TODO Use initial=0 when moved to Python 3.8
offsets = [0] + list(accumulate(subs, operator.add))[:-1]
return action + np.array(offsets)
elif isinstance(action_space, gym.spaces.Tuple):
subs = n_vecs(action_space)
offsets = [0] + list(accumulate(subs, operator.add))[:-1]
return action + np.array(offsets)
else:
raise ActionSpaceNotSupported(action_space)

Expand All @@ -436,8 +533,20 @@ def _multicategorical_softmax(t: torch.Tensor, sizes: List[int]) -> torch.Tensor
probs = torch.empty_like(t)
offset = 0
for sub_space_size in sizes:
probs[offset : offset + sub_space_size] = torch.softmax(
t[offset : offset + sub_space_size], dim=-1
probs[offset: offset + sub_space_size] = torch.softmax(
t[offset: offset + sub_space_size], dim=-1
)
offset += sub_space_size
return probs


def n_vecs(action_space: gym.spaces.Tuple) -> List[int]:
nvecs = []
for a_s in action_space.spaces:
if isinstance(a_s, gym.spaces.Discrete):
nvecs.append(a_s.n)
elif isinstance(a_s, gym.spaces.MultiDiscrete):
nvecs.extend(list(a_s.nvec))
elif isinstance(a_s, gym.spaces.Tuple):
nvecs.extend(n_vecs(a_s))
return nvecs