HSIC Network - Models that train without back-prop

In this notebook we explain how you can use HSIC Bottleneck paradigm for training a feed-forward neural network. For more information refer to the paper: HSIC Bottleneck

Importing glow modules

[1]:
import glow
from glow.layers import Dense, Dropout, Conv2d, Flatten, HSICoutput
from glow.datasets import mnist, cifar10
from glow.models import IBSequential, Sequential, HSICSequential, Network
from glow.information_bottleneck.estimator import HSIC
from glow.information_bottleneck import Estimator
import torch

Load dataset

[2]:
# hyperparameter
batch_size = 64
num_workers = 3
validation_split = 0.2
num_epochs = 3

# load the dataset
train_loader, val_loader, test_loader = mnist.load_data(
    batch_size=batch_size, num_workers=num_workers, validation_split=validation_split
)

HSIC Bottleneck-based Model

These types of network uses a different paradigm of training a neural network which is described in the paper ‘HSIC Bottleneck - Deep Learing without Back-Propagation’.

[5]:
model = HSICSequential(input_shape=(1, 28, 28), gpu=True)
model.add(Conv2d(filters=16, kernel_size=3, stride=1, padding=1, activation='relu'))
model.add(Flatten())
model.add(Dense(500, activation='relu'), HSIC(kernel='gaussian', gpu=True, sigma=5), regularize_coeff=100)
model.add(Dense(200, activation='relu'))
Running on CUDA enabled device !

Compile and Pre-Training Phase

Compile the model with HSIC IB-based loss objective and train the network for obtaining optimal intermediate representations (which is called pre-training phase).

[4]:
model.compile(loss_criterion=HSIC(kernel='gaussian', gpu=True, sigma=10), optimizer='SGD', regularize_coeff=100)
model.pre_training_loop(num_epochs, train_loader, val_loader)
  0%|          | 0/750 [00:00<?, ?it/s]


Pre-Train-Epoch 1/3
100%|██████████| 750/750 [05:34<00:00,  2.20it/s]
  0%|          | 0/750 [00:00<?, ?it/s]


Pre-Train-Epoch 2/3
100%|██████████| 750/750 [05:23<00:00,  2.03it/s]
  0%|          | 0/750 [00:00<?, ?it/s]


Pre-Train-Epoch 3/3
100%|██████████| 750/750 [05:29<00:00,  2.35it/s]
[ ]: