|
1 | 1 | #include <ATen/CPUGenerator.h>
|
2 |
| - |
3 |
| -#define const_generator_cast(generator) \ |
4 |
| - dynamic_cast<const CPUGenerator&>(generator) |
| 2 | +#include <c10/util/C++17.h> |
| 3 | +#include <algorithm> |
5 | 4 |
|
6 | 5 | namespace at {
|
7 | 6 |
|
8 |
| -CPUGenerator::CPUGenerator(Context * context_) |
9 |
| - : context(context_), generator(THGenerator_new()) |
10 |
| -{} |
| 7 | +namespace detail { |
| 8 | + |
| 9 | +// Ensures default_gen_cpu is initialized once. |
| 10 | +static std::once_flag cpu_gen_init_flag; |
| 11 | + |
| 12 | +// Default, global CPU generator. |
| 13 | +static std::shared_ptr<CPUGenerator> default_gen_cpu; |
| 14 | + |
| 15 | +/** |
| 16 | + * PyTorch maintains a collection of default generators that get |
| 17 | + * initialized once. The purpose of these default generators is to |
| 18 | + * maintain a global running state of the pseudo random number generation, |
| 19 | + * when a user does not explicitly mention any generator. |
| 20 | + * getDefaultCPUGenerator gets the default generator for a particular |
| 21 | + * device. |
| 22 | + */ |
| 23 | +CPUGenerator* getDefaultCPUGenerator() { |
| 24 | + std::call_once(cpu_gen_init_flag, [&] { |
| 25 | + default_gen_cpu = std::make_shared<CPUGenerator>(getNonDeterministicRandom()); |
| 26 | + }); |
| 27 | + return default_gen_cpu.get(); |
| 28 | +} |
| 29 | + |
| 30 | +/** |
| 31 | + * Utility to create a CPUGenerator. Returns a shared_ptr |
| 32 | + */ |
| 33 | +std::shared_ptr<CPUGenerator> createCPUGenerator(uint64_t seed_val) { |
| 34 | + return std::make_shared<CPUGenerator>(seed_val); |
| 35 | +} |
| 36 | + |
| 37 | +/** |
| 38 | + * Helper function to concatenate two 32 bit unsigned int |
| 39 | + * and return them as a 64 bit unsigned int |
| 40 | + */ |
| 41 | +inline uint64_t make64BitsFrom32Bits(uint32_t hi, uint32_t lo) { |
| 42 | + return (static_cast<uint64_t>(hi) << 32) | lo; |
| 43 | +} |
| 44 | + |
| 45 | +} // namespace detail |
| 46 | + |
| 47 | +/** |
| 48 | + * CPUGenerator class implementation |
| 49 | + */ |
| 50 | +CPUGenerator::CPUGenerator(uint64_t seed_in) |
| 51 | + : Generator{Device(DeviceType::CPU)}, |
| 52 | + engine_{seed_in}, |
| 53 | + next_float_normal_sample_{c10::optional<float>()}, |
| 54 | + next_double_normal_sample_{c10::optional<double>()} { } |
| 55 | + |
| 56 | +/** |
| 57 | + * Manually seeds the engine with the seed input |
| 58 | + * See Note [Acquire lock when using random generators] |
| 59 | + */ |
| 60 | +void CPUGenerator::set_current_seed(uint64_t seed) { |
| 61 | + next_float_normal_sample_.reset(); |
| 62 | + next_double_normal_sample_.reset(); |
| 63 | + engine_ = mt19937(seed); |
| 64 | +} |
| 65 | + |
| 66 | +/** |
| 67 | + * Gets the current seed of CPUGenerator. |
| 68 | + */ |
| 69 | +uint64_t CPUGenerator::current_seed() const { |
| 70 | + return engine_.seed(); |
| 71 | +} |
| 72 | + |
| 73 | +/** |
| 74 | + * Gets the DeviceType of CPUGenerator. |
| 75 | + * Used for type checking during run time. |
| 76 | + */ |
| 77 | +DeviceType CPUGenerator::device_type() { |
| 78 | + return DeviceType::CPU; |
| 79 | +} |
| 80 | + |
| 81 | +/** |
| 82 | + * Gets a random 32 bit unsigned integer from the engine |
| 83 | + * |
| 84 | + * See Note [Acquire lock when using random generators] |
| 85 | + */ |
| 86 | +uint32_t CPUGenerator::random() { |
| 87 | + return engine_(); |
| 88 | +} |
| 89 | + |
| 90 | +/** |
| 91 | + * Gets a random 64 bit unsigned integer from the engine |
| 92 | + * |
| 93 | + * See Note [Acquire lock when using random generators] |
| 94 | + */ |
| 95 | +uint64_t CPUGenerator::random64() { |
| 96 | + uint32_t random1 = engine_(); |
| 97 | + uint32_t random2 = engine_(); |
| 98 | + return detail::make64BitsFrom32Bits(random1, random2); |
| 99 | +} |
11 | 100 |
|
12 |
| -CPUGenerator::~CPUGenerator() { |
13 |
| - if (generator) |
14 |
| - THGenerator_free(generator); |
| 101 | +/** |
| 102 | + * Get the cached normal random in float |
| 103 | + */ |
| 104 | +c10::optional<float> CPUGenerator::next_float_normal_sample() { |
| 105 | + return next_float_normal_sample_; |
15 | 106 | }
|
16 | 107 |
|
17 |
| -CPUGenerator& CPUGenerator::copy(const Generator& from) { |
18 |
| - THGenerator_copy(generator, const_generator_cast(from).generator); |
19 |
| - return *this; |
| 108 | +/** |
| 109 | + * Get the cached normal random in double |
| 110 | + */ |
| 111 | +c10::optional<double> CPUGenerator::next_double_normal_sample() { |
| 112 | + return next_double_normal_sample_; |
20 | 113 | }
|
21 | 114 |
|
22 |
| -CPUGenerator& CPUGenerator::free() { |
23 |
| - THGenerator_free(generator); |
24 |
| - return *this; |
| 115 | +/** |
| 116 | + * Cache normal random in float |
| 117 | + * |
| 118 | + * See Note [Acquire lock when using random generators] |
| 119 | + */ |
| 120 | +void CPUGenerator::set_next_float_normal_sample(c10::optional<float> randn) { |
| 121 | + next_float_normal_sample_ = randn; |
25 | 122 | }
|
26 | 123 |
|
27 |
| -uint64_t CPUGenerator::seed() { |
28 |
| - return THRandom_seed(generator); |
| 124 | +/** |
| 125 | + * Cache normal random in double |
| 126 | + * |
| 127 | + * See Note [Acquire lock when using random generators] |
| 128 | + */ |
| 129 | +void CPUGenerator::set_next_double_normal_sample(c10::optional<double> randn) { |
| 130 | + next_double_normal_sample_ = randn; |
29 | 131 | }
|
30 | 132 |
|
31 |
| -uint64_t CPUGenerator::initialSeed() { |
32 |
| - return THRandom_initialSeed(generator); |
| 133 | +/** |
| 134 | + * Get the engine of the CPUGenerator |
| 135 | + */ |
| 136 | +at::mt19937 CPUGenerator::engine() { |
| 137 | + return engine_; |
33 | 138 | }
|
34 | 139 |
|
35 |
| -CPUGenerator& CPUGenerator::manualSeed(uint64_t seed) { |
36 |
| - THRandom_manualSeed(generator, seed); |
37 |
| - return *this; |
| 140 | +/** |
| 141 | + * Set the engine of the CPUGenerator |
| 142 | + * |
| 143 | + * See Note [Acquire lock when using random generators] |
| 144 | + */ |
| 145 | +void CPUGenerator::set_engine(at::mt19937 engine) { |
| 146 | + engine_ = engine; |
38 | 147 | }
|
39 | 148 |
|
40 |
| -CPUGenerator& CPUGenerator::manualSeedAll(uint64_t seed) { |
41 |
| - // There's only one CPU generator |
42 |
| - return manualSeed(seed); |
| 149 | +/** |
| 150 | + * Public clone method implementation |
| 151 | + * |
| 152 | + * See Note [Acquire lock when using random generators] |
| 153 | + */ |
| 154 | +std::shared_ptr<CPUGenerator> CPUGenerator::clone() const { |
| 155 | + return std::shared_ptr<CPUGenerator>(this->clone_impl()); |
43 | 156 | }
|
44 | 157 |
|
45 |
| -void * CPUGenerator::unsafeGetTH() { |
46 |
| - return generator; |
| 158 | +/** |
| 159 | + * Private clone method implementation |
| 160 | + * |
| 161 | + * See Note [Acquire lock when using random generators] |
| 162 | + */ |
| 163 | +CPUGenerator* CPUGenerator::clone_impl() const { |
| 164 | + auto gen = new CPUGenerator(); |
| 165 | + gen->set_engine(engine_); |
| 166 | + gen->set_next_float_normal_sample(next_float_normal_sample_); |
| 167 | + gen->set_next_double_normal_sample(next_double_normal_sample_); |
| 168 | + return gen; |
47 | 169 | }
|
48 | 170 |
|
49 | 171 | } // namespace at
|
0 commit comments