Building Basic model for Understanding ML

santoshpremi

Santosh Premi Adhikari

Posted on August 31, 2024

Building Basic model for Understanding ML
  1. Simple Neural Network Model
  2. Training Model and Saving it(.pth),
  3. Loading model and using it for prediction.

We'll use a small dataset for demonstration, like the classic MNIST dataset, which consists of handwritten digits.

Step 1: Import Libraries and Define the Model

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# Define a simple neural network
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Instantiate the model, define loss function and optimizer
model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

Enter fullscreen mode Exit fullscreen mode

Step 2: Load the Dataset and Train the Model

# Load the MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)


# Train the model
for epoch in range(1):  # Train for 1 epoch for simplicity
    for images, labels in trainloader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

print('Training complete!')
Enter fullscreen mode Exit fullscreen mode

Step 3: Save the Model

# Save the model state dictionary
torch.save(model.state_dict(), 'simple_nn.pth')
print('Model saved!')
Enter fullscreen mode Exit fullscreen mode

Step 4: Load the Model and Make Predictions

# Load the model state dictionary
loaded_model = SimpleNN()
loaded_model.load_state_dict(torch.load('simple_nn.pth'))
loaded_model.eval()         # Set the model to evaluation mode

# Make a prediction on a single image
test_image, label = trainset[20]   # Use the 20th image from the training set as an example
test_image = test_image.unsqueeze(0)  # Add a batch dimension


# Display the image
plt.imshow(test_image.squeeze(), cmap='gray')
plt.title(f'Actual Label: {label}')
plt.axis('off')
plt.show()

output = loaded_model(test_image)
_, predicted = torch.max(output, 1)

print('Predicted label:', predicted.item())

Enter fullscreen mode Exit fullscreen mode

Result

Hope you found this post helpful and enjoyable.
Thank you!

💖 💪 🙅 🚩
santoshpremi
Santosh Premi Adhikari

Posted on August 31, 2024

Join Our Newsletter. No Spam, Only the good stuff.

Sign up to receive the latest update from our blog.

Related