Training MNIST data classifier in PyGlow

This notebook is to make you realize how much PyGlow shares with Keras API structure. The purpose of this type of API is to make anyone just entering this field to quickly realize the power of these tools rather than get overwhelmed by the complexity of the theory itself.

Importing glow modules

[1]:
from glow.layers import Dense, Dropout, Conv2d, Flatten
from glow.models import Sequential
from glow.datasets import mnist
import numpy as np

Load dataset

Currently PyGlow supports two standard datasets namely mnist, cifar10 so you can use either of them for experimention.

[2]:
# hyperparameter
batch_size = 64
num_workers = 3
validation_split = 0.2
num_epochs = 2

# load the dataset
train_loader, val_loader, test_loader = mnist.load_data(
    batch_size=batch_size, num_workers=num_workers, validation_split=validation_split
)

Keras like Sequential Model

We use the basic keras-like Sequential model to build our model in the ‘keras’ way !

[6]:
model = Sequential(input_shape=(1, 28, 28), gpu=True)
model.add(Conv2d(filters=16, kernel_size=3, stride=1, padding=1, activation='relu'))
model.add(Flatten())
model.add(Dropout(0.4))
model.add(Dense(500, activation='relu'))
model.add(Dropout(0.4))
model.add(Dense(200, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(10, activation='softmax'))
Running on CUDA enabled device !

Compile the model

[4]:
model.compile(optimizer='SGD', loss='cross_entropy', metrics=['accuracy'])

Train the model

[5]:
# training the model
model.fit_generator(train_loader, val_loader, num_epochs, show_plot=False)
  0%|          | 0/750 [00:00<?, ?it/s]


Epoch 1/2
Training loop:
100%|██████████| 750/750 [00:34<00:00, 19.50it/s]
  0%|          | 0/188 [00:00<?, ?it/s]


loss: 2.24 - acc: 0.39
Validation loop:
100%|██████████| 188/188 [00:02<00:00, 93.60it/s]
  0%|          | 0/750 [00:00<?, ?it/s]


loss: 1.99 - acc: 0.50


Epoch 2/2
Training loop:
100%|██████████| 750/750 [00:32<00:00, 22.96it/s]
  0%|          | 0/188 [00:00<?, ?it/s]


loss: 1.81 - acc: 0.83
Validation loop:
100%|██████████| 188/188 [00:01<00:00, 105.00it/s]


loss: 1.66 - acc: 0.84