Commit 91b19b3c authored by Chao Liu's avatar Chao Liu
Browse files

clean

parent 4511f877
No related merge requests found
Showing with 156 additions and 445 deletions
+156 -445
......@@ -167,12 +167,14 @@ struct DeviceGemmXdlSplitKCShuffle
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static auto GetActualBatchAndKSplitted(index_t KRaw, index_t KBatch)
static auto GetActualBatchAndKPerBatch(index_t KRaw, index_t KBatchDesired)
{
const index_t KSplitted = math::integer_divide_ceil(KRaw, KPerBlock * KBatch) * KPerBlock;
const index_t actual_k_batch = math::integer_divide_ceil(KRaw, KSplitted);
const index_t KPerBatch =
math::integer_divide_ceil(KRaw, KPerBlock * KBatchDesired) * KPerBlock;
return std::make_pair(actual_k_batch, KSplitted);
const index_t KBatch = math::integer_divide_ceil(KRaw, KPerBatch);
return std::make_tuple(KBatch, KPerBatch);
}
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
......@@ -488,32 +490,33 @@ struct DeviceGemmXdlSplitKCShuffle
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t k_batch)
index_t k_batch_desired)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
BatchCount_(k_batch),
KBatch_{0},
has_k_batch_tail_{false},
compute_ptr_offset_of_batch_{0, 0},
block_2_ctile_map_{},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
const auto actual_batch_and_ksplitted = GetActualBatchAndKSplitted(KRaw, k_batch);
BatchCount_ = actual_batch_and_ksplitted.first;
index_t KPerBatch = 0;
const auto KSplitted = actual_batch_and_ksplitted.second;
std::tie(KBatch_, KPerBatch) = GetActualBatchAndKPerBatch(KRaw, k_batch_desired);
a_grid_desc_ak0_m_ak1_ =
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KSplitted, StrideA);
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KPerBatch, StrideA);
b_grid_desc_bk0_n_bk1_ =
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KSplitted, NRaw, StrideB);
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KPerBatch, NRaw, StrideB);
c_grid_desc_m_n_ = DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC);
if(KRaw != KSplitted * BatchCount_ || KRaw != KSplitted * BatchCount_)
if(KRaw != KPerBatch * KBatch_)
{
const auto KTail = KRaw - KSplitted * (BatchCount_ - 1);
has_k_batch_tail_ = true;
const auto KTail = KRaw - KPerBatch * (KBatch_ - 1);
a_grid_desc_ak0_m_ak1_tail_ =
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KTail, StrideA);
......@@ -525,27 +528,27 @@ struct DeviceGemmXdlSplitKCShuffle
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_);
const index_t a_batch_stride = [KSplitted, StrideA]() {
const index_t a_batch_stride = [KPerBatch, StrideA]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
ignore = StrideA;
return KSplitted;
return KPerBatch;
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return KSplitted * StrideA;
return KPerBatch * StrideA;
}
}();
const index_t b_batch_stride = [KSplitted, StrideB]() {
const index_t b_batch_stride = [KPerBatch, StrideB]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return KSplitted * StrideB;
return KPerBatch * StrideB;
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
ignore = StrideB;
return KSplitted;
return KPerBatch;
}
}();
......@@ -553,14 +556,15 @@ struct DeviceGemmXdlSplitKCShuffle
ComputePtrOffsetOfStridedBatch{a_batch_stride, b_batch_stride};
block_2_ctile_map_ = MakeBlock2CTileMap(
BatchCount_, c_grid_desc_m_n_.GetLength(I0), c_grid_desc_m_n_.GetLength(I1), 1, 1);
KBatch_, c_grid_desc_m_n_.GetLength(I0), c_grid_desc_m_n_.GetLength(I1), 1, 1);
}
// private:
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
index_t BatchCount_;
index_t KBatch_;
bool has_k_batch_tail_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_tail_;
......@@ -588,7 +592,7 @@ struct DeviceGemmXdlSplitKCShuffle
}
{
std::cout << "k_batch = " << arg.BatchCount_ << "\n";
std::cout << "k_batch_desired = " << arg.KBatch_ << "\n";
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
......@@ -614,7 +618,7 @@ struct DeviceGemmXdlSplitKCShuffle
}
const index_t grid_size =
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_;
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_) * arg.KBatch_;
const auto K0 = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
......@@ -634,7 +638,7 @@ struct DeviceGemmXdlSplitKCShuffle
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.KBatch_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
......@@ -742,11 +746,20 @@ struct DeviceGemmXdlSplitKCShuffle
static bool IsSupportedArgument(const Argument& arg)
{
return GridwiseGemm::CheckValidity(
arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_) &&
GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_tail_,
arg.b_grid_desc_bk0_n_bk1_tail_,
arg.c_grid_desc_m_n_);
if(!GridwiseGemm::CheckValidity(
arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_))
{
return false;
}
if(arg.has_k_batch_tail_ && !GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_tail_,
arg.b_grid_desc_bk0_n_bk1_tail_,
arg.c_grid_desc_m_n_))
{
return false;
}
return true;
}
// polymorphic
......
......@@ -98,9 +98,6 @@ bool profile_gemm_splitk_impl(int do_verification,
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
}
// set zero to c_device_buf
c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0<CDataType>{}, num_thread);
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
......@@ -115,7 +112,6 @@ bool profile_gemm_splitk_impl(int do_verification,
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
c_device_buf.ToDevice(c_m_n_device_result.mData.data());
// add device GEMM instances
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmNoOpPtr> gemm_ptrs;
......@@ -196,6 +192,8 @@ bool profile_gemm_splitk_impl(int do_verification,
// profile device GEMM instances
for(auto& gemm_ptr : gemm_ptrs)
{
std::cout << gemm_ptr->GetTypeString() << std::endl;
auto argument_ptr =
gemm_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
......@@ -263,6 +261,7 @@ bool profile_gemm_splitk_impl(int do_verification,
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
pass = pass &&
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
......
......@@ -57,309 +57,95 @@ bool profile_gemm(int argc, char* argv[])
const int StrideB = std::stoi(argv[12]);
const int StrideC = std::stoi(argv[13]);
auto profile =
[&](auto a_type, auto b_type, auto c_type, auto a_layout, auto b_layout, auto c_layout) {
using ADataType = decltype(a_type);
using BDataType = decltype(b_type);
using CDataType = decltype(c_type);
using ALayout = decltype(a_layout);
using BLayout = decltype(b_layout);
using CLayout = decltype(c_layout);
return ck::profiler::
profile_gemm_impl<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC);
};
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
return ck::profiler::profile_gemm_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC);
return profile(ck::half_t{}, ck::half_t{}, ck::half_t{}, Row{}, Row{}, Row{});
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return ck::profiler::profile_gemm_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC);
return profile(ck::half_t{}, ck::half_t{}, ck::half_t{}, Row{}, Col{}, Row{});
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{
return ck::profiler::profile_gemm_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC);
return profile(ck::half_t{}, ck::half_t{}, ck::half_t{}, Col{}, Row{}, Row{});
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{
return ck::profiler::profile_gemm_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC);
return profile(ck::half_t{}, ck::half_t{}, ck::half_t{}, Col{}, Col{}, Row{});
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
{
return ck::profiler::profile_gemm_impl<float,
float,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC);
return profile(float{}, float{}, float{}, Row{}, Row{}, Row{});
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
{
return ck::profiler::profile_gemm_impl<float,
float,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC);
return profile(float{}, float{}, float{}, Row{}, Col{}, Row{});
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
{
return ck::profiler::profile_gemm_impl<float,
float,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC);
return profile(float{}, float{}, float{}, Col{}, Row{}, Row{});
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
{
return ck::profiler::profile_gemm_impl<float,
float,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC);
return profile(float{}, float{}, float{}, Col{}, Col{}, Row{});
}
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN)
{
return ck::profiler::profile_gemm_impl<int8_t,
int8_t,
int8_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC);
return profile(int8_t{}, int8_t{}, int8_t{}, Row{}, Row{}, Row{});
}
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN)
{
return ck::profiler::profile_gemm_impl<int8_t,
int8_t,
int8_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC);
return profile(int8_t{}, int8_t{}, int8_t{}, Row{}, Col{}, Row{});
}
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN)
{
return ck::profiler::profile_gemm_impl<int8_t,
int8_t,
int8_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC);
return profile(int8_t{}, int8_t{}, int8_t{}, Col{}, Row{}, Row{});
}
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN)
{
return ck::profiler::profile_gemm_impl<int8_t,
int8_t,
int8_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC);
return profile(int8_t{}, int8_t{}, int8_t{}, Col{}, Col{}, Row{});
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{
return ck::profiler::profile_gemm_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC);
return profile(ck::bhalf_t{}, ck::bhalf_t{}, ck::bhalf_t{}, Row{}, Row{}, Row{});
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return ck::profiler::profile_gemm_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC);
return profile(ck::bhalf_t{}, ck::bhalf_t{}, ck::bhalf_t{}, Row{}, Col{}, Row{});
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN)
{
return ck::profiler::profile_gemm_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC);
return profile(ck::bhalf_t{}, ck::bhalf_t{}, ck::bhalf_t{}, Col{}, Row{}, Row{});
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN)
{
return ck::profiler::profile_gemm_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC);
return profile(ck::bhalf_t{}, ck::bhalf_t{}, ck::bhalf_t{}, Col{}, Col{}, Row{});
}
else
{
......
......@@ -25,7 +25,7 @@ bool profile_gemm_splitk(int argc, char* argv[])
if(argc != 15)
{
printf("arg1: tensor operation (gemm: GEMM)\n");
printf("arg1: tensor operation (gemm: GEMMSplitK)\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n");
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
......@@ -56,165 +56,68 @@ bool profile_gemm_splitk(int argc, char* argv[])
const int StrideC = std::stoi(argv[13]);
const int KBatch = std::stoi(argv[14]);
auto profile = [&](auto a_type,
auto b_type,
auto c_type,
auto a_layout,
auto b_layout,
auto c_layout) {
using ADataType = decltype(a_type);
using BDataType = decltype(b_type);
using CDataType = decltype(c_type);
using ALayout = decltype(a_layout);
using BLayout = decltype(b_layout);
using CLayout = decltype(c_layout);
return ck::profiler::
profile_gemm_splitk_impl<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
};
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
return ck::profiler::profile_gemm_splitk_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
return profile(ck::half_t{}, ck::half_t{}, ck::half_t{}, Row{}, Row{}, Row{});
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return ck::profiler::profile_gemm_splitk_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
return profile(ck::half_t{}, ck::half_t{}, ck::half_t{}, Row{}, Col{}, Row{});
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{
return ck::profiler::profile_gemm_splitk_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
return profile(ck::half_t{}, ck::half_t{}, ck::half_t{}, Col{}, Row{}, Row{});
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{
return ck::profiler::profile_gemm_splitk_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
return profile(ck::half_t{}, ck::half_t{}, ck::half_t{}, Col{}, Col{}, Row{});
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
{
return ck::profiler::profile_gemm_splitk_impl<float,
float,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
return profile(float{}, float{}, float{}, Row{}, Row{}, Row{});
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
{
return ck::profiler::profile_gemm_splitk_impl<float,
float,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
return profile(float{}, float{}, float{}, Row{}, Col{}, Row{});
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
{
return ck::profiler::profile_gemm_splitk_impl<float,
float,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
return profile(float{}, float{}, float{}, Col{}, Row{}, Row{});
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
{
return ck::profiler::profile_gemm_splitk_impl<float,
float,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
return profile(float{}, float{}, float{}, Col{}, Col{}, Row{});
}
else
{
......
......@@ -22,13 +22,39 @@ bool profile_batched_gemm_reduce(int, char*[]);
int main(int argc, char* argv[])
{
auto print_help_message = []() {
// clang-format off
printf("arg1: tensor operation, gemm: GEMM\n"
" gemm_splitk: GEMM Split-K\n"
" gemm_bias_2d: GEMM+Bias(2D)\n"
" gemm_bias_relu: GEMM+Bias+ReLU\n"
" gemm_bias_relu_add: GEMM+Bias+ReLU+Add\n"
" gemm_reduce: GEMM+Reduce\n"
" grouped_gemm: Grouped GEMM\n"
" conv_fwd: Convolution Forward\n"
" conv_fwd_bias_relu: ForwardConvolution+Bias+ReLU\n"
" conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLU+Add\n"
" conv1d_bwd_data: Convolution Backward Data 1D\n"
" conv2d_bwd_data: Convolution Backward Data 2D\n"
" conv3d_bwd_data: Convolution Backward Data 3D\n"
" reduce: Reduce\n"
" conv2d_bwd_weight: Convolution Backward Weight 2D\n");
// clang-format on
};
if(argc < 2)
{
print_help_message();
exit(1);
}
bool pass = true;
if(strcmp(argv[1], "gemm") == 0)
{
pass = profile_gemm(argc, argv);
}
if(strcmp(argv[1], "gemm_splitk") == 0)
else if(strcmp(argv[1], "gemm_splitk") == 0)
{
pass = profile_gemm_splitk(argc, argv);
}
......@@ -94,23 +120,7 @@ int main(int argc, char* argv[])
}
else
{
// clang-format off
printf("arg1: tensor operation, gemm: GEMM\n"
" gemm_splitk: GEMMSplitK)\n"
" gemm_bias_2d: GEMM+Bias(2D)\n"
" gemm_bias_relu: GEMM+Bias+ReLU\n"
" gemm_bias_relu_add: GEMM+Bias+ReLU+Add\n"
" gemm_reduce: GEMM+Reduce\n"
" grouped_gemm: Grouped GEMM\n"
" conv_fwd: ForwardConvolution\n"
" conv_fwd_bias_relu: ForwardConvolution+Bias+ReLU\n"
" conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLU+Add\n"
" conv1d_bwd_data: BackwardConvolution data 1 dim\n"
" conv2d_bwd_data: BackwardConvolution data 2 dim\n"
" conv3d_bwd_data: BackwardConvolution data 3 dim\n"
" reduce: Reduce\n"
" conv2d_bwd_weight: Backward Weight Convolution 2d\n");
// clang-format on
print_help_message();
}
return pass ? 0 : 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment