From 5c532ec5dc51cd17cd4bb9ae940ecf2c9baf89f6 Mon Sep 17 00:00:00 2001 From: rainiwu Date: Fri, 26 Jan 2024 00:29:35 -0800 Subject: [PATCH 01/10] remove deprecated ftz intrinsics --- dfdx-core/src/lib.rs | 38 -------------------------------------- dfdx/examples/12-mnist.rs | 3 --- 2 files changed, 41 deletions(-) diff --git a/dfdx-core/src/lib.rs b/dfdx-core/src/lib.rs index 31e61643..c126db2c 100644 --- a/dfdx-core/src/lib.rs +++ b/dfdx-core/src/lib.rs @@ -128,44 +128,6 @@ pub mod prelude { pub use crate::tensor_ops::*; } -/// Sets a CPU `sse` flag to flush denormal floating point numbers to zero. The opposite of this is [keep_denormals()]. -/// -/// Some resources: -/// 1. [Effects of Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/the-effects-of-using-flush-to-zero-mode?lang=en) -/// 2. [When to use Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/when-to-use-flush-to-zero-mode?lang=en) -pub fn flush_denormals_to_zero() { - #[cfg(all(target_arch = "x86", target_feature = "sse"))] - { - use std::arch::x86::{_MM_FLUSH_ZERO_ON, _MM_SET_FLUSH_ZERO_MODE}; - unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON) } - } - - #[cfg(all(target_arch = "x86_64", target_feature = "sse"))] - { - use std::arch::x86_64::{_MM_FLUSH_ZERO_ON, _MM_SET_FLUSH_ZERO_MODE}; - unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON) } - } -} - -/// Sets a CPU flag to keep denormal floating point numbers. The opposite of this is [flush_denormals_to_zero()]. -/// -/// Some resources: -/// 1. [Effects of Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/the-effects-of-using-flush-to-zero-mode?lang=en) -/// 2. [When to use Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/when-to-use-flush-to-zero-mode?lang=en) -pub fn keep_denormals() { - #[cfg(all(target_arch = "x86", target_feature = "sse"))] - { - use std::arch::x86::{_MM_FLUSH_ZERO_OFF, _MM_SET_FLUSH_ZERO_MODE}; - unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_OFF) } - } - - #[cfg(all(target_arch = "x86_64", target_feature = "sse"))] - { - use std::arch::x86_64::{_MM_FLUSH_ZERO_OFF, _MM_SET_FLUSH_ZERO_MODE}; - unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_OFF) } - } -} - #[cfg(test)] pub(crate) mod tests { pub use num_traits::{Float, NumCast, Zero}; diff --git a/dfdx/examples/12-mnist.rs b/dfdx/examples/12-mnist.rs index 705d14c8..00d43452 100644 --- a/dfdx/examples/12-mnist.rs +++ b/dfdx/examples/12-mnist.rs @@ -62,9 +62,6 @@ type Mlp = ( const BATCH_SIZE: usize = 32; fn main() { - // ftz substantially improves performance - dfdx::flush_denormals_to_zero(); - let mnist_path = std::env::args() .nth(1) .unwrap_or_else(|| "./datasets/MNIST/raw".to_string()); From fb91f13314fb24a67c2d8e14ad40345d2d334805 Mon Sep 17 00:00:00 2001 From: rainiwu Date: Fri, 26 Jan 2024 00:55:48 -0800 Subject: [PATCH 02/10] suppress spurious cargo clippy warning --- dfdx-core/src/data/collate.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/dfdx-core/src/data/collate.rs b/dfdx-core/src/data/collate.rs index d38a2a67..5f52d636 100644 --- a/dfdx-core/src/data/collate.rs +++ b/dfdx-core/src/data/collate.rs @@ -55,6 +55,7 @@ impl Collate for Vec<(A, B)> { impl<'a, A, B> Collate for Vec<&'a (A, B)> { type Collated = (Vec<&'a A>, Vec<&'a B>); fn collated(self) -> Self::Collated { + #[allow(clippy::map_identity)] self.into_iter().map(|(a, b)| (a, b)).unzip() } } From 4e3f7c7a24728668f72cf3617a66f4476280f6fb Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Tue, 6 Feb 2024 18:27:46 -0500 Subject: [PATCH 03/10] avoid conv1d bound for cudnn --- dfdx-core/src/tensor_ops/utilities/device.rs | 50 +++++++++++++++----- 1 file changed, 39 insertions(+), 11 deletions(-) diff --git a/dfdx-core/src/tensor_ops/utilities/device.rs b/dfdx-core/src/tensor_ops/utilities/device.rs index 8cbc2137..91f87cf6 100644 --- a/dfdx-core/src/tensor_ops/utilities/device.rs +++ b/dfdx-core/src/tensor_ops/utilities/device.rs @@ -114,25 +114,49 @@ pub trait Device: + crate::tensor_ops::axpy::AxpyKernel // conv1d - + super::super::conv1d::Conv1DKernel + + NonCudnnCuda +{ +} + +#[cfg(feature = "cudnn")] +pub trait NonCudnnCuda {} + +#[cfg(not(feature = "cudnn"))] +pub trait NonCudnnCuda: + // conv1d + super::super::conv1d::Conv1DKernel { } #[cfg(feature = "f16")] -impl Device for crate::tensor::Cpu {} -#[cfg(feature = "f16")] -impl Device> for crate::tensor::Cpu {} +mod f16_ { + use super::*; + impl Device for crate::tensor::Cpu {} + impl NonCudnnCuda for crate::tensor::Cpu {} + impl Device> for crate::tensor::Cpu {} + impl NonCudnnCuda> for crate::tensor::Cpu {} +} impl Device for crate::tensor::Cpu {} +impl NonCudnnCuda for crate::tensor::Cpu {} impl Device for crate::tensor::Cpu {} +impl NonCudnnCuda for crate::tensor::Cpu {} #[cfg(all(feature = "cuda", feature = "f16"))] -impl Device for crate::tensor::Cuda {} -#[cfg(all(feature = "cuda", feature = "f16"))] -impl Device> for crate::tensor::Cuda {} -#[cfg(feature = "cuda")] -impl Device for crate::tensor::Cuda {} +mod cuda_f16 { + use super::*; + impl Device for crate::tensor::Cuda {} + impl NonCudnnCuda for crate::tensor::Cuda {} + impl Device> for crate::tensor::Cuda {} + impl NonCudnnCuda> for crate::tensor::Cuda {} +} #[cfg(feature = "cuda")] -impl Device for crate::tensor::Cuda {} +mod cuda { + use super::*; + impl Device for crate::tensor::Cuda {} + impl NonCudnnCuda for crate::tensor::Cuda {} + impl Device for crate::tensor::Cuda {} + impl NonCudnnCuda for crate::tensor::Cuda {} +} // TODO: How can we implement this for f16 when WGSL doesn't support f16 yet? // #[cfg(all(feature = "webgpu", feature = "f16"))] @@ -140,7 +164,11 @@ impl Device for crate::tensor::Cuda {} // #[cfg(all(feature = "webgpu", feature = "f16"))] // impl Device> for crate::tensor::Webgpu {} #[cfg(feature = "webgpu")] -impl Device for crate::tensor::Webgpu {} +mod webgpu { + use super::*; + impl Device for crate::tensor::Webgpu {} + impl NonCudnnCuda for crate::tensor::Webgpu {} +} // TODO: How can we implement this for f64 when WGSL doesn't support f64 yet? // #[cfg(feature = "webgpu")] From a8bc54c5c8e02c68fe09e72fc94ba0a8b3273b9a Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Fri, 9 Feb 2024 11:53:40 -0500 Subject: [PATCH 04/10] bump gemm --- dfdx-core/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dfdx-core/Cargo.toml b/dfdx-core/Cargo.toml index 5309ef7c..0f6cd5c6 100644 --- a/dfdx-core/Cargo.toml +++ b/dfdx-core/Cargo.toml @@ -35,7 +35,7 @@ num-traits = { workspace = true } safetensors = { workspace = true, optional = true } memmap2 = { workspace = true, optional = true } half = { version = "2.3.1", optional = true, features = ["num-traits", "rand_distr"] } -gemm = { version = "0.16.14", default-features = false, optional = true, features = ["rayon"] } +gemm = { version = "0.17.1", default-features = false, optional = true, features = ["rayon"] } rayon = { version = "1.7.0", optional = true } libm = { workspace = true } wgpu = { version = "0.18.0", features = ["glsl", "spirv"], optional = true } From 557687c0a9e29dfba2311fe67414863c6c5137bf Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Fri, 9 Feb 2024 12:52:05 -0500 Subject: [PATCH 05/10] clippy fix --- dfdx-core/src/tensor/gradients.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dfdx-core/src/tensor/gradients.rs b/dfdx-core/src/tensor/gradients.rs index 86974ec6..d24e2e32 100644 --- a/dfdx-core/src/tensor/gradients.rs +++ b/dfdx-core/src/tensor/gradients.rs @@ -153,7 +153,7 @@ impl> Gradients { #[inline] pub(crate) fn many_and_ref( &mut self, - ls: &Vec>, + ls: &[impl Tensorlike], r: &impl Tensorlike, ) -> (Vec<&mut D::Vec>, &D::Vec) { for i in 0..ls.len() { From ea879a155ef38a68973ea512226b73ab07860176 Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Tue, 5 Dec 2023 16:58:20 -0500 Subject: [PATCH 06/10] reduce tensor sizes for 'slow' tests --- dfdx/src/nn/layers/conv1d.rs | 59 +++++---- dfdx/src/nn/layers/conv2d.rs | 67 +++++----- dfdx/src/nn/layers/conv_trans2d.rs | 30 +++-- dfdx/src/nn/layers/multi_head_attention.rs | 57 ++++---- dfdx/src/nn/layers/transformer.rs | 145 +++++++++------------ 5 files changed, 171 insertions(+), 187 deletions(-) diff --git a/dfdx/src/nn/layers/conv1d.rs b/dfdx/src/nn/layers/conv1d.rs index 5241b912..0986d1af 100644 --- a/dfdx/src/nn/layers/conv1d.rs +++ b/dfdx/src/nn/layers/conv1d.rs @@ -174,47 +174,50 @@ mod tests { fn test_grouped_forward_sizes() { let dev: TestDevice = Default::default(); - let x = dev.ones::>(); + let x = dev.ones::>(); - let m = dev.build_module::(>::default()); - let _: Tensor, _, _> = m.weight; - let _: Tensor, _, _> = m.forward(x.clone()); + let m = dev.build_module::(>::default()); + let _: Tensor, _, _> = m.weight; + let _: Tensor, _, _> = m.forward(x.clone()); - let m = - dev.build_module::(>::default()); - let _: Tensor, _, _> = m.weight; - let _: Tensor, _, _> = m.forward(x.clone()); + let m = dev.build_module::(>::default()); + let _: Tensor, _, _> = m.weight; + let _: Tensor, _, _> = m.forward(x.clone()); - let m = - dev.build_module::(>::default()); - let _: Tensor, _, _> = m.weight; - let _: Tensor, _, _> = m.forward(x.clone()); + let x = dev.ones::>(); - let m = - dev.build_module::(>::default()); - let _: Tensor, _, _> = m.weight; - let _: Tensor, _, _> = m.forward(x.clone()); + let m = dev.build_module::(>::default()); + let _: Tensor, _, _> = m.weight; + let _: Tensor, _, _> = m.forward(x.clone()); + + let x = dev.ones::>(); + + let m = dev.build_module::(>::default()); + let _: Tensor, _, _> = m.weight; + let _: Tensor, _, _> = m.forward(x.clone()); + + let x = dev.ones::>(); let m = dev.build_module::(>::default()); let _: Tensor, _, _> = m.weight; - let _: Tensor, _, _> = m.forward(x); + let _: Tensor, _, _> = m.forward(x); } #[rustfmt::skip] #[test] fn test_forward_4d_sizes() { let dev: TestDevice = Default::default(); - let x = dev.zeros::>(); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let x = dev.zeros::>(); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); } #[test] @@ -248,7 +251,7 @@ mod tests { let weight_init = m.weight.clone(); let mut opt = crate::nn::optim::Sgd::new(&m, Default::default()); - let out = m.forward(dev.sample_normal::>().leaky_trace()); + let out = m.forward(dev.sample_normal::>().leaky_trace()); let g = out.square().mean().backward(); assert_ne!(g.get(&m.weight).array(), [[[TestDtype::zero(); 3]; 2]; 4]); diff --git a/dfdx/src/nn/layers/conv2d.rs b/dfdx/src/nn/layers/conv2d.rs index c88ea821..a4cd0b5e 100644 --- a/dfdx/src/nn/layers/conv2d.rs +++ b/dfdx/src/nn/layers/conv2d.rs @@ -197,48 +197,53 @@ mod tests { fn test_grouped_forward_sizes() { let dev: TestDevice = Default::default(); - let x = dev.zeros::>(); + let x = dev.zeros::>(); - let m = - dev.build_module::(>::default()); - let _: Tensor, _, _> = m.weight; - let _: Tensor, _, _> = m.forward(x.clone()); + let m = dev.build_module::(>::default()); + let _: Tensor, _, _> = m.weight; + let _: Tensor, _, _> = m.forward(x.clone()); - let m = - dev.build_module::(>::default()); - let _: Tensor, _, _> = m.weight; - let _: Tensor, _, _> = m.forward(x.clone()); + let m = dev.build_module::(>::default()); + let _: Tensor, _, _> = m.weight; + let _: Tensor, _, _> = m.forward(x.clone()); - let m = - dev.build_module::(>::default()); - let _: Tensor, _, _> = m.weight; - let _: Tensor, _, _> = m.forward(x.clone()); + let x = dev.zeros::>(); - let m = - dev.build_module::(>::default()); - let _: Tensor, _, _> = m.weight; - let _: Tensor, _, _> = m.forward(x.clone()); + let m = dev.build_module::(>::default()); + let _: Tensor, _, _> = m.weight; + let _: Tensor, _, _> = m.forward(x.clone()); + + let x = dev.zeros::>(); + + let m = dev.build_module::(>::default()); + let _: Tensor, _, _> = m.weight; + let _: Tensor, _, _> = m.forward(x.clone()); + + let x = dev.zeros::>(); let m = dev.build_module::(>::default()); let _: Tensor, _, _> = m.weight; - let _: Tensor, _, _> = m.forward(x); + let _: Tensor, _, _> = m.forward(x); } #[rustfmt::skip] #[test] fn test_forward_4d_sizes() { let dev: TestDevice = Default::default(); - let x = dev.zeros::>(); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let x = dev.zeros::>(); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let x = dev.zeros::>(); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let x = dev.zeros::>(); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let x = dev.zeros::>(); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); } #[test] @@ -267,17 +272,17 @@ mod tests { fn test_conv_with_optimizer() { let dev: TestDevice = Default::default(); - let mut m = dev.build_module::(Conv2DConstConfig::<2, 4, 3>::default()); + let mut m = dev.build_module::(Conv2DConstConfig::<2, 3, 2>::default()); let weight_init = m.weight.clone(); let mut opt = crate::nn::optim::Sgd::new(&m, Default::default()); - let out = m.forward(dev.sample_normal::>().leaky_trace()); + let out = m.forward(dev.sample_normal::>().leaky_trace()); let g = out.square().mean().backward(); assert_ne!( g.get(&m.weight).array(), - [[[[TestDtype::zero(); 3]; 3]; 2]; 4] + [[[[TestDtype::zero(); 2]; 2]; 2]; 3] ); opt.update(&mut m, &g).expect("unused params"); diff --git a/dfdx/src/nn/layers/conv_trans2d.rs b/dfdx/src/nn/layers/conv_trans2d.rs index b7683676..943f3d85 100644 --- a/dfdx/src/nn/layers/conv_trans2d.rs +++ b/dfdx/src/nn/layers/conv_trans2d.rs @@ -180,16 +180,24 @@ mod tests { #[test] fn test_forward_4d_sizes() { let dev: TestDevice = Default::default(); - let x = dev.zeros::>(); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); - let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + + let x = dev.zeros::>(); + + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + + let x = dev.zeros::>(); + + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); + + let x = dev.zeros::>(); + + let _: Tensor, _, _, _> = dev.build_module::(>::default()).forward(x.clone()); } #[test] @@ -225,7 +233,7 @@ mod tests { let weight_init = m.weight.clone(); let mut opt = crate::nn::optim::Sgd::new(&m, Default::default()); - let out = m.forward(dev.sample_normal::>().leaky_trace()); + let out = m.forward(dev.sample_normal::>().leaky_trace()); let g = out.square().mean().backward(); assert_ne!( diff --git a/dfdx/src/nn/layers/multi_head_attention.rs b/dfdx/src/nn/layers/multi_head_attention.rs index 1232b433..fba2d08c 100644 --- a/dfdx/src/nn/layers/multi_head_attention.rs +++ b/dfdx/src/nn/layers/multi_head_attention.rs @@ -208,11 +208,11 @@ mod tests { fn test_mha_batched() { let dev = TestDevice::seed_from_u64(1); - const BATCH: usize = 5; - const M: usize = 8; + const BATCH: usize = 2; + const M: usize = 4; const NUM_HEADS: usize = 2; const S1: usize = 3; - const S2: usize = 4; + const S2: usize = 2; type Dtype = f32; @@ -224,42 +224,37 @@ mod tests { let k: Tensor, Dtype, _> = dev.sample_normal(); let v: Tensor, Dtype, _> = dev.sample_normal(); - let y = mha.forward((q, k, v)); + // uncomment to save for this specific test params and inputs + // + // mha.save_safetensors("mha.safetensor").unwrap(); + // q.save_safetensors("q.safetensor").unwrap(); + // k.save_safetensors("k.safetensor").unwrap(); + // v.save_safetensors("v.safetensor").unwrap(); + + let y = mha.forward((q.clone(), k.clone(), v.clone())); + + // uncomment to save for this specific test params and inputs + // + // y.save_safetensors("y.safetensor").unwrap(); // This expected y was generated by: // 1. saving `mha` parameters, `q`, `k`, `v` to a file // 2. Running pytorch with the same values // 3. printing out the output // See https://github.com/coreylowman/dfdx/wiki/Exporting-MultiHeadAttention-to-pytorch-for-unit-tests - #[rustfmt::skip] assert_close_to_literal!( y, [ [ - [-0.32666653, 0.23977730, 0.25563523,-0.46537930, 0.19651681,-0.37467819, 0.44978297, 0.04501118], - [-0.32847843, 0.22905068, 0.24268147,-0.49660331, 0.17547092,-0.41919118, 0.45197228,-0.01052883], - [-0.28976738, 0.26420441, 0.24134403,-0.41927847, 0.21895495,-0.35072452, 0.44843924, 0.07374063], + [-0.16630043, 0.01757687, 0.22978050, 0.50355506], + [-0.19439587, 0.02942148, 0.23266082, 0.48612449], + [-0.19675586, 0.06542480, 0.18101424, 0.43833256] ], [ - [-0.10029950, 0.15455982, 0.23578438,-0.36703593, 0.03778699,-0.41743413, 0.50207543, 0.11432818], - [-0.04076880, 0.24567264, 0.23325926,-0.19454414, 0.11575195,-0.22209120, 0.49752438, 0.30388331], - [-0.06600001, 0.20277922, 0.24651963,-0.24732135, 0.08645092,-0.28015324, 0.49499762, 0.23243824], - ], - [ - [-0.18352799, 0.15783942, 0.36657059,-0.24797240, 0.11065251,-0.22565264, 0.46300891, 0.18687661], - [-0.15986431, 0.26687002, 0.30500177,-0.22695602, 0.18453379,-0.21377291, 0.46498343, 0.30064404], - [-0.09165541, 0.31019136, 0.20057595,-0.29627919, 0.15811513,-0.33667034, 0.48559439, 0.32546705], - ], - [ - [-0.45827997, 0.08988418, 0.44279462,-0.45245945, 0.16884868,-0.26618001, 0.40024126, 0.01272556], - [-0.43258160, 0.11801003, 0.42784777,-0.41539627, 0.19628736,-0.23836099, 0.39999473, 0.05304383], - [-0.44729146, 0.09233949, 0.45179683,-0.41795415, 0.16631508,-0.22713992, 0.39473629, 0.04260518], - ], - [ - [-0.51776350, 0.05404706, 0.39951840,-0.61738086, 0.21067555,-0.51225299, 0.41040331,-0.25894681], - [-0.47914022, 0.09410305, 0.36355501,-0.59280866, 0.24956036,-0.50058168, 0.40235144,-0.16756263], - [-0.55189615,-0.06088167, 0.41224611,-0.76746291, 0.09680001,-0.70136547, 0.40278757,-0.45541200], - ], + [-0.23499183, -0.21414454, 0.32811928, 0.46780989], + [-0.25318044, -0.20085460, 0.37180322, 0.52941465], + [-0.22117066, -0.23581570, 0.36783585, 0.53560883] + ] ] ); } @@ -269,11 +264,11 @@ mod tests { let dev: TestDevice = Default::default(); let mut mha = dev - .build_module::(, Const<4>>>::default()); + .build_module::(, Const<2>>>::default()); - let q: Tensor, TestDtype, _> = dev.sample_normal(); - let k: Tensor, TestDtype, _> = dev.sample_normal(); - let v: Tensor, TestDtype, _> = dev.sample_normal(); + let q: Tensor, TestDtype, _> = dev.sample_normal(); + let k: Tensor, TestDtype, _> = dev.sample_normal(); + let v: Tensor, TestDtype, _> = dev.sample_normal(); let y = mha.forward((q.leaky_trace(), k, v)); let g = y.square().mean().backward(); diff --git a/dfdx/src/nn/layers/transformer.rs b/dfdx/src/nn/layers/transformer.rs index fa7ab76a..a0e50a30 100644 --- a/dfdx/src/nn/layers/transformer.rs +++ b/dfdx/src/nn/layers/transformer.rs @@ -204,38 +204,31 @@ mod tests { fn test_transformer_forward() { let dev = TestDevice::seed_from_u64(0); let mut t = dev.build_module::(TransformerConfig::new( - Const::<16>, - Const::<4>, - Const::<8>, - 3, - 3, + Const::<2>, Const::<2>, Const::<2>, 2, 2, )); // unbatched - let src = dev.sample_normal::>(); - let tgt = dev.sample_normal::>(); - let _: Tensor, _, _, _> = t.forward_mut((src, tgt)); + let src = dev.sample_normal::>(); + let tgt = dev.sample_normal::>(); + let _: Tensor, _, _, _> = t.forward_mut((src, tgt)); // batched - let src = dev.sample_normal::>(); - let tgt = dev.sample_normal::>(); - let _: Tensor, _, _, _> = t.forward_mut((src, tgt)); + let src = dev.sample_normal::>(); + let tgt = dev.sample_normal::>(); + let _: Tensor, _, _, _> = t.forward_mut((src, tgt)); } #[test] fn test_transformer_backward() { let dev = TestDevice::seed_from_u64(0); + let mut t = dev.build_module::(TransformerConfig::new( - Const::<16>, - Const::<4>, - Const::<8>, - 3, - 3, + Const::<2>, Const::<2>, Const::<2>, 2, 2, )); - let src = dev.sample_normal::>(); - let tgt = dev.sample_normal::>(); - let out: Tensor, _, _, _> = t.forward_mut((src.leaky_trace(), tgt)); + let src = dev.sample_normal::>(); + let tgt = dev.sample_normal::>(); + let out: Tensor, _, _, _> = t.forward_mut((src.leaky_trace(), tgt)); let g = out.mean().backward(); let mut opt = crate::nn::optim::Sgd::new(&t, Default::default()); @@ -246,11 +239,11 @@ mod tests { fn test_encoder_block_forward() { let dev = TestDevice::seed_from_u64(2); - const BATCH: usize = 3; - const SEQ_LEN: usize = 5; - const EMBED_DIM: usize = 9; - const NUM_HEADS: usize = 3; - const FF_DIM: usize = 16; + const BATCH: usize = 2; + const SEQ_LEN: usize = 3; + const EMBED_DIM: usize = 4; + const NUM_HEADS: usize = 2; + const FF_DIM: usize = 2; type Dtype = f32; @@ -261,38 +254,36 @@ mod tests { )); let x: Tensor, Dtype, _> = dev.sample_normal(); + + // uncomment to save for this specific test params and inputs + // + // encoder.save_safetensors("encoder.safetensor").unwrap(); + // x.save_safetensors("x.safetensor").unwrap(); + let y = encoder.forward(x); + // uncomment to save for this specific test params and inputs + // + // y.save_safetensors("y.safetensor").unwrap(); + // This expected y was generated by: // 1. saving `encoder` parameters, `x` and `y` to a npz files // 2. Running pytorch with the same values // 3. printing out the output // See https://github.com/coreylowman/dfdx/wiki/Exporting-MultiHeadAttention-to-pytorch-for-unit-tests - #[rustfmt::skip] assert_close_to_literal!( y, [ [ - [0.83316803, 0.85057360, 0.37431455, 1.48506296,-0.38405111,-1.89352179,-1.07049453,-0.50913972, 0.31408834], - [-0.57205188, 0.64078861,-0.56589824, 0.67155081, 0.65419787, 0.28409126,-1.75282931, 1.68111539,-1.04096484], - [-0.01414229, 1.34985816, 0.09684382, 0.13165890,-1.39875984,-1.61741352, 1.28747427, 0.75574619,-0.59126562], - [0.12542287, 2.60457349, 0.21064451,-0.81285846,-0.15861531,-0.87273139,-0.81707120,-0.17004849,-0.10931605], - [-1.54970682,-0.77183282, 1.37495196,-0.69562960,-0.66684282, 0.24720824, 1.38581741,-0.35962212, 1.03565681], + [-1.7209842, 0.6216407, 0.7037436, 0.39559996], + [0.53576326, -1.4666773, 1.2166189, -0.28570476], + [-1.3280064, 0.42387456, -0.45566577, 1.3597975] ], [ - [-0.15229249,-0.90768278,-0.85165489, 0.12768827, 1.61459768, 1.25826979,-0.46860829, 0.87496787,-1.49528503], - [-1.35595357, 1.13305736,-0.08542954, 1.01601434,-0.04678532,-1.69470263, 0.76144469,-0.68443829, 0.95679283], - [-1.49877191, 0.64559501, 0.33383703, 1.73698330,-0.14289393, 1.17869902,-1.01659226,-0.61038357,-0.62647283], - [0.78263682, 0.78481543,-0.16064386, 1.03396618, 1.49144781,-1.55002558,-1.11833119,-0.62120575,-0.64265978], - [-1.58957553, 1.75000548, 0.01272983, 0.11212827,-0.34744453,-1.45086825, 0.95842224, 0.50071126, 0.05389150], - ], - [ - [-1.13160479,-0.21202824, 0.25907388,-0.64313424,-0.76302397,-0.16797650,-0.75345570, 2.01765633, 1.39449334], - [-0.16463053,-0.73241645,-0.69120175, 0.13771832, 0.72443259,-2.06525135, 1.02475107, 1.40244913, 0.36414924], - [0.38766465,-0.19543301,-1.80767059, 1.11545098, 0.21692322,-1.22834778, 0.13580292, 1.63094711,-0.25533777], - [1.22877085, 0.05472810, 0.65142977, 0.73869365,-0.74706972,-1.29277837, 1.07350135, 0.06228387,-1.76955938], - [-0.01733636,-1.57447529, 0.79691470, 1.00687420, 1.65637493,-0.75668150,-0.54616517, 0.45799020,-1.02349579], - ], + [0.89139193, -1.2803736, 1.0577338, -0.668752], + [-0.41001588, 1.6245831, -1.084222, -0.13034514], + [0.9247901, -1.1639801, -0.8187512, 1.0579412] + ] ] ); } @@ -301,11 +292,11 @@ mod tests { fn test_decoder_block_forward() { let dev = TestDevice::seed_from_u64(2); - const BATCH: usize = 4; - const S1: usize = 8; - const S2: usize = 6; - const EMBED_DIM: usize = 12; - const NUM_HEADS: usize = 6; + const BATCH: usize = 2; + const S1: usize = 3; + const S2: usize = 2; + const EMBED_DIM: usize = 4; + const NUM_HEADS: usize = 2; const FF_DIM: usize = 2; type Dtype = f32; @@ -318,57 +309,39 @@ mod tests { let tgt: Tensor, Dtype, _> = dev.sample_normal(); let mem: Tensor, Dtype, _> = dev.sample_normal(); + + // uncomment to save for this specific test params and inputs + // + // decoder.save_safetensors("decoder.safetensor").unwrap(); + // tgt.save_safetensors("tgt.safetensor").unwrap(); + // mem.save_safetensors("mem.safetensor").unwrap(); + let y = decoder.forward((tgt, mem)); + // uncomment to save for this specific test params and inputs + // + // y.save_safetensors("y.safetensor").unwrap(); + + println!("{:?}", y.array()); + // This expected y was generated by: // 1. saving `decoder` parameters, `tgt`, `mem` and `y` to a npz files // 2. Running pytorch with the same values // 3. printing out the output // See https://github.com/coreylowman/dfdx/wiki/Exporting-MultiHeadAttention-to-pytorch-for-unit-tests - #[rustfmt::skip] assert_close_to_literal!( y, [ [ - [-1.87558722, 0.45965099, 0.20498508,-1.73645127, 1.19475269,-0.07198015, 1.87802076, 0.18534835, 0.09591459,-0.19824848,-0.35261178, 0.21620668], - [-1.65146410, 0.36979428, 2.44077325, 0.06124005,-1.35236311, 0.06834260, 0.15826070,-0.82507777, 0.37757808, 0.65084165,-0.26028851,-0.03763753], - [-0.30696073,-0.83636290, 1.20258296, 0.11318116, 2.23617601,-0.58318114, 0.66371393,-0.26198950,-0.46798199,-1.64899850, 0.63527161,-0.74545103], - [-0.23854624,-1.12693906, 1.16869855,-0.19282928, 1.83873713,-0.11721543, 1.00944722,-0.97332841,-0.75959450,-0.69980252, 1.23692346,-1.14555120], - [1.36781275,-1.00360036,-0.45941362, 1.16563404, 0.24138503, 0.51682448,-0.20305091,-0.68849629, 0.21949562,-2.32909155, 1.11119950, 0.06130134], - [-0.70381856, 1.24304760, 1.32746470, 0.43500248,-1.45963287,-0.33785006, 0.95192397,-0.72454590,-0.56011575,-1.33778274, 1.46311414,-0.29680732], - [-0.72720474,-1.29362297, 0.24656427, 0.25788289,-1.20061839, 0.20161679,-0.18183309,-0.28182927, 1.85331190,-0.41204709, 2.05122447,-0.51344484], - [-0.45356780, 1.31273413, 0.69735909,-1.96937740, 0.33488208,-0.99047261, 0.59060574,-0.65752614, 1.89437556,-0.41522720,-0.09553659,-0.24824893], - ], - [ - [0.92695564,-0.37954834, 0.74523187, 0.91893858, 0.26190025,-1.12540352, 0.87693417,-0.56255865, 0.20910029,-2.21528411, 1.21251309,-0.86877924], - [-0.94927889,-1.28225541, 1.38664925,-0.47819123, 1.60083365,-0.25243780, 1.21168947,-0.77403182, 0.60282439,-0.67139530, 0.72949010,-1.12389636], - [0.32318670, 0.44635653, 0.69037175,-2.00356507, 0.31796345,-1.09540510, 1.65720248, 0.18892130, 0.52996045,-0.80869401, 0.91539401,-1.16169262], - [-0.93624949, 0.90174866,-0.35485053, 0.28630549,-0.67549163,-1.74944031, 0.75101191, 0.73161471, 2.11734390,-0.91214812, 0.20135719,-0.36120197], - [-0.12938653,-0.65747797, 2.05397773,-1.01142454,-0.12065405,-2.02726126, 0.42845321, 0.56529117, 1.02239680, 0.41882706, 0.12460811,-0.66735017], - [1.61325872, 1.18383896, 0.58100909,-1.39098096,-0.86362296, 0.16341744,-0.44804084,-0.85499638,-0.94598162, 0.20620863, 1.56031752,-0.80442756], - [0.15400597, 0.30694833,-0.10923728,-1.54726267, 2.59482384,-0.72448921,-0.47337827, 0.94458705,-0.74652761, 0.43154043,-0.49556813,-0.33544219], - [0.06703589,-1.33028281, 1.29519308, 0.01789100, 1.73138475, 0.11349702, 0.98292470,-1.37452459,-0.57708341,-0.04158162, 0.54672015,-1.43117404], - ], - [ - [-1.13928354,-0.41951340, 1.02809525, 1.10831285,-0.37338197, 0.62760144,-0.49609870, 0.89603722, 0.28748062,-2.46635914, 0.32486960, 0.62223953], - [0.66343045, 0.17840990,-0.32520610,-0.91180247,-1.24669814, 0.98684084, 1.03520977,-0.66813290, 2.06043386,-1.47457957, 0.05163103,-0.34953672], - [0.70942575,-1.41629028, 0.57625329, 1.22837853, 0.26442787,-1.24242258,-0.38967255,-0.10485345, 1.34950197,-1.88799143, 0.64463151, 0.26861122], - [-0.90124643, 2.06094766, 0.20568365, 0.06078637, 1.68658400,-0.19301027,-0.56969130,-0.80906254,-1.20984066, 0.12565698, 0.62286967,-1.07967734], - [-0.58323914,-0.91550159, 2.76294446,-0.23104562, 1.03537095,-0.79180622,-0.30585235,-0.37028444, 0.06941666,-0.66646379, 0.61295509,-0.61649406], - [-0.69953281,-0.53587002, 0.10623999,-1.43030167,-1.28995168,-0.84757996,-0.18267554,-0.03703059, 1.55741370, 1.54363191, 0.52537125, 1.29028559], - [-0.70696884,-0.75943643, 1.45195222,-0.89612883,-0.74769866, 0.21710433,-0.64992350,-1.06435382,-0.16617794, 2.16994262, 1.05082333, 0.10086535], - [-0.37381354,-0.70111430, 1.83576059, 0.72364914,-1.35405958, 0.72988695, 0.52067578,-0.01720174,-0.46059695, 1.23575497,-0.43288255,-1.70605886], + [0.94532686, -0.46526614, 0.93781346, -1.4178741], + [1.6348482, -1.0348053, -0.49546495, -0.10457793], + [0.8033758, 1.1668185, -0.823479, -1.146715] ], [ - [-1.20804095, 0.38654494, 1.65309286,-1.20736289, 1.07261550, 0.46114275, 0.83086872,-0.01955486,-1.26059496,-0.11887560, 0.79357809,-1.38341355], - [-0.56300515,-0.59784967, 2.81054258,-0.37848800,-0.41372916,-0.90938121, 0.82510620, 0.12329611, 0.14460202, 0.12636989,-1.24349451, 0.07603064], - [-1.36658132,-1.11734688, 1.74118745, 0.56276298, 0.35426524, 0.82628661,-1.63426054,-0.80171925, 0.09229738, 0.71951282,-0.27681157, 0.90040714], - [-0.47256982,-0.39320827,-1.71228957, 0.24000385, 0.71217608, 1.75911832,-1.24219942,-0.00148612, 0.80727738,-1.04095078, 0.02052352, 1.32360506], - [-0.00462395, 0.10117173, 1.83498573,-0.69001645, 0.46190643,-1.00014806, 1.14456511, 0.55384815, 0.36776620,-0.55358148,-0.00812254,-2.20775104], - [-0.59229124,-1.63409364, 1.70002937, 0.40580338, 0.76335514,-0.50594056, 0.32149875, 1.17081654,-1.73462892, 0.50679129,-0.56456679, 0.16322602], - [-0.28135568, 0.12212670, 1.39109802,-1.15742660, 0.81334966, 0.21747869,-0.01345161, 0.15832950, 0.68586451,-1.60281539, 1.38292646,-1.71612430], - [0.52762824,-1.20023167, 1.34064293,-0.40414453, 0.61767668,-0.24842866, 0.06679908, 1.13988364,-0.66101944,-0.71850598, 1.43029106,-1.89059174], - ], + [1.2232355, -1.5628394, 0.2116476, 0.12795626], + [0.99152863, -0.98818815, 1.0083598, -1.0117002], + [-1.4775288, 0.47518563, -0.23662777, 1.2389709] + ] ] ); } From 6b15ab753b56736f00eacd3b815e2a040ffadc4f Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Wed, 6 Dec 2023 12:12:58 -0500 Subject: [PATCH 07/10] reduce tensor sizes for 'slow' tests --- dfdx-core/src/tensor_ops/conv2d/tests.rs | 22 ++++++------ dfdx-core/src/tensor_ops/convtrans2d/tests.rs | 10 +++--- dfdx-core/src/tensor_ops/log_softmax.rs | 2 +- dfdx-core/src/tensor_ops/matmul/mod.rs | 34 +++++++++---------- dfdx-core/src/tensor_ops/softmax.rs | 2 +- 5 files changed, 35 insertions(+), 35 deletions(-) diff --git a/dfdx-core/src/tensor_ops/conv2d/tests.rs b/dfdx-core/src/tensor_ops/conv2d/tests.rs index 85603de8..b7110a22 100644 --- a/dfdx-core/src/tensor_ops/conv2d/tests.rs +++ b/dfdx-core/src/tensor_ops/conv2d/tests.rs @@ -218,10 +218,10 @@ fn test_conv2d_s4p3k2() { #[test] fn test_batched_conv2d() { let dev: TestDevice = Default::default(); - let x: Tensor, TestDtype, _> = dev.sample_normal(); + let x: Tensor, TestDtype, _> = dev.sample_normal(); let w: Tensor, TestDtype, _> = dev.sample_normal(); - let y: Tensor, _, _, _> = + let y: Tensor, _, _, _> = (x.leaky_trace(), w.clone()).conv2d(Const::<3>, Const::<2>, Const::<1>, Const::<1>); let y0 = y.retaped::(); let grads0 = y.square().mean().backward(); @@ -229,11 +229,11 @@ fn test_batched_conv2d() { let w0 = grads0.get(&w); let x = x - .broadcast::, _>() - .reshape::>(); + .broadcast::, _>() + .reshape::>(); assert_eq!(x.strides, x.shape.strides()); - let y: Tensor, _, _, _> = + let y: Tensor, _, _, _> = (x.leaky_trace(), w.clone()).conv2d(Const::<3>, Const::<2>, Const::<1>, Const::<1>); for i in 0..10 { assert_close_to_tensor!(y0, y.retaped::().select(dev.tensor(i))); @@ -245,7 +245,7 @@ fn test_batched_conv2d() { let x_grad = grads.get(&x) * 10.0; for i in 0..10 { - assert_close_to_tensor!(x0, x_grad.clone().select(dev.tensor(i))); + assert_close_to_tensor!(x0, x_grad.clone().select(dev.tensor(i)), 3e-6); } } @@ -405,7 +405,7 @@ fn test_conv2d_grouped() { fn test_conv2d_grouped_slices() { const NUM_GROUPS: usize = 3; let dev: TestDevice = Default::default(); - let x: Tensor, TestDtype, _> = dev.sample_normal(); + let x: Tensor, TestDtype, _> = dev.sample_normal(); let w: Tensor, TestDtype, _> = dev.sample_normal(); let y = (x.leaky_trace(), w.clone()).conv2d( @@ -419,7 +419,7 @@ fn test_conv2d_grouped_slices() { let x_group = x .clone() .slice((.., 3 * i..3 * (i + 1), .., ..)) - .realize::<(Const<2>, Const<3>, Const<14>, Const<14>)>(); + .realize::<(Const<2>, Const<3>, Const<3>, Const<3>)>(); let w_group = w .clone() .slice((5 * i..5 * (i + 1), .., .., ..)) @@ -428,7 +428,7 @@ fn test_conv2d_grouped_slices() { let y_group_true = y .retaped::() .slice((.., 5 * i..5 * (i + 1), .., ..)) - .realize::<(Const<2>, Const<5>, Const<12>, Const<12>)>(); + .realize::<(Const<2>, Const<5>, Const<1>, Const<1>)>(); assert_close_to_tensor!(y_group, y_group_true); } @@ -440,7 +440,7 @@ fn test_conv2d_grouped_slices() { let x_group = x .clone() .slice((.., 3 * i..3 * (i + 1), .., ..)) - .realize::<(Const<2>, Const<3>, Const<14>, Const<14>)>(); + .realize::<(Const<2>, Const<3>, Const<3>, Const<3>)>(); let w_group = w .clone() .slice((5 * i..5 * (i + 1), .., .., ..)) @@ -452,7 +452,7 @@ fn test_conv2d_grouped_slices() { let x_grad_group_true = x_grad .clone() .slice((.., 3 * i..3 * (i + 1), .., ..)) - .realize::<(Const<2>, Const<3>, Const<14>, Const<14>)>(); + .realize::<(Const<2>, Const<3>, Const<3>, Const<3>)>(); let w_grad_group_true = w_grad .clone() .slice((5 * i..5 * (i + 1), .., .., ..)) diff --git a/dfdx-core/src/tensor_ops/convtrans2d/tests.rs b/dfdx-core/src/tensor_ops/convtrans2d/tests.rs index 3d64acbf..c3670294 100644 --- a/dfdx-core/src/tensor_ops/convtrans2d/tests.rs +++ b/dfdx-core/src/tensor_ops/convtrans2d/tests.rs @@ -280,10 +280,10 @@ fn test_convtrans2d_padded() { #[test] fn test_convtrans2d_batched() { let dev: TestDevice = Default::default(); - let x: Tensor, TestDtype, _> = dev.sample_normal(); + let x: Tensor, TestDtype, _> = dev.sample_normal(); let w: Tensor, TestDtype, _> = dev.sample_normal(); - let y: Tensor, _, _, _> = + let y: Tensor, _, _, _> = (x.leaky_trace(), w.clone()).convtrans2d(Const::<3>, Const::<2>, Const::<1>, Const::<1>); let y0 = y.retaped::(); let grads0 = y.square().mean().backward(); @@ -291,10 +291,10 @@ fn test_convtrans2d_batched() { let w0 = grads0.get(&w); let x = x - .broadcast::, _>() - .reshape::>(); + .broadcast::, _>() + .reshape::>(); - let y: Tensor, _, _, _> = + let y: Tensor, _, _, _> = (x.leaky_trace(), w.clone()).convtrans2d(Const::<3>, Const::<2>, Const::<1>, Const::<1>); for i in 0..10 { assert_close_to_tensor!(y0, y.retaped::().select(dev.tensor(i)), 1e-5); diff --git a/dfdx-core/src/tensor_ops/log_softmax.rs b/dfdx-core/src/tensor_ops/log_softmax.rs index 487c33e5..d98bc330 100644 --- a/dfdx-core/src/tensor_ops/log_softmax.rs +++ b/dfdx-core/src/tensor_ops/log_softmax.rs @@ -81,7 +81,7 @@ mod tests { #[test] fn test_log_softmax_equivalence() { let dev: TestDevice = Default::default(); - let t: Tensor, TestDtype, _> = dev.sample_normal(); + let t: Tensor, TestDtype, _> = dev.sample_normal(); let p = t.leaky_trace().log_softmax::>(); let p_truth = t.leaky_trace() - t.leaky_trace().logsumexp::<_, Axis<3>>().broadcast(); // we can't create an array as it will overflow the stack diff --git a/dfdx-core/src/tensor_ops/matmul/mod.rs b/dfdx-core/src/tensor_ops/matmul/mod.rs index 5e4d03b3..d133b9ab 100644 --- a/dfdx-core/src/tensor_ops/matmul/mod.rs +++ b/dfdx-core/src/tensor_ops/matmul/mod.rs @@ -346,21 +346,21 @@ mod tests { } { - let a: Tensor, TestDtype, _> = dev.zeros(); + let a: Tensor, TestDtype, _> = dev.zeros(); let b: Tensor, TestDtype, _> = dev.zeros(); - let _: Tensor, TestDtype, _> = a.matmul(b); + let _: Tensor, TestDtype, _> = a.matmul(b); } { - let a: Tensor, TestDtype, _> = dev.zeros(); - let b: Tensor, TestDtype, _> = dev.zeros(); - let _: Tensor, TestDtype, _> = a.matmul(b); + let a: Tensor, TestDtype, _> = dev.zeros(); + let b: Tensor, TestDtype, _> = dev.zeros(); + let _: Tensor, TestDtype, _> = a.matmul(b); } { - let a: Tensor, TestDtype, _> = dev.zeros(); - let b: Tensor, TestDtype, _> = dev.zeros(); - let _: Tensor, TestDtype, _> = a.matmul(b); + let a: Tensor, TestDtype, _> = dev.zeros(); + let b: Tensor, TestDtype, _> = dev.zeros(); + let _: Tensor, TestDtype, _> = a.matmul(b); } } @@ -427,7 +427,7 @@ mod tests { #[test] fn test_matmul_broadcast() { - const N: usize = 5; + const N: usize = 2; let dev: TestDevice = Default::default(); let a: Tensor, TestDtype, _> = dev.sample_normal(); let a_array = a.array(); @@ -458,7 +458,7 @@ mod tests { #[test] fn test_matmul_broadcast_actual() { - const N: usize = 5; + const N: usize = 2; let dev: TestDevice = Default::default(); let a: Tensor, TestDtype, _> = dev.sample_normal(); let b: Tensor, TestDtype, _> = dev.sample_normal(); @@ -476,9 +476,9 @@ mod tests { fn test_matmul_batched_3d() { let dev: TestDevice = Default::default(); - let a: Tensor, TestDtype, _> = dev.sample_normal(); + let a: Tensor, TestDtype, _> = dev.sample_normal(); let a_array = a.array(); - let b: Tensor, TestDtype, _> = dev.sample_normal(); + let b: Tensor, TestDtype, _> = dev.sample_normal(); let b_array = b.array(); let c = a.leaky_trace().matmul(b.clone()); let c_array = c.array(); @@ -487,7 +487,7 @@ mod tests { let g_a = g.get(&a).array(); let g_b = g.get(&b).array(); - for i in 0..5 { + for i in 0..2 { let sub_a = dev.tensor(a_array[i]); let sub_b = dev.tensor(b_array[i]); let sub_c = sub_a.leaky_trace().matmul(sub_b.clone()); @@ -502,9 +502,9 @@ mod tests { fn test_matmul_batched_4d() { let dev: TestDevice = Default::default(); - let a: Tensor, TestDtype, _> = dev.sample_normal(); + let a: Tensor, TestDtype, _> = dev.sample_normal(); let a_array = a.array(); - let b: Tensor, TestDtype, _> = dev.sample_normal(); + let b: Tensor, TestDtype, _> = dev.sample_normal(); let b_array = b.array(); let c = a.leaky_trace().matmul(b.clone()); let c_array = c.array(); @@ -513,8 +513,8 @@ mod tests { let g_a = g.get(&a).array(); let g_b = g.get(&b).array(); - for i in 0..7 { - for j in 0..5 { + for i in 0..2 { + for j in 0..3 { let sub_a = dev.tensor(a_array[i][j]); let sub_b = dev.tensor(b_array[i][j]); let sub_c = sub_a.leaky_trace().matmul(sub_b.clone()); diff --git a/dfdx-core/src/tensor_ops/softmax.rs b/dfdx-core/src/tensor_ops/softmax.rs index 0a6ec8aa..a45436c8 100644 --- a/dfdx-core/src/tensor_ops/softmax.rs +++ b/dfdx-core/src/tensor_ops/softmax.rs @@ -91,7 +91,7 @@ mod tests { #[test] fn test_softmax_equivalence() { let dev: TestDevice = Default::default(); - let t: Tensor, TestDtype, _> = dev.sample_normal(); + let t: Tensor, TestDtype, _> = dev.sample_normal(); let p = t.leaky_trace().softmax::>(); let p_truth = t.leaky_trace().log_softmax::>().exp(); // we can't create an array as it will overflow the stack From 44a41c06421178600d69ed96f975b78509b2b2f1 Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Tue, 5 Dec 2023 07:58:03 -0500 Subject: [PATCH 08/10] patch crossbeam Note that crossbeam is an indirect dependency and they are still to release a version that [passes](https://github.com/crossbeam-rs/crossbeam/pull/996) on this miri check. So to temporarily get this out of the way, you can patch crossbeam on the workspace `dfdx/Cargo.toml`. --- Cargo.toml | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 68cc915c..97427663 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,4 +8,13 @@ safetensors = { version = "0.4.0", default-features = false } memmap2 = { version = "0.9.0", default-features = false } rand = { version = "0.8.5", default-features = false, features = ["std_rng"] } rand_distr = { version = "0.4.3", default-features = false } -libm = "0.2.8" \ No newline at end of file +libm = "0.2.8" + +[patch.crates-io] +crossbeam = { git = "https://github.com/crossbeam-rs/crossbeam", rev = "a57e655eef415c21babddc4ba0217b6ca7acd0a2" } +crossbeam-epoch = { git = "https://github.com/crossbeam-rs/crossbeam", rev = "a57e655eef415c21babddc4ba0217b6ca7acd0a2" } +crossbeam-channel = { git = "https://github.com/crossbeam-rs/crossbeam", rev = "a57e655eef415c21babddc4ba0217b6ca7acd0a2" } +crossbeam-deque = { git = "https://github.com/crossbeam-rs/crossbeam", rev = "a57e655eef415c21babddc4ba0217b6ca7acd0a2" } +crossbeam-queue = { git = "https://github.com/crossbeam-rs/crossbeam", rev = "a57e655eef415c21babddc4ba0217b6ca7acd0a2" } +crossbeam-skiplist = { git = "https://github.com/crossbeam-rs/crossbeam", rev = "a57e655eef415c21babddc4ba0217b6ca7acd0a2" } +crossbeam-utils = { git = "https://github.com/crossbeam-rs/crossbeam", rev = "a57e655eef415c21babddc4ba0217b6ca7acd0a2" } From 4251a777ec24b2b24248d9bb1b124bae56b6a055 Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Tue, 5 Dec 2023 08:30:19 -0500 Subject: [PATCH 09/10] miri pass nn::layers::add_into::tests::longer_network Thread main finished before other threads were still active. The only thing related to threads were the gemm usage of rayon, and disabling rayon usage for matmul seems to be sufficient to fix this. --- dfdx-core/src/tensor_ops/matmul/cpu_kernel.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dfdx-core/src/tensor_ops/matmul/cpu_kernel.rs b/dfdx-core/src/tensor_ops/matmul/cpu_kernel.rs index bf3e6ce0..e22af5a0 100644 --- a/dfdx-core/src/tensor_ops/matmul/cpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/matmul/cpu_kernel.rs @@ -90,7 +90,7 @@ impl MatMulImpl> for Cpu { false, false, false, - gemm::Parallelism::Rayon(rayon::current_num_threads()), + gemm::Parallelism::None, ) } } @@ -138,7 +138,7 @@ impl MatMulImpl for Cpu { false, false, false, - gemm::Parallelism::Rayon(rayon::current_num_threads()), + gemm::Parallelism::None, ) } } @@ -180,7 +180,7 @@ impl MatMulImpl for Cpu { false, false, false, - gemm::Parallelism::Rayon(rayon::current_num_threads()), + gemm::Parallelism::None, ) } } @@ -222,7 +222,7 @@ impl MatMulImpl for Cpu { false, false, false, - gemm::Parallelism::Rayon(rayon::current_num_threads()), + gemm::Parallelism::None, ) } } From 8f43ce54ee1fe9ff14ca7846dd097e7ce4a0a7cb Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Wed, 6 Dec 2023 17:33:47 -0500 Subject: [PATCH 10/10] Free cache memory on `TensorCache` drop - Moved the impl of `Cache::try_empty_cache()` to `TensorCache::clear()`. - This can be invoked both by `Cache::try_empty_cache()` and by `drop(TensorCache)`. - Moved the device cache ptr deallocation to `BytesPtr`, `CudaBytesPtr` (newtype over `CUdeviceptr`) and `Buffer`. - This is abstracted by the `CachePtr` trait. - Can be called by `TensorCache::clear()`. - This method may require some "extra" device information, such as in the cuda case. That information is held by `TensorCache`. --- dfdx-core/src/tensor/cache.rs | 74 ++++++++++++++++++++++- dfdx-core/src/tensor/cpu/device.rs | 84 +++++++++++++-------------- dfdx-core/src/tensor/cuda/device.rs | 38 ++++++------ dfdx-core/src/tensor/webgpu/device.rs | 26 ++++----- 4 files changed, 146 insertions(+), 76 deletions(-) diff --git a/dfdx-core/src/tensor/cache.rs b/dfdx-core/src/tensor/cache.rs index e785cb64..ad18e5c1 100644 --- a/dfdx-core/src/tensor/cache.rs +++ b/dfdx-core/src/tensor/cache.rs @@ -33,21 +33,35 @@ pub(crate) struct AllocationKey { /// valid allocation. When the last value is removed from the list, the key /// is removed. #[derive(Debug)] -pub(crate) struct TensorCache { +pub(crate) struct TensorCache, DeviceDev = ()> { pub(crate) allocations: RwLock>>, pub(crate) enabled: RwLock, + device_dev: DeviceDev, } -impl Default for TensorCache { +impl, DeviceDev: Default> Default for TensorCache { fn default() -> Self { Self { allocations: Default::default(), enabled: RwLock::new(false), + device_dev: DeviceDev::default(), } } } -impl TensorCache { +#[allow(dead_code)] +impl, DeviceDev> TensorCache { + /// Initiate an empty [TensorCache] with a given `device_dev`. + pub(crate) fn new(device_dev: DeviceDev) -> Self { + Self { + allocations: Default::default(), + enabled: RwLock::new(false), + device_dev, + } + } +} + +impl, DeviceDev> TensorCache { /// Returns the number of allocations in the cache. #[allow(unused)] pub(crate) fn len(&self) -> usize { @@ -183,6 +197,60 @@ impl TensorCache { } } +impl, DeviceDev> TensorCache { + /// Deallocates all cached memory on the device and empties the cache. + pub(crate) fn try_clear(&self) -> Result<(), crate::prelude::Error> { + let mut cache = { + #[cfg(not(feature = "no-std"))] + { + self.allocations.write().unwrap() + } + #[cfg(feature = "no-std")] + { + self.allocations.write() + } + }; + + for (&key, allocations) in cache.iter_mut() { + for alloc in allocations.drain(..) { + alloc.dealloc(&key, &self.device_dev); + } + } + cache.clear(); + Ok(()) + } +} + +impl, DeviceDev> Drop for TensorCache { + fn drop(&mut self) { + self.try_clear().unwrap(); + } +} + +/// Functionality internalized by the pointer. +pub(crate) trait CachePtr: Sized { + // by default no deallocation is made for any cache ptr + // ie. they leak + /// Deallocates the memory referred by this pointer. + fn dealloc(self, _key: &AllocationKey, _dev: &Dev) {} +} + +impl CachePtr for bool {} +impl CachePtr for u8 {} +impl CachePtr for u16 {} +impl CachePtr for u32 {} +impl CachePtr for u64 {} +impl CachePtr for u128 {} +impl CachePtr for usize {} +impl CachePtr for i8 {} +impl CachePtr for i16 {} +impl CachePtr for i32 {} +impl CachePtr for i64 {} +impl CachePtr for i128 {} +impl CachePtr for isize {} +impl CachePtr for f32 {} +impl CachePtr for f64 {} + #[cfg(test)] mod test { use super::*; diff --git a/dfdx-core/src/tensor/cpu/device.rs b/dfdx-core/src/tensor/cpu/device.rs index d3ce936f..1c6789fc 100644 --- a/dfdx-core/src/tensor/cpu/device.rs +++ b/dfdx-core/src/tensor/cpu/device.rs @@ -25,7 +25,7 @@ pub struct Cpu { /// A thread safe random number generator. pub(crate) rng: Arc>, /// A thread safe cache of memory allocations that can be reused. - pub(crate) cache: Arc>, + pub(crate) cache: Arc>, } impl Default for Cpu { @@ -47,6 +47,45 @@ impl Cpu { } } +/// Unit struct to represent information needed for managing allocations on the Cpu. +#[derive(Clone, Debug, Default)] +pub(crate) struct CpuDevice; + +impl crate::tensor::cache::CachePtr for BytesPtr { + fn dealloc(self, key: &crate::tensor::cache::AllocationKey, _dev: &CpuDevice) { + assert!(key.num_bytes % key.size == 0); + assert!(key.num_bytes < isize::MAX as usize); + let len = key.num_bytes / key.size; + let cap = len; + // SAFETY: + // - "ptr must have been allocated using the global allocator, such as via the alloc::alloc function." + // - ✅ cpu uses global allocator + // - "T needs to have the same alignment as what ptr was allocated with." + // - ✅ we are matching on the alignment below + // - "The size of T times the capacity needs to be the same size as the pointer was allocated with." + // - ✅ covered by `key.num_bytes / key.size` and the `key.num_bytes % key.size == 0` assertion above + // - "length needs to be less than or equal to capacity." + // - ✅ they are equal + // - "The first length values must be properly initialized values of type T." + // - ✅ any bit pattern is valid for unsigned ints used below + // - "capacity needs to be the capacity that the pointer was allocated with." + // - ✅ handled by assertion above (key.num_bytes % key.size == 0) + // - "The allocated size in bytes must be no larger than isize::MAX. See the safety documentation of pointer::offset." + // - ✅ handled by assertion above + debug_assert_eq!(std::alloc::Layout::new::().align(), 1); + debug_assert_eq!(std::alloc::Layout::new::().align(), 2); + debug_assert_eq!(std::alloc::Layout::new::().align(), 4); + debug_assert_eq!(std::alloc::Layout::new::().align(), 8); + match key.alignment { + 1 => unsafe { drop(Vec::from_raw_parts(self.0, len, cap)) }, + 2 => unsafe { drop(Vec::from_raw_parts(self.0 as *mut u16, len, cap)) }, + 4 => unsafe { drop(Vec::from_raw_parts(self.0 as *mut u32, len, cap)) }, + 8 => unsafe { drop(Vec::from_raw_parts(self.0 as *mut u64, len, cap)) }, + _ => unreachable!(), + }; + } +} + /// A [Vec] that can be cloned without allocating new memory. /// When [Drop]ed it will insert it's data into the cache. #[derive(Debug)] @@ -54,7 +93,7 @@ pub struct CachableVec { /// The data stored in this vector. pub(crate) data: Vec, /// A cache of memory allocations that can be reused. - pub(crate) cache: Arc>, + pub(crate) cache: Arc>, } impl Clone for CachableVec { @@ -166,45 +205,6 @@ impl Cache for Cpu { } fn try_empty_cache(&self) -> Result<(), Error> { - #[cfg(not(feature = "no-std"))] - let mut cache = self.cache.allocations.write().unwrap(); - #[cfg(feature = "no-std")] - let mut cache = self.cache.allocations.write(); - for (&key, allocations) in cache.iter_mut() { - assert!(key.num_bytes % key.size == 0); - assert!(key.num_bytes < isize::MAX as usize); - let len = key.num_bytes / key.size; - let cap = len; - for alloc in allocations.drain(..) { - // SAFETY: - // - "ptr must have been allocated using the global allocator, such as via the alloc::alloc function." - // - ✅ cpu uses global allocator - // - "T needs to have the same alignment as what ptr was allocated with." - // - ✅ we are matching on the alignment below - // - "The size of T times the capacity needs to be the same size as the pointer was allocated with." - // - ✅ covered by `key.num_bytes / key.size` and the `key.num_bytes % key.size == 0` assertion above - // - "length needs to be less than or equal to capacity." - // - ✅ they are equal - // - "The first length values must be properly initialized values of type T." - // - ✅ any bit pattern is valid for unsigned ints used below - // - "capacity needs to be the capacity that the pointer was allocated with." - // - ✅ handled by assertion above (key.num_bytes % key.size == 0) - // - "The allocated size in bytes must be no larger than isize::MAX. See the safety documentation of pointer::offset." - // - ✅ handled by assertion above - debug_assert_eq!(std::alloc::Layout::new::().align(), 1); - debug_assert_eq!(std::alloc::Layout::new::().align(), 2); - debug_assert_eq!(std::alloc::Layout::new::().align(), 4); - debug_assert_eq!(std::alloc::Layout::new::().align(), 8); - match key.alignment { - 1 => unsafe { drop(Vec::from_raw_parts(alloc.0, len, cap)) }, - 2 => unsafe { drop(Vec::from_raw_parts(alloc.0 as *mut u16, len, cap)) }, - 4 => unsafe { drop(Vec::from_raw_parts(alloc.0 as *mut u32, len, cap)) }, - 8 => unsafe { drop(Vec::from_raw_parts(alloc.0 as *mut u64, len, cap)) }, - _ => unreachable!(), - }; - } - } - cache.clear(); - Ok(()) + self.cache.try_clear() } } diff --git a/dfdx-core/src/tensor/cuda/device.rs b/dfdx-core/src/tensor/cuda/device.rs index de6f7196..fc9c8225 100644 --- a/dfdx-core/src/tensor/cuda/device.rs +++ b/dfdx-core/src/tensor/cuda/device.rs @@ -29,7 +29,7 @@ pub struct Cuda { /// A second stream for kernels to optionally execute on. pub(crate) par_stream: Arc, pub(crate) workspace: Arc>>, - pub(crate) cache: Arc>, + pub(crate) cache: Arc>>, } impl From for Error { @@ -77,6 +77,7 @@ impl Cuda { let cudnn = cudarc::cudnn::Cudnn::new(dev.clone())?; let par_stream = Arc::new(dev.fork_default_stream()?); let workspace = Arc::new(Mutex::new(dev.alloc_zeros::(1)?)); + let cache = Arc::new(TensorCache::new(Arc::clone(&dev))); Ok(Self { cpu, dev, @@ -85,7 +86,7 @@ impl Cuda { cudnn, par_stream, workspace, - cache: Default::default(), + cache, }) } } @@ -100,7 +101,7 @@ impl Cuda { ) -> Result, Error> { let data = self.cache.try_pop::(len).map_or_else( || self.dev.alloc::(len), - |ptr| Ok(self.dev.upgrade_device_ptr(ptr, len)), + |ptr| Ok(self.dev.upgrade_device_ptr(ptr.0, len)), )?; Ok(data) } @@ -122,6 +123,18 @@ impl Cuda { } } +/// A pointer to a bytes on the Cuda device. Used in conjunction with [TensorCache]. +#[repr(transparent)] +#[derive(Clone, Debug)] +pub struct CudaBytesPtr(pub(crate) CUdeviceptr); + +impl crate::tensor::cache::CachePtr> for CudaBytesPtr { + fn dealloc(self, key: &crate::tensor::cache::AllocationKey, dev: &Arc) { + let data = unsafe { dev.upgrade_device_ptr::(self.0, key.num_bytes) }; + drop(data); + } +} + /// A [CudaSlice] that can be cloned without allocating new memory. /// When [Drop]ed it will insert it's data into the cache. #[derive(Debug)] @@ -129,7 +142,7 @@ pub struct CachableCudaSlice { /// The actual data. pub(crate) data: CudaSlice, /// A cache of device pointers that can be reused. - pub(crate) cache: Arc>, + pub(crate) cache: Arc>>, } impl Clone for CachableCudaSlice { @@ -142,7 +155,7 @@ impl Clone for CachableCudaSlice { // SAFETY: // 1. we know that ptr is valid for `num_bytes` because it was registered for that. // 2. we are about to set the memory with dtod_copy - let mut slice = unsafe { dev.upgrade_device_ptr(ptr, len) }; + let mut slice = unsafe { dev.upgrade_device_ptr(ptr.0, len) }; dev.dtod_copy(&self.data, &mut slice).unwrap(); slice }, @@ -209,7 +222,7 @@ impl Drop for CachableCudaSlice { let numel = data.len(); // Get access to the raw pointer without freeing it. let ptr = data.leak(); - self.cache.insert::(numel, ptr); + self.cache.insert::(numel, CudaBytesPtr(ptr)); } } } @@ -232,18 +245,7 @@ impl Cache for Cuda { } fn try_empty_cache(&self) -> Result<(), Error> { - #[cfg(not(feature = "no-std"))] - let mut cache = self.cache.allocations.write().unwrap(); - #[cfg(feature = "no-std")] - let mut cache = self.cache.allocations.write(); - for (&key, allocations) in cache.iter_mut() { - for alloc in allocations.drain(..) { - let data = unsafe { self.dev.upgrade_device_ptr::(alloc, key.num_bytes) }; - drop(data); - } - } - cache.clear(); - Ok(()) + self.cache.try_clear() } } diff --git a/dfdx-core/src/tensor/webgpu/device.rs b/dfdx-core/src/tensor/webgpu/device.rs index 1c23989b..52acf5e5 100644 --- a/dfdx-core/src/tensor/webgpu/device.rs +++ b/dfdx-core/src/tensor/webgpu/device.rs @@ -109,7 +109,7 @@ pub struct Webgpu { pub(crate) dev: Arc, pub(crate) queue: Arc, - pub(crate) cache: Arc>, + pub(crate) cache: Arc>, pub(crate) cs_cache: Arc>>>, } @@ -297,12 +297,22 @@ impl Webgpu { // } } +/// Unit struct to represent information needed for managing allocations on the WebGpu. +#[derive(Clone, Debug, Default)] +pub(crate) struct WebGpuDevice; + +impl crate::tensor::cache::CachePtr for Buffer { + fn dealloc(self, _key: &crate::tensor::cache::AllocationKey, _dev: &WebGpuDevice) { + drop(self) + } +} + #[derive(Debug)] pub struct CachableBuffer { pub(crate) dev: Arc, pub(crate) queue: Arc, pub(crate) data: Buffer, - pub(crate) cache: Arc>, + pub(crate) cache: Arc>, pub(crate) _phantom: PhantomData, } @@ -397,17 +407,7 @@ impl Cache for Webgpu { } fn try_empty_cache(&self) -> Result<(), Error> { - #[cfg(not(feature = "no-std"))] - let mut cache = self.cache.allocations.write().unwrap(); - #[cfg(feature = "no-std")] - let mut cache = self.cache.allocations.write(); - for (&_key, allocations) in cache.iter_mut() { - for alloc in allocations.drain(..) { - drop(alloc); - } - } - cache.clear(); - Ok(()) + self.cache.try_clear() } }