Import Libraries#
import torch
import torch.nn as nn
import torch.nn.init as init
import torchvision
from torchvision import transforms
from torchvision import models
from torchvision import datasets
from torch.utils.data import DataLoader
Device Info#
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device
device(type='cpu')
Loading Dataset#
train_data = datasets.MNIST(root='./data', train=True, download=True, transform = transforms.ToTensor())
test_data = datasets.MNIST(root='./data', train=False, download=True, transform = transforms.ToTensor())
100%|██████████| 9.91M/9.91M [00:00<00:00, 16.2MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 484kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.49MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 7.89MB/s]
Linear VAE Model (W/O using Conv Layers)#
class LinearVAE(nn.Module):
def __init__(self,input_dim, latent_dim):
super().__init__()
self.latent_dim = latent_dim
self.encoder = nn.Sequential(
nn.Flatten(),
nn.Linear(input_dim , 156),
nn.Tanh(),
nn.Linear(156, 48),
nn.Tanh()
)
self.mean_fc = nn.Sequential(
nn.Linear(48, 16),
nn.Tanh(),
nn.Linear(16,latent_dim)
)
self.logvar_fc = nn.Sequential(
nn.Linear(48, 16),
nn.Tanh(),
nn.Linear(16,latent_dim)
)
self.decoder = nn.Sequential(
nn.Linear(latent_dim, 16),
nn.Tanh(),
nn.Linear(16, 48),
nn.Tanh(),
nn.Linear(48, 156),
nn.Tanh(),
nn.Linear(156, input_dim),
)
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
if m.bias is not None:
init.constant_(m.bias, 0) # Initialize biases to zero
def encode(self, X):
X = self.encoder(X)
mu = self.mean_fc(X)
sigma = self.logvar_fc(X)
return mu, sigma
def reparamterize(self, mu, sigma):
epsilon = torch.rand_like(sigma)
logvar = torch.exp(0.5* sigma)
latent_sample = mu + epsilon * logvar
return latent_sample
def decode(self,latent_sample):
return self.decoder(latent_sample)
def forward(self, X):
batch_size = X.shape[0]
mu, sigma = self.encode(X)
latent_sample = self.reparamterize(mu, sigma)
X_reconstructed = self.decode(latent_sample)
X_reconstructed = X_reconstructed.view(batch_size, 1, 28, 28)
return mu, sigma, X_reconstructed
def fit(self, epochs=10):
train_loader = DataLoader(batch_size = 32, shuffle=True, dataset=train_data)
test_loader = DataLoader(batch_size=32, shuffle=True, dataset=test_data)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(self.parameters(), lr=0.0001)
for epoch in range(epochs):
reconstruct_loss_epoch= 0
kl_loss_epoch = 0
total_loss_epoch =0
for batch, (X,y) in enumerate(train_loader):
X = X.to(device)
mu, sigma, X_reconstructed = self(X)
reconstruction_loss = loss_fn(X_reconstructed, X)
kl_loss = torch.mean(0.5* torch.sum(torch.exp(sigma) + mu**2 - 1 -sigma, dim=-1))
loss = reconstruction_loss + 0.000001 * kl_loss
reconstruct_loss_epoch = reconstruct_loss_epoch + reconstruction_loss.item()
kl_loss_epoch = kl_loss_epoch + kl_loss.item()
total_loss_epoch = total_loss_epoch + reconstruction_loss.item() + kl_loss.item()
# reconstruct_losses.append(reconstruction_loss.item())
# kl_losses.append(kl_loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss_epoch = total_loss_epoch/len(train_loader)
reconstruct_loss_epoch = reconstruct_loss_epoch/len(train_loader)
kl_loss_epoch = kl_loss_epoch/len(train_loader)
print(f'Epoch:- {epoch},Total Loss:-{total_loss_epoch} , Reconstrucion Loss after Epoch:-{reconstruct_loss_epoch}, KL Loss after Epoch:-{kl_loss_epoch} ')
def save_model(self):
checkpoint = {
'model_state_dict': vae_model.state_dict()
}
torch.save(checkpoint,'linear_vae.pth')
Trainining Loop#
vae_model = LinearVAE(784, 2)
vae_model = vae_model.to(device)
vae_model.fit(10)
Epoch:- 0,Total Loss:-1.620043286939462 , Reconstrucion Loss after Epoch:-0.0916821592926979, KL Loss after Epoch:-1.5283611276467641
Epoch:- 1,Total Loss:-2.4466339523136615 , Reconstrucion Loss after Epoch:-0.05852550930778185, KL Loss after Epoch:-2.38810844300588
Epoch:- 2,Total Loss:-3.3052036174257595 , Reconstrucion Loss after Epoch:-0.05529772346417109, KL Loss after Epoch:-3.2499058939615884
Epoch:- 3,Total Loss:-4.257120461044709 , Reconstrucion Loss after Epoch:-0.05370230486591657, KL Loss after Epoch:-4.203418156178793
Epoch:- 4,Total Loss:-5.246161178302765 , Reconstrucion Loss after Epoch:-0.05191369269688924, KL Loss after Epoch:-5.194247485605875
Epoch:- 5,Total Loss:-6.198763153584798 , Reconstrucion Loss after Epoch:-0.05033046340942383, KL Loss after Epoch:-6.148432690175374
Epoch:- 6,Total Loss:-7.065699731987714 , Reconstrucion Loss after Epoch:-0.04910271024902662, KL Loss after Epoch:-7.016597021738688
Epoch:- 7,Total Loss:-7.844319595227639 , Reconstrucion Loss after Epoch:-0.048096689242124555, KL Loss after Epoch:-7.796222905985514
Epoch:- 8,Total Loss:-8.491182029853265 , Reconstrucion Loss after Epoch:-0.0473443729420503, KL Loss after Epoch:-8.443837656911214
Epoch:- 9,Total Loss:-8.97091471409003 , Reconstrucion Loss after Epoch:-0.04683623305161794, KL Loss after Epoch:-8.92407848103841
Original Image#
ori_img = test_data[223][0]
ori_img = ori_img.squeeze(0).numpy()
plt.imshow(ori_img)
<matplotlib.image.AxesImage at 0x781be1e880d0>

Reconstructed Image => Inference with Encoder + Latent + Decoder#
import matplotlib.pyplot as plt
img = test_data[223][0]
vae_model.eval()
mu, sigma,X_reconstructed = vae_model(img.unsqueeze(0).to(device))
X_reconstructed.shape
X_reconstructed = X_reconstructed.squeeze(0,1).detach().numpy()
plt.imshow(X_reconstructed)
<matplotlib.image.AxesImage at 0x781be1d393d0>

Images Generated from Gaussian Distribution => Inference with Decoder#
vae_model.eval()
num_samples = 16
latent_samples = torch.randn(num_samples, 2).to(device)
# Generate images using the decoder
generated_images = vae_model.decoder(latent_samples)
# Reshape the generated images to (batch_size, channels, height, width)
# Assuming grayscale images (1 channel) and 28x28 size
generated_images = generated_images.view(num_samples, 1, 28, 28)
# Create a grid of images
grid = torchvision.utils.make_grid(generated_images, nrow=4, padding=2) # Adjust nrow as needed
# Convert the grid tensor to a PIL Image and then to a NumPy array for displaying
grid_np = grid.permute(1, 2, 0).detach().cpu().numpy()
# Display the grid of images
plt.imshow(grid_np, cmap='gray') # Use cmap='gray' for grayscale images
plt.axis('off') # Hide axes
plt.show()
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-1.1588097..1.2557548].
