Skip to content

Commit ae342fd

Browse files
syed-ahmedfacebook-github-bot
authored andcommitted
Refactor Random Number Generators in ATen (pytorch#21364)
Summary: Pull Request resolved: pytorch#21364 ghimport-source-id: ca7d37e Differential Revision: D15696497 Pulled By: ezyang fbshipit-source-id: 2e713b8566ae915e175b5a79ac1dd9b86cc2a23d
1 parent 9691025 commit ae342fd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

72 files changed

+1966
-798
lines changed

aten/src/ATen/CPUGenerator.cpp

+149-27
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,171 @@
11
#include <ATen/CPUGenerator.h>
2-
3-
#define const_generator_cast(generator) \
4-
dynamic_cast<const CPUGenerator&>(generator)
2+
#include <c10/util/C++17.h>
3+
#include <algorithm>
54

65
namespace at {
76

8-
CPUGenerator::CPUGenerator(Context * context_)
9-
: context(context_), generator(THGenerator_new())
10-
{}
7+
namespace detail {
8+
9+
// Ensures default_gen_cpu is initialized once.
10+
static std::once_flag cpu_gen_init_flag;
11+
12+
// Default, global CPU generator.
13+
static std::shared_ptr<CPUGenerator> default_gen_cpu;
14+
15+
/**
16+
* PyTorch maintains a collection of default generators that get
17+
* initialized once. The purpose of these default generators is to
18+
* maintain a global running state of the pseudo random number generation,
19+
* when a user does not explicitly mention any generator.
20+
* getDefaultCPUGenerator gets the default generator for a particular
21+
* device.
22+
*/
23+
CPUGenerator* getDefaultCPUGenerator() {
24+
std::call_once(cpu_gen_init_flag, [&] {
25+
default_gen_cpu = std::make_shared<CPUGenerator>(getNonDeterministicRandom());
26+
});
27+
return default_gen_cpu.get();
28+
}
29+
30+
/**
31+
* Utility to create a CPUGenerator. Returns a shared_ptr
32+
*/
33+
std::shared_ptr<CPUGenerator> createCPUGenerator(uint64_t seed_val) {
34+
return std::make_shared<CPUGenerator>(seed_val);
35+
}
36+
37+
/**
38+
* Helper function to concatenate two 32 bit unsigned int
39+
* and return them as a 64 bit unsigned int
40+
*/
41+
inline uint64_t make64BitsFrom32Bits(uint32_t hi, uint32_t lo) {
42+
return (static_cast<uint64_t>(hi) << 32) | lo;
43+
}
44+
45+
} // namespace detail
46+
47+
/**
48+
* CPUGenerator class implementation
49+
*/
50+
CPUGenerator::CPUGenerator(uint64_t seed_in)
51+
: Generator{Device(DeviceType::CPU)},
52+
engine_{seed_in},
53+
next_float_normal_sample_{c10::optional<float>()},
54+
next_double_normal_sample_{c10::optional<double>()} { }
55+
56+
/**
57+
* Manually seeds the engine with the seed input
58+
* See Note [Acquire lock when using random generators]
59+
*/
60+
void CPUGenerator::set_current_seed(uint64_t seed) {
61+
next_float_normal_sample_.reset();
62+
next_double_normal_sample_.reset();
63+
engine_ = mt19937(seed);
64+
}
65+
66+
/**
67+
* Gets the current seed of CPUGenerator.
68+
*/
69+
uint64_t CPUGenerator::current_seed() const {
70+
return engine_.seed();
71+
}
72+
73+
/**
74+
* Gets the DeviceType of CPUGenerator.
75+
* Used for type checking during run time.
76+
*/
77+
DeviceType CPUGenerator::device_type() {
78+
return DeviceType::CPU;
79+
}
80+
81+
/**
82+
* Gets a random 32 bit unsigned integer from the engine
83+
*
84+
* See Note [Acquire lock when using random generators]
85+
*/
86+
uint32_t CPUGenerator::random() {
87+
return engine_();
88+
}
89+
90+
/**
91+
* Gets a random 64 bit unsigned integer from the engine
92+
*
93+
* See Note [Acquire lock when using random generators]
94+
*/
95+
uint64_t CPUGenerator::random64() {
96+
uint32_t random1 = engine_();
97+
uint32_t random2 = engine_();
98+
return detail::make64BitsFrom32Bits(random1, random2);
99+
}
11100

12-
CPUGenerator::~CPUGenerator() {
13-
if (generator)
14-
THGenerator_free(generator);
101+
/**
102+
* Get the cached normal random in float
103+
*/
104+
c10::optional<float> CPUGenerator::next_float_normal_sample() {
105+
return next_float_normal_sample_;
15106
}
16107

17-
CPUGenerator& CPUGenerator::copy(const Generator& from) {
18-
THGenerator_copy(generator, const_generator_cast(from).generator);
19-
return *this;
108+
/**
109+
* Get the cached normal random in double
110+
*/
111+
c10::optional<double> CPUGenerator::next_double_normal_sample() {
112+
return next_double_normal_sample_;
20113
}
21114

22-
CPUGenerator& CPUGenerator::free() {
23-
THGenerator_free(generator);
24-
return *this;
115+
/**
116+
* Cache normal random in float
117+
*
118+
* See Note [Acquire lock when using random generators]
119+
*/
120+
void CPUGenerator::set_next_float_normal_sample(c10::optional<float> randn) {
121+
next_float_normal_sample_ = randn;
25122
}
26123

27-
uint64_t CPUGenerator::seed() {
28-
return THRandom_seed(generator);
124+
/**
125+
* Cache normal random in double
126+
*
127+
* See Note [Acquire lock when using random generators]
128+
*/
129+
void CPUGenerator::set_next_double_normal_sample(c10::optional<double> randn) {
130+
next_double_normal_sample_ = randn;
29131
}
30132

31-
uint64_t CPUGenerator::initialSeed() {
32-
return THRandom_initialSeed(generator);
133+
/**
134+
* Get the engine of the CPUGenerator
135+
*/
136+
at::mt19937 CPUGenerator::engine() {
137+
return engine_;
33138
}
34139

35-
CPUGenerator& CPUGenerator::manualSeed(uint64_t seed) {
36-
THRandom_manualSeed(generator, seed);
37-
return *this;
140+
/**
141+
* Set the engine of the CPUGenerator
142+
*
143+
* See Note [Acquire lock when using random generators]
144+
*/
145+
void CPUGenerator::set_engine(at::mt19937 engine) {
146+
engine_ = engine;
38147
}
39148

40-
CPUGenerator& CPUGenerator::manualSeedAll(uint64_t seed) {
41-
// There's only one CPU generator
42-
return manualSeed(seed);
149+
/**
150+
* Public clone method implementation
151+
*
152+
* See Note [Acquire lock when using random generators]
153+
*/
154+
std::shared_ptr<CPUGenerator> CPUGenerator::clone() const {
155+
return std::shared_ptr<CPUGenerator>(this->clone_impl());
43156
}
44157

45-
void * CPUGenerator::unsafeGetTH() {
46-
return generator;
158+
/**
159+
* Private clone method implementation
160+
*
161+
* See Note [Acquire lock when using random generators]
162+
*/
163+
CPUGenerator* CPUGenerator::clone_impl() const {
164+
auto gen = new CPUGenerator();
165+
gen->set_engine(engine_);
166+
gen->set_next_float_normal_sample(next_float_normal_sample_);
167+
gen->set_next_double_normal_sample(next_double_normal_sample_);
168+
return gen;
47169
}
48170

49171
} // namespace at

aten/src/ATen/CPUGenerator.h

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#pragma once
2+
3+
#include <ATen/core/Generator.h>
4+
#include <ATen/core/MT19937RNGEngine.h>
5+
#include <ATen/core/PhiloxRNGEngine.h>
6+
#include <c10/util/Optional.h>
7+
8+
namespace at {
9+
10+
struct CAFFE2_API CPUGenerator : public Generator {
11+
// Constructors
12+
CPUGenerator(uint64_t seed_in = default_rng_seed_val);
13+
~CPUGenerator() = default;
14+
15+
// CPUGenerator methods
16+
std::shared_ptr<CPUGenerator> clone() const;
17+
void set_current_seed(uint64_t seed) override;
18+
uint64_t current_seed() const override;
19+
static DeviceType device_type();
20+
uint32_t random();
21+
uint64_t random64();
22+
c10::optional<float> next_float_normal_sample();
23+
c10::optional<double> next_double_normal_sample();
24+
void set_next_float_normal_sample(c10::optional<float> randn);
25+
void set_next_double_normal_sample(c10::optional<double> randn);
26+
at::mt19937 engine();
27+
void set_engine(at::mt19937 engine);
28+
29+
private:
30+
CPUGenerator* clone_impl() const override;
31+
at::mt19937 engine_;
32+
c10::optional<float> next_float_normal_sample_;
33+
c10::optional<double> next_double_normal_sample_;
34+
};
35+
36+
namespace detail {
37+
38+
CAFFE2_API CPUGenerator* getDefaultCPUGenerator();
39+
CAFFE2_API std::shared_ptr<CPUGenerator> createCPUGenerator(uint64_t seed_val = default_rng_seed_val);
40+
41+
} // namespace detail
42+
43+
}

aten/src/ATen/CPUTypeDefault.cpp

-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#include <ATen/CPUTypeDefault.h>
22

33
#include <ATen/Context.h>
4-
#include <ATen/CPUGenerator.h>
54

65
namespace at {
76

@@ -13,8 +12,4 @@ Device CPUTypeDefault::getDeviceFromPtr(void * data) const {
1312
return DeviceType::CPU;
1413
}
1514

16-
std::unique_ptr<Generator> CPUTypeDefault::generator() const {
17-
return std::unique_ptr<Generator>(new CPUGenerator(&at::globalContext()));
18-
}
19-
2015
} // namespace at

aten/src/ATen/CPUTypeDefault.h

-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ struct CAFFE2_API CPUTypeDefault : public TypeDefault {
88
: TypeDefault(type_id, is_variable, is_undefined) {}
99
Allocator* allocator() const override;
1010
Device getDeviceFromPtr(void * data) const override;
11-
std::unique_ptr<Generator> generator() const override;
1211
};
1312

1413
} // namespace at

aten/src/ATen/CheckGenerator.h

-18
This file was deleted.

aten/src/ATen/Context.cpp

-4
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#include <string>
1111
#include <stdexcept>
1212

13-
#include <ATen/CPUGenerator.h>
1413
#include <ATen/RegisterCPU.h>
1514
#include <ATen/Tensor.h>
1615
#include <ATen/cpu/FlushDenormal.h>
@@ -36,9 +35,6 @@ Context::Context()
3635

3736
THSetDefaultErrorHandler(errorHandler,nullptr);
3837
THSetDefaultArgErrorHandler(argErrorHandler,nullptr);
39-
40-
generator_registry[static_cast<int>(DeviceType::CPU)]
41-
.reset(new CPUGenerator(this));
4238
register_cpu_types(this);
4339
}
4440

aten/src/ATen/Context.h

+18-5
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <ATen/LegacyTHDispatcher.h>
99
#include <ATen/core/ATenGeneral.h>
1010
#include <ATen/core/Generator.h>
11+
#include <ATen/CPUGenerator.h>
1112
#include <ATen/core/LegacyTypeDispatch.h>
1213
#include <ATen/core/VariableHooksInterface.h>
1314
#include <ATen/detail/CUDAHooksInterface.h>
@@ -57,13 +58,20 @@ class CAFFE2_API Context {
5758
LegacyTHDispatch::LegacyTHDispatcherUniquePtr{t, LegacyTHDispatcherDeleter([](LegacyTHDispatcher* p) { delete p; }) });
5859
}
5960

60-
Generator & defaultGenerator(DeviceType device_type) {
61+
Generator & defaultGenerator(Device device) {
62+
DeviceType device_type = device.type();
6163
initCUDAIfNeeded(device_type);
6264
initHIPIfNeeded(device_type);
63-
auto & generator = generator_registry[static_cast<int>(device_type)];
64-
if(!generator)
65+
if (device_type == at::kCPU) {
66+
return *at::detail::getDefaultCPUGenerator();
67+
} else if (device_type == at::kCUDA) {
68+
auto & generator = generator_registry[static_cast<int>(device_type)];
69+
if(!generator)
6570
AT_ERROR(DeviceTypeName(device_type), " backend type not enabled.");
66-
return *generator;
71+
return *generator;
72+
} else {
73+
AT_ERROR(DeviceTypeName(device_type), " backend type not enabled.");
74+
}
6775
}
6876
bool hasOpenMP() const;
6977
bool hasMKL() const;
@@ -252,7 +260,12 @@ static inline bool hasMKLDNN() {
252260
}
253261

254262
static inline void manual_seed(uint64_t seed) {
255-
globalContext().defaultGenerator(DeviceType::CPU).manualSeed(seed);
263+
auto& gen = globalContext().defaultGenerator(DeviceType::CPU);
264+
{
265+
// See Note [Acquire lock when using random generators]
266+
std::lock_guard<std::mutex> lock(gen.mutex_);
267+
gen.set_current_seed(seed);
268+
}
256269
// NB: Sometimes we build with CUDA, but we don't have any GPUs
257270
// available. In that case, we must not seed CUDA; it will fail!
258271
if (hasCUDA() && detail::getCUDAHooks().getNumGPUs() > 0) {

aten/src/ATen/UndefinedType.cpp

-3
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,6 @@ Storage UndefinedType::unsafeStorageFromTH(void * th_pointer, bool retain) const
2323
Tensor UndefinedType::unsafeTensorFromTH(void * th_pointer, bool retain) const {
2424
AT_ERROR("unsafeTensorFromTH not defined for UndefinedType");
2525
}
26-
std::unique_ptr<Generator> UndefinedType::generator() const {
27-
AT_ERROR("generator not defined for UndefinedType");
28-
}
2926

3027
const char * UndefinedType::toString() const {
3128
return "UndefinedType";

aten/src/ATen/UndefinedType.h

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#pragma once
22

33
#include <ATen/TypeDefault.h>
4-
#include <ATen/CheckGenerator.h>
4+
#include <ATen/Utils.h>
55

66
#ifdef _MSC_VER
77
#ifdef Type
@@ -16,7 +16,6 @@ struct UndefinedType final : public TypeDefault {
1616
virtual Backend backend() const override;
1717
virtual Allocator* allocator() const override;
1818
virtual Device getDeviceFromPtr(void* data) const override;
19-
virtual std::unique_ptr<Generator> generator() const override;
2019
virtual const char * toString() const override;
2120
virtual Type & toBackend(Backend b) const override;
2221
virtual Type & toScalarType(ScalarType s) const override;

0 commit comments

Comments
 (0)