Convert my Pytorch model to Pytorch Lightning
Quoc Bao
Posted on June 29, 2022
Hello, everybody! Today I am going to show how you how to convert my model from Pytorch to Pytorch Lightning. Pytorch Lightning is a light-weight deep learning framework built upon Pytorch. It removes a lot of boilerplate code (standard code that can be found in almost any deep learning pipeline) and adds in many functions that helps to interfere training at a specific position.
Firstly, I import the libraries.
pip install pytorch-lightning
import pytorch_lightning as pl
Pytorch LightningModule resembles nn.Module. Forward function can be defined in a pl class.
# an nn class can be converted to a pl class by replacing nn with pl
class NeuralNet(nn.Module):
# --> class NeuralNet(pl.LightningModule):
def __init__(self, input_size, num_classes):
super(NeuralNet, self).__init__()
self.fc1 = nn.Linear(input_size, 50)
self.fc2 = nn.Linear(50, num_classes)
# --> specific functions belong to nn class should not be changed!
def forward(self, x):
out = self.fc1(x)
out = torch.sigmoid(out)
out = self.fc2(out)
return out
Read more here.
💖 💪 🙅 🚩
Quoc Bao
Posted on June 29, 2022
Join Our Newsletter. No Spam, Only the good stuff.
Sign up to receive the latest update from our blog.