From 0bd18ba31485af354a3bbf25b3f9e5c43090ccc6 Mon Sep 17 00:00:00 2001 From: Shautvast Date: Tue, 4 Feb 2025 08:09:00 +0100 Subject: [PATCH] changed to f32: ~30s -> ~25s --- Cargo.lock | 171 ++++++++++++++++++++++++++++++++++++++-------- Cargo.toml | 6 +- src/dataloader.rs | 31 +++++---- src/net.rs | 52 +++++++------- 4 files changed, 189 insertions(+), 71 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f7fce15..cbf7685 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "approx" @@ -17,6 +17,12 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +[[package]] +name = "bitflags" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" + [[package]] name = "bytemuck" version = "1.13.0" @@ -31,13 +37,14 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "getrandom" -version = "0.2.8" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31" +checksum = "43a49c392881ce6d5c3b8cb70f98717b7c07aabbdff06687b9030dbfbe2725f8" dependencies = [ "cfg-if", "libc", "wasi", + "windows-targets", ] [[package]] @@ -48,9 +55,9 @@ checksum = "fad582f4b9e86b6caa621cabeb0963332d92eea04729ab12892c2533951e6440" [[package]] name = "libc" -version = "0.2.139" +version = "0.2.169" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "201de327520df007757c1f0adce6e827fe8562fbc28bfd9c15571c66ca1f5f79" +checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" [[package]] name = "libm" @@ -80,9 +87,9 @@ dependencies = [ [[package]] name = "nalgebra" -version = "0.32.1" +version = "0.33.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6515c882ebfddccaa73ead7320ca28036c4bc84c9bcca3cc0cbba8efe89223a" +checksum = "26aecdf64b707efd1310e3544d709c5c0ac61c13756046aaaba41be5c4f66a3b" dependencies = [ "approx", "matrixmultiply", @@ -96,13 +103,13 @@ dependencies = [ [[package]] name = "nalgebra-macros" -version = "0.2.0" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d232c68884c0c99810a5a4d333ef7e47689cfd0edc85efc9e54e1e6bf5212766" +checksum = "254a5372af8fc138e36684761d3c0cdb758a4410e938babcff1c860ce14ddbfc" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.98", ] [[package]] @@ -159,38 +166,38 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.51" +version = "1.0.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d727cae5b39d21da60fa540906919ad737832fe0b1c165da3a34d6548c849d6" +checksum = "60946a68e5f9d28b0dc1c21bb8a97ee7d018a8b322fa57838ba31cc878e22d99" dependencies = [ "unicode-ident", ] [[package]] name = "quote" -version = "1.0.23" +version = "1.0.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8856d8364d252a14d474036ea1358d63c9e6965c8e5c1885c18f73d70bff9c7b" +checksum = "0e4dccaaaf89514f546c693ddc140f729f958c247918a13380cccc6078391acc" dependencies = [ "proc-macro2", ] [[package]] name = "rand" -version = "0.8.5" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94" dependencies = [ - "libc", "rand_chacha", "rand_core", + "zerocopy", ] [[package]] name = "rand_chacha" -version = "0.3.1" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" dependencies = [ "ppv-lite86", "rand_core", @@ -198,18 +205,19 @@ dependencies = [ [[package]] name = "rand_core" -version = "0.6.4" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +checksum = "b08f3c9802962f7e1b25113931d94f43ed9725bebc59db9d0c3e9a23b67e15ff" dependencies = [ "getrandom", + "zerocopy", ] [[package]] name = "rand_distr" -version = "0.4.3" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +checksum = "ddc3b5afe4c995c44540865b8ca5c52e6a59fa362da96c5d30886930ddc8da1c" dependencies = [ "num-traits", "rand", @@ -253,7 +261,7 @@ checksum = "af487d118eecd09402d70a5d72551860e788df87b464af30e5ea6a38c75c541e" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.107", ] [[package]] @@ -269,9 +277,9 @@ dependencies = [ [[package]] name = "simba" -version = "0.8.0" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50582927ed6f77e4ac020c057f37a268fc6aebc29225050365aacbb9deeeddc4" +checksum = "b3a386a501cd104797982c15ae17aafe8b9261315b5d07e3ec803f2ea26be0fa" dependencies = [ "approx", "num-complex", @@ -291,6 +299,17 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "syn" +version = "2.0.98" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36147f1a48ae0ec2b5b3bc5b537d267457555a10dc06f3dbc8cb11ba3006d3b1" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "typenum" version = "1.16.0" @@ -305,9 +324,12 @@ checksum = "84a22b9f218b40614adcb3f4ff08b703773ad44fa9423e4e0d346d5db86e4ebc" [[package]] name = "wasi" -version = "0.11.0+wasi-snapshot-preview1" +version = "0.13.3+wasi-0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +checksum = "26816d2e1a4a36a2940b96c5296ce403917633dff8f3440e9b236ed6f6bacad2" +dependencies = [ + "wit-bindgen-rt", +] [[package]] name = "wide" @@ -318,3 +340,96 @@ dependencies = [ "bytemuck", "safe_arch", ] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "wit-bindgen-rt" +version = "0.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c" +dependencies = [ + "bitflags", +] + +[[package]] +name = "zerocopy" +version = "0.8.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1e101d4bc320b6f9abb68846837b70e25e380ca2f467ab494bf29fcc435fcc3" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03a73df1008145cd135b3c780d275c57c3e6ba8324a41bd5e0008fe167c3bc7c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.98", +] diff --git a/Cargo.toml b/Cargo.toml index 04414ae..327b762 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,8 +4,8 @@ version = "1.0.0" edition = "2021" [dependencies] -rand = "0.8" -rand_distr = "0.4" -nalgebra = "0.32" +rand = "0.9" +rand_distr = "0.5" +nalgebra = "0.33" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" \ No newline at end of file diff --git a/src/dataloader.rs b/src/dataloader.rs index 4787e63..478238f 100644 --- a/src/dataloader.rs +++ b/src/dataloader.rs @@ -4,15 +4,18 @@ use nalgebra::DMatrix; use rand::prelude::*; use serde::Deserialize; -pub fn load_data() -> (Data, Data) { +pub fn load_data() -> (Data, Data) +{ // the mnist data is structured as // x: [[[pixels]],[[pixels]], etc], // y: [label1, label2, etc] // this is transformed to: // Data : Vec // DataLine {inputs: Vec, label: f64} - let raw_training_data: Vec = serde_json::from_slice(include_bytes!("data/training_data.json")).unwrap(); - let raw_test_data: Vec = serde_json::from_slice(include_bytes!("data/test_data.json")).unwrap(); + let raw_training_data: Vec = + serde_json::from_slice(include_bytes!("data/training_data.json")).unwrap(); + let raw_test_data: Vec = + serde_json::from_slice(include_bytes!("data/test_data.json")).unwrap(); let train = vectorize(raw_training_data); let test = vectorize(raw_test_data); @@ -20,17 +23,19 @@ pub fn load_data() -> (Data, Data) { (Data(train), Data(test)) } -fn vectorize(raw_training_data: Vec) -> Vec> { +fn vectorize(raw_data: Vec) -> Vec> +{ let mut result = Vec::new(); - for line in raw_training_data { + for line in raw_data { result.push(DataLine { inputs: DMatrix::from_vec(line.x.len(), 1, line.x), label: onehot(line.y) }); } result } #[derive(Deserialize)] -struct RawData { - x: Vec, +struct RawData +{ + x: Vec, y: u8, } @@ -43,19 +48,17 @@ pub struct DataLine where X: Clone, Y: Clone { } /// simple way to encode a onehot vector. An object that returns 1.0 if you get the 'right' index, or 0.0 otherwise -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct OneHotVector { pub val: usize, } -impl OneHotVector { +impl OneHotVector{ pub fn new(val: usize) -> Self { - Self { - val - } + Self { val } } - pub fn get(&self, index: usize) -> f64 { + pub fn get(&self, index: usize) -> f32 { if self.val == index { 1.0 } else { @@ -69,7 +72,7 @@ pub struct Data(pub Vec>) where X: Clone, Y: Clone; impl Data where X: Clone, Y: Clone { pub fn shuffle(&mut self) { - let mut rng = thread_rng(); + let mut rng = rand::rng(); self.0.shuffle(&mut rng); } diff --git a/src/net.rs b/src/net.rs index 7a7089d..eb4b861 100644 --- a/src/net.rs +++ b/src/net.rs @@ -11,11 +11,11 @@ use crate::dataloader::{Data, DataLine, OneHotVector}; pub struct Network { _sizes: Vec, num_layers: usize, - pub biases: Vec>, - pub weights: Vec>, + pub biases: Vec>, + pub weights: Vec>, } -impl Network { +impl Network{ /// The list `sizes` contains the number of neurons in the /// respective layers of the network. For example, if the list /// was [2, 3, 1] then it would be a three-layer network, with the @@ -48,11 +48,11 @@ impl Network { } } - fn feed_forward(&self, input: &DMatrix) -> DMatrix { + fn feed_forward(&self, input: &DMatrix) -> DMatrix { self.feed_forward_activation(input, sigmoid_inplace) } - fn feed_forward_activation(&self, input: &DMatrix, activation: fn(&mut f64)) -> DMatrix { + fn feed_forward_activation(&self, input: &DMatrix, activation: fn(&mut f32)) -> DMatrix { let mut a = input.clone(); for (b, w) in zip(&self.biases, &self.weights) { a = b + w * a; @@ -61,7 +61,7 @@ impl Network { a } - pub fn sgd(&mut self, mut training_data: Data, epochs: usize, minibatch_size: usize, eta: f64, test_data: Option>) { + pub fn sgd(&mut self, mut training_data: Data, epochs: usize, minibatch_size: usize, eta: f32, test_data: Option>) { for j in 0..epochs { training_data.shuffle(); let mini_batches = training_data.as_batches(minibatch_size); @@ -81,7 +81,7 @@ impl Network { /// gradient descent using backpropagation to a single mini batch. /// The ``mini_batch`` is a list of tuples ``(x, y)``, and ``eta`` /// is the learning rate. - fn update_mini_batch(&mut self, mini_batch: &[DataLine], eta: f64) { + fn update_mini_batch(&mut self, mini_batch: &[DataLine], eta: f32) { let (mut nabla_b, mut nabla_w) = self.zero_gradient(); for line in mini_batch.iter() { let (delta_nabla_b, delta_nabla_w) = self.backprop(&line.inputs, &line.label); @@ -91,17 +91,17 @@ impl Network { } self.weights = zip(&self.weights, &nabla_w) - .map(|(w, nw)| w.sub(nw.scale(eta / mini_batch.len() as f64))).collect(); + .map(|(w, nw)| w.sub(nw.scale(eta / mini_batch.len() as f32))).collect(); self.biases = zip(&self.biases, &nabla_b) - .map(|(b, nb)| b.sub(nb.scale(eta / mini_batch.len() as f64))).collect(); + .map(|(b, nb)| b.sub(nb.scale(eta / mini_batch.len() as f32))).collect(); } /// Return the number of test inputs for which the neural /// network outputs the correct result. Note that the neural /// network's output is assumed to be the index of whichever /// neuron in the final layer has the highest activation. - fn evaluate(&self, test_data: &Data) -> usize { + fn evaluate(&self, test_data: &Data) -> usize { let test_results: Vec<(usize, usize)> = test_data.0.iter() .map(|line| (argmax(self.feed_forward(&line.inputs)), line.label.val)) .collect(); @@ -113,7 +113,7 @@ impl Network { /// gradient for the cost function C_x. `nabla_b` and /// `nabla_w` are layer-by-layer lists of matrices, similar /// to `self.biases` and `self.weights`. - fn backprop(&self, x: &DMatrix, y: &OneHotVector) -> (Vec>, Vec>) { + fn backprop(&self, x: &DMatrix, y: &OneHotVector) -> (Vec>, Vec>) { let (mut nabla_b, mut nabla_w) = self.zero_gradient(); // feedforward @@ -128,7 +128,7 @@ impl Network { activations.push(activation.clone()); } // backward pass - let delta: DMatrix = cost_derivative(&activations[activations.len() - 1], y).component_mul(&zs[zs.len() - 1].map(sigmoid_prime)); + let delta: DMatrix = cost_derivative(&activations[activations.len() - 1], y).component_mul(&zs[zs.len() - 1].map(sigmoid_prime)); let index = nabla_b.len() - 1; nabla_b[index] = delta.clone(); @@ -148,12 +148,12 @@ impl Network { (nabla_b, nabla_w) } - fn zero_gradient(&self) -> (Vec>, Vec>) { - let nabla_b: Vec> = self.biases.iter() + fn zero_gradient(&self) -> (Vec>, Vec>) { + let nabla_b: Vec> = self.biases.iter() .map(|b| b.shape()) .map(|s| DMatrix::zeros(s.0, s.1)) .collect(); - let nabla_w: Vec> = self.weights.iter() + let nabla_w: Vec> = self.weights.iter() .map(|w| w.shape()) .map(|s| DMatrix::zeros(s.0, s.1)) .collect(); @@ -161,7 +161,7 @@ impl Network { } } -fn cost_derivative(output_activations: &DMatrix, y: &OneHotVector) -> DMatrix { +fn cost_derivative(output_activations: &DMatrix, y: &OneHotVector) -> DMatrix { let shape = output_activations.shape(); DMatrix::from_iterator(shape.0, shape.1, output_activations.iter().enumerate() .map(|(index, a)| a - y.get(index))) @@ -169,7 +169,7 @@ fn cost_derivative(output_activations: &DMatrix, y: &OneHotVector) -> DMatr /// index of max value /// only meaningful for single row or column matrix -fn argmax(val: DMatrix) -> usize { +fn argmax(val: DMatrix) -> usize { let mut max = 0.0; let mut index = 0; for (i, x) in val.iter().enumerate() { @@ -181,30 +181,30 @@ fn argmax(val: DMatrix) -> usize { index } -fn biases(sizes: Vec, init: fn(&usize) -> DMatrix) -> Vec> { +fn biases(sizes: Vec, init: fn(&usize) -> DMatrix) -> Vec> { sizes.iter().map(init).collect() } -fn weights(sizes: Vec<(usize, usize)>, init: fn(&(usize, usize)) -> DMatrix) -> Vec> { +fn weights(sizes: Vec<(usize, usize)>, init: fn(&(usize, usize)) -> DMatrix) -> Vec> { sizes.iter().map(init).collect() } -fn random_matrix(rows: usize, cols: usize) -> DMatrix { - let normal: Normal = Normal::new(0.0, 1.0).unwrap(); +fn random_matrix(rows: usize, cols: usize) -> DMatrix { + let normal: Normal = Normal::new(0.0, 1.0).unwrap(); - DMatrix::from_fn(rows, cols, |_, _| normal.sample(&mut thread_rng())) + DMatrix::from_fn(rows, cols, |_, _| normal.sample(&mut rand::rng())) } -fn sigmoid_inplace(val: &mut f64) { +fn sigmoid_inplace(val: &mut f32) { *val = sigmoid(*val); } -fn sigmoid(val: f64) -> f64 { +fn sigmoid(val: f32) -> f32 { 1.0 / (1.0 + (-val).exp()) } /// Derivative of the sigmoid function. -fn sigmoid_prime(val: f64) -> f64 { +fn sigmoid_prime(val: f32) -> f32 { sigmoid(val) * (1.0 - sigmoid(val)) } @@ -216,7 +216,7 @@ mod test { #[test] fn test_sigmoid() { - let mut mat: DMatrix = DMatrix::from_vec(1, 1, vec![0.0]); + let mut mat: DMatrix = DMatrix::from_vec(1, 1, vec![0.0]); mat.apply(sigmoid_inplace); assert_eq!(mat.get(0), Some(&0.5)); }