{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Multi-Layer Perceptron: Classification of handwritten digits (MNIST)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# importing functions and classes from our framework\n", "from dataset import Dataset\n", "from nn import MLP\n", "from layers import Dense\n", "# other imports\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "np.random.seed(1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Demo" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "

\n", "In the following, we will demonstrate how our Multi-Layer Perceptron can be used to classify handwritten digits.\n", "For that purpose, we are loading a pretrained model whose architecture can be seen below. Note that unlike the previous layers, the output of the neural network uses Softmax as an activation function. This combined with the Crossentropy loss function is a good setup for classification tasks. Softmax returns a probability distribution over the available classes and Crossentropy yields a faster convergence for probability distributions than the Mean-Squared Error.\n", "

" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "-------------------- MULTI LAYER PERCEPTRON (MLP) --------------------\n", "\n", "HIDDEN LAYERS = 1 \n", "TOTAL PARAMETERS = 84060 \n", "\n", " *** 1. Layer: *** \n", "------------------------\n", "DENSE 784 -> 100 [ReLU]\n", "------------------------\n", "Total parameters: 78500 \n", "---> WEIGHTS: (100, 784)\n", "---> BIASES: (100,)\n", "------------------------\n", "\n", " *** 2. Layer: *** \n", "-----------------------\n", "DENSE 100 -> 50 [ReLU]\n", "-----------------------\n", "Total parameters: 5050 \n", "---> WEIGHTS: (50, 100)\n", "---> BIASES: (50,)\n", "-----------------------\n", "\n", " *** 3. Layer: *** \n", "-------------------------\n", "DENSE 50 -> 10 [Softmax]\n", "-------------------------\n", "Total parameters: 510 \n", "---> WEIGHTS: (10, 50)\n", "---> BIASES: (10,)\n", "-------------------------\n", "\n", "----------------------------------------------------------------------\n", "\n" ] } ], "source": [ "classifier = MLP()\n", "classifier.load(\"mnist_classifier\") # classifier is saved in '/models/mnist_classifier'\n", "print(classifier)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We now load a mini-batch of size 1 from the shuffled MNIST dataset and take the input image from the test set." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "dataset = Dataset(name = \"mnist\", train_size = 60000, test_size = 10000, batch_size = 1)\n", "randomBatch = next(dataset.batches())\n", "inputImage = randomBatch[1][0]\n", "inputLabel = randomBatch[1][1]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The first image from the test set is shown below." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAPeklEQVR4nO3df5BddXnH8c+T3+QHkPBjTZMUCCYiWAx1DVBSRWkxIBaoFmE6TFoZF0foyAx2iHQ6Mk6npkgARaSGQo0Wg0z52Q4WQkSZAA3Z0BgSgga2QRJDYowjgZKf+/SPPTgL7Hnucn9nn/drZufePc899zwc9pNz7/3ec77m7gIw9A1rdQMAmoOwA0kQdiAJwg4kQdiBJEY0c2OjbLSP0bhmbhJIZZde0x7fbQPVagq7mc2V9HVJwyX9i7sviB4/RuN0sp1RyyYBBFb4stJa1S/jzWy4pJslnSXpeEkXmdnx1T4fgMaq5T37bEnPu3uPu++RdKekc+vTFoB6qyXsUyS91O/3TcWyNzGzLjPrNrPuvdpdw+YA1KLhn8a7+yJ373T3zpEa3ejNAShRS9g3S5rW7/epxTIAbaiWsK+UNMPMjjGzUZIulPRAfdoCUG9VD725+z4zu1zSQ+obervd3dfVrTMAdVXTOLu7PyjpwTr1AqCB+LoskARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSdQ0iyuGvn0f/UBY3/iJkWF9+SevK60dOXxsuO6ne84M6z/bfmRYd7fS2hHfPihcd9RD3WH9QFRT2M1so6SdkvZL2ufunfVoCkD91ePI/hF3316H5wHQQLxnB5KoNewu6WEzW2VmXQM9wMy6zKzbzLr3aneNmwNQrVpfxs9x981mdqSkpWb2nLs/1v8B7r5I0iJJOtgmeY3bA1Clmo7s7r65uN0m6V5Js+vRFID6qzrsZjbOzCa8cV/SmZLW1qsxAPVVy8v4Dkn3mtkbz/N9d/+vunSFutl1Tvxi64JrfxjWPz3hG2F94rB4vLpX5fVexe/qlkx/KKwPm14+jl7p+ecf9cFw3XXbTwjrvmpdWG9HVYfd3Xskvb+OvQBoIIbegCQIO5AEYQeSIOxAEoQdSIJTXNuAjYj/NwyfNiWs/+9fltef/NzCcN2xNiqsS2Mq1Ku31/eH9VMXXhHWx23pjdf/26dKawvetTJcd+4/TQ3rI/4kLLcljuxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kATj7G2g0jj6fcvvCevDVH6qZ6/icfS5688P61cdHZ8Ce8ZB1V9q7B+3x5epnnr/L8P6vp6NYf0JO6W8eF08zl7pv3uh4lNg2xFHdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnH2NrD7qMPCejSOXskJ37s8rM+a8/OwXss4uiR95Jm/KK2NvnZiuO6InlU1bTtSyz49UHFkB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkGGdvAz2fHBnWK01tHI0Zr7v4m+G6r/bG4+j/sL0zrP/4qtPC+riHny4v9vaE6zZSpX06FFU8spvZ7Wa2zczW9ls2ycyWmtmG4jb+dgSAlhvMy/jvSJr7lmXzJS1z9xmSlhW/A2hjFcPu7o9J2vGWxedKWlzcXyzpvDr3BaDOqn3P3uHuW4r7L0vqKHugmXVJ6pKkMRpb5eYA1KrmT+Pd3aXyTzvcfZG7d7p750iNrnVzAKpUbdi3mtlkSSput9WvJQCNUG3YH5A0r7g/T9L99WkHQKNUfM9uZksknS7pcDPbJOnLkhZIusvMLpH0oqQLGtnkUHfos/G/uY+fHY/DR7668eywvvW+3w/rHTc9EdZHKb7+eiMNP+E9Yf2qr/xb1c+9bOeBd134SiqG3d0vKimdUedeADQQX5cFkiDsQBKEHUiCsANJEHYgCU5xbQNH3PJkWP/qLSfW8OybwmpHhXo7e/HP4ktwf3zsb6t+7hMOivfLak2t+rlbhSM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBODsOWLuO7A3r0SW2f9P7erjuNxaWTzUtSYcp/m5EO+LIDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJMM6OA9b0EzeH9Wha5pP//cpw3XffeuCNo1fCkR1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkmCcHS1jI+I/v62fmx3Wu4/7Zli/77WJpbXjbno5XHdfWD0wVTyym9ntZrbNzNb2W3aNmW02s9XFTzwJOICWG8zL+O9ImjvA8hvcfVbx82B92wJQbxXD7u6PSdrRhF4ANFAtH9BdbmZripf5pW+OzKzLzLrNrHuvdtewOQC1qDbst0g6VtIsSVskLSx7oLsvcvdOd+8cqdFVbg5AraoKu7tvdff97t4r6VZJ8cemAFquqrCb2eR+v54vaW3ZYwG0h4rj7Ga2RNLpkg43s02SvizpdDObJcklbZR0aQN7xBBVaRz9qS/dFNY/v3lOWF9z4/tLawf3/He47lBUMezuftEAi29rQC8AGoivywJJEHYgCcIOJEHYgSQIO5BEmlNcd3zm1LD++jmvVP3ca0+5I6zv93hq4UrufPWIsL7g2Y+V1vb99NBw3WNuei6s7/91badFRPt98RevD9d99PUJYf3xu08K61O+/0RYz4YjO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kYe7l09rW28E2yU+2M5q2vf5mrIyvknPD78Vjskt2dpTWfrHn8HDd/R7/mzph+K6w/jcTN4T1YbLSWjRtsST9z574OwD/1xvvt79+5JKw/pOzbiitTbB4v3z4xi+G9ckLGUd/qxW+TK/4jgH/IDiyA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EASQ+Z89g03nRzW75p8Y1if//JpYX1d13tLa75qXbhuJTY6Pm/7hx/8UFj/5R+PLa3tPfG1cN3//KNvhfWTRsXHg+fP+XZY79VBpbXj7rosXPfdjKPXFUd2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUhiyIyzj53yaly3UWH9npWdYX3mqqfecU+D5bt3h/Vhy1eH9anLy2vbu+Lr5X9r5ofD+tfetSKs1+LJTy0M63Nej89nP+bqJ+vZzpBX8chuZtPM7FEze9bM1pnZF4rlk8xsqZltKG4nNr5dANUazMv4fZKudPfjJZ0i6TIzO17SfEnL3H2GpGXF7wDaVMWwu/sWd3+6uL9T0npJUySdK2lx8bDFks5rVJMAaveO3rOb2dGSTpK0QlKHu28pSi9LGvAibWbWJalLksao/DvcABpr0J/Gm9l4SXdLusLd3zQLovddtXLAKxu6+yJ373T3zpGKL14IoHEGFXYzG6m+oN/h7vcUi7ea2eSiPlnStsa0CKAeKr6MNzOTdJuk9e7ef47dByTNk7SguL2/IR0O0mlTe8J6dLllSTp8xfB6tlNXu86ZHdY75r9QWnvwmJtr2vapqy8M64d9Kd5vr7znkNLavK/8R7ju+nlx72c+8tmwPuJHq8J6NoN5z36apIslPWNmbwz4Xq2+kN9lZpdIelHSBY1pEUA9VAy7uy+XSg+LrZnxAcA7xtdlgSQIO5AEYQeSIOxAEoQdSGLInOLaszOeNrnS1MXvu3RtWH9u1ymltfGb4lNUX/hU/M3BmX/wUlj/5+nxZbBnjhxTWrv+NzPCdf91ycfC+lELnw7rvbvi6abHrymv3bc07u36RfFgz/hj49OWD/tRWE6HIzuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJGF9F5lpjoNtkp9sjTlRzj5wQliffPMvwvqiaT+uetuVzpWvNMZfybo9+8L6n//k86W19/791nDdfS9tqqontKcVvkyv+I4B/yA5sgNJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEkPmfHZftS6sb/34pLA+82uXhvWxh7xeWjv0BxPCdX87Pf439ZCe3rB+6OPxdwRmbC6/Pno8Qo9MOLIDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKDmZ99mqTvSuqQ5JIWufvXzewaSZ+V9KvioVe7+4ONarRW+3+9I6zP/Excr8X4GtdnrBz1MJgv1eyTdKW7P21mEyStMrOlRe0Gd7+uce0BqJfBzM++RdKW4v5OM1svaUqjGwNQX+/oPbuZHS3pJEkrikWXm9kaM7vdzCaWrNNlZt1m1r1X8TRJABpn0GE3s/GS7pZ0hbu/IukWScdKmqW+I//CgdZz90Xu3ununSMVz3kGoHEGFXYzG6m+oN/h7vdIkrtvdff97t4r6VZJsxvXJoBaVQy7mZmk2yStd/fr+y2f3O9h50uKp0EF0FKD+TT+NEkXS3rGzFYXy66WdJGZzVLfcNxGSfE5ogBaajCfxi+XBrwwetuOqQN4O75BByRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSMLcvXkbM/uVpBf7LTpc0vamNfDOtGtv7dqXRG/VqmdvR7n7EQMVmhr2t23crNvdO1vWQKBde2vXviR6q1azeuNlPJAEYQeSaHXYF7V4+5F27a1d+5LorVpN6a2l79kBNE+rj+wAmoSwA0m0JOxmNtfMfmZmz5vZ/Fb0UMbMNprZM2a22sy6W9zL7Wa2zczW9ls2ycyWmtmG4nbAOfZa1Ns1Zra52HerzezsFvU2zcweNbNnzWydmX2hWN7SfRf01ZT91vT37GY2XNLPJf2ppE2SVkq6yN2fbWojJcxso6ROd2/5FzDM7EOSXpX0XXd/X7HsWkk73H1B8Q/lRHe/qk16u0bSq62exruYrWhy/2nGJZ0n6a/Uwn0X9HWBmrDfWnFkny3peXfvcfc9ku6UdG4L+mh77v6YpB1vWXyupMXF/cXq+2NpupLe2oK7b3H3p4v7OyW9Mc14S/dd0FdTtCLsUyS91O/3TWqv+d5d0sNmtsrMulrdzAA63H1Lcf9lSR2tbGYAFafxbqa3TDPeNvuumunPa8UHdG83x93/UNJZki4rXq62Je97D9ZOY6eDmsa7WQaYZvx3Wrnvqp3+vFatCPtmSdP6/T61WNYW3H1zcbtN0r1qv6mot74xg25xu63F/fxOO03jPdA042qDfdfK6c9bEfaVkmaY2TFmNkrShZIeaEEfb2Nm44oPTmRm4ySdqfabivoBSfOK+/Mk3d/CXt6kXabxLptmXC3edy2f/tzdm/4j6Wz1fSL/gqS/a0UPJX1Nl/TT4mddq3uTtER9L+v2qu+zjUskHSZpmaQNkh6RNKmNevuepGckrVFfsCa3qLc56nuJvkbS6uLn7Fbvu6Cvpuw3vi4LJMEHdEAShB1IgrADSRB2IAnCDiRB2IEkCDuQxP8DH3F/xflsqvEAAAAASUVORK5CYII=\n", "text/plain": [ "

" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.imshow(inputImage.reshape(28,28))\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# helper function\n", "def prettyPerc(val, length = 5):\n", " out = str(round(val*100,2))\n", " return (length - len(out)) * \" \" + out" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we feed the sample image through the neural network and show the predicted probability distribution over the classes." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Probability distribution of classification:\n", "\n", "0: 0.0 %\n", "1: 0.0 %\n", "2: 99.46 %\n", "3: 0.45 %\n", "4: 0.0 %\n", "5: 0.02 %\n", "6: 0.0 %\n", "7: 0.0 %\n", "8: 0.07 %\n", "9: 0.0 %\n" ] } ], "source": [ "prediction = classifier.predict(inputImage)\n", "print(\"Probability distribution of classification:\\n\")\n", "for i in range(10):\n", " print(f\"{i}: {prettyPerc(prediction[i][0])} %\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see, the first handwritten has been predicted very accurately." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicted digit: 2\n", "Label: 2\n" ] } ], "source": [ "print(f\"Predicted digit: {np.argmax(prediction)}\")\n", "print(f\"Label: {np.argmax(inputLabel)}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Obviously our network still misclassifies some images. Thus we will loop over the test set, until we find an image, where the classification doesn't match the label of the image." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "while np.argmax(prediction) == np.argmax(inputLabel):\n", " randomBatch = next(dataset.batches())\n", " inputImage = randomBatch[1][0]\n", " inputLabel = randomBatch[1][1]\n", " prediction = classifier.predict(inputImage)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAPIElEQVR4nO3df5BV9XnH8c/DilARRCDubHGrojgtDVMSN5hUxphx4ihmAjgdC40pmSHddEYabVNTtTNK80dLYhMnjT8ya2CCLdHEJEQmoVGyZcamUWQxlB8Sg2GwQFYwIQpag+zu0z/2YFfd8z3LPfcXPO/XzM69e5577n24w2fPvfd7v+dr7i4Ap75RjW4AQH0QdiAIwg4EQdiBIAg7EMRp9Xyw022Mj9W4ej4kEMpv9Zre8KM2XK1U2M3saklfltQi6Wvuvjx1+7Eap0vtyjIPCSBho3fn1ip+GW9mLZLulXSNpBmSFpnZjErvD0BtlXnPPlvS8+6+293fkPSwpHnVaQtAtZUJ+1RJe4f8vi/b9hZm1mlmPWbWc0xHSzwcgDJq/mm8u3e5e4e7d4zWmFo/HIAcZcK+X1L7kN/PzbYBaEJlwr5J0nQzu8DMTpe0UNLa6rQFoNoqHnpz9z4zWyrpMQ0Ova109x1V6wxAVZUaZ3f3dZLWVakXADXE12WBIAg7EARhB4Ig7EAQhB0IgrADQRB2IAjCDgRB2IEgCDsQBGEHgiDsQBCEHQiCsANBEHYgCMIOBEHYgSAIOxAEYQeCIOxAEIQdCIKwA0EQdiAIwg4EQdiBIAg7EARhB4Ig7EAQhB0IgrADQZRastnM9kg6IqlfUp+7d1SjKQDVVyrsmQ+5+6+qcD8AaoiX8UAQZcPukh43s81m1jncDcys08x6zKznmI6WfDgAlSr7Mn6Ou+83s3MkrTezn7n7E0Nv4O5dkrokaYJN8pKPB6BCpY7s7r4/uzwoaY2k2dVoCkD1VRx2MxtnZuOPX5d0laTt1WoMQHWVeRnfKmmNmR2/n2+4+w+r0hWax+yZyfIvbm6p+K5P23VGst43/X+T9b+e1Z2sD3j+seybey9J7rth5iPJ+vz3z0vW+/buS9YboeKwu/tuSX9UxV4A1BBDb0AQhB0IgrADQRB2IAjCDgRRjYkwOInt/sIHkvXuhXcl620tv5OsD2ggtzbqg+ljTWpfSRpVcKxK7f+XE3eXeuyWf+tL1vs+mCw3BEd2IAjCDgRB2IEgCDsQBGEHgiDsQBCEHQiCcfbgbv7I95P1onH0UbJk/asvX5RbWzB+R3LfazYPe6azN722b3yy/tx19+XWRlt6au6xgnMqHb3lXekb6MWCev1xZAeCIOxAEIQdCIKwA0EQdiAIwg4EQdiBIBhnPwX8ekn+nPSrlv5Xct/Os/Yk60XzuouOF9/4n/yFfX9wx+XJfX9307ZkfddXLk3WU70XjaPf+/KF6Rs8ne6tGXFkB4Ig7EAQhB0IgrADQRB2IAjCDgRB2IEgGGdvAi0Tz0rWj743f064JG383L25tQGlB5SL5qNfse1Pk/Uz7xiXrE9IjEcXDHUX2nXd/cn6pqP5x7KPf3tpct9pn32yop6aWeGR3cxWmtlBM9s+ZNskM1tvZruyy7Nr2yaAskbyMv7rkq5+27ZbJXW7+3RJ3dnvAJpYYdjd/QlJh962eZ6kVdn1VZLmV7kvAFVW6Xv2Vnfvza6/KKk174Zm1impU5LG6owKHw5AWaU/jXd3V+KzFnfvcvcOd+8YrTFlHw5AhSoN+wEza5Ok7PJg9VoCUAuVhn2tpMXZ9cWSHq1OOwBqpfA9u5k9JOkKSVPMbJ+kOyUtl/QtM1si6QVJ19eyyZNd0Rron752XbLeOfFHyfpA4m920Xz02f90U7Le9uD2ZL3/cHqd85Te7/1Bsu6e/g7AgDYn6zc8+cnc2oWn4Dh6kcKwu/uinNKVVe4FQA3xdVkgCMIOBEHYgSAIOxAEYQeCYIprFbz2w2nJ+rMz70nWi6aZpobWJOlA/+u5tQXLbknue87KnyTr/cmqdFr7ucn6z/+qPbf2s/flT82ViqfnXvzvn0rWZyzrza31Jfc8NXFkB4Ig7EAQhB0IgrADQRB2IAjCDgRB2IEgGGevgg0zH0nWyy57fPnW9Azis+7IP93XpE0lp3LOnpkst9z1UrI+7ZZXc2ub/iQ9jv43t92YrF/8zaeS9Yhj6Skc2YEgCDsQBGEHgiDsQBCEHQiCsANBEHYgCMbZq+CSTTck6x85b0ey/vg9lyXrk1ekx8pTo9VF8813fnZqsv7cdfcl612vnJ+sX/LwntzanQs/kdx3/NPpcXScGI7sQBCEHQiCsANBEHYgCMIOBEHYgSAIOxAE4+xV0DZ/Z7K+ueBv6mSVm3P++rzZubU5y9Jj1d8759FkvWgu/oCn/22fu3ZhfnHntuS+qK7CI7uZrTSzg2a2fci2ZWa238y2ZD9za9smgLJG8jL+65KuHmb73e4+K/tZV922AFRbYdjd/QlJh+rQC4AaKvMB3VIz25q9zD8770Zm1mlmPWbWc0xHSzwcgDIqDfv9ki6UNEtSr6Qv5t3Q3bvcvcPdO0ZrTIUPB6CsisLu7gfcvd/dByQ9ICn/42AATaGisJtZ25BfF0janndbAM2hcJzdzB6SdIWkKWa2T9Kdkq4ws1kanEq9R1J6oWwktcy4OFlvW/nLZL2r/au5taI1zovWhi86Htz9H8MN1Py/37ffFNw/6qUw7O6+aJjNK2rQC4Aa4uuyQBCEHQiCsANBEHYgCMIOBMEU1ybwyyunJOtr2lcn6wOJv9lll4su2r/oVNNPX5s/tPf5venJkv0fa0nW+/btT9bxVhzZgSAIOxAEYQeCIOxAEIQdCIKwA0EQdiAIxtmbQOvGI8n6gf706bxSf7Gv2dyZ3Nefmpisvzb9jWT9b//4sWS986w9ubU1F6XPU3rfYxck69//w9yzoWEYHNmBIAg7EARhB4Ig7EAQhB0IgrADQRB2IAhzT59quJom2CS/1K6s2+OdKux9Myve1zfVdlnklgkTkvXeP393bu3p276S3LdoLv1ld3w6WZ+8otxS2Cejjd6tw35o2JMIcGQHgiDsQBCEHQiCsANBEHYgCMIOBEHYgSCYz34SqPVYeRn9hw8n6+fc85Pc2qjbyi0X/dspRftjqMIju5m1m9kGM3vWzHaY2U3Z9klmtt7MdmWXnEkAaGIjeRnfJ+kz7j5D0vsl3WhmMyTdKqnb3adL6s5+B9CkCsPu7r3u/kx2/YiknZKmSponaVV2s1WS5teqSQDlndB7djM7X9J7JG2U1OruvVnpRUmtOft0SuqUpLE6o9I+AZQ04k/jzexMSd+RdLO7v+VTGR+cTTPsjBp373L3DnfvGK0xpZoFULkRhd3MRmsw6Kvd/bvZ5gNm1pbV2yQdrE2LAKqh8GW8mZmkFZJ2uvuXhpTWSlosaXl2+WhNOsRJ7ddLPpBbG9Dm5L7Fy03jRIzkPftlkj4uaZuZbcm23a7BkH/LzJZIekHS9bVpEUA1FIbd3X8sKe/bC5yJAjhJ8HVZIAjCDgRB2IEgCDsQBGEHgmCKK5JOaz83WX/lgdOT9Sdn3pNbO9D/enLfDz10S7I+7fP502fxThzZgSAIOxAEYQeCIOxAEIQdCIKwA0EQdiAIxtlHanb+sslj7nopuWv/x1rS923pUyK/8Ge/l96/hKOT00t2L5+/Oln/6LjfJOupOekLlhWMo6+Mt+RyLXFkB4Ig7EAQhB0IgrADQRB2IAjCDgRB2IEgbHAxl/qYYJP8Ujs5T0h7yU/zx4v/sXVrct9j3p+sj7b0OHzR/qNyT/4r3ffyBcl9/+UHc5P1sdNfSdb9qYnJ+nmrX8it9e3bn9wXJ26jd+uwHxr2PwRHdiAIwg4EQdiBIAg7EARhB4Ig7EAQhB0IYiTrs7dLelBSqySX1OXuXzazZZL+QtLxydy3u/u6WjXaaI88fllu7R9u+Gly36J1xi/a8MmKejru/K/lj7Ofvvn55L7TDtd2znhfTe8dJ2IkJ6/ok/QZd3/GzMZL2mxm67Pa3e7+z7VrD0C1jGR99l5Jvdn1I2a2U9LUWjcGoLpO6D27mZ0v6T2SNmablprZVjNbaWZn5+zTaWY9ZtZzTEdLNQugciMOu5mdKek7km5298OS7pd0oaRZGjzyf3G4/dy9y9073L1jtMZUoWUAlRhR2M1stAaDvtrdvytJ7n7A3fvdfUDSA5Jm165NAGUVht3MTNIKSTvd/UtDtrcNudkCSdur3x6Aaimc4mpmcyT9p6Rt0ptjSLdLWqTBl/AuaY+kT2Uf5uU6mae4AieD1BTXkXwa/2Np2AnTp+yYOnAq4ht0QBCEHQiCsANBEHYgCMIOBEHYgSAIOxAEYQeCIOxAEIQdCIKwA0EQdiAIwg4EQdiBIOq6ZLOZvSRp6Bq+UyT9qm4NnJhm7a1Z+5LorVLV7O08d3/XcIW6hv0dD27W4+4dDWsgoVl7a9a+JHqrVL1642U8EARhB4JodNi7Gvz4Kc3aW7P2JdFbperSW0PfswOon0Yf2QHUCWEHgmhI2M3sajN7zsyeN7NbG9FDHjPbY2bbzGyLmfU0uJeVZnbQzLYP2TbJzNab2a7sctg19hrU2zIz2589d1vMbG6Dems3sw1m9qyZ7TCzm7LtDX3uEn3V5Xmr+3t2M2uR9HNJH5a0T9ImSYvc/dm6NpLDzPZI6nD3hn8Bw8wul/SqpAfd/d3Zti9IOuTuy7M/lGe7+981SW/LJL3a6GW8s9WK2oYuMy5pvqRPqIHPXaKv61WH560RR/bZkp53993u/oakhyXNa0AfTc/dn5B06G2b50lalV1fpcH/LHWX01tTcPded38mu35E0vFlxhv63CX6qotGhH2qpL1Dft+n5lrv3SU9bmabzayz0c0Mo3XIMlsvSmptZDPDKFzGu57etsx40zx3lSx/XhYf0L3THHd/r6RrJN2YvVxtSj74HqyZxk5HtIx3vQyzzPibGvncVbr8eVmNCPt+Se1Dfj8329YU3H1/dnlQ0ho131LUB46voJtdHmxwP29qpmW8h1tmXE3w3DVy+fNGhH2TpOlmdoGZnS5poaS1DejjHcxsXPbBicxsnKSr1HxLUa+VtDi7vljSow3s5S2aZRnvvGXG1eDnruHLn7t73X8kzdXgJ/K/kPT3jeghp69pkv47+9nR6N4kPaTBl3XHNPjZxhJJkyV1S9ol6UeSJjVRb/+qwaW9t2owWG0N6m2OBl+ib5W0JfuZ2+jnLtFXXZ43vi4LBMEHdEAQhB0IgrADQRB2IAjCDgRB2IEgCDsQxP8BjIqHLPgR5ZoAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.imshow(inputImage.reshape(28,28))\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "

\n", "We are feeding an image of a 3 into the neural network, but the AI classifies it as a 8. The handwriting is a little bit sloppy. Thus the network misclassifies the digit.\n", "

" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Probability distribution of classification:\n", "\n", "0: 0.0 %\n", "1: 0.0 %\n", "2: 0.0 %\n", "3: 4.99 %\n", "4: 0.0 %\n", "5: 0.33 %\n", "6: 0.0 %\n", "7: 0.0 %\n", "8: 90.48 %\n", "9: 4.21 %\n" ] } ], "source": [ "prediction = classifier.predict(inputImage)\n", "print(\"Probability distribution of classification:\\n\")\n", "for i in range(10):\n", " print(f\"{i}: {prettyPerc(prediction[i][0])} %\")" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicted digit: 8\n", "Label: 3\n" ] } ], "source": [ "print(f\"Predicted digit: {np.argmax(prediction)}\")\n", "print(f\"Label: {np.argmax(inputLabel)}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Compute accuracy" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "testDataset = Dataset(name = \"mnist\", train_size = 60000, test_size = 10000, batch_size = 10000)\n", "sample = next(testDataset.batches())" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "# forward propagate the test data\n", "_ = classifier.feedforward(sample[1][0])" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy on test set: 97.79 %\n" ] } ], "source": [ "print(f\"Accuracy on test set: {round(classifier.getAccuracy(sample[1][1]),2)} %\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " Note that our approach is very simple, but we still get an about human level performance in terns of prediction accuracy." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 4 }