• Docs >
  • torch.utils.tensorboard
Shortcuts

torch.utils.tensorboard

Warning

This code is EXPERIMENTAL and might change in the future. It also currently does not support all model types for add_graph, which we are actively working on.

Before going further, more details on TensorBoard can be found at https://www.tensorflow.org/tensorboard/

Once you’ve installed TensorBoard, these utilities let you log PyTorch models and metrics into a directory for visualization within the TensorBoard UI. Scalars, images, histograms, graphs, and embedding visualizations are all supported for PyTorch models and tensors as well as Caffe2 nets and blobs.

The SummaryWriter class is your main entry to log data for consumption and visualization by TensorBoard. For example:

import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms

# Writer will output to ./runs/ directory by default
writer = SummaryWriter()

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = datasets.MNIST('mnist_train', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
model = torchvision.models.resnet50(False)
# Have ResNet model take in grayscale rather than RGB
model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
images, labels = next(iter(trainloader))

grid = torchvision.utils.make_grid(images)
writer.add_image('images', grid, 0)
writer.add_graph(model, images)
writer.close()

This can then be visualized with TensorBoard, which should be installable and runnable with:

pip install tb-nightly  # Until 1.14 moves to the release channel
tensorboard --logdir=runs
class torch.utils.tensorboard.writer.SummaryWriter(log_dir=None, comment='', **kwargs)[source]

Writes entries directly to event files in the log_dir to be consumed by TensorBoard.

The SummaryWriter class provides a high-level API to create an event file in a given directory and add summaries and events to it. The class updates the file contents asynchronously. This allows a training program to call methods to add data to the file directly from the training loop, without slowing down training.

add_scalar(tag, scalar_value, global_step=None, walltime=None)[source]

Add scalar data to summary.

Parameters
  • tag (string) – Data identifier

  • scalar_value (float or string/blobname) – Value to save

  • global_step (int) – Global step value to record

  • walltime (float) – Optional override default walltime (time.time()) with seconds after epoch of event

add_histogram(tag, values, global_step=None, bins='tensorflow', walltime=None, max_bins=None)[source]

Add histogram to summary.

Parameters
  • tag (string) – Data identifier

  • values (torch.Tensor, numpy.array, or string/blobname) – Values to build histogram

  • global_step (int) – Global step value to record

  • bins (string) – one of {‘tensorflow’,’auto’, ‘fd’, …}, this determines how the bins are made. You can find other options in: https://docs.scipy.org/doc/numpy/reference/generated/numpy.histogram.html

  • walltime (float) – Optional override default walltime (time.time()) seconds after epoch of event

add_image(tag, img_tensor, global_step=None, walltime=None, dataformats='CHW')[source]

Add image data to summary.

Note that this requires the pillow package.

Parameters
  • tag (string) – Data identifier

  • img_tensor (torch.Tensor, numpy.array, or string/blobname) – Image data

  • global_step (int) – Global step value to record

  • walltime (float) – Optional override default walltime (time.time()) seconds after epoch of event

Shape:

img_tensor: Default is \((3, H, W)\). You can use torchvision.utils.make_grid() to convert a batch of tensor into 3xHxW format or call add_images and let us do the job. Tensor with \((1, H, W)\), \((H, W)\), \((H, W, 3)\) is also suitible as long as corresponding dataformats argument is passed. e.g. CHW, HWC, HW.

add_figure(tag, figure, global_step=None, close=True, walltime=None)[source]

Render matplotlib figure into an image and add it to summary.

Note that this requires the matplotlib package.

Parameters
  • tag (string) – Data identifier

  • figure (matplotlib.pyplot.figure) – figure or a list of figures

  • global_step (int) – Global step value to record

  • close (bool) – Flag to automatically close the figure

  • walltime (float) – Optional override default walltime (time.time()) seconds after epoch of event

add_video(tag, vid_tensor, global_step=None, fps=4, walltime=None)[source]

Add video data to summary.

Note that this requires the moviepy package.

Parameters
  • tag (string) – Data identifier

  • vid_tensor (torch.Tensor) – Video data

  • global_step (int) – Global step value to record

  • fps (float or int) – Frames per second

  • walltime (float) – Optional override default walltime (time.time()) seconds after epoch of event

Shape:

vid_tensor: \((N, T, C, H, W)\). The values should lie in [0, 255] for type uint8 or [0, 1] for type float.

add_audio(tag, snd_tensor, global_step=None, sample_rate=44100, walltime=None)[source]

Add audio data to summary.

Parameters
  • tag (string) – Data identifier

  • snd_tensor (torch.Tensor) – Sound data

  • global_step (int) – Global step value to record

  • sample_rate (int) – sample rate in Hz

  • walltime (float) – Optional override default walltime (time.time()) seconds after epoch of event

Shape:

snd_tensor: \((1, L)\). The values should lie between [-1, 1].

add_text(tag, text_string, global_step=None, walltime=None)[source]

Add text data to summary.

Parameters
  • tag (string) – Data identifier

  • text_string (string) – String to save

  • global_step (int) – Global step value to record

  • walltime (float) – Optional override default walltime (time.time()) seconds after epoch of event

Examples:

writer.add_text('lstm', 'This is an lstm', 0)
writer.add_text('rnn', 'This is an rnn', 10)
add_graph(model, input_to_model=None, verbose=False, **kwargs)[source]

Add graph data to summary.

Parameters
  • model (torch.nn.Module) – model to draw.

  • input_to_model (torch.Tensor or list of torch.Tensor) – a variable or a tuple of variables to be fed.

  • verbose (bool) – Whether to print graph structure in console.

  • omit_useless_nodes (bool) – Default to true, which eliminates unused nodes.

  • operator_export_type (string) – One of: "ONNX", "RAW". This determines the optimization level of the graph. If error happens during exporting the graph, use "RAW" may help.

add_embedding(mat, metadata=None, label_img=None, global_step=None, tag='default', metadata_header=None)[source]

Add embedding projector data to summary.

Parameters
  • mat (torch.Tensor or numpy.array) – A matrix which each row is the feature vector of the data point

  • metadata (list) – A list of labels, each element will be convert to string

  • label_img (torch.Tensor) – Images correspond to each data point

  • global_step (int) – Global step value to record

  • tag (string) – Name for the embedding

Shape:

mat: \((N, D)\), where N is number of data and D is feature dimension

label_img: \((N, C, H, W)\)

Examples:

import keyword
import torch
meta = []
while len(meta)<100:
    meta = meta+keyword.kwlist # get some strings
meta = meta[:100]

for i, v in enumerate(meta):
    meta[i] = v+str(i)

label_img = torch.rand(100, 3, 10, 32)
for i in range(100):
    label_img[i]*=i/100.0

writer.add_embedding(torch.randn(100, 5), metadata=meta, label_img=label_img)
writer.add_embedding(torch.randn(100, 5), label_img=label_img)
writer.add_embedding(torch.randn(100, 5), metadata=meta)
add_pr_curve(tag, labels, predictions, global_step=None, num_thresholds=127, weights=None, walltime=None)[source]

Adds precision recall curve.

Parameters
  • tag (string) – Data identifier

  • labels (torch.Tensor, numpy.array, or string/blobname) – Ground truth data. Binary label for each element.

  • predictions (torch.Tensor, numpy.array, or string/blobname) –

  • probability that an element be classified as true. Value should in [0, 1] (The) –

  • global_step (int) – Global step value to record

  • num_thresholds (int) – Number of thresholds used to draw the curve.

  • walltime (float) – Optional override default walltime (time.time()) seconds after epoch of event

add_custom_scalars(layout)[source]

Create special chart by collecting charts tags in ‘scalars’. Note that this function can only be called once for each SummaryWriter() object. Because it only provides metadata to tensorboard, the function can be called before or after the training loop.

Parameters

layout (dict) – {categoryName: charts}, where charts is also a dictionary {chartName: ListOfProperties}. The first element in ListOfProperties is the chart’s type (one of Multiline or Margin) and the second element should be a list containing the tags you have used in add_scalar function, which will be collected into the new chart.

Examples:

layout = {'Taiwan':{'twse':['Multiline',['twse/0050', 'twse/2330']]},
             'USA':{ 'dow':['Margin',   ['dow/aaa', 'dow/bbb', 'dow/ccc']],
                  'nasdaq':['Margin',   ['nasdaq/aaa', 'nasdaq/bbb', 'nasdaq/ccc']]}}

writer.add_custom_scalars(layout)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources