Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
liyinrong
composable_kernel
Commits
b01c83d1
Commit
b01c83d1
authored
3 years ago
by
ltqin
Browse files
Options
Download
Email Patches
Plain Diff
just for fp32
parent
074f7410
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp
+1
-57
...ght_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp
with
1 addition
and
57 deletions
+1
-57
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp
View file @
b01c83d1
...
...
@@ -48,35 +48,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
const
auto
wei_k_y_x_c_desc
=
make_naive_tensor_descriptor_packed
(
wei_k_y_x_c_lengths
);
const
auto
out_n_ho_wo_k_desc
=
make_naive_tensor_descriptor_packed
(
out_n_ho_wo_k_lengths
);
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif
1
#if 1
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
constexpr
index_t
BlockSize
=
256
;
...
...
@@ -106,34 +78,6 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
constexpr
index_t
KBatch
=
32
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerXDL
=
32
;
constexpr
index_t
GemmNPerXDL
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
4
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
32
,
2
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmM
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
4
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
4
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
32
,
2
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#endif
const
auto
descs
=
transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad
(
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment