{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Multi-Layer Perceptron: Classification of handwritten digits (MNIST)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# importing functions and classes from our framework\n", "from dataset import Dataset\n", "from nn import MLP\n", "from layers import Dense\n", "# other imports\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "np.random.seed(1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Demo" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "In the following, we will demonstrate how our Multi-Layer Perceptron can be used to classify handwritten digits.\n", "For that purpose, we are loading a pretrained model whose architecture can be seen below. Note that unlike the previous layers, the output of the neural network uses Softmax as an activation function. This combined with the Crossentropy loss function is a good setup for classification tasks. Softmax returns a probability distribution over the available classes and Crossentropy yields a faster convergence for probability distributions than the Mean-Squared Error.\n", "
" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "-------------------- MULTI LAYER PERCEPTRON (MLP) --------------------\n", "\n", "HIDDEN LAYERS = 1 \n", "TOTAL PARAMETERS = 84060 \n", "\n", " *** 1. Layer: *** \n", "------------------------\n", "DENSE 784 -> 100 [ReLU]\n", "------------------------\n", "Total parameters: 78500 \n", "---> WEIGHTS: (100, 784)\n", "---> BIASES: (100,)\n", "------------------------\n", "\n", " *** 2. Layer: *** \n", "-----------------------\n", "DENSE 100 -> 50 [ReLU]\n", "-----------------------\n", "Total parameters: 5050 \n", "---> WEIGHTS: (50, 100)\n", "---> BIASES: (50,)\n", "-----------------------\n", "\n", " *** 3. Layer: *** \n", "-------------------------\n", "DENSE 50 -> 10 [Softmax]\n", "-------------------------\n", "Total parameters: 510 \n", "---> WEIGHTS: (10, 50)\n", "---> BIASES: (10,)\n", "-------------------------\n", "\n", "----------------------------------------------------------------------\n", "\n" ] } ], "source": [ "classifier = MLP()\n", "classifier.load(\"mnist_classifier\") # classifier is saved in '/models/mnist_classifier'\n", "print(classifier)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We now load a mini-batch of size 1 from the shuffled MNIST dataset and take the input image from the test set." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "dataset = Dataset(name = \"mnist\", train_size = 60000, test_size = 10000, batch_size = 1)\n", "randomBatch = next(dataset.batches())\n", "inputImage = randomBatch[1][0]\n", "inputLabel = randomBatch[1][1]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The first image from the test set is shown below." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAPeklEQVR4nO3df5BddXnH8c+T3+QHkPBjTZMUCCYiWAx1DVBSRWkxIBaoFmE6TFoZF0foyAx2iHQ6Mk6npkgARaSGQo0Wg0z52Q4WQkSZAA3Z0BgSgga2QRJDYowjgZKf+/SPPTgL7Hnucn9nn/drZufePc899zwc9pNz7/3ec77m7gIw9A1rdQMAmoOwA0kQdiAJwg4kQdiBJEY0c2OjbLSP0bhmbhJIZZde0x7fbQPVagq7mc2V9HVJwyX9i7sviB4/RuN0sp1RyyYBBFb4stJa1S/jzWy4pJslnSXpeEkXmdnx1T4fgMaq5T37bEnPu3uPu++RdKekc+vTFoB6qyXsUyS91O/3TcWyNzGzLjPrNrPuvdpdw+YA1KLhn8a7+yJ373T3zpEa3ejNAShRS9g3S5rW7/epxTIAbaiWsK+UNMPMjjGzUZIulPRAfdoCUG9VD725+z4zu1zSQ+obervd3dfVrTMAdVXTOLu7PyjpwTr1AqCB+LoskARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSdQ0iyuGvn0f/UBY3/iJkWF9+SevK60dOXxsuO6ne84M6z/bfmRYd7fS2hHfPihcd9RD3WH9QFRT2M1so6SdkvZL2ufunfVoCkD91ePI/hF3316H5wHQQLxnB5KoNewu6WEzW2VmXQM9wMy6zKzbzLr3aneNmwNQrVpfxs9x981mdqSkpWb2nLs/1v8B7r5I0iJJOtgmeY3bA1Clmo7s7r65uN0m6V5Js+vRFID6qzrsZjbOzCa8cV/SmZLW1qsxAPVVy8v4Dkn3mtkbz/N9d/+vunSFutl1Tvxi64JrfxjWPz3hG2F94rB4vLpX5fVexe/qlkx/KKwPm14+jl7p+ecf9cFw3XXbTwjrvmpdWG9HVYfd3Xskvb+OvQBoIIbegCQIO5AEYQeSIOxAEoQdSIJTXNuAjYj/NwyfNiWs/+9fltef/NzCcN2xNiqsS2Mq1Ku31/eH9VMXXhHWx23pjdf/26dKawvetTJcd+4/TQ3rI/4kLLcljuxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kATj7G2g0jj6fcvvCevDVH6qZ6/icfS5688P61cdHZ8Ce8ZB1V9q7B+3x5epnnr/L8P6vp6NYf0JO6W8eF08zl7pv3uh4lNg2xFHdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnH2NrD7qMPCejSOXskJ37s8rM+a8/OwXss4uiR95Jm/KK2NvnZiuO6InlU1bTtSyz49UHFkB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkGGdvAz2fHBnWK01tHI0Zr7v4m+G6r/bG4+j/sL0zrP/4qtPC+riHny4v9vaE6zZSpX06FFU8spvZ7Wa2zczW9ls2ycyWmtmG4jb+dgSAlhvMy/jvSJr7lmXzJS1z9xmSlhW/A2hjFcPu7o9J2vGWxedKWlzcXyzpvDr3BaDOqn3P3uHuW4r7L0vqKHugmXVJ6pKkMRpb5eYA1KrmT+Pd3aXyTzvcfZG7d7p750iNrnVzAKpUbdi3mtlkSSput9WvJQCNUG3YH5A0r7g/T9L99WkHQKNUfM9uZksknS7pcDPbJOnLkhZIusvMLpH0oqQLGtnkUHfos/G/uY+fHY/DR7668eywvvW+3w/rHTc9EdZHKb7+eiMNP+E9Yf2qr/xb1c+9bOeBd134SiqG3d0vKimdUedeADQQX5cFkiDsQBKEHUiCsANJEHYgCU5xbQNH3PJkWP/qLSfW8OybwmpHhXo7e/HP4ktwf3zsb6t+7hMOivfLak2t+rlbhSM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBODsOWLuO7A3r0SW2f9P7erjuNxaWTzUtSYcp/m5EO+LIDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJMM6OA9b0EzeH9Wha5pP//cpw3XffeuCNo1fCkR1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkmCcHS1jI+I/v62fmx3Wu4/7Zli/77WJpbXjbno5XHdfWD0wVTyym9ntZrbNzNb2W3aNmW02s9XFTzwJOICWG8zL+O9ImjvA8hvcfVbx82B92wJQbxXD7u6PSdrRhF4ANFAtH9BdbmZripf5pW+OzKzLzLrNrHuvdtewOQC1qDbst0g6VtIsSVskLSx7oLsvcvdOd+8cqdFVbg5AraoKu7tvdff97t4r6VZJ8cemAFquqrCb2eR+v54vaW3ZYwG0h4rj7Ga2RNLpkg43s02SvizpdDObJcklbZR0aQN7xBBVaRz9qS/dFNY/v3lOWF9z4/tLawf3/He47lBUMezuftEAi29rQC8AGoivywJJEHYgCcIOJEHYgSQIO5BEmlNcd3zm1LD++jmvVP3ca0+5I6zv93hq4UrufPWIsL7g2Y+V1vb99NBw3WNuei6s7/91badFRPt98RevD9d99PUJYf3xu08K61O+/0RYz4YjO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kYe7l09rW28E2yU+2M5q2vf5mrIyvknPD78Vjskt2dpTWfrHn8HDd/R7/mzph+K6w/jcTN4T1YbLSWjRtsST9z574OwD/1xvvt79+5JKw/pOzbiitTbB4v3z4xi+G9ckLGUd/qxW+TK/4jgH/IDiyA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EASQ+Z89g03nRzW75p8Y1if//JpYX1d13tLa75qXbhuJTY6Pm/7hx/8UFj/5R+PLa3tPfG1cN3//KNvhfWTRsXHg+fP+XZY79VBpbXj7rosXPfdjKPXFUd2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUhiyIyzj53yaly3UWH9npWdYX3mqqfecU+D5bt3h/Vhy1eH9anLy2vbu+Lr5X9r5ofD+tfetSKs1+LJTy0M63Nej89nP+bqJ+vZzpBX8chuZtPM7FEze9bM1pnZF4rlk8xsqZltKG4nNr5dANUazMv4fZKudPfjJZ0i6TIzO17SfEnL3H2GpGXF7wDaVMWwu/sWd3+6uL9T0npJUySdK2lx8bDFks5rVJMAaveO3rOb2dGSTpK0QlKHu28pSi9LGvAibWbWJalLksao/DvcABpr0J/Gm9l4SXdLusLd3zQLovddtXLAKxu6+yJ373T3zpGKL14IoHEGFXYzG6m+oN/h7vcUi7ea2eSiPlnStsa0CKAeKr6MNzOTdJuk9e7ef47dByTNk7SguL2/IR0O0mlTe8J6dLllSTp8xfB6tlNXu86ZHdY75r9QWnvwmJtr2vapqy8M64d9Kd5vr7znkNLavK/8R7ju+nlx72c+8tmwPuJHq8J6NoN5z36apIslPWNmbwz4Xq2+kN9lZpdIelHSBY1pEUA9VAy7uy+XSg+LrZnxAcA7xtdlgSQIO5AEYQeSIOxAEoQdSGLInOLaszOeNrnS1MXvu3RtWH9u1ymltfGb4lNUX/hU/M3BmX/wUlj/5+nxZbBnjhxTWrv+NzPCdf91ycfC+lELnw7rvbvi6abHrymv3bc07u36RfFgz/hj49OWD/tRWE6HIzuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJGF9F5lpjoNtkp9sjTlRzj5wQliffPMvwvqiaT+uetuVzpWvNMZfybo9+8L6n//k86W19/791nDdfS9tqqontKcVvkyv+I4B/yA5sgNJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEkPmfHZftS6sb/34pLA+82uXhvWxh7xeWjv0BxPCdX87Pf439ZCe3rB+6OPxdwRmbC6/Pno8Qo9MOLIDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKDmZ99mqTvSuqQ5JIWufvXzewaSZ+V9KvioVe7+4ONarRW+3+9I6zP/Excr8X4GtdnrBz1MJgv1eyTdKW7P21mEyStMrOlRe0Gd7+uce0BqJfBzM++RdKW4v5OM1svaUqjGwNQX+/oPbuZHS3pJEkrikWXm9kaM7vdzCaWrNNlZt1m1r1X8TRJABpn0GE3s/GS7pZ0hbu/IukWScdKmqW+I//CgdZz90Xu3ununSMVz3kGoHEGFXYzG6m+oN/h7vdIkrtvdff97t4r6VZJsxvXJoBaVQy7mZmk2yStd/fr+y2f3O9h50uKp0EF0FKD+TT+NEkXS3rGzFYXy66WdJGZzVLfcNxGSfE5ogBaajCfxi+XBrwwetuOqQN4O75BByRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSMLcvXkbM/uVpBf7LTpc0vamNfDOtGtv7dqXRG/VqmdvR7n7EQMVmhr2t23crNvdO1vWQKBde2vXviR6q1azeuNlPJAEYQeSaHXYF7V4+5F27a1d+5LorVpN6a2l79kBNE+rj+wAmoSwA0m0JOxmNtfMfmZmz5vZ/Fb0UMbMNprZM2a22sy6W9zL7Wa2zczW9ls2ycyWmtmG4nbAOfZa1Ns1Zra52HerzezsFvU2zcweNbNnzWydmX2hWN7SfRf01ZT91vT37GY2XNLPJf2ppE2SVkq6yN2fbWojJcxso6ROd2/5FzDM7EOSXpX0XXd/X7HsWkk73H1B8Q/lRHe/qk16u0bSq62exruYrWhy/2nGJZ0n6a/Uwn0X9HWBmrDfWnFkny3peXfvcfc9ku6UdG4L+mh77v6YpB1vWXyupMXF/cXq+2NpupLe2oK7b3H3p4v7OyW9Mc14S/dd0FdTtCLsUyS91O/3TWqv+d5d0sNmtsrMulrdzAA63H1Lcf9lSR2tbGYAFafxbqa3TDPeNvuumunPa8UHdG83x93/UNJZki4rXq62Je97D9ZOY6eDmsa7WQaYZvx3Wrnvqp3+vFatCPtmSdP6/T61WNYW3H1zcbtN0r1qv6mot74xg25xu63F/fxOO03jPdA042qDfdfK6c9bEfaVkmaY2TFmNkrShZIeaEEfb2Nm44oPTmRm4ySdqfabivoBSfOK+/Mk3d/CXt6kXabxLptmXC3edy2f/tzdm/4j6Wz1fSL/gqS/a0UPJX1Nl/TT4mddq3uTtER9L+v2qu+zjUskHSZpmaQNkh6RNKmNevuepGckrVFfsCa3qLc56nuJvkbS6uLn7Fbvu6Cvpuw3vi4LJMEHdEAShB1IgrADSRB2IAnCDiRB2IEkCDuQxP8DH3F/xflsqvEAAAAASUVORK5CYII=\n", "text/plain": [ "