diff --git a/.gitignore b/.gitignore index 8b7a06b..8462796 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ /target *.iml -.idea \ No newline at end of file +.idea +src/data/training.json \ No newline at end of file diff --git a/src/data/unittest.json b/src/data/unittest.json new file mode 100644 index 0000000..da13f6f --- /dev/null +++ b/src/data/unittest.json @@ -0,0 +1 @@ +{"x":[[0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.001171875,0.00703125,0.00703125,0.00703125,0.04921875,0.053125,0.068359375,0.01015625,0.06484375,0.099609375,0.096484375,0.049609375,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.01171875,0.0140625,0.03671875,0.06015625,0.06640625,0.098828125,0.098828125,0.098828125,0.098828125,0.098828125,0.087890625,0.0671875,0.098828125,0.09453125,0.076171875,0.025,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.019140625,0.09296875,0.098828125,0.098828125,0.098828125,0.098828125,0.098828125,0.098828125,0.098828125,0.098828125,0.098046875,0.036328125,0.03203125,0.03203125,0.021875,0.015234375,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.00703125,0.085546875,0.098828125,0.098828125,0.098828125,0.098828125,0.098828125,0.07734375,0.07109375,0.096484375,0.094140625,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.03125,0.0609375,0.041796875,0.098828125,0.098828125,0.080078125,0.004296875,0.0,0.016796875,0.06015625,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.00546875,0.000390625,0.06015625,0.098828125,0.03515625,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.054296875,0.098828125,0.07421875,0.00078125,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.004296875,0.07421875,0.098828125,0.02734375,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.013671875,0.094140625,0.087890625,0.0625,0.0421875,0.000390625,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.031640625,0.09375,0.098828125,0.098828125,0.046484375,0.009765625,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.017578125,0.07265625,0.098828125,0.098828125,0.05859375,0.010546875,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.00625,0.036328125,0.0984375,0.098828125,0.073046875,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.097265625,0.098828125,0.097265625,0.025,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.01796875,0.05078125,0.071484375,0.098828125,0.098828125,0.080859375,0.00078125,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.015234375,0.0578125,0.089453125,0.098828125,0.098828125,0.098828125,0.09765625,0.07109375,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.009375,0.04453125,0.086328125,0.098828125,0.098828125,0.098828125,0.098828125,0.078515625,0.03046875,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.008984375,0.02578125,0.083203125,0.098828125,0.098828125,0.098828125,0.098828125,0.07734375,0.031640625,0.00078125,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.00703125,0.066796875,0.085546875,0.098828125,0.098828125,0.098828125,0.098828125,0.076171875,0.03125,0.003515625,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.021484375,0.0671875,0.08828125,0.098828125,0.098828125,0.098828125,0.098828125,0.0953125,0.051953125,0.004296875,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.053125,0.098828125,0.098828125,0.098828125,0.0828125,0.052734375,0.0515625,0.00625,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]], "y":[5]} \ No newline at end of file diff --git a/src/dataloader.rs b/src/dataloader.rs index 7fae50b..a3ba0e5 100644 --- a/src/dataloader.rs +++ b/src/dataloader.rs @@ -1,19 +1,20 @@ use std::iter::zip; +use nalgebra::DMatrix; use rand::prelude::*; use serde::Deserialize; -pub fn load_data() -> Data { - /// the mnist data is structured as - /// x: [[[pixels]],[[pixels]], etc], - /// y: [label1, label2, etc] - /// this is transformed to: - /// Data : Vec - /// DataLine {inputs: Vec, label: f32} +pub fn load_data() -> Data { + // the mnist data is structured as + // x: [[[pixels]],[[pixels]], etc], + // y: [label1, label2, etc] + // this is transformed to: + // Data : Vec + // DataLine {inputs: Vec, label: f32} let raw_data: RawData = serde_json::from_slice(include_bytes!("data/unittest.json")).unwrap(); let mut vec = Vec::new(); for (x, y) in zip(raw_data.x, raw_data.y) { - vec.push(DataLine { inputs: x, label: y}); + vec.push(DataLine { inputs: x, label: onehot(y) }); } Data(vec) @@ -27,17 +28,38 @@ struct RawData { /// X is type of input /// Y is type of output -pub struct DataLine { +pub struct DataLine { pub inputs: Vec, pub label: Y, } -pub struct Data(pub Vec>); +pub struct Data(pub Vec>); + +pub struct OneHotVector{ + pub val: usize +} + +impl OneHotVector{ + fn new(val: usize) -> Self{ + Self{ + val + } + } + + pub fn get(&self, index: usize) -> f32{ + if self.val == index { + 1.0 + } else { + 0.0 + } + } -impl Data { +} + +impl Data { pub fn shuffle(&mut self) { - let mut rng = rand::thread_rng(); + let mut rng = thread_rng(); self.0.shuffle(&mut rng); } @@ -45,7 +67,7 @@ impl Data { self.0.len() } - pub fn as_batches(&self, batch_size: usize) -> Vec<&[DataLine]> { + pub fn as_batches(&self, batch_size: usize) -> Vec<&[DataLine]> { let mut batches = Vec::with_capacity(self.0.len() / batch_size + 1); let mut offset = 0; for _ in 0..self.0.len() / batch_size { @@ -55,6 +77,9 @@ impl Data { batches.push(&self.0[offset..self.0.len()]); batches } - - } + +/// returns a vector as matrix where y is one-hot encoded +fn onehot(y: u8) -> OneHotVector { + OneHotVector::new(y as usize) +} \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index f154dc1..ca821ee 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,2 +1,3 @@ pub mod net; -pub mod dataloader; \ No newline at end of file +pub mod dataloader; +mod mat; \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index ebd727d..1f1a7f2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,14 +2,15 @@ use mnist_rs::dataloader::load_data; fn main() { let mut net = mnist_rs::net::Network::from(vec![784, 30, 10]); - for w in net.weights.iter() { - println!("{}, {}", w.shape().0, w.shape().1); - } - println!(); - for b in net.biases.iter() { - println!("{:?}", b.shape()); - } let training_data = load_data(); net.sgd(training_data, 30, 10, 3.0, &None); + + + // let sizes = vec![5,3,2]; + // let net = mnist_rs::net::Network::from(sizes); + // println!("biases {:?}", net.biases.iter().map(|b|b.shape()).collect::>()); + // println!("weights {:?}", net.weights.iter().map(|b|b.shape()).collect::>()); + + } \ No newline at end of file diff --git a/src/mat.rs b/src/mat.rs new file mode 100644 index 0000000..4564cc1 --- /dev/null +++ b/src/mat.rs @@ -0,0 +1,205 @@ +use core::ops::Add; +use std::fmt::Debug; +use std::ops::AddAssign; +use nalgebra::DMatrix; + +pub fn add(v1: DMatrix, v2: DMatrix) -> Result, String> + where T: PartialEq + Copy + Clone + Debug + Add + Add + AddAssign + 'static +{ + let (r1, c1) = v1.shape(); + let (r2, c2) = v2.shape(); + + if r1 == r2 && c1 == c2 { + // same size, no broadcasting needed + Ok(v1 + v2) + } else if r1 == 1 && c2 == 1 { + Ok(DMatrix::from_fn(r2, c1, |r, c| *v1.get(c).unwrap() + *v2.get(r).unwrap())) + } else if c1 == 1 && r2 == 1 { + Ok(DMatrix::from_fn(r1, c2, |r, c| *v1.get(r).unwrap() + *v2.get(c).unwrap())) + } else if r1 == 1 && c1 == c2 { + Ok(DMatrix::from_fn(r2, c1, |r, c| *v1.get(c).unwrap() + *v2.get(c * r2 + r).unwrap())) + } else if r2 == 1 && c1 == c2 { + Ok(DMatrix::from_fn(r1, c2, |r, c| *v2.get(c).unwrap() + *v1.get(c * r1 + r).unwrap())) + } else if c1 == 1 && r1 == r2 { + Ok(DMatrix::from_fn(r1, c2, |r, c| *v1.get(r).unwrap() + *v2.get(c * r2 + r).unwrap())) + } else if c2 == 1 && r1 == r2 { + Ok(DMatrix::from_fn(r2, c1, |r, c| *v2.get(r).unwrap() + *v1.get(c * r2 + r).unwrap())) + } else { + Err(format!("ValueError: operands could not be broadcast together ({},{}), ({},{})", r1,c1, r2, c2)) + } +} + +#[cfg(test)] +mod test { + use nalgebra::dmatrix; + use super::*; + + #[test] + fn stretch_row_column_to_square() { + let v1: DMatrix = dmatrix![1,2,3]; + let v2: DMatrix = dmatrix![1;2;3]; + + let sum = add(v1, v2).unwrap(); + assert_eq!(sum.shape(), (3, 3)); + assert_eq!(sum, dmatrix![2,3,4;3,4,5;4,5,6]); + } + + #[test] + fn stretch_row_column_to_rect() { + let v1: DMatrix = dmatrix![1,2,3]; + let v2: DMatrix = dmatrix![1;2]; + + let sum = add(v1, v2).unwrap(); + assert_eq!(sum.shape(), (2, 3)); + assert_eq!(sum, dmatrix![2,3,4;3,4,5]); + } + + #[test] + fn stretch_column_row_to_square() { + let v1: DMatrix = dmatrix![1;2;3]; + let v2: DMatrix = dmatrix![1,2,3]; + + let sum = add(v1, v2).unwrap(); + assert_eq!(sum.shape(), (3, 3)); + assert_eq!(sum, dmatrix![2,3,4;3,4,5;4,5,6]); + } + + #[test] + fn stretch_column_row_to_rect() { + let v1: DMatrix = dmatrix![1;2;3]; + let v2: DMatrix = dmatrix![1,2]; + + let sum = add(v1, v2).unwrap(); + assert_eq!(sum.shape(), (3, 2)); + assert_eq!(sum, dmatrix![2,3;3,4;4,5]); + } + + #[test] + fn stretch_row() { + let v1: DMatrix = dmatrix![1,2,3]; + let v2: DMatrix = dmatrix![1,2,3;4,5,6]; + + let sum = add(v1, v2).unwrap(); + assert_eq!(sum.shape(), (2, 3)); + assert_eq!(sum, dmatrix![2,4,6;5,7,9]); + } + + #[test] + fn stretch_row_commute() { + let v1: DMatrix = dmatrix![1,2,3;4,5,6]; + let v2: DMatrix = dmatrix![1,2,3]; + + let sum = add(v1, v2).unwrap(); + assert_eq!(sum.shape(), (2, 3)); + assert_eq!(sum, dmatrix![2,4,6;5,7,9]); + } + + #[test] + fn stretch_column() { + let v1: DMatrix = dmatrix![1;2]; + let v2: DMatrix = dmatrix![1,2,3;4,5,6]; + + let sum = add(v1, v2).unwrap(); + assert_eq!(sum.shape(), (2, 3)); + assert_eq!(sum, dmatrix![2,3,4;6,7,8]); + } + + #[test] + fn stretch_column_commute() { + let v1: DMatrix = dmatrix![1,2,3;4,5,6]; + let v2: DMatrix = dmatrix![1;2]; + + let sum = add(v1, v2).unwrap(); + assert_eq!(sum.shape(), (2, 3)); + assert_eq!(sum, dmatrix![2,3,4;6,7,8]); + } + + #[test] + fn test_broadcast_2dims() { + let v1: DMatrix = dmatrix![1,2,3]; + let v2: DMatrix = dmatrix![1;2;3]; + + let sum = add(v1, v2).unwrap(); + assert_eq!(sum.shape(), (3, 3)); + assert_eq!(sum, dmatrix![2,3,4;3,4,5;4,5,6]); + } + + #[test] + fn test_add_commutative() { + let v1 = dmatrix![1,2,3]; + let v2 = dmatrix![1;2;3]; + + let sum = add(v2, v1).unwrap(); + assert_eq!(sum.shape(), (3, 3)); + assert_eq!(sum, dmatrix![2,3,4;3,4,5;4,5,6]); + } + + #[test] + fn test_add_same_size() { + let v1 = dmatrix![1,2;3,4]; + let v2 = dmatrix![3,4;5,6]; + + let sum = add(v2, v1).unwrap(); + assert_eq!(sum.shape(), (2, 2)); + assert_eq!(sum, dmatrix![4,6;8,10]); + } + + #[test] + fn test_add_row_broadcast() {// + let v1 = dmatrix![1,2;3,4]; + let v2 = dmatrix![3,4]; + + let sum = add(v1, v2).unwrap(); + assert_eq!(sum, dmatrix![4,6;6,8]); + } + + #[test] + fn test_add_row_broadcast2() { + let v1 = dmatrix![1,1]; + let v2 = dmatrix![1,2;3,4]; + + let sum = add(v1, v2).unwrap(); + assert_eq!(sum.shape(), (2, 2)); + assert_eq!(sum, dmatrix![2,3;4,5]); + } + + #[test] + fn test_column_broadcast() { + let v1 = dmatrix![1;1]; + let v2 = dmatrix![1,2;3,4]; + + let sum = add(v1, v2).unwrap(); + assert_eq!(sum.shape(), (2, 2)); + assert_eq!(sum, dmatrix![2,3;4,5]); + } + + #[test] + fn test_column_broadcast2() { + let v1 = dmatrix![1,2;3,4]; + let v2 = dmatrix![1;1]; + + let sum = add(v1, v2).unwrap(); + assert_eq!(sum.shape(), (2, 2)); + assert_eq!(sum, dmatrix![2,3;4,5]); + } + + #[test] + fn column_too_long() { + let v1 = dmatrix![1;1;1]; + let v2 = dmatrix![1,2;3,4]; + + let result = add(v1, v2); + assert_eq!(result, Err("ValueError: operands could not be broadcast together".to_owned())); + } + + #[test] + fn row_too_long() { + let v1 = dmatrix![1,1,1]; + let v2 = dmatrix![1,2;3,4]; + + let result = add(v1, v2); + assert_eq!(result, Err("ValueError: operands could not be broadcast together".to_owned())); + } + + +} \ No newline at end of file diff --git a/src/net.rs b/src/net.rs index 49504ef..0b68ce6 100644 --- a/src/net.rs +++ b/src/net.rs @@ -1,17 +1,18 @@ -use std::convert::identity; use std::iter::zip; -use std::ops::{Add, Sub}; +use std::ops::Add; -use nalgebra::{DMatrix, Matrix, OMatrix}; +use nalgebra::DMatrix; use rand::prelude::*; use rand_distr::Normal; -use crate::dataloader::{Data, DataLine}; +use crate::dataloader::{Data, DataLine, OneHotVector}; +use crate::mat; +use crate::mat::add; #[derive(Debug)] pub struct Network { _sizes: Vec, - _num_layers: usize, + num_layers: usize, pub biases: Vec>, pub weights: Vec>, } @@ -30,7 +31,7 @@ impl Network { pub fn from(sizes: Vec) -> Self { Self { _sizes: sizes.clone(), - _num_layers: sizes.len(), + num_layers: sizes.len(), biases: biases(sizes[1..].to_vec()), weights: weights(zip(sizes[..sizes.len() - 1].to_vec(), sizes[1..].to_vec()).collect()), } @@ -39,13 +40,13 @@ impl Network { fn feed_forward(&self, input: Vec) -> Vec { let mut a = DMatrix::from_vec(input.len(), 1, input); for (b, w) in zip(&self.biases, &self.weights) { - a = b.add_scalar(w.dot(&a)); + a = add(b.clone(), w * a).unwrap(); a.apply(sigmoid_inplace); } a.column(1).iter().map(|v| *v).collect() } - pub fn sgd(&mut self, mut training_data: Data, epochs: usize, minibatch_size: usize, eta: f32, test_data: &Option>) { + pub fn sgd(&mut self, mut training_data: Data, epochs: usize, minibatch_size: usize, eta: f32, test_data: &Option>) { for j in 0..epochs { training_data.shuffle(); let mini_batches = training_data.as_batches(minibatch_size); @@ -65,7 +66,7 @@ impl Network { /// gradient descent using backpropagation to a single mini batch. /// The ``mini_batch`` is a list of tuples ``(x, y)``, and ``eta`` /// is the learning rate. - fn update_mini_batch(&mut self, mini_batch: &[DataLine], eta: f32) { + fn update_mini_batch(&mut self, mini_batch: &[DataLine], eta: f32) { let mut nabla_b: Vec> = self.biases.iter() .map(|b| b.shape()) .map(|s| DMatrix::zeros(s.0, s.1)) @@ -75,34 +76,34 @@ impl Network { .map(|s| DMatrix::zeros(s.0, s.1)) .collect(); for line in mini_batch.iter() { - let (delta_nabla_b, delta_nabla_w) = self.backprop(line.inputs.to_vec(), line.label); + let (delta_nabla_b, delta_nabla_w) = self.backprop(line.inputs.to_vec(), &line.label); nabla_b = zip(&nabla_b, &delta_nabla_b).map(|(nb, dnb)| nb.add(dnb)).collect(); nabla_w = zip(&nabla_w, &delta_nabla_w).map(|(nw, dnw)| nw.add(dnw)).collect(); } self.weights = zip(&self.weights, &nabla_w) - .map(|(w, nw)| w.add_scalar(-eta / mini_batch.len() as f32)).collect(); + .map(|(w, nw)| (w.add_scalar(-eta / mini_batch.len() as f32)).component_mul(nw)).collect(); self.biases = zip(&self.biases, &nabla_b) - .map(|(b, nb)| b.add_scalar(-eta / mini_batch.len() as f32)).collect(); + .map(|(b, nb)| (b.add_scalar(-eta / mini_batch.len() as f32)).component_mul(nb)).collect(); } /// Return the number of test inputs for which the neural /// network outputs the correct result. Note that the neural /// network's output is assumed to be the index of whichever /// neuron in the final layer has the highest activation. - fn evaluate(&self, test_data: &Data) -> usize { - let test_results: Vec<(usize, u8)> = test_data.0.iter() - .map(|line| (argmax(self.feed_forward(line.inputs.clone())), line.label)) + fn evaluate(&self, test_data: &Data) -> usize { + let test_results: Vec<(usize, usize)> = test_data.0.iter() + .map(|line| (argmax(self.feed_forward(line.inputs.clone())), line.label.val)) .collect(); - test_results.into_iter().filter(|(x, y)| *x == *y as usize).count() + test_results.into_iter().filter(|(x, y)| x == y).count() } /// Return a tuple `(nabla_b, nabla_w)` representing the /// gradient for the cost function C_x. `nabla_b` and /// `nabla_w` are layer-by-layer lists of matrices, similar /// to `self.biases` and `self.weights`. - fn backprop(&self, x: Vec, y: u8) -> (Vec>, Vec>) { + fn backprop(&self, x: Vec, y: &OneHotVector) -> (Vec>, Vec>) { // zero_grad ie. set gradient to zero let mut nabla_b: Vec> = self.biases.iter() .map(|b| b.shape()) @@ -119,38 +120,40 @@ impl Network { let mut zs = vec![]; for (b, w) in zip(&self.biases, &self.weights) { - // println!("{:?}", w.shape()); - // println!("{:?}", activation.shape()); - // println!("{:?}", b.shape()); - - let mut z: DMatrix = w * &activation + b; + let z = add(w * &activation, b.clone()).unwrap(); zs.push(z.clone()); activation = z.map(sigmoid); activations.push(activation.clone()); } - // backward pass - let delta: DMatrix = self.cost_derivative( - &activations[activations.len() - 1], - y as f32); - println!("delta {:?}", delta.shape()); - println!("z {:?}", &zs[zs.len() - 1].transpose().shape()); - let delta = delta * (&zs[zs.len() - 1].transpose().map(sigmoid_prime)); - println!("delta {:?}", delta.shape()); + // delta = self.cost_derivative(activations[-1], y) * sigmoid_prime(zs[-1]) + let delta: DMatrix = self.cost_derivative(&activations[activations.len() - 1], y).component_mul((&zs[zs.len() - 1].map(sigmoid_prime))); let index = nabla_b.len() - 1; nabla_b[index] = delta.clone(); - println!("delta {:?}", delta.shape()); - println!("activation {:?}", activations[activations.len() - 2].shape()); let index = nabla_w.len() - 1; - nabla_w[index] = delta * &activations[activations.len() - 2]; - + let ac = &activations[activations.len() - 2].transpose(); + nabla_w[index] = &delta * ac; + let lens_zs = zs.len(); + for l in 2..self.num_layers { + let z = &zs[lens_zs - l]; + let sp = z.map(sigmoid_prime); + let weight = self.weights[self.weights.len() - l + 1].transpose(); + let delta2 = (weight * &delta).component_mul(&sp); + let len_nb = nabla_b.len(); + nabla_b[len_nb - l] = delta2.clone(); + let len_nw = nabla_w.len(); + nabla_w[len_nw - l] = delta2 * activations[activations.len() - l - 1].transpose(); + } (nabla_b, nabla_w) } - fn cost_derivative(&self, output_activations: &DMatrix, y: f32) -> DMatrix { - output_activations.add_scalar(-y) + fn cost_derivative(&self, output_activations: &DMatrix, y: &OneHotVector) -> DMatrix { + // output_activations - y + let shape = output_activations.shape(); + DMatrix::from_iterator(shape.0, shape.1, output_activations.iter().enumerate() + .map(|(index, a)| a - y.get(index))) } } @@ -171,7 +174,6 @@ fn biases(sizes: Vec) -> Vec> { } fn weights(sizes: Vec<(usize, usize)>) -> Vec> { - println!("{:?}", sizes); sizes.iter().map(|size| random_matrix(size.1, size.0)).collect() }