Commit 8d92ff86 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Rangify check_err()

By rangifying check_err(), we can not only compare values between
std::vector<>s, but also compare any ranges which have same value
type.
Showing with 47 additions and 60 deletions
+47 -60
......@@ -240,7 +240,7 @@ int main(int argc, char* argv[])
show_2d_matrix(std::cout << "c_host :", c_m_n_host_result) << std::endl;
}
#endif
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
}
return 0;
......
......@@ -133,11 +133,11 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
c_m_n_device_result = c_m_n_device_result_converted.CopyAsType<CDataType>();
return ck::utils::check_err(c_m_n_device_result_converted.mData, c_m_n_host_result.mData);
return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result);
#else
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
#endif
}
......
......@@ -299,7 +299,7 @@ int main(int argc, char* argv[])
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
return ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData) ? 0 : 1;
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1;
}
return 0;
......
......@@ -275,7 +275,7 @@ int main(int argc, char* argv[])
}
}
return ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData) ? 0 : 1;
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1;
}
return 0;
......
......@@ -147,9 +147,9 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC
#ifdef BUILD_INT4_EXAMPLE
const Tensor<EDataType> e_m_n_device_result_converted(e_m_n_device_result);
return ck::utils::check_err(e_m_n_device_result_converted.mData, e_m_n_host_result.mData);
return ck::utils::check_err(e_m_n_device_result_converted, e_m_n_host_result);
#else
return ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData);
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result);
#endif
}
......
......@@ -164,7 +164,7 @@ bool run_grouped_conv_fwd(bool do_verification,
out_device_buf.FromDevice(out_device.mData.data());
return ck::utils::check_err(
out_device.mData, out_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f);
out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
}
return true;
......
......@@ -276,16 +276,13 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size,
conv_output_device_buf.FromDevice(conv_output_device.mData.data());
r0_device_buf.FromDevice(r0_device.mData.data());
return ck::utils::check_err(conv_output_device.mData,
conv_output_host.mData,
return ck::utils::check_err(conv_output_device,
conv_output_host,
"Error: incorrect results! (Matrix E)",
1e-5f,
1e-4f) &&
ck::utils::check_err(r0_device.mData,
r0_host.mData,
"Error: incorrect results! (Matrix R0)",
1e-5f,
1e-4f);
ck::utils::check_err(
r0_device, r0_host, "Error: incorrect results! (Matrix R0)", 1e-5f, 1e-4f);
}
return true;
......
......@@ -322,12 +322,12 @@ int reduce_blockwise_impl(bool do_verification,
#endif
out_dev.FromDevice(out.mData.data());
pass = pass && ck::utils::check_err(out.mData, out_ref.mData);
pass = pass && ck::utils::check_err(out, out_ref);
if(OutputIndex)
{
out_index_dev.FromDevice(out_indices.mData.data());
pass = pass && ck::utils::check_err(out_indices.mData, out_indices_ref.mData);
pass = pass && ck::utils::check_err(out_indices, out_indices_ref);
};
};
......
......@@ -294,7 +294,7 @@ int main(int argc, char* argv[])
if(do_verify)
{
out_dev.FromDevice(out.mData.data());
pass = pass && ck::utils::check_err(out.mData, out_ref.mData);
pass = pass && ck::utils::check_err(out, out_ref);
};
return (pass ? 0 : 1);
......
......@@ -223,7 +223,7 @@ int reduce_multiblock_atomic_add_impl(bool do_verification,
if(do_verification)
{
out_dev.FromDevice(out.mData.data());
pass = pass && ck::utils::check_err(out.mData, out_ref.mData);
pass = pass && ck::utils::check_err(out, out_ref);
};
return (pass ? 0 : 1);
......
......@@ -267,14 +267,14 @@ bool pool_test(bool do_verification,
out_device_buf.FromDevice(out_n_c_ho_wo_device.mData.data());
pass = pass && ck::utils::check_err(out_n_c_ho_wo_device.mData, out_n_c_ho_wo_host.mData);
pass = pass && ck::utils::check_err(out_n_c_ho_wo_device, out_n_c_ho_wo_host);
if constexpr(OutputIndex)
{
out_indices_device_buf.FromDevice(out_indices_n_c_ho_wo_device.mData.data());
pass = pass && ck::utils::check_err(out_indices_n_c_ho_wo_device.mData,
out_indices_n_c_ho_wo_host.mData);
pass = pass &&
ck::utils::check_err(out_indices_n_c_ho_wo_device, out_indices_n_c_ho_wo_host);
};
}
......
......@@ -249,7 +249,7 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument);
return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1;
return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result) ? 0 : 1;
}
return 0;
......
......@@ -208,10 +208,10 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
#ifdef BUILD_INT4_EXAMPLE
const Tensor<EDataType> c_device_result_converted(c_device_tensors[i]);
pass &= ck::utils::check_err(c_device_result_converted.mData, c_host_tensors[i].mData);
pass &= ck::utils::check_err(c_device_result_converted, c_host_tensors[i]);
#else
pass &= ck::utils::check_err(c_device_tensors[i].mData, c_host_tensors[i].mData);
pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]);
#endif
}
}
......
......@@ -259,12 +259,9 @@ int main()
r0_device_buf.FromDevice(r0_m.mData.data());
r1_device_buf.FromDevice(r1_m.mData.data());
pass = ck::utils::check_err(
e_m_n.mData, e_m_n_host.mData, "Error: Incorrect results c", 1e-2, 1e-2);
pass &= ck::utils::check_err(
r0_m.mData, r0_m_host.mData, "Error: Incorrect results d0", 1e-2, 1e-2);
pass &= ck::utils::check_err(
r1_m.mData, r1_m_host.mData, "Error: Incorrect results d1", 1e-2, 1e-2);
pass = ck::utils::check_err(e_m_n, e_m_n_host, "Error: Incorrect results c", 1e-2, 1e-2);
pass &= ck::utils::check_err(r0_m, r0_m_host, "Error: Incorrect results d0", 1e-2, 1e-2);
pass &= ck::utils::check_err(r1_m, r1_m_host, "Error: Incorrect results d1", 1e-2, 1e-2);
}
bool time_kernel = true;
......
......@@ -264,15 +264,13 @@ bool run_gemm_reduce_add_addsquare_xdl(ck::index_t M,
Tensor<EDataType> e_m_n_host_converted(e_m_n_host);
pass = ck::utils::check_err(
e_m_n.mData, e_m_n_host_converted.mData, "Error: Incorrect results c", 1e-2, 1e-2);
e_m_n, e_m_n_host_converted, "Error: Incorrect results c", 1e-2, 1e-2);
r0_device_buf.FromDevice(r0_m.mData.data());
r1_device_buf.FromDevice(r1_m.mData.data());
pass &= ck::utils::check_err(
r0_m.mData, r0_m_host.mData, "Error: Incorrect results d0", 1e-2, 1e-2);
pass &= ck::utils::check_err(
r1_m.mData, r1_m_host.mData, "Error: Incorrect results d1", 1e-2, 1e-2);
pass &= ck::utils::check_err(r0_m, r0_m_host, "Error: Incorrect results d0", 1e-2, 1e-2);
pass &= ck::utils::check_err(r1_m, r1_m_host, "Error: Incorrect results d1", 1e-2, 1e-2);
if(pass)
{
......
......@@ -243,8 +243,8 @@ auto run_gemm_reduce_max_xdl(ck::index_t M,
if constexpr(std::is_same_v<ADataType, ck::int4_t>)
{
Tensor<EDataType> e_m_n_device_converted(e_m_n);
pass = ck::utils::check_err(e_m_n_device_converted.mData,
e_m_n_host_converted.mData,
pass = ck::utils::check_err(e_m_n_device_converted,
e_m_n_host_converted,
"Error: Incorrect results c",
1e-2,
1e-2);
......@@ -253,12 +253,11 @@ auto run_gemm_reduce_max_xdl(ck::index_t M,
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
{
pass = ck::utils::check_err(
e_m_n.mData, e_m_n_host_converted.mData, "Error: Incorrect results c", 1e-2, 1e-2);
e_m_n, e_m_n_host_converted, "Error: Incorrect results c", 1e-2, 1e-2);
}
r0_device_buf.FromDevice(r0_m.mData.data());
pass &= ck::utils::check_err(
r0_m.mData, r0_m_host.mData, "Error: Incorrect results d0", 1e-2, 1e-2);
pass &= ck::utils::check_err(r0_m, r0_m_host, "Error: Incorrect results d0", 1e-2, 1e-2);
if(pass)
{
......@@ -460,8 +459,8 @@ bool run_gemm_reduce_mean_meansquare_xdl(ck::index_t M,
if constexpr(std::is_same_v<ADataType, ck::int4_t>)
{
Tensor<EDataType> e_m_n_device_converted(e_m_n);
pass = ck::utils::check_err(e_m_n_device_converted.mData,
e_m_n_host_converted.mData,
pass = ck::utils::check_err(e_m_n_device_converted,
e_m_n_host_converted,
"Error: Incorrect results c",
1e-2,
1e-2);
......@@ -470,16 +469,14 @@ bool run_gemm_reduce_mean_meansquare_xdl(ck::index_t M,
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
{
pass = ck::utils::check_err(
e_m_n.mData, e_m_n_host_converted.mData, "Error: Incorrect results c", 1e-2, 1e-2);
e_m_n, e_m_n_host_converted, "Error: Incorrect results c", 1e-2, 1e-2);
}
r0_device_buf.FromDevice(r0_m.mData.data());
r1_device_buf.FromDevice(r1_m.mData.data());
pass &= ck::utils::check_err(
r0_m.mData, r0_m_host.mData, "Error: Incorrect results d0", 1e-2, 1e-2);
pass &= ck::utils::check_err(
r1_m.mData, r1_m_host.mData, "Error: Incorrect results d1", 1e-2, 1e-2);
pass &= ck::utils::check_err(r0_m, r0_m_host, "Error: Incorrect results d0", 1e-2, 1e-2);
pass &= ck::utils::check_err(r1_m, r1_m_host, "Error: Incorrect results d1", 1e-2, 1e-2);
if(pass)
{
......
......@@ -142,7 +142,7 @@ int run_conv_bwd_data(bool do_verification,
in_device_buf.FromDevice(in_device.mData.data());
return ck::utils::check_err(in_device.mData, in_host.mData) ? 0 : 1;
return ck::utils::check_err(in_device, in_host) ? 0 : 1;
}
return 0;
......
......@@ -296,16 +296,15 @@ int main(int argc, char* argv[])
}
}
pass = ck::utils::check_err(c_g_m_n_host_result.mData,
c_g_m_n_device_result.mData,
"Error: Incorrect results c") &&
ck::utils::check_err(d0_g_m_device_result.mData,
d0_g_m_host_result.mData,
pass = ck::utils::check_err(
c_g_m_n_host_result, c_g_m_n_device_result, "Error: Incorrect results c") &&
ck::utils::check_err(d0_g_m_device_result,
d0_g_m_host_result,
"Error: Incorrect results! D0",
1e-4,
1e-5) &&
ck::utils::check_err(d1_g_m_device_result.mData,
d1_g_m_host_result.mData,
ck::utils::check_err(d1_g_m_device_result,
d1_g_m_host_result,
"Error: Incorrect results! D1",
1e-3,
1e-5);
......
......@@ -128,8 +128,7 @@ int main()
host_broadcast2D<Tensor<ABDataType>, Tensor<ABDataType>, Tensor<CDataType>, Add, 0>(
host_c_m_n, a_m_n, b_n, M, N, Add{});
pass &= ck::utils::check_err(
c_m_n.mData, host_c_m_n.mData, "Error: Incorrect results c", 1e-3, 1e-3);
pass &= ck::utils::check_err(c_m_n, host_c_m_n, "Error: Incorrect results c", 1e-3, 1e-3);
}
return pass ? 0 : 1;
......
......@@ -113,8 +113,8 @@ int main()
host_broadcast3D_am_bmnk<Tensor<ABDataType>, Tensor<ABDataType>, Tensor<CDataType>, Add>(
host_c_m_n_k, a_m, b_m_n_k, mnk, Add{});
pass &= ck::utils::check_err(
c_m_n_k.mData, host_c_m_n_k.mData, "Error: Incorrect results c", 1e-3, 1e-3);
pass &=
ck::utils::check_err(c_m_n_k, host_c_m_n_k, "Error: Incorrect results c", 1e-3, 1e-3);
}
return pass ? 0 : 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment