Unverified Commit 31d1913f authored by Shaojie WANG's avatar Shaojie WANG Committed by GitHub
Browse files

Merge branch 'develop' into fix_ctile_err_for_conv2d_fwd_bias_relu_add

parents b5f1f3eb 86185bd7
Showing with 184 additions and 405 deletions
+184 -405
......@@ -147,8 +147,6 @@ class SimpleAppArgs
int main(int argc, char* argv[])
{
using namespace ck::host_reduce;
const std::vector<int> reduceDims{0, 1, 2};
const std::vector<int> invariantDims{3};
......@@ -254,7 +252,9 @@ int main(int argc, char* argv[])
ReductionHost<InDataType,
AccDataType,
OutDataType,
ReduceOpId,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
Rank,
NumReduceDim,
PropagateNan,
......
......@@ -108,8 +108,6 @@ int main(int argc, char* argv[])
const std::vector<size_t> outLengths = {64, 320, 80};
using namespace ck::host_reduce;
if(argc == 1)
{
do_verify = true;
......@@ -191,7 +189,9 @@ int main(int argc, char* argv[])
ReductionHost<InOutDataType,
AccDataType,
InOutDataType,
ReduceOpId,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
5, // Rank
2, // NumReduceDim
PropagateNan,
......
......@@ -8,10 +8,12 @@
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_reduce_util.hpp"
#include "device_tensor.hpp"
#include "tensor_layout.hpp"
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "reduction_functions_accumulate.hpp"
#include "device_pool2d_fwd_nhwc_nhwc.hpp"
template <typename InDataType,
......@@ -29,19 +31,24 @@ static void pool_host_verify(const Tensor<InDataType>& in,
const std::array<ck::index_t, 2>& in_left_pads,
const std::array<ck::index_t, 2>& /*in_right_pads*/)
{
using namespace ck::host_reduce;
const int32_t divider = window_spatial_lengths[0] * window_spatial_lengths[1];
const auto PreUnaryOp = PreUnaryOpFn<AccDataType, ReduceOpId>(divider);
const auto PosUnaryOp = PosUnaryOpFn<AccDataType, ReduceOpId>(divider);
using ReduceOperation = typename ck::reduce_binary_operator<AccDataType, ReduceOpId>::opType;
using InElementwiseOperation = typename ck::
reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation = typename ck::
reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation;
const InElementwiseOperation in_elementwise_op(divider);
const AccElementwiseOperation acc_elementwise_op(divider);
if constexpr(!OutputIndex)
{
auto opReduce = ReduceOpFn<AccDataType, ReduceOpId>();
using Accumulation =
ck::detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
auto accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>();
auto accuVal = ReduceOperation::GetIdentityValue();
for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y)
{
......@@ -54,14 +61,14 @@ static void pool_host_verify(const Tensor<InDataType>& in,
{
AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi));
PreUnaryOp(currVal);
in_elementwise_op(currVal, currVal);
binop_with_nan_check<AccDataType, PropagateNan>(opReduce, accuVal, currVal);
Accumulation::Calculate(accuVal, currVal);
}
}
}
PosUnaryOp(accuVal);
acc_elementwise_op(accuVal, accuVal);
out(n, c, ho, wo) = accuVal;
};
......@@ -74,10 +81,12 @@ static void pool_host_verify(const Tensor<InDataType>& in,
}
else
{
auto opReduce = ReduceOpFn2<AccDataType, ReduceOpId>();
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
auto accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>();
using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
AccDataType,
IndexDataType>;
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
auto accuVal = ReduceOperation::GetIdentityValue();
IndexDataType accuIndex = 0;
for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y)
......@@ -92,15 +101,14 @@ static void pool_host_verify(const Tensor<InDataType>& in,
AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi));
IndexDataType currIndex = y * window_spatial_lengths[1] + x;
PreUnaryOp(currVal);
in_elementwise_op(currVal, currVal);
binop_with_index_and_nan_check<AccDataType, IndexDataType, PropagateNan>(
opReduce, accuVal, currVal, accuIndex, currIndex);
Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex);
}
}
}
PosUnaryOp(accuVal);
acc_elementwise_op(accuVal, accuVal);
out(n, c, ho, wo) = accuVal;
out_indices(n, c, ho, wo) = accuIndex;
......@@ -139,8 +147,6 @@ bool pool_test(bool do_verification,
ck::index_t in_right_pad_h,
ck::index_t in_right_pad_w)
{
using namespace ck::host_reduce;
using DevicePoolFwdInstance =
ck::tensor_operation::device::DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<
InDataType, // InDataType
......
......@@ -27,8 +27,6 @@ static constexpr bool PropagateNan = false;
int main(int argc, char* argv[])
{
using namespace ck::host_reduce;
bool do_verification;
int init_method;
bool time_kernel;
......
......@@ -27,8 +27,6 @@ static constexpr bool PropagateNan = false;
int main(int argc, char* argv[])
{
using namespace ck::host_reduce;
bool do_verification;
int init_method;
bool time_kernel;
......
......@@ -236,7 +236,7 @@ int main(int argc, char* argv[])
for(int m = 0; m < M; ++m)
{
ReduceAccDataType d_acc = d_reduce_op.GetReductionZeroVal();
ReduceAccDataType d_acc = d_reduce_op.GetIdentityValue();
for(int n = 0; n < N; ++n)
d_reduce_op(d_acc, c_m_n_host_result(m, n));
......
......@@ -261,8 +261,8 @@ int main(int argc, char* argv[])
for(int m = 0; m < M; ++m)
{
float d0_acc = d0_reduce_op.GetReductionZeroVal();
float d1_acc = d1_reduce_op.GetReductionZeroVal();
float d0_acc = d0_reduce_op.GetIdentityValue();
float d1_acc = d1_reduce_op.GetIdentityValue();
for(int n = 0; n < N; ++n)
{
......
......@@ -259,8 +259,8 @@ int main(int argc, char* argv[])
{
for(int m = 0; m < M; ++m)
{
float d0_acc = d0_reduce_op.GetReductionZeroVal();
float d1_acc = d1_reduce_op.GetReductionZeroVal();
float d0_acc = d0_reduce_op.GetIdentityValue();
float d1_acc = d1_reduce_op.GetIdentityValue();
for(int n = 0; n < N; ++n)
{
......
......@@ -157,8 +157,8 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
auto reduceSumOpInst = ReduceSumOp{};
for(int m = 0; m < M; ++m)
{
float mean_acc = reduceSumOpInst.GetReductionZeroVal();
float square_mean_acc = reduceSumOpInst.GetReductionZeroVal();
float mean_acc = reduceSumOpInst.GetIdentityValue();
float square_mean_acc = reduceSumOpInst.GetIdentityValue();
for(int n = 0; n < N; ++n)
{
......
......@@ -348,8 +348,8 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
if constexpr(use_multiblock)
{
const auto zeroVal =
ck::reduce::GetReductionZeroValueForInMemoryDataOperation<OutDataType>(
const auto identityVal =
ck::reduce::GetIdentityValueueForInMemoryDataOperation<OutDataType>(
OutMemoryDataOperation);
const auto kernel_pre =
......@@ -362,7 +362,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
0,
out_grid_desc_m_2,
arg.out_dev_,
zeroVal);
identityVal);
};
avg_time += launch_and_time_kernel(stream_config,
......
#pragma once
#include "data_type.hpp"
#include "math_v2.hpp"
namespace ck {
namespace tensor_operation {
......@@ -296,7 +297,7 @@ struct UnaryAbs<float, float>
{
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(float& y, const float& x) const { y = abs(x); };
__host__ __device__ void operator()(float& y, const float& x) const { y = ck::math::abs(x); };
};
template <>
......@@ -304,7 +305,7 @@ struct UnaryAbs<half_t, half_t>
{
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(half_t& y, const half_t& x) const { y = __habs(x); };
__host__ __device__ void operator()(half_t& y, const half_t& x) const { y = ck::math::abs(x); };
};
template <>
......@@ -312,7 +313,7 @@ struct UnaryAbs<double, double>
{
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(double& y, const double& x) const { y = abs(x); };
__host__ __device__ void operator()(double& y, const double& x) const { y = ck::math::abs(x); };
};
template <>
......@@ -320,12 +321,7 @@ struct UnaryAbs<int8_t, int8_t>
{
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(int8_t& y, const int8_t& x) const
{
int8_t sgn = x >> (8 - 1);
y = (x ^ sgn) - sgn;
};
__host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = ck::math::abs(x); };
};
template <typename Y, typename X>
......@@ -336,7 +332,7 @@ struct UnarySqrt<float, float>
{
__host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(float& y, const float& x) const { y = sqrtf(x); };
__host__ __device__ void operator()(float& y, const float& x) const { y = ck::math::sqrt(x); };
};
template <>
......@@ -344,7 +340,10 @@ struct UnarySqrt<double, double>
{
__host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(double& y, const double& x) const { y = sqrt(x); };
__host__ __device__ void operator()(double& y, const double& x) const
{
y = ck::math::sqrt(x);
};
};
} // namespace element_wise
......
......@@ -171,7 +171,7 @@ struct GridwiseReduction_mk_to_m_multiblock
AccDataType beta,
OutDataType* const __restrict__ p_out_value_global)
{
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto identityVal = ReduceOperation::GetIdentityValue();
// LDS
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
......@@ -179,7 +179,7 @@ struct GridwiseReduction_mk_to_m_multiblock
const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal));
type_convert<InDataType>(identityVal));
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
......@@ -191,7 +191,7 @@ struct GridwiseReduction_mk_to_m_multiblock
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = identityVal; });
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
......@@ -358,12 +358,12 @@ struct GridwiseReduction_mk_to_m_multiblock
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
__shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto identityVal = ReduceOperation::GetIdentityValue();
const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal));
type_convert<InDataType>(identityVal));
const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
......@@ -418,7 +418,7 @@ struct GridwiseReduction_mk_to_m_multiblock
thread_k_cluster_id * KThreadSliceSize));
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = zeroVal;
accu_value_buf(I) = identityVal;
accu_index_buf(I) = 0;
});
......@@ -459,7 +459,7 @@ struct GridwiseReduction_mk_to_m_multiblock
in_thread_idx_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
AccDataType tmpValue = zeroVal;
AccDataType tmpValue = identityVal;
IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
......@@ -512,7 +512,7 @@ struct GridwiseReduction_mk_to_m_multiblock
in_thread_val_buf(Number<offset>{}));
});
AccDataType tmpValue = zeroVal;
AccDataType tmpValue = identityVal;
IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
......
......@@ -135,12 +135,12 @@ struct GridwiseReduction_mk_to_m_threadwise
ReduceOperation,
PropagateNan>;
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto identityVal = ReduceOperation::GetIdentityValue();
const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal));
type_convert<InDataType>(identityVal));
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
......@@ -149,7 +149,7 @@ struct GridwiseReduction_mk_to_m_threadwise
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = identityVal; });
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
......@@ -276,12 +276,12 @@ struct GridwiseReduction_mk_to_m_threadwise
(void)acc_elementwise_op;
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto identityVal = ReduceOperation::GetIdentityValue();
const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal));
type_convert<InDataType>(identityVal));
const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
......@@ -303,7 +303,7 @@ struct GridwiseReduction_mk_to_m_threadwise
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = zeroVal;
accu_value_buf(I) = identityVal;
accu_index_buf(I) = 0;
});
......
......@@ -816,10 +816,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
false>;
// Global write Gemm shuffle + reduction
const auto d_zeroVal = DReduceOperation::GetReductionZeroVal();
const auto d_identityVal = DReduceOperation::GetIdentityValue();
static_for<0, mreduce_per_thread, 1>{}(
[&](auto I) { d_thread_buf(I) = d_zeroVal; });
[&](auto I) { d_thread_buf(I) = d_identityVal; });
// reduce in VGPR
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
......
......@@ -3,11 +3,13 @@
#include <cmath>
#include "data_type.hpp"
#include "half.hpp"
#include "type.hpp"
namespace ck {
namespace math {
// math functions for the host, some are implemented by calling C++ std functions
static inline __host__ float abs(float x) { return std::abs(x); };
static inline __host__ double abs(double x) { return std::abs(x); };
......@@ -28,26 +30,26 @@ static inline __host__ int32_t abs(int32_t x)
static inline __host__ half_t abs(half_t x)
{
half_float::half xx = *reinterpret_cast<half_float::half*>(&x);
uint16_t xx = ck::bit_cast<uint16_t>(x);
half_float::half abs_xx = half_float::abs(xx);
uint16_t abs_xx = xx & 0x7fff;
half_t abs_x = *reinterpret_cast<half_t*>(&abs_xx);
half_t abs_x = ck::bit_cast<half_t>(abs_xx);
return abs_x;
};
static inline __host__ float isnan(float x) { return std::isnan(x); };
static inline __host__ bool isnan(float x) { return std::isnan(x); };
static inline __host__ double isnan(double x) { return std::isnan(x); };
static inline __host__ bool isnan(double x) { return std::isnan(x); };
static inline __host__ int8_t isnan(int8_t x)
static inline __host__ bool isnan(int8_t x)
{
(void)x;
return false;
};
static inline __host__ int32_t isnan(int32_t x)
static inline __host__ bool isnan(int32_t x)
{
(void)x;
return false;
......@@ -55,11 +57,59 @@ static inline __host__ int32_t isnan(int32_t x)
static inline __host__ bool isnan(half_t x)
{
half_float::half xx = *reinterpret_cast<half_float::half*>(&x);
uint16_t xx = ck::bit_cast<uint16_t>(x);
return (xx & 0x7FFF) > 0x7C00;
};
static inline __host__ float sqrt(float x) { return std::sqrt(x); };
static inline __host__ double sqrt(double x) { return std::sqrt(x); };
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
static inline __device__ float abs(float x) { return ::abs(x); };
static inline __device__ double abs(double x) { return ::abs(x); };
static inline __device__ int8_t abs(int8_t x)
{
int8_t sgn = x >> (8 - 1);
return (x ^ sgn) - sgn;
};
static inline __device__ int32_t abs(int32_t x)
{
int32_t sgn = x >> (32 - 1);
return (x ^ sgn) - sgn;
};
static inline __device__ half_t abs(half_t x) { return ::__habs(x); };
static inline __device__ bool isnan(float x) { return ::isnan(x); };
static inline __device__ bool isnan(double x) { return ::isnan(x); };
static inline __device__ bool isnan(int8_t x)
{
(void)x;
return false;
};
return half_float::isnan(xx);
static inline __device__ bool isnan(int32_t x)
{
(void)x;
return false;
};
static inline __device__ bool isnan(half_t x) { return ::__hisnan(x); };
static inline __device__ float sqrt(float x) { return ::sqrtf(x); };
static inline __device__ double sqrt(double x) { return ::sqrt(x); };
} // namespace math
} // namespace ck
......
......@@ -27,6 +27,7 @@
#define CK_REDUCTION_FUNCTIONS_BINOP_HPP
#include "data_type.hpp"
#include "math_v2.hpp"
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
......@@ -34,18 +35,6 @@
namespace ck {
namespace detail {
template <typename T>
static inline __device__ bool is_nan(T x)
{
return (isnan(x));
};
template <>
inline __device__ bool is_nan<half_t>(half_t x)
{
return (__hisnan(x));
};
template <bool PropagateNan, typename ReduceOperation, typename AccDataType>
struct AccumulateWithNanCheck;
......@@ -53,7 +42,7 @@ template <typename ReduceOperation, typename AccDataType>
struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType>
{
// cppcheck-suppress constParameter
__device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
__host__ __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
{
ReduceOperation{}(accuVal, currVal);
};
......@@ -62,9 +51,11 @@ struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType>
template <typename ReduceOperation, typename AccDataType>
struct AccumulateWithNanCheck<true, ReduceOperation, AccDataType>
{
__device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
__host__ __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
{
if(is_nan(currVal))
using ck::math::isnan;
if(isnan(currVal))
{
accuVal = currVal;
}
......@@ -81,7 +72,7 @@ struct AccumulateWithIndexAndNanCheck;
template <typename ReduceOperation, typename AccDataType, typename IndexDataType>
struct AccumulateWithIndexAndNanCheck<false, ReduceOperation, AccDataType, IndexDataType>
{
__device__ static inline void
__host__ __device__ static inline void
// cppcheck-suppress constParameter
Calculate(AccDataType& accuVal,
AccDataType currVal,
......@@ -101,12 +92,14 @@ template <typename ReduceOperation, typename AccDataType, typename IndexDataType
struct AccumulateWithIndexAndNanCheck<true, ReduceOperation, AccDataType, IndexDataType>
{
// The method is called when the ReduceOperation is indexable and the user asked for indices
__device__ static inline void Calculate(AccDataType& accuVal,
AccDataType currVal,
IndexDataType& accuIndex,
IndexDataType currIndex)
__host__ __device__ static inline void Calculate(AccDataType& accuVal,
AccDataType currVal,
IndexDataType& accuIndex,
IndexDataType currIndex)
{
if(is_nan(currVal))
using ck::math::isnan;
if(isnan(currVal))
{
accuVal = currVal;
accuIndex = currIndex;
......
......@@ -36,7 +36,7 @@ namespace reduce {
// Every binary operator used in reduction is represented by a templated functor class. Each functor
// class must provide at least
// three members:
// 1) GetReductionZeroVal() -- the interface to return the "identity element" for the binary
// 1) GetIdentityValue() -- the interface to return the "identity element" for the binary
// operator, "identity element" is the unique
// element in the algebraic space that doesn't affect the value of other elements
// when operated against them, and the concept is similar to zero vector in
......@@ -59,7 +59,7 @@ struct Add
{
using dataType = T;
__host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(0.0f); };
__host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(0.0f); };
__device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
......@@ -76,7 +76,7 @@ struct Mul
{
using dataType = T;
__host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(1.0f); };
__host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(1.0f); };
__device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
......@@ -92,7 +92,7 @@ struct Max
{
using dataType = T;
__host__ __device__ static constexpr T GetReductionZeroVal()
__host__ __device__ static constexpr T GetIdentityValue()
{
return NumericLimits<T>::Lowest();
};
......@@ -125,10 +125,7 @@ struct Min
{
using dataType = T;
__host__ __device__ static constexpr T GetReductionZeroVal()
{
return NumericLimits<T>::Max();
};
__host__ __device__ static constexpr T GetIdentityValue() { return NumericLimits<T>::Max(); };
__device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
......@@ -158,7 +155,7 @@ struct AMax
{
using dataType = T;
__host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(0.0f); };
__host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(0.0f); };
__device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
......@@ -184,7 +181,7 @@ struct AMax
};
template <typename T>
T GetReductionZeroValueForInMemoryDataOperation(InMemoryDataOperationEnum operation)
T GetIdentityValueueForInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
T result = ck::type_convert<T>(0.0f);
......
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef GUARD_HOST_REDUCE_UTIL_HPP
#define GUARD_HOST_REDUCE_UTIL_HPP
#include <limits>
#include <cmath>
#include <functional>
#include "reduction_enums.hpp"
#include "data_type.hpp"
#include "math_v2.hpp"
namespace ck {
namespace host_reduce {
using ck::NanPropagation;
using ck::ReduceTensorOp;
template <typename AccDataType, ReduceTensorOp ReduceOpId>
__host__ static inline std::function<void(AccDataType&)> PreUnaryOpFn(int)
{
using ck::math::abs;
if constexpr(ReduceOpId == ReduceTensorOp::NORM1)
{
return ([&](AccDataType& a_) { a_ = abs(a_); });
}
else if constexpr(ReduceOpId == ReduceTensorOp::NORM2)
{
return ([&](AccDataType& a_) { a_ = a_ * a_; });
}
else if constexpr(ReduceOpId == ReduceTensorOp::AMAX)
{
return ([&](AccDataType& a_) { a_ = abs(a_); });
}
else
{
// ReduceTensorOp::AVG:
// ReduceTensorOp::ADD:
// ReduceTensorOp::MUL:
// ReduceTensorOp::MIN:
// ReduceTensorOp::MAX:
return ([&](AccDataType&) {});
};
};
template <typename AccDataType, ReduceTensorOp ReduceOpId>
__host__ static inline std::function<void(AccDataType&)> PosUnaryOpFn(int32_t divider)
{
using std::sqrt;
if constexpr(ReduceOpId == ReduceTensorOp::NORM2)
{
return ([&](AccDataType& a_) { a_ = sqrt(a_); });
}
else if constexpr(ReduceOpId == ReduceTensorOp::AVG)
{
return ([&, divider](AccDataType& a_) {
a_ = a_ / static_cast<AccDataType>(static_cast<float>(divider));
});
}
else
{
// ReduceTensorOp::ADD:
// ReduceTensorOp::NORM1:
// ReduceTensorOp::MUL:
// ReduceTensorOp::MIN:
// ReduceTensorOp::MAX:
// ReduceTensorOp::AMAX:
return ([&](AccDataType&) {});
}
};
template <typename AccDataType, ReduceTensorOp ReduceOpId>
__host__ static inline std::function<void(AccDataType&, AccDataType)> ReduceOpFn()
{
if constexpr(ReduceOpId == ReduceTensorOp::ADD || ReduceOpId == ReduceTensorOp::AVG ||
ReduceOpId == ReduceTensorOp::NORM1 || ReduceOpId == ReduceTensorOp::NORM2)
{
return ([&](AccDataType& a_, AccDataType b_) { a_ = a_ + b_; });
}
else if constexpr(ReduceOpId == ReduceTensorOp::MUL)
{
return ([&](AccDataType& a_, AccDataType b_) { a_ = a_ * b_; });
}
else if constexpr(ReduceOpId == ReduceTensorOp::MIN)
{
return ([&](AccDataType& a_, AccDataType b_) {
if(a_ > b_)
a_ = b_;
});
}
else if constexpr(ReduceOpId == ReduceTensorOp::MAX || ReduceOpId == ReduceTensorOp::AMAX)
{
return ([&](AccDataType& a_, AccDataType b_) {
if(a_ < b_)
a_ = b_;
});
}
};
template <typename AccDataType, ReduceTensorOp ReduceOpId>
__host__ static inline std::function<void(AccDataType&, AccDataType, bool& changed)> ReduceOpFn2()
{
if constexpr(ReduceOpId == ReduceTensorOp::MIN)
{
return ([&](AccDataType& a_, AccDataType b_, bool& changed) {
if(a_ > b_)
{
a_ = b_;
changed = true;
}
else
changed = false;
});
}
else if constexpr(ReduceOpId == ReduceTensorOp::MAX || ReduceOpId == ReduceTensorOp::AMAX)
{
return ([&](AccDataType& a_, AccDataType b_, bool& changed) {
if(a_ < b_)
{
a_ = b_;
changed = true;
}
else
changed = false;
});
}
else
{
// ReduceTensorOp::ADD:
// ReduceTensorOp::MUL:
// ReduceTensorOp::AVG:
// ReduceTensorOp::NORM1:
// ReduceTensorOp::NORM2:
return (std::function<void(AccDataType&, AccDataType, bool&)>{});
};
};
template <typename AccDataType, ReduceTensorOp ReduceOpId>
__host__ static inline AccDataType ReduceOpZeroVal()
{
if constexpr(ReduceOpId == ReduceTensorOp::MUL)
{
return (static_cast<AccDataType>(1.0f));
}
else if constexpr(ReduceOpId == ReduceTensorOp::MIN)
{
return (ck::NumericLimits<AccDataType>::Max());
}
else if constexpr(ReduceOpId == ReduceTensorOp::MAX)
{
return (ck::NumericLimits<AccDataType>::Lowest());
}
else if constexpr(ReduceOpId == ReduceTensorOp::AMAX)
{
return (static_cast<AccDataType>(0.0f));
}
else
{
// ReduceTensorOp::ADD
// ReduceTensorOp::AVG
// ReduceTensorOp::NORM1
// ReduceTensorOp::NORM2
return (static_cast<AccDataType>(0.0f));
};
};
template <typename AccDataType, bool PropagateNan>
__host__ static inline void
binop_with_nan_check(std::function<void(AccDataType&, AccDataType)> opReduce,
AccDataType& accuVal,
AccDataType currVal)
{
using ck::math::isnan;
if constexpr(!PropagateNan)
{
opReduce(accuVal, currVal);
}
else
{
if(isnan(currVal))
accuVal = currVal;
else
opReduce(accuVal, currVal);
};
};
template <typename AccDataType, typename IndexDataType, bool PropagateNan>
__host__ static inline void
binop_with_index_and_nan_check(std::function<void(AccDataType&, AccDataType, bool&)> opReduce,
AccDataType& accuVal,
AccDataType currVal,
IndexDataType& accuIndex,
IndexDataType currIndex)
{
using ck::math::isnan;
if constexpr(!PropagateNan)
{
bool changed;
opReduce(accuVal, currVal, changed);
if(changed)
accuIndex = currIndex;
}
else
{
if(isnan(currVal))
{
accuVal = currVal;
accuIndex = currIndex;
}
else
{
bool changed;
opReduce(accuVal, currVal, changed);
if(changed)
accuIndex = currIndex;
};
};
};
}; // namespace host_reduce
}; // namespace ck
#endif
......@@ -33,10 +33,10 @@
#include "reduction_enums.hpp"
#include "reduction_common.hpp"
#include "host_reduce_util.hpp"
#include "host_common_util.hpp"
#include "host_tensor.hpp"
#include "data_type.hpp"
#include "reduction_functions_accumulate.hpp"
template <int NDim>
static void get_all_indexes(const std::array<size_t, NDim>& dimLengths,
......@@ -106,11 +106,13 @@ static size_t get_offset_from_index(const std::vector<size_t>& strides,
template <typename InDataType,
typename AccDataType,
typename OutDataType,
ck::ReduceTensorOp ReduceOpId,
typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation,
int Rank,
int NumReduceDim,
bool PropagateNan,
bool NeedIndices>
bool OutputIndex>
struct ReductionHost
{
using IndexDataType = int32_t;
......@@ -122,8 +124,6 @@ struct ReductionHost
std::vector<int> reduceDims;
IndexDataType divider;
std::function<void(AccDataType&)> preUnaryOp;
std::function<void(AccDataType&)> posUnaryOp;
std::array<size_t, NumReduceDim> reduceLengths;
std::array<size_t, NumReduceDim> reduceStrides;
std::array<size_t, NumInvariantDim> invariantLengths;
......@@ -137,9 +137,6 @@ struct ReductionHost
const std::vector<int>& invariantDims_,
const std::vector<int>& reduceDims_)
{
using ck::host_reduce::PosUnaryOpFn;
using ck::host_reduce::PreUnaryOpFn;
// this->outLengths = to_int_vector(outDesc.GetLengths());
this->outStrides = outDesc.GetStrides();
......@@ -171,9 +168,6 @@ struct ReductionHost
invariant_dim_indexes.clear();
get_all_indexes<NumInvariantDim>(invariantLengths, invariant_dim_indexes);
};
preUnaryOp = PreUnaryOpFn<AccDataType, ReduceOpId>(divider);
posUnaryOp = PosUnaryOpFn<AccDataType, ReduceOpId>(divider);
};
void Run(float alpha,
......@@ -182,7 +176,7 @@ struct ReductionHost
OutDataType* out_data,
IndexDataType* out_indices)
{
if constexpr(NeedIndices)
if constexpr(OutputIndex)
{
RunImpl_with_index(alpha, in_data, beta, out_data, out_indices);
}
......@@ -201,15 +195,17 @@ struct ReductionHost
using ck::float_equal_one;
using ck::float_equal_zero;
using ck::type_convert;
using ck::host_reduce::binop_with_index_and_nan_check;
using ck::host_reduce::ReduceOpFn2;
using ck::host_reduce::ReduceOpZeroVal;
auto opReduce2 = ReduceOpFn2<AccDataType, ReduceOpId>();
using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
AccDataType,
IndexDataType>;
InElementwiseOperation in_elementwise_op(divider);
AccElementwiseOperation acc_elementwise_op(divider);
if constexpr(NumInvariantDim == 0)
{
AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>();
AccDataType accuVal = ReduceOperation::GetIdentityValue();
IndexDataType accuIndex = 0;
for(std::size_t i = 0; i < reduce_dim_indexes.size(); i++)
......@@ -219,15 +215,14 @@ struct ReductionHost
auto currVal = type_convert<AccDataType>(in_data[offset_reduce]);
preUnaryOp(currVal);
in_elementwise_op(currVal, currVal);
auto currIndex = static_cast<IndexDataType>(i);
binop_with_index_and_nan_check<AccDataType, IndexDataType, PropagateNan>(
opReduce2, accuVal, currVal, accuIndex, currIndex);
Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex);
};
posUnaryOp(accuVal);
acc_elementwise_op(accuVal, accuVal);
if(!float_equal_one{}(alpha))
accuVal *= type_convert<AccDataType>(alpha);
......@@ -241,7 +236,7 @@ struct ReductionHost
else
{
auto thread_reduce_func = [&](auto invariant_index) {
AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>();
AccDataType accuVal = ReduceOperation::GetIdentityValue();
IndexDataType accuIndex = 0;
auto offset_invariant =
......@@ -255,15 +250,14 @@ struct ReductionHost
auto currVal =
type_convert<AccDataType>(in_data[offset_invariant + offset_reduce]);
preUnaryOp(currVal);
in_elementwise_op(currVal, currVal);
auto currIndex = static_cast<IndexDataType>(i);
binop_with_index_and_nan_check<AccDataType, IndexDataType, PropagateNan>(
opReduce2, accuVal, currVal, accuIndex, currIndex);
Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex);
};
posUnaryOp(accuVal);
acc_elementwise_op(accuVal, accuVal);
if(!float_equal_one{}(alpha))
accuVal *= type_convert<AccDataType>(alpha);
......@@ -308,15 +302,16 @@ struct ReductionHost
using ck::float_equal_one;
using ck::float_equal_zero;
using ck::type_convert;
using ck::host_reduce::binop_with_nan_check;
using ck::host_reduce::ReduceOpFn;
using ck::host_reduce::ReduceOpZeroVal;
auto opReduce = ReduceOpFn<AccDataType, ReduceOpId>();
using Accumulation =
ck::detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
InElementwiseOperation in_elementwise_op(divider);
AccElementwiseOperation acc_elementwise_op(divider);
if constexpr(NumInvariantDim == 0)
{
AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>();
AccDataType accuVal = ReduceOperation::GetIdentityValue();
for(const auto& reduce_index : reduce_dim_indexes)
{
......@@ -325,12 +320,12 @@ struct ReductionHost
auto currVal = type_convert<AccDataType>(in_data[offset_reduce]);
preUnaryOp(currVal);
in_elementwise_op(currVal, currVal);
binop_with_nan_check<AccDataType, PropagateNan>(opReduce, accuVal, currVal);
Accumulation::Calculate(accuVal, currVal);
};
posUnaryOp(accuVal);
acc_elementwise_op(accuVal, accuVal);
if(!float_equal_one{}(alpha))
accuVal *= type_convert<AccDataType>(alpha);
......@@ -343,7 +338,7 @@ struct ReductionHost
else
{
auto thread_reduce_func = [&](auto invariant_index) {
AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>();
AccDataType accuVal = ReduceOperation::GetIdentityValue();
auto offset_invariant =
get_offset_from_index<NumInvariantDim>(invariantStrides, invariant_index);
......@@ -356,12 +351,12 @@ struct ReductionHost
auto currVal =
type_convert<AccDataType>(in_data[offset_invariant + offset_reduce]);
preUnaryOp(currVal);
in_elementwise_op(currVal, currVal);
binop_with_nan_check<AccDataType, PropagateNan>(opReduce, accuVal, currVal);
Accumulation::Calculate(accuVal, currVal);
};
posUnaryOp(accuVal);
acc_elementwise_op(accuVal, accuVal);
if(!float_equal_one{}(alpha))
accuVal *= type_convert<AccDataType>(alpha);
......
......@@ -171,8 +171,8 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
{
for(int m = 0; m < M; ++m)
{
float d0_acc = d0_reduce_op.GetReductionZeroVal();
float d1_acc = d1_reduce_op.GetReductionZeroVal();
float d0_acc = d0_reduce_op.GetIdentityValue();
float d1_acc = d1_reduce_op.GetIdentityValue();
for(int n = 0; n < N; ++n)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment