PyTorch Quick Ref
Lam
Posted on December 17, 2023
Imports
import torch # root package
from torch.utils.data import Dataset, DataLoader # dataset representation and loading
Neural Network API
import torch.autograd as autograd # computation graph
from torch import Tensor # tensor node in the computation graph
import torch.nn as nn # neural networks
import torch.nn.functional as F # layers, activations and more
import torch.optim as optim # optimizers e.g. gradient descent, ADAM, etc.
from torch.jit import script, trace # hybrid frontend decorator and tracing jit
TorchScript and JIT
torch.jit.trace() # takes your module or function and an example
# data input, and traces the computational steps
# that the data encounters as it progresses through the model
@script # decorator used to indicate data-dependent
# control flow within the code being traced
ONNX
torch.onnx.export(model, dummy data, xxxx.proto) # exports an ONNX formatted
# model using a trained model, dummy
# data and the desired file name
model = onnx.load("alexnet.proto") # load an ONNX model
onnx.checker.check_model(model) # check that the model
# IR is well formed
onnx.helper.printable_graph(model.graph) # print a human readable
# representation of the graph
References:
💖 💪 🙅 🚩
Lam
Posted on December 17, 2023
Join Our Newsletter. No Spam, Only the good stuff.
Sign up to receive the latest update from our blog.