Commit c94efafd authored by Chao Liu's avatar Chao Liu
Browse files

default iterator hack for blockwise copy

parent 1c1b56fe
Showing with 12 additions and 2 deletions
+12 -2
......@@ -89,6 +89,16 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
}
}
template <typename SrcBuffer>
__device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_desc, src_buf);
}
}
template <typename DstBuffer>
__device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
{
......
......@@ -150,7 +150,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmK1 = 8;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 1;
constexpr index_t NRepeat = 1;
......@@ -164,7 +164,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment