Trace-Driven Simulation in Astra Sim with the PyTorch Profiler
Published
Collecting traces from PyTorch, converting them to Chakra Execution Traces, and simulating them in Astra-Sim for hardware design in an end-to-end tutorial.
In this tutorial, we will be:
Creating a “Hello World” MNIST classifier in PyTorch
Training the model for a single epoch
Collecting traces using PyTorch Profiler + Execution Trace Observer
Merging traces and converting them to Chakra Execution Traces
Building Astra-Sim in docker
Running an Astra-Sim simulation with the Chakra Execution Traces
Hardware and Network Simulators
Simulators are well-established, good proxies for real-world systems. Hardware and network simulators provide timings and output characteristics very similar to real-world systems. This is useful because new hardware and network designs may be tested in simulation and have real-world impact.
Getting Started
To get started with simulation, we will begin by with the classical example of classifying handwritten digits using the MNIST dataset. From this model, a single training epoch will be measured and traces will be collected using PyTorch Profiler and PyTorch Execution Trace Observer. These traces will then be merged and converted to Chakra Execution Traces and simulated using Astra-Sim.
Next, we may import torch and setup cuda as a default device.
import torchimport torch.nn as nntorch.device('cuda' if torch.cuda.is_available() else 'cpu')
Data Loading
To load the MNIST data, we will use a DataLoader from PyTorch. This will allow us to load the data in batches and shuffle the data for training and testing.
from torchvision import datasetsfrom torchvision.transforms import ToTensorfrom torch.utils.data import DataLoader# Loading the MNIST datasettrain_data = datasets.MNIST(root='data', train=True, transform=ToTensor(), download=True)# Creating data loadertrain_data = DataLoader(train_data, batch_size=64, num_workers=4, pin_memory=True)
Now, we have our data loaders configured to provide training examples to our model.
Defining a Model
We will define a simple Convolutional Neural Network (CNN) for classifying the MNIST dataset. For this contrived example, the model used will only be for demonstration purposes.
class CNN(nn.Module): def __init__(self): super(CNN, self).__init__() self.conv1 = nn.Sequential( nn.Conv2d( in_channels=1, out_channels=40, kernel_size=5, stride=1, padding=2, ), nn.ReLU(), nn.MaxPool2d(kernel_size=2), ) self.conv2 = nn.Sequential( nn.Conv2d(40, 32, 5, 1, 2), nn.ReLU(), nn.MaxPool2d(2), ) # fully connected layer, output 10 classes self.out = nn.Linear(32 * 7 * 7, 10) def forward(self, x): x = self.conv1(x) x = self.conv2(x) # flatten the output of conv2 to (batch_size, 32 * 7 * 7) x = x.view(x.size(0), -1) output = self.out(x) return output, x # return x for visualizationcnn = CNN()cnn.cuda() # move model to GPU
Now, with the structure of the model defined, we can move on to model training.
Defining an Epoch Training Function
To train the mode, we will first define a basic function which takes a batch of data from the data loader and uses it to train the model.
loss_func = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(cnn.parameters(), lr=0.001)def train(batch_data): """ Very basic training loop for a CNN model """ # data movement images, labels = batch_data images, labels = images.cuda(), labels.cuda() # Forward pass output = cnn(images)[0] loss = loss_func(output, labels) # clear gradients for this training step optimizer.zero_grad() # backgpropagation, compute gradients loss.backward() # apply gradients optimizer.step()
Now, all we have to do is run the training loop with the profiler setup to cllect our traces.
Trace Collection
After setting up the model, we can collect traces with PyTorch Profiler. The PyTorch Profiler is a context manager that allows you to collect traces within your model’s training loop. For use in trace analysis, you must use a tensorboard_trace_handler to save the traces to disk with both CPU and GPU activity.
from torch.profiler import profile, schedule, tensorboard_trace_handler, ProfilerActivity, ExecutionTraceObserver# the tracing schedule helps control for startup costs and warmup in PyTorchtracing_schedule = schedule(skip_first=5, wait=5, warmup=2, active=2, repeat=1)et = ExecutionTraceObserver()et.register_callback("pytorch_et.json")et.start()with profile( activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA], schedule = tracing_schedule, on_trace_ready=lambda x: x.export_chrome_trace("kineto_trace.json"), profile_memory = True, record_shapes = True, with_stack = True) as prof: for epoch in range(1): # number of epochs for step, batch_data in list(enumerate(train_data))[:1]: train(batch_data) prof.step()et.stop()et.unregister_callback()
Following this, two files will be created: kineto_trace.json and pytorch_et.json. The kineto_trace.json file contains the traces for the CPU and GPU, while the pytorch_et.json file contains the PyTorch Execution Traces. Both files will be merged to create a single trace file.
Converting to Chakra Execution Traces (ET)
Charka is a library produced by Meta AI research with the goal of providing a standardized format for simulators, analyzers and collectors to use for traces.
Using Charka consists of converting PyTorch Execution Traces produced earlier into Chakra using Chakra’s execution trace converter.
Finally, we have produced the file chakra_traces which is in the MLCommons Chakra format using both: PyTorch Execution Traces and Kineto traces. This file may now be used as input to Astra-Sim.
Setting Up Astra-Sim
Astra Sim is a trace driven simulator designed for both hardware and network simulation. To get started, we will build Astra-Sim from source in a Docker container.
For more details about configuring Astra-Sim, please see the Run Astra-Sim Documenation. For more information about configuring Astra-Sim, check out the Astra-Sim Tutorial from ASPLOS which provides a comprehensive overview of the simulator.
Summary
In this tutorial, we have defined a simple CNN model for classifying the MNIST dataset. Model training was analyized with PyTorch Profiler and traces were collected. Traces were then able to be: analyzed and converted into Chakra executions traces. Finally, Astra-Sim was then used to simulate how these traces would execute on a network.