import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
# Device configuration
= torch.device('cuda' if torch.cuda.is_available() else 'cpu') device
Architecture
Unlike LeNet-5 the famous Alexnet network was operated with 3-channel images which were (224x224x3) in size. It also used max pooling with ReLU activations when subsampling. The kernels used for convolutions were either 11x11, 5x5, or 3x3 while kernels used for max pooling were 3x3 in size. It classified images into 1000 classes. It also utilized multiple GPUs.
Data Loading
Dataset
Let’s start by loading and then pre-processing the data. For our purposes, we will be using the CIFAR10 dataset. The dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.
Classes in the dataset are completely mututally exclusive. There is no overlap.
importing the libraries
Let’s start by importing the required libraries along with defining a variable device, so that the Notebook knows to use a GPU to train the model if it’s available.
Loading the dataset
Using torchvision (a helper library for computer vision tasks), we will load our dataset. This method has some helper functions that makes pre-processing pretty easy and straight-forward. Let’s define the functions get_train_valid_loader and get_test_loader, and then call them to load in and process our CIFAR-10 data
def get_train_valid_loader(data_dir,
batch_size,
augment,
random_seed,=0.1,
valid_size=True):
shuffle= transforms.Normalize(
normalize =[0.4914, 0.4822, 0.4465],
mean=[0.2023, 0.1994, 0.2010],
std
)
# define transforms
= transforms.Compose([
valid_transform 227,227)),
transforms.Resize((
transforms.ToTensor(),
normalize,
])if augment:
= transforms.Compose([
train_transform 32, padding=4),
transforms.RandomCrop(
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])else:
= transforms.Compose([
train_transform 227,227)),
transforms.Resize((
transforms.ToTensor(),
normalize,
])
# load the dataset
= datasets.CIFAR10(
train_dataset =data_dir, train=True,
root=True, transform=train_transform,
download
)
= datasets.CIFAR10(
valid_dataset =data_dir, train=True,
root=True, transform=valid_transform,
download
)
= len(train_dataset)
num_train = list(range(num_train))
indices = int(np.floor(valid_size * num_train))
split
if shuffle:
np.random.seed(random_seed)
np.random.shuffle(indices)
= indices[split:], indices[:split]
train_idx, valid_idx = SubsetRandomSampler(train_idx)
train_sampler = SubsetRandomSampler(valid_idx)
valid_sampler
= torch.utils.data.DataLoader(
train_loader =batch_size, sampler=train_sampler)
train_dataset, batch_size
= torch.utils.data.DataLoader(
valid_loader =batch_size, sampler=valid_sampler)
valid_dataset, batch_size
return (train_loader, valid_loader)
def get_test_loader(data_dir,
batch_size,=True):
shuffle= transforms.Normalize(
normalize =[0.485, 0.456, 0.406],
mean=[0.229, 0.224, 0.225],
std
)
# define transform
= transforms.Compose([
transform 227,227)),
transforms.Resize((
transforms.ToTensor(),
normalize,
])
= datasets.CIFAR10(
dataset =data_dir, train=False,
root=True, transform=transform,
download
)
= torch.utils.data.DataLoader(
data_loader =batch_size, shuffle=shuffle
dataset, batch_size
)
return data_loader
# CIFAR10 dataset
= get_train_valid_loader(data_dir = './data', batch_size = 64,
train_loader, valid_loader = False, random_seed = 1)
augment
= get_test_loader(data_dir = './data',
test_loader = 64) batch_size
AlexNet from Scratch
class AlexNet(nn.Module):
def __init__(self, num_classes=10):
super(AlexNet, self).__init__()
self.layer1 = nn.Sequential(
3, 96, kernel_size=11, stride=4, padding=0),
nn.Conv2d(96),
nn.BatchNorm2d(
nn.ReLU(),= 3, stride = 2))
nn.MaxPool2d(kernel_size self.layer2 = nn.Sequential(
96, 256, kernel_size=5, stride=1, padding=2),
nn.Conv2d(256),
nn.BatchNorm2d(
nn.ReLU(),= 3, stride = 2))
nn.MaxPool2d(kernel_size self.layer3 = nn.Sequential(
256, 384, kernel_size=3, stride=1, padding=1),
nn.Conv2d(384),
nn.BatchNorm2d(
nn.ReLU())self.layer4 = nn.Sequential(
384, 384, kernel_size=3, stride=1, padding=1),
nn.Conv2d(384),
nn.BatchNorm2d(
nn.ReLU())self.layer5 = nn.Sequential(
384, 256, kernel_size=3, stride=1, padding=1),
nn.Conv2d(256),
nn.BatchNorm2d(
nn.ReLU(),= 3, stride = 2))
nn.MaxPool2d(kernel_size self.fc = nn.Sequential(
0.5),
nn.Dropout(9216, 4096),
nn.Linear(
nn.ReLU())self.fc1 = nn.Sequential(
0.5),
nn.Dropout(4096, 4096),
nn.Linear(
nn.ReLU())self.fc2= nn.Sequential(
4096, num_classes))
nn.Linear(
def forward(self, x):
= self.layer1(x)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = self.layer5(out)
out = out.reshape(out.size(0), -1)
out = self.fc(out)
out = self.fc1(out)
out = self.fc2(out)
out return out
Setting Hyperparameters
= 10
num_classes = 20
num_epochs = 64
batch_size = 0.005
learning_rate
= AlexNet(num_classes).to(device)
model
# Loss and optimizer
= nn.CrossEntropyLoss()
criterion = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay = 0.005, momentum = 0.9)
optimizer
# Train the model
= len(train_loader) total_step
Training
= len(train_loader)
total_step
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
# Move tensors to the configured device
= images.to(device)
images = labels.to(device)
labels
# Forward pass
= model(images)
outputs = criterion(outputs, labels)
loss
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
format(epoch+1, num_epochs, i+1, total_step, loss.item()))
.
# Validation
with torch.no_grad():
= 0
correct = 0
total for images, labels in valid_loader:
= images.to(device)
images = labels.to(device)
labels = model(images)
outputs = torch.max(outputs.data, 1)
_, predicted += labels.size(0)
total += (predicted == labels).sum().item()
correct del images, labels, outputs
print('Accuracy of the network on the {} validation images: {} %'.format(5000, 100 * correct / total))