from keras.utils import np_utils
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Activation
import numpy as np
np.random.seed(3)
# generating data set
# load training set and test set
(data_train, label_train), (data_test, label_test) = mnist.load_data()
# data_train, data_test are RGB image with shape (num_samples, 3, width, height)
# split training set and test set
data_val = data_train[50000:]
label_val = label_train[50000:]
data_train = data_train[:50000]
label_train = label_train[:50000]
# dataset pre processing
data_train = data_train.reshape(50000,784).astype('float32')/255.0
data_val = data_val.reshape(10000, 784).astype('float32')/255.0
data_test = data_test.reshape(10000, 784).astype('float32')/255.0
# select training, validation set
train_rand_idxs = np.random.choice(50000, 700)
val_rand_idxs = np.random.choice(10000, 300)
data_train = data_train[train_rand_idxs]
label_train = label_train[train_rand_idxs]
data_val = data_val[val_rand_idxs]
label_val = label_val[val_rand_idxs]
# convert to 'one-hot' encoding for label data
label_train = np_utils.to_categorical(label_train)
label_val = np_utils.to_categorical(label_val)
label_test = np_utils.to_categorical(label_test)
# Build model
model = Sequential()
model.add(Dense(units=2, input_dim=28*28, activation='relu'))
model.add(Dense(units=10, activation='softmax'))
# units = dimensionality of the output space
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
# see https://keras.io/losses/, https://keras.io/optimizers/
# conduct learn
from keras.callbacks import EarlyStopping
early_stopping = EarlyStopping(patience = 20)
hist = model.fit(data_train, label_train, epochs=1000, batch_size=10, validation_data=(data_val, label_val), callbacks=[early_stopping])
# show the learning process
%matplotlib inline
import matplotlib.pyplot as plt
fig, loss_ax = plt.subplots()
acc_ax = loss_ax.twinx()
loss_ax.plot(hist.history['loss'], 'y', label='train loss')
loss_ax.plot(hist.history['val_loss'], 'r', label='val loss')
acc_ax.plot(hist.history['acc'], 'b', label='train acc')
acc_ax.plot(hist.history['val_acc'], 'g', label='val acc')
loss_ax.set_xlabel('epoch')
loss_ax.set_ylabel('loss')
acc_ax.set_ylabel('accuracy')
loss_ax.legend(loc='upper left')
acc_ax.legend(loc='lower left')
plt.show()