From 5f37c8262531ed3e919a0340deb4ab9aefdc649d Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Wed, 6 Dec 2023 17:33:47 -0500 Subject: [PATCH] Free cache memory on `TensorCache` drop - Moved the impl of `Cache::try_empty_cache()` to `TensorCache::clear()`. - This can be invoked both by `Cache::try_empty_cache()` and by `drop(TensorCache)`. - Moved the device cache ptr deallocation to `BytesPtr`, `CudaBytesPtr` (newtype over `CUdeviceptr`) and `Buffer`. - This is abstracted by the `CachePtr` trait. - Can be called by `TensorCache::clear()`. - This method may require some "extra" device information, such as in the cuda case. That information is held by `TensorCache`. --- dfdx-core/src/tensor/cache.rs | 74 ++++++++++++++++++++++- dfdx-core/src/tensor/cpu/device.rs | 84 +++++++++++++-------------- dfdx-core/src/tensor/cuda/device.rs | 38 ++++++------ dfdx-core/src/tensor/webgpu/device.rs | 26 ++++----- 4 files changed, 146 insertions(+), 76 deletions(-) diff --git a/dfdx-core/src/tensor/cache.rs b/dfdx-core/src/tensor/cache.rs index e785cb647..ad18e5c1e 100644 --- a/dfdx-core/src/tensor/cache.rs +++ b/dfdx-core/src/tensor/cache.rs @@ -33,21 +33,35 @@ pub(crate) struct AllocationKey { /// valid allocation. When the last value is removed from the list, the key /// is removed. #[derive(Debug)] -pub(crate) struct TensorCache { +pub(crate) struct TensorCache, DeviceDev = ()> { pub(crate) allocations: RwLock>>, pub(crate) enabled: RwLock, + device_dev: DeviceDev, } -impl Default for TensorCache { +impl, DeviceDev: Default> Default for TensorCache { fn default() -> Self { Self { allocations: Default::default(), enabled: RwLock::new(false), + device_dev: DeviceDev::default(), } } } -impl TensorCache { +#[allow(dead_code)] +impl, DeviceDev> TensorCache { + /// Initiate an empty [TensorCache] with a given `device_dev`. + pub(crate) fn new(device_dev: DeviceDev) -> Self { + Self { + allocations: Default::default(), + enabled: RwLock::new(false), + device_dev, + } + } +} + +impl, DeviceDev> TensorCache { /// Returns the number of allocations in the cache. #[allow(unused)] pub(crate) fn len(&self) -> usize { @@ -183,6 +197,60 @@ impl TensorCache { } } +impl, DeviceDev> TensorCache { + /// Deallocates all cached memory on the device and empties the cache. + pub(crate) fn try_clear(&self) -> Result<(), crate::prelude::Error> { + let mut cache = { + #[cfg(not(feature = "no-std"))] + { + self.allocations.write().unwrap() + } + #[cfg(feature = "no-std")] + { + self.allocations.write() + } + }; + + for (&key, allocations) in cache.iter_mut() { + for alloc in allocations.drain(..) { + alloc.dealloc(&key, &self.device_dev); + } + } + cache.clear(); + Ok(()) + } +} + +impl, DeviceDev> Drop for TensorCache { + fn drop(&mut self) { + self.try_clear().unwrap(); + } +} + +/// Functionality internalized by the pointer. +pub(crate) trait CachePtr: Sized { + // by default no deallocation is made for any cache ptr + // ie. they leak + /// Deallocates the memory referred by this pointer. + fn dealloc(self, _key: &AllocationKey, _dev: &Dev) {} +} + +impl CachePtr for bool {} +impl CachePtr for u8 {} +impl CachePtr for u16 {} +impl CachePtr for u32 {} +impl CachePtr for u64 {} +impl CachePtr for u128 {} +impl CachePtr for usize {} +impl CachePtr for i8 {} +impl CachePtr for i16 {} +impl CachePtr for i32 {} +impl CachePtr for i64 {} +impl CachePtr for i128 {} +impl CachePtr for isize {} +impl CachePtr for f32 {} +impl CachePtr for f64 {} + #[cfg(test)] mod test { use super::*; diff --git a/dfdx-core/src/tensor/cpu/device.rs b/dfdx-core/src/tensor/cpu/device.rs index d3ce936f1..1c6789fcc 100644 --- a/dfdx-core/src/tensor/cpu/device.rs +++ b/dfdx-core/src/tensor/cpu/device.rs @@ -25,7 +25,7 @@ pub struct Cpu { /// A thread safe random number generator. pub(crate) rng: Arc>, /// A thread safe cache of memory allocations that can be reused. - pub(crate) cache: Arc>, + pub(crate) cache: Arc>, } impl Default for Cpu { @@ -47,6 +47,45 @@ impl Cpu { } } +/// Unit struct to represent information needed for managing allocations on the Cpu. +#[derive(Clone, Debug, Default)] +pub(crate) struct CpuDevice; + +impl crate::tensor::cache::CachePtr for BytesPtr { + fn dealloc(self, key: &crate::tensor::cache::AllocationKey, _dev: &CpuDevice) { + assert!(key.num_bytes % key.size == 0); + assert!(key.num_bytes < isize::MAX as usize); + let len = key.num_bytes / key.size; + let cap = len; + // SAFETY: + // - "ptr must have been allocated using the global allocator, such as via the alloc::alloc function." + // - ✅ cpu uses global allocator + // - "T needs to have the same alignment as what ptr was allocated with." + // - ✅ we are matching on the alignment below + // - "The size of T times the capacity needs to be the same size as the pointer was allocated with." + // - ✅ covered by `key.num_bytes / key.size` and the `key.num_bytes % key.size == 0` assertion above + // - "length needs to be less than or equal to capacity." + // - ✅ they are equal + // - "The first length values must be properly initialized values of type T." + // - ✅ any bit pattern is valid for unsigned ints used below + // - "capacity needs to be the capacity that the pointer was allocated with." + // - ✅ handled by assertion above (key.num_bytes % key.size == 0) + // - "The allocated size in bytes must be no larger than isize::MAX. See the safety documentation of pointer::offset." + // - ✅ handled by assertion above + debug_assert_eq!(std::alloc::Layout::new::().align(), 1); + debug_assert_eq!(std::alloc::Layout::new::().align(), 2); + debug_assert_eq!(std::alloc::Layout::new::().align(), 4); + debug_assert_eq!(std::alloc::Layout::new::().align(), 8); + match key.alignment { + 1 => unsafe { drop(Vec::from_raw_parts(self.0, len, cap)) }, + 2 => unsafe { drop(Vec::from_raw_parts(self.0 as *mut u16, len, cap)) }, + 4 => unsafe { drop(Vec::from_raw_parts(self.0 as *mut u32, len, cap)) }, + 8 => unsafe { drop(Vec::from_raw_parts(self.0 as *mut u64, len, cap)) }, + _ => unreachable!(), + }; + } +} + /// A [Vec] that can be cloned without allocating new memory. /// When [Drop]ed it will insert it's data into the cache. #[derive(Debug)] @@ -54,7 +93,7 @@ pub struct CachableVec { /// The data stored in this vector. pub(crate) data: Vec, /// A cache of memory allocations that can be reused. - pub(crate) cache: Arc>, + pub(crate) cache: Arc>, } impl Clone for CachableVec { @@ -166,45 +205,6 @@ impl Cache for Cpu { } fn try_empty_cache(&self) -> Result<(), Error> { - #[cfg(not(feature = "no-std"))] - let mut cache = self.cache.allocations.write().unwrap(); - #[cfg(feature = "no-std")] - let mut cache = self.cache.allocations.write(); - for (&key, allocations) in cache.iter_mut() { - assert!(key.num_bytes % key.size == 0); - assert!(key.num_bytes < isize::MAX as usize); - let len = key.num_bytes / key.size; - let cap = len; - for alloc in allocations.drain(..) { - // SAFETY: - // - "ptr must have been allocated using the global allocator, such as via the alloc::alloc function." - // - ✅ cpu uses global allocator - // - "T needs to have the same alignment as what ptr was allocated with." - // - ✅ we are matching on the alignment below - // - "The size of T times the capacity needs to be the same size as the pointer was allocated with." - // - ✅ covered by `key.num_bytes / key.size` and the `key.num_bytes % key.size == 0` assertion above - // - "length needs to be less than or equal to capacity." - // - ✅ they are equal - // - "The first length values must be properly initialized values of type T." - // - ✅ any bit pattern is valid for unsigned ints used below - // - "capacity needs to be the capacity that the pointer was allocated with." - // - ✅ handled by assertion above (key.num_bytes % key.size == 0) - // - "The allocated size in bytes must be no larger than isize::MAX. See the safety documentation of pointer::offset." - // - ✅ handled by assertion above - debug_assert_eq!(std::alloc::Layout::new::().align(), 1); - debug_assert_eq!(std::alloc::Layout::new::().align(), 2); - debug_assert_eq!(std::alloc::Layout::new::().align(), 4); - debug_assert_eq!(std::alloc::Layout::new::().align(), 8); - match key.alignment { - 1 => unsafe { drop(Vec::from_raw_parts(alloc.0, len, cap)) }, - 2 => unsafe { drop(Vec::from_raw_parts(alloc.0 as *mut u16, len, cap)) }, - 4 => unsafe { drop(Vec::from_raw_parts(alloc.0 as *mut u32, len, cap)) }, - 8 => unsafe { drop(Vec::from_raw_parts(alloc.0 as *mut u64, len, cap)) }, - _ => unreachable!(), - }; - } - } - cache.clear(); - Ok(()) + self.cache.try_clear() } } diff --git a/dfdx-core/src/tensor/cuda/device.rs b/dfdx-core/src/tensor/cuda/device.rs index de6f7196b..fc9c8225f 100644 --- a/dfdx-core/src/tensor/cuda/device.rs +++ b/dfdx-core/src/tensor/cuda/device.rs @@ -29,7 +29,7 @@ pub struct Cuda { /// A second stream for kernels to optionally execute on. pub(crate) par_stream: Arc, pub(crate) workspace: Arc>>, - pub(crate) cache: Arc>, + pub(crate) cache: Arc>>, } impl From for Error { @@ -77,6 +77,7 @@ impl Cuda { let cudnn = cudarc::cudnn::Cudnn::new(dev.clone())?; let par_stream = Arc::new(dev.fork_default_stream()?); let workspace = Arc::new(Mutex::new(dev.alloc_zeros::(1)?)); + let cache = Arc::new(TensorCache::new(Arc::clone(&dev))); Ok(Self { cpu, dev, @@ -85,7 +86,7 @@ impl Cuda { cudnn, par_stream, workspace, - cache: Default::default(), + cache, }) } } @@ -100,7 +101,7 @@ impl Cuda { ) -> Result, Error> { let data = self.cache.try_pop::(len).map_or_else( || self.dev.alloc::(len), - |ptr| Ok(self.dev.upgrade_device_ptr(ptr, len)), + |ptr| Ok(self.dev.upgrade_device_ptr(ptr.0, len)), )?; Ok(data) } @@ -122,6 +123,18 @@ impl Cuda { } } +/// A pointer to a bytes on the Cuda device. Used in conjunction with [TensorCache]. +#[repr(transparent)] +#[derive(Clone, Debug)] +pub struct CudaBytesPtr(pub(crate) CUdeviceptr); + +impl crate::tensor::cache::CachePtr> for CudaBytesPtr { + fn dealloc(self, key: &crate::tensor::cache::AllocationKey, dev: &Arc) { + let data = unsafe { dev.upgrade_device_ptr::(self.0, key.num_bytes) }; + drop(data); + } +} + /// A [CudaSlice] that can be cloned without allocating new memory. /// When [Drop]ed it will insert it's data into the cache. #[derive(Debug)] @@ -129,7 +142,7 @@ pub struct CachableCudaSlice { /// The actual data. pub(crate) data: CudaSlice, /// A cache of device pointers that can be reused. - pub(crate) cache: Arc>, + pub(crate) cache: Arc>>, } impl Clone for CachableCudaSlice { @@ -142,7 +155,7 @@ impl Clone for CachableCudaSlice { // SAFETY: // 1. we know that ptr is valid for `num_bytes` because it was registered for that. // 2. we are about to set the memory with dtod_copy - let mut slice = unsafe { dev.upgrade_device_ptr(ptr, len) }; + let mut slice = unsafe { dev.upgrade_device_ptr(ptr.0, len) }; dev.dtod_copy(&self.data, &mut slice).unwrap(); slice }, @@ -209,7 +222,7 @@ impl Drop for CachableCudaSlice { let numel = data.len(); // Get access to the raw pointer without freeing it. let ptr = data.leak(); - self.cache.insert::(numel, ptr); + self.cache.insert::(numel, CudaBytesPtr(ptr)); } } } @@ -232,18 +245,7 @@ impl Cache for Cuda { } fn try_empty_cache(&self) -> Result<(), Error> { - #[cfg(not(feature = "no-std"))] - let mut cache = self.cache.allocations.write().unwrap(); - #[cfg(feature = "no-std")] - let mut cache = self.cache.allocations.write(); - for (&key, allocations) in cache.iter_mut() { - for alloc in allocations.drain(..) { - let data = unsafe { self.dev.upgrade_device_ptr::(alloc, key.num_bytes) }; - drop(data); - } - } - cache.clear(); - Ok(()) + self.cache.try_clear() } } diff --git a/dfdx-core/src/tensor/webgpu/device.rs b/dfdx-core/src/tensor/webgpu/device.rs index 1c23989b3..52acf5e56 100644 --- a/dfdx-core/src/tensor/webgpu/device.rs +++ b/dfdx-core/src/tensor/webgpu/device.rs @@ -109,7 +109,7 @@ pub struct Webgpu { pub(crate) dev: Arc, pub(crate) queue: Arc, - pub(crate) cache: Arc>, + pub(crate) cache: Arc>, pub(crate) cs_cache: Arc>>>, } @@ -297,12 +297,22 @@ impl Webgpu { // } } +/// Unit struct to represent information needed for managing allocations on the WebGpu. +#[derive(Clone, Debug, Default)] +pub(crate) struct WebGpuDevice; + +impl crate::tensor::cache::CachePtr for Buffer { + fn dealloc(self, _key: &crate::tensor::cache::AllocationKey, _dev: &WebGpuDevice) { + drop(self) + } +} + #[derive(Debug)] pub struct CachableBuffer { pub(crate) dev: Arc, pub(crate) queue: Arc, pub(crate) data: Buffer, - pub(crate) cache: Arc>, + pub(crate) cache: Arc>, pub(crate) _phantom: PhantomData, } @@ -397,17 +407,7 @@ impl Cache for Webgpu { } fn try_empty_cache(&self) -> Result<(), Error> { - #[cfg(not(feature = "no-std"))] - let mut cache = self.cache.allocations.write().unwrap(); - #[cfg(feature = "no-std")] - let mut cache = self.cache.allocations.write(); - for (&_key, allocations) in cache.iter_mut() { - for alloc in allocations.drain(..) { - drop(alloc); - } - } - cache.clear(); - Ok(()) + self.cache.try_clear() } }