Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-547] Tutorial explaining how to use the profiler #11274

Merged
merged 12 commits into from
Jun 19, 2018
1 change: 1 addition & 0 deletions docs/tutorials/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Select API: 
* [Inference using an ONNX model](/tutorials/onnx/inference_on_onnx_model.html)
* [Fine-tuning an ONNX model on Gluon](/tutorials/onnx/fine_tuning_gluon.html)
* [Visualizing Decisions of Convolutional Neural Networks](/tutorials/vision/cnn_visualization.html)
* [Profiling MXNet Models](/tutorials/python/profiler.html)
* API Guides
* Core APIs
* NDArray
Expand Down
190 changes: 190 additions & 0 deletions docs/tutorials/python/profiler.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# Profiling MXNet Models

It is often helpful to understand what operations take how much time while running a model. This helps optimize the model to run faster. In this tutorial, we will learn how to profile MXNet models to measure their running time and memory consumption using the MXNet profiler.

## The incorrect way to profile
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not incorrect. You can still use wait_to_read to time the dot operation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. But I don't want to suggest wait_to_read as the recommended way to measure time taken by operations. While it might work for toy problems like this,

  • it is harder to use to measure execution time of multiple operations (requires wait_to_read both before and after the measured operation in multiple places).
  • It is hard to use to measure running time of a block inside a sequence (which is common)
  • it won't work for hybrid networks.

The goal of this tutorial is to point people to a recommended way of profiling that works for almost all cases.

However I can add a note along the lines of "While it is possible to use wait_to_read() before and after an operation to get running time of an operation, it is not a scalable method to measure running time of multiple operations, especially in a Sequential or Hybrid network"


If you have just begun using MXNet, you might be tempted to measure the execution time of your model using Python's `time` module like shown below:

```python
from time import time
from mxnet import autograd, nd
import mxnet as mx

start = time()
x = nd.random_uniform(shape=(2000,2000))
y = nd.dot(x, x)
print('Time for matrix multiplication: %f sec\n' % (time() - start))

start = time()
print(y.asnumpy())
print('Time for printing the output: %f sec' % (time() - start))
```


**Time for matrix multiplication: 0.005051 sec**<!--notebook-skip-line-->

[[501.1584 508.29724 495.65237 ... 492.84705 492.69092 490.0481 ]<!--notebook-skip-line-->

[508.81058 507.1822 495.1743 ... 503.10526 497.29315 493.67917]<!--notebook-skip-line-->

[489.56598 499.47015 490.17722 ... 490.99945 488.05008 483.28836]<!--notebook-skip-line-->

...<!--notebook-skip-line-->

[484.0019 495.7179 479.92142 ... 493.69952 478.89194 487.2074 ]<!--notebook-skip-line-->

[499.64932 507.65094 497.5938 ... 493.0474 500.74512 495.82712]<!--notebook-skip-line-->

[516.0143 519.1715 506.354 ... 510.08878 496.35608 495.42523]]<!--notebook-skip-line-->

**Time for printing the output: 0.167693 sec**<!--notebook-skip-line-->


From the output above, it seems as if printing the output takes lot more time that multiplying two large matrices. That doesn't feel right.

This is because, in MXNet, all operations are executed asynchronously. So, when `nd.dot(x, x)` returns, the matrix multiplication is not complete, it has only been queued for execution. `asnumpy` in `print(y.asnumpy())` however, waits for the result to be computed and hence takes longer time.

While it is possible to use `NDArray.waitall()` before and after operations to get running time of operations, it is not a scalable method to measure running time of multiple sets of operations, especially in a Sequential or Hybrid network.

## The correct way to profile

The correct way to measure running time of MXNet models is to use MXNet profiler. In the rest of this tutorial, we will learn how to use the MXNet profiler to measure the running time and memory consumption of MXNet models.

To use the profiler, you need to build MXNet with `USE_PROFILER=1`. Check the [installation](http://mxnet.incubator.apache.org/install/index.html) page for more information on how to install MXNet from source. After building with `USE_PROFILER=1` and installing, you can import the profiler and configure it from Python code.

```python
from mxnet import profiler
profiler.set_config(profile_all=True, aggregate_stats=True, filename='profile_output.json')
```

`profile_all` enables all types of profiling. You can also individually enable the following types of profiling:

- `profile_symbolic` (boolean): whether to profile symbolic operators
- `profile_imperative` (boolean): whether to profile imperative operators
- `profile_memory` (boolean): whether to profile memory usage
- `profile_api` (boolean): whether to profile the C API

`aggregate_stats` aggregates statistics in memory which can then be printed to console by calling `profiler.dumps()`.

### Setup: Build a model

Let's build a small convolutional neural network that we can use for profiling.

```python
from mxnet import gluon
net = gluon.nn.HybridSequential()
with net.name_scope():
net.add(gluon.nn.Conv2D(channels=20, kernel_size=5, activation='relu'))
net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))
net.add(gluon.nn.Conv2D(channels=50, kernel_size=5, activation='relu'))
net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))
net.add(gluon.nn.Flatten())
net.add(gluon.nn.Dense(512, activation="relu"))
net.add(gluon.nn.Dense(10))
```

We need data that we can run through the network for profiling. We'll use the MNIST dataset.

```python
from mxnet.gluon.data.vision import transforms
train_data = gluon.data.DataLoader(gluon.data.vision.MNIST(train=True).transform_first(transforms.ToTensor()),
batch_size=64, shuffle=True)
```

Let's define a method that will run one training iteration given data and label.

```python
# Use GPU if available
ctx = mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This API has a bug. It doesn't detect GPUs on Windows (due to usage of nvidia-smi command that may not exist in Windows). I have Github issue for this. Pasting here as FYI.


# Initialize the parameters with random weights
net.collect_params().initialize(mx.init.Xavier(), ctx=ctx)

# Use SGD optimizer
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': .1})

# Softmax Cross Entropy is a frequently used loss function for multi-classs classification
softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()

# A helper function to run one training iteration
def run_training_iteration(data, label):

# Load data and label is the right context
data = data.as_in_context(ctx)
label = label.as_in_context(ctx)

# Run the forward pass
with autograd.record():
output = net(data)
loss = softmax_cross_entropy(output, label)

# Run the backward pass
loss.backward()

# Apply changes to parameters
trainer.step(data.shape[0])
```

### Starting and stopping the profiler

When the first forward pass is run on a network, MXNet does a number of housekeeping tasks including inferring the shapes of various parameters, allocating memory for intermediate and final outputs, etc. For these reasons, profiling the first iteration doesn't provide accurate results. We will, therefore skip the first iteration.

```python
# Run the first iteration without profiling
itr = iter(train_data)
run_training_iteration(*next(itr))
```

We'll run the next iteration with the profiler turned on.

```python
data, label = next(itr)

# Ask the profiler to start recording
profiler.set_state('run')

run_training_iteration(*next(itr))

# Ask the profiler to stop recording
profiler.set_state('stop')
```

Between running and stopping the profiler, you can also pause and resume the profiler using `profiler.pause()` and `profiler.resume()` respectively to profile only parts of the code you want to profile.

### Viewing profiler output

There are two ways to view the information collected by the profiler. You can either view it in the console or you can view a more graphical version in a browser.

#### 1. View in console

You can use the `profiler.dumps()` method to view the information collected by the profiler in the console. The collected information contains time taken by each operator, time taken by each C API and memory consumed in both CPU and GPU.

```python
print(profiler.dumps())
```

![Profile Statistics](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/tutorials/python/profiler/profile_stats.png)<!--notebook-skip-line-->

#### 2. View in browser

You can also dump the information collected by the profiler into a `json` file using the `profiler.dump()` function and view it in a browser.

```python
profiler.dump()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, the difference between getting a plain text version vs. json is in calling "dumps()" vs "dump()"? Is it possible to change this signature?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the s in dumps indicates the method returns a string. Like pickle.dumps() or json.dumps()

```

`dump()` creates a `json` file which can be viewed using a trace consumer like `chrome://tracing` in the Chrome browser. Here is a snapshot that shows the output of the profiling we did above.

![Tracing Screenshot](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/tutorials/python/profiler/profiler_output_chrome.png)

Let's zoom in to check the time taken by operators

![Operator profiling](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/tutorials/python/profiler/profile_operators.png)

The above picture visualizes the sequence in which the operators were executed and the time taken by each operator.

If you would like to learn more about the profiler, there are more examples available [here](https://github.com/apache/incubator-mxnet/tree/master/example/profiler).

<!-- INSERT SOURCE DOWNLOAD BUTTONS -->

3 changes: 3 additions & 0 deletions tests/tutorials/test_tutorials.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,6 @@ def test_vision_large_scale_classification():

def test_vision_cnn_visualization():
assert _test_tutorial_nb('vision/cnn_visualization')

def test_python_profiler():
assert _test_tutorial_nb('python/profiler')