Commit 3dfa67b5 authored by Jing Zhang's avatar Jing Zhang
Browse files

improve sigmoid

No related merge requests found
Showing with 111 additions and 282 deletions
+111 -282
......@@ -20,6 +20,7 @@ template <typename GridwiseGemm,
typename DGridDesc_K_N_Hox2_Wox2,
typename CGridDesc_K_N_Ho_Wo,
typename CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo,
bool HasMainE0BlockLoop,
bool HasMainE1BlockLoop,
bool HasDoubleTailE1BlockLoop>
__global__ void
......@@ -51,6 +52,7 @@ __global__ void
b_e0_e1_n_ho_wo_e2_grid_desc,
d_k_n_hox2_wox2_grid_desc,
c_k_n_ho_wo_grid_desc,
integral_constant<bool, HasMainE0BlockLoop>{},
integral_constant<bool, HasMainE1BlockLoop>{},
integral_constant<bool, HasDoubleTailE1BlockLoop>{});
}
......@@ -66,6 +68,7 @@ template <typename GridwiseGemm,
typename DGridDesc_K_N_Hox2_Wox2,
typename CGridDesc_K_N_Ho_Wo,
typename CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo,
bool HasMainE0BlockLoop,
bool HasMainE1BlockLoop,
bool HasDoubleTailE1BlockLoop>
__global__ void
......@@ -109,6 +112,7 @@ __global__ void
b_e0_e1_n_ho_wo_e2_grid_desc,
d_k_n_hox2_wox2_grid_desc,
c_k_n_ho_wo_grid_desc,
integral_constant<bool, HasMainE0BlockLoop>,
integral_constant<bool, HasMainE1BlockLoop>{},
integral_constant<bool, HasDoubleTailE1BlockLoop>{});
}
......@@ -182,7 +186,7 @@ struct GridwiseGemmDlops_km_kn_mn_add_v3
return a_block_space_size * sizeof(FloatAB);
}
template <bool HasMainE1BlockLoop, bool HasDoubleTailE1BlockLoop>
template <bool HasMainE0BlockLoop, bool HasMainE1BlockLoop, bool HasDoubleTailE1BlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global,
const FloatC* __restrict__ p_d_global,
......@@ -192,6 +196,7 @@ struct GridwiseGemmDlops_km_kn_mn_add_v3
const BGlobalDesc_E0_E1_N_Ho_Wo_E2& b_e0_e1_n_ho_wo_e2_global_desc,
const DGlobalDesc_K_N_Hox2_Wox2& d_k_n_hox2_wox2_global_desc,
const CGlobalDesc_K_N_Ho_Wo& c_k_n_ho_wo_global_desc,
integral_constant<bool, HasMainE0BlockLoop>,
integral_constant<bool, HasMainE1BlockLoop>,
integral_constant<bool, HasDoubleTailE1BlockLoop>)
{
......@@ -375,8 +380,6 @@ struct GridwiseGemmDlops_km_kn_mn_add_v3
true>
b_thread_even_buf, b_thread_odd_buf;
constexpr auto HasMainE0BlockLoop = false;
if constexpr(HasMainE0BlockLoop)
{
const auto E0 = b_e0_e1_n_ho_wo_e2_global_desc.GetLength(I0);
......@@ -593,14 +596,7 @@ struct GridwiseGemmDlops_km_kn_mn_add_v3
}
else if constexpr(activ_type == 2)
{
// const auto x = c_thread_buf[i];
// constexpr auto log2_e = FloatAcc(1.44269504089);
// const auto r = 1.0 + pow(2, -x * log2_e);
// c_thread_buf(i) = 1.0 / r;
c_thread_buf(i) = 1.0 / (1.0 + expf(-c_thread_buf[i]));
// c_thread_buf(i) = 0.5 * (x / (1 + abs(x))) + 0.5;
// c_thread_buf(i) = x / sqrt(1 + x * x);
}
});
}
......@@ -627,6 +623,8 @@ struct GridwiseGemmDlops_km_kn_mn_add_v3
// Resize_Add
{
#if 1
ThreadwiseTensorSliceTransfer_v2<FloatC,
FloatC,
decltype(d_k_n_hox2_wox2_global_desc),
......@@ -647,6 +645,7 @@ struct GridwiseGemmDlops_km_kn_mn_add_v3
make_tuple(I0, I0, I0, I0),
d_thread_buf,
c_k_n_ho_wo_global_tensor_step_hacks);
#endif
static_for<0, KPerThread, 1>{}([&](auto k_i) {
static_for<0, HoPerThreadx2, 1>{}([&](auto h_i) {
......
......@@ -344,10 +344,10 @@ struct GridwiseGemmDlops_km_kn_mn_v3
c_thread_buf;
// initialize output thread tensor
// ThreadwiseTensorSliceSet_v1<FloatAcc,
// decltype(c_k_n_ho_wo_thread_desc),
// Sequence<KPerThread, 1, HoPerThread, WoPerThread>>{}
//.Run(c_k_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
ThreadwiseTensorSliceSet_v1<FloatAcc,
decltype(c_k_n_ho_wo_thread_desc),
Sequence<KPerThread, 1, HoPerThread, WoPerThread>>{}
.Run(c_k_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
constexpr auto b_thread_slice_copy_step = make_multi_index(0, E1PerBlock, 0, 0, 0, 0);
......@@ -578,14 +578,14 @@ struct GridwiseGemmDlops_km_kn_mn_v3
}
else if constexpr(activ_type == 2)
{
const auto x = c_thread_buf[i];
// constexpr auto log2_e = FloatAcc(1.44269504089);
// const auto r = 1.0 + pow(2, -x * log2_e);
// c_thread_buf(i) = 1.0 / r;
c_thread_buf(i) = 1.0 / (1.0 + expf(-c_thread_buf[i]));
// c_thread_buf(i) = 0.5 * (x / (1 + abs(x))) + 0.5;
// c_thread_buf(i) = x / sqrt(1 + x * x);
FloatAcc x = 1.0 + exp(-c_thread_buf[i]);
asm volatile("\n \
v_rcp_f32 %0, %1 \n"
: "=v"(x)
: "0"(x));
c_thread_buf(i) = x;
}
});
}
......
......@@ -141,7 +141,7 @@ void device_convolution_add_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
constexpr index_t HoPerBlock = 8;
constexpr index_t WoPerBlock = 32;
constexpr index_t E1 = C0 * 9;
constexpr index_t E1 = C0;
constexpr index_t E2 = 1;
constexpr index_t EPerBlock = C0;
......@@ -150,7 +150,7 @@ void device_convolution_add_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = 1;
using ABlockTransferThreadSliceLengths_E0_E1_K_E2 = Sequence<1, 9, 1, E2>;
using ABlockTransferThreadSliceLengths_E0_E1_K_E2 = Sequence<1, 1, 1, E2>;
using ABlockTransferThreadClusterLengths_E0_E1_K_E2 = Sequence<1, EPerBlock, KPerBlock, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2;
......
......@@ -297,11 +297,12 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
const auto grid_size = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N;
const bool has_main_k_block_loop = (E1 + E1PerBlock) / (2 * E1PerBlock) > 1;
const bool has_main_E0_block_loop = E0 > 1;
const bool has_main_E1_block_loop = (E1 + E1PerBlock) / (2 * E1PerBlock) > 1;
const bool has_double_tail_k_block_loop = (E1 / E1PerBlock) % 2 == 0;
std::cerr << "has_main_k_block_loop = " << has_main_k_block_loop
std::cerr << "has_main_E0_block_loop = " << has_main_E0_block_loop
<< "has_main_E1_block_loop = " << has_main_E1_block_loop
<< " has_double_tail_k_block_loop = " << has_double_tail_k_block_loop
<< std::endl;
......@@ -316,7 +317,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
float ave_time = 0;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
if constexpr(has_main_k_block_loop && has_double_tail_k_block_loop)
if constexpr(has_main_E1_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_gemm_dlops_add_v2<
GridwiseGemm,
......@@ -327,6 +328,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
remove_reference_t<DGridDesc_K_N_Hox2_Wox2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
has_main_E0_block_loop,
true,
true>;
......@@ -345,7 +347,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
c_k_n_hop_wop_grid_desc,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor);
}
else if constexpr(has_main_k_block_loop && !has_double_tail_k_block_loop)
else if constexpr(has_main_E1_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = kernel_gemm_dlops_add_v2<
GridwiseGemm,
......@@ -356,6 +358,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
remove_reference_t<DGridDesc_K_N_Hox2_Wox2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
has_main_E0_block_loop,
true,
false>;
......@@ -374,7 +377,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
c_k_n_hop_wop_grid_desc,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor);
}
else if constexpr(!has_main_k_block_loop && has_double_tail_k_block_loop)
else if constexpr(!has_main_E1_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_gemm_dlops_add_v2<
GridwiseGemm,
......@@ -385,6 +388,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
remove_reference_t<DGridDesc_K_N_Hox2_Wox2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
has_main_E0_block_loop,
false,
true>;
......@@ -414,6 +418,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
remove_reference_t<DGridDesc_K_N_Hox2_Wox2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
has_main_E0_block_loop,
false,
false>;
......@@ -447,7 +452,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf.ToDevice(
&c_blockid_to_k_n_ho_wo_block_cluster_adaptor);
if constexpr(has_main_k_block_loop && has_double_tail_k_block_loop)
if constexpr(has_main_E1_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_gemm_dlops_add_v2<
GridwiseGemm,
......@@ -457,6 +462,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
has_main_E0_block_loop,
true,
true>;
......@@ -481,7 +487,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
cast_pointer_to_constant_address_space(
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
}
else if constexpr(has_main_k_block_loop && !has_double_tail_k_block_loop)
else if constexpr(has_main_E1_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = kernel_gemm_dlops_add_v2<
GridwiseGemm,
......@@ -491,6 +497,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
has_main_E0_block_loop,
true,
false>;
......@@ -515,7 +522,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
cast_pointer_to_constant_address_space(
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
}
else if constexpr(!has_main_k_block_loop && has_double_tail_k_block_loop)
else if constexpr(!has_main_E1_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_gemm_dlops_add_v2<
GridwiseGemm,
......@@ -525,6 +532,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
has_main_E0_block_loop,
false,
true>;
......@@ -559,6 +567,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
has_main_E0_block_loop,
false,
false>;
......
......@@ -279,12 +279,12 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
const auto grid_size = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N;
const bool has_main_e0_block_loop = E0 > 1;
const bool has_main_e1_block_loop = (E1 + E1PerBlock) / (2 * E1PerBlock) > 1;
const bool has_double_tail_k_block_loop = (E1 / E1PerBlock) % 2 == 0;
const bool has_main_e0_block_loop = E0 > 1;
const bool has_main_e1_block_loop = (E1 + E1PerBlock) / (2 * E1PerBlock) > 1;
const bool has_double_tail_e1_block_loop = (E1 / E1PerBlock) % 2 == 0;
std::cerr << "has_main_e1_block_loop = " << has_main_e1_block_loop
<< " has_double_tail_k_block_loop = " << has_double_tail_k_block_loop
<< " has_double_tail_e1_block_loop = " << has_double_tail_e1_block_loop
<< std::endl;
const auto c_blockid_to_k_n_ho_wo_block_cluster_adaptor =
......@@ -298,114 +298,31 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
float ave_time = 0;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
if constexpr(has_main_e1_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_E0_E1_K_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
has_main_e0_block_loop,
true,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_e0_e1_k_e2_grid_desc,
b_e0_e1_n_ho_wo_e2_grid_desc,
c_k_n_hop_wop_grid_desc,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor);
}
else if constexpr(has_main_e1_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_E0_E1_K_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
has_main_e0_block_loop,
true,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_e0_e1_k_e2_grid_desc,
b_e0_e1_n_ho_wo_e2_grid_desc,
c_k_n_hop_wop_grid_desc,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor);
}
else if constexpr(!has_main_e1_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_E0_E1_K_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
has_main_e0_block_loop,
false,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_e0_e1_k_e2_grid_desc,
b_e0_e1_n_ho_wo_e2_grid_desc,
c_k_n_hop_wop_grid_desc,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor);
}
else
{
const auto kernel =
kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_E0_E1_K_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
has_main_e0_block_loop,
false,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_e0_e1_k_e2_grid_desc,
b_e0_e1_n_ho_wo_e2_grid_desc,
c_k_n_hop_wop_grid_desc,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor);
}
const auto kernel =
kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_E0_E1_K_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
has_main_e0_block_loop,
has_main_e1_block_loop,
has_double_tail_e1_block_loop>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_e0_e1_k_e2_grid_desc,
b_e0_e1_n_ho_wo_e2_grid_desc,
c_k_n_hop_wop_grid_desc,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor);
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_e0_e1_k_e2_grid_desc_dev_buf(sizeof(AGridDesc_E0_E1_K_E2));
DeviceMem b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf(sizeof(BGridDesc_E0_E1_N_Ho_Wo_E2));
......@@ -419,134 +336,36 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf.ToDevice(
&c_blockid_to_k_n_ho_wo_block_cluster_adaptor);
if constexpr(has_main_e1_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_E0_E1_K_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
has_main_e0_block_loop,
true,
true>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
cast_pointer_to_constant_address_space(
a_e0_e1_k_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_k_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
}
else if constexpr(has_main_e1_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_E0_E1_K_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
has_main_e0_block_loop,
true,
false>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
cast_pointer_to_constant_address_space(
a_e0_e1_k_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_k_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
}
else if constexpr(!has_main_e1_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_E0_E1_K_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
has_main_e0_block_loop,
false,
true>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
cast_pointer_to_constant_address_space(
a_e0_e1_k_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_k_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
}
else
{
const auto kernel =
kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_E0_E1_K_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
has_main_e0_block_loop,
false,
false>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
cast_pointer_to_constant_address_space(
a_e0_e1_k_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_k_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
}
const auto kernel =
kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_E0_E1_K_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
has_main_e0_block_loop,
has_main_e1_block_loop,
has_double_tail_e1_block_loop>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
cast_pointer_to_constant_address_space(
a_e0_e1_k_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_k_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
}
#endif
return ave_time;
}
......
......@@ -108,14 +108,14 @@ int main(int argc, char* argv[])
const bool do_log = std::stoi(argv[5]);
const int nrepeat = std::stoi(argv[6]);
constexpr index_t activ_type = 0;
constexpr index_t activ_type = 2;
#if 1
#if 0
constexpr auto N = Number<1>{};
constexpr auto C = Number<16>{};
constexpr auto Hi = Number<1080>{};
constexpr auto Wi = Number<1920>{};
constexpr auto K = Number<16>{};
constexpr auto K = Number<64>{};
constexpr auto Y = Number<3>{};
constexpr auto X = Number<3>{};
#elif 0
......
......@@ -186,25 +186,27 @@ void host_direct_convolution_add(const Tensor<TIn>& in,
out(n, k, hox2 + 1, wox2 + 1) = v + add(n, k, hox2 + 1, wox2 + 1);
};
switch(layout)
if(layout == ConvTensorLayout::NCHW)
{
case ConvTensorLayout::NCHW:
make_ParallelTensorFunctor(f_nchw,
out.mDesc.GetLengths()[0],
out.mDesc.GetLengths()[1],
out.mDesc.GetLengths()[2] / 2,
out.mDesc.GetLengths()[3] /
2)(std::thread::hardware_concurrency());
break;
case ConvTensorLayout::NHWC:
}
else if(layout == ConvTensorLayout::NHWC)
{
make_ParallelTensorFunctor(f_nhwc,
out.mDesc.GetLengths()[0],
out.mDesc.GetLengths()[1],
out.mDesc.GetLengths()[2] / 2,
out.mDesc.GetLengths()[3] /
2)(std::thread::hardware_concurrency());
break;
default: throw std::runtime_error("wrong! not supported layout");
}
else
{
throw std::runtime_error("wrong! not supported layout");
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment