From d9ba8cc07998a72ce605fb2508962a582d0e3ab7 Mon Sep 17 00:00:00 2001 From: Sander Hautvast Date: Tue, 14 Feb 2023 17:42:02 +0100 Subject: [PATCH] not working --- .gitignore | 3 + Cargo.lock | 320 ++++++++++++++++++++++++++++++++++++++++++++++ Cargo.toml | 13 ++ src/dataloader.rs | 60 +++++++++ src/lib.rs | 2 + src/main.rs | 15 +++ src/net.rs | 209 ++++++++++++++++++++++++++++++ 7 files changed, 622 insertions(+) create mode 100644 .gitignore create mode 100644 Cargo.lock create mode 100644 Cargo.toml create mode 100644 src/dataloader.rs create mode 100644 src/lib.rs create mode 100644 src/main.rs create mode 100644 src/net.rs diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..8b7a06b --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +/target +*.iml +.idea \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..131fe6a --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,320 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "approx" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6" +dependencies = [ + "num-traits", +] + +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + +[[package]] +name = "bytemuck" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c041d3eab048880cb0b86b256447da3f18859a163c3b8d8893f4e6368abe6393" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "getrandom" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "itoa" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fad582f4b9e86b6caa621cabeb0963332d92eea04729ab12892c2533951e6440" + +[[package]] +name = "libc" +version = "0.2.139" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "201de327520df007757c1f0adce6e827fe8562fbc28bfd9c15571c66ca1f5f79" + +[[package]] +name = "libm" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "348108ab3fba42ec82ff6e9564fc4ca0247bdccdc68dd8af9764bbc79c3c8ffb" + +[[package]] +name = "matrixmultiply" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "add85d4dd35074e6fedc608f8c8f513a3548619a9024b751949ef0e8e45a4d84" +dependencies = [ + "rawpointer", +] + +[[package]] +name = "mnist-rs" +version = "0.1.0" +dependencies = [ + "nalgebra", + "rand", + "rand_distr", + "serde", + "serde_json", +] + +[[package]] +name = "nalgebra" +version = "0.32.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6515c882ebfddccaa73ead7320ca28036c4bc84c9bcca3cc0cbba8efe89223a" +dependencies = [ + "approx", + "matrixmultiply", + "nalgebra-macros", + "num-complex", + "num-rational", + "num-traits", + "simba", + "typenum", +] + +[[package]] +name = "nalgebra-macros" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d232c68884c0c99810a5a4d333ef7e47689cfd0edc85efc9e54e1e6bf5212766" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "num-complex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02e0d21255c828d6f128a1e41534206671e8c3ea0c62f32291e808dc82cff17d" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" +dependencies = [ + "autocfg", + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" +dependencies = [ + "autocfg", + "libm", +] + +[[package]] +name = "paste" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d01a5bd0424d00070b0098dd17ebca6f961a959dead1dbcbbbc1d1cd8d3deeba" + +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + +[[package]] +name = "proc-macro2" +version = "1.0.51" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d727cae5b39d21da60fa540906919ad737832fe0b1c165da3a34d6548c849d6" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8856d8364d252a14d474036ea1358d63c9e6965c8e5c1885c18f73d70bff9c7b" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rand_distr" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +dependencies = [ + "num-traits", + "rand", +] + +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + +[[package]] +name = "ryu" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b4b9743ed687d4b4bcedf9ff5eaa7398495ae14e61cba0a295704edbc7decde" + +[[package]] +name = "safe_arch" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "794821e4ccb0d9f979512f9c1973480123f9bd62a90d74ab0f9426fcf8f4a529" +dependencies = [ + "bytemuck", +] + +[[package]] +name = "serde" +version = "1.0.152" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb7d1f0d3021d347a83e556fc4683dea2ea09d87bccdf88ff5c12545d89d5efb" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.152" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af487d118eecd09402d70a5d72551860e788df87b464af30e5ea6a38c75c541e" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.93" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cad406b69c91885b5107daf2c29572f6c8cdb3c66826821e286c533490c0bc76" +dependencies = [ + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "simba" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50582927ed6f77e4ac020c057f37a268fc6aebc29225050365aacbb9deeeddc4" +dependencies = [ + "approx", + "num-complex", + "num-traits", + "paste", + "wide", +] + +[[package]] +name = "syn" +version = "1.0.107" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f4064b5b16e03ae50984a5a8ed5d4f8803e6bc1fd170a3cda91a1be4b18e3f5" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "typenum" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" + +[[package]] +name = "unicode-ident" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "84a22b9f218b40614adcb3f4ff08b703773ad44fa9423e4e0d346d5db86e4ebc" + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "wide" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "feff0a412894d67223777b6cc8d68c0dab06d52d95e9890d5f2d47f10dd9366c" +dependencies = [ + "bytemuck", + "safe_arch", +] diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..d87fb52 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "mnist-rs" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +rand = "0.8" +rand_distr = "0.4" +nalgebra = "0.32" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" \ No newline at end of file diff --git a/src/dataloader.rs b/src/dataloader.rs new file mode 100644 index 0000000..7fae50b --- /dev/null +++ b/src/dataloader.rs @@ -0,0 +1,60 @@ +use std::iter::zip; + +use rand::prelude::*; +use serde::Deserialize; + +pub fn load_data() -> Data { + /// the mnist data is structured as + /// x: [[[pixels]],[[pixels]], etc], + /// y: [label1, label2, etc] + /// this is transformed to: + /// Data : Vec + /// DataLine {inputs: Vec, label: f32} + let raw_data: RawData = serde_json::from_slice(include_bytes!("data/unittest.json")).unwrap(); + let mut vec = Vec::new(); + for (x, y) in zip(raw_data.x, raw_data.y) { + vec.push(DataLine { inputs: x, label: y}); + } + + Data(vec) +} + +#[derive(Deserialize)] +struct RawData { + x: Vec>, + y: Vec, +} + +/// X is type of input +/// Y is type of output +pub struct DataLine { + pub inputs: Vec, + pub label: Y, +} + +pub struct Data(pub Vec>); + + +impl Data { + pub fn shuffle(&mut self) { + let mut rng = rand::thread_rng(); + self.0.shuffle(&mut rng); + } + + pub fn len(&self) -> usize { + self.0.len() + } + + pub fn as_batches(&self, batch_size: usize) -> Vec<&[DataLine]> { + let mut batches = Vec::with_capacity(self.0.len() / batch_size + 1); + let mut offset = 0; + for _ in 0..self.0.len() / batch_size { + batches.push(&self.0[offset..offset + batch_size]); + offset += batch_size; + } + batches.push(&self.0[offset..self.0.len()]); + batches + } + + +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..f154dc1 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,2 @@ +pub mod net; +pub mod dataloader; \ No newline at end of file diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..ebd727d --- /dev/null +++ b/src/main.rs @@ -0,0 +1,15 @@ +use mnist_rs::dataloader::load_data; + +fn main() { + let mut net = mnist_rs::net::Network::from(vec![784, 30, 10]); + for w in net.weights.iter() { + println!("{}, {}", w.shape().0, w.shape().1); + } + println!(); + for b in net.biases.iter() { + println!("{:?}", b.shape()); + } + let training_data = load_data(); + + net.sgd(training_data, 30, 10, 3.0, &None); +} \ No newline at end of file diff --git a/src/net.rs b/src/net.rs new file mode 100644 index 0000000..49504ef --- /dev/null +++ b/src/net.rs @@ -0,0 +1,209 @@ +use std::convert::identity; +use std::iter::zip; +use std::ops::{Add, Sub}; + +use nalgebra::{DMatrix, Matrix, OMatrix}; +use rand::prelude::*; +use rand_distr::Normal; + +use crate::dataloader::{Data, DataLine}; + +#[derive(Debug)] +pub struct Network { + _sizes: Vec, + _num_layers: usize, + pub biases: Vec>, + pub weights: Vec>, +} + +impl Network { + /// The list `sizes` contains the number of neurons in the + /// respective layers of the network. For example, if the list + /// was [2, 3, 1] then it would be a three-layer network, with the + /// first layer containing 2 neurons, the second layer 3 neurons, + /// and the third layer 1 neuron. The biases and weights for the + /// network are initialized randomly, using a Gaussian + /// distribution with mean 0, and variance 1. Note that the first + /// layer is assumed to be an input layer, and by convention we + /// won't set any biases for those neurons, since biases are only + /// ever used in computing the outputs from later layers. + pub fn from(sizes: Vec) -> Self { + Self { + _sizes: sizes.clone(), + _num_layers: sizes.len(), + biases: biases(sizes[1..].to_vec()), + weights: weights(zip(sizes[..sizes.len() - 1].to_vec(), sizes[1..].to_vec()).collect()), + } + } + + fn feed_forward(&self, input: Vec) -> Vec { + let mut a = DMatrix::from_vec(input.len(), 1, input); + for (b, w) in zip(&self.biases, &self.weights) { + a = b.add_scalar(w.dot(&a)); + a.apply(sigmoid_inplace); + } + a.column(1).iter().map(|v| *v).collect() + } + + pub fn sgd(&mut self, mut training_data: Data, epochs: usize, minibatch_size: usize, eta: f32, test_data: &Option>) { + for j in 0..epochs { + training_data.shuffle(); + let mini_batches = training_data.as_batches(minibatch_size); + for mini_batch in mini_batches { + self.update_mini_batch(mini_batch, eta); + } + + if let Some(test_data) = test_data { + println!("Epoch {}: {} / {}", j, self.evaluate(test_data), test_data.len()); + } else { + println!("Epoch {} complete", j); + } + } + } + + /// Update the network's weights and biases by applying + /// gradient descent using backpropagation to a single mini batch. + /// The ``mini_batch`` is a list of tuples ``(x, y)``, and ``eta`` + /// is the learning rate. + fn update_mini_batch(&mut self, mini_batch: &[DataLine], eta: f32) { + let mut nabla_b: Vec> = self.biases.iter() + .map(|b| b.shape()) + .map(|s| DMatrix::zeros(s.0, s.1)) + .collect(); + let mut nabla_w: Vec> = self.weights.iter() + .map(|w| w.shape()) + .map(|s| DMatrix::zeros(s.0, s.1)) + .collect(); + for line in mini_batch.iter() { + let (delta_nabla_b, delta_nabla_w) = self.backprop(line.inputs.to_vec(), line.label); + + nabla_b = zip(&nabla_b, &delta_nabla_b).map(|(nb, dnb)| nb.add(dnb)).collect(); + nabla_w = zip(&nabla_w, &delta_nabla_w).map(|(nw, dnw)| nw.add(dnw)).collect(); + } + + self.weights = zip(&self.weights, &nabla_w) + .map(|(w, nw)| w.add_scalar(-eta / mini_batch.len() as f32)).collect(); + self.biases = zip(&self.biases, &nabla_b) + .map(|(b, nb)| b.add_scalar(-eta / mini_batch.len() as f32)).collect(); + } + + /// Return the number of test inputs for which the neural + /// network outputs the correct result. Note that the neural + /// network's output is assumed to be the index of whichever + /// neuron in the final layer has the highest activation. + fn evaluate(&self, test_data: &Data) -> usize { + let test_results: Vec<(usize, u8)> = test_data.0.iter() + .map(|line| (argmax(self.feed_forward(line.inputs.clone())), line.label)) + .collect(); + test_results.into_iter().filter(|(x, y)| *x == *y as usize).count() + } + + /// Return a tuple `(nabla_b, nabla_w)` representing the + /// gradient for the cost function C_x. `nabla_b` and + /// `nabla_w` are layer-by-layer lists of matrices, similar + /// to `self.biases` and `self.weights`. + fn backprop(&self, x: Vec, y: u8) -> (Vec>, Vec>) { + // zero_grad ie. set gradient to zero + let mut nabla_b: Vec> = self.biases.iter() + .map(|b| b.shape()) + .map(|s| DMatrix::zeros(s.0, s.1)) + .collect(); + let mut nabla_w: Vec> = self.weights.iter() + .map(|w| w.shape()) + .map(|s| DMatrix::zeros(s.0, s.1)) + .collect(); + + // feedforward + let mut activation = DMatrix::from_vec(x.len(), 1, x); + let mut activations = vec![activation.clone()]; + let mut zs = vec![]; + + for (b, w) in zip(&self.biases, &self.weights) { + // println!("{:?}", w.shape()); + // println!("{:?}", activation.shape()); + // println!("{:?}", b.shape()); + + let mut z: DMatrix = w * &activation + b; + zs.push(z.clone()); + activation = z.map(sigmoid); + activations.push(activation.clone()); + } + + // backward pass + let delta: DMatrix = self.cost_derivative( + &activations[activations.len() - 1], + y as f32); + println!("delta {:?}", delta.shape()); + println!("z {:?}", &zs[zs.len() - 1].transpose().shape()); + let delta = delta * (&zs[zs.len() - 1].transpose().map(sigmoid_prime)); + println!("delta {:?}", delta.shape()); + let index = nabla_b.len() - 1; + nabla_b[index] = delta.clone(); + + println!("delta {:?}", delta.shape()); + println!("activation {:?}", activations[activations.len() - 2].shape()); + let index = nabla_w.len() - 1; + nabla_w[index] = delta * &activations[activations.len() - 2]; + + + (nabla_b, nabla_w) + } + + fn cost_derivative(&self, output_activations: &DMatrix, y: f32) -> DMatrix { + output_activations.add_scalar(-y) + } +} + +fn argmax(val: Vec) -> usize { + let mut max = 0.0; + let mut index = 0; + for (i, x) in val.iter().enumerate() { + if *x > max { + index = i; + max = *x; + } + } + index +} + +fn biases(sizes: Vec) -> Vec> { + sizes.iter().map(|size| random_matrix(*size, 1)).collect() +} + +fn weights(sizes: Vec<(usize, usize)>) -> Vec> { + println!("{:?}", sizes); + sizes.iter().map(|size| random_matrix(size.1, size.0)).collect() +} + +fn random_matrix(rows: usize, cols: usize) -> DMatrix { + let normal: Normal = Normal::new(0.0, 1.0).unwrap(); + + DMatrix::from_fn(rows, cols, |_, _| normal.sample(&mut thread_rng())) +} + +fn sigmoid_inplace(val: &mut f32) { + *val = sigmoid(*val); +} + +fn sigmoid(val: f32) -> f32 { + 1.0 / (1.0 + (-val).exp()) +} + +/// Derivative of the sigmoid function. +fn sigmoid_prime(val: f32) -> f32 { + sigmoid(val) * (1.0 - sigmoid(val)) +} + +#[cfg(test)] +mod test { + use nalgebra::DMatrix; + + use super::*; + + #[test] + fn test_sigmoid() { + let mut mat: DMatrix = DMatrix::from_vec(1, 1, vec![0.0]); + mat.apply(sigmoid_inplace); + assert_eq!(mat.get(0), Some(&0.5)); + } +} \ No newline at end of file