Building Basic model for Understanding ML
Santosh Premi Adhikari
Posted on August 31, 2024
- Simple Neural Network Model
- Training Model and Saving it(.pth),
- Loading model and using it for prediction.
We'll use a small dataset for demonstration, like the classic MNIST dataset, which consists of handwritten digits.
Step 1: Import Libraries and Define the Model
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
# Define a simple neural network
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(28*28, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = x.view(-1, 28*28)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# Instantiate the model, define loss function and optimizer
model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
Step 2: Load the Dataset and Train the Model
# Load the MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
# Train the model
for epoch in range(1): # Train for 1 epoch for simplicity
for images, labels in trainloader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print('Training complete!')
Step 3: Save the Model
# Save the model state dictionary
torch.save(model.state_dict(), 'simple_nn.pth')
print('Model saved!')
Step 4: Load the Model and Make Predictions
# Load the model state dictionary
loaded_model = SimpleNN()
loaded_model.load_state_dict(torch.load('simple_nn.pth'))
loaded_model.eval() # Set the model to evaluation mode
# Make a prediction on a single image
test_image, label = trainset[20] # Use the 20th image from the training set as an example
test_image = test_image.unsqueeze(0) # Add a batch dimension
# Display the image
plt.imshow(test_image.squeeze(), cmap='gray')
plt.title(f'Actual Label: {label}')
plt.axis('off')
plt.show()
output = loaded_model(test_image)
_, predicted = torch.max(output, 1)
print('Predicted label:', predicted.item())
Hope you found this post helpful and enjoyable.
Thank you!
💖 💪 🙅 🚩
Santosh Premi Adhikari
Posted on August 31, 2024
Join Our Newsletter. No Spam, Only the good stuff.
Sign up to receive the latest update from our blog.
Related
datascience Recapping the AI, Machine Learning and Computer Meetup — November 14, 2024
November 15, 2024
machinelearning Exploring the Diversity of Machine Learning: 10 Essential Branches Beyond NLP and Computer Vision
August 3, 2023