alpaka
Abstraction Library for Parallel Kernel Acceleration
Loading...
Searching...
No Matches
warp.hpp
Go to the documentation of this file.
1/* Copyright 2025 Mehmet Yusufoglu, René Widera
2 * SPDX-License-Identifier: MPL-2.0
3 */
4
5#pragma once
6
11
12#include <cstdint>
13#include <type_traits>
14
15#if ALPAKA_LANG_CUDA
16namespace alpaka::onAcc::warp::internal
17{
18 template<alpaka::onAcc::concepts::Acc T_Acc>
19 struct Activemask::Op<T_Acc, api::Cuda>
20 {
21 constexpr __device__ auto operator()(T_Acc const&, api::Cuda) const
22 {
23 return __activemask();
24 }
25 };
26
27 template<alpaka::onAcc::concepts::Acc T_Acc>
28 struct GetLaneIdx::Op<T_Acc, api::Cuda>
29 {
30 constexpr __device__ auto operator()(T_Acc const&, api::Cuda) const
31 {
32 unsigned lIdx;
33# if ALPAKA_COMP_NVCC
34 asm volatile("mov.u32 %0, %laneid;" : "=r"(lIdx));
35# else
36 asm("mov.u32 %0, %%laneid;" : "=r"(lIdx));
37# endif
38 return lIdx;
39 }
40 };
41
42 template<alpaka::onAcc::concepts::Acc T_Acc>
43 struct GetWarpIdx::Op<T_Acc, api::Cuda>
44 {
45 constexpr __device__ uint32_t operator()(T_Acc const& acc, api::Cuda) const
46 {
47 constexpr uint32_t warpExtent = onAcc::warp::internal::getSize<ALPAKA_TYPEOF(acc)>();
48 alpaka::concepts::Vector auto blockThreadCount
49 = acc.getExtentsOf(onAcc::origin::block, onAcc::unit::threads);
50 alpaka::concepts::Vector auto threadIdxInBlock
51 = acc.getIdxWithin(alpaka::onAcc::origin::block, alpaka::onAcc::unit::threads);
52 return linearize(blockThreadCount, threadIdxInBlock) / warpExtent;
53 }
54 };
55
56 template<alpaka::onAcc::concepts::Acc T_Acc>
57 struct All::Op<T_Acc, api::Cuda>
58 {
59 constexpr __device__ bool operator()(T_Acc const&, api::Cuda, int32_t predicate) const
60 {
61 return __all_sync(__activemask(), static_cast<int>(predicate)) != 0;
62 }
63 };
64
65 template<alpaka::onAcc::concepts::Acc T_Acc>
66 struct Any::Op<T_Acc, api::Cuda>
67 {
68 constexpr __device__ bool operator()(T_Acc const&, api::Cuda, int32_t predicate) const
69 {
70 return __any_sync(__activemask(), static_cast<int>(predicate)) != 0;
71 }
72 };
73
74 template<alpaka::onAcc::concepts::Acc T_Acc>
75 struct Ballot::Op<T_Acc, api::Cuda>
76 {
77 constexpr __device__ auto operator()(T_Acc const&, api::Cuda, int32_t predicate) const
78 {
79 return __ballot_sync(__activemask(), static_cast<int>(predicate));
80 }
81 };
82
83 template<alpaka::onAcc::concepts::Acc T_Acc, typename T>
84 struct Shfl::Op<T_Acc, api::Cuda, T>
85 {
86 constexpr __device__ T
87 operator()(T_Acc const&, api::Cuda, T const& value, uint32_t srcLane, uint32_t width) const
88 {
89 return __shfl_sync(__activemask(), value, static_cast<int>(srcLane), static_cast<int>(width));
90 }
91 };
92
93 template<alpaka::onAcc::concepts::Acc T_Acc, typename T>
94 struct ShflDown::Op<T_Acc, api::Cuda, T>
95 {
96 constexpr __device__ T
97 operator()(T_Acc const&, api::Cuda, T const& value, uint32_t delta, uint32_t width) const
98 {
99 return __shfl_down_sync(__activemask(), value, static_cast<int>(delta), static_cast<int>(width));
100 }
101 };
102
103 template<alpaka::onAcc::concepts::Acc T_Acc, typename T>
104 struct ShflUp::Op<T_Acc, api::Cuda, T>
105 {
106 constexpr __device__ T
107 operator()(T_Acc const&, api::Cuda, T const& value, uint32_t delta, uint32_t width) const
108 {
109 return __shfl_up_sync(__activemask(), value, static_cast<int>(delta), static_cast<int>(width));
110 }
111 };
112
113 template<alpaka::onAcc::concepts::Acc T_Acc, typename T>
114 struct ShflXor::Op<T_Acc, api::Cuda, T>
115 {
116 constexpr __device__ T
117 operator()(T_Acc const&, api::Cuda, T const& value, uint32_t laneMask, uint32_t width) const
118 {
119 return __shfl_xor_sync(__activemask(), value, static_cast<int>(laneMask), static_cast<int>(width));
120 }
121 };
122} // namespace alpaka::onAcc::warp::internal
123#endif
constexpr Api api
Definition tag.hpp:24
constexpr T_IntegralType linearize(Vec< T_IntegralType, T_dim - 1u, T_Storage > const &dim, Vec< T_IntegralType, T_dim, T_OtherStorage > const &idx)
Give the linear index of an N-dimensional index within an N-dimensional index space.
Definition Vec.hpp:832