alpaka
Abstraction Library for Parallel Kernel Acceleration
Loading...
Searching...
No Matches
PhiloxStateless.hpp
Go to the documentation of this file.
1/* Copyright 2022 Jiri Vyskocil, Bernhard Manfred Gruber, Jeffrey Kelling
2 * SPDX-License-Identifier: MPL-2.0
3 */
4
5
6#pragma once
7
8#include "alpaka/Vec.hpp"
11
12#include <utility>
13
14namespace alpaka::rand::engine::internal
15{
16 /** Philox algorithm parameters
17 *
18 * @tparam TCounterSize number of elements in the counter
19 * @tparam TWidth width of one counter element (in bits)
20 * @tparam TRounds number of S-box rounds
21 */
22 template<unsigned TCounterSize, unsigned TWidth, unsigned TRounds>
23 struct PhiloxParams
24 {
25 static constexpr unsigned counterSize = TCounterSize;
26 static constexpr unsigned width = TWidth;
27 static constexpr unsigned rounds = TRounds;
28 };
29
30 /** Class basic Philox family counter-based PRNG
31 *
32 * Checks the validity of passed-in parameters and calls the backend methods to perform N rounds of the
33 * Philox shuffle.
34 *
35 * @tparam T_Params Philox algorithm parameters \sa PhiloxParams
36 */
37 template<typename T_Params>
38 class PhiloxStateless
39 {
40 static constexpr unsigned numRounds()
41 {
42 return T_Params::rounds;
43 }
44
45 static constexpr unsigned vectorSize()
46 {
47 return T_Params::counterSize;
48 }
49
50 static constexpr unsigned numberWidth()
51 {
52 return T_Params::width;
53 }
54
55 static_assert(numRounds() > 0, "Number of Philox rounds must be > 0.");
56 static_assert(vectorSize() % 2 == 0, "Philox counter size must be an even number.");
57 static_assert(vectorSize() <= 16, "Philox SP network is not specified for sizes > 16.");
58 static_assert(numberWidth() % 8 == 0, "Philox number width in bits must be a multiple of 8.");
59
60 static_assert(numberWidth() == 32, "Philox implemented only for 32 bit numbers.");
61
62 public:
63 using Counter = alpaka::Vec<std::uint32_t, T_Params::counterSize>;
64 using Key = alpaka::Vec<std::uint32_t, T_Params::counterSize / 2>;
65
66 protected:
67 /** Single round of the Philox shuffle
68 *
69 * @param counter state of the counter
70 * @param key value of the key
71 * @return shuffled counter
72 */
73 static constexpr auto singleRound(Counter const& counter, Key const& key)
74 {
75 std::uint32_t H0, L0, H1, L1;
76 multiplyAndSplit64to32(counter[0], PhiloxConstants::MULTIPLITER_4x32_0(), H0, L0);
77 multiplyAndSplit64to32(counter[2], PhiloxConstants::MULTIPLITER_4x32_1(), H1, L1);
78 return Counter{H1 ^ counter[1] ^ key[0], L1, H0 ^ counter[3] ^ key[1], L0};
79 }
80
81 /** Bump the \a key by the Weyl sequence step parameter
82 *
83 * @param key the key to be bumped
84 * @return the bumped key
85 */
86 static constexpr auto bumpKey(Key const& key)
87 {
88 return Key{key[0] + PhiloxConstants::WEYL_32_0(), key[1] + PhiloxConstants::WEYL_32_1()};
89 }
90
91 /** Performs N rounds of the Philox shuffle
92 *
93 * @param counter_in initial state of the counter
94 * @param key_in initial state of the key
95 * @return result of the PRNG shuffle; has the same size as the counter
96 */
97 static constexpr auto nRounds(Counter const& counter_in, Key const& key_in) -> Counter
98 {
99 Key key{key_in};
100 Counter counter = singleRound(counter_in, key);
101
102 // Use a constexpr variable to ensure the unroll factor is a compile-time constant
103 constexpr unsigned rounds = numRounds();
104
105 for(unsigned int n = 0; n < rounds; ++n)
106 {
107 key = bumpKey(key);
108 counter = singleRound(counter, key);
109 }
110
111 return counter;
112 }
113
114 public:
115 /** Generates a random number (\p TCounterSize x32-bit)
116 *
117 * @param counter initial state of the counter
118 * @param key initial state of the key
119 * @return result of the PRNG shuffle; has the same size as the counter
120 */
121 static constexpr auto generate(Counter const& counter, Key const& key) -> Counter
122 {
123 return nRounds(counter, key);
124 }
125 };
126} // namespace alpaka::rand::engine::internal
ALPAKA_FN_HOST_ACC Vec(T_1, T_Args...) -> Vec< T_1, uint32_t(sizeof...(T_Args)+1u), ArrayStorage< T_1, uint32_t(sizeof...(T_Args)+1u)> >