tests fixed, removed unused broadcast add

This commit is contained in:
Shautvast 2023-03-03 16:40:30 +01:00
parent 8acf2a11d5
commit 0b62f3cbc2
3 changed files with 18 additions and 232 deletions

View file

@ -1,3 +1,2 @@
pub mod net;
pub mod dataloader;
mod mat;

View file

@ -1,209 +0,0 @@
use core::ops::Add;
use std::fmt::Debug;
use std::ops::AddAssign;
use nalgebra::DMatrix;
/// matrix add with broadcasting
/// like the numpy add operation
/// not sure I even need it anymore, after fixing inconsistencies with the matrix shapes
/// TODO see if it's still needed, or that standard matrix addition suffices
pub fn add<T>(v1: DMatrix<T>, v2: DMatrix<T>) -> Result<DMatrix<T>, String>
where T: PartialEq + Copy + Clone + Debug + Add + Add<Output=T> + AddAssign + 'static
{
let (r1, c1) = v1.shape();
let (r2, c2) = v2.shape();
if r1 == r2 && c1 == c2 {
// same size, no broadcasting needed
Ok(v1 + v2)
} else if r1 == 1 && c2 == 1 {
Ok(DMatrix::from_fn(r2, c1, |r, c| *v1.get(c).unwrap() + *v2.get(r).unwrap()))
} else if c1 == 1 && r2 == 1 {
Ok(DMatrix::from_fn(r1, c2, |r, c| *v1.get(r).unwrap() + *v2.get(c).unwrap()))
} else if r1 == 1 && c1 == c2 {
Ok(DMatrix::from_fn(r2, c1, |r, c| *v1.get(c).unwrap() + *v2.get(c * r2 + r).unwrap()))
} else if r2 == 1 && c1 == c2 {
Ok(DMatrix::from_fn(r1, c2, |r, c| *v2.get(c).unwrap() + *v1.get(c * r1 + r).unwrap()))
} else if c1 == 1 && r1 == r2 {
Ok(DMatrix::from_fn(r1, c2, |r, c| *v1.get(r).unwrap() + *v2.get(c * r2 + r).unwrap()))
} else if c2 == 1 && r1 == r2 {
Ok(DMatrix::from_fn(r2, c1, |r, c| *v2.get(r).unwrap() + *v1.get(c * r2 + r).unwrap()))
} else {
Err(format!("ValueError: operands could not be broadcast together ({},{}), ({},{})", r1,c1, r2, c2))
}
}
#[cfg(test)]
mod test {
use nalgebra::dmatrix;
use super::*;
#[test]
fn stretch_row_column_to_square() {
let v1: DMatrix<u32> = dmatrix![1,2,3];
let v2: DMatrix<u32> = dmatrix![1;2;3];
let sum = add(v1, v2).unwrap();
assert_eq!(sum.shape(), (3, 3));
assert_eq!(sum, dmatrix![2,3,4;3,4,5;4,5,6]);
}
#[test]
fn stretch_row_column_to_rect() {
let v1: DMatrix<u32> = dmatrix![1,2,3];
let v2: DMatrix<u32> = dmatrix![1;2];
let sum = add(v1, v2).unwrap();
assert_eq!(sum.shape(), (2, 3));
assert_eq!(sum, dmatrix![2,3,4;3,4,5]);
}
#[test]
fn stretch_column_row_to_square() {
let v1: DMatrix<u32> = dmatrix![1;2;3];
let v2: DMatrix<u32> = dmatrix![1,2,3];
let sum = add(v1, v2).unwrap();
assert_eq!(sum.shape(), (3, 3));
assert_eq!(sum, dmatrix![2,3,4;3,4,5;4,5,6]);
}
#[test]
fn stretch_column_row_to_rect() {
let v1: DMatrix<u32> = dmatrix![1;2;3];
let v2: DMatrix<u32> = dmatrix![1,2];
let sum = add(v1, v2).unwrap();
assert_eq!(sum.shape(), (3, 2));
assert_eq!(sum, dmatrix![2,3;3,4;4,5]);
}
#[test]
fn stretch_row() {
let v1: DMatrix<u32> = dmatrix![1,2,3];
let v2: DMatrix<u32> = dmatrix![1,2,3;4,5,6];
let sum = add(v1, v2).unwrap();
assert_eq!(sum.shape(), (2, 3));
assert_eq!(sum, dmatrix![2,4,6;5,7,9]);
}
#[test]
fn stretch_row_commute() {
let v1: DMatrix<u32> = dmatrix![1,2,3;4,5,6];
let v2: DMatrix<u32> = dmatrix![1,2,3];
let sum = add(v1, v2).unwrap();
assert_eq!(sum.shape(), (2, 3));
assert_eq!(sum, dmatrix![2,4,6;5,7,9]);
}
#[test]
fn stretch_column() {
let v1: DMatrix<u32> = dmatrix![1;2];
let v2: DMatrix<u32> = dmatrix![1,2,3;4,5,6];
let sum = add(v1, v2).unwrap();
assert_eq!(sum.shape(), (2, 3));
assert_eq!(sum, dmatrix![2,3,4;6,7,8]);
}
#[test]
fn stretch_column_commute() {
let v1: DMatrix<u32> = dmatrix![1,2,3;4,5,6];
let v2: DMatrix<u32> = dmatrix![1;2];
let sum = add(v1, v2).unwrap();
assert_eq!(sum.shape(), (2, 3));
assert_eq!(sum, dmatrix![2,3,4;6,7,8]);
}
#[test]
fn test_broadcast_2dims() {
let v1: DMatrix<u32> = dmatrix![1,2,3];
let v2: DMatrix<u32> = dmatrix![1;2;3];
let sum = add(v1, v2).unwrap();
assert_eq!(sum.shape(), (3, 3));
assert_eq!(sum, dmatrix![2,3,4;3,4,5;4,5,6]);
}
#[test]
fn test_add_commutative() {
let v1 = dmatrix![1,2,3];
let v2 = dmatrix![1;2;3];
let sum = add(v2, v1).unwrap();
assert_eq!(sum.shape(), (3, 3));
assert_eq!(sum, dmatrix![2,3,4;3,4,5;4,5,6]);
}
#[test]
fn test_add_same_size() {
let v1 = dmatrix![1,2;3,4];
let v2 = dmatrix![3,4;5,6];
let sum = add(v2, v1).unwrap();
assert_eq!(sum.shape(), (2, 2));
assert_eq!(sum, dmatrix![4,6;8,10]);
}
#[test]
fn test_add_row_broadcast() {//
let v1 = dmatrix![1,2;3,4];
let v2 = dmatrix![3,4];
let sum = add(v1, v2).unwrap();
assert_eq!(sum, dmatrix![4,6;6,8]);
}
#[test]
fn test_add_row_broadcast2() {
let v1 = dmatrix![1,1];
let v2 = dmatrix![1,2;3,4];
let sum = add(v1, v2).unwrap();
assert_eq!(sum.shape(), (2, 2));
assert_eq!(sum, dmatrix![2,3;4,5]);
}
#[test]
fn test_column_broadcast() {
let v1 = dmatrix![1;1];
let v2 = dmatrix![1,2;3,4];
let sum = add(v1, v2).unwrap();
assert_eq!(sum.shape(), (2, 2));
assert_eq!(sum, dmatrix![2,3;4,5]);
}
#[test]
fn test_column_broadcast2() {
let v1 = dmatrix![1,2;3,4];
let v2 = dmatrix![1;1];
let sum = add(v1, v2).unwrap();
assert_eq!(sum.shape(), (2, 2));
assert_eq!(sum, dmatrix![2,3;4,5]);
}
#[test]
fn column_too_long() {
let v1 = dmatrix![1;1;1];
let v2 = dmatrix![1,2;3,4];
let result = add(v1, v2);
assert_eq!(result, Err("ValueError: operands could not be broadcast together".to_owned()));
}
#[test]
fn row_too_long() {
let v1 = dmatrix![1,1,1];
let v2 = dmatrix![1,2;3,4];
let result = add(v1, v2);
assert_eq!(result, Err("ValueError: operands could not be broadcast together".to_owned()));
}
}

View file

@ -6,7 +6,6 @@ use rand::prelude::*;
use rand_distr::Normal;
use crate::dataloader::{Data, DataLine, OneHotVector};
use crate::mat::add;
#[derive(Debug)]
pub struct Network {
@ -56,7 +55,7 @@ impl Network {
fn feed_forward_activation(&self, input: Vec<f64>, activation: fn(&mut f64)) -> Vec<f64> {
let mut a = DMatrix::from_vec(input.len(), 1, input);
for (b, w) in zip(&self.biases, &self.weights) {
a = add(b.clone(), w * a).unwrap();
a = b.clone()+ w * a;
a.apply(activation);
}
a.column(0).iter().copied().collect()
@ -83,14 +82,7 @@ impl Network {
/// The ``mini_batch`` is a list of tuples ``(x, y)``, and ``eta``
/// is the learning rate.
fn update_mini_batch(&mut self, mini_batch: &[DataLine<f64, OneHotVector>], eta: f64) {
let mut nabla_b: Vec<DMatrix<f64>> = self.biases.iter()
.map(|b| b.shape())
.map(|s| DMatrix::zeros(s.0, s.1))
.collect();
let mut nabla_w: Vec<DMatrix<f64>> = self.weights.iter()
.map(|w| w.shape())
.map(|s| DMatrix::zeros(s.0, s.1))
.collect();
let (mut nabla_b, mut nabla_w) = self.zero_gradient();
for line in mini_batch.iter() {
let (delta_nabla_b, delta_nabla_w) = self.backprop(line.inputs.to_vec(), &line.label);
@ -124,15 +116,7 @@ impl Network {
/// `nabla_w` are layer-by-layer lists of matrices, similar
/// to `self.biases` and `self.weights`.
fn backprop(&self, x: Vec<f64>, y: &OneHotVector) -> (Vec<DMatrix<f64>>, Vec<DMatrix<f64>>) {
// zero_grad ie. set gradient to zero
let mut nabla_b: Vec<DMatrix<f64>> = self.biases.iter()
.map(|b| b.shape())
.map(|s| DMatrix::zeros(s.0, s.1))
.collect();
let mut nabla_w: Vec<DMatrix<f64>> = self.weights.iter()
.map(|w| w.shape())
.map(|s| DMatrix::zeros(s.0, s.1))
.collect();
let (mut nabla_b, mut nabla_w) = self.zero_gradient();
// feedforward
let mut activation = DMatrix::from_vec(x.len(), 1, x);
@ -168,6 +152,19 @@ impl Network {
(nabla_b, nabla_w)
}
fn zero_gradient(&self) -> (Vec<DMatrix<f64>>, Vec<DMatrix<f64>>) {
let nabla_b: Vec<DMatrix<f64>> = self.biases.iter()
.map(|b| b.shape())
.map(|s| DMatrix::zeros(s.0, s.1))
.collect();
let nabla_w: Vec<DMatrix<f64>> = self.weights.iter()
.map(|w| w.shape())
.map(|s| DMatrix::zeros(s.0, s.1))
.collect();
(nabla_b, nabla_w)
}
}
fn cost_derivative(output_activations: &DMatrix<f64>, y: &OneHotVector) -> DMatrix<f64> {
@ -225,7 +222,6 @@ fn sigmoid_prime(val: f64) -> f64 {
#[cfg(test)]
mod test {
use std::convert::identity;
use nalgebra::DMatrix;
use super::*;
@ -241,7 +237,7 @@ mod test {
fn test_sigmoid_inplace() {
let mut v = 10.0;
sigmoid_inplace(&mut v);
assert_eq!(0.9999546, v);
assert_eq!(0.9999546021312976, v);
}
#[test]