90 lines
1.8 KiB
Python
90 lines
1.8 KiB
Python
from mnist_loader import load_data
|
|
import numpy as np
|
|
import os
|
|
from PIL import Image
|
|
import resource
|
|
|
|
|
|
def vectorized_result(j):
|
|
"""Return a 10-dimensional unit vector with a 1.0 in the jth
|
|
position and zeroes elsewhere. This is used to convert a digit
|
|
(0...9) into a corresponding desired output from the neural
|
|
network."""
|
|
e = np.zeros((10, 1))
|
|
e[j] = 1.0
|
|
return e
|
|
|
|
def loadSet(path):
|
|
|
|
filelist = []
|
|
|
|
for root, dirs, files in os.walk(path):
|
|
for file in files:
|
|
filelist.append(os.path.join(root,file))
|
|
|
|
i = 0
|
|
pixels = []
|
|
result = []
|
|
|
|
|
|
for name in filelist:
|
|
|
|
if i >= 100:
|
|
|
|
break
|
|
|
|
if ".png" in name:
|
|
|
|
with Image.open(path + "/" + name.split("/")[-1]) as im:
|
|
|
|
pix = im.load()
|
|
temparray = []
|
|
|
|
result.append(name.split("/")[-1][0])
|
|
|
|
for x in range(im.size[0]):
|
|
|
|
for y in range(im.size[1]):
|
|
|
|
temparray.append(pix[x, y] / 255)
|
|
|
|
pixels.append(temparray)
|
|
print(temparray)
|
|
print(str("%.2f" % round(i / (len(filelist) if len(filelist) < 100 else 100) * 100, 2)) + "% Done, ram usage: " + str("%.2f" % round(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / (1024*1024), 2)) + "Go", end = '\r')
|
|
i += 1
|
|
|
|
print("max ram usage: " + str(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / (1024*1024)) + "Go")
|
|
|
|
return (pixels, result)
|
|
|
|
|
|
def loadTrainingSet(path):
|
|
|
|
print("importing training set...")
|
|
|
|
set = loadSet(path)
|
|
|
|
training_inputs = [np.reshape(x, (262144, 1)) for x in set[0]]
|
|
training_results = [vectorized_result(int(y)) for y in set[1]]
|
|
training_data = zip(training_inputs, training_results)
|
|
|
|
return training_data
|
|
|
|
def loadTestSet(path):
|
|
|
|
print("importing test set...")
|
|
|
|
set = loadSet(path)
|
|
|
|
test_inputs = [np.reshape(x, (262144, 1)) for x in set[0]]
|
|
test_data = zip(test_inputs, set[1])
|
|
|
|
return test_data
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
print(loadSet("set")[0]) |