-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathbuffer.py
More file actions
216 lines (186 loc) · 9.09 KB
/
buffer.py
File metadata and controls
216 lines (186 loc) · 9.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
from stable_baselines3.common.buffers import BaseBuffer
import warnings
from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, List, Optional, Union
import numpy as np
import torch as th
from gym import spaces
from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape
from stable_baselines3.common.type_aliases import (
RolloutBufferSamples,
)
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env import VecNormalize
try:
# Check memory used by replay buffer when possible
import psutil
except ImportError:
psutil = None
class RolloutBuffer(BaseBuffer):
"""
Rollout buffer used in on-policy algorithms like A2C/PPO.
It corresponds to ``buffer_size`` transitions collected
using the current policy.
This experience will be discarded after the policy update.
In order to use PPO objective, we also store the current value of each state
and the log probability of each taken action.
The term rollout here refers to the model-free notion and should not
be used with the concept of rollout used in model-based RL or planning.
Hence, it is only involved in policy and value function training but not action selection.
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device: PyTorch device
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
Equivalent to classic advantage when set to 1.
:param gamma: Discount factor
:param n_envs: Number of parallel environments
"""
def __init__(
self,
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = "auto",
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,
):
super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
self.gae_lambda = gae_lambda
self.gamma = gamma
self.observations, self.actions, self.rewards, self.advantages, self.next_observations = None, None, None, None, None
self.returns, self.episode_starts, self.values, self.log_probs = None, None, None, None
self.generator_ready = False
self.reset()
#self.action_dim = 18
def reset(self) -> None:
self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=np.float32)
self.next_observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=np.float32)
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.returns = np.zeros((self.buffer_size, self.n_envs, 2), dtype=np.float32)
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.values = np.zeros((self.buffer_size, self.n_envs, 2), dtype=np.float32)
self.log_probs = np.zeros((self.buffer_size, self.n_envs, 2), dtype=np.float32)
self.advantages = np.zeros((self.buffer_size, self.n_envs, 2), dtype=np.float32)
self.generator_ready = False
super().reset()
def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarray) -> None:
"""
Post-processing step: compute the lambda-return (TD(lambda) estimate)
and GAE(lambda) advantage.
Uses Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
to compute the advantage. To obtain Monte-Carlo advantage estimate (A(s) = R - V(S))
where R is the sum of discounted reward with value bootstrap
(because we don't always have full episode), set ``gae_lambda=1.0`` during initialization.
The TD(lambda) estimator has also two special cases:
- TD(1) is Monte-Carlo estimate (sum of discounted rewards)
- TD(0) is one-step estimate with bootstrapping (r_t + gamma * v(s_{t+1}))
For more information, see discussion in https://github.com/DLR-RM/stable-baselines3/pull/375.
:param last_values: state value estimation for the last step (one for each env)
:param dones: if the last step was a terminal step (one bool for each env).
"""
# Convert to numpy
last_values_cd = last_values[1].clone().cpu().numpy().flatten()
last_values_in = last_values[0].clone().cpu().numpy().flatten()
last_gae_lam_cd = 0
last_gae_lam_in = 0
for step in reversed(range(self.buffer_size)):
if step == self.buffer_size - 1:
next_non_terminal = 1.0 - dones
next_values_cd = last_values_cd
next_values_in = last_values_in
else:
next_non_terminal = 1.0 - self.episode_starts[step + 1]
next_values_cd = self.values[step + 1, :, 1]
next_values_in = self.values[step + 1, :, 0]
value_cd = self.values[step, :, 1]
delta_cd = self.rewards[step] + self.gamma * next_values_cd * next_non_terminal - value_cd
last_gae_lam_cd = delta_cd + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam_cd
value_in = self.values[step, :, 0]
delta_in = self.rewards[step] + self.gamma * next_values_in * next_non_terminal - value_in
last_gae_lam_in = delta_in + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam_in
last_gae_lam = np.concatenate([last_gae_lam_in.reshape(-1, 1), last_gae_lam_cd.reshape(-1, 1)]).transpose()
self.advantages[step] = last_gae_lam.reshape(-1, 2)
# TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)"
# in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA
self.returns = self.advantages + self.values
def add(
self,
obs: np.ndarray,
next_obs: np.ndarray,
action: np.ndarray,
reward: np.ndarray,
episode_start: np.ndarray,
value: th.Tensor,
log_prob: th.Tensor,
) -> None:
"""
:param obs: Observation
:param action: Action
:param reward:
:param episode_start: Start of episode signal.
:param value: estimated value of the current state
following the current policy.
:param log_prob: log probability of the action
following the current policy.
"""
if len(log_prob.shape) == 0:
# Reshape 0-d tensor to avoid error
log_prob = log_prob.reshape(-1, 1)
# Reshape needed when using multiple envs with discrete observations
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
if isinstance(self.observation_space, spaces.Discrete):
obs = obs.reshape((self.n_envs,) + self.obs_shape)
# Same reshape, for actions
action = action.reshape((self.n_envs, self.action_dim))
self.observations[self.pos] = np.array(obs).copy()
self.next_observations[self.pos] = np.array(next_obs).copy()
self.actions[self.pos] = np.array(action).copy()
self.rewards[self.pos] = np.array(reward).copy()
self.episode_starts[self.pos] = np.array(episode_start).copy()
self.values[self.pos] = value.clone().cpu().numpy()
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
self.pos += 1
if self.pos == self.buffer_size:
self.full = True
def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSamples, None, None]:
assert self.full, ""
indices = np.random.permutation(self.buffer_size * self.n_envs)
# Prepare the data
if not self.generator_ready:
_tensor_names = [
"observations",
"next_observations",
"actions",
"values",
"log_probs",
"advantages",
"returns",
]
for tensor in _tensor_names:
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
self.generator_ready = True
# Return everything, don't create minibatches
if batch_size is None:
batch_size = self.buffer_size * self.n_envs
start_idx = 0
while start_idx < self.buffer_size * self.n_envs:
yield self._get_samples(indices[start_idx : start_idx + batch_size])
start_idx += batch_size
def _get_samples(
self,
batch_inds: np.ndarray,
env: Optional[VecNormalize] = None,
) -> RolloutBufferSamples: # type: ignore[signature-mismatch] #FIXME
data = (
self.observations[batch_inds],
# self.next_observations[batch_inds],
self.actions[batch_inds],
self.values[batch_inds].flatten(),
self.log_probs[batch_inds].flatten(),
self.advantages[batch_inds].flatten(),
self.returns[batch_inds].flatten(),
)
return RolloutBufferSamples(*tuple(map(self.to_torch, data)))