tests fixed, removed unused broadcast add
This commit is contained in:
parent
8acf2a11d5
commit
0b62f3cbc2
3 changed files with 18 additions and 232 deletions
|
|
@ -1,3 +1,2 @@
|
||||||
pub mod net;
|
pub mod net;
|
||||||
pub mod dataloader;
|
pub mod dataloader;
|
||||||
mod mat;
|
|
||||||
209
src/mat.rs
209
src/mat.rs
|
|
@ -1,209 +0,0 @@
|
||||||
use core::ops::Add;
|
|
||||||
use std::fmt::Debug;
|
|
||||||
use std::ops::AddAssign;
|
|
||||||
use nalgebra::DMatrix;
|
|
||||||
|
|
||||||
/// matrix add with broadcasting
|
|
||||||
/// like the numpy add operation
|
|
||||||
/// not sure I even need it anymore, after fixing inconsistencies with the matrix shapes
|
|
||||||
/// TODO see if it's still needed, or that standard matrix addition suffices
|
|
||||||
pub fn add<T>(v1: DMatrix<T>, v2: DMatrix<T>) -> Result<DMatrix<T>, String>
|
|
||||||
where T: PartialEq + Copy + Clone + Debug + Add + Add<Output=T> + AddAssign + 'static
|
|
||||||
{
|
|
||||||
let (r1, c1) = v1.shape();
|
|
||||||
let (r2, c2) = v2.shape();
|
|
||||||
|
|
||||||
if r1 == r2 && c1 == c2 {
|
|
||||||
// same size, no broadcasting needed
|
|
||||||
Ok(v1 + v2)
|
|
||||||
} else if r1 == 1 && c2 == 1 {
|
|
||||||
Ok(DMatrix::from_fn(r2, c1, |r, c| *v1.get(c).unwrap() + *v2.get(r).unwrap()))
|
|
||||||
} else if c1 == 1 && r2 == 1 {
|
|
||||||
Ok(DMatrix::from_fn(r1, c2, |r, c| *v1.get(r).unwrap() + *v2.get(c).unwrap()))
|
|
||||||
} else if r1 == 1 && c1 == c2 {
|
|
||||||
Ok(DMatrix::from_fn(r2, c1, |r, c| *v1.get(c).unwrap() + *v2.get(c * r2 + r).unwrap()))
|
|
||||||
} else if r2 == 1 && c1 == c2 {
|
|
||||||
Ok(DMatrix::from_fn(r1, c2, |r, c| *v2.get(c).unwrap() + *v1.get(c * r1 + r).unwrap()))
|
|
||||||
} else if c1 == 1 && r1 == r2 {
|
|
||||||
Ok(DMatrix::from_fn(r1, c2, |r, c| *v1.get(r).unwrap() + *v2.get(c * r2 + r).unwrap()))
|
|
||||||
} else if c2 == 1 && r1 == r2 {
|
|
||||||
Ok(DMatrix::from_fn(r2, c1, |r, c| *v2.get(r).unwrap() + *v1.get(c * r2 + r).unwrap()))
|
|
||||||
} else {
|
|
||||||
Err(format!("ValueError: operands could not be broadcast together ({},{}), ({},{})", r1,c1, r2, c2))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod test {
|
|
||||||
use nalgebra::dmatrix;
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn stretch_row_column_to_square() {
|
|
||||||
let v1: DMatrix<u32> = dmatrix![1,2,3];
|
|
||||||
let v2: DMatrix<u32> = dmatrix![1;2;3];
|
|
||||||
|
|
||||||
let sum = add(v1, v2).unwrap();
|
|
||||||
assert_eq!(sum.shape(), (3, 3));
|
|
||||||
assert_eq!(sum, dmatrix![2,3,4;3,4,5;4,5,6]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn stretch_row_column_to_rect() {
|
|
||||||
let v1: DMatrix<u32> = dmatrix![1,2,3];
|
|
||||||
let v2: DMatrix<u32> = dmatrix![1;2];
|
|
||||||
|
|
||||||
let sum = add(v1, v2).unwrap();
|
|
||||||
assert_eq!(sum.shape(), (2, 3));
|
|
||||||
assert_eq!(sum, dmatrix![2,3,4;3,4,5]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn stretch_column_row_to_square() {
|
|
||||||
let v1: DMatrix<u32> = dmatrix![1;2;3];
|
|
||||||
let v2: DMatrix<u32> = dmatrix![1,2,3];
|
|
||||||
|
|
||||||
let sum = add(v1, v2).unwrap();
|
|
||||||
assert_eq!(sum.shape(), (3, 3));
|
|
||||||
assert_eq!(sum, dmatrix![2,3,4;3,4,5;4,5,6]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn stretch_column_row_to_rect() {
|
|
||||||
let v1: DMatrix<u32> = dmatrix![1;2;3];
|
|
||||||
let v2: DMatrix<u32> = dmatrix![1,2];
|
|
||||||
|
|
||||||
let sum = add(v1, v2).unwrap();
|
|
||||||
assert_eq!(sum.shape(), (3, 2));
|
|
||||||
assert_eq!(sum, dmatrix![2,3;3,4;4,5]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn stretch_row() {
|
|
||||||
let v1: DMatrix<u32> = dmatrix![1,2,3];
|
|
||||||
let v2: DMatrix<u32> = dmatrix![1,2,3;4,5,6];
|
|
||||||
|
|
||||||
let sum = add(v1, v2).unwrap();
|
|
||||||
assert_eq!(sum.shape(), (2, 3));
|
|
||||||
assert_eq!(sum, dmatrix![2,4,6;5,7,9]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn stretch_row_commute() {
|
|
||||||
let v1: DMatrix<u32> = dmatrix![1,2,3;4,5,6];
|
|
||||||
let v2: DMatrix<u32> = dmatrix![1,2,3];
|
|
||||||
|
|
||||||
let sum = add(v1, v2).unwrap();
|
|
||||||
assert_eq!(sum.shape(), (2, 3));
|
|
||||||
assert_eq!(sum, dmatrix![2,4,6;5,7,9]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn stretch_column() {
|
|
||||||
let v1: DMatrix<u32> = dmatrix![1;2];
|
|
||||||
let v2: DMatrix<u32> = dmatrix![1,2,3;4,5,6];
|
|
||||||
|
|
||||||
let sum = add(v1, v2).unwrap();
|
|
||||||
assert_eq!(sum.shape(), (2, 3));
|
|
||||||
assert_eq!(sum, dmatrix![2,3,4;6,7,8]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn stretch_column_commute() {
|
|
||||||
let v1: DMatrix<u32> = dmatrix![1,2,3;4,5,6];
|
|
||||||
let v2: DMatrix<u32> = dmatrix![1;2];
|
|
||||||
|
|
||||||
let sum = add(v1, v2).unwrap();
|
|
||||||
assert_eq!(sum.shape(), (2, 3));
|
|
||||||
assert_eq!(sum, dmatrix![2,3,4;6,7,8]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_broadcast_2dims() {
|
|
||||||
let v1: DMatrix<u32> = dmatrix![1,2,3];
|
|
||||||
let v2: DMatrix<u32> = dmatrix![1;2;3];
|
|
||||||
|
|
||||||
let sum = add(v1, v2).unwrap();
|
|
||||||
assert_eq!(sum.shape(), (3, 3));
|
|
||||||
assert_eq!(sum, dmatrix![2,3,4;3,4,5;4,5,6]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_add_commutative() {
|
|
||||||
let v1 = dmatrix![1,2,3];
|
|
||||||
let v2 = dmatrix![1;2;3];
|
|
||||||
|
|
||||||
let sum = add(v2, v1).unwrap();
|
|
||||||
assert_eq!(sum.shape(), (3, 3));
|
|
||||||
assert_eq!(sum, dmatrix![2,3,4;3,4,5;4,5,6]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_add_same_size() {
|
|
||||||
let v1 = dmatrix![1,2;3,4];
|
|
||||||
let v2 = dmatrix![3,4;5,6];
|
|
||||||
|
|
||||||
let sum = add(v2, v1).unwrap();
|
|
||||||
assert_eq!(sum.shape(), (2, 2));
|
|
||||||
assert_eq!(sum, dmatrix![4,6;8,10]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_add_row_broadcast() {//
|
|
||||||
let v1 = dmatrix![1,2;3,4];
|
|
||||||
let v2 = dmatrix![3,4];
|
|
||||||
|
|
||||||
let sum = add(v1, v2).unwrap();
|
|
||||||
assert_eq!(sum, dmatrix![4,6;6,8]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_add_row_broadcast2() {
|
|
||||||
let v1 = dmatrix![1,1];
|
|
||||||
let v2 = dmatrix![1,2;3,4];
|
|
||||||
|
|
||||||
let sum = add(v1, v2).unwrap();
|
|
||||||
assert_eq!(sum.shape(), (2, 2));
|
|
||||||
assert_eq!(sum, dmatrix![2,3;4,5]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_column_broadcast() {
|
|
||||||
let v1 = dmatrix![1;1];
|
|
||||||
let v2 = dmatrix![1,2;3,4];
|
|
||||||
|
|
||||||
let sum = add(v1, v2).unwrap();
|
|
||||||
assert_eq!(sum.shape(), (2, 2));
|
|
||||||
assert_eq!(sum, dmatrix![2,3;4,5]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_column_broadcast2() {
|
|
||||||
let v1 = dmatrix![1,2;3,4];
|
|
||||||
let v2 = dmatrix![1;1];
|
|
||||||
|
|
||||||
let sum = add(v1, v2).unwrap();
|
|
||||||
assert_eq!(sum.shape(), (2, 2));
|
|
||||||
assert_eq!(sum, dmatrix![2,3;4,5]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn column_too_long() {
|
|
||||||
let v1 = dmatrix![1;1;1];
|
|
||||||
let v2 = dmatrix![1,2;3,4];
|
|
||||||
|
|
||||||
let result = add(v1, v2);
|
|
||||||
assert_eq!(result, Err("ValueError: operands could not be broadcast together".to_owned()));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn row_too_long() {
|
|
||||||
let v1 = dmatrix![1,1,1];
|
|
||||||
let v2 = dmatrix![1,2;3,4];
|
|
||||||
|
|
||||||
let result = add(v1, v2);
|
|
||||||
assert_eq!(result, Err("ValueError: operands could not be broadcast together".to_owned()));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
38
src/net.rs
38
src/net.rs
|
|
@ -6,7 +6,6 @@ use rand::prelude::*;
|
||||||
use rand_distr::Normal;
|
use rand_distr::Normal;
|
||||||
|
|
||||||
use crate::dataloader::{Data, DataLine, OneHotVector};
|
use crate::dataloader::{Data, DataLine, OneHotVector};
|
||||||
use crate::mat::add;
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Network {
|
pub struct Network {
|
||||||
|
|
@ -56,7 +55,7 @@ impl Network {
|
||||||
fn feed_forward_activation(&self, input: Vec<f64>, activation: fn(&mut f64)) -> Vec<f64> {
|
fn feed_forward_activation(&self, input: Vec<f64>, activation: fn(&mut f64)) -> Vec<f64> {
|
||||||
let mut a = DMatrix::from_vec(input.len(), 1, input);
|
let mut a = DMatrix::from_vec(input.len(), 1, input);
|
||||||
for (b, w) in zip(&self.biases, &self.weights) {
|
for (b, w) in zip(&self.biases, &self.weights) {
|
||||||
a = add(b.clone(), w * a).unwrap();
|
a = b.clone()+ w * a;
|
||||||
a.apply(activation);
|
a.apply(activation);
|
||||||
}
|
}
|
||||||
a.column(0).iter().copied().collect()
|
a.column(0).iter().copied().collect()
|
||||||
|
|
@ -83,14 +82,7 @@ impl Network {
|
||||||
/// The ``mini_batch`` is a list of tuples ``(x, y)``, and ``eta``
|
/// The ``mini_batch`` is a list of tuples ``(x, y)``, and ``eta``
|
||||||
/// is the learning rate.
|
/// is the learning rate.
|
||||||
fn update_mini_batch(&mut self, mini_batch: &[DataLine<f64, OneHotVector>], eta: f64) {
|
fn update_mini_batch(&mut self, mini_batch: &[DataLine<f64, OneHotVector>], eta: f64) {
|
||||||
let mut nabla_b: Vec<DMatrix<f64>> = self.biases.iter()
|
let (mut nabla_b, mut nabla_w) = self.zero_gradient();
|
||||||
.map(|b| b.shape())
|
|
||||||
.map(|s| DMatrix::zeros(s.0, s.1))
|
|
||||||
.collect();
|
|
||||||
let mut nabla_w: Vec<DMatrix<f64>> = self.weights.iter()
|
|
||||||
.map(|w| w.shape())
|
|
||||||
.map(|s| DMatrix::zeros(s.0, s.1))
|
|
||||||
.collect();
|
|
||||||
for line in mini_batch.iter() {
|
for line in mini_batch.iter() {
|
||||||
let (delta_nabla_b, delta_nabla_w) = self.backprop(line.inputs.to_vec(), &line.label);
|
let (delta_nabla_b, delta_nabla_w) = self.backprop(line.inputs.to_vec(), &line.label);
|
||||||
|
|
||||||
|
|
@ -124,15 +116,7 @@ impl Network {
|
||||||
/// `nabla_w` are layer-by-layer lists of matrices, similar
|
/// `nabla_w` are layer-by-layer lists of matrices, similar
|
||||||
/// to `self.biases` and `self.weights`.
|
/// to `self.biases` and `self.weights`.
|
||||||
fn backprop(&self, x: Vec<f64>, y: &OneHotVector) -> (Vec<DMatrix<f64>>, Vec<DMatrix<f64>>) {
|
fn backprop(&self, x: Vec<f64>, y: &OneHotVector) -> (Vec<DMatrix<f64>>, Vec<DMatrix<f64>>) {
|
||||||
// zero_grad ie. set gradient to zero
|
let (mut nabla_b, mut nabla_w) = self.zero_gradient();
|
||||||
let mut nabla_b: Vec<DMatrix<f64>> = self.biases.iter()
|
|
||||||
.map(|b| b.shape())
|
|
||||||
.map(|s| DMatrix::zeros(s.0, s.1))
|
|
||||||
.collect();
|
|
||||||
let mut nabla_w: Vec<DMatrix<f64>> = self.weights.iter()
|
|
||||||
.map(|w| w.shape())
|
|
||||||
.map(|s| DMatrix::zeros(s.0, s.1))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
// feedforward
|
// feedforward
|
||||||
let mut activation = DMatrix::from_vec(x.len(), 1, x);
|
let mut activation = DMatrix::from_vec(x.len(), 1, x);
|
||||||
|
|
@ -168,6 +152,19 @@ impl Network {
|
||||||
|
|
||||||
(nabla_b, nabla_w)
|
(nabla_b, nabla_w)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn zero_gradient(&self) -> (Vec<DMatrix<f64>>, Vec<DMatrix<f64>>) {
|
||||||
|
let nabla_b: Vec<DMatrix<f64>> = self.biases.iter()
|
||||||
|
.map(|b| b.shape())
|
||||||
|
.map(|s| DMatrix::zeros(s.0, s.1))
|
||||||
|
.collect();
|
||||||
|
let nabla_w: Vec<DMatrix<f64>> = self.weights.iter()
|
||||||
|
.map(|w| w.shape())
|
||||||
|
.map(|s| DMatrix::zeros(s.0, s.1))
|
||||||
|
.collect();
|
||||||
|
(nabla_b, nabla_w)
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cost_derivative(output_activations: &DMatrix<f64>, y: &OneHotVector) -> DMatrix<f64> {
|
fn cost_derivative(output_activations: &DMatrix<f64>, y: &OneHotVector) -> DMatrix<f64> {
|
||||||
|
|
@ -225,7 +222,6 @@ fn sigmoid_prime(val: f64) -> f64 {
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test {
|
mod test {
|
||||||
use std::convert::identity;
|
|
||||||
use nalgebra::DMatrix;
|
use nalgebra::DMatrix;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
@ -241,7 +237,7 @@ mod test {
|
||||||
fn test_sigmoid_inplace() {
|
fn test_sigmoid_inplace() {
|
||||||
let mut v = 10.0;
|
let mut v = 10.0;
|
||||||
sigmoid_inplace(&mut v);
|
sigmoid_inplace(&mut v);
|
||||||
assert_eq!(0.9999546, v);
|
assert_eq!(0.9999546021312976, v);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue