Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
liyinrong
composable_kernel
Commits
8bbb4f85
Commit
8bbb4f85
authored
3 years ago
by
rocking
Browse files
Options
Download
Email Patches
Plain Diff
Use index_t instead of int in API
parent
389cab4b
eltwise_op
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
example/19_binary_elementwise/elementwise_add_4d.cpp
+10
-9
example/19_binary_elementwise/elementwise_add_4d.cpp
include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp
+10
-10
...tensor_operation/gpu/device/device_binary_elementwise.hpp
library/include/ck/library/host_tensor/host_utility.hpp
+17
-0
library/include/ck/library/host_tensor/host_utility.hpp
with
37 additions
and
19 deletions
+37
-19
example/19_binary_elementwise/elementwise_add_4d.cpp
View file @
8bbb4f85
...
...
@@ -3,9 +3,9 @@
#include "check_err.hpp"
#include "config.hpp"
#include "device.hpp"
#include "host_reduce_util.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_utility.hpp"
#include "device_tensor.hpp"
#include "binary_element_wise_operation.hpp"
...
...
@@ -71,14 +71,15 @@ int main()
b_m_device_buf
.
ToDevice
(
b_m
.
mData
.
data
());
auto
broadcastAdd
=
DeviceElementwiseAddInstance
{};
auto
argument
=
broadcastAdd
.
MakeArgumentPointer
(
a_m_device_buf
.
GetDeviceBuffer
(),
b_m_device_buf
.
GetDeviceBuffer
(),
c_m_device_buf
.
GetDeviceBuffer
(),
ck
::
to_int_vector
(
nchw
),
ck
::
to_int_vector
(
a_m
.
mDesc
.
GetStrides
()),
ck
::
to_int_vector
(
b_m
.
mDesc
.
GetStrides
()),
ck
::
to_int_vector
(
c_m
.
mDesc
.
GetStrides
()),
Add
{});
auto
argument
=
broadcastAdd
.
MakeArgumentPointer
(
a_m_device_buf
.
GetDeviceBuffer
(),
b_m_device_buf
.
GetDeviceBuffer
(),
c_m_device_buf
.
GetDeviceBuffer
(),
ck
::
convert_vector_element_type
<
std
::
size_t
,
ck
::
index_t
>
(
nchw
),
ck
::
convert_vector_element_type
<
std
::
size_t
,
ck
::
index_t
>
(
a_m
.
mDesc
.
GetStrides
()),
ck
::
convert_vector_element_type
<
std
::
size_t
,
ck
::
index_t
>
(
b_m
.
mDesc
.
GetStrides
()),
ck
::
convert_vector_element_type
<
std
::
size_t
,
ck
::
index_t
>
(
c_m
.
mDesc
.
GetStrides
()),
Add
{});
if
(
!
broadcastAdd
.
IsSupportedArgument
(
argument
.
get
()))
{
...
...
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp
View file @
8bbb4f85
...
...
@@ -37,8 +37,8 @@ struct DeviceBinaryElementwise : public BaseOperator
return
desc_m0_pad
;
}
static
auto
MakeDescriptor_M0
(
const
std
::
vector
<
int
>&
shape
,
const
std
::
vector
<
int
>&
stride
,
static
auto
MakeDescriptor_M0
(
const
std
::
vector
<
in
dex_
t
>&
shape
,
const
std
::
vector
<
in
dex_
t
>&
stride
,
index_t
gridSize
,
index_t
blockSize
)
{
...
...
@@ -77,10 +77,10 @@ struct DeviceBinaryElementwise : public BaseOperator
Argument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
CDataType
*
p_c
,
const
std
::
vector
<
int
>&
shape
,
const
std
::
vector
<
int
>&
stride_a
,
const
std
::
vector
<
int
>&
stride_b
,
const
std
::
vector
<
int
>&
stride_c
,
const
std
::
vector
<
in
dex_
t
>&
shape
,
const
std
::
vector
<
in
dex_
t
>&
stride_a
,
const
std
::
vector
<
in
dex_
t
>&
stride_b
,
const
std
::
vector
<
in
dex_
t
>&
stride_c
,
ElementwiseFunctor
functor
,
index_t
blockSize
)
:
p_a_
(
p_a
),
...
...
@@ -160,10 +160,10 @@ struct DeviceBinaryElementwise : public BaseOperator
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
std
::
vector
<
int
>
shape
,
std
::
vector
<
int
>
stride_a
,
std
::
vector
<
int
>
stride_b
,
std
::
vector
<
int
>
stride_c
,
std
::
vector
<
in
dex_
t
>
shape
,
std
::
vector
<
in
dex_
t
>
stride_a
,
std
::
vector
<
in
dex_
t
>
stride_b
,
std
::
vector
<
in
dex_
t
>
stride_c
,
ElementwiseFunctor
functor
)
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
...
...
This diff is collapsed.
Click to expand it.
library/include/ck/library/host_tensor/host_utility.hpp
0 → 100644
View file @
8bbb4f85
#pragma once
#include <vector>
namespace
ck
{
template
<
typename
Src
,
typename
Dst
>
inline
std
::
vector
<
Dst
>
convert_vector_element_type
(
const
std
::
vector
<
Src
>&
inData
)
{
std
::
vector
<
Dst
>
outData
;
for
(
auto
elem
:
inData
)
outData
.
push_back
(
static_cast
<
Dst
>
(
elem
));
return
(
outData
);
};
};
// namespace ck
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment
Menu
Projects
Groups
Snippets
Help