Commit 9fdbcf90 authored by Chao Liu's avatar Chao Liu
Browse files

change cgemm example to fp16

No related merge requests found
Showing with 40 additions and 64 deletions
+40 -64
add_example_executable(example_cgemm_xdl_bf16 cgemm_xdl_bf16.cpp)
add_example_executable(example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp)
......@@ -44,17 +44,17 @@
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t;
using F32 = float;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = BF16;
using BDataType = BF16;
using CDataType = BF16;
using ADataType = F16;
using BDataType = F16;
using CDataType = F16;
using AccDataType = F32;
using ALayout = ck::tensor_layout::gemm::RowMajor;
......@@ -109,7 +109,7 @@ using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_
// clang-format on
using ReferenceCGemmInstance = ck::tensor_operation::host::
ReferenceCGemm<float, float, float, PassThrough, PassThrough, PassThrough>;
ReferenceCGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
int main(int argc, char* argv[])
{
......@@ -266,31 +266,18 @@ int main(int argc, char* argv[])
if(do_verification)
{
Tensor<float> a_f32_m_k_real(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<float> a_f32_m_k_imag(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<float> b_f32_k_n_real(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<float> b_f32_k_n_imag(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<float> c_m_n_real_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<float> c_m_n_imag_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<float> c_m_n_real_device_f32_result(
Tensor<CDataType> c_m_n_real_host_result(
f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<float> c_m_n_imag_device_f32_result(
Tensor<CDataType> c_m_n_imag_host_result(
f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
bf16_to_f32_(a_m_k_real, a_f32_m_k_real);
bf16_to_f32_(a_m_k_imag, a_f32_m_k_imag);
bf16_to_f32_(b_k_n_real, b_f32_k_n_real);
bf16_to_f32_(b_k_n_imag, b_f32_k_n_imag);
bf16_to_f32_(c_m_n_real_device_result, c_m_n_real_device_f32_result);
bf16_to_f32_(c_m_n_imag_device_result, c_m_n_imag_device_f32_result);
auto ref_cgemm = ReferenceCGemmInstance{};
auto ref_invoker = ref_cgemm.MakeInvoker();
auto ref_argument = ref_cgemm.MakeArgument(a_f32_m_k_real,
a_f32_m_k_imag,
b_f32_k_n_real,
b_f32_k_n_imag,
auto ref_argument = ref_cgemm.MakeArgument(a_m_k_real,
a_m_k_imag,
b_k_n_real,
b_k_n_imag,
c_m_n_real_host_result,
c_m_n_imag_host_result,
a_element_op,
......@@ -299,12 +286,12 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument);
ck::utils::check_err(c_m_n_real_device_f32_result.mData,
ck::utils::check_err(c_m_n_real_device_result.mData,
c_m_n_real_host_result.mData,
"Verification error: incorrect results in real part!",
1e-2f,
1e-1f);
ck::utils::check_err(c_m_n_imag_device_f32_result.mData,
ck::utils::check_err(c_m_n_imag_device_result.mData,
c_m_n_imag_host_result.mData,
"Verification error: incorrect results in imaginary part!",
1e-2f,
......
......@@ -33,12 +33,19 @@ namespace ck {
namespace tensor_operation {
namespace host {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
// FIXME: support arbitrary elementwise operation for A/B/C
template <
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
enable_if_t<
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
bool> = false>
struct ReferenceCGemm : public device::BaseOperator
{
// Argument
......@@ -92,52 +99,34 @@ struct ReferenceCGemm : public device::BaseOperator
}
auto f_mk_kn_mn_real = [&](auto m, auto n) {
float v_acc = 0;
float v_c_real = 0;
for(std::size_t k = 0; k < K; ++k)
{
float v_a_real;
float v_b_real;
float v_a_imag;
float v_b_imag;
float v_a_real = ck::type_convert<float>(arg.a_m_k_real_(m, k));
float v_a_imag = ck::type_convert<float>(arg.a_m_k_imag_(m, k));
float v_b_real = ck::type_convert<float>(arg.b_k_n_real_(k, n));
float v_b_imag = ck::type_convert<float>(arg.b_k_n_imag_(k, n));
arg.a_element_op_(v_a_real, ck::type_convert<float>(arg.a_m_k_real_(m, k)));
arg.a_element_op_(v_a_imag, ck::type_convert<float>(arg.a_m_k_imag_(m, k)));
arg.b_element_op_(v_b_real, ck::type_convert<float>(arg.b_k_n_real_(k, n)));
arg.b_element_op_(v_b_imag, ck::type_convert<float>(arg.b_k_n_imag_(k, n)));
v_acc += v_a_real * v_b_real - v_a_imag * v_b_imag;
v_c_real += v_a_real * v_b_real - v_a_imag * v_b_imag;
}
float v_c_real;
arg.c_element_op_(v_c_real, v_acc);
arg.c_m_n_real_(m, n) = v_c_real;
};
auto f_mk_kn_mn_imag = [&](auto m, auto n) {
float v_acc = 0;
float v_c_imag = 0;
for(std::size_t k = 0; k < K; ++k)
{
float v_a_real;
float v_b_real;
float v_a_imag;
float v_b_imag;
float v_a_real = ck::type_convert<float>(arg.a_m_k_real_(m, k));
float v_a_imag = ck::type_convert<float>(arg.a_m_k_imag_(m, k));
float v_b_real = ck::type_convert<float>(arg.b_k_n_real_(k, n));
float v_b_imag = ck::type_convert<float>(arg.b_k_n_imag_(k, n));
arg.a_element_op_(v_a_real, ck::type_convert<float>(arg.a_m_k_real_(m, k)));
arg.a_element_op_(v_a_imag, ck::type_convert<float>(arg.a_m_k_imag_(m, k)));
arg.b_element_op_(v_b_real, ck::type_convert<float>(arg.b_k_n_real_(k, n)));
arg.b_element_op_(v_b_imag, ck::type_convert<float>(arg.b_k_n_imag_(k, n)));
v_acc += v_a_real * v_b_imag + v_a_imag * v_b_real;
v_c_imag += v_a_real * v_b_imag + v_a_imag * v_b_real;
}
float v_c_imag;
arg.c_element_op_(v_c_imag, v_acc);
arg.c_m_n_imag_(m, n) = v_c_imag;
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment