Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
liyinrong
composable_kernel
Commits
91b19b3c
Commit
91b19b3c
authored
3 years ago
by
Chao Liu
Browse files
Options
Download
Email Patches
Plain Diff
clean
parent
4511f877
No related merge requests found
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
+43
-30
...operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
profiler/include/profile_gemm_splitk_impl.hpp
+3
-4
profiler/include/profile_gemm_splitk_impl.hpp
profiler/src/profile_gemm.cpp
+42
-256
profiler/src/profile_gemm.cpp
profiler/src/profile_gemm_splitk.cpp
+40
-137
profiler/src/profile_gemm_splitk.cpp
profiler/src/profiler.cpp
+28
-18
profiler/src/profiler.cpp
with
156 additions
and
445 deletions
+156
-445
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
View file @
91b19b3c
...
...
@@ -167,12 +167,14 @@ struct DeviceGemmXdlSplitKCShuffle
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
auto
GetActualBatchAndK
Splitted
(
index_t
KRaw
,
index_t
KBatch
)
static
auto
GetActualBatchAndK
PerBatch
(
index_t
KRaw
,
index_t
KBatch
Desired
)
{
const
index_t
K
Splitted
=
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
*
KBatch
)
*
KPerBlock
;
const
index_t
actual_k_batch
=
math
::
integer_divide_ceil
(
KRaw
,
K
Splitted
)
;
const
index_t
K
PerBatch
=
math
::
integer_divide_ceil
(
KRaw
,
K
PerBlock
*
KBatchDesired
)
*
KPerBlock
;
return
std
::
make_pair
(
actual_k_batch
,
KSplitted
);
const
index_t
KBatch
=
math
::
integer_divide_ceil
(
KRaw
,
KPerBatch
);
return
std
::
make_tuple
(
KBatch
,
KPerBatch
);
}
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
...
...
@@ -488,32 +490,33 @@ struct DeviceGemmXdlSplitKCShuffle
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
index_t
k_batch
)
index_t
k_batch
_desired
)
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
BatchCount_
(
k_batch
),
KBatch_
{
0
},
has_k_batch_tail_
{
false
},
compute_ptr_offset_of_batch_
{
0
,
0
},
block_2_ctile_map_
{},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
{
const
auto
actual_batch_and_ksplitted
=
GetActualBatchAndKSplitted
(
KRaw
,
k_batch
);
BatchCount_
=
actual_batch_and_ksplitted
.
first
;
index_t
KPerBatch
=
0
;
const
auto
KSplitted
=
a
ctual
_b
atch
_and_ksplitted
.
second
;
std
::
tie
(
KBatch_
,
KPerBatch
)
=
GetA
ctual
B
atch
AndKPerBatch
(
KRaw
,
k_batch_desired
)
;
a_grid_desc_ak0_m_ak1_
=
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
K
Splitted
,
StrideA
);
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
K
PerBatch
,
StrideA
);
b_grid_desc_bk0_n_bk1_
=
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
K
Splitted
,
NRaw
,
StrideB
);
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
K
PerBatch
,
NRaw
,
StrideB
);
c_grid_desc_m_n_
=
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideC
);
if
(
KRaw
!=
K
Splitted
*
BatchCount_
||
KRaw
!=
KSplitted
*
BatchCount
_
)
if
(
KRaw
!=
K
PerBatch
*
KBatch
_
)
{
const
auto
KTail
=
KRaw
-
KSplitted
*
(
BatchCount_
-
1
);
has_k_batch_tail_
=
true
;
const
auto
KTail
=
KRaw
-
KPerBatch
*
(
KBatch_
-
1
);
a_grid_desc_ak0_m_ak1_tail_
=
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
KTail
,
StrideA
);
...
...
@@ -525,27 +528,27 @@ struct DeviceGemmXdlSplitKCShuffle
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n_
);
const
index_t
a_batch_stride
=
[
K
Splitted
,
StrideA
]()
{
const
index_t
a_batch_stride
=
[
K
PerBatch
,
StrideA
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
ignore
=
StrideA
;
return
K
Splitted
;
return
K
PerBatch
;
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
{
return
K
Splitted
*
StrideA
;
return
K
PerBatch
*
StrideA
;
}
}();
const
index_t
b_batch_stride
=
[
K
Splitted
,
StrideB
]()
{
const
index_t
b_batch_stride
=
[
K
PerBatch
,
StrideB
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
K
Splitted
*
StrideB
;
return
K
PerBatch
*
StrideB
;
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
ignore
=
StrideB
;
return
K
Splitted
;
return
K
PerBatch
;
}
}();
...
...
@@ -553,14 +556,15 @@ struct DeviceGemmXdlSplitKCShuffle
ComputePtrOffsetOfStridedBatch
{
a_batch_stride
,
b_batch_stride
};
block_2_ctile_map_
=
MakeBlock2CTileMap
(
Batch
Count
_
,
c_grid_desc_m_n_
.
GetLength
(
I0
),
c_grid_desc_m_n_
.
GetLength
(
I1
),
1
,
1
);
K
Batch_
,
c_grid_desc_m_n_
.
GetLength
(
I0
),
c_grid_desc_m_n_
.
GetLength
(
I1
),
1
,
1
);
}
// private:
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
index_t
BatchCount_
;
index_t
KBatch_
;
bool
has_k_batch_tail_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_tail_
;
...
...
@@ -588,7 +592,7 @@ struct DeviceGemmXdlSplitKCShuffle
}
{
std
::
cout
<<
"k_batch = "
<<
arg
.
Batch
Count
_
<<
"
\n
"
;
std
::
cout
<<
"k_batch
_desired
= "
<<
arg
.
K
Batch_
<<
"
\n
"
;
std
::
cout
<<
"arg.a_grid_desc_ak0_m_ak1_{"
<<
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
)
<<
", "
...
...
@@ -614,7 +618,7 @@ struct DeviceGemmXdlSplitKCShuffle
}
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
Batch
Count
_
;
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
K
Batch_
;
const
auto
K0
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
...
...
@@ -634,7 +638,7 @@ struct DeviceGemmXdlSplitKCShuffle
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
Batch
Count
_
,
arg
.
K
Batch_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
...
...
@@ -742,11 +746,20 @@ struct DeviceGemmXdlSplitKCShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m_n_
)
&&
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_tail_
,
arg
.
b_grid_desc_bk0_n_bk1_tail_
,
arg
.
c_grid_desc_m_n_
);
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m_n_
))
{
return
false
;
}
if
(
arg
.
has_k_batch_tail_
&&
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_tail_
,
arg
.
b_grid_desc_bk0_n_bk1_tail_
,
arg
.
c_grid_desc_m_n_
))
{
return
false
;
}
return
true
;
}
// polymorphic
...
...
This diff is collapsed.
Click to expand it.
profiler/include/profile_gemm_splitk_impl.hpp
View file @
91b19b3c
...
...
@@ -98,9 +98,6 @@ bool profile_gemm_splitk_impl(int do_verification,
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
},
num_thread
);
}
// set zero to c_device_buf
c_m_n_device_result
.
GenerateTensorValue
(
GeneratorTensor_0
<
CDataType
>
{},
num_thread
);
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
...
...
@@ -115,7 +112,6 @@ bool profile_gemm_splitk_impl(int do_verification,
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
c_device_buf
.
ToDevice
(
c_m_n_device_result
.
mData
.
data
());
// add device GEMM instances
std
::
vector
<
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
DeviceGemmNoOpPtr
>
gemm_ptrs
;
...
...
@@ -196,6 +192,8 @@ bool profile_gemm_splitk_impl(int do_verification,
// profile device GEMM instances
for
(
auto
&
gemm_ptr
:
gemm_ptrs
)
{
std
::
cout
<<
gemm_ptr
->
GetTypeString
()
<<
std
::
endl
;
auto
argument_ptr
=
gemm_ptr
->
MakeArgumentPointer
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_device_buf
.
GetDeviceBuffer
()),
...
...
@@ -263,6 +261,7 @@ bool profile_gemm_splitk_impl(int do_verification,
a_m_k
,
b_k_n
,
c_m_n_host_result
,
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
pass
=
pass
&&
ck
::
utils
::
check_err
(
c_m_n_device_result
.
mData
,
c_m_n_host_result
.
mData
);
...
...
This diff is collapsed.
Click to expand it.
profiler/src/profile_gemm.cpp
View file @
91b19b3c
...
...
@@ -57,309 +57,95 @@ bool profile_gemm(int argc, char* argv[])
const
int
StrideB
=
std
::
stoi
(
argv
[
12
]);
const
int
StrideC
=
std
::
stoi
(
argv
[
13
]);
auto
profile
=
[
&
](
auto
a_type
,
auto
b_type
,
auto
c_type
,
auto
a_layout
,
auto
b_layout
,
auto
c_layout
)
{
using
ADataType
=
decltype
(
a_type
);
using
BDataType
=
decltype
(
b_type
);
using
CDataType
=
decltype
(
c_type
);
using
ALayout
=
decltype
(
a_layout
);
using
BLayout
=
decltype
(
b_layout
);
using
CLayout
=
decltype
(
c_layout
);
return
ck
::
profiler
::
profile_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
};
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
return
ck
::
profiler
::
profile_gemm_impl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
return
profile
(
ck
::
half_t
{},
ck
::
half_t
{},
ck
::
half_t
{},
Row
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
return
ck
::
profiler
::
profile_gemm_impl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
return
profile
(
ck
::
half_t
{},
ck
::
half_t
{},
ck
::
half_t
{},
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
return
ck
::
profiler
::
profile_gemm_impl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
return
profile
(
ck
::
half_t
{},
ck
::
half_t
{},
ck
::
half_t
{},
Col
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
return
ck
::
profiler
::
profile_gemm_impl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
return
profile
(
ck
::
half_t
{},
ck
::
half_t
{},
ck
::
half_t
{},
Col
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
return
ck
::
profiler
::
profile_gemm_impl
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
return
profile
(
float
{},
float
{},
float
{},
Row
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
return
ck
::
profiler
::
profile_gemm_impl
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
return
profile
(
float
{},
float
{},
float
{},
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
return
ck
::
profiler
::
profile_gemm_impl
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
return
profile
(
float
{},
float
{},
float
{},
Col
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
return
ck
::
profiler
::
profile_gemm_impl
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
return
profile
(
float
{},
float
{},
float
{},
Col
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
INT8_INT8_INT8
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
return
ck
::
profiler
::
profile_gemm_impl
<
int8_t
,
int8_t
,
int8_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
return
profile
(
int8_t
{},
int8_t
{},
int8_t
{},
Row
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
INT8_INT8_INT8
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
return
ck
::
profiler
::
profile_gemm_impl
<
int8_t
,
int8_t
,
int8_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
return
profile
(
int8_t
{},
int8_t
{},
int8_t
{},
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
INT8_INT8_INT8
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
return
ck
::
profiler
::
profile_gemm_impl
<
int8_t
,
int8_t
,
int8_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
return
profile
(
int8_t
{},
int8_t
{},
int8_t
{},
Col
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
INT8_INT8_INT8
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
return
ck
::
profiler
::
profile_gemm_impl
<
int8_t
,
int8_t
,
int8_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
return
profile
(
int8_t
{},
int8_t
{},
int8_t
{},
Col
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
return
ck
::
profiler
::
profile_gemm_impl
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
return
profile
(
ck
::
bhalf_t
{},
ck
::
bhalf_t
{},
ck
::
bhalf_t
{},
Row
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
return
ck
::
profiler
::
profile_gemm_impl
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
return
profile
(
ck
::
bhalf_t
{},
ck
::
bhalf_t
{},
ck
::
bhalf_t
{},
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
return
ck
::
profiler
::
profile_gemm_impl
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
return
profile
(
ck
::
bhalf_t
{},
ck
::
bhalf_t
{},
ck
::
bhalf_t
{},
Col
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
return
ck
::
profiler
::
profile_gemm_impl
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
return
profile
(
ck
::
bhalf_t
{},
ck
::
bhalf_t
{},
ck
::
bhalf_t
{},
Col
{},
Col
{},
Row
{});
}
else
{
...
...
This diff is collapsed.
Click to expand it.
profiler/src/profile_gemm_splitk.cpp
View file @
91b19b3c
...
...
@@ -25,7 +25,7 @@ bool profile_gemm_splitk(int argc, char* argv[])
if
(
argc
!=
15
)
{
printf
(
"arg1: tensor operation (gemm: GEMM)
\n
"
);
printf
(
"arg1: tensor operation (gemm: GEMM
SplitK
)
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)
\n
"
);
printf
(
"arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];
\n
"
);
printf
(
" 1: A[m, k] * B[n, k] = C[m, n];
\n
"
);
...
...
@@ -56,165 +56,68 @@ bool profile_gemm_splitk(int argc, char* argv[])
const
int
StrideC
=
std
::
stoi
(
argv
[
13
]);
const
int
KBatch
=
std
::
stoi
(
argv
[
14
]);
auto
profile
=
[
&
](
auto
a_type
,
auto
b_type
,
auto
c_type
,
auto
a_layout
,
auto
b_layout
,
auto
c_layout
)
{
using
ADataType
=
decltype
(
a_type
);
using
BDataType
=
decltype
(
b_type
);
using
CDataType
=
decltype
(
c_type
);
using
ALayout
=
decltype
(
a_layout
);
using
BLayout
=
decltype
(
b_layout
);
using
CLayout
=
decltype
(
c_layout
);
return
ck
::
profiler
::
profile_gemm_splitk_impl
<
ADataType
,
BDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
KBatch
);
};
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
return
ck
::
profiler
::
profile_gemm_splitk_impl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
KBatch
);
return
profile
(
ck
::
half_t
{},
ck
::
half_t
{},
ck
::
half_t
{},
Row
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
return
ck
::
profiler
::
profile_gemm_splitk_impl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
KBatch
);
return
profile
(
ck
::
half_t
{},
ck
::
half_t
{},
ck
::
half_t
{},
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
return
ck
::
profiler
::
profile_gemm_splitk_impl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
KBatch
);
return
profile
(
ck
::
half_t
{},
ck
::
half_t
{},
ck
::
half_t
{},
Col
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
return
ck
::
profiler
::
profile_gemm_splitk_impl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
KBatch
);
return
profile
(
ck
::
half_t
{},
ck
::
half_t
{},
ck
::
half_t
{},
Col
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
return
ck
::
profiler
::
profile_gemm_splitk_impl
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
KBatch
);
return
profile
(
float
{},
float
{},
float
{},
Row
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
return
ck
::
profiler
::
profile_gemm_splitk_impl
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
KBatch
);
return
profile
(
float
{},
float
{},
float
{},
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
return
ck
::
profiler
::
profile_gemm_splitk_impl
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
KBatch
);
return
profile
(
float
{},
float
{},
float
{},
Col
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
return
ck
::
profiler
::
profile_gemm_splitk_impl
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
M
:
StrideA
,
(
StrideB
<
0
)
?
K
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
,
KBatch
);
return
profile
(
float
{},
float
{},
float
{},
Col
{},
Col
{},
Row
{});
}
else
{
...
...
This diff is collapsed.
Click to expand it.
profiler/src/profiler.cpp
View file @
91b19b3c
...
...
@@ -22,13 +22,39 @@ bool profile_batched_gemm_reduce(int, char*[]);
int
main
(
int
argc
,
char
*
argv
[])
{
auto
print_help_message
=
[]()
{
// clang-format off
printf
(
"arg1: tensor operation, gemm: GEMM
\n
"
" gemm_splitk: GEMM Split-K
\n
"
" gemm_bias_2d: GEMM+Bias(2D)
\n
"
" gemm_bias_relu: GEMM+Bias+ReLU
\n
"
" gemm_bias_relu_add: GEMM+Bias+ReLU+Add
\n
"
" gemm_reduce: GEMM+Reduce
\n
"
" grouped_gemm: Grouped GEMM
\n
"
" conv_fwd: Convolution Forward
\n
"
" conv_fwd_bias_relu: ForwardConvolution+Bias+ReLU
\n
"
" conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLU+Add
\n
"
" conv1d_bwd_data: Convolution Backward Data 1D
\n
"
" conv2d_bwd_data: Convolution Backward Data 2D
\n
"
" conv3d_bwd_data: Convolution Backward Data 3D
\n
"
" reduce: Reduce
\n
"
" conv2d_bwd_weight: Convolution Backward Weight 2D
\n
"
);
// clang-format on
};
if
(
argc
<
2
)
{
print_help_message
();
exit
(
1
);
}
bool
pass
=
true
;
if
(
strcmp
(
argv
[
1
],
"gemm"
)
==
0
)
{
pass
=
profile_gemm
(
argc
,
argv
);
}
if
(
strcmp
(
argv
[
1
],
"gemm_splitk"
)
==
0
)
else
if
(
strcmp
(
argv
[
1
],
"gemm_splitk"
)
==
0
)
{
pass
=
profile_gemm_splitk
(
argc
,
argv
);
}
...
...
@@ -94,23 +120,7 @@ int main(int argc, char* argv[])
}
else
{
// clang-format off
printf
(
"arg1: tensor operation, gemm: GEMM
\n
"
" gemm_splitk: GEMMSplitK)
\n
"
" gemm_bias_2d: GEMM+Bias(2D)
\n
"
" gemm_bias_relu: GEMM+Bias+ReLU
\n
"
" gemm_bias_relu_add: GEMM+Bias+ReLU+Add
\n
"
" gemm_reduce: GEMM+Reduce
\n
"
" grouped_gemm: Grouped GEMM
\n
"
" conv_fwd: ForwardConvolution
\n
"
" conv_fwd_bias_relu: ForwardConvolution+Bias+ReLU
\n
"
" conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLU+Add
\n
"
" conv1d_bwd_data: BackwardConvolution data 1 dim
\n
"
" conv2d_bwd_data: BackwardConvolution data 2 dim
\n
"
" conv3d_bwd_data: BackwardConvolution data 3 dim
\n
"
" reduce: Reduce
\n
"
" conv2d_bwd_weight: Backward Weight Convolution 2d
\n
"
);
// clang-format on
print_help_message
();
}
return
pass
?
0
:
1
;
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment
Menu
Projects
Groups
Snippets
Help