Commit 8bbb4f85 authored by rocking's avatar rocking
Browse files

Use index_t instead of int in API

No related merge requests found
Showing with 37 additions and 19 deletions
+37 -19
......@@ -3,9 +3,9 @@
#include "check_err.hpp"
#include "config.hpp"
#include "device.hpp"
#include "host_reduce_util.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_utility.hpp"
#include "device_tensor.hpp"
#include "binary_element_wise_operation.hpp"
......@@ -71,14 +71,15 @@ int main()
b_m_device_buf.ToDevice(b_m.mData.data());
auto broadcastAdd = DeviceElementwiseAddInstance{};
auto argument = broadcastAdd.MakeArgumentPointer(a_m_device_buf.GetDeviceBuffer(),
b_m_device_buf.GetDeviceBuffer(),
c_m_device_buf.GetDeviceBuffer(),
ck::to_int_vector(nchw),
ck::to_int_vector(a_m.mDesc.GetStrides()),
ck::to_int_vector(b_m.mDesc.GetStrides()),
ck::to_int_vector(c_m.mDesc.GetStrides()),
Add{});
auto argument = broadcastAdd.MakeArgumentPointer(
a_m_device_buf.GetDeviceBuffer(),
b_m_device_buf.GetDeviceBuffer(),
c_m_device_buf.GetDeviceBuffer(),
ck::convert_vector_element_type<std::size_t, ck::index_t>(nchw),
ck::convert_vector_element_type<std::size_t, ck::index_t>(a_m.mDesc.GetStrides()),
ck::convert_vector_element_type<std::size_t, ck::index_t>(b_m.mDesc.GetStrides()),
ck::convert_vector_element_type<std::size_t, ck::index_t>(c_m.mDesc.GetStrides()),
Add{});
if(!broadcastAdd.IsSupportedArgument(argument.get()))
{
......
......@@ -37,8 +37,8 @@ struct DeviceBinaryElementwise : public BaseOperator
return desc_m0_pad;
}
static auto MakeDescriptor_M0(const std::vector<int>& shape,
const std::vector<int>& stride,
static auto MakeDescriptor_M0(const std::vector<index_t>& shape,
const std::vector<index_t>& stride,
index_t gridSize,
index_t blockSize)
{
......@@ -77,10 +77,10 @@ struct DeviceBinaryElementwise : public BaseOperator
Argument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
const std::vector<int>& shape,
const std::vector<int>& stride_a,
const std::vector<int>& stride_b,
const std::vector<int>& stride_c,
const std::vector<index_t>& shape,
const std::vector<index_t>& stride_a,
const std::vector<index_t>& stride_b,
const std::vector<index_t>& stride_c,
ElementwiseFunctor functor,
index_t blockSize)
: p_a_(p_a),
......@@ -160,10 +160,10 @@ struct DeviceBinaryElementwise : public BaseOperator
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
std::vector<int> shape,
std::vector<int> stride_a,
std::vector<int> stride_b,
std::vector<int> stride_c,
std::vector<index_t> shape,
std::vector<index_t> stride_a,
std::vector<index_t> stride_b,
std::vector<index_t> stride_c,
ElementwiseFunctor functor)
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
......
#pragma once
#include <vector>
namespace ck {
template <typename Src, typename Dst>
inline std::vector<Dst> convert_vector_element_type(const std::vector<Src>& inData)
{
std::vector<Dst> outData;
for(auto elem : inData)
outData.push_back(static_cast<Dst>(elem));
return (outData);
};
}; // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment