-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinstall_check.py
310 lines (263 loc) · 9.64 KB
/
install_check.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy as np
import paddle
__all__ = []
def _simple_network():
"""
Define a simple network composed by a single linear layer.
"""
input = paddle.static.data(
name="input", shape=[None, 2, 2], dtype="float32"
)
weight = paddle.create_parameter(
shape=[2, 3],
dtype="float32",
attr=paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0.1)),
)
bias = paddle.create_parameter(shape=[3], dtype="float32")
linear_out = paddle.nn.functional.linear(x=input, weight=weight, bias=bias)
out = paddle.tensor.sum(linear_out)
return input, out, weight
def _prepare_data():
"""
Prepare feeding data for simple network. The shape is [1, 2, 2].
"""
# Prepare the feeding data.
np_input_single = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
return np_input_single.reshape(1, 2, 2)
def _is_cuda_available():
"""
Check whether CUDA is available.
"""
try:
assert len(paddle.static.cuda_places()) > 0
return True
except Exception as e:
logging.warning(
"You are using GPU version PaddlePaddle, but there is no GPU "
"detected on your machine. Maybe CUDA devices is not set properly."
"\n Original Error is {}".format(e)
)
return False
def _is_xpu_available():
"""
Check whether XPU is available.
"""
try:
assert len(paddle.static.xpu_places()) > 0
return True
except Exception as e:
logging.warning(
"You are using XPU version PaddlePaddle, but there is no XPU "
"detected on your machine. Maybe XPU devices is not set properly."
"\n Original Error is {}".format(e)
)
return False
def _run_dygraph_single(use_cuda, use_xpu, use_custom, custom_device_name):
"""
Testing the simple network in dygraph mode using one CPU/GPU/XPU.
Args:
use_cuda (bool): Whether running with CUDA.
use_xpu (bool): Whether running with XPU.
"""
paddle.disable_static()
if use_cuda:
paddle.set_device('gpu')
elif use_xpu:
paddle.set_device('xpu')
elif use_custom:
paddle.set_device(custom_device_name)
else:
paddle.set_device('cpu')
weight_attr = paddle.ParamAttr(
name="weight", initializer=paddle.nn.initializer.Constant(value=0.5)
)
bias_attr = paddle.ParamAttr(
name="bias", initializer=paddle.nn.initializer.Constant(value=1.0)
)
linear = paddle.nn.Linear(
2, 4, weight_attr=weight_attr, bias_attr=bias_attr
)
input_np = _prepare_data()
input_tensor = paddle.to_tensor(input_np)
linear_out = linear(input_tensor)
out = paddle.tensor.sum(linear_out)
out.backward()
opt = paddle.optimizer.Adam(
learning_rate=0.001, parameters=linear.parameters()
)
opt.step()
def _run_static_single(use_cuda, use_xpu, use_custom, custom_device_name):
"""
Testing the simple network with executor running directly, using one CPU/GPU/XPU.
Args:
use_cuda (bool): Whether running with CUDA.
use_xpu (bool): Whether running with XPU.
"""
paddle.enable_static()
with paddle.static.scope_guard(paddle.static.Scope()):
train_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
startup_prog.random_seed = 1
with paddle.static.program_guard(train_prog, startup_prog):
input, out, weight = _simple_network()
param_grads = paddle.static.append_backward(
out, parameter_list=[weight.name]
)[0]
if use_cuda:
place = paddle.CUDAPlace(0)
elif use_xpu:
place = paddle.XPUPlace(0)
elif use_custom:
place = paddle.CustomPlace(custom_device_name, 0)
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)
exe.run(
train_prog,
feed={input.name: _prepare_data()},
fetch_list=[out.name, param_grads[1].name],
)
paddle.disable_static()
def train_for_run_parallel():
"""
train script for parallel training check
"""
# to avoid cyclic import
class LinearNet(paddle.nn.Layer):
"""
simple fc network for parallel training check
"""
def __init__(self):
super().__init__()
self._linear1 = paddle.nn.Linear(10, 10)
self._linear2 = paddle.nn.Linear(10, 1)
def forward(self, x):
"""
forward
"""
return self._linear2(self._linear1(x))
paddle.distributed.init_parallel_env()
layer = LinearNet()
dp_layer = paddle.DataParallel(layer)
loss_fn = paddle.nn.MSELoss()
adam = paddle.optimizer.Adam(
learning_rate=0.001, parameters=dp_layer.parameters()
)
inputs = paddle.randn([10, 10], 'float32')
outputs = dp_layer(inputs)
labels = paddle.randn([10, 1], 'float32')
loss = loss_fn(outputs, labels)
loss.backward()
adam.step()
adam.clear_grad()
def _run_parallel(device_list):
"""
Testing the simple network in data parallel mode, using multiple CPU/GPU.
Args:
use_cuda (bool): Whether running with CUDA.
use_xpu (bool): Whether running with XPU.
device_list (int): The specified devices.
"""
paddle.distributed.spawn(train_for_run_parallel, nprocs=len(device_list))
def run_check():
"""
Check whether PaddlePaddle is installed correctly and running successfully
on your system.
Examples:
.. code-block:: python
>>> import paddle
>>> paddle.utils.run_check()
>>> # doctest: +SKIP('the output will change in different run')
Running verify PaddlePaddle program ...
I0818 15:35:08.335391 30540 program_interpreter.cc:173] New Executor is Running.
I0818 15:35:08.398319 30540 interpreter_util.cc:529] Standalone Executor is Used.
PaddlePaddle works well on 1 CPU.
PaddlePaddle is installed successfully! Let's start deep learning with PaddlePaddle now.
"""
print("Running verify PaddlePaddle program ... ")
use_cuda = False
use_xpu = False
use_custom = False
custom_device_name = None
if paddle.is_compiled_with_cuda():
use_cuda = _is_cuda_available()
elif paddle.is_compiled_with_xpu():
use_xpu = _is_xpu_available()
elif len(paddle.framework.core.get_all_custom_device_type()) > 0:
use_custom = True
if len(paddle.framework.core.get_all_custom_device_type()) > 1:
logging.warning(
"More than one kind of custom devices detected, but run check would only be executed on {}.".format(
paddle.framework.core.get_all_custom_device_type()[0]
)
)
if use_cuda:
device_str = "GPU"
device_list = paddle.static.cuda_places()
elif use_xpu:
device_str = "XPU"
device_list = paddle.static.xpu_places()
elif use_custom:
device_str = paddle.framework.core.get_all_custom_device_type()[0]
custom_device_name = device_str
device_list = list(
range(
paddle.framework.core.get_custom_device_count(
custom_device_name
)
)
)
else:
device_str = "CPU"
device_list = paddle.static.cpu_places(device_count=1)
device_count = len(device_list)
_run_static_single(use_cuda, use_xpu, use_custom, custom_device_name)
_run_dygraph_single(use_cuda, use_xpu, use_custom, custom_device_name)
print(f"PaddlePaddle works well on 1 {device_str}.")
try:
if len(device_list) > 1:
if use_custom:
import os
os.environ['PADDLE_DISTRI_BACKEND'] = "xccl"
_run_parallel(device_list)
print(
"PaddlePaddle works well on {} {}s.".format(
device_count, device_str
)
)
print(
"PaddlePaddle is installed successfully! Let's start deep learning with PaddlePaddle now."
)
except Exception as e:
logging.warning(
"PaddlePaddle meets some problem with {} {}s. This may be caused by:"
"\n 1. There is not enough GPUs visible on your system"
"\n 2. Some GPUs are occupied by other process now"
"\n 3. NVIDIA-NCCL2 is not installed correctly on your system. Please follow instruction on https://github.com/NVIDIA/nccl-tests "
"\n to test your NCCL, or reinstall it following https://docs.nvidia.com/deeplearning/sdk/nccl-install-guide/index.html".format(
device_count, device_str
)
)
logging.warning(f"\n Original Error is: {e}")
print(
"PaddlePaddle is installed successfully ONLY for single {}! "
"Let's start deep learning with PaddlePaddle now.".format(
device_str
)
)
raise e