-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdebug_controller.py
More file actions
257 lines (219 loc) · 10.5 KB
/
Copy pathdebug_controller.py
File metadata and controls
257 lines (219 loc) · 10.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
"""
Debug script to record controller behavior and analyze why it falls
"""
import argparse
import numpy as np
import time
import json
from datetime import datetime
from residual_rl import HumanoidPushEnv, create_base_controller
def record_controller_run(
base_controller_type="capture_point",
push_probability=0.0,
push_force_range=(50, 200),
max_steps=5000,
speed=1.0,
output_file=None
):
"""Record controller behavior and save to file."""
if output_file is None:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_file = f"controller_debug_{timestamp}.json"
print(f"Creating {base_controller_type} controller...")
if base_controller_type.lower() in ["capture_point", "cp"]:
base_controller = create_base_controller(base_controller_type)
else:
base_controller = create_base_controller(base_controller_type, kp=100.0, kd=10.0)
print("Creating environment...")
env = HumanoidPushEnv(
base_controller=base_controller,
push_probability=push_probability,
push_force_range=push_force_range,
max_episode_steps=max_steps,
render_mode="human"
)
print(f"Recording data (will save to {output_file})...")
print("Press Ctrl+C to stop early")
# Storage for recorded data
recorded_data = {
'timesteps': [],
'joint_positions': [],
'joint_velocities': [],
'base_position': [],
'base_velocity': [],
'base_orientation': [],
'base_angular_velocity': [],
'com_position': [],
'com_velocity': [],
'capture_point': [],
'capture_point_error': [],
'torso_pitch_roll': [],
'actions': [],
'rewards': [],
'height': [],
'terminated': [],
'truncated': []
}
try:
obs, info = env.reset()
step = 0
while step < max_steps:
# Use zero residual action (only base controller)
residual_action = np.zeros(env.action_space.shape[0])
obs, reward, terminated, truncated, info = env.step(residual_action)
# Extract state information
obs_array = np.array(obs).flatten()
# Parse observation
n_qpos = 24
n_qvel = 23
if len(obs_array) >= n_qpos + n_qvel:
qpos = obs_array[:n_qpos]
qvel = obs_array[n_qpos:n_qpos + n_qvel]
elif len(obs_array) >= 47:
qpos = obs_array[:24]
qvel = obs_array[24:47]
else:
mid = len(obs_array) // 2
qpos = obs_array[:mid]
qvel = obs_array[mid:]
# Extract components
base_pos = qpos[:3] if len(qpos) >= 3 else np.zeros(3)
base_quat = qpos[3:7] if len(qpos) >= 7 else np.array([1.0, 0.0, 0.0, 0.0])
joint_qpos = qpos[7:7+17] if len(qpos) > 7 else np.zeros(17)
base_lin_vel = qvel[:3] if len(qvel) >= 3 else np.zeros(3)
base_ang_vel = qvel[3:6] if len(qvel) >= 6 else np.zeros(3)
joint_qvel = qvel[6:6+17] if len(qvel) > 6 else np.zeros(17)
# Compute capture point (if using capture point controller)
com_pos = base_pos.copy()
com_vel = base_lin_vel.copy()
zc = com_pos[2] if len(com_pos) > 2 else 1.4
zc = max(zc, 0.1)
g = 9.81
Tc = np.sqrt(zc / g)
Cp = com_pos[:2] + com_vel[:2] * Tc
Cp_des = np.array([0.0, 0.0]) # Center of support polygon
cp_error = Cp - Cp_des
# Compute torso pitch/roll
w, x, y, z = base_quat
pitch = np.arcsin(np.clip(2 * (w * y - x * z), -1, 1))
roll = np.arctan2(2 * (w * x + y * z), 1 - 2 * (x**2 + y**2))
# Get actual action being sent (base + residual)
# The base controller is called in env.step, so we need to get it from the env
# For now, record what we can - the residual is zero, but the total includes base
# We'll compute what the base controller would output
if env.base_controller is not None:
try:
base_action = env.base_controller(obs, info)
total_action = base_action + residual_action * 0.5 # Same as env.step does
except:
total_action = residual_action
else:
total_action = residual_action
action = total_action
# Record all data
recorded_data['timesteps'].append(step)
recorded_data['joint_positions'].append(joint_qpos.tolist())
recorded_data['joint_velocities'].append(joint_qvel.tolist())
recorded_data['base_position'].append(base_pos.tolist())
recorded_data['base_velocity'].append(base_lin_vel.tolist())
recorded_data['base_orientation'].append(base_quat.tolist())
recorded_data['base_angular_velocity'].append(base_ang_vel.tolist())
recorded_data['com_position'].append(com_pos.tolist())
recorded_data['com_velocity'].append(com_vel.tolist())
recorded_data['capture_point'].append(Cp.tolist())
recorded_data['capture_point_error'].append(cp_error.tolist())
recorded_data['torso_pitch_roll'].append([float(pitch), float(roll)])
recorded_data['actions'].append(action.tolist())
recorded_data['rewards'].append(float(reward))
recorded_data['height'].append(float(base_pos[2]) if len(base_pos) > 2 else 0.0)
recorded_data['terminated'].append(bool(terminated))
recorded_data['truncated'].append(bool(truncated))
step += 1
if step % 100 == 0:
print(f"Step {step}: Height={base_pos[2]:.3f}, "
f"Forward Vel={base_lin_vel[0]:.3f}, "
f"Pitch={np.degrees(pitch):.1f}°, "
f"CP Error={np.linalg.norm(cp_error):.3f}")
if terminated or truncated:
print(f"\nEpisode ended at step {step}")
print(f" Reason: {'Terminated' if terminated else 'Truncated'}")
print(f" Final height: {base_pos[2]:.3f}")
print(f" Final forward velocity: {base_lin_vel[0]:.3f}")
print(f" Final pitch: {np.degrees(pitch):.1f}°")
break
# Control simulation speed
time.sleep(env.unwrapped.dt / speed)
except KeyboardInterrupt:
print("\nRecording stopped by user")
finally:
# Save data to JSON file
print(f"\nSaving {len(recorded_data['timesteps'])} timesteps to {output_file}...")
# Convert numpy arrays to lists for JSON serialization
json_data = {}
for key, value in recorded_data.items():
json_data[key] = value
with open(output_file, 'w') as f:
json.dump(json_data, f, indent=2)
print(f"Data saved to {output_file}")
# Print summary statistics
if len(recorded_data['timesteps']) > 0:
print("\n" + "="*60)
print("SUMMARY STATISTICS")
print("="*60)
heights = np.array(recorded_data['height'])
forward_vels = [v[0] for v in recorded_data['base_velocity']]
pitches = [p[0] for p in recorded_data['torso_pitch_roll']]
cp_errors = [np.linalg.norm(e) for e in recorded_data['capture_point_error']]
print(f"Total steps recorded: {len(recorded_data['timesteps'])}")
print(f"Final height: {heights[-1]:.3f} m")
print(f"Min height: {heights.min():.3f} m")
print(f"Max height: {heights.max():.3f} m")
print(f"\nForward velocity:")
print(f" Mean: {np.mean(forward_vels):.4f} m/s")
print(f" Std: {np.std(forward_vels):.4f} m/s")
print(f" Final: {forward_vels[-1]:.4f} m/s")
print(f"\nTorso pitch:")
print(f" Mean: {np.degrees(np.mean(pitches)):.2f}°")
print(f" Std: {np.degrees(np.std(pitches)):.2f}°")
print(f" Final: {np.degrees(pitches[-1]):.2f}°")
print(f"\nCapture point error:")
print(f" Mean: {np.mean(cp_errors):.4f} m")
print(f" Max: {np.max(cp_errors):.4f} m")
print(f" Final: {cp_errors[-1]:.4f} m")
# Analyze joint positions
joint_pos_array = np.array(recorded_data['joint_positions'])
print(f"\nJoint positions (first 4 joints - typically ankles/hips):")
for i in range(min(4, joint_pos_array.shape[1])):
joint_data = joint_pos_array[:, i]
print(f" Joint {i}: mean={np.mean(joint_data):.4f}, "
f"std={np.std(joint_data):.4f}, "
f"final={joint_data[-1]:.4f}")
env.close()
def main():
parser = argparse.ArgumentParser(description="Record controller behavior for debugging")
parser.add_argument("--base_controller", type=str, default="capture_point",
choices=["pd", "lqr", "capture_point", "cp"],
help="Type of base controller")
parser.add_argument("--push_probability", type=float, default=0.0,
help="Probability of push disturbance per step")
parser.add_argument("--push_force_min", type=float, default=50.0,
help="Minimum push force magnitude")
parser.add_argument("--push_force_max", type=float, default=200.0,
help="Maximum push force magnitude")
parser.add_argument("--max_steps", type=int, default=5000,
help="Maximum steps to record")
parser.add_argument("--speed", type=float, default=1.0,
help="Simulation speed multiplier")
parser.add_argument("--output", type=str, default=None,
help="Output file path (default: auto-generated)")
args = parser.parse_args()
record_controller_run(
base_controller_type=args.base_controller,
push_probability=args.push_probability,
push_force_range=(args.push_force_min, args.push_force_max),
max_steps=args.max_steps,
speed=args.speed,
output_file=args.output
)
if __name__ == "__main__":
main()