{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# HSIC Network - Models that train without back-prop\n", "In this notebook we explain how you can use HSIC Bottleneck paradigm for training a feed-forward neural network.\n", "For more information refer to the paper: [HSIC Bottleneck](https://arxiv.org/abs/1908.01580)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Importing glow modules" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import glow\n", "from glow.layers import Dense, Dropout, Conv2d, Flatten, HSICoutput\n", "from glow.datasets import mnist, cifar10\n", "from glow.models import IBSequential, Sequential, HSICSequential, Network\n", "from glow.information_bottleneck.estimator import HSIC\n", "from glow.information_bottleneck import Estimator\n", "import torch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load dataset" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# hyperparameter\n", "batch_size = 64\n", "num_workers = 3\n", "validation_split = 0.2\n", "num_epochs = 3\n", "\n", "# load the dataset\n", "train_loader, val_loader, test_loader = mnist.load_data(\n", " batch_size=batch_size, num_workers=num_workers, validation_split=validation_split\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## HSIC Bottleneck-based Model\n", "These types of network uses a different paradigm of training a neural network which is described in the paper 'HSIC Bottleneck - Deep Learing without Back-Propagation'." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on CUDA enabled device !\n" ] } ], "source": [ "model = HSICSequential(input_shape=(1, 28, 28), gpu=True)\n", "model.add(Conv2d(filters=16, kernel_size=3, stride=1, padding=1, activation='relu'))\n", "model.add(Flatten())\n", "model.add(Dense(500, activation='relu'), HSIC(kernel='gaussian', gpu=True, sigma=5), regularize_coeff=100)\n", "model.add(Dense(200, activation='relu'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compile and Pre-Training Phase\n", "Compile the model with HSIC IB-based loss objective and train the network for obtaining optimal intermediate representations (which is called pre-training phase)." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\r", " 0%| | 0/750 [00:00