Lessons Learned from Using Torch Inductor For Inference
Aaron Langford
Posted on November 16, 2024
The purpose of this blog post is to give an intro to compiling models using Torch Inductor along with some helpful advice to avoid pitfalls. I began using Torch Inductor this year as a part of a dive into optimization of PyTorch models for inference. My team needed a way to cut latency down for a proprietary diffusion architecture. This is when I stumbled on Torch's Inductor tool kit. Since then I've been able to help 3 different teams at Amazon speed up their inference using Torch Inductor.
Just In Time compilation with Inductor
I first played with Inductor via a "Just In Time" method. This is not to be confused with the JIT tracing feature available in Torch. This is where torch can wrap a model with torch.compile()
and let the compiler find optimizations as multiple inference requests progress. Torch's documentation currently heavily steers users towards using this approach for compilation.
Advantages
The advantage of torch.compile()
is the ease of use. I recommend anyone just getting started to explore how much torch can accelerate their models by using this approach. Here are some of the big benefits I found:
1) More Forgiving Compilation with Graph Breaks
Compilation in this API is more forgiving compared to other APIs because of a feature called "graph breaks". When Torch compiles code, it first creates a graph representation of the model. It's helpful to picture this as a sequence of tasks chained together. A compiler scans this chain of tasks and produces more optimal versions of each link in the chain (or node in the graph). For some of these links, the compiler just can't figure out how to compile successfully. Instead of erroring out, it just leaves the original Python for that node in place, and will optimize all nodes in the graph that it can manage.
2) Dynamic Shapes
Additionally if there are different shapes (like image size or number of frames in a video) in the input, Torch is able to recompile the network on the fly to handle those different shapes.
3) Decoupled From Weights
This is not applicable in all cases, but a popular feature in AI services is to allow customers to swap in their own weights for a given architecture, or change the default weights through a fine tuning offered by the model provider. Swapping weights out is pretty easy for the Just In Time compilation, because compilation happens right before an inference request is served. Other approaches that compile models ahead of time are more difficult because they typically embed the weights in the executable, requiring a recompilation of the same model with different weights.
4) Multiple Backends
With the torch.compile()
interface, there are multiple choices for backend. While Inductor is the default, NVidia's TensorRT can also be used as a backend. This allows for quick evaluation of multiple optimization platforms for a torch model.
Hopefully in the future, this will be a path for more hardware platforms (like Amazon's Neuron platform or Google's TPU platform) to be easily compiled to.
4) Flexible Targets
A model with dozens of layers can be compiled as well as just a simple attention layer. This allows compilation to be applied in many ways to a model, even if that model is doing some more complex things like caching, dynamic runtime pruning, or cross-device communication.
Disadvantages
1) No way to "save" an optimized model
The disadvantage of this approach is that all the decisions for optimizations are gone once the process ends. This means that the optimal model must be found each time an inference server is started.
2) Slow start for Inference
Inference servers using torch.compile()
will be slower at first unless something else runs some warm up requests before the server takes the first real requests.
3) Recompilations Introduce Latency Spikes
Any recompilations that Torch deems necessary will result in higher latency for some requests. Torch should only need to recompile if input shapes or control flow in the forward pass changes.
Ahead of Time Compilation with AOTInductor
I was motivated enough by these disadvantages to look for an "Ahead of Time" version of torch.compile()
. Sure enough, Torch does have this! They have a module called AOTInductor
under the torch.export
package, which can do much of what the "Just in Time" approach does.
How AOTInductor works
AOTInductor takes any torch.nn.module (something with a forward
function) or a plain old python function. First, the arbitrary Python code is traced with Torch's Dynamo module and the execution is translated into a Torch FX Graph. This is an intermediate representation of the code that will be optimized.
The Torch FX Graph version of the model is then handed off to Inductor. It generates an optimized version of the model composed of Triton Kernels and a C++ orchestrator. The kernels are compiled to .cubin
files and the C++ is compiled to a .so
file. Torch provides both a C++ and a Python wrapper which can be called in the same way the original function was.
Benefits
In my experience with AOTInductor, I was able to knock off 30-40% of latency of the eager PyTorch forward path for UNet and UViT based diffusion models.
Additionally, the AOTInductor compiler enables an automated build pipeline, so scientists can stay in PyTorch, numpy, etc and have an automated step produce an optimized inference time ready version of their model.
Best Practices
I compiled a list of best practices to use when writing models that can be compiled by the AOTInductor compiler.
1) Only use torch.Tensor types as input to the model.
import torch
import torch.nn as nn
class IncorrectModel(nn.Module):
def __init__(self):
super().__init__()
self.layer = nn.Linear(10, 10)
def forward(self, input_tensor, config_dict, mode="default"):
# Check the config_dict and mode string to determine model behavior
if mode == "special" and config_dict.get("use_special", False):
# Apply some custom logic based on the non-tensor config
x = input_tensor * 2
else:
x = input_tensor
return self.layer(x)
Prefer to keep non-tensor inputs as properties of the module. If the model's forward pass requires something that's fundamentally not a tensor, lift it out of the forward pass, make a wrapper and compile logic that does not depend on this kind of parameter. If all else fails, try boxing parameters in a tensor. But beware, this is a bit of an anti-pattern in my opinion. that this is a smell of an overcomplicated forward pass.
class CorrectModel(nn.Module):
def __init__(self, use_special=False, mode="default"):
super().__init__()
self.layer = nn.Linear(10, 10)
self.use_special = use_special
self.mode = mode
def forward(self, input_tensor):
# Use model properties instead of passing non-tensor inputs
if self.mode == "special" and self.use_special:
x = input_tensor * 2
else:
x = input_tensor
return self.layer(x)
2) Avoid using optional arguments, kwargs, and argument expansions anywhere in the forward pass.
Prefer to use explicit arguments everywhere in forward passes. Either the tensor will be there or it won't.
3) Avoid problematic Python code in the forward pass
Calls to things like isfunction
will break the compiler.
Calls to regex libraries also tend to break the compiler. This tends to come up when people write code that conditionally check if certain versions of packages (like xformers
) before using one implementation of a layer vs another.
Instead, take care of selecting layers in initialization logic.
4) Forward pass should be a pure function.
For a model that will be compiled, it should be stateless. This means that forward(x)
should produce the same output no matter how many times it is called, or what else is called between calls to forward(x)
, as long as x is the same. State can be managed outside of the target compiled model.
When the forward call changes the state of the model, this causes a few problems. For example, writes to self
are completely disallowed by the Inductor compiler.
Another common pattern I've seen is caching in between forward passes. This is common in diffusion models where one image or video requires many steps, and each step requires a forward pass. This leads to frequent state sharing in models that are intended to be used from step to step. Here's an example:
class IncorrectCachingModel(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(10, 10)
self.layer2 = nn.Linear(10, 10)
# Register a buffer for caching, marked as non-persistent
self.register_buffer("cached_output", None, persistent=False)
def forward(self, x):
# Check if there's a cached output to reuse
if self.cached_output:
print("Using cached output")
return self.cached_output
# Perform the full forward pass and cache the result
x = self.layer1(x)
x = torch.relu(x)
x = self.layer2(x)
# Cache the result by setting it in the registered buffer
self.register_buffer("cached_output", x, persistent=False)
return x
Instead, lift the tensors that need to be cached out of the forward pass, or make them a parameter of the class if they can be computed ahead of time.
Sometimes, the caching is of data that can't be known until after a forward pass takes place. For example, the DeepCache paper suggests that for UNet based diffusion, the middle blocks of the network may be skipped. In the place of the output of those middle blocks, the output of the middle blocks from a prior step is retrieved from a cache.
In this case, I'd recommend including the part of the network that needs to be cached as an output of the forward pass. This might be concatenated along with the actual output of the forward pass. The downside here is that matching shapes of the cached data match the output of the forward pass can be awkward.
Another approach to consider could be compiling individual blocks of the model, and coordinating the caching in Python (or whatever runtime is used to execute the network).
5) Minimize complexity in the forward pass.
Most forward pass complexity that I've seen stems from branching. While the compiler can handle branching, and is getting better at this in more recent versions of Torch, the compiler will usually try to just pick the right branch at compile time.
Most of this trouble comes when engineers try to create highly configurable models with loads of flags. This is a desirable property, especially in a rapidly evolving model. Scientists want to be able to include a new optional architecture configuration and test it with different combinations of previous features.
import torch
import torch.nn as nn
class ComplexConfigurableModel(nn.Module):
def __init__(self, use_special_layer=False, apply_dropout=False, activation="relu"):
super().__init__()
self.use_special_layer = use_special_layer
self.apply_dropout = apply_dropout
self.activation = activation
# Define layers, some of which are optional
self.layer1 = nn.Linear(10, 20)
self.special_layer = nn.Linear(20, 20) if use_special_layer else None
self.dropout = nn.Dropout(p=0.5) if apply_dropout else None
self.layer2 = nn.Linear(20, 10)
def forward(self, x):
# Forward pass with branching logic based on configuration flags
x = self.layer1(x)
# Apply special layer if specified
if self.use_special_layer and self.special_layer is not None:
x = self.special_layer(x)
# Apply specified activation function
if self.activation == "relu":
x = torch.relu(x)
elif self.activation == "sigmoid":
x = torch.sigmoid(x)
elif self.activation == "tanh":
x = torch.tanh(x)
# Apply dropout if specified
if self.apply_dropout and self.dropout is not None:
x = self.dropout(x)
x = self.layer2(x)
return x
Instead of forward pass code with tons of branches, aim to push the complexity inherent in rapidly evolving models into the initialization of the model. Builder pattern, composite pattern, and strategy pattern are classic object oriented patterns for addressing this.
class SimplifiedModel(nn.Module):
def __init__(self, activation="relu"):
super().__init__()
self.layer1 = nn.Linear(10, 20)
self.layer2 = nn.Linear(20, 10)
# Set activation function based on input; done at initialization
if activation == "relu":
self.activation = torch.relu
elif activation == "sigmoid":
self.activation = torch.sigmoid
elif activation == "tanh":
self.activation = torch.tanh
else:
raise ValueError("Unsupported activation")
def forward(self, x):
# Forward pass without branching
x = self.layer1(x)
x = self.activation(x)
x = self.layer2(x)
return x
Alternatively, it may be better to maintain experimental classes that define networks that are actively evolving. Then when an architecture configuration is settled on, create a class that represents a stripped down version of the model. This is beneficial because it makes debugging issues with compiled models simpler.
6) Use a package distribution service to distribute compiled artifacts
Depending on your inference time environment, the choice might be different. I opted to wrap my .so
and .cubin
files in a Python wheel and publish the wheel to AWS Code Artifact. Because some of these artifacts can be many GBs large, a limit increase with AWS Code Artifact may be necessary.
This made it extremely easy to distribute the model because the entire package was versioned. We had to produce new versions of this model nearly weekly as changes to weights and architecture rolled in. A package that could be easily pulled in via a bump to someone's requirements.txt made version upgrades very easy.
Finally, these artifacts have strict requirements for dependencies matching. Python wheels gave us an easy way to communicate that a specific version of Torch had to be there in the inference runtime environment.
7) Use an absolute path when writing out compiled artifacts
If you use a relative path, Inductor will dump .cubin
files into a directory in /tmp
. The C++ uses absolute paths to refer to the kernels defined in the .cubin
files, and /tmp
is not a great place to keep executables that need to be around for a while.
I recommend choosing a place in /opt/
to write these files. This does mean that at run time, the files need to be in the same path they were initially written to on compilation. This is another reason why a wheel is a good choice for distribution, as it allowed us to include logic for setting up symlinks along with the compiled artifacts.
8) Test for numerical differences
The compiled version of a model is not guaranteed to produce a bit exact model. Even if the compiler can produce a bit exact version of the model, it's likely that some of the PyTorch model needed to change to be compilable. If that model has already been trained, the risk of mangling some of the architecture is high, which will translate to bad output from the compiled model.
For example, I had a model that I needed to compile that had a NormLayer in it. The weights of the NormLayer were being cast to bfloat16 before inference. NormLayer is not stable at this precision, and this was made more obvious when I compiled it. The compiled version of the NormLayer was producing outputs that were different by as much as 1e-1 from the eager version.
So test, test, and test again. The tolerance for differences may differ depending on the model, but starting at the defaults for torch.allclose(...)
is a good idea. I've seen some models do OK with as much as 1e-3 of difference, but this is a matter to be resolved via experiments in a specific model!
9) Turn on trace logs when compiling
Use an environment variable to enable trace logging during compilation: TORCH_TRACE="/tmp/tracedir"
. The folks at Meta have this option on by default when they compile their models. This is the best way to understand exactly what decisions the compiler made. The trace logs will include copies of the Triton Kernels that were generated.
10) Use an IDE with a Debugger when compiling
Compilation, especially for more custom models, is bound to fail. The errors printed in the output can be extremely cryptic, making it impossible to trace a problem back to the Python it originated from. IDE debuggers (like VS Code) provide easy tools to walk backwards in the stack trace and piece together which part of the model the issue emerges from.
Posted on November 16, 2024
Join Our Newsletter. No Spam, Only the good stuff.
Sign up to receive the latest update from our blog.