Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion qa/L0_pytorch_debug_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ NVTE_TORCH_COMPILE=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_api_features.xml
pytest -v -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_perf.py"

# standard sanity and numerics tests with initialized debug
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "debug test_sanity.py"
NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "debug test_sanity.py"
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "debug test_numerics.py"

if [ "$RET" -ne 0 ]; then
Expand Down
6 changes: 3 additions & 3 deletions qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ mkdir -p "$XML_LOG_DIR"

pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"

python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py"
NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py"
Expand All @@ -37,11 +37,11 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_quantized_tensor
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_grouped_tensor.xml $TE_PATH/tests/pytorch/test_grouped_tensor.py || test_fail "test_grouped_tensor.py"
NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_grouped_tensor.xml $TE_PATH/tests/pytorch/test_grouped_tensor.py || test_fail "test_grouped_tensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py"
NVTE_CUTEDSL_FUSED_GROUPED_MLP=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py"
NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 NVTE_CUTEDSL_FUSED_GROUPED_MLP=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_backward_override.xml $TE_PATH/tests/pytorch/test_backward_override.py || test_fail "test_backward_override.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
Expand Down
69 changes: 53 additions & 16 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3084,7 +3084,10 @@ def _make_zero_tokens_grouped_tensor(logical_last_dim, is_a):
weight_tensors = [torch.randn(n, k, dtype=dtype, device=device) for _ in range(z)]
if use_mxfp8:
grouped_A = _make_grouped_tensor_quantized_mxfp8(
weight_tensors, is_a=True, transposed=transa, device=device
weight_tensors,
rowwise=transa,
columnwise=not transa,
device=device,
)
else:
grouped_A = _make_grouped_tensor_uniform(z, n, k, device, dtype)
Expand Down Expand Up @@ -3138,36 +3141,61 @@ def _make_zero_tokens_grouped_tensor(logical_last_dim, is_a):
def _make_grouped_tensor_quantized_mxfp8(
tensors: List[torch.Tensor],
*,
is_a: bool,
transposed: bool,
rowwise: bool,
columnwise: bool,
device: torch.device,
optimize_for_gemm: bool = True,
is_weight: bool = False,
) -> GroupedTensor:
"""Create a quantized MXFP8 GroupedTensor from a list of per-expert tensors.

For weights (uniform per-expert shape), we generally won't keep it swizzled since we
might need for future dequantize operations. Swizzling is done internally within
general_grouped_gemm_for_grouped_tensor call.

For non-weight tensors (inputs / grad_outputs), we still pass
``first_dims`` and keep ``optimize_for_gemm=True``; so the kernel must emit the
already-swizzled layout up front.
"""
if not tensors:
raise ValueError("Expected non-empty tensor list for grouped quantization.")
if is_a:
rowwise = transposed
columnwise = not transposed
else:
rowwise = not transposed
columnwise = transposed
quantizer = MXFP8Quantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
rowwise=rowwise,
columnwise=columnwise,
)
quantizer.optimize_for_gemm = optimize_for_gemm
quantizer.optimize_for_gemm = not is_weight
grouped_input = torch.cat(tensors, dim=0)
first_dims = torch.tensor([t.shape[0] for t in tensors], dtype=torch.int64, device=device)
if is_weight:
first_dims = None
else:
first_dims = torch.tensor([t.shape[0] for t in tensors], dtype=torch.int64, device=device)
return tex.group_quantize(grouped_input, quantizer, len(tensors), first_dims)


def _per_tensor_quantize_mxfp8(
tensors: List[torch.Tensor],
*,
rowwise: bool,
columnwise: bool,
) -> List:
"""Quantize each tensor individually with MXFP8.
Used to build reference discrete inputs for grouped GEMM.
"""
quantizer = MXFP8Quantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
rowwise=rowwise,
columnwise=columnwise,
)
return [quantizer(t) for t in tensors]


@pytest.mark.parametrize(
"shape",
[
(1, 128, 128, 512),
(8, 1024, 128, 512),
(16, 4096, 128, 512),
(2, 256, 2880, 2880),
],
)
@pytest.mark.parametrize("accumulate", [False, True])
Expand Down Expand Up @@ -3208,12 +3236,21 @@ def test_grouped_gemm_grouped_tensor_mxfp8(

transa = layout[0] == "T"
transb = layout[1] == "T"
grouped_A = _make_grouped_tensor_quantized_mxfp8(A, is_a=True, transposed=transa, device="cuda")
a_is_weight = all(t.shape == A[0].shape for t in A)
a_rowwise, a_columnwise = transa, not transa
b_rowwise, b_columnwise = not transb, transb
grouped_A = _make_grouped_tensor_quantized_mxfp8(
A,
rowwise=a_rowwise,
columnwise=a_columnwise,
device="cuda",
is_weight=a_is_weight,
)
grouped_B = _make_grouped_tensor_quantized_mxfp8(
B, is_a=False, transposed=transb, device="cuda"
B, rowwise=b_rowwise, columnwise=b_columnwise, device="cuda"
)
A_fp8 = grouped_A.split_into_quantized_tensors()
B_fp8 = grouped_B.split_into_quantized_tensors()
A_fp8 = _per_tensor_quantize_mxfp8(A, rowwise=a_rowwise, columnwise=a_columnwise)
B_fp8 = _per_tensor_quantize_mxfp8(B, rowwise=b_rowwise, columnwise=b_columnwise)

general_grouped_gemm(
A_fp8,
Expand Down
7 changes: 5 additions & 2 deletions transformer_engine/common/cast/mxfp8/swizzle.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ namespace dispatch {
namespace mxfp8 {
namespace swizzle {

constexpr size_t GEMM_SWIZZLED_SCALE_TILE_DIM_X = 4;
constexpr size_t GEMM_SWIZZLED_SCALE_TILE_DIM_Y = 128;

/*! \brief Convert compact scale indices into GEMM swizzled scale index
*
* MXFP8 GEMM expects scaling factors to be in a "swizzled" order
Expand All @@ -25,8 +28,8 @@ namespace swizzle {
*
*/
__device__ __forceinline__ size_t gemm_swizzled_scale_idx(size_t i, size_t j, size_t num_tiles_X) {
constexpr size_t TILE_DIM_X = 4; // Tile dim in scale buffer
constexpr size_t TILE_DIM_Y = 128;
constexpr size_t TILE_DIM_X = GEMM_SWIZZLED_SCALE_TILE_DIM_X;
constexpr size_t TILE_DIM_Y = GEMM_SWIZZLED_SCALE_TILE_DIM_Y;
constexpr size_t TILE_SIZE = TILE_DIM_X * TILE_DIM_Y;
const size_t tile_idx_X = j / TILE_DIM_X;
const size_t tile_idx_Y = i / TILE_DIM_Y;
Expand Down
65 changes: 59 additions & 6 deletions transformer_engine/common/gemm/cublaslt_grouped_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <type_traits>
#include <vector>

#include "../cast/mxfp8/swizzle.cuh"
#include "../common.h"
#include "../util/cuda_runtime.h"
#include "../util/handle_manager.h"
Expand Down Expand Up @@ -330,6 +331,7 @@ struct GroupedOperandSelection {
NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING;
bool with_gemm_swizzled_scales = false;
bool trans = false;
bool rowwise = true;
};

constexpr int kMaxGroups = 64;
Expand Down Expand Up @@ -613,6 +615,7 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine::
sel.dptr = static_cast<char *>(t->columnwise_data.dptr);
sel.scale_inv = t->columnwise_scale_inv.dptr;
sel.dtype = col_dtype;
sel.rowwise = false;
sel.shape = create_shape_info(t, swap_dims);
};

Expand All @@ -621,6 +624,7 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine::
sel.dptr = static_cast<char *>(t->data.dptr);
sel.scale_inv = t->scale_inv.dptr;
sel.dtype = row_dtype;
sel.rowwise = true;
sel.shape = create_shape_info(t, /*swap_dims=*/false);
};

Expand Down Expand Up @@ -846,6 +850,45 @@ __forceinline__ __device__ int64_t compute_grouped_tensor_offset(const TensorSha
}
}

__forceinline__ __device__ int64_t padded_mxfp8_scale_inv_bytes(int64_t first, int64_t last,
bool rowwise) {
namespace mxfp8_swizzle = transformer_engine::dispatch::mxfp8::swizzle;
constexpr int64_t kMxfp8BlockSize = 32;
// x is the dimension along which quantization is applied, y is other dimension
const int64_t scale_tile_y = static_cast<int64_t>(mxfp8_swizzle::GEMM_SWIZZLED_SCALE_TILE_DIM_Y);
const int64_t scale_tile_x = static_cast<int64_t>(mxfp8_swizzle::GEMM_SWIZZLED_SCALE_TILE_DIM_X);
// Padded byte size of the swizzled MXFP8 scale_inv for a single tensor with data
// shape (first, last). Rowwise scales use rows=first, cols=last; columnwise
// scales swap the orientation since they are stored in column-major order.
const int64_t scale_dim_y = rowwise ? first : last;
const int64_t padded_scale_dim_y =
((scale_dim_y + scale_tile_y - 1) / scale_tile_y) * scale_tile_y;
const int64_t data_dim_x = rowwise ? last : first;
const int64_t scale_dim_x = (data_dim_x + kMxfp8BlockSize - 1) / kMxfp8BlockSize;
const int64_t padded_scale_dim_x =
((scale_dim_x + scale_tile_x - 1) / scale_tile_x) * scale_tile_x;
// MXFP8 scales are E8M0 (1 byte per element), so element count == byte count.
return padded_scale_dim_y * padded_scale_dim_x;
}

// Device helper: byte offset into a contiguous grouped MXFP8 scale_inv buffer for
// tensor `idx`. Each expert's scale_inv is expected to be padded
// to the 128x4 swizzled layout.
__forceinline__ __device__ int64_t compute_grouped_tensor_mxfp8_scale_inv_offset(
const TensorShapeInfo &meta, size_t idx, bool rowwise) {
if (meta.first_dims != nullptr || meta.last_dims != nullptr) {
int64_t cumsum = 0;
for (size_t i = 0; i < idx; i++) {
const int64_t f = meta.first_dims ? meta.first_dims[i] : meta.uniform_first;
const int64_t l = meta.last_dims ? meta.last_dims[i] : meta.uniform_last;
cumsum += padded_mxfp8_scale_inv_bytes(f, l, rowwise);
}
return cumsum;
}
return static_cast<int64_t>(idx) *
padded_mxfp8_scale_inv_bytes(meta.uniform_first, meta.uniform_last, rowwise);
}

// Linear scan to find which tensor contains the given row.
// Returns the tensor index and writes the exclusive end-row of that tensor to *out_tensor_row_end.
__forceinline__ __device__ int find_tensor_for_row(const int64_t *first_dims, int64_t uniform_first,
Expand Down Expand Up @@ -977,7 +1020,8 @@ __global__ void setup_grouped_gemm_kernel(
size_t b_elem_size, size_t c_elem_size, size_t d_elem_size, float *alpha_ptr, float *beta_ptr,
// Scale inputs: for tensor scaling, pass float* and set mxfp8_base to nullptr
// For MXFP8, pass nullptr for tensor_scale and set mxfp8_base
float *a_scale_base, float *b_scale_base, NVTEScalingMode scaling_mode, size_t num_tensors,
float *a_scale_base, float *b_scale_base, bool a_rowwise, bool b_rowwise,
NVTEScalingMode scaling_mode, size_t num_tensors,
MultiTensorGroupGemmInputArgs a_multi_tensor_args,
MultiTensorGroupGemmOutputArgs c_multi_tensor_args,
MultiTensorGroupGemmOutputArgs d_multi_tensor_args) {
Expand Down Expand Up @@ -1038,12 +1082,13 @@ __global__ void setup_grouped_gemm_kernel(

// Fill scale pointers (per-matrix).
// The interpretation of the scale buffers depends on the shared scaling recipe:
// NVTE_MXFP8_1D_SCALING : E8M0 byte stream; offset = data_offset / 32 elements
// otherwise : one float per tensor, indexed by tensor index
if (a_scale_base) {
if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
const int64_t a_scale_offset =
compute_grouped_tensor_mxfp8_scale_inv_offset(A_meta, idx, a_rowwise);
a_scale_inv_ptrs[idx] = reinterpret_cast<void *>(
static_cast<char *>(static_cast<void *>(a_scale_base)) + a_offset / 32);
static_cast<char *>(static_cast<void *>(a_scale_base)) + a_scale_offset);
} else {
a_scale_inv_ptrs[idx] = static_cast<float *>(a_scale_base) + idx;
}
Expand All @@ -1052,8 +1097,10 @@ __global__ void setup_grouped_gemm_kernel(
}
if (b_scale_base) {
if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
const int64_t b_scale_offset =
compute_grouped_tensor_mxfp8_scale_inv_offset(B_meta, idx, b_rowwise);
b_scale_inv_ptrs[idx] = reinterpret_cast<void *>(
static_cast<char *>(static_cast<void *>(b_scale_base)) + b_offset / 32);
static_cast<char *>(static_cast<void *>(b_scale_base)) + b_scale_offset);
} else {
b_scale_inv_ptrs[idx] = static_cast<float *>(b_scale_base) + idx;
}
Expand Down Expand Up @@ -1116,14 +1163,19 @@ inline void launch_grouped_gemm_setup(

// A and B share the same scaling recipe (validated in validate_grouped_gemm_inputs).
// Pass scale buffers as void* and let the kernel interpret them via scaling_mode.

// Scale rowwise flag for MXFP8/NVFP4: to calculate scale_inv padding based offsets
// within kernel. Ignored for tensor scaling.
const bool a_rowwise = A_sel.rowwise;
const bool b_rowwise = B_sel.rowwise;
setup_grouped_gemm_kernel<<<num_blocks, threads_per_block, 0, stream>>>(
ws.A_ptrs, ws.B_ptrs, ws.C_ptrs, ws.D_ptrs, ws.a_rows, ws.a_cols, ws.b_rows, ws.b_cols,
ws.d_rows, ws.d_cols, ws.alpha_ptrs, ws.beta_ptrs, ws.a_scale_inv_ptrs, ws.b_scale_inv_ptrs,
A_sel.dptr, B_sel.dptr, c_base, d_base, A_meta, B_meta, C_meta, D_meta, a_elem_size,
b_elem_size, c_elem_size, d_elem_size, static_cast<float *>(alpha_tensor->data.dptr),
static_cast<float *>(beta_tensor->data.dptr), reinterpret_cast<float *>(A_sel.scale_inv),
reinterpret_cast<float *>(B_sel.scale_inv), A_sel.scaling_mode, num_tensors,
a_multi_tensor_args, c_multi_tensor_args, d_multi_tensor_args);
reinterpret_cast<float *>(B_sel.scale_inv), a_rowwise, b_rowwise, A_sel.scaling_mode,
num_tensors, a_multi_tensor_args, c_multi_tensor_args, d_multi_tensor_args);

NVTE_CHECK_CUDA(cudaGetLastError());
}
Expand Down Expand Up @@ -1276,6 +1328,7 @@ void nvte_grouped_gemm_with_discrete_inputA(const NVTETensor *A_list, size_t num
choose_grouped_operand_storage(static_cast<bool>(transa), /*is_A=*/true, mxfp8, is_fp8,
non_tn_fp8_ok, A_list_info.all_row, A_list_info.all_col, "A");
A_sel.trans = choice.trans;
A_sel.rowwise = choice.use_rowwise;
if (choice.use_rowwise) {
NVTE_CHECK(A_list_info.all_row, "Grouped GEMM: A_list is missing row-wise data");
A_sel.dtype = A_list_info.row_dtype;
Expand Down
12 changes: 10 additions & 2 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
clear_tensor_data,
init_method_constant,
requires_grad,
resolve_grouped_linear_single_param_flags,
get_nvtx_range_context,
)
from ..distributed import (
Expand Down Expand Up @@ -659,11 +660,15 @@ class GroupedLinear(TransformerEngineBaseModule):
single_grouped_weight : bool, default = False
If set to ``True``, grouped weights are stored as a single grouped parameter
instead of one parameter per GEMM.
EXPERIMENTAL and subject to change.
EXPERIMENTAL and subject to change. Gated by the
``NVTE_GROUPED_LINEAR_SINGLE_PARAM`` environment variable: if the env var
is not set this argument is forced to ``False`` with a warning.
single_grouped_bias : bool, default = False
If set to ``True``, grouped biases are stored as a single grouped bias
instead of one bias per GEMM.
EXPERIMENTAL and subject to change.
EXPERIMENTAL and subject to change. Gated by the
``NVTE_GROUPED_LINEAR_SINGLE_PARAM`` environment variable: if the env var
is not set this argument is forced to ``False`` with a warning.

Notes
-----
Expand Down Expand Up @@ -712,6 +717,9 @@ def __init__(
self.ub_overlap_ag = ub_overlap_ag
self.ub_name = ub_name
self.save_original_input = save_original_input
single_grouped_weight, single_grouped_bias = resolve_grouped_linear_single_param_flags(
single_grouped_weight, single_grouped_bias
)
self.single_grouped_weight = single_grouped_weight
self.single_grouped_bias = single_grouped_bias
if ub_overlap_rs or ub_overlap_ag:
Expand Down
Loading
Loading