First commit
This commit is contained in:
parent
26247b7afb
commit
09ec5c7e62
84 changed files with 2578 additions and 0 deletions
60
src/expand_mnist.py
Normal file
60
src/expand_mnist.py
Normal file
|
@ -0,0 +1,60 @@
|
|||
"""expand_mnist.py
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Take the 50,000 MNIST training images, and create an expanded set of
|
||||
250,000 images, by displacing each training image up, down, left and
|
||||
right, by one pixel. Save the resulting file to
|
||||
../data/mnist_expanded.pkl.gz.
|
||||
|
||||
Note that this program is memory intensive, and may not run on small
|
||||
systems.
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
#### Libraries
|
||||
|
||||
# Standard library
|
||||
import cPickle
|
||||
import gzip
|
||||
import os.path
|
||||
import random
|
||||
|
||||
# Third-party libraries
|
||||
import numpy as np
|
||||
|
||||
print("Expanding the MNIST training set")
|
||||
|
||||
if os.path.exists("../data/mnist_expanded.pkl.gz"):
|
||||
print("The expanded training set already exists. Exiting.")
|
||||
else:
|
||||
f = gzip.open("../data/mnist.pkl.gz", 'rb')
|
||||
training_data, validation_data, test_data = cPickle.load(f)
|
||||
f.close()
|
||||
expanded_training_pairs = []
|
||||
j = 0 # counter
|
||||
for x, y in zip(training_data[0], training_data[1]):
|
||||
expanded_training_pairs.append((x, y))
|
||||
image = np.reshape(x, (-1, 28))
|
||||
j += 1
|
||||
if j % 1000 == 0: print("Expanding image number", j)
|
||||
# iterate over data telling us the details of how to
|
||||
# do the displacement
|
||||
for d, axis, index_position, index in [
|
||||
(1, 0, "first", 0),
|
||||
(-1, 0, "first", 27),
|
||||
(1, 1, "last", 0),
|
||||
(-1, 1, "last", 27)]:
|
||||
new_img = np.roll(image, d, axis)
|
||||
if index_position == "first":
|
||||
new_img[index, :] = np.zeros(28)
|
||||
else:
|
||||
new_img[:, index] = np.zeros(28)
|
||||
expanded_training_pairs.append((np.reshape(new_img, 784), y))
|
||||
random.shuffle(expanded_training_pairs)
|
||||
expanded_training_data = [list(d) for d in zip(*expanded_training_pairs)]
|
||||
print("Saving expanded data. This may take a few minutes.")
|
||||
f = gzip.open("../data/mnist_expanded.pkl.gz", "w")
|
||||
cPickle.dump((expanded_training_data, validation_data, test_data), f)
|
||||
f.close()
|
Loading…
Add table
Add a link
Reference in a new issue