Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
liyinrong
composable_kernel
Commits
9fdbcf90
Commit
9fdbcf90
authored
3 years ago
by
Chao Liu
Browse files
Options
Download
Email Patches
Plain Diff
change cgemm example to fp16
parent
52ccb9bb
myamlak/cgemm
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
example/22_cgemm/CMakeLists.txt
+1
-1
example/22_cgemm/CMakeLists.txt
example/22_cgemm/cgemm_xdl_fp16.cpp
+14
-27
example/22_cgemm/cgemm_xdl_fp16.cpp
library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp
+25
-36
...ibrary/reference_tensor_operation/cpu/reference_cgemm.hpp
with
40 additions
and
64 deletions
+40
-64
example/22_cgemm/CMakeLists.txt
View file @
9fdbcf90
add_example_executable
(
example_cgemm_xdl_
b
f16 cgemm_xdl_
b
f16.cpp
)
add_example_executable
(
example_cgemm_xdl_f
p
16 cgemm_xdl_f
p
16.cpp
)
This diff is collapsed.
Click to expand it.
example/22_cgemm/cgemm_xdl_
b
f16.cpp
→
example/22_cgemm/cgemm_xdl_f
p
16.cpp
View file @
9fdbcf90
...
...
@@ -44,17 +44,17 @@
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
B
F16
=
ck
::
b
half_t
;
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ADataType
=
B
F16
;
using
BDataType
=
B
F16
;
using
CDataType
=
B
F16
;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
CDataType
=
F16
;
using
AccDataType
=
F32
;
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
...
...
@@ -109,7 +109,7 @@ using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_
// clang-format on
using
ReferenceCGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceCGemm
<
float
,
float
,
float
,
PassThrough
,
PassThrough
,
PassThrough
>
;
ReferenceCGemm
<
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
int
main
(
int
argc
,
char
*
argv
[])
{
...
...
@@ -266,31 +266,18 @@ int main(int argc, char* argv[])
if
(
do_verification
)
{
Tensor
<
float
>
a_f32_m_k_real
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
float
>
a_f32_m_k_imag
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
float
>
b_f32_k_n_real
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
float
>
b_f32_k_n_imag
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
float
>
c_m_n_real_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
float
>
c_m_n_imag_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
float
>
c_m_n_real_device_f32_result
(
Tensor
<
CDataType
>
c_m_n_real_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
float
>
c_m_n_imag_
device_f32
_result
(
Tensor
<
CDataType
>
c_m_n_imag_
host
_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
bf16_to_f32_
(
a_m_k_real
,
a_f32_m_k_real
);
bf16_to_f32_
(
a_m_k_imag
,
a_f32_m_k_imag
);
bf16_to_f32_
(
b_k_n_real
,
b_f32_k_n_real
);
bf16_to_f32_
(
b_k_n_imag
,
b_f32_k_n_imag
);
bf16_to_f32_
(
c_m_n_real_device_result
,
c_m_n_real_device_f32_result
);
bf16_to_f32_
(
c_m_n_imag_device_result
,
c_m_n_imag_device_f32_result
);
auto
ref_cgemm
=
ReferenceCGemmInstance
{};
auto
ref_invoker
=
ref_cgemm
.
MakeInvoker
();
auto
ref_argument
=
ref_cgemm
.
MakeArgument
(
a_
f32_
m_k_real
,
a_
f32_
m_k_imag
,
b_
f32_
k_n_real
,
b_
f32_
k_n_imag
,
auto
ref_argument
=
ref_cgemm
.
MakeArgument
(
a_m_k_real
,
a_m_k_imag
,
b_k_n_real
,
b_k_n_imag
,
c_m_n_real_host_result
,
c_m_n_imag_host_result
,
a_element_op
,
...
...
@@ -299,12 +286,12 @@ int main(int argc, char* argv[])
ref_invoker
.
Run
(
ref_argument
);
ck
::
utils
::
check_err
(
c_m_n_real_device_
f32_
result
.
mData
,
ck
::
utils
::
check_err
(
c_m_n_real_device_result
.
mData
,
c_m_n_real_host_result
.
mData
,
"Verification error: incorrect results in real part!"
,
1e-2
f
,
1e-1
f
);
ck
::
utils
::
check_err
(
c_m_n_imag_device_
f32_
result
.
mData
,
ck
::
utils
::
check_err
(
c_m_n_imag_device_result
.
mData
,
c_m_n_imag_host_result
.
mData
,
"Verification error: incorrect results in imaginary part!"
,
1e-2
f
,
...
...
This diff is collapsed.
Click to expand it.
library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp
View file @
9fdbcf90
...
...
@@ -33,12 +33,19 @@ namespace ck {
namespace
tensor_operation
{
namespace
host
{
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
// FIXME: support arbitrary elementwise operation for A/B/C
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
enable_if_t
<
is_same_v
<
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
&&
is_same_v
<
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
&&
is_same_v
<
CElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
,
bool
>
=
false
>
struct
ReferenceCGemm
:
public
device
::
BaseOperator
{
// Argument
...
...
@@ -92,52 +99,34 @@ struct ReferenceCGemm : public device::BaseOperator
}
auto
f_mk_kn_mn_real
=
[
&
](
auto
m
,
auto
n
)
{
float
v_
acc
=
0
;
float
v_
c_real
=
0
;
for
(
std
::
size_t
k
=
0
;
k
<
K
;
++
k
)
{
float
v_a_real
;
float
v_
b_real
;
float
v_
a_imag
;
float
v_b_imag
;
float
v_a_real
=
ck
::
type_convert
<
float
>
(
arg
.
a_m_k_real_
(
m
,
k
))
;
float
v_
a_imag
=
ck
::
type_convert
<
float
>
(
arg
.
a_m_k_imag_
(
m
,
k
))
;
float
v_
b_real
=
ck
::
type_convert
<
float
>
(
arg
.
b_k_n_real_
(
k
,
n
))
;
float
v_b_imag
=
ck
::
type_convert
<
float
>
(
arg
.
b_k_n_imag_
(
k
,
n
))
;
arg
.
a_element_op_
(
v_a_real
,
ck
::
type_convert
<
float
>
(
arg
.
a_m_k_real_
(
m
,
k
)));
arg
.
a_element_op_
(
v_a_imag
,
ck
::
type_convert
<
float
>
(
arg
.
a_m_k_imag_
(
m
,
k
)));
arg
.
b_element_op_
(
v_b_real
,
ck
::
type_convert
<
float
>
(
arg
.
b_k_n_real_
(
k
,
n
)));
arg
.
b_element_op_
(
v_b_imag
,
ck
::
type_convert
<
float
>
(
arg
.
b_k_n_imag_
(
k
,
n
)));
v_acc
+=
v_a_real
*
v_b_real
-
v_a_imag
*
v_b_imag
;
v_c_real
+=
v_a_real
*
v_b_real
-
v_a_imag
*
v_b_imag
;
}
float
v_c_real
;
arg
.
c_element_op_
(
v_c_real
,
v_acc
);
arg
.
c_m_n_real_
(
m
,
n
)
=
v_c_real
;
};
auto
f_mk_kn_mn_imag
=
[
&
](
auto
m
,
auto
n
)
{
float
v_
acc
=
0
;
float
v_
c_imag
=
0
;
for
(
std
::
size_t
k
=
0
;
k
<
K
;
++
k
)
{
float
v_a_real
;
float
v_
b_real
;
float
v_
a_imag
;
float
v_b_imag
;
float
v_a_real
=
ck
::
type_convert
<
float
>
(
arg
.
a_m_k_real_
(
m
,
k
))
;
float
v_
a_imag
=
ck
::
type_convert
<
float
>
(
arg
.
a_m_k_imag_
(
m
,
k
))
;
float
v_
b_real
=
ck
::
type_convert
<
float
>
(
arg
.
b_k_n_real_
(
k
,
n
))
;
float
v_b_imag
=
ck
::
type_convert
<
float
>
(
arg
.
b_k_n_imag_
(
k
,
n
))
;
arg
.
a_element_op_
(
v_a_real
,
ck
::
type_convert
<
float
>
(
arg
.
a_m_k_real_
(
m
,
k
)));
arg
.
a_element_op_
(
v_a_imag
,
ck
::
type_convert
<
float
>
(
arg
.
a_m_k_imag_
(
m
,
k
)));
arg
.
b_element_op_
(
v_b_real
,
ck
::
type_convert
<
float
>
(
arg
.
b_k_n_real_
(
k
,
n
)));
arg
.
b_element_op_
(
v_b_imag
,
ck
::
type_convert
<
float
>
(
arg
.
b_k_n_imag_
(
k
,
n
)));
v_acc
+=
v_a_real
*
v_b_imag
+
v_a_imag
*
v_b_real
;
v_c_imag
+=
v_a_real
*
v_b_imag
+
v_a_imag
*
v_b_real
;
}
float
v_c_imag
;
arg
.
c_element_op_
(
v_c_imag
,
v_acc
);
arg
.
c_m_n_imag_
(
m
,
n
)
=
v_c_imag
;
};
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment
Menu
Projects
Groups
Snippets
Help