forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathMPSDevice.mm
158 lines (140 loc) · 5.4 KB
/
MPSDevice.mm
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
//
// Copyright (c) 2023 Apple Inc. All rights reserved.
// Provided subject to the LICENSE file in the top level directory.
//
#include "MPSDevice.h"
#include <executorch/runtime/platform/assert.h>
#include <memory>
#include <mutex>
namespace executorch {
namespace backends {
namespace mps {
namespace delegate {
using executorch::runtime::Error;
static std::unique_ptr<MPSDevice> mps_device;
static std::once_flag mpsdev_init;
static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& device, bool macOS13Plus) {
// MPS Advanced Indexing needs at least Metal 2.0 (support for Argument Buffers and function constants)
// host_name attribute needs at least Metal 2.2 and ulong needs Metal 2.3 (supported on MacOS 11+)
MTLLanguageVersion languageVersion = MTLLanguageVersion2_3;
if (@available(iOS 16, macOS 13, *)) {
if (macOS13Plus) {
languageVersion = MTLLanguageVersion3_0;
}
}
ET_CHECK_MSG([device supportsFamily:MTLGPUFamilyMac2], "Missing Metal support for MTLGPUFamilyMac2");
return languageVersion;
}
MPSDevice::~MPSDevice() {
[_mtl_device release];
_mtl_device = nil;
}
MPSDevice::MPSDevice(): _mtl_device(nil) {
@autoreleasepool {
#if TARGET_OS_IPHONE
_mtl_device = MTLCreateSystemDefaultDevice();
#else
NSArray* devices = MTLCopyAllDevices();
for (unsigned long i = 0 ; i < [devices count] ; i++) {
id<MTLDevice> device = devices[i];
if(![device isLowPower]) { // exclude Intel GPUs
_mtl_device = [device retain];
break;
}
}
#endif
}
// MPS TODO: Replace with `ET_CHECK_OR_RETURN_ERROR` and propagate back the error.
ET_CHECK(_mtl_device != nil);
}
MPSDevice* MPSDevice::getInstance() {
std::call_once(mpsdev_init, [] {
mps_device = std::unique_ptr<MPSDevice>(new MPSDevice());
});
return mps_device.get();
}
bool MPSDevice::isMacOS13Plus(MacOSVersion version) const {
id mpsCD = NSClassFromString(@"MPSGraph");
static auto compileOptions = [[[MTLCompileOptions alloc] init] autorelease];
static bool _macos_13_0_plus = [mpsCD instancesRespondToSelector:@selector(cumulativeSumWithTensor:
axis:name:)] == YES;
static bool _macos_13_1_plus =
[mpsCD instancesRespondToSelector:@selector
(sampleGridWithSourceTensor:
coordinateTensor:layout:normalizeCoordinates:relativeCoordinates:alignCorners:paddingMode
:samplingMode:constantValue:name:)] == YES;
static bool _macos_13_2_plus =
[mpsCD instancesRespondToSelector:@selector(convolution3DWithSourceTensor:weightsTensor:descriptor:name:)] == YES;
static bool _macos_13_3_plus = [compileOptions respondsToSelector:@selector(maxTotalThreadsPerThreadgroup)] == YES;
static bool _macos_14_0_plus = [mpsCD instancesRespondToSelector:@selector(conjugateWithTensor:name:)] == YES;
static bool _macos_15_0_plus = [mpsCD instancesRespondToSelector:@selector(scaledDotProductAttentionWithQueryTensor:keyTensor:valueTensor:maskTensor:scale:name:)] == YES;
switch (version) {
case MacOSVersion::MACOS_VER_13_0_PLUS:
return _macos_13_0_plus;
case MacOSVersion::MACOS_VER_13_1_PLUS:
return _macos_13_1_plus;
case MacOSVersion::MACOS_VER_13_2_PLUS:
return _macos_13_2_plus;
case MacOSVersion::MACOS_VER_13_3_PLUS:
return _macos_13_3_plus;
case MacOSVersion::MACOS_VER_14_0_PLUS:
return _macos_14_0_plus;
case MacOSVersion::MACOS_VER_15_0_PLUS:
return _macos_15_0_plus;
default:
return false;
}
}
const char* getLibraryCString(LibraryType libraryType) {
switch (libraryType) {
case LibraryType::INDEXING_KERNELS:
return "TODO";
default:
ET_CHECK_MSG(false, "Unhandled library type!");
}
}
Error
MPSDevice::compileLibrary(LibraryType libraryType) {
Error err = Error::Ok;
NSError* error = nil;
MTLCompileOptions* options = [MTLCompileOptions new];
[options setLanguageVersion:getMetalLanguageVersion(_mtl_device, isMacOS13Plus(MacOSVersion::MACOS_VER_13_0_PLUS))];
[options setFastMathEnabled:YES];
id<MTLLibrary> lib =
[_mtl_device newLibraryWithSource:[NSString stringWithCString:getLibraryCString(libraryType)
encoding:NSASCIIStringEncoding]
options:options
error:&error];
ET_CHECK_OR_RETURN_ERROR(
lib != nil,
Internal,
"Failed to create indexing library, error: %s", [[error description] UTF8String]
);
_m_library_cache[libraryType] = lib;
return err;
}
Error
MPSDevice::compilePSO(LibraryType libraryType, const char* kernelName) {
Error err = Error::Ok;
if (_m_library_cache.find(libraryType) == _m_library_cache.end()) {
ET_LOG(Debug, "Compiling library type: %d", libraryType);
err = compileLibrary(libraryType);
ET_CHECK_OR_RETURN_ERROR(
err == Error::Ok,
Internal,
"An error occured occured while compiling library %d", libraryType
);
}
if (_m_pso_cache.find(kernelName) == _m_pso_cache.end()) {
ET_LOG(Debug, "Compiling kernel: %s", kernelName);
// err = compilePSO(libraryType, kernelName);
}
return err;
}
bool is_macos_13_or_newer(MacOSVersion version) {
return MPSDevice::getInstance()->isMacOS13Plus(version);
}
} // namespace delegate
} // namespace mps
} // namespace backends
} // namespace executorch