From 1e37a7d7b83b1794dd2fcdf4b4b4b3724455706c Mon Sep 17 00:00:00 2001 From: piercehowell Date: Wed, 22 May 2024 10:55:11 -0400 Subject: [PATCH 1/2] fixed partial observation sensing dynamics. --- vmas/scenarios/transport.py | 41 ++++++++++++++++++++++++++----------- 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/vmas/scenarios/transport.py b/vmas/scenarios/transport.py index a264a744..3caa3c4c 100644 --- a/vmas/scenarios/transport.py +++ b/vmas/scenarios/transport.py @@ -81,8 +81,8 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs): self.package_mass = kwargs.get("package_mass", 3) # partial obs - self.partial_observations = kwargs.get("partial_observations", False) - self.package_observation_radius = kwargs.get("package_observation_radius", 0.35) + self.partial_observations = kwargs.get("partial_observations", True) + selfpackage_observation_dist = kwargs.get("package_observation_radius", 0.35) # realism self.linear_friction = kwargs.get("linear_friction", 0.01) @@ -131,6 +131,7 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs): # Add agents capabilities = [] # save capabilities for relative capabilities later + self.observation_sensors = [] # for partial observability for i in range(n_agents): max_linear_vel = self.default_agent_max_linear_vel * random.uniform(self.capability_mult_min, self.capability_mult_max) max_angular_vel = self.default_agent_max_angular_vel * random.uniform(self.capability_mult_min, self.capability_mult_max) @@ -152,6 +153,21 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs): world.add_agent(agent) + # add the observation sensor if partial observability is turned on. + if self.partial_observations: + self.observation_sensors.append( + Landmark( + name=f'obs_sensor_agent_{i}', + collide=False, + + shape=Sphere(radius=selfpackage_observation_dist+radius), + color=(0.827, 0.827, 0.827, 0.65), + movable=False, + ) + ) + world.add_landmark(self.observation_sensors[-1]) + + self.capabilities = torch.tensor(capabilities) # Add landmarks @@ -256,6 +272,11 @@ def reset_world_at(self, env_index: int = None): ), occupied_positions=package_occupied_pos, ) + + # spawn the sensor radius for each agent + if self.partial_observations: + for i, agent_i_sensor in enumerate(self.observation_sensors): + agent_i_sensor.set_pos(self.world.agents[i].state.pos, env_index) self.package_starting_dists = [] self.og_package_positions = [] @@ -444,12 +465,17 @@ def partial_observation(self, agent: Agent): # get positions of all entities in this agent's reference frame package_obs = [] out_of_obs_val = -0.0001 # default value used for out-of-observation data in the observation vector + + # spawn the sensor radius for each agent + for i, agent_i_sensor in enumerate(self.observation_sensors): + agent_i_sensor.set_pos(self.world.agents[i].state.pos, None) + for i, package in enumerate(self.packages): # box starting position and goal position alway part of the observation package_obs.append(self.og_package_positions[i]) package_obs.append(package.on_goal.unsqueeze(-1)) - mask = (torch.linalg.vector_norm(package.state.pos - agent.state.pos, dim=-1) < self.package_observation_radius) + mask = self.world.is_overlapping(self.observation_sensors[i], package) pkg_state_vec = package.state.pos.clone() pkg_rot_vec = package.state.rot.clone() pkg_vel_vec = package.state.vel.clone() @@ -606,15 +632,6 @@ def extra_render(self, env_index: int = 0) -> "List[Geom]": geoms: List[Geom] = [] if not self.partial_observations: return geoms - - for i, agent in enumerate(self.world.agents): - - obs_circle = rendering.make_circle(self.package_observation_radius, filled=True) - xform = rendering.Transform() - xform.set_translation(*agent.state.pos[env_index]) - obs_circle.add_attr(xform) - obs_circle.set_color(*(0.827, 0.827, 0.827, 0.65)) - geoms.append(obs_circle) return geoms From 945730ff8fb48678891c93877fe2e2ccb1f2f170 Mon Sep 17 00:00:00 2001 From: piercehowell Date: Wed, 22 May 2024 11:29:51 -0400 Subject: [PATCH 2/2] fixed bug where the sensor shape wasn't being reset. --- vmas/scenarios/transport.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/vmas/scenarios/transport.py b/vmas/scenarios/transport.py index 3caa3c4c..b781f784 100644 --- a/vmas/scenarios/transport.py +++ b/vmas/scenarios/transport.py @@ -82,7 +82,7 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs): # partial obs self.partial_observations = kwargs.get("partial_observations", True) - selfpackage_observation_dist = kwargs.get("package_observation_radius", 0.35) + self.package_observation_dist = kwargs.get("package_observation_dist", 0.35) # realism self.linear_friction = kwargs.get("linear_friction", 0.01) @@ -159,8 +159,7 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs): Landmark( name=f'obs_sensor_agent_{i}', collide=False, - - shape=Sphere(radius=selfpackage_observation_dist+radius), + shape=Sphere(radius=self.package_observation_dist+radius), color=(0.827, 0.827, 0.827, 0.65), movable=False, ) @@ -207,7 +206,7 @@ def reset_world_at(self, env_index: int = None): # only do this during batched resets! if not env_index: capabilities = [] # save capabilities for relative capabilities later - for agent in self.world.agents: + for i, agent in enumerate(self.world.agents): max_linear_vel = self.default_agent_max_linear_vel * random.uniform(self.capability_mult_min, self.capability_mult_max) max_angular_vel = self.default_agent_max_angular_vel * random.uniform(self.capability_mult_min, self.capability_mult_max) radius = self.default_agent_radius * random.uniform(self.capability_mult_min, self.capability_mult_max) @@ -219,6 +218,12 @@ def reset_world_at(self, env_index: int = None): agent.shape=Sphere(radius) agent.mass=mass + # spawn the sensor radius for each agent + if self.partial_observations: + self.observation_sensors[i].set_pos(self.world.agents[i].state.pos, env_index) + self.observation_sensors[i].shape = Sphere(self.package_observation_dist+radius) + + self.capabilities = torch.tensor(capabilities) # spawn goal at origin @@ -273,11 +278,6 @@ def reset_world_at(self, env_index: int = None): occupied_positions=package_occupied_pos, ) - # spawn the sensor radius for each agent - if self.partial_observations: - for i, agent_i_sensor in enumerate(self.observation_sensors): - agent_i_sensor.set_pos(self.world.agents[i].state.pos, env_index) - self.package_starting_dists = [] self.og_package_positions = [] for i, package in enumerate(self.packages):