Skip to content

Rust bindings for PyTorch

License

Apache-2.0, MIT licenses found

Licenses found

Apache-2.0
LICENSE-APACHE
MIT
LICENSE-MIT
Notifications You must be signed in to change notification settings

sailfish009/tch-rs

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

tch-rs

Rust bindings for PyTorch. The goal of the tch crate is to provide some thin wrappers around the C++ PyTorch api (a.k.a. libtorch). It aims at staying as close as possible to the original C++ api. More idiomatic rust bindings could then be developed on top of this. The documentation can be found on docs.rs.

Build Status Latest version Documentation License

The code generation part for the C api on top of libtorch comes from ocaml-torch.

Getting Started

This crate requires the C++ PyTorch library (libtorch) in version v1.3.0 to be available on your system. You can either install it manually and let the build script know about it via the LIBTORCH environment variable. If not set, the build script will try downloading and extracting a pre-built binary version of libtorch.

Libtorch Manual Install

  • Get libtorch from the PyTorch website download section and extract the content of the zip file.
  • For Linux users, add the following to your .bashrc or equivalent, where /path/to/libtorch is the path to the directory that was created when unzipping the file.
export LIBTORCH=/path/to/libtorch
export LD_LIBRARY_PATH=${LIBTORCH}/lib:$LD_LIBRARY_PATH
  • For Windows users, navigate to Control Panel -> View advanced system settings -> Environment variables. Create LIBTORCH variable and set to X:\path\to\libtorch. Then, append X:\path\to\libtorch to Path variable, where X:\path\to\libtorch is the unzipped libtorch directory. If you prefer to temporarily set environment variables, in PowerShell you can run
$Env:LIBTORCH = "X:\path\to\libtorch"
$Env:Path += ";X:\path\to\libtorch"
  • You should now be able to run some examples, e.g. cargo run --example basics.

Examples

Basic Tensor Operations

This crate provides a tensor type which wraps PyTorch tensors. Here is a minimal example of how to perform some tensor operations.

extern crate tch;
use tch::Tensor;

fn main() {
    let t = Tensor::of_slice(&[3, 1, 4, 1, 5]);
    let t = t * 2;
    t.print();
}

Training a Model via Gradient Descent

PyTorch provides automatic differentiation for most tensor operations it supports. This is commonly used to train models using gradient descent. The optimization is performed over variables which are created via a nn::VarStore by defining their shapes and initializations.

In the example below my_module uses two variables x1 and x2 which initial values are 0. The forward pass applied to tensor xs returns xs * x1 + exp(xs) * x2.

Once the model has been generated, a nn::Sgd optimizer is created. Then on each step of the training loop:

  • The forward pass is applied to a mini-batch of data.
  • A loss is computed as the mean square error between the model output and the mini-batch ground truth.
  • Finally an optimization step is performed: gradients are computed and variables from the VarStore are modified accordingly.
fn my_module(p: nn::Path, dim: i64) -> impl nn::Module {
    let x1 = p.zeros("x1", &[dim]);
    let x2 = p.zeros("x2", &[dim]);
    nn::func(move |xs| xs * &x1 + xs.exp() * &x2)
}

fn gradient_descent() {
    let vs = nn::VarStore::new(Device::Cpu);
    let my_module = my_module(vs.root(), 7);
    let opt = nn::Sgd::default().build(&vs, 1e-2).unwrap();
    for _idx in 1..50 {
        // Dummy mini-batches made of zeros.
        let xs = Tensor::zeros(&[7], kind::FLOAT_CPU);
        let ys = Tensor::zeros(&[7], kind::FLOAT_CPU);
        let loss = (my_module.forward(&xs) - ys).pow(2).sum();
        opt.backward_step(&loss);
    }
}

Writing a Simple Neural Network

The nn api can be used to create neural network architectures, e.g. the following code defines a simple model with one hidden layer and trains it on the MNIST dataset using the Adam optimizer.

extern crate tch;
use tch::{nn, nn::Module, nn::OptimizerConfig, Device};

const IMAGE_DIM: i64 = 784;
const HIDDEN_NODES: i64 = 128;
const LABELS: i64 = 10;

fn net(vs: &nn::Path) -> impl Module {
    nn::seq()
        .add(nn::linear(vs / "layer1", IMAGE_DIM, HIDDEN_NODES, Default::default()))
        .add_fn(|xs| xs.relu())
        .add(nn::linear(vs, HIDDEN_NODES, LABELS, Default::default()))
}

pub fn run() -> failure::Fallible<()> {
    let m = tch::vision::mnist::load_dir("data")?;
    let vs = nn::VarStore::new(Device::Cpu);
    let net = net(&vs.root());
    let opt = nn::Adam::default().build(&vs, 1e-3)?;
    for epoch in 1..200 {
        let loss = net
            .forward(&m.train_images)
            .cross_entropy_for_logits(&m.train_labels);
        opt.backward_step(&loss);
        let test_accuracy = net
            .forward(&m.test_images)
            .accuracy_for_logits(&m.test_labels);
        println!(
            "epoch: {:4} train loss: {:8.5} test acc: {:5.2}%",
            epoch,
            f64::from(&loss),
            100. * f64::from(&test_accuracy),
        );
    }
    Ok(())
}

More details on the training loop can be found in the detailed tutorial.

Using some Pre-Trained Model

The pretrained-models example illustrates how to use some pre-trained computer vision model on an image. The weights - which have been extracted from the PyTorch implementation - can be downloaded here resnet18.ot and here resnet34.ot.

The example can then be run via the following command:

cargo run --example pretrained-models -- resnet18.ot tiger.jpg

This should print the top 5 imagenet categories for the image. The code for this example is pretty simple.

    // First the image is loaded and resized to 224x224.
    let image = imagenet::load_image_and_resize(image_file)?;

    // A variable store is created to hold the model parameters.
    let vs = tch::nn::VarStore::new(tch::Device::Cpu);

    // Then the model is built on this variable store, and the weights are loaded.
    let resnet18 = tch::vision::resnet::resnet18(vs.root(), imagenet::CLASS_COUNT);
    vs.load(weight_file)?;

    // Apply the forward pass of the model to get the logits and convert them
    // to probabilities via a softmax.
    let output = resnet18
        .forward_t(&image.unsqueeze(0), /*train=*/ false)
        .softmax(-1);

    // Finally print the top 5 categories and their associated probabilities.
    for (probability, class) in imagenet::top(&output, 5).iter() {
        println!("{:50} {:5.2}%", class, 100.0 * probability)
    }

Further examples include:

External material:

  • A tutorial showing how to use Torch to compute option prices and greeks.

License

tch-rs is distributed under the terms of both the MIT license and the Apache license (version 2.0), at your option.

See LICENSE-APACHE, LICENSE-MIT for more details.

About

Rust bindings for PyTorch

Resources

License

Apache-2.0, MIT licenses found

Licenses found

Apache-2.0
LICENSE-APACHE
MIT
LICENSE-MIT

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Rust 57.4%
  • C 26.3%
  • C++ 15.1%
  • OCaml 1.1%
  • Python 0.1%
  • CMake 0.0%