Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
liyinrong
composable_kernel
Commits
6f92737f
Commit
6f92737f
authored
3 years ago
by
Chao Liu
Browse files
Options
Download
Email Patches
Plain Diff
use remove_cvref_t
parent
21444cfc
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
composable_kernel/include/tensor_description/tensor_adaptor.hpp
+1
-2
...able_kernel/include/tensor_description/tensor_adaptor.hpp
composable_kernel/include/tensor_description/tensor_descriptor.hpp
+3
-4
...e_kernel/include/tensor_description/tensor_descriptor.hpp
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp
+5
-7
...rnel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp
composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp
+18
-24
...include/tensor_operation/threadwise_contraction_dlops.hpp
composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp
+9
-12
...nel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp
composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp
+2
-2
.../include/tensor_operation/threadwise_tensor_slice_set.hpp
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp
+25
-34
...ude/tensor_operation/threadwise_tensor_slice_transfer.hpp
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp
+16
-19
.../tensor_operation/threadwise_tensor_slice_transfer_v2.hpp
composable_kernel/include/utility/array.hpp
+1
-1
composable_kernel/include/utility/array.hpp
composable_kernel/include/utility/dynamic_buffer.hpp
+49
-64
composable_kernel/include/utility/dynamic_buffer.hpp
composable_kernel/include/utility/tuple.hpp
+1
-1
composable_kernel/include/utility/tuple.hpp
composable_kernel/include/utility/tuple_helper.hpp
+1
-3
composable_kernel/include/utility/tuple_helper.hpp
with
131 additions
and
173 deletions
+131
-173
composable_kernel/include/tensor_description/tensor_adaptor.hpp
View file @
6f92737f
...
...
@@ -189,8 +189,7 @@ struct TensorAdaptor
bool
is_known
=
true
;
static_for
<
0
,
Transforms
::
Size
(),
1
>
{}([
&
](
auto
i
)
{
is_known
&=
remove_cv_t
<
remove_reference_t
<
decltype
(
Transforms
{}[
i
])
>>::
IsKnownAtCompileTime
();
is_known
&=
remove_cvref_t
<
decltype
(
Transforms
{}[
i
])
>::
IsKnownAtCompileTime
();
});
return
is_known
&&
is_known_at_compile_time
<
ElementSize
>::
value
;
...
...
This diff is collapsed.
Click to expand it.
composable_kernel/include/tensor_description/tensor_descriptor.hpp
View file @
6f92737f
...
...
@@ -185,8 +185,7 @@ struct TensorDescriptor
bool
is_known
=
true
;
static_for
<
0
,
Transforms
::
Size
(),
1
>
{}([
&
](
auto
i
)
{
is_known
&=
remove_cv_t
<
remove_reference_t
<
decltype
(
Transforms
{}[
i
])
>>::
IsKnownAtCompileTime
();
is_known
&=
remove_cvref_t
<
decltype
(
Transforms
{}[
i
])
>::
IsKnownAtCompileTime
();
});
return
is_known
&&
is_known_at_compile_time
<
ElementSize
>::
value
&&
...
...
@@ -587,11 +586,11 @@ __host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc&
template
<
typename
TensorDesc
>
using
TensorCoordinate_t
=
decltype
(
make_tensor_coordinate
(
TensorDesc
{},
MultiIndex
<
remove_cv
_t
<
remove_reference
_t
<
TensorDesc
>
>
::
GetNumOfDimension
()
>
{}));
TensorDesc
{},
MultiIndex
<
remove_cv
ref
_t
<
TensorDesc
>::
GetNumOfDimension
()
>
{}));
template
<
typename
TensorDesc
>
using
TensorCoordinateStep_t
=
decltype
(
make_tensor_coordinate_step
(
TensorDesc
{},
MultiIndex
<
remove_cv
_t
<
remove_reference
_t
<
TensorDesc
>
>
::
GetNumOfDimension
()
>
{}));
TensorDesc
{},
MultiIndex
<
remove_cv
ref
_t
<
TensorDesc
>::
GetNumOfDimension
()
>
{}));
}
// namespace ck
#endif
This diff is collapsed.
Click to expand it.
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp
View file @
6f92737f
...
...
@@ -110,13 +110,11 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
const
BThreadBuffer
&
b_thread_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
ABlockBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatA
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
BThreadBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatB
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
CThreadBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatC
>>>::
value
&&
"wrong! inconsistent type"
);
static_assert
(
is_same
<
remove_cvref_t
<
typename
ABlockBuffer
::
type
>
,
remove_cvref_t
<
FloatA
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
BThreadBuffer
::
type
>
,
remove_cvref_t
<
FloatB
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
CThreadBuffer
::
type
>
,
remove_cvref_t
<
FloatC
>>::
value
&&
"wrong! inconsistent type"
);
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
This diff is collapsed.
Click to expand it.
composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp
View file @
6f92737f
...
...
@@ -55,19 +55,16 @@ struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1
CBuffer
&
c_buf
,
COriginIdx
)
{
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
AOriginIdx
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
BOriginIdx
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
COriginIdx
>>::
value
,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
AOriginIdx
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
BOriginIdx
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
COriginIdx
>>>::
value
,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
ABuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatA
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
BBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatB
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
CBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatC
>>>::
value
&&
"wrong! inconsistent type"
);
is_same
<
remove_cvref_t
<
typename
ABuffer
::
type
>
,
remove_cvref_t
<
FloatA
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
BBuffer
::
type
>
,
remove_cvref_t
<
FloatB
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
CBuffer
::
type
>
,
remove_cvref_t
<
FloatC
>>::
value
&&
"wrong! inconsistent type"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -157,19 +154,16 @@ struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_
CBuffer
&
c_buf
,
COriginIdx
)
{
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
AOriginIdx
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
BOriginIdx
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
COriginIdx
>>::
value
,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
AOriginIdx
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
BOriginIdx
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
COriginIdx
>>>::
value
,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
ABuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatA
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
BBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatB
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
CBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatC
>>>::
value
&&
"wrong! inconsistent type"
);
is_same
<
remove_cvref_t
<
typename
ABuffer
::
type
>
,
remove_cvref_t
<
FloatA
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
BBuffer
::
type
>
,
remove_cvref_t
<
FloatB
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
CBuffer
::
type
>
,
remove_cvref_t
<
FloatC
>>::
value
&&
"wrong! inconsistent type"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
This diff is collapsed.
Click to expand it.
composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp
View file @
6f92737f
...
...
@@ -41,19 +41,16 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
CDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
AOriginIdx
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
BOriginIdx
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
COriginIdx
>>::
value
,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
AOriginIdx
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
BOriginIdx
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
COriginIdx
>>>::
value
,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
ABuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatA
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
BBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatB
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
CBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatC
>>>::
value
&&
"wrong! inconsistent type"
);
is_same
<
remove_cvref_t
<
typename
ABuffer
::
type
>
,
remove_cvref_t
<
FloatA
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
BBuffer
::
type
>
,
remove_cvref_t
<
FloatB
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
CBuffer
::
type
>
,
remove_cvref_t
<
FloatC
>>::
value
&&
"wrong! inconsistent type"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
This diff is collapsed.
Click to expand it.
composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp
View file @
6f92737f
...
...
@@ -30,11 +30,11 @@ struct ThreadwiseTensorSliceSet_v1
static_assert
(
Buffer
::
IsStaticBuffer
(),
"wrong! DstBuffer need to be StaticBuffer"
);
static_assert
(
is_known_at_compile_time
<
remove_cv
_t
<
remove_reference
_t
<
OriginIdx
>>
>
::
value
,
static_assert
(
is_known_at_compile_time
<
remove_cv
ref
_t
<
OriginIdx
>>::
value
,
"wrong! OriginIdx need to be known at compile-time"
);
// Desc is known at compile-time
constexpr
auto
desc
=
remove_cv
_t
<
remove_reference
_t
<
Desc
>
>
{};
constexpr
auto
desc
=
remove_cv
ref
_t
<
Desc
>
{};
// OriginIdx is known at compile-time
constexpr
auto
origin_idx
=
to_multi_index
(
OriginIdx
{});
...
...
This diff is collapsed.
Click to expand it.
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp
View file @
6f92737f
...
...
@@ -95,18 +95,13 @@ struct ThreadwiseTensorSliceTransfer_v1r3
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc need to known at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
SrcSliceOriginIdx
>>>::
value
,
"wrong! SrcSliceOrigin need to known at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
SrcSliceOriginIdx
>>::
value
,
"wrong! SrcSliceOrigin need to known at compile-time"
);
static_assert
(
SrcBuffer
::
IsStaticBuffer
(),
"wrong! SrcBuffer need to be StaticBuffer"
);
// static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>,
// remove_cv_t<remove_reference_t<SrcData>>>::value,
//"wrong! SrcBuffer data type is wrong");
// SrcDesc and src_slice_origin_idx are known at compile-time
constexpr
auto
src_desc
=
remove_cv
_t
<
remove_reference
_t
<
SrcDesc
>
>
{};
constexpr
auto
src_desc
=
remove_cv
ref
_t
<
SrcDesc
>
{};
constexpr
auto
src_slice_origin_idx
=
to_multi_index
(
SrcSliceOriginIdx
{});
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -421,16 +416,15 @@ struct ThreadwiseTensorSliceTransfer_v2
static_assert
(
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! DstDesc need to known at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
DstSliceOriginIdx
>>>::
value
,
"wrong! DstSliceOrigin need to known at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
DstSliceOriginIdx
>>::
value
,
"wrong! DstSliceOrigin need to known at compile-time"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
DstBuffer
::
type
>>
,
remove_cv_t
<
remove_ref
erence
_t
<
DstData
>>
>
::
value
&&
"wrong! inconsistent type"
);
static_assert
(
is_same
<
remove_cvref_t
<
typename
DstBuffer
::
type
>
,
remove_
cv
ref_t
<
DstData
>>::
value
&&
"wrong! inconsistent type"
);
// DstDesc and dst_slice_origin_idx are known at compile-time
constexpr
auto
dst_desc
=
remove_cv
_t
<
remove_reference
_t
<
DstDesc
>
>
{};
constexpr
auto
dst_desc
=
remove_cv
ref
_t
<
DstDesc
>
{};
constexpr
auto
dst_slice_origin_idx
=
DstSliceOriginIdx
{};
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -742,9 +736,9 @@ struct ThreadwiseTensorSliceTransfer_v3
SrcBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Lds
,
"wrong!"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
SrcBuffer
::
type
>>
,
remove_cv_t
<
remove_ref
erence
_t
<
SrcData
>>
>
::
value
,
"wrong! SrcBuffer and SrcData data type are inconsistent"
);
static_assert
(
is_same
<
remove_cvref_t
<
typename
SrcBuffer
::
type
>
,
remove_
cv
ref_t
<
SrcData
>>::
value
,
"wrong! SrcBuffer and SrcData data type are inconsistent"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -899,9 +893,9 @@ struct ThreadwiseTensorSliceTransfer_v3
DstBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Lds
,
"wrong!"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
DstBuffer
::
type
>>
,
remove_cv_t
<
remove_ref
erence
_t
<
DstData
>>
>
::
value
,
"wrong! SrcBuffer or DstBuffer data type is wrong"
);
static_assert
(
is_same
<
remove_cvref_t
<
typename
DstBuffer
::
type
>
,
remove_
cv
ref_t
<
DstData
>>::
value
,
"wrong! SrcBuffer or DstBuffer data type is wrong"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -1315,24 +1309,21 @@ struct ThreadwiseTensorSliceTransfer_v4
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc and DstDesc need to known at compile-time"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
SrcBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
SrcData
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
DstBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
DstData
>>>::
value
,
"wrong! SrcBuffer or DstBuffer data type is wrong"
);
static_assert
(
is_same
<
remove_cvref_t
<
typename
SrcBuffer
::
type
>
,
remove_cvref_t
<
SrcData
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
DstBuffer
::
type
>
,
remove_cvref_t
<
DstData
>>::
value
,
"wrong! SrcBuffer or DstBuffer data type is wrong"
);
static_assert
(
DstBuffer
::
IsStaticBuffer
(),
"wrong! DstBuffer need to be StaticBuffer"
);
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
SrcRefToOriginDisplacement
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
DstOriginIdx
>>>::
value
,
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
"at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
SrcRefToOriginDisplacement
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
DstOriginIdx
>>::
value
,
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
"at compile-time"
);
// SrcDesc and DstDesc are known at compile-time
constexpr
auto
src_desc
=
remove_cv
_t
<
remove_reference
_t
<
SrcDesc
>
>
{};
constexpr
auto
dst_desc
=
remove_cv
_t
<
remove_reference
_t
<
DstDesc
>
>
{};
constexpr
auto
src_desc
=
remove_cv
ref
_t
<
SrcDesc
>
{};
constexpr
auto
dst_desc
=
remove_cv
ref
_t
<
DstDesc
>
{};
// SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time
constexpr
auto
src_ref_to_origin_disp_idx
=
to_multi_index
(
SrcRefToOriginDisplacement
{});
...
...
This diff is collapsed.
Click to expand it.
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp
View file @
6f92737f
...
...
@@ -80,9 +80,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
SrcBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Lds
,
"wrong!"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
SrcBuffer
::
type
>>
,
remove_cv_t
<
remove_ref
erence
_t
<
SrcData
>>
>
::
value
,
"wrong! SrcBuffer and SrcData data type are inconsistent"
);
static_assert
(
is_same
<
remove_cvref_t
<
typename
SrcBuffer
::
type
>
,
remove_
cv
ref_t
<
SrcData
>>::
value
,
"wrong! SrcBuffer and SrcData data type are inconsistent"
);
// tensor descriptor for src_vector
constexpr
auto
src_vector_tensor_lengths
=
SrcVectorTensorLengths
{};
...
...
@@ -248,9 +248,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
DstBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Lds
,
"wrong!"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
DstBuffer
::
type
>>
,
remove_cv_t
<
remove_ref
erence
_t
<
DstData
>>
>
::
value
,
"wrong! SrcBuffer or DstBuffer data type is wrong"
);
static_assert
(
is_same
<
remove_cvref_t
<
typename
DstBuffer
::
type
>
,
remove_
cv
ref_t
<
DstData
>>::
value
,
"wrong! SrcBuffer or DstBuffer data type is wrong"
);
// tensor descriptor for dst_vector
constexpr
auto
dst_vector_tensor_lengths
=
DstVectorTensorLengths
{};
...
...
@@ -669,24 +669,21 @@ struct ThreadwiseTensorSliceTransfer_v4r1
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc and DstDesc need to known at compile-time"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
SrcBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
SrcData
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
DstBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
DstData
>>>::
value
,
"wrong! SrcBuffer or DstBuffer data type is wrong"
);
static_assert
(
is_same
<
remove_cvref_t
<
typename
SrcBuffer
::
type
>
,
remove_cvref_t
<
SrcData
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
DstBuffer
::
type
>
,
remove_cvref_t
<
DstData
>>::
value
,
"wrong! SrcBuffer or DstBuffer data type is wrong"
);
static_assert
(
DstBuffer
::
IsStaticBuffer
(),
"wrong! DstBuffer need to be StaticBuffer"
);
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
SrcRefToOriginDisplacement
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
DstOriginIdx
>>>::
value
,
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
"at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
SrcRefToOriginDisplacement
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
DstOriginIdx
>>::
value
,
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
"at compile-time"
);
// SrcDesc and DstDesc are known at compile-time
constexpr
auto
src_desc
=
remove_cv
_t
<
remove_reference
_t
<
SrcDesc
>
>
{};
constexpr
auto
dst_desc
=
remove_cv
_t
<
remove_reference
_t
<
DstDesc
>
>
{};
constexpr
auto
src_desc
=
remove_cv
ref
_t
<
SrcDesc
>
{};
constexpr
auto
dst_desc
=
remove_cv
ref
_t
<
DstDesc
>
{};
// SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time
constexpr
auto
src_ref_to_origin_disp_idx
=
to_multi_index
(
SrcRefToOriginDisplacement
{});
...
...
This diff is collapsed.
Click to expand it.
composable_kernel/include/utility/array.hpp
View file @
6f92737f
...
...
@@ -48,7 +48,7 @@ struct Array<TData, 0>
template
<
typename
X
,
typename
...
Xs
>
__host__
__device__
constexpr
auto
make_array
(
X
&&
x
,
Xs
&&
...
xs
)
{
using
data_type
=
remove_cv
_t
<
remove_reference
_t
<
X
>
>
;
using
data_type
=
remove_cv
ref
_t
<
X
>
;
return
Array
<
data_type
,
sizeof
...(
Xs
)
+
1
>
{{
std
::
forward
<
X
>
(
x
),
std
::
forward
<
Xs
>
(
xs
)...}};
}
...
...
This diff is collapsed.
Click to expand it.
composable_kernel/include/utility/dynamic_buffer.hpp
View file @
6f92737f
...
...
@@ -39,18 +39,15 @@ struct DynamicBuffer
}
template
<
typename
X
,
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>
>>::
type
,
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
type
>::
value
,
bool
>::
type
=
false
>
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>
>::
type
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
auto
Get
(
index_t
i
,
bool
is_valid_element
)
const
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
vector_size
;
constexpr
index_t
scalar_per_t_vector
=
scalar_type
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
scalar_type
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X need to be multiple T"
);
...
...
@@ -67,15 +64,14 @@ struct DynamicBuffer
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
{
return
amd_buffer_load_invalid_element_return_return_zero
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
t_per_x
>
(
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
return
amd_buffer_load_invalid_element_return_return_zero
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
}
else
{
return
amd_buffer_load_invalid_element_return_customized_value
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
t_per_x
>
(
return
amd_buffer_load_invalid_element_return_customized_value
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
p_data_
,
i
,
is_valid_element
,
element_space_size_
,
invalid_element_value_
);
}
}
...
...
@@ -94,18 +90,15 @@ struct DynamicBuffer
}
template
<
typename
X
,
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>
>>::
type
,
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
type
>::
value
,
bool
>::
type
=
false
>
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>
>::
type
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
void
Set
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
vector_size
;
constexpr
index_t
scalar_per_t_vector
=
scalar_type
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
scalar_type
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X need to be multiple T"
);
...
...
@@ -115,7 +108,7 @@ struct DynamicBuffer
#if CK_USE_AMD_BUFFER_ADDRESSING
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_store
<
remove_cv
_t
<
remove_reference
_t
<
T
>
>
,
t_per_x
>
(
amd_buffer_store
<
remove_cv
ref
_t
<
T
>
,
t_per_x
>
(
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
#else
if
(
is_valid_element
)
...
...
@@ -136,70 +129,65 @@ struct DynamicBuffer
// ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
// ds_write_b128
// TODO: remove this after compiler fix
if
constexpr
(
is_same
<
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
type
,
int8_t
>::
value
)
if
constexpr
(
is_same
<
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
,
int8_t
>::
value
)
{
static_assert
(
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8_t
>::
value
)
||
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8x2_t
>::
value
)
||
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8x4_t
>::
value
)
||
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8x4_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8x4_t
>::
value
)
||
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8x8_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8x8_t
>::
value
)
||
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8x16_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8x16_t
>::
value
),
"wrong! not implemented for this combination, please add "
"implementation"
);
static_assert
((
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
is_same
<
remove_cvref_t
<
X
>
,
int8_t
>::
value
)
||
(
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
is_same
<
remove_cvref_t
<
X
>
,
int8x2_t
>::
value
)
||
(
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
is_same
<
remove_cvref_t
<
X
>
,
int8x4_t
>::
value
)
||
(
is_same
<
remove_cvref_t
<
T
>
,
int8x4_t
>::
value
&&
is_same
<
remove_cvref_t
<
X
>
,
int8x4_t
>::
value
)
||
(
is_same
<
remove_cvref_t
<
T
>
,
int8x8_t
>::
value
&&
is_same
<
remove_cvref_t
<
X
>
,
int8x8_t
>::
value
)
||
(
is_same
<
remove_cvref_t
<
T
>
,
int8x16_t
>::
value
&&
is_same
<
remove_cvref_t
<
X
>
,
int8x16_t
>::
value
),
"wrong! not implemented for this combination, please add "
"implementation"
);
if
constexpr
(
is_same
<
remove_cv
_t
<
remove_reference
_t
<
T
>
>
,
int8_t
>::
value
&&
is_same
<
remove_cv
_t
<
remove_reference
_t
<
X
>
>
,
int8_t
>::
value
)
if
constexpr
(
is_same
<
remove_cv
ref
_t
<
T
>
,
int8_t
>::
value
&&
is_same
<
remove_cv
ref
_t
<
X
>
,
int8_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
c_style_pointer_cast
<
int8_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
const
int8_t
*>
(
&
x
);
}
else
if
constexpr
(
is_same
<
remove_cv
_t
<
remove_reference
_t
<
T
>
>
,
int8_t
>::
value
&&
is_same
<
remove_cv
_t
<
remove_reference
_t
<
X
>
>
,
int8x2_t
>::
value
)
else
if
constexpr
(
is_same
<
remove_cv
ref
_t
<
T
>
,
int8_t
>::
value
&&
is_same
<
remove_cv
ref
_t
<
X
>
,
int8x2_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
c_style_pointer_cast
<
int16_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
const
int16_t
*>
(
&
x
);
}
else
if
constexpr
(
is_same
<
remove_cv
_t
<
remove_reference
_t
<
T
>
>
,
int8_t
>::
value
&&
is_same
<
remove_cv
_t
<
remove_reference
_t
<
X
>
>
,
int8x4_t
>::
value
)
else
if
constexpr
(
is_same
<
remove_cv
ref
_t
<
T
>
,
int8_t
>::
value
&&
is_same
<
remove_cv
ref
_t
<
X
>
,
int8x4_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
c_style_pointer_cast
<
int32_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
const
int32_t
*>
(
&
x
);
}
else
if
constexpr
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8x4_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8x4_t
>::
value
)
else
if
constexpr
(
is_same
<
remove_cvref_t
<
T
>
,
int8x4_t
>::
value
&&
is_same
<
remove_cvref_t
<
X
>
,
int8x4_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
c_style_pointer_cast
<
int32_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
const
int32_t
*>
(
&
x
);
}
else
if
constexpr
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8x8_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8x8_t
>::
value
)
else
if
constexpr
(
is_same
<
remove_cvref_t
<
T
>
,
int8x8_t
>::
value
&&
is_same
<
remove_cvref_t
<
X
>
,
int8x8_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
c_style_pointer_cast
<
int32x2_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
const
int32x2_t
*>
(
&
x
);
}
else
if
constexpr
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8x16_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8x16_t
>::
value
)
else
if
constexpr
(
is_same
<
remove_cvref_t
<
T
>
,
int8x16_t
>::
value
&&
is_same
<
remove_cvref_t
<
X
>
,
int8x16_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
...
...
@@ -224,18 +212,15 @@ struct DynamicBuffer
}
template
<
typename
X
,
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>
>>::
type
,
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
type
>::
value
,
bool
>::
type
=
false
>
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>
>::
type
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
void
AtomicAdd
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
vector_size
;
constexpr
index_t
scalar_per_t_vector
=
scalar_type
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
scalar_type
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X need to be multiple T"
);
...
...
@@ -245,7 +230,7 @@ struct DynamicBuffer
#if CK_USE_AMD_BUFFER_ADDRESSING
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_atomic_add
<
remove_cv
_t
<
remove_reference
_t
<
T
>
>
,
t_per_x
>
(
amd_buffer_atomic_add
<
remove_cv
ref
_t
<
T
>
,
t_per_x
>
(
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
#else
if
(
is_valid_element
)
...
...
This diff is collapsed.
Click to expand it.
composable_kernel/include/utility/tuple.hpp
View file @
6f92737f
...
...
@@ -159,7 +159,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
template
<
typename
...
Xs
>
__host__
__device__
constexpr
auto
make_tuple
(
Xs
&&
...
xs
)
{
return
Tuple
<
remove_cv
_t
<
remove_reference
_t
<
Xs
>
>
...
>
(
std
::
forward
<
Xs
>
(
xs
)...);
return
Tuple
<
remove_cv
ref
_t
<
Xs
>
...
>
(
std
::
forward
<
Xs
>
(
xs
)...);
}
}
// namespace ck
...
...
This diff is collapsed.
Click to expand it.
composable_kernel/include/utility/tuple_helper.hpp
View file @
6f92737f
...
...
@@ -14,9 +14,7 @@ struct is_known_at_compile_time<Tuple<Ts...>>
return
container_reduce
(
Tuple
<
Ts
...
>
{},
[](
auto
x
,
bool
r
)
{
return
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
decltype
(
x
)
>>>::
value
&
r
;
return
is_known_at_compile_time
<
remove_cvref_t
<
decltype
(
x
)
>>::
value
&
r
;
},
true
);
}
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment