changeset 1:59ba73797553 default tip

Implement MNIST convolutional network
author Lewin Bormann <lbo@spheniscida.de>
date Fri, 25 Dec 2020 11:53:54 +0100
parents e93aac1287c0
children
files sec6_3/complete_convolutional.py
diffstat 1 files changed, 46 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/sec6_3/complete_convolutional.py	Fri Dec 25 11:53:54 2020 +0100
@@ -0,0 +1,46 @@
+#!/usr/bin/env python3
+
+import numpy as np
+import keras
+from keras.models import Sequential
+from keras.layers import (Dense, Dropout, Activation, Flatten,
+        Conv2D, MaxPooling2D)
+from keras.utils import np_utils
+from keras.datasets import mnist
+
+
+# Import MNIST data
+(train_samples, train_labels), (test_samples, test_labels) = mnist.load_data()
+
+# Preprocess MNIST images
+train_samples = train_samples.reshape(train_samples.shape[0], 28, 28, 1)
+test_samples = test_samples.reshape(test_samples.shape[0], 28, 28, 1)
+train_samples = train_samples.astype(np.float32)
+test_samples = test_samples.astype(np.float32)
+train_samples = train_samples / 255
+test_samples = test_samples / 255
+
+# Convert labels into array of 10-dim vectors
+cat_train_labels = np_utils.to_categorical(train_labels, 10)
+cat_test_labels = np_utils.to_categorical(test_labels, 10)
+
+# Build convolutional network
+convnet = Sequential()
+convnet.add(Conv2D(32, 4, 4, activation='relu', input_shape=(28,28,1)))
+convnet.add(MaxPooling2D(pool_size=(2,2), strides=1))
+convnet.add(Conv2D(32, 3, 3, activation='relu'))
+convnet.add(MaxPooling2D(pool_size=(2,2), strides=1))
+convnet.add(Dropout(0.3))
+convnet.add(Flatten())
+convnet.add(Dense(10, activation='softmax'))
+
+convnet.compile(loss=keras.losses.CategoricalCrossentropy(from_logits=False), optimizer='sgd',
+                metrics=['accuracy'])
+convnet.summary()
+
+convnet.fit(train_samples, cat_train_labels, batch_size=32, epochs=20,
+            verbose=1)
+metrics = convnet.evaluate(test_samples, cat_test_labels, verbose=1)
+print('{}: {:.2f}'.format(convnet.metrics_names[1], metrics[1]*100))
+print(convnet.metrics_names)
+predictions = convnet.predict(test_samples)