Skip to content

Commit 65722fe

Browse files
authored
Use torch.accelerator API in mnist examples (#1334)
1 parent 6967ff5 commit 65722fe

File tree

9 files changed

+48
-52
lines changed

9 files changed

+48
-52
lines changed

mnist/main.py

+12-16
Original file line numberDiff line numberDiff line change
@@ -82,39 +82,35 @@ def main():
8282
help='learning rate (default: 1.0)')
8383
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
8484
help='Learning rate step gamma (default: 0.7)')
85-
parser.add_argument('--no-cuda', action='store_true', default=False,
86-
help='disables CUDA training')
87-
parser.add_argument('--no-mps', action='store_true', default=False,
88-
help='disables macOS GPU training')
89-
parser.add_argument('--dry-run', action='store_true', default=False,
85+
parser.add_argument('--no-accel', action='store_true',
86+
help='disables accelerator')
87+
parser.add_argument('--dry-run', action='store_true',
9088
help='quickly check a single pass')
9189
parser.add_argument('--seed', type=int, default=1, metavar='S',
9290
help='random seed (default: 1)')
9391
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
9492
help='how many batches to wait before logging training status')
95-
parser.add_argument('--save-model', action='store_true', default=False,
93+
parser.add_argument('--save-model', action='store_true',
9694
help='For Saving the current Model')
9795
args = parser.parse_args()
98-
use_cuda = not args.no_cuda and torch.cuda.is_available()
99-
use_mps = not args.no_mps and torch.backends.mps.is_available()
96+
97+
use_accel = not args.no_accel and torch.accelerator.is_available()
10098

10199
torch.manual_seed(args.seed)
102100

103-
if use_cuda:
104-
device = torch.device("cuda")
105-
elif use_mps:
106-
device = torch.device("mps")
101+
if use_accel:
102+
device = torch.accelerator.current_accelerator()
107103
else:
108104
device = torch.device("cpu")
109105

110106
train_kwargs = {'batch_size': args.batch_size}
111107
test_kwargs = {'batch_size': args.test_batch_size}
112-
if use_cuda:
113-
cuda_kwargs = {'num_workers': 1,
108+
if use_accel:
109+
accel_kwargs = {'num_workers': 1,
114110
'pin_memory': True,
115111
'shuffle': True}
116-
train_kwargs.update(cuda_kwargs)
117-
test_kwargs.update(cuda_kwargs)
112+
train_kwargs.update(accel_kwargs)
113+
test_kwargs.update(accel_kwargs)
118114

119115
transform=transforms.Compose([
120116
transforms.ToTensor(),

mnist/requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
torch
2-
torchvision==0.20.0
2+
torchvision

mnist_forward_forward/README.md

+1-2
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@ optional arguments:
1616
-h, --help show this help message and exit
1717
--epochs EPOCHS number of epochs to train (default: 1000)
1818
--lr LR learning rate (default: 0.03)
19-
--no_cuda disables CUDA training
20-
--no_mps disables MPS training
19+
--no_accel disables accelerator
2120
--seed SEED random seed (default: 1)
2221
--save_model For saving the current Model
2322
--train_size TRAIN_SIZE

mnist_forward_forward/main.py

+8-16
Original file line numberDiff line numberDiff line change
@@ -102,18 +102,14 @@ def train(self, x_pos, x_neg):
102102
help="learning rate (default: 0.03)",
103103
)
104104
parser.add_argument(
105-
"--no_cuda", action="store_true", default=False, help="disables CUDA training"
106-
)
107-
parser.add_argument(
108-
"--no_mps", action="store_true", default=False, help="disables MPS training"
105+
"--no_accel", action="store_true", help="disables accelerator"
109106
)
110107
parser.add_argument(
111108
"--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
112109
)
113110
parser.add_argument(
114111
"--save_model",
115112
action="store_true",
116-
default=False,
117113
help="For saving the current Model",
118114
)
119115
parser.add_argument(
@@ -126,7 +122,6 @@ def train(self, x_pos, x_neg):
126122
parser.add_argument(
127123
"--save-model",
128124
action="store_true",
129-
default=False,
130125
help="For Saving the current Model",
131126
)
132127
parser.add_argument(
@@ -137,22 +132,19 @@ def train(self, x_pos, x_neg):
137132
help="how many batches to wait before logging training status",
138133
)
139134
args = parser.parse_args()
140-
use_cuda = not args.no_cuda and torch.cuda.is_available()
141-
use_mps = not args.no_mps and torch.backends.mps.is_available()
142-
if use_cuda:
143-
device = torch.device("cuda")
144-
elif use_mps:
145-
device = torch.device("mps")
135+
use_accel = not args.no_accel and torch.accelerator.is_available()
136+
if use_accel:
137+
device = torch.accelerator.current_accelerator()
146138
else:
147139
device = torch.device("cpu")
148140

149141
train_kwargs = {"batch_size": args.train_size}
150142
test_kwargs = {"batch_size": args.test_size}
151143

152-
if use_cuda:
153-
cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True}
154-
train_kwargs.update(cuda_kwargs)
155-
test_kwargs.update(cuda_kwargs)
144+
if use_accel:
145+
accel_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True}
146+
train_kwargs.update(accel_kwargs)
147+
test_kwargs.update(accel_kwargs)
156148

157149
transform = Compose(
158150
[
+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
torch
2-
torchvision==0.20.0
2+
torchvision

mnist_rnn/README.md

+15
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,18 @@ pip install -r requirements.txt
88
python main.py
99
# CUDA_VISIBLE_DEVICES=2 python main.py # to specify GPU id to ex. 2
1010
```
11+
12+
```bash
13+
optional arguments:
14+
-h, --help show this help message and exit
15+
--batch_size input batch_size for training (default:64)
16+
--testing_batch_size input batch size for testing (default: 1000)
17+
--epochs EPOCHS number of epochs to train (default: 14)
18+
--lr LR learning rate (default: 0.1)
19+
--gamma learning rate step gamma (default: 0.7)
20+
--accel enables accelerator
21+
--seed SEED random seed (default: 1)
22+
--save_model For saving the current Model
23+
--log_interval how many batches to wait before logging training status
24+
--dry-run quickly check a single pass
25+
```

mnist_rnn/main.py

+8-14
Original file line numberDiff line numberDiff line change
@@ -91,32 +91,26 @@ def main():
9191
help='learning rate (default: 0.1)')
9292
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
9393
help='learning rate step gamma (default: 0.7)')
94-
parser.add_argument('--cuda', action='store_true', default=False,
95-
help='enables CUDA training')
96-
parser.add_argument('--mps', action="store_true", default=False,
97-
help="enables MPS training")
98-
parser.add_argument('--dry-run', action='store_true', default=False,
94+
parser.add_argument('--accel', action='store_true',
95+
help='enables accelerator')
96+
parser.add_argument('--dry-run', action='store_true',
9997
help='quickly check a single pass')
10098
parser.add_argument('--seed', type=int, default=1, metavar='S',
10199
help='random seed (default: 1)')
102100
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
103101
help='how many batches to wait before logging training status')
104-
parser.add_argument('--save-model', action='store_true', default=False,
102+
parser.add_argument('--save-model', action='store_true',
105103
help='for Saving the current Model')
106104
args = parser.parse_args()
107105

108-
if args.cuda and not args.mps:
109-
device = "cuda"
110-
elif args.mps and not args.cuda:
111-
device = "mps"
106+
if args.accel:
107+
device = torch.accelerator.current_accelerator()
112108
else:
113-
device = "cpu"
114-
115-
device = torch.device(device)
109+
device = torch.device("cpu")
116110

117111
torch.manual_seed(args.seed)
118112

119-
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
113+
kwargs = {'num_workers': 1, 'pin_memory': True} if args.accel else {}
120114
train_loader = torch.utils.data.DataLoader(
121115
datasets.MNIST('../data', train=True, download=True,
122116
transform=transforms.Compose([

mnist_rnn/requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
torch
2-
torchvision==0.20.0
2+
torchvision

run_python_examples.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ function mnist() {
9393
uv run main.py --epochs 1 --dry-run || error "mnist example failed"
9494
}
9595
function mnist_forward_forward() {
96-
uv run main.py --epochs 1 --no_mps --no_cuda || error "mnist forward forward failed"
96+
uv run main.py --epochs 1 --no_accel || error "mnist forward forward failed"
9797

9898
}
9999
function mnist_hogwild() {

0 commit comments

Comments
 (0)