diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 4d38174c2c8..954f205e7de 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -159,7 +159,7 @@ def __init__( assert self.split_P_arrive % 32 == 0 assert self.split_P_arrive < self.n_block_size self.arch = BaseDSL._get_dsl().get_arch_enum() - assert self.arch >= Arch.sm_100 and self.arch <= Arch.sm_110f, "Only SM 10.x and 11.x are supported" + assert self.arch.value[0] in [10, 11], "Only SM 10.x and 11.x are supported" self.cta_group_size = 2 if self.use_2cta_instrs else 1 # cta_tiler M includes only 1 CTA, the scheduler will take into account the cluster shape