From 4615ac105a16f45e45d6ccfd403099daf437a944 Mon Sep 17 00:00:00 2001 From: Tim Wedde Date: Mon, 4 Dec 2023 23:18:32 +0100 Subject: [PATCH] Implementations for u8->u64 and i8->i64 (#873) --- dfdx-core/src/tensor/numpy.rs | 251 ++++++++++++++++++++++++++++++++++ 1 file changed, 251 insertions(+) diff --git a/dfdx-core/src/tensor/numpy.rs b/dfdx-core/src/tensor/numpy.rs index b362be2ad..0a22a97f3 100644 --- a/dfdx-core/src/tensor/numpy.rs +++ b/dfdx-core/src/tensor/numpy.rs @@ -282,6 +282,166 @@ impl NumpyDtype for f64 { } } +impl NumpyDtype for u8 { + const NUMPY_DTYPE_STR: &'static str = "u1"; + fn read_endian(r: &mut R, endian: Endian) -> io::Result { + let mut bytes = [0; 1]; + r.read_exact(&mut bytes)?; + Ok(match endian { + Endian::Big => Self::from_be_bytes(bytes), + Endian::Little => Self::from_le_bytes(bytes), + Endian::Native => Self::from_ne_bytes(bytes), + }) + } + fn write_endian(&self, w: &mut W, endian: Endian) -> io::Result<()> { + match endian { + Endian::Big => w.write_all(&self.to_be_bytes()), + Endian::Little => w.write_all(&self.to_le_bytes()), + Endian::Native => w.write_all(&self.to_ne_bytes()), + } + } +} + +impl NumpyDtype for u16 { + const NUMPY_DTYPE_STR: &'static str = "u2"; + fn read_endian(r: &mut R, endian: Endian) -> io::Result { + let mut bytes = [0; 2]; + r.read_exact(&mut bytes)?; + Ok(match endian { + Endian::Big => Self::from_be_bytes(bytes), + Endian::Little => Self::from_le_bytes(bytes), + Endian::Native => Self::from_ne_bytes(bytes), + }) + } + fn write_endian(&self, w: &mut W, endian: Endian) -> io::Result<()> { + match endian { + Endian::Big => w.write_all(&self.to_be_bytes()), + Endian::Little => w.write_all(&self.to_le_bytes()), + Endian::Native => w.write_all(&self.to_ne_bytes()), + } + } +} + +impl NumpyDtype for u32 { + const NUMPY_DTYPE_STR: &'static str = "u4"; + fn read_endian(r: &mut R, endian: Endian) -> io::Result { + let mut bytes = [0; 4]; + r.read_exact(&mut bytes)?; + Ok(match endian { + Endian::Big => Self::from_be_bytes(bytes), + Endian::Little => Self::from_le_bytes(bytes), + Endian::Native => Self::from_ne_bytes(bytes), + }) + } + fn write_endian(&self, w: &mut W, endian: Endian) -> io::Result<()> { + match endian { + Endian::Big => w.write_all(&self.to_be_bytes()), + Endian::Little => w.write_all(&self.to_le_bytes()), + Endian::Native => w.write_all(&self.to_ne_bytes()), + } + } +} + +impl NumpyDtype for u64 { + const NUMPY_DTYPE_STR: &'static str = "u8"; + fn read_endian(r: &mut R, endian: Endian) -> io::Result { + let mut bytes = [0; 8]; + r.read_exact(&mut bytes)?; + Ok(match endian { + Endian::Big => Self::from_be_bytes(bytes), + Endian::Little => Self::from_le_bytes(bytes), + Endian::Native => Self::from_ne_bytes(bytes), + }) + } + fn write_endian(&self, w: &mut W, endian: Endian) -> io::Result<()> { + match endian { + Endian::Big => w.write_all(&self.to_be_bytes()), + Endian::Little => w.write_all(&self.to_le_bytes()), + Endian::Native => w.write_all(&self.to_ne_bytes()), + } + } +} + +impl NumpyDtype for i8 { + const NUMPY_DTYPE_STR: &'static str = "i1"; + fn read_endian(r: &mut R, endian: Endian) -> io::Result { + let mut bytes = [0; 1]; + r.read_exact(&mut bytes)?; + Ok(match endian { + Endian::Big => Self::from_be_bytes(bytes), + Endian::Little => Self::from_le_bytes(bytes), + Endian::Native => Self::from_ne_bytes(bytes), + }) + } + fn write_endian(&self, w: &mut W, endian: Endian) -> io::Result<()> { + match endian { + Endian::Big => w.write_all(&self.to_be_bytes()), + Endian::Little => w.write_all(&self.to_le_bytes()), + Endian::Native => w.write_all(&self.to_ne_bytes()), + } + } +} + +impl NumpyDtype for i16 { + const NUMPY_DTYPE_STR: &'static str = "i2"; + fn read_endian(r: &mut R, endian: Endian) -> io::Result { + let mut bytes = [0; 2]; + r.read_exact(&mut bytes)?; + Ok(match endian { + Endian::Big => Self::from_be_bytes(bytes), + Endian::Little => Self::from_le_bytes(bytes), + Endian::Native => Self::from_ne_bytes(bytes), + }) + } + fn write_endian(&self, w: &mut W, endian: Endian) -> io::Result<()> { + match endian { + Endian::Big => w.write_all(&self.to_be_bytes()), + Endian::Little => w.write_all(&self.to_le_bytes()), + Endian::Native => w.write_all(&self.to_ne_bytes()), + } + } +} + +impl NumpyDtype for i32 { + const NUMPY_DTYPE_STR: &'static str = "i4"; + fn read_endian(r: &mut R, endian: Endian) -> io::Result { + let mut bytes = [0; 4]; + r.read_exact(&mut bytes)?; + Ok(match endian { + Endian::Big => Self::from_be_bytes(bytes), + Endian::Little => Self::from_le_bytes(bytes), + Endian::Native => Self::from_ne_bytes(bytes), + }) + } + fn write_endian(&self, w: &mut W, endian: Endian) -> io::Result<()> { + match endian { + Endian::Big => w.write_all(&self.to_be_bytes()), + Endian::Little => w.write_all(&self.to_le_bytes()), + Endian::Native => w.write_all(&self.to_ne_bytes()), + } + } +} + +impl NumpyDtype for i64 { + const NUMPY_DTYPE_STR: &'static str = "i8"; + fn read_endian(r: &mut R, endian: Endian) -> io::Result { + let mut bytes = [0; 8]; + r.read_exact(&mut bytes)?; + Ok(match endian { + Endian::Big => Self::from_be_bytes(bytes), + Endian::Little => Self::from_le_bytes(bytes), + Endian::Native => Self::from_ne_bytes(bytes), + }) + } + fn write_endian(&self, w: &mut W, endian: Endian) -> io::Result<()> { + match endian { + Endian::Big => w.write_all(&self.to_be_bytes()), + Endian::Little => w.write_all(&self.to_le_bytes()), + Endian::Native => w.write_all(&self.to_ne_bytes()), + } + } +} + #[derive(Debug)] pub enum NpyError { /// Magic number did not match the expected value. @@ -560,4 +720,95 @@ mod tests { .load_from_npy(file.path()) .expect_err(""); } + + #[test] + fn test_0d_u8_save() { + let dev: TestDevice = Default::default(); + + let x = dev.tensor(0u8); + + let file = NamedTempFile::new().expect("failed to create tempfile"); + + x.save_to_npy(file.path()).expect("Saving failed"); + + let mut f = File::open(file.path()).expect("No file found"); + + let mut found = Vec::new(); + f.read_to_end(&mut found).expect("Reading failed"); + + assert_eq!( + &found, + &[ + 147, 78, 85, 77, 80, 89, 1, 0, 64, 0, 123, 39, 100, 101, 115, 99, 114, 39, 58, 32, + 39, 60, 117, 49, 39, 44, 32, 39, 102, 111, 114, 116, 114, 97, 110, 95, 111, 114, + 100, 101, 114, 39, 58, 32, 70, 97, 108, 115, 101, 44, 32, 39, 115, 104, 97, 112, + 101, 39, 58, 32, 40, 41, 44, 32, 125, 32, 32, 32, 32, 32, 32, 32, 32, 10, 0 + ] + ); + } + + #[test] + fn test_0d_u8_load() { + let dev: TestDevice = Default::default(); + let x = dev.tensor(2u8); + + let file = NamedTempFile::new().expect("failed to create tempfile"); + + x.save_to_npy(file.path()).expect("Saving failed"); + + let mut v = dev.tensor(0u8); + v.load_from_npy(file.path()).expect("Loading failed"); + assert_eq!(v.array(), x.array()); + + dev.tensor(0u16).load_from_npy(file.path()).expect_err(""); + dev.tensor([0u8; 1]) + .load_from_npy(file.path()) + .expect_err(""); + } + + #[test] + fn test_0d_i8_save() { + let dev: TestDevice = Default::default(); + + let x = dev.tensor(0i8); + + let file = NamedTempFile::new().expect("failed to create tempfile"); + + x.save_to_npy(file.path()).expect("Saving failed"); + x.save_to_npy("out.npy").expect("Saving failed"); + + let mut f = File::open(file.path()).expect("No file found"); + + let mut found = Vec::new(); + f.read_to_end(&mut found).expect("Reading failed"); + + assert_eq!( + &found, + &[ + 147, 78, 85, 77, 80, 89, 1, 0, 64, 0, 123, 39, 100, 101, 115, 99, 114, 39, 58, 32, + 39, 60, 105, 49, 39, 44, 32, 39, 102, 111, 114, 116, 114, 97, 110, 95, 111, 114, + 100, 101, 114, 39, 58, 32, 70, 97, 108, 115, 101, 44, 32, 39, 115, 104, 97, 112, + 101, 39, 58, 32, 40, 41, 44, 32, 125, 32, 32, 32, 32, 32, 32, 32, 32, 10, 0 + ] + ); + } + + #[test] + fn test_0d_i8_load() { + let dev: TestDevice = Default::default(); + let x = dev.tensor(2i8); + + let file = NamedTempFile::new().expect("failed to create tempfile"); + + x.save_to_npy(file.path()).expect("Saving failed"); + + let mut v = dev.tensor(0i8); + v.load_from_npy(file.path()).expect("Loading failed"); + assert_eq!(v.array(), x.array()); + + dev.tensor(0i16).load_from_npy(file.path()).expect_err(""); + dev.tensor([0i8; 1]) + .load_from_npy(file.path()) + .expect_err(""); + } }