Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 42 additions & 191 deletions tests/ap/matmul/all_tuning_configs.h

Large diffs are not rendered by default.

70 changes: 70 additions & 0 deletions tests/ap/matmul/autotune.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#pragma once

#include "all_tuning_configs.h"
#include "default_config_id.h"
#include "profile.h"

namespace ap {

template <typename Runner, SwizzleType ST, int... Is>
auto GenerateFuncList(std::integer_sequence<int, Is...>) {
using FuncPtr = decltype(&Runner::template Apply<0, ST>);
return std::vector<FuncPtr>{&Runner::template Apply<Is, ST>...};
}

template <typename T, typename Runner, bool EnableStreamK, typename... Args>
int RunWithAutotune(cudaStream_t stream, int config_id, Args &&...args) {
#if AP_ENABLE_AUTOTUNE
int selected_config_id = config_id;

using FuncPtr = decltype(&Runner::template Apply<0, SwizzleType::kCommon>);
constexpr int N = ap::ConfigsInfo<T>::kNumTotals;

static std::vector<FuncPtr> matmul_functions;
static std::vector<FuncPtr> streamk_functions;

if (matmul_functions.empty()) {
matmul_functions = GenerateFuncList<Runner, SwizzleType::kCommon>(
std::make_integer_sequence<int, N>{});
}

if constexpr (EnableStreamK) {
if (streamk_functions.empty()) {
streamk_functions = GenerateFuncList<Runner, SwizzleType::kStreamK>(
std::make_integer_sequence<int, N>{});
}
}

if (selected_config_id == -1) {
selected_config_id = ap::ProfileBestConfig(matmul_functions, stream,
std::forward<Args>(args)...);
if constexpr (EnableStreamK) {
std::vector<FuncPtr> mixed_functions = {
matmul_functions[selected_config_id],
streamk_functions[selected_config_id]};
int mixed_config_id = ap::ProfileBestConfig(mixed_functions, stream,
std::forward<Args>(args)...);
selected_config_id = (mixed_config_id == 0) ? selected_config_id
: (selected_config_id + N);
}
} else {
if constexpr (EnableStreamK) {
if (selected_config_id < N) {
matmul_functions[selected_config_id](std::forward<Args>(args)...);
} else {
streamk_functions[selected_config_id - N](std::forward<Args>(args)...);
}
} else {
matmul_functions[selected_config_id](std::forward<Args>(args)...);
}
}

return selected_config_id;
#else
Runner::template Apply<DefaultConfig::kConfigId, SwizzleType::kCommon>(
std::forward<Args>(args)...);
return -1;
#endif
}

} // namespace ap
176 changes: 25 additions & 151 deletions tests/ap/matmul/cutlass_matmul.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,7 @@ template <typename ElementT,
bool TransposeA = false,
bool TransposeB = false,
int ConfigId = DefaultConfig::kConfigId,
int SwizzleFactor = DefaultConfig::kSwizzleFactor,
bool Batched = DefaultConfig::kBatched>
SwizzleType ST = DefaultConfig::kSwizzleType>
void CutlassMatmul(const GemmEpilogueParams& params) {
using ElementAccumulator = typename CutlassDataType<ElementComputeT>::Type; // <- data type of accumulator
using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations
Expand All @@ -107,12 +106,12 @@ void CutlassMatmul(const GemmEpilogueParams& params) {
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
typename GemmTuningConfigs<ElementT, SwizzleFactor, Batched, ConfigId>::TShape,
typename GemmTuningConfigs<ElementT, SwizzleFactor, Batched, ConfigId>::WShape,
typename GemmTuningConfigs<ElementT, SwizzleFactor, Batched, ConfigId>::IShape,
typename GemmTuningConfigs<ElementT, ConfigId>::TShape,
typename GemmTuningConfigs<ElementT, ConfigId>::WShape,
typename GemmTuningConfigs<ElementT, ConfigId>::IShape,
EpilogueOutputOp,
typename GemmTuningConfigs<ElementT, SwizzleFactor, Batched, ConfigId>::SwizzleThreadBlock,
GemmTuningConfigs<ElementT, SwizzleFactor, Batched, ConfigId>::kNumStages,
typename ThreadBlockSwizzle<ST>::Type,
GemmTuningConfigs<ElementT, ConfigId>::kNumStages,
128 / cutlass::sizeof_bits<ElementInputA>::value, // AlignA
128 / cutlass::sizeof_bits<ElementInputB>::value, // AlignB
typename GemmOperation<ElementT>::Type // Operation performed by GEMM
Expand Down Expand Up @@ -170,11 +169,10 @@ void CutlassMatmul(const GemmEpilogueParams& params) {
template <typename ElementT,
typename ElementComputeT,
template<typename T> class UnaryFunctor,
bool TransposeA = false,
bool TransposeB = false,
int AlignA = 128 / cutlass::sizeof_bits<ElementT>::value,
int AlignB = 128 / cutlass::sizeof_bits<ElementT>::value,
int ConfigId = DefaultConfig::kConfigId,
int SwizzleFactor = DefaultConfig::kSwizzleFactor,
bool Batched = DefaultConfig::kBatched>
SwizzleType ST = DefaultConfig::kSwizzleType>
void CutlassMatmulAddUnary(const GemmEpilogueParams& params, const typename UnaryFunctor<ElementComputeT>::Arguments& unary_args) {
using ElementAccumulator = typename CutlassDataType<ElementComputeT>::Type; // <- data type of accumulator
using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations
Expand All @@ -197,23 +195,23 @@ void CutlassMatmulAddUnary(const GemmEpilogueParams& params, const typename Unar

using GemmFunc = cutlass::gemm::device::GemmUniversal<
ElementInputA,
typename MatrixLayout<TransposeA>::Type,
cutlass::layout::RowMajor,
ElementInputB,
typename MatrixLayout<TransposeB>::Type,
cutlass::layout::RowMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
typename GemmTuningConfigs<ElementT, SwizzleFactor, Batched, ConfigId>::TShape,
typename GemmTuningConfigs<ElementT, SwizzleFactor, Batched, ConfigId>::WShape,
typename GemmTuningConfigs<ElementT, SwizzleFactor, Batched, ConfigId>::IShape,
typename GemmTuningConfigs<ElementT, ConfigId>::TShape,
typename GemmTuningConfigs<ElementT, ConfigId>::WShape,
typename GemmTuningConfigs<ElementT, ConfigId>::IShape,
EpilogueOutputOp,
typename GemmTuningConfigs<ElementT, SwizzleFactor, Batched, ConfigId>::SwizzleThreadBlock,
GemmTuningConfigs<ElementT, SwizzleFactor, Batched, ConfigId>::kNumStages,
128 / cutlass::sizeof_bits<ElementInputA>::value, // AlignA
128 / cutlass::sizeof_bits<ElementInputB>::value, // AlignB
typename GemmOperation<ElementT>::Type // Operation performed by GEMM
typename ThreadBlockSwizzle<ST>::Type,
GemmTuningConfigs<ElementT, ConfigId>::kNumStages,
AlignA,
AlignB,
typename GemmOperation<ElementT>::Type
>;

CHECK_CUTLASS(SetMaxDynamicSharedMemorySize<GemmFunc>());
Expand Down Expand Up @@ -265,137 +263,13 @@ void CutlassMatmulAddUnary(const GemmEpilogueParams& params, const typename Unar
#endif
}

template <typename ElementT,
typename ElementComputeT,
int ConfigId = DefaultConfig::kConfigId,
int SwizzleFactor = DefaultConfig::kSwizzleFactor,
bool Batched = DefaultConfig::kBatched>
void CutlassMatmulAddBroadcast(const GemmBroadcastEpilogueParams& params) {
using ElementAccumulator = typename CutlassDataType<ElementComputeT>::Type; // <- data type of accumulator
using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations
using ElementInputA = typename CutlassDataType<ElementT>::Type; // <- data type of elements in input matrix A
using ElementInputB = typename CutlassDataType<ElementT>::Type; // <- data type of elements in input matrix B
using ElementOutputC = typename CutlassDataType<ElementT>::Type;// <- data type of elements in output matrix D
using ElementOutputZ = ElementOutputC;
using ElementOutputT = ElementOutputC;

// Epilogue operation as LinearCombinationBiasElementwise:
// Y = GEMM(AB, C)
// T[i, j] = BinaryOp(Y[i, j], Broadcast[i])
// Z[i, j] = Elementwise(T[i, j])
using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasElementwise<
ElementOutputC,
ElementAccumulator,
ElementComputeEpilogue,
ElementOutputZ,
ElementOutputT,
128 / cutlass::sizeof_bits<ElementOutputC>::value
>;

// Epilogue operation as LinearCombinationResidualBlock:
// Y = GEMM(AB, C1)
// UnaryOp(BinaryOp2(BinaryOp1(ActivationOp(Y), residual1), residual2))
// using EpilogueOp = cutlass::epilogue::thread::LinearCombinationResidualBlock<
// ElementOutput, // Element type for output matrix
// ElementAccumulator, // Element type from internal accumulation
// ElementCompute, // Element type from internal accumulation
// ElementC, // Element type for C1/C2/D matrix operands
// AlignmentC, // Memory access granularity of C and D matrix in units of elements
// cutlass::epilogue::thread::Identity, // Activation
// cutlass::plus, // Binary operation 1
// cutlass::epilogue::thread::Identity, // Unary operation
// cutlass::plus // Binary operation 2
// >;

using GemmFunc = cutlass::gemm::device::GemmUniversalWithBroadcast<
ElementInputA,
cutlass::layout::RowMajor,
ElementInputB,
cutlass::layout::RowMajor,
ElementOutputC,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
typename GemmTuningConfigs<ElementT, SwizzleFactor, Batched, ConfigId>::TShape,
typename GemmTuningConfigs<ElementT, SwizzleFactor, Batched, ConfigId>::WShape,
typename GemmTuningConfigs<ElementT, SwizzleFactor, Batched, ConfigId>::IShape,
EpilogueOutputOp,
typename GemmTuningConfigs<ElementT, SwizzleFactor, Batched, ConfigId>::SwizzleThreadBlock,
GemmTuningConfigs<ElementT, SwizzleFactor, Batched, ConfigId>::kNumStages,
128 / cutlass::sizeof_bits<ElementInputA>::value, // AlignA
128 / cutlass::sizeof_bits<ElementInputB>::value, // AlignB
typename GemmOperation<ElementT>::Type // Operation performed by GEMM
>;

CHECK_CUTLASS(SetMaxDynamicSharedMemorySize<GemmFunc>());

/// Arguments
cutlass::gemm::GemmCoord problem_size{params.m, params.n, params.k};

const ElementInputA *input = reinterpret_cast<const ElementInputA *>(params.input);
const ElementInputB *weight = reinterpret_cast<const ElementInputB *>(params.weight);
const ElementOutputC *bias = reinterpret_cast<const ElementOutputC *>(params.bias);
ElementOutputZ *output = reinterpret_cast<ElementOutputZ *>(params.output);
ElementOutputC *broadcast = reinterpret_cast<ElementOutputC *>(params.broadcast);
ElementOutputT *broadcast_out = reinterpret_cast<ElementOutputT *>(params.broadcast_out);

const int64_t batch_stride_Broadcast = params.need_broadcast ? problem_size.m() : problem_size.m() * problem_size.n();
const int64_t ldr_broadcast = params.need_broadcast ? 0 : problem_size.n();

ElementComputeEpilogue alpha = static_cast<ElementComputeEpilogue>(1);
ElementComputeEpilogue beta = static_cast<ElementComputeEpilogue>(1);

typename GemmFunc::Arguments arguments{
GetGemmMode(params.batch_count),
problem_size, // <- problem size of matrix multiplication
params.batch_count, // <- batch_count or k-dimension split factor
{alpha, beta}, // <- epilogue params, alpha, beta
input, // <- input, ptr_A, A, shape={M, K}
weight, // <- input, ptr_B, B, shape={K, N}
bias, // <- input, ptr_C, shape={M, N} or {1, N}
output, // <- output, ptr_D, Z, shape={M, N}
broadcast, // <- input, ptr_Vector, Broadcast, shape={M, 1}
broadcast_out, // <- output, ptr_Tensor, T
params.shape_args.batch_stride_A,
params.shape_args.batch_stride_B,
params.shape_args.batch_stride_C,
params.shape_args.batch_stride_D,
batch_stride_Broadcast, // <- batch_stride_Vector, need broadcast
problem_size.m() * problem_size.n(), // <- batch_stride_Tensor
params.shape_args.lda,
params.shape_args.ldb,
params.shape_args.ldc_bias,
params.shape_args.ldd,
ldr_broadcast, // <- ldr, must be zero
problem_size.n() // <- ldt
};

size_t workspace_size = GemmFunc::get_workspace_size(arguments);
void* workspace = workspace_size > 0 ? GetWorkspace(workspace_size) : nullptr;

GemmFunc device_gemm;

CHECK_CUTLASS(device_gemm.can_implement(arguments));
CHECK_CUTLASS(device_gemm.initialize(arguments, workspace, params.stream));

//
// Run the GEMM
//
CHECK_CUTLASS(device_gemm(params.stream));
#if AP_ENABLE_DEBUG
CHECK_CUDA(cudaStreamSynchronize(params.stream));
#endif
}

template <typename ElementT,
typename ElementComputeT,
template<typename T> class VariadicFunctor,
int AlignA = 128 / cutlass::sizeof_bits<ElementT>::value,
int AlignB = 128 / cutlass::sizeof_bits<ElementT>::value,
int ConfigId = DefaultConfig::kConfigId,
int SwizzleFactor = DefaultConfig::kSwizzleFactor,
bool Batched = DefaultConfig::kBatched>
SwizzleType ST = DefaultConfig::kSwizzleType>
void CutlassMatmulAddVariadic(const GemmEpilogueParams& params, const typename VariadicFunctor<ElementComputeT>::Arguments& variadic_args) {
using ElementAccumulator = typename CutlassDataType<ElementComputeT>::Type; // <- data type of accumulator
using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations
Expand Down Expand Up @@ -425,12 +299,12 @@ void CutlassMatmulAddVariadic(const GemmEpilogueParams& params, const typename V
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
typename GemmTuningConfigs<ElementT, SwizzleFactor, Batched, ConfigId>::TShape,
typename GemmTuningConfigs<ElementT, SwizzleFactor, Batched, ConfigId>::WShape,
typename GemmTuningConfigs<ElementT, SwizzleFactor, Batched, ConfigId>::IShape,
typename GemmTuningConfigs<ElementT, ConfigId>::TShape,
typename GemmTuningConfigs<ElementT, ConfigId>::WShape,
typename GemmTuningConfigs<ElementT, ConfigId>::IShape,
EpilogueOutputOp,
typename GemmTuningConfigs<ElementT, SwizzleFactor, Batched, ConfigId>::SwizzleThreadBlock,
GemmTuningConfigs<ElementT, SwizzleFactor, Batched, ConfigId>::kNumStages,
typename ThreadBlockSwizzle<ST>::Type,
GemmTuningConfigs<ElementT, ConfigId>::kNumStages,
AlignA,
AlignB,
typename GemmOperation<ElementT>::Type
Expand Down
Loading