forked from HIPS/autograd
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_mnist.py
44 lines (34 loc) · 1.39 KB
/
data_mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import array
import gzip
import os
import struct
from urllib.request import urlretrieve
import numpy as np
def download(url, filename):
if not os.path.exists("data"):
os.makedirs("data")
out_file = os.path.join("data", filename)
if not os.path.isfile(out_file):
urlretrieve(url, out_file)
def mnist():
base_url = "http://yann.lecun.com/exdb/mnist/"
def parse_labels(filename):
with gzip.open(filename, "rb") as fh:
magic, num_data = struct.unpack(">II", fh.read(8))
return np.array(array.array("B", fh.read()), dtype=np.uint8)
def parse_images(filename):
with gzip.open(filename, "rb") as fh:
magic, num_data, rows, cols = struct.unpack(">IIII", fh.read(16))
return np.array(array.array("B", fh.read()), dtype=np.uint8).reshape(num_data, rows, cols)
for filename in [
"train-images-idx3-ubyte.gz",
"train-labels-idx1-ubyte.gz",
"t10k-images-idx3-ubyte.gz",
"t10k-labels-idx1-ubyte.gz",
]:
download(base_url + filename, filename)
train_images = parse_images("data/train-images-idx3-ubyte.gz")
train_labels = parse_labels("data/train-labels-idx1-ubyte.gz")
test_images = parse_images("data/t10k-images-idx3-ubyte.gz")
test_labels = parse_labels("data/t10k-labels-idx1-ubyte.gz")
return train_images, train_labels, test_images, test_labels