Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
408 changes: 173 additions & 235 deletions genesis/engine/solvers/rigid/collider/box_contact.py

Large diffs are not rendered by default.

10 changes: 2 additions & 8 deletions genesis/engine/solvers/rigid/collider/broadphase.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,12 @@
import genesis as gs
import genesis.utils.array_class as array_class

from .utils import (
func_is_geom_aabbs_overlap,
)
from .utils import func_is_geom_aabbs_overlap


@qd.func
def func_find_intersect_midpoint(
i_ga,
i_gb,
i_b,
geoms_state: array_class.GeomsState,
geoms_info: array_class.GeomsInfo,
i_ga, i_gb, i_b, geoms_state: array_class.GeomsState, geoms_info: array_class.GeomsInfo
):
# return the center of the intersecting AABB of AABBs of two geoms
intersect_lower = qd.max(geoms_state.aabb_min[i_ga, i_b], geoms_state.aabb_min[i_gb, i_b])
Expand Down
42 changes: 12 additions & 30 deletions genesis/engine/solvers/rigid/collider/capsule_contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,7 @@


@qd.func
def func_capsule_capsule_contact(
i_ga,
i_gb,
ga_pos,
ga_quat,
gb_pos,
gb_quat,
geoms_info: array_class.GeomsInfo,
rigid_global_info: array_class.RigidGlobalInfo,
):
def func_capsule_capsule_contact(i_ga, i_gb, ga_pos, ga_quat, gb_pos, gb_quat, gst: array_class.GlobalState):
"""
Analytical capsule-capsule collision detection.

Expand All @@ -31,19 +22,19 @@ def func_capsule_capsule_contact(
ga_pos, ga_quat : Position and orientation of capsule A (may be perturbed for multi-contact).
gb_pos, gb_quat : Position and orientation of capsule B (may be perturbed for multi-contact).
"""
EPS = rigid_global_info.EPS[None]
EPS = gst.rigid_global_info.EPS[None]

# Get capsule A parameters
pos_a = ga_pos
quat_a = ga_quat
radius_a = geoms_info.data[i_ga][0]
halflength_a = 0.5 * geoms_info.data[i_ga][1]
radius_a = gst.geoms_info.data[i_ga][0]
halflength_a = 0.5 * gst.geoms_info.data[i_ga][1]

# Get capsule B parameters
pos_b = gb_pos
quat_b = gb_quat
radius_b = geoms_info.data[i_gb][0]
halflength_b = 0.5 * geoms_info.data[i_gb][1]
radius_b = gst.geoms_info.data[i_gb][0]
halflength_b = 0.5 * gst.geoms_info.data[i_gb][1]

# Capsules are aligned along local Z-axis by convention
local_z_unit = qd.Vector([0.0, 0.0, 1.0], dt=gs.qd_float)
Expand Down Expand Up @@ -99,16 +90,7 @@ def func_capsule_capsule_contact(


@qd.func
def func_sphere_capsule_contact(
i_ga,
i_gb,
ga_pos,
ga_quat,
gb_pos,
gb_quat,
geoms_info: array_class.GeomsInfo,
rigid_global_info: array_class.RigidGlobalInfo,
):
def func_sphere_capsule_contact(i_ga, i_gb, ga_pos, ga_quat, gb_pos, gb_quat, gst: array_class.GlobalState):
"""
Analytical sphere-capsule collision detection.

Expand All @@ -122,25 +104,25 @@ def func_sphere_capsule_contact(
ga_pos, ga_quat : Position and orientation of geom A (may be perturbed for multi-contact).
gb_pos, gb_quat : Position and orientation of geom B (may be perturbed for multi-contact).
"""
EPS = rigid_global_info.EPS[None]
EPS = gst.rigid_global_info.EPS[None]

# Ensure sphere is always i_ga and capsule is i_gb
normal_dir = 1
sphere_center = ga_pos
capsule_center = gb_pos
capsule_q = gb_quat
if geoms_info.type[i_gb] == gs.GEOM_TYPE.SPHERE:
if gst.geoms_info.type[i_gb] == gs.GEOM_TYPE.SPHERE:
i_ga, i_gb = i_gb, i_ga
sphere_center = gb_pos
capsule_center = ga_pos
capsule_q = ga_quat
normal_dir = -1

sphere_radius = geoms_info.data[i_ga][0]
sphere_radius = gst.geoms_info.data[i_ga][0]

capsule_quat = capsule_q
capsule_radius = geoms_info.data[i_gb][0]
capsule_halflength = 0.5 * geoms_info.data[i_gb][1]
capsule_radius = gst.geoms_info.data[i_gb][0]
capsule_halflength = 0.5 * gst.geoms_info.data[i_gb][1]

# Capsule is aligned along local Z-axis
local_z_unit = qd.Vector([0.0, 0.0, 1.0], dt=gs.qd_float)
Expand Down
178 changes: 77 additions & 101 deletions genesis/engine/solvers/rigid/collider/collider.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
terrain), and contact management.
"""

import functools
import math
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -551,8 +552,7 @@ def _init_terrain_state(self):
rc = np.array(entity.terrain_hf.shape, dtype=gs.np_int)
hf = entity.terrain_hf.astype(gs.np_float) * scale[1]
xyz_maxmin = np.array(
[rc[0] * scale[0], rc[1] * scale[0], hf.max(), 0, 0, hf.min() - 1.0],
dtype=gs.np_float,
[rc[0] * scale[0], rc[1] * scale[0], hf.max(), 0, 0, hf.min() - 1.0], dtype=gs.np_float
)

self._collider_info.terrain_hf.from_numpy(hf)
Expand Down Expand Up @@ -586,7 +586,12 @@ def reset(self, envs_idx=None, *, cache_only: bool = True) -> None:
return

envs_idx = self._solver._scene._sanitize_envs_idx(envs_idx)
collider_kernel_reset(envs_idx, self._solver._static_rigid_sim_config, self._collider_state, cache_only)
collider_kernel_reset(
envs_idx,
self._solver._static_rigid_sim_config,
self._gst,
cache_only,
)

def clear(self, envs_idx=None):
self.reset(envs_idx, cache_only=False)
Expand Down Expand Up @@ -648,44 +653,64 @@ def clear(self, envs_idx=None):
fn = kernel_collider_clear
fn(
envs_idx,
self._solver.links_state,
self._solver.links_info,
self._gst,
self._solver._static_rigid_sim_config,
self._collider_state,
)

@functools.cached_property
def _gst(self) -> array_class.GlobalState:
"""Lazily-built GlobalState bundling every dataclass narrowphase wants.

The fields are stable references for the lifetime of the Collider, so we build
``GlobalState`` exactly once -- on first access -- and reuse the instance at every call
site. We can't build it in ``__init__`` because ``self._solver.constraint_solver`` is
created later in the solver's bring-up; deferring to first use sidesteps that ordering
constraint. Algorithm-local scratch buffers (MPRState / GJKState) are not part of
``GlobalState`` and are passed explicitly to the kernels that need them.
"""
s = self._solver
return array_class.GlobalState(
rigid_global_info=s._rigid_global_info,
constraint_state=s.constraint_solver.constraint_state,
collider_state=self._collider_state,
collider_info=self._collider_info,
mpr_info=self._mpr._mpr_info,
gjk_info=self._gjk._gjk_info,
sdf_info=self._sdf._sdf_info,
support_field_info=self._support_field._support_field_info,
dofs_state=s.dofs_state,
dofs_info=s.dofs_info,
links_state=s.links_state,
links_info=s.links_info,
joints_state=s.joints_state,
joints_info=s.joints_info,
geoms_state=s.geoms_state,
geoms_info=s.geoms_info,
verts_info=s.verts_info,
faces_info=s.faces_info,
edges_info=s.edges_info,
equalities_info=s.equalities_info,
entities_state=s.entities_state,
entities_info=s.entities_info,
errno=s._errno,
)

def _call_multicontact(self):
narrowphase._func_narrowphase_multicontact_mixed(
self._solver.links_state,
self._solver.links_info,
self._solver.geoms_state,
self._solver.geoms_info,
self._solver.geoms_init_AABB,
self._solver.verts_info,
self._solver.faces_info,
self._solver._rigid_global_info,
self._solver._static_rigid_sim_config,
self._collider_state,
self._collider_info,
self._collider_static_config,
self._gst,
self._multicontact_mpr_state,
self._mpr._mpr_info,
self._multicontact_gjk_state,
self._gjk._gjk_info,
self._solver._static_rigid_sim_config,
self._collider_static_config,
self._gjk._gjk_static_config,
self._support_field._support_field_info,
self._multicontact_gjk_state.diff_contact_input,
self._solver._errno,
self._multicontact_n_gjk_threads,
self._multicontact_n_total_threads,
self._multicontact_max_items_per_thread,
)

def detection(self) -> None:
rigid_solver.kernel_update_geom_aabbs(
self._solver.geoms_state,
self._solver.geoms_init_AABB,
self._solver._static_rigid_sim_config,
self._solver.geoms_state, self._solver.geoms_init_AABB, self._solver._static_rigid_sim_config
)

if self._n_possible_pairs == 0:
Expand All @@ -706,105 +731,51 @@ def detection(self) -> None:
self._solver._errno,
)
if self._use_split_narrowphase:
narrowphase._func_reset_narrowphase_work_queues(
self._collider_state,
)
narrowphase._func_reset_narrowphase_work_queues(self._gst)
narrowphase._func_narrowphase_contact0(
self._solver.geoms_state,
self._solver.geoms_info,
self._solver.geoms_init_AABB,
self._solver.verts_info,
self._solver._rigid_global_info,
self._solver._static_rigid_sim_config,
self._collider_state,
self._collider_info,
self._collider_static_config,
self._gst,
self._contact0_mpr_state,
self._mpr._mpr_info,
self._contact0_gjk_state,
self._gjk._gjk_info,
self._support_field._support_field_info,
self._solver._errno,
self._solver._static_rigid_sim_config,
self._collider_static_config,
self._solver._B,
self._contact0_n_chunks,
)
self._call_multicontact()
narrowphase._func_prepare_gjk_rerun(self._collider_state)
narrowphase._func_prepare_gjk_rerun(self._gst)
self._call_multicontact()
elif self._collider_static_config.has_non_box_plane_convex_convex:
narrowphase.func_narrow_phase_convex_vs_convex(
self._solver.links_state,
self._solver.links_info,
self._solver.geoms_state,
self._solver.geoms_info,
self._solver.geoms_init_AABB,
self._solver.verts_info,
self._solver.faces_info,
self._solver.edges_info,
self._solver._rigid_global_info,
self._solver._static_rigid_sim_config,
self._collider_state,
self._collider_info,
self._collider_static_config,
self._gst,
self._mpr._mpr_state,
self._mpr._mpr_info,
self._gjk._gjk_state,
self._gjk._gjk_info,
self._solver._static_rigid_sim_config,
self._collider_static_config,
self._gjk._gjk_static_config,
self._sdf._sdf_info,
self._support_field._support_field_info,
self._gjk._gjk_state.diff_contact_input,
self._solver._errno,
)
if self._collider_static_config.has_convex_specialization:
func_narrow_phase_convex_specializations(
self._solver.geoms_state,
self._solver.geoms_info,
self._solver.geoms_init_AABB,
self._solver.verts_info,
self._solver._rigid_global_info,
self._gst,
self._solver._static_rigid_sim_config,
self._collider_state,
self._collider_info,
self._collider_static_config,
self._solver._errno,
)
if self._collider_static_config.has_terrain:
func_narrow_phase_any_vs_terrain(
self._solver.geoms_state,
self._solver.geoms_info,
self._solver.geoms_init_AABB,
self._gst,
self._mpr._mpr_state,
self._solver._static_rigid_sim_config,
self._collider_state,
self._collider_info,
self._collider_static_config,
self._mpr._mpr_state,
self._mpr._mpr_info,
self._support_field._support_field_info,
self._solver._errno,
)
if self._collider_static_config.has_nonconvex_nonterrain:
func_narrow_phase_nonconvex_vs_nonterrain(
self._solver.links_state,
self._solver.links_info,
self._solver.geoms_state,
self._solver.geoms_info,
self._solver.geoms_init_AABB,
self._solver.verts_info,
self._solver.edges_info,
self._solver._rigid_global_info,
self._gst,
self._solver._static_rigid_sim_config,
self._collider_state,
self._collider_info,
self._collider_static_config,
self._sdf._sdf_info,
self._solver._errno,
)

if self._use_split_narrowphase:
func_clamp_and_sort_contacts(
self._collider_state,
self._collider_info,
self._gst,
self._solver._static_rigid_sim_config,
)

Expand Down Expand Up @@ -862,7 +833,11 @@ def get_contacts(self, as_tensor: bool = True, to_torch: bool = True, keep_batch
# Copy contact data
if n_contacts_max > 0:
collider_kernel_get_contacts(
as_tensor, iout, fout, self._solver._static_rigid_sim_config, self._collider_state
as_tensor,
iout,
fout,
self._solver._static_rigid_sim_config,
self._gst,
)

# Build structured view (no copy)
Expand Down Expand Up @@ -914,17 +889,18 @@ def get_contacts(self, as_tensor: bool = True, to_torch: bool = True, keep_batch
return contact_data.copy()

def backward(self, dL_dposition, dL_dnormal, dL_dpenetration):
func_set_upstream_grad(dL_dposition, dL_dnormal, dL_dpenetration, self._collider_state)
func_set_upstream_grad(
dL_dposition,
dL_dnormal,
dL_dpenetration,
self._gst,
)

# Compute gradient
func_narrow_phase_diff_convex_vs_convex.grad(
self._solver.geoms_state,
self._solver.geoms_info,
self._gst,
self._gjk._gjk_state,
self._solver._static_rigid_sim_config,
self._collider_state,
self._collider_info,
self._gjk._gjk_info,
self._collider_state.diff_contact_input,
)


Expand Down
Loading
Loading