diff --git a/src/lib.rs b/src/lib.rs index ca821ee..f154dc1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,2 @@ pub mod net; -pub mod dataloader; -mod mat; \ No newline at end of file +pub mod dataloader; \ No newline at end of file diff --git a/src/mat.rs b/src/mat.rs deleted file mode 100644 index 4e35de8..0000000 --- a/src/mat.rs +++ /dev/null @@ -1,209 +0,0 @@ -use core::ops::Add; -use std::fmt::Debug; -use std::ops::AddAssign; -use nalgebra::DMatrix; - -/// matrix add with broadcasting -/// like the numpy add operation -/// not sure I even need it anymore, after fixing inconsistencies with the matrix shapes -/// TODO see if it's still needed, or that standard matrix addition suffices -pub fn add(v1: DMatrix, v2: DMatrix) -> Result, String> - where T: PartialEq + Copy + Clone + Debug + Add + Add + AddAssign + 'static -{ - let (r1, c1) = v1.shape(); - let (r2, c2) = v2.shape(); - - if r1 == r2 && c1 == c2 { - // same size, no broadcasting needed - Ok(v1 + v2) - } else if r1 == 1 && c2 == 1 { - Ok(DMatrix::from_fn(r2, c1, |r, c| *v1.get(c).unwrap() + *v2.get(r).unwrap())) - } else if c1 == 1 && r2 == 1 { - Ok(DMatrix::from_fn(r1, c2, |r, c| *v1.get(r).unwrap() + *v2.get(c).unwrap())) - } else if r1 == 1 && c1 == c2 { - Ok(DMatrix::from_fn(r2, c1, |r, c| *v1.get(c).unwrap() + *v2.get(c * r2 + r).unwrap())) - } else if r2 == 1 && c1 == c2 { - Ok(DMatrix::from_fn(r1, c2, |r, c| *v2.get(c).unwrap() + *v1.get(c * r1 + r).unwrap())) - } else if c1 == 1 && r1 == r2 { - Ok(DMatrix::from_fn(r1, c2, |r, c| *v1.get(r).unwrap() + *v2.get(c * r2 + r).unwrap())) - } else if c2 == 1 && r1 == r2 { - Ok(DMatrix::from_fn(r2, c1, |r, c| *v2.get(r).unwrap() + *v1.get(c * r2 + r).unwrap())) - } else { - Err(format!("ValueError: operands could not be broadcast together ({},{}), ({},{})", r1,c1, r2, c2)) - } -} - -#[cfg(test)] -mod test { - use nalgebra::dmatrix; - use super::*; - - #[test] - fn stretch_row_column_to_square() { - let v1: DMatrix = dmatrix![1,2,3]; - let v2: DMatrix = dmatrix![1;2;3]; - - let sum = add(v1, v2).unwrap(); - assert_eq!(sum.shape(), (3, 3)); - assert_eq!(sum, dmatrix![2,3,4;3,4,5;4,5,6]); - } - - #[test] - fn stretch_row_column_to_rect() { - let v1: DMatrix = dmatrix![1,2,3]; - let v2: DMatrix = dmatrix![1;2]; - - let sum = add(v1, v2).unwrap(); - assert_eq!(sum.shape(), (2, 3)); - assert_eq!(sum, dmatrix![2,3,4;3,4,5]); - } - - #[test] - fn stretch_column_row_to_square() { - let v1: DMatrix = dmatrix![1;2;3]; - let v2: DMatrix = dmatrix![1,2,3]; - - let sum = add(v1, v2).unwrap(); - assert_eq!(sum.shape(), (3, 3)); - assert_eq!(sum, dmatrix![2,3,4;3,4,5;4,5,6]); - } - - #[test] - fn stretch_column_row_to_rect() { - let v1: DMatrix = dmatrix![1;2;3]; - let v2: DMatrix = dmatrix![1,2]; - - let sum = add(v1, v2).unwrap(); - assert_eq!(sum.shape(), (3, 2)); - assert_eq!(sum, dmatrix![2,3;3,4;4,5]); - } - - #[test] - fn stretch_row() { - let v1: DMatrix = dmatrix![1,2,3]; - let v2: DMatrix = dmatrix![1,2,3;4,5,6]; - - let sum = add(v1, v2).unwrap(); - assert_eq!(sum.shape(), (2, 3)); - assert_eq!(sum, dmatrix![2,4,6;5,7,9]); - } - - #[test] - fn stretch_row_commute() { - let v1: DMatrix = dmatrix![1,2,3;4,5,6]; - let v2: DMatrix = dmatrix![1,2,3]; - - let sum = add(v1, v2).unwrap(); - assert_eq!(sum.shape(), (2, 3)); - assert_eq!(sum, dmatrix![2,4,6;5,7,9]); - } - - #[test] - fn stretch_column() { - let v1: DMatrix = dmatrix![1;2]; - let v2: DMatrix = dmatrix![1,2,3;4,5,6]; - - let sum = add(v1, v2).unwrap(); - assert_eq!(sum.shape(), (2, 3)); - assert_eq!(sum, dmatrix![2,3,4;6,7,8]); - } - - #[test] - fn stretch_column_commute() { - let v1: DMatrix = dmatrix![1,2,3;4,5,6]; - let v2: DMatrix = dmatrix![1;2]; - - let sum = add(v1, v2).unwrap(); - assert_eq!(sum.shape(), (2, 3)); - assert_eq!(sum, dmatrix![2,3,4;6,7,8]); - } - - #[test] - fn test_broadcast_2dims() { - let v1: DMatrix = dmatrix![1,2,3]; - let v2: DMatrix = dmatrix![1;2;3]; - - let sum = add(v1, v2).unwrap(); - assert_eq!(sum.shape(), (3, 3)); - assert_eq!(sum, dmatrix![2,3,4;3,4,5;4,5,6]); - } - - #[test] - fn test_add_commutative() { - let v1 = dmatrix![1,2,3]; - let v2 = dmatrix![1;2;3]; - - let sum = add(v2, v1).unwrap(); - assert_eq!(sum.shape(), (3, 3)); - assert_eq!(sum, dmatrix![2,3,4;3,4,5;4,5,6]); - } - - #[test] - fn test_add_same_size() { - let v1 = dmatrix![1,2;3,4]; - let v2 = dmatrix![3,4;5,6]; - - let sum = add(v2, v1).unwrap(); - assert_eq!(sum.shape(), (2, 2)); - assert_eq!(sum, dmatrix![4,6;8,10]); - } - - #[test] - fn test_add_row_broadcast() {// - let v1 = dmatrix![1,2;3,4]; - let v2 = dmatrix![3,4]; - - let sum = add(v1, v2).unwrap(); - assert_eq!(sum, dmatrix![4,6;6,8]); - } - - #[test] - fn test_add_row_broadcast2() { - let v1 = dmatrix![1,1]; - let v2 = dmatrix![1,2;3,4]; - - let sum = add(v1, v2).unwrap(); - assert_eq!(sum.shape(), (2, 2)); - assert_eq!(sum, dmatrix![2,3;4,5]); - } - - #[test] - fn test_column_broadcast() { - let v1 = dmatrix![1;1]; - let v2 = dmatrix![1,2;3,4]; - - let sum = add(v1, v2).unwrap(); - assert_eq!(sum.shape(), (2, 2)); - assert_eq!(sum, dmatrix![2,3;4,5]); - } - - #[test] - fn test_column_broadcast2() { - let v1 = dmatrix![1,2;3,4]; - let v2 = dmatrix![1;1]; - - let sum = add(v1, v2).unwrap(); - assert_eq!(sum.shape(), (2, 2)); - assert_eq!(sum, dmatrix![2,3;4,5]); - } - - #[test] - fn column_too_long() { - let v1 = dmatrix![1;1;1]; - let v2 = dmatrix![1,2;3,4]; - - let result = add(v1, v2); - assert_eq!(result, Err("ValueError: operands could not be broadcast together".to_owned())); - } - - #[test] - fn row_too_long() { - let v1 = dmatrix![1,1,1]; - let v2 = dmatrix![1,2;3,4]; - - let result = add(v1, v2); - assert_eq!(result, Err("ValueError: operands could not be broadcast together".to_owned())); - } - - -} \ No newline at end of file diff --git a/src/net.rs b/src/net.rs index 020c64d..b91bc2e 100644 --- a/src/net.rs +++ b/src/net.rs @@ -6,7 +6,6 @@ use rand::prelude::*; use rand_distr::Normal; use crate::dataloader::{Data, DataLine, OneHotVector}; -use crate::mat::add; #[derive(Debug)] pub struct Network { @@ -56,7 +55,7 @@ impl Network { fn feed_forward_activation(&self, input: Vec, activation: fn(&mut f64)) -> Vec { let mut a = DMatrix::from_vec(input.len(), 1, input); for (b, w) in zip(&self.biases, &self.weights) { - a = add(b.clone(), w * a).unwrap(); + a = b.clone()+ w * a; a.apply(activation); } a.column(0).iter().copied().collect() @@ -83,14 +82,7 @@ impl Network { /// The ``mini_batch`` is a list of tuples ``(x, y)``, and ``eta`` /// is the learning rate. fn update_mini_batch(&mut self, mini_batch: &[DataLine], eta: f64) { - let mut nabla_b: Vec> = self.biases.iter() - .map(|b| b.shape()) - .map(|s| DMatrix::zeros(s.0, s.1)) - .collect(); - let mut nabla_w: Vec> = self.weights.iter() - .map(|w| w.shape()) - .map(|s| DMatrix::zeros(s.0, s.1)) - .collect(); + let (mut nabla_b, mut nabla_w) = self.zero_gradient(); for line in mini_batch.iter() { let (delta_nabla_b, delta_nabla_w) = self.backprop(line.inputs.to_vec(), &line.label); @@ -124,15 +116,7 @@ impl Network { /// `nabla_w` are layer-by-layer lists of matrices, similar /// to `self.biases` and `self.weights`. fn backprop(&self, x: Vec, y: &OneHotVector) -> (Vec>, Vec>) { - // zero_grad ie. set gradient to zero - let mut nabla_b: Vec> = self.biases.iter() - .map(|b| b.shape()) - .map(|s| DMatrix::zeros(s.0, s.1)) - .collect(); - let mut nabla_w: Vec> = self.weights.iter() - .map(|w| w.shape()) - .map(|s| DMatrix::zeros(s.0, s.1)) - .collect(); + let (mut nabla_b, mut nabla_w) = self.zero_gradient(); // feedforward let mut activation = DMatrix::from_vec(x.len(), 1, x); @@ -168,6 +152,19 @@ impl Network { (nabla_b, nabla_w) } + + fn zero_gradient(&self) -> (Vec>, Vec>) { + let nabla_b: Vec> = self.biases.iter() + .map(|b| b.shape()) + .map(|s| DMatrix::zeros(s.0, s.1)) + .collect(); + let nabla_w: Vec> = self.weights.iter() + .map(|w| w.shape()) + .map(|s| DMatrix::zeros(s.0, s.1)) + .collect(); + (nabla_b, nabla_w) + } + } fn cost_derivative(output_activations: &DMatrix, y: &OneHotVector) -> DMatrix { @@ -225,7 +222,6 @@ fn sigmoid_prime(val: f64) -> f64 { #[cfg(test)] mod test { - use std::convert::identity; use nalgebra::DMatrix; use super::*; @@ -241,7 +237,7 @@ mod test { fn test_sigmoid_inplace() { let mut v = 10.0; sigmoid_inplace(&mut v); - assert_eq!(0.9999546, v); + assert_eq!(0.9999546021312976, v); } #[test]