Skip to content
69 changes: 43 additions & 26 deletions genesis/engine/solvers/rigid/collider/collider.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@
func_contact_orthogonals,
func_rotate_frame,
func_set_upstream_grad,
func_clamp_and_sort_contacts,
func_prune_contacts,
func_clamp_prune_and_sort_contacts,
func_prune_contacts_coop,
)
from . import narrowphase
from .narrowphase import (
Expand Down Expand Up @@ -221,27 +221,30 @@ def _init_collision_fields(self) -> None:
# 'contact_data_cache' is not used in Quadrants kernels, so keep it outside of the collider state / info
self._contact_data_cache: dict[tuple[bool, bool], dict[str, torch.Tensor | tuple[torch.Tensor]]] = {}

# Contact0 & multicontact scratch states only needed when split narrowphase is active.
# GPU core count (used by split-narrowphase chunking + the cooperative dedup dispatch gate).
# FIXME: Quadrants should expose a unified API to query GPU core count across all backends.
# Falling back to upper bound for backends where torch.cuda is unavailable (e.g., CPU-only torch). Benchmarks
# on RTX 6000 Blackwell (Genesis-Embodied-AI/Genesis#2616) showed that switching from hardcoded 40000 threads
# to hardware-derived 21760 had marginal performance impact, so it should be fine.
if torch.cuda.is_available():
gpu_props = torch.cuda.get_device_properties(torch.cuda.current_device())
# NVIDIA: 128 CUDA cores per SM. AMD/ROCm: 64 stream processors per CU.
cores_per_unit = 64 if torch.version.hip else 128
gpu_cores = gpu_props.multi_processor_count * cores_per_unit
elif gs.backend == gs.metal:
# Upper-bound estimate for Apple Silicon: 40 GPU cores, each GPU core having 128 ALUs
cores_per_unit = 128
gpu_cores = 5120
else:
# Using AMD GPU as a baseline. AMD MI350X has 256 SM (so-called Compute Units) with 64 cores each.
# See: https://www.amd.com/en/products/accelerators/instinct/mi350/mi350x.html
# For comparison, RTX6000 Blackwell boasts 188 SMs, compared to 170 SMs for RTX5090 with 128 cores each.
cores_per_unit = 64
gpu_cores = 16384
self._gpu_cores = gpu_cores

# Contact0 & multicontact scratch states only needed when split narrowphase is active.
if self._use_split_narrowphase:
if torch.cuda.is_available():
gpu_props = torch.cuda.get_device_properties(torch.cuda.current_device())
# NVIDIA: 128 CUDA cores per SM. AMD/ROCm: 64 stream processors per CU.
cores_per_unit = 64 if torch.version.hip else 128
gpu_cores = gpu_props.multi_processor_count * cores_per_unit
elif gs.backend == gs.metal:
# Upper-bound estimate for Apple Silicon: 40 GPU cores, each GPU core having 128 ALUs
cores_per_unit = 128
gpu_cores = 5120
else:
# Using AMD GPU as a baseline. AMD MI350X has 256 SM (so-called Compute Units) with 64 cores each.
# See: https://www.amd.com/en/products/accelerators/instinct/mi350/mi350x.html
# For comparison, RTX6000 Blackwell boasts 188 SMs, compared to 170 SMs for RTX5090 with 128 cores each.
cores_per_unit = 64
gpu_cores = 16384
self._contact0_n_chunks = max(1, math.ceil(gpu_cores / self._solver._B))
self._contact0_grid_size = self._solver._B * self._contact0_n_chunks
self._contact0_mpr_state = array_class.get_mpr_state(self._contact0_grid_size)
Expand Down Expand Up @@ -819,23 +822,37 @@ def detection(self) -> None:
self._solver._errno,
)

if (
self._collider_static_config.link_pair_pruning_supported
and self._solver._options.contact_pruning_tolerance is not None
# On GPU backends, when the scene is dedup-eligible and we're not in autodiff mode, dispatch the cooperative
# warp-per-env kernel (32 lanes/block; parallel reductions + lex-stride writes; serial sorts + hull build on
# lane 0; fused compact+spatial-sort in the final phase). This beats the serial fused kernel on dex_hand /
# g1_fall by spreading the per-env work across the warp instead of running one env per block thread, but only
# while the GPU has spare occupancy. Once n_envs exceeds half the GPU's CUDA core count, the coop launch fills
# the device and the warp-cooperation overhead stops paying off — the serial fused kernel (one thread per env)
# wins. The half-core threshold leaves room for the existing kernels already sharing the SMs (narrowphase,
# constraint solver) so we don't push them to second-wave scheduling. Everything else (CPU, autodiff, scenes
# where link_pair_pruning_supported=False, oversubscribed n_envs) falls through to the serial fused kernel,
# which has internal qd.static gates that drop the prune phases when they're not eligible.
ran_fused_dedup_coop = (
gs.backend != gs.cpu
and self._collider_static_config.link_pair_pruning_supported
and not self._solver._static_rigid_sim_config.requires_grad
):
func_prune_contacts(
and (self._solver._options.contact_pruning_tolerance or 0.0) > 0.0
and self._solver._B * 2 <= self._gpu_cores
)
if ran_fused_dedup_coop:
func_prune_contacts_coop(
self._collider_state,
self._collider_info,
self._solver._rigid_global_info,
self._solver._static_rigid_sim_config,
)

if self._use_split_narrowphase:
func_clamp_and_sort_contacts(
else:
func_clamp_prune_and_sort_contacts(
self._collider_state,
self._collider_info,
self._solver._rigid_global_info,
self._solver._static_rigid_sim_config,
self._collider_static_config,
)

def get_contacts(self, as_tensor: bool = True, to_torch: bool = True, keep_batch_dim: bool = False):
Expand Down
Loading
Loading