CUB
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Macros Groups Pages
Classes | Enumerations | Functions
cub Namespace Reference

CUB namespace. More...

Classes

struct  ArrayTraits
 Array traits. More...
 
struct  BaseTraits
 Basic type traits. More...
 
class  BlockDiscontinuity
 BlockDiscontinuity provides operations for flagging discontinuities within a list of data items partitioned across a CUDA threadblock.

discont_logo.png
.
More...
 
class  BlockExchange
 BlockExchange provides operations for reorganizing the partitioning of ordered data across a CUDA threadblock.

transpose_logo.png
.
More...
 
class  BlockLoad
 BlockLoad provides data movement operations for reading block-arranged data from global memory.

block_load_logo.png
.
More...
 
class  BlockRadixSort
 BlockRadixSort provides variants of parallel radix sorting across a CUDA threadblock.

sorting_logo.png
.
More...
 
class  BlockReduce
 BlockReduce provides variants of parallel reduction across a CUDA threadblock.

reduce_logo.png
.
More...
 
class  BlockScan
 BlockScan provides variants of parallel prefix scan (and prefix sum) across a CUDA threadblock.

scan_logo.png
.
More...
 
class  BlockStore
 BlockStore provides data movement operations for writing blocked-arranged data to global memory.

block_store_logo.png
.
More...
 
struct  EnableIf
 Simple enable-if (similar to Boost) More...
 
struct  Equality
 Default equality functor. More...
 
struct  Equals
 Type equality test. More...
 
struct  If
 Type selection (IF ? ThenType : ElseType) More...
 
struct  IsVolatile
 Volatile modifier test. More...
 
struct  Log2
 Statically determine log2(N), rounded up. More...
 
struct  Max
 Default max functor. More...
 
struct  NullType
 A simple "NULL" marker type. More...
 
struct  NumericTraits
 Numeric type traits. More...
 
struct  RemoveQualifiers
 Removes const and volatile qualifiers from type Tp. More...
 
struct  Sum
 Default sum functor. More...
 
struct  Traits
 Type traits. More...
 
class  WarpScan
 WarpScan provides variants of parallel prefix scan across a CUDA warp.

warp_scan_logo.png
.
More...
 

Enumerations

enum  BlockLoadPolicy { BLOCK_LOAD_DIRECT, BLOCK_LOAD_VECTORIZE, BLOCK_LOAD_TRANSPOSE }
 Tuning policy for cub::BlockLoad. More...
 
enum  BlockScanPolicy { BLOCK_SCAN_RAKING, BLOCK_SCAN_WARPSCANS }
 Tuning policy for cub::BlockScan. More...
 
enum  BlockStorePolicy { BLOCK_STORE_DIRECT, BLOCK_STORE_VECTORIZE, BLOCK_STORE_TRANSPOSE }
 Tuning policy for cub::BlockStore. More...
 
enum  Category { NOT_A_NUMBER, SIGNED_INTEGER, UNSIGNED_INTEGER, FLOATING_POINT }
 Basic type traits categories.
 
enum  PtxLoadModifier {
  PTX_LOAD_NONE, PTX_LOAD_CA, PTX_LOAD_CG, PTX_LOAD_CS,
  PTX_LOAD_CV, PTX_LOAD_LDG, PTX_LOAD_VS
}
 Enumeration of PTX cache-modifiers for memory load operations. More...
 
enum  PtxStoreModifier {
  PTX_STORE_NONE, PTX_STORE_WB, PTX_STORE_CG, PTX_STORE_CS,
  PTX_STORE_WT, PTX_STORE_VS
}
 Enumeration of PTX cache-modifiers for memory store operations. More...
 

Functions

__host__ __device__
__forceinline__ cudaError_t 
Debug (cudaError_t error, const char *message, const char *filename, int line)
 If CUB_STDERR is defined and error is not cudaSuccess, message is printed to stderr along with the supplied source context. More...
 
__host__ __device__
__forceinline__ cudaError_t 
Debug (cudaError_t error, const char *filename, int line)
 If CUB_STDERR is defined and error is not cudaSuccess, the corresponding error message is printed to stderr along with the supplied source context. More...
 
Direct threadblock loads (blocked arrangement)
template<PtxLoadModifier MODIFIER, typename T , int ITEMS_PER_THREAD, typename InputIterator >
__device__ __forceinline__ void BlockLoadDirect (InputIterator block_itr, T(&items)[ITEMS_PER_THREAD])
 Load a tile of items across a threadblock directly using the specified cache modifier. More...
 
template<typename T , int ITEMS_PER_THREAD, typename InputIterator >
__device__ __forceinline__ void BlockLoadDirect (InputIterator block_itr, T(&items)[ITEMS_PER_THREAD])
 Load a tile of items across a threadblock directly. More...
 
template<PtxLoadModifier MODIFIER, typename T , int ITEMS_PER_THREAD, typename InputIterator , typename SizeT >
__device__ __forceinline__ void BlockLoadDirect (InputIterator block_itr, const SizeT &guarded_items, T(&items)[ITEMS_PER_THREAD])
 Load a tile of items across a threadblock directly using the specified cache modifier, guarded by range. More...
 
template<typename T , int ITEMS_PER_THREAD, typename InputIterator , typename SizeT >
__device__ __forceinline__ void BlockLoadDirect (InputIterator block_itr, const SizeT &guarded_items, T(&items)[ITEMS_PER_THREAD])
 Load a tile of items across a threadblock directly, guarded by range. More...
 
template<PtxLoadModifier MODIFIER, typename T , int ITEMS_PER_THREAD, typename InputIterator , typename SizeT >
__device__ __forceinline__ void BlockLoadDirect (InputIterator block_itr, const SizeT &guarded_items, T oob_default, T(&items)[ITEMS_PER_THREAD])
 Load a tile of items across a threadblock directly using the specified cache modifier, guarded by range, with assignment for out-of-bound elements. More...
 
template<typename T , int ITEMS_PER_THREAD, typename InputIterator , typename SizeT >
__device__ __forceinline__ void BlockLoadDirect (InputIterator block_itr, const SizeT &guarded_items, T oob_default, T(&items)[ITEMS_PER_THREAD])
 Load a tile of items across a threadblock directly, guarded by range, with assignment for out-of-bound elements. More...
 
Direct threadblock loads (striped arrangement)
template<PtxLoadModifier MODIFIER, typename T , int ITEMS_PER_THREAD, typename InputIterator >
__device__ __forceinline__ void BlockLoadDirectStriped (InputIterator block_itr, T(&items)[ITEMS_PER_THREAD], int stride=blockDim.x)
 Load striped tile directly using the specified cache modifier. More...
 
template<typename T , int ITEMS_PER_THREAD, typename InputIterator >
__device__ __forceinline__ void BlockLoadDirectStriped (InputIterator block_itr, T(&items)[ITEMS_PER_THREAD], int stride=blockDim.x)
 Load striped tile directly. More...
 
template<PtxLoadModifier MODIFIER, typename T , int ITEMS_PER_THREAD, typename InputIterator , typename SizeT >
__device__ __forceinline__ void BlockLoadDirectStriped (InputIterator block_itr, const SizeT &guarded_items, T(&items)[ITEMS_PER_THREAD], int stride=blockDim.x)
 Load striped directly tile using the specified cache modifier, guarded by range. More...
 
template<typename T , int ITEMS_PER_THREAD, typename InputIterator , typename SizeT >
__device__ __forceinline__ void BlockLoadDirectStriped (InputIterator block_itr, const SizeT &guarded_items, T(&items)[ITEMS_PER_THREAD], int stride=blockDim.x)
 Load striped tile directly, guarded by range. More...
 
template<PtxLoadModifier MODIFIER, typename T , int ITEMS_PER_THREAD, typename InputIterator , typename SizeT >
__device__ __forceinline__ void BlockLoadDirectStriped (InputIterator block_itr, const SizeT &guarded_items, T oob_default, T(&items)[ITEMS_PER_THREAD], int stride=blockDim.x)
 Load striped directly tile using the specified cache modifier, guarded by range, with assignment for out-of-bound elements. More...
 
template<typename T , int ITEMS_PER_THREAD, typename InputIterator , typename SizeT >
__device__ __forceinline__ void BlockLoadDirectStriped (InputIterator block_itr, const SizeT &guarded_items, T oob_default, T(&items)[ITEMS_PER_THREAD], int stride=blockDim.x)
 Load striped tile directly, guarded by range, with assignment for out-of-bound elements. More...
 
Threadblock vectorized loads (blocked arrangement)
template<PtxLoadModifier MODIFIER, typename T , int ITEMS_PER_THREAD>
__device__ __forceinline__ void BlockLoadVectorized (T *block_ptr, T(&items)[ITEMS_PER_THREAD])
 Load a tile of items across a threadblock directly using the specified cache modifier. More...
 
template<typename T , int ITEMS_PER_THREAD>
__device__ __forceinline__ void BlockLoadVectorized (T *block_ptr, T(&items)[ITEMS_PER_THREAD])
 Load a tile of items across a threadblock directly. More...
 
Direct threadblock stores (blocked arrangement)
template<PtxStoreModifier MODIFIER, typename T , int ITEMS_PER_THREAD, typename OutputIterator >
__device__ __forceinline__ void BlockStoreDirect (OutputIterator block_itr, T(&items)[ITEMS_PER_THREAD])
 Store a tile of items across a threadblock directly using the specified cache modifier. More...
 
template<typename T , int ITEMS_PER_THREAD, typename OutputIterator >
__device__ __forceinline__ void BlockStoreDirect (OutputIterator block_itr, T(&items)[ITEMS_PER_THREAD])
 Store a tile of items across a threadblock directly. More...
 
template<PtxStoreModifier MODIFIER, typename T , int ITEMS_PER_THREAD, typename OutputIterator , typename SizeT >
__device__ __forceinline__ void BlockStoreDirect (OutputIterator block_itr, const SizeT &guarded_items, T(&items)[ITEMS_PER_THREAD])
 Store a tile of items across a threadblock directly using the specified cache modifier, guarded by range. More...
 
template<typename T , int ITEMS_PER_THREAD, typename OutputIterator , typename SizeT >
__device__ __forceinline__ void BlockStoreDirect (OutputIterator block_itr, const SizeT &guarded_items, T(&items)[ITEMS_PER_THREAD])
 Store a tile of items across a threadblock directly, guarded by range. More...
 
Direct threadblock stores (striped arrangement)
template<PtxStoreModifier MODIFIER, typename T , int ITEMS_PER_THREAD, typename OutputIterator >
__device__ __forceinline__ void BlockStoreDirectStriped (OutputIterator block_itr, T(&items)[ITEMS_PER_THREAD], int stride=blockDim.x)
 Store striped tile directly using the specified cache modifier. More...
 
template<typename T , int ITEMS_PER_THREAD, typename OutputIterator >
__device__ __forceinline__ void BlockStoreDirectStriped (OutputIterator block_itr, T(&items)[ITEMS_PER_THREAD], int stride=blockDim.x)
 Store striped tile directly. More...
 
template<PtxStoreModifier MODIFIER, typename T , int ITEMS_PER_THREAD, typename OutputIterator , typename SizeT >
__device__ __forceinline__ void BlockStoreDirectStriped (OutputIterator block_itr, const SizeT &guarded_items, T(&items)[ITEMS_PER_THREAD], int stride=blockDim.x)
 
template<typename T , int ITEMS_PER_THREAD, typename OutputIterator , typename SizeT >
__device__ __forceinline__ void BlockStoreDirectStriped (OutputIterator block_itr, const SizeT &guarded_items, T(&items)[ITEMS_PER_THREAD], int stride=blockDim.x)
 Store striped tile directly, guarded by range. More...
 
Threadblock vectorized stores (blocked arrangement)
template<PtxStoreModifier MODIFIER, typename T , int ITEMS_PER_THREAD>
__device__ __forceinline__ void BlockStoreVectorized (T *block_ptr, T(&items)[ITEMS_PER_THREAD])
 Store a tile of items across a threadblock directly using the specified cache modifier. More...
 
template<typename T , int ITEMS_PER_THREAD>
__device__ __forceinline__ void BlockStoreVectorized (T *block_ptr, T(&items)[ITEMS_PER_THREAD])
 Store a tile of items across a threadblock directly. More...
 
Thread utilities for memory I/O using PTX cache modifiers
template<PtxLoadModifier MODIFIER, typename InputIterator >
__device__ __forceinline__
std::iterator_traits
< InputIterator >::value_type 
ThreadLoad (InputIterator itr)
 Thread utility for reading memory using cub::PtxLoadModifier cache modifiers. More...
 
template<PtxStoreModifier MODIFIER, typename OutputIterator , typename T >
__device__ __forceinline__ void ThreadStore (OutputIterator itr, const T &val)
 Thread utility for writing memory using cub::PtxStoreModifier cache modifiers. More...
 

Detailed Description

CUB namespace.

Enumeration Type Documentation

Tuning policy for cub::BlockScan.

Enumerator
BLOCK_SCAN_RAKING 
Overview
An efficient "raking reduce-then-scan" prefix scan algorithm. Scan execution is comprised of five phases:
  1. Upsweep sequential reduction in registers (if threads contribute more than one input each). Each thread then places the partial reduction of its item(s) into shared memory.
  2. Upsweep sequential reduction in shared memory. Threads within a single warp rake across segments of shared partial reductions.
  3. A warp-synchronous Kogge-Stone style exclusive scan within the raking warp.
  4. Downsweep sequential exclusive scan in shared memory. Threads within a single warp rake across segments of shared partial reductions, seeded with the warp-scan output.
  5. Downsweep sequential scan in registers (if threads contribute more than one input), seeded with the raking scan output.
block_scan_raking.png
BLOCK_SCAN_RAKING data flow for a hypothetical 16-thread threadblock and 4-thread raking warp.
Performance Considerations
  • Although this variant may suffer longer turnaround latencies when the GPU is under-occupied, it can often provide higher overall throughput across the GPU when suitably occupied.
BLOCK_SCAN_WARPSCANS 
Overview
A quick "tiled warpscans" prefix scan algorithm. Scan execution is comprised of four phases:
  1. Upsweep sequential reduction in registers (if threads contribute more than one input each). Each thread then places the partial reduction of its item(s) into shared memory.
  2. Compute a shallow, but inefficient warp-synchronous Kogge-Stone style scan within each warp.
  3. A propagation phase where the warp scan outputs in each warp are updated with the aggregate from each preceding warp.
  4. Downsweep sequential scan in registers (if threads contribute more than one input), seeded with the raking scan output.
block_scan_warpscans.png
BLOCK_SCAN_WARPSCANS data flow for a hypothetical 16-thread threadblock and 4-thread raking warp.
Performance Considerations
  • Although this variant may suffer lower overall throughput across the GPU because due to a heavy reliance on inefficient warpscans, it can often provide lower turnaround latencies when the GPU is under-occupied.

Tuning policy for cub::BlockLoad.

Enumerator
BLOCK_LOAD_DIRECT 
Overview

A blocked arrangement of data is read directly from memory. The threadblock reads items in a parallel "raking" fashion: threadi reads the ith segment of consecutive elements.

Performance Considerations
  • The utilization of memory transactions (coalescing) decreases as the access stride between threads increases (i.e., the number items per thread).
BLOCK_LOAD_VECTORIZE 
Overview

A blocked arrangement of data is read directly from memory using CUDA's built-in vectorized loads as a coalescing optimization. The threadblock reads items in a parallel "raking" fashion: threadi uses vector loads to read the ith segment of consecutive elements.

For example, ld.global.v4.s32 instructions will be generated when T = int and ITEMS_PER_THREAD > 4.

Performance Considerations
  • The utilization of memory transactions (coalescing) remains high until the the access stride between threads (i.e., the number items per thread) exceeds the maximum vector load width (typically 4 items or 64B, whichever is lower).
  • The following conditions will prevent vectorization and loading will fall back to cub::BLOCK_LOAD_DIRECT:
    • ITEMS_PER_THREAD is odd
    • The InputIterator is not a simple pointer type
    • The block input offset is not quadword-aligned
    • The data type T is not a built-in primitive or CUDA vector type (e.g., short, int2, double, float2, etc.)
BLOCK_LOAD_TRANSPOSE 
Overview

A striped arrangement of data is read directly from memory and then is locally transposed into a blocked arrangement. The threadblock reads items in a parallel "strip-mining" fashion: threadi reads items having stride BLOCK_THREADS between them. cub::BlockExchange is then used to locally reorder the items into a blocked arrangement.

Performance Considerations
  • The utilization of memory transactions (coalescing) remains high regardless of items loaded per thread.
  • The local reordering incurs slightly longer latencies and throughput than the direct cub::BLOCK_LOAD_DIRECT and cub::BLOCK_LOAD_VECTORIZE alternatives.

Tuning policy for cub::BlockStore.

Enumerator
BLOCK_STORE_DIRECT 
Overview

A blocked arrangement of data is written directly to memory. The threadblock writes items in a parallel "raking" fashion: threadi writes the ith segment of consecutive elements.

Performance Considerations
  • The utilization of memory transactions (coalescing) decreases as the access stride between threads increases (i.e., the number items per thread).
BLOCK_STORE_VECTORIZE 
Overview

A blocked arrangement of data is written directly to memory using CUDA's built-in vectorized stores as a coalescing optimization. The threadblock writes items in a parallel "raking" fashion: threadi uses vector stores to write the ith segment of consecutive elements.

For example, st.global.v4.s32 instructions will be generated when T = int and ITEMS_PER_THREAD > 4.

Performance Considerations
  • The utilization of memory transactions (coalescing) remains high until the the access stride between threads (i.e., the number items per thread) exceeds the maximum vector load width (typically 4 items or 64B, whichever is lower).
  • The following conditions will prevent vectorization and loading will fall back to cub::BLOCK_STORE_DIRECT:
    • ITEMS_PER_THREAD is odd
    • The OutputIterator is not a simple pointer type
    • The block output offset is not quadword-aligned
    • The data type T is not a built-in primitive or CUDA vector type (e.g., short, int2, double, float2, etc.)
BLOCK_STORE_TRANSPOSE 
Overview
A blocked arrangement is locally transposed into a striped arrangement which is then written to memory. More specifically, cub::BlockExchange used to locally reorder the items into a striped arrangement, after which the threadblock writes items in a parallel "strip-mining" fashion: consecutive items owned by threadi are written to memory with stride BLOCK_THREADS between them.
Performance Considerations
  • The utilization of memory transactions (coalescing) remains high regardless of items written per thread.
  • The local reordering incurs slightly longer latencies and throughput than the direct cub::BLOCK_STORE_DIRECT and cub::BLOCK_STORE_VECTORIZE alternatives.