Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
liyinrong
composable_kernel
Commits
be59d8a8
Commit
be59d8a8
authored
2 years ago
by
Po-Yen, Chen
Browse files
Options
Download
Email Patches
Plain Diff
Refactor the design of DeviceGemmMultipleDMultipleR_Xdl_CShuffle
parent
c366de55
feature/refactor-gemm-multiple-d-multiple-r
feature/add-convnd-fwd-reduce-examples
No related merge requests found
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
include/ck/tensor_operation/gpu/device/device_gemm.hpp
+0
-3
include/ck/tensor_operation/gpu/device/device_gemm.hpp
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp
+1
-1
...ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r.hpp
+14
-2
...peration/gpu/device/device_gemm_multiple_d_multiple_r.hpp
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp
+43
-234
...device/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp
+62
-20
...grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp
include/ck/utility/reduction_operator.hpp
+1
-1
include/ck/utility/reduction_operator.hpp
library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp
+2
-0
...include/ck/library/tensor_operation_instance/gpu/gemm.hpp
with
123 additions
and
261 deletions
+123
-261
include/ck/tensor_operation/gpu/device/device_gemm.hpp
View file @
be59d8a8
...
...
@@ -3,9 +3,6 @@
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
...
...
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp
View file @
be59d8a8
...
...
@@ -5,7 +5,7 @@
#include <array>
#include "device_base.hpp"
#include "
ck/tensor_operation/gpu/device/
device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r.hpp
View file @
be59d8a8
...
...
@@ -3,15 +3,27 @@
#pragma once
#include <
iostream
>
#include <
array
>
#include "device_base.hpp"
#include "
ck/tensor_operation/gpu/device/
device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// FIXME: DeviceGemmReduce type need to well define the problem
// GEMM:
// input : A[AK0, M, AK1]
// input : B[AK0, N, AK1]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// output : R0[M], R1[M], ...
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Q0 = reduce0(q_op0(E)), Q1 = reduce1(q_op0(E)), ...
// R0 = r_op0(Q0), R1 = r_op1(Q1), ...
// Assume:
// D0, D1, ... and E have the same layout
template
<
typename
ALayout
,
typename
BLayout
,
typename
DELayout
,
...
...
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp
View file @
be59d8a8
...
...
@@ -12,6 +12,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
...
@@ -192,7 +193,10 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
static
auto
MakeAGridDescriptor_M_K
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
...
...
@@ -207,95 +211,10 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
}
}();
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both M and K
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad M, but not K
assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad K, but not M
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
MRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
{
// not pad M or K
assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
MRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
return
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
}
static
auto
MakeBGridDescriptor_
BK0_N_BK1
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
static
auto
MakeBGridDescriptor_
N_K
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
...
...
@@ -310,92 +229,7 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
}
}();
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both N and K
assert
(
K
%
BK1
==
0
);
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_right_pad_transform
(
NRaw
,
NPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad N, but not K
assert
(
KRaw
%
BK1
==
0
);
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad K, but not N
assert
(
K
%
BK1
==
0
);
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_pass_through_transform
(
NRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
{
// not pad N or K
assert
(
KRaw
%
BK1
==
0
);
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
return
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
}
static
auto
MakeEGridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
StrideE
)
...
...
@@ -413,47 +247,7 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
}
}();
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
NPad
=
N
-
NRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad M and N
return
transform_tensor_descriptor
(
e_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad M, but not N
return
transform_tensor_descriptor
(
e_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad N, but not M
return
transform_tensor_descriptor
(
e_grid_desc_mraw_nraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
// not pad M or N
return
e_grid_desc_mraw_nraw
;
}
return
matrix_padder
.
PadCDescriptor_M_N
(
e_grid_desc_mraw_nraw
);
}
// assume D is packed tensor
...
...
@@ -482,10 +276,10 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
}
}
using
AGridDesc_
AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_
AK0_M_AK1
(
1
,
1
,
1
));
using
BGridDesc_
BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_
BK0_N_BK1
(
1
,
1
,
1
));
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
(
1
,
1
,
1
));
using
RGridDesc_M
=
decltype
(
MakeRGridDescriptor_M
(
1
));
using
AGridDesc_
M_K
=
decltype
(
MakeAGridDescriptor_
M_K
(
1
,
1
,
1
));
using
BGridDesc_
N_K
=
decltype
(
MakeBGridDescriptor_
N_K
(
1
,
1
,
1
));
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
(
1
,
1
,
1
));
using
RGridDesc_M
=
decltype
(
MakeRGridDescriptor_M
(
1
));
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
<
...
...
@@ -504,8 +298,8 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
ThreadReduceOperations
,
InMemoryDataOperationEnum
::
Set
,
RsGlobalMemoryDataOperation
,
AGridDesc_
AK0_M_AK1
,
BGridDesc_
BK0_N_BK1
,
AGridDesc_
M_K
,
BGridDesc_
N_K
,
EGridDesc_M_N
,
RGridDesc_M
,
NumGemmKPrefetchStage
,
...
...
@@ -542,6 +336,13 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
RThreadTransferDstScalarPerVector_MPerBlock
,
LoopSched
>
;
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
Block2ETileMap
=
typename
GridwiseGemm
::
DefaultBlock2ETileMap
;
// Argument
struct
Argument
:
public
BaseArgument
{
...
...
@@ -567,12 +368,16 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
p_ds_grid_
{},
// FIXME
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)},
p_rs_grid_
{},
// FIXME
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
KRaw
,
StrideA
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
KRaw
,
NRaw
,
StrideB
)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_
{},
a_grid_desc_m_k_
{
DeviceOp
::
MakeAGridDescriptor_M_K
(
MRaw
,
KRaw
,
StrideA
)},
b_grid_desc_n_k_
{
DeviceOp
::
MakeBGridDescriptor_N_K
(
KRaw
,
NRaw
,
StrideB
)},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideE
)},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
r_grid_desc_m_
{
DeviceOp
::
MakeRGridDescriptor_M
(
MRaw
)},
a_grid_desc_ak0_m_ak1_
{
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k_
)},
b_grid_desc_bk0_n_bk1_
{
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
b_grid_desc_n_k_
)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_
{},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
rs_grid_desc_mblock_mperblock_
{},
block_2_etile_map_
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
...
...
@@ -581,8 +386,8 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
qs_element_op_
{
qs_element_op
},
rs_element_op_
{
rs_element_op
}
{
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_
ak0_m_ak1
_
,
b_grid_desc_
bk0_n_bk1
_
,
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_
m_k
_
,
b_grid_desc_
n_k
_
,
e_grid_desc_m_n_
,
r_grid_desc_m_
,
block_2_etile_map_
))
...
...
@@ -624,6 +429,12 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
typename
GridwiseGemm
::
RsGridPointer
p_rs_grid_
;
// tensor descriptors
AGridDesc_M_K
a_grid_desc_m_k_
;
BGridDesc_N_K
b_grid_desc_n_k_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
RGridDesc_M
r_grid_desc_m_
;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
StaticallyIndexedArray
<
...
...
@@ -631,16 +442,14 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
NumDTensor
>
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
// FIXME: Ds desc may be of different
// type from E
EGridDesc_M_N
e_grid_desc_m_n_
;
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
RGridDesc_M
r_grid_desc_m_
;
StaticallyIndexedArray
<
typename
GridwiseGemm
::
RGridDescriptor_MBlock_MPerBlock
,
NumRTensor
>
rs_grid_desc_mblock_mperblock_
;
// block-to-e-tile map
typename
GridwiseGemm
::
Default
Block2ETileMap
block_2_etile_map_
;
Block2ETileMap
block_2_etile_map_
;
// element-wise op
AElementwiseOperation
a_element_op_
;
...
...
@@ -657,8 +466,8 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
ak0_m_ak1
_
,
arg
.
b_grid_desc_
bk0_n_bk1
_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
m_k
_
,
arg
.
b_grid_desc_
n_k
_
,
arg
.
e_grid_desc_m_n_
,
arg
.
r_grid_desc_m_
,
arg
.
block_2_etile_map_
))
...
...
@@ -750,8 +559,8 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
ak0_m_ak1
_
,
arg
.
b_grid_desc_
bk0_n_bk1
_
,
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
m_k
_
,
arg
.
b_grid_desc_
n_k
_
,
arg
.
e_grid_desc_m_n_
,
arg
.
r_grid_desc_m_
,
arg
.
block_2_etile_map_
);
...
...
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp
View file @
be59d8a8
...
...
@@ -32,8 +32,8 @@ template <typename FloatAB,
typename
ThreadReduceOperations
,
InMemoryDataOperationEnum
EGlobalMemoryDataOperation
,
typename
RsGlobalMemoryDataOperation
,
typename
AGridDesc_
AK0_M_AK1
,
typename
BGridDesc_
BK0_N_BK1
,
typename
AGridDesc_
M_K
,
typename
BGridDesc_
N_K
,
typename
EGridDesc_M_N
,
typename
RGridDesc_M
,
index_t
NumGemmKPrefetchStage
,
...
...
@@ -84,10 +84,10 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static
constexpr
auto
I7
=
Number
<
7
>
{};
// K1 should be Number<...>
static
constexpr
auto
AK
0
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK
0
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
AK
1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK
1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
AK
1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK
1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
AK
0PerBlock
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK
0PerBlock
=
Number
<
KPerBlock
/
BK1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
...
...
@@ -97,7 +97,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
// A matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
AK0
,
Number
<
MPerBlock
>
{},
AK1
),
make_tuple
(
AK0
PerBlock
,
Number
<
MPerBlock
>
{},
AK1
),
make_tuple
(
Number
<
MPerBlock
+
ABlockLdsExtraM
>
{}
*
AK1
,
AK1
,
I1
));
}
...
...
@@ -105,7 +105,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
// B matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
BK0
,
Number
<
NPerBlock
>
{},
BK1
),
make_tuple
(
BK0
PerBlock
,
Number
<
NPerBlock
>
{},
BK1
),
make_tuple
(
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1
,
BK1
,
I1
));
}
...
...
@@ -167,22 +167,57 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_block_size
*
sizeof
(
FloatCShuffle
));
}
// A desc for source in blockwise copy
__host__
__device__
static
constexpr
auto
MakeDefaultAGridDescriptor_AK0_M_AK1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
)
{
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
AK0
=
K
/
AK1
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
// B desc for source in blockwise copy
__host__
__device__
static
constexpr
auto
MakeDefaultBGridDescriptor_BK0_N_BK1
(
const
BGridDesc_N_K
&
b_grid_desc_n_k
)
{
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
BK0
=
K
/
BK1
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2ETileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
RGridDesc_M
&
r_grid_desc_m
,
const
Block2ETileMap
&
block_2_etile_map
)
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
,
const
BGridDesc_N_K
&
b_grid_desc_n_k
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
RGridDesc_M
&
r_grid_desc_m
,
const
Block2ETileMap
&
block_2_etile_map
)
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
);
const
auto
K
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
);
static_assert
(
AGridDesc_M_K
::
GetNumOfDimension
()
==
2
);
static_assert
(
BGridDesc_N_K
::
GetNumOfDimension
()
==
2
);
static_assert
(
EGridDesc_M_N
::
GetNumOfDimension
()
==
2
);
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
if
(
!
(
M
==
e_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e_grid_desc_m_n
.
GetLength
(
I1
)))
return
false
;
...
...
@@ -259,6 +294,10 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
e_grid_desc_m_n
);
}
using
DefaultAGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
DefaultBGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
...
...
@@ -272,7 +311,10 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using
DsGridPointer
=
decltype
(
MakeTsGridPointer
<
DsDataType
,
true
>
());
using
RsGridPointer
=
decltype
(
MakeTsGridPointer
<
RsDataType
,
false
>
());
template
<
bool
HasMainKBlockLoop
,
typename
Block2ETileMap
>
template
<
bool
HasMainKBlockLoop
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
Block2ETileMap
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
...
...
@@ -356,7 +398,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
Sequence
<
AK0
PerBlock
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
...
...
@@ -387,7 +429,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
Sequence
<
BK0
PerBlock
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
...
...
This diff is collapsed.
Click to expand it.
include/ck/utility/reduction_operator.hpp
View file @
be59d8a8
...
...
@@ -79,7 +79,7 @@ struct SquaredAdd
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the
Max
accumulator!"
);
"The data type is not supported by the
SquaredAdd
accumulator!"
);
a
=
a
+
b
*
b
;
}
...
...
This diff is collapsed.
Click to expand it.
library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp
View file @
be59d8a8
...
...
@@ -4,6 +4,8 @@
#pragma once
#include <cstdlib>
#include <memory>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment
Menu
Projects
Groups
Snippets
Help