fixed runtime errors
This commit is contained in:
parent
d9ba8cc079
commit
74b228ead0
7 changed files with 297 additions and 61 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -1,3 +1,4 @@
|
||||||
/target
|
/target
|
||||||
*.iml
|
*.iml
|
||||||
.idea
|
.idea
|
||||||
|
src/data/training.json
|
||||||
1
src/data/unittest.json
Normal file
1
src/data/unittest.json
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
{"x":[[0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.001171875,0.00703125,0.00703125,0.00703125,0.04921875,0.053125,0.068359375,0.01015625,0.06484375,0.099609375,0.096484375,0.049609375,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.01171875,0.0140625,0.03671875,0.06015625,0.06640625,0.098828125,0.098828125,0.098828125,0.098828125,0.098828125,0.087890625,0.0671875,0.098828125,0.09453125,0.076171875,0.025,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.019140625,0.09296875,0.098828125,0.098828125,0.098828125,0.098828125,0.098828125,0.098828125,0.098828125,0.098828125,0.098046875,0.036328125,0.03203125,0.03203125,0.021875,0.015234375,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.00703125,0.085546875,0.098828125,0.098828125,0.098828125,0.098828125,0.098828125,0.07734375,0.07109375,0.096484375,0.094140625,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.03125,0.0609375,0.041796875,0.098828125,0.098828125,0.080078125,0.004296875,0.0,0.016796875,0.06015625,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.00546875,0.000390625,0.06015625,0.098828125,0.03515625,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.054296875,0.098828125,0.07421875,0.00078125,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.004296875,0.07421875,0.098828125,0.02734375,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.013671875,0.094140625,0.087890625,0.0625,0.0421875,0.000390625,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.031640625,0.09375,0.098828125,0.098828125,0.046484375,0.009765625,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.017578125,0.07265625,0.098828125,0.098828125,0.05859375,0.010546875,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.00625,0.036328125,0.0984375,0.098828125,0.073046875,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.097265625,0.098828125,0.097265625,0.025,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.01796875,0.05078125,0.071484375,0.098828125,0.098828125,0.080859375,0.00078125,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.015234375,0.0578125,0.089453125,0.098828125,0.098828125,0.098828125,0.09765625,0.07109375,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.009375,0.04453125,0.086328125,0.098828125,0.098828125,0.098828125,0.098828125,0.078515625,0.03046875,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.008984375,0.02578125,0.083203125,0.098828125,0.098828125,0.098828125,0.098828125,0.07734375,0.031640625,0.00078125,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.00703125,0.066796875,0.085546875,0.098828125,0.098828125,0.098828125,0.098828125,0.076171875,0.03125,0.003515625,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.021484375,0.0671875,0.08828125,0.098828125,0.098828125,0.098828125,0.098828125,0.0953125,0.051953125,0.004296875,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.053125,0.098828125,0.098828125,0.098828125,0.0828125,0.052734375,0.0515625,0.00625,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]], "y":[5]}
|
||||||
|
|
@ -1,19 +1,20 @@
|
||||||
use std::iter::zip;
|
use std::iter::zip;
|
||||||
|
use nalgebra::DMatrix;
|
||||||
|
|
||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
|
||||||
pub fn load_data() -> Data<f32, u8> {
|
pub fn load_data() -> Data<f32, OneHotVector> {
|
||||||
/// the mnist data is structured as
|
// the mnist data is structured as
|
||||||
/// x: [[[pixels]],[[pixels]], etc],
|
// x: [[[pixels]],[[pixels]], etc],
|
||||||
/// y: [label1, label2, etc]
|
// y: [label1, label2, etc]
|
||||||
/// this is transformed to:
|
// this is transformed to:
|
||||||
/// Data : Vec<DataLine>
|
// Data : Vec<DataLine>
|
||||||
/// DataLine {inputs: Vec<pixels as f32>, label: f32}
|
// DataLine {inputs: Vec<pixels as f32>, label: f32}
|
||||||
let raw_data: RawData = serde_json::from_slice(include_bytes!("data/unittest.json")).unwrap();
|
let raw_data: RawData = serde_json::from_slice(include_bytes!("data/unittest.json")).unwrap();
|
||||||
let mut vec = Vec::new();
|
let mut vec = Vec::new();
|
||||||
for (x, y) in zip(raw_data.x, raw_data.y) {
|
for (x, y) in zip(raw_data.x, raw_data.y) {
|
||||||
vec.push(DataLine { inputs: x, label: y});
|
vec.push(DataLine { inputs: x, label: onehot(y) });
|
||||||
}
|
}
|
||||||
|
|
||||||
Data(vec)
|
Data(vec)
|
||||||
|
|
@ -27,17 +28,38 @@ struct RawData {
|
||||||
|
|
||||||
/// X is type of input
|
/// X is type of input
|
||||||
/// Y is type of output
|
/// Y is type of output
|
||||||
pub struct DataLine<X,Y> {
|
pub struct DataLine<X, Y> {
|
||||||
pub inputs: Vec<X>,
|
pub inputs: Vec<X>,
|
||||||
pub label: Y,
|
pub label: Y,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Data<X,Y>(pub Vec<DataLine<X,Y>>);
|
pub struct Data<X, Y>(pub Vec<DataLine<X, Y>>);
|
||||||
|
|
||||||
|
pub struct OneHotVector{
|
||||||
|
pub val: usize
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OneHotVector{
|
||||||
|
fn new(val: usize) -> Self{
|
||||||
|
Self{
|
||||||
|
val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get(&self, index: usize) -> f32{
|
||||||
|
if self.val == index {
|
||||||
|
1.0
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
impl<X,Y> Data<X,Y> {
|
}
|
||||||
|
|
||||||
|
impl<X, Y> Data<X, Y> {
|
||||||
pub fn shuffle(&mut self) {
|
pub fn shuffle(&mut self) {
|
||||||
let mut rng = rand::thread_rng();
|
let mut rng = thread_rng();
|
||||||
self.0.shuffle(&mut rng);
|
self.0.shuffle(&mut rng);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -45,7 +67,7 @@ impl<X,Y> Data<X,Y> {
|
||||||
self.0.len()
|
self.0.len()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn as_batches(&self, batch_size: usize) -> Vec<&[DataLine<X,Y>]> {
|
pub fn as_batches(&self, batch_size: usize) -> Vec<&[DataLine<X, Y>]> {
|
||||||
let mut batches = Vec::with_capacity(self.0.len() / batch_size + 1);
|
let mut batches = Vec::with_capacity(self.0.len() / batch_size + 1);
|
||||||
let mut offset = 0;
|
let mut offset = 0;
|
||||||
for _ in 0..self.0.len() / batch_size {
|
for _ in 0..self.0.len() / batch_size {
|
||||||
|
|
@ -55,6 +77,9 @@ impl<X,Y> Data<X,Y> {
|
||||||
batches.push(&self.0[offset..self.0.len()]);
|
batches.push(&self.0[offset..self.0.len()]);
|
||||||
batches
|
batches
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// returns a vector as matrix where y is one-hot encoded
|
||||||
|
fn onehot(y: u8) -> OneHotVector {
|
||||||
|
OneHotVector::new(y as usize)
|
||||||
|
}
|
||||||
|
|
@ -1,2 +1,3 @@
|
||||||
pub mod net;
|
pub mod net;
|
||||||
pub mod dataloader;
|
pub mod dataloader;
|
||||||
|
mod mat;
|
||||||
15
src/main.rs
15
src/main.rs
|
|
@ -2,14 +2,15 @@ use mnist_rs::dataloader::load_data;
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
let mut net = mnist_rs::net::Network::from(vec![784, 30, 10]);
|
let mut net = mnist_rs::net::Network::from(vec![784, 30, 10]);
|
||||||
for w in net.weights.iter() {
|
|
||||||
println!("{}, {}", w.shape().0, w.shape().1);
|
|
||||||
}
|
|
||||||
println!();
|
|
||||||
for b in net.biases.iter() {
|
|
||||||
println!("{:?}", b.shape());
|
|
||||||
}
|
|
||||||
let training_data = load_data();
|
let training_data = load_data();
|
||||||
|
|
||||||
net.sgd(training_data, 30, 10, 3.0, &None);
|
net.sgd(training_data, 30, 10, 3.0, &None);
|
||||||
|
|
||||||
|
|
||||||
|
// let sizes = vec![5,3,2];
|
||||||
|
// let net = mnist_rs::net::Network::from(sizes);
|
||||||
|
// println!("biases {:?}", net.biases.iter().map(|b|b.shape()).collect::<Vec<(usize,usize)>>());
|
||||||
|
// println!("weights {:?}", net.weights.iter().map(|b|b.shape()).collect::<Vec<(usize,usize)>>());
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
205
src/mat.rs
Normal file
205
src/mat.rs
Normal file
|
|
@ -0,0 +1,205 @@
|
||||||
|
use core::ops::Add;
|
||||||
|
use std::fmt::Debug;
|
||||||
|
use std::ops::AddAssign;
|
||||||
|
use nalgebra::DMatrix;
|
||||||
|
|
||||||
|
pub fn add<T>(v1: DMatrix<T>, v2: DMatrix<T>) -> Result<DMatrix<T>, String>
|
||||||
|
where T: PartialEq + Copy + Clone + Debug + Add + Add<Output=T> + AddAssign + 'static
|
||||||
|
{
|
||||||
|
let (r1, c1) = v1.shape();
|
||||||
|
let (r2, c2) = v2.shape();
|
||||||
|
|
||||||
|
if r1 == r2 && c1 == c2 {
|
||||||
|
// same size, no broadcasting needed
|
||||||
|
Ok(v1 + v2)
|
||||||
|
} else if r1 == 1 && c2 == 1 {
|
||||||
|
Ok(DMatrix::from_fn(r2, c1, |r, c| *v1.get(c).unwrap() + *v2.get(r).unwrap()))
|
||||||
|
} else if c1 == 1 && r2 == 1 {
|
||||||
|
Ok(DMatrix::from_fn(r1, c2, |r, c| *v1.get(r).unwrap() + *v2.get(c).unwrap()))
|
||||||
|
} else if r1 == 1 && c1 == c2 {
|
||||||
|
Ok(DMatrix::from_fn(r2, c1, |r, c| *v1.get(c).unwrap() + *v2.get(c * r2 + r).unwrap()))
|
||||||
|
} else if r2 == 1 && c1 == c2 {
|
||||||
|
Ok(DMatrix::from_fn(r1, c2, |r, c| *v2.get(c).unwrap() + *v1.get(c * r1 + r).unwrap()))
|
||||||
|
} else if c1 == 1 && r1 == r2 {
|
||||||
|
Ok(DMatrix::from_fn(r1, c2, |r, c| *v1.get(r).unwrap() + *v2.get(c * r2 + r).unwrap()))
|
||||||
|
} else if c2 == 1 && r1 == r2 {
|
||||||
|
Ok(DMatrix::from_fn(r2, c1, |r, c| *v2.get(r).unwrap() + *v1.get(c * r2 + r).unwrap()))
|
||||||
|
} else {
|
||||||
|
Err(format!("ValueError: operands could not be broadcast together ({},{}), ({},{})", r1,c1, r2, c2))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test {
|
||||||
|
use nalgebra::dmatrix;
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn stretch_row_column_to_square() {
|
||||||
|
let v1: DMatrix<u32> = dmatrix![1,2,3];
|
||||||
|
let v2: DMatrix<u32> = dmatrix![1;2;3];
|
||||||
|
|
||||||
|
let sum = add(v1, v2).unwrap();
|
||||||
|
assert_eq!(sum.shape(), (3, 3));
|
||||||
|
assert_eq!(sum, dmatrix![2,3,4;3,4,5;4,5,6]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn stretch_row_column_to_rect() {
|
||||||
|
let v1: DMatrix<u32> = dmatrix![1,2,3];
|
||||||
|
let v2: DMatrix<u32> = dmatrix![1;2];
|
||||||
|
|
||||||
|
let sum = add(v1, v2).unwrap();
|
||||||
|
assert_eq!(sum.shape(), (2, 3));
|
||||||
|
assert_eq!(sum, dmatrix![2,3,4;3,4,5]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn stretch_column_row_to_square() {
|
||||||
|
let v1: DMatrix<u32> = dmatrix![1;2;3];
|
||||||
|
let v2: DMatrix<u32> = dmatrix![1,2,3];
|
||||||
|
|
||||||
|
let sum = add(v1, v2).unwrap();
|
||||||
|
assert_eq!(sum.shape(), (3, 3));
|
||||||
|
assert_eq!(sum, dmatrix![2,3,4;3,4,5;4,5,6]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn stretch_column_row_to_rect() {
|
||||||
|
let v1: DMatrix<u32> = dmatrix![1;2;3];
|
||||||
|
let v2: DMatrix<u32> = dmatrix![1,2];
|
||||||
|
|
||||||
|
let sum = add(v1, v2).unwrap();
|
||||||
|
assert_eq!(sum.shape(), (3, 2));
|
||||||
|
assert_eq!(sum, dmatrix![2,3;3,4;4,5]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn stretch_row() {
|
||||||
|
let v1: DMatrix<u32> = dmatrix![1,2,3];
|
||||||
|
let v2: DMatrix<u32> = dmatrix![1,2,3;4,5,6];
|
||||||
|
|
||||||
|
let sum = add(v1, v2).unwrap();
|
||||||
|
assert_eq!(sum.shape(), (2, 3));
|
||||||
|
assert_eq!(sum, dmatrix![2,4,6;5,7,9]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn stretch_row_commute() {
|
||||||
|
let v1: DMatrix<u32> = dmatrix![1,2,3;4,5,6];
|
||||||
|
let v2: DMatrix<u32> = dmatrix![1,2,3];
|
||||||
|
|
||||||
|
let sum = add(v1, v2).unwrap();
|
||||||
|
assert_eq!(sum.shape(), (2, 3));
|
||||||
|
assert_eq!(sum, dmatrix![2,4,6;5,7,9]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn stretch_column() {
|
||||||
|
let v1: DMatrix<u32> = dmatrix![1;2];
|
||||||
|
let v2: DMatrix<u32> = dmatrix![1,2,3;4,5,6];
|
||||||
|
|
||||||
|
let sum = add(v1, v2).unwrap();
|
||||||
|
assert_eq!(sum.shape(), (2, 3));
|
||||||
|
assert_eq!(sum, dmatrix![2,3,4;6,7,8]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn stretch_column_commute() {
|
||||||
|
let v1: DMatrix<u32> = dmatrix![1,2,3;4,5,6];
|
||||||
|
let v2: DMatrix<u32> = dmatrix![1;2];
|
||||||
|
|
||||||
|
let sum = add(v1, v2).unwrap();
|
||||||
|
assert_eq!(sum.shape(), (2, 3));
|
||||||
|
assert_eq!(sum, dmatrix![2,3,4;6,7,8]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_broadcast_2dims() {
|
||||||
|
let v1: DMatrix<u32> = dmatrix![1,2,3];
|
||||||
|
let v2: DMatrix<u32> = dmatrix![1;2;3];
|
||||||
|
|
||||||
|
let sum = add(v1, v2).unwrap();
|
||||||
|
assert_eq!(sum.shape(), (3, 3));
|
||||||
|
assert_eq!(sum, dmatrix![2,3,4;3,4,5;4,5,6]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_add_commutative() {
|
||||||
|
let v1 = dmatrix![1,2,3];
|
||||||
|
let v2 = dmatrix![1;2;3];
|
||||||
|
|
||||||
|
let sum = add(v2, v1).unwrap();
|
||||||
|
assert_eq!(sum.shape(), (3, 3));
|
||||||
|
assert_eq!(sum, dmatrix![2,3,4;3,4,5;4,5,6]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_add_same_size() {
|
||||||
|
let v1 = dmatrix![1,2;3,4];
|
||||||
|
let v2 = dmatrix![3,4;5,6];
|
||||||
|
|
||||||
|
let sum = add(v2, v1).unwrap();
|
||||||
|
assert_eq!(sum.shape(), (2, 2));
|
||||||
|
assert_eq!(sum, dmatrix![4,6;8,10]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_add_row_broadcast() {//
|
||||||
|
let v1 = dmatrix![1,2;3,4];
|
||||||
|
let v2 = dmatrix![3,4];
|
||||||
|
|
||||||
|
let sum = add(v1, v2).unwrap();
|
||||||
|
assert_eq!(sum, dmatrix![4,6;6,8]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_add_row_broadcast2() {
|
||||||
|
let v1 = dmatrix![1,1];
|
||||||
|
let v2 = dmatrix![1,2;3,4];
|
||||||
|
|
||||||
|
let sum = add(v1, v2).unwrap();
|
||||||
|
assert_eq!(sum.shape(), (2, 2));
|
||||||
|
assert_eq!(sum, dmatrix![2,3;4,5]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_column_broadcast() {
|
||||||
|
let v1 = dmatrix![1;1];
|
||||||
|
let v2 = dmatrix![1,2;3,4];
|
||||||
|
|
||||||
|
let sum = add(v1, v2).unwrap();
|
||||||
|
assert_eq!(sum.shape(), (2, 2));
|
||||||
|
assert_eq!(sum, dmatrix![2,3;4,5]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_column_broadcast2() {
|
||||||
|
let v1 = dmatrix![1,2;3,4];
|
||||||
|
let v2 = dmatrix![1;1];
|
||||||
|
|
||||||
|
let sum = add(v1, v2).unwrap();
|
||||||
|
assert_eq!(sum.shape(), (2, 2));
|
||||||
|
assert_eq!(sum, dmatrix![2,3;4,5]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn column_too_long() {
|
||||||
|
let v1 = dmatrix![1;1;1];
|
||||||
|
let v2 = dmatrix![1,2;3,4];
|
||||||
|
|
||||||
|
let result = add(v1, v2);
|
||||||
|
assert_eq!(result, Err("ValueError: operands could not be broadcast together".to_owned()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn row_too_long() {
|
||||||
|
let v1 = dmatrix![1,1,1];
|
||||||
|
let v2 = dmatrix![1,2;3,4];
|
||||||
|
|
||||||
|
let result = add(v1, v2);
|
||||||
|
assert_eq!(result, Err("ValueError: operands could not be broadcast together".to_owned()));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
76
src/net.rs
76
src/net.rs
|
|
@ -1,17 +1,18 @@
|
||||||
use std::convert::identity;
|
|
||||||
use std::iter::zip;
|
use std::iter::zip;
|
||||||
use std::ops::{Add, Sub};
|
use std::ops::Add;
|
||||||
|
|
||||||
use nalgebra::{DMatrix, Matrix, OMatrix};
|
use nalgebra::DMatrix;
|
||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
use rand_distr::Normal;
|
use rand_distr::Normal;
|
||||||
|
|
||||||
use crate::dataloader::{Data, DataLine};
|
use crate::dataloader::{Data, DataLine, OneHotVector};
|
||||||
|
use crate::mat;
|
||||||
|
use crate::mat::add;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Network {
|
pub struct Network {
|
||||||
_sizes: Vec<usize>,
|
_sizes: Vec<usize>,
|
||||||
_num_layers: usize,
|
num_layers: usize,
|
||||||
pub biases: Vec<DMatrix<f32>>,
|
pub biases: Vec<DMatrix<f32>>,
|
||||||
pub weights: Vec<DMatrix<f32>>,
|
pub weights: Vec<DMatrix<f32>>,
|
||||||
}
|
}
|
||||||
|
|
@ -30,7 +31,7 @@ impl Network {
|
||||||
pub fn from(sizes: Vec<usize>) -> Self {
|
pub fn from(sizes: Vec<usize>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
_sizes: sizes.clone(),
|
_sizes: sizes.clone(),
|
||||||
_num_layers: sizes.len(),
|
num_layers: sizes.len(),
|
||||||
biases: biases(sizes[1..].to_vec()),
|
biases: biases(sizes[1..].to_vec()),
|
||||||
weights: weights(zip(sizes[..sizes.len() - 1].to_vec(), sizes[1..].to_vec()).collect()),
|
weights: weights(zip(sizes[..sizes.len() - 1].to_vec(), sizes[1..].to_vec()).collect()),
|
||||||
}
|
}
|
||||||
|
|
@ -39,13 +40,13 @@ impl Network {
|
||||||
fn feed_forward(&self, input: Vec<f32>) -> Vec<f32> {
|
fn feed_forward(&self, input: Vec<f32>) -> Vec<f32> {
|
||||||
let mut a = DMatrix::from_vec(input.len(), 1, input);
|
let mut a = DMatrix::from_vec(input.len(), 1, input);
|
||||||
for (b, w) in zip(&self.biases, &self.weights) {
|
for (b, w) in zip(&self.biases, &self.weights) {
|
||||||
a = b.add_scalar(w.dot(&a));
|
a = add(b.clone(), w * a).unwrap();
|
||||||
a.apply(sigmoid_inplace);
|
a.apply(sigmoid_inplace);
|
||||||
}
|
}
|
||||||
a.column(1).iter().map(|v| *v).collect()
|
a.column(1).iter().map(|v| *v).collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn sgd(&mut self, mut training_data: Data<f32, u8>, epochs: usize, minibatch_size: usize, eta: f32, test_data: &Option<Data<f32, u8>>) {
|
pub fn sgd(&mut self, mut training_data: Data<f32, OneHotVector>, epochs: usize, minibatch_size: usize, eta: f32, test_data: &Option<Data<f32, OneHotVector>>) {
|
||||||
for j in 0..epochs {
|
for j in 0..epochs {
|
||||||
training_data.shuffle();
|
training_data.shuffle();
|
||||||
let mini_batches = training_data.as_batches(minibatch_size);
|
let mini_batches = training_data.as_batches(minibatch_size);
|
||||||
|
|
@ -65,7 +66,7 @@ impl Network {
|
||||||
/// gradient descent using backpropagation to a single mini batch.
|
/// gradient descent using backpropagation to a single mini batch.
|
||||||
/// The ``mini_batch`` is a list of tuples ``(x, y)``, and ``eta``
|
/// The ``mini_batch`` is a list of tuples ``(x, y)``, and ``eta``
|
||||||
/// is the learning rate.
|
/// is the learning rate.
|
||||||
fn update_mini_batch(&mut self, mini_batch: &[DataLine<f32, u8>], eta: f32) {
|
fn update_mini_batch(&mut self, mini_batch: &[DataLine<f32, OneHotVector>], eta: f32) {
|
||||||
let mut nabla_b: Vec<DMatrix<f32>> = self.biases.iter()
|
let mut nabla_b: Vec<DMatrix<f32>> = self.biases.iter()
|
||||||
.map(|b| b.shape())
|
.map(|b| b.shape())
|
||||||
.map(|s| DMatrix::zeros(s.0, s.1))
|
.map(|s| DMatrix::zeros(s.0, s.1))
|
||||||
|
|
@ -75,34 +76,34 @@ impl Network {
|
||||||
.map(|s| DMatrix::zeros(s.0, s.1))
|
.map(|s| DMatrix::zeros(s.0, s.1))
|
||||||
.collect();
|
.collect();
|
||||||
for line in mini_batch.iter() {
|
for line in mini_batch.iter() {
|
||||||
let (delta_nabla_b, delta_nabla_w) = self.backprop(line.inputs.to_vec(), line.label);
|
let (delta_nabla_b, delta_nabla_w) = self.backprop(line.inputs.to_vec(), &line.label);
|
||||||
|
|
||||||
nabla_b = zip(&nabla_b, &delta_nabla_b).map(|(nb, dnb)| nb.add(dnb)).collect();
|
nabla_b = zip(&nabla_b, &delta_nabla_b).map(|(nb, dnb)| nb.add(dnb)).collect();
|
||||||
nabla_w = zip(&nabla_w, &delta_nabla_w).map(|(nw, dnw)| nw.add(dnw)).collect();
|
nabla_w = zip(&nabla_w, &delta_nabla_w).map(|(nw, dnw)| nw.add(dnw)).collect();
|
||||||
}
|
}
|
||||||
|
|
||||||
self.weights = zip(&self.weights, &nabla_w)
|
self.weights = zip(&self.weights, &nabla_w)
|
||||||
.map(|(w, nw)| w.add_scalar(-eta / mini_batch.len() as f32)).collect();
|
.map(|(w, nw)| (w.add_scalar(-eta / mini_batch.len() as f32)).component_mul(nw)).collect();
|
||||||
self.biases = zip(&self.biases, &nabla_b)
|
self.biases = zip(&self.biases, &nabla_b)
|
||||||
.map(|(b, nb)| b.add_scalar(-eta / mini_batch.len() as f32)).collect();
|
.map(|(b, nb)| (b.add_scalar(-eta / mini_batch.len() as f32)).component_mul(nb)).collect();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return the number of test inputs for which the neural
|
/// Return the number of test inputs for which the neural
|
||||||
/// network outputs the correct result. Note that the neural
|
/// network outputs the correct result. Note that the neural
|
||||||
/// network's output is assumed to be the index of whichever
|
/// network's output is assumed to be the index of whichever
|
||||||
/// neuron in the final layer has the highest activation.
|
/// neuron in the final layer has the highest activation.
|
||||||
fn evaluate(&self, test_data: &Data<f32, u8>) -> usize {
|
fn evaluate(&self, test_data: &Data<f32, OneHotVector>) -> usize {
|
||||||
let test_results: Vec<(usize, u8)> = test_data.0.iter()
|
let test_results: Vec<(usize, usize)> = test_data.0.iter()
|
||||||
.map(|line| (argmax(self.feed_forward(line.inputs.clone())), line.label))
|
.map(|line| (argmax(self.feed_forward(line.inputs.clone())), line.label.val))
|
||||||
.collect();
|
.collect();
|
||||||
test_results.into_iter().filter(|(x, y)| *x == *y as usize).count()
|
test_results.into_iter().filter(|(x, y)| x == y).count()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return a tuple `(nabla_b, nabla_w)` representing the
|
/// Return a tuple `(nabla_b, nabla_w)` representing the
|
||||||
/// gradient for the cost function C_x. `nabla_b` and
|
/// gradient for the cost function C_x. `nabla_b` and
|
||||||
/// `nabla_w` are layer-by-layer lists of matrices, similar
|
/// `nabla_w` are layer-by-layer lists of matrices, similar
|
||||||
/// to `self.biases` and `self.weights`.
|
/// to `self.biases` and `self.weights`.
|
||||||
fn backprop(&self, x: Vec<f32>, y: u8) -> (Vec<DMatrix<f32>>, Vec<DMatrix<f32>>) {
|
fn backprop(&self, x: Vec<f32>, y: &OneHotVector) -> (Vec<DMatrix<f32>>, Vec<DMatrix<f32>>) {
|
||||||
// zero_grad ie. set gradient to zero
|
// zero_grad ie. set gradient to zero
|
||||||
let mut nabla_b: Vec<DMatrix<f32>> = self.biases.iter()
|
let mut nabla_b: Vec<DMatrix<f32>> = self.biases.iter()
|
||||||
.map(|b| b.shape())
|
.map(|b| b.shape())
|
||||||
|
|
@ -119,38 +120,40 @@ impl Network {
|
||||||
let mut zs = vec![];
|
let mut zs = vec![];
|
||||||
|
|
||||||
for (b, w) in zip(&self.biases, &self.weights) {
|
for (b, w) in zip(&self.biases, &self.weights) {
|
||||||
// println!("{:?}", w.shape());
|
let z = add(w * &activation, b.clone()).unwrap();
|
||||||
// println!("{:?}", activation.shape());
|
|
||||||
// println!("{:?}", b.shape());
|
|
||||||
|
|
||||||
let mut z: DMatrix<f32> = w * &activation + b;
|
|
||||||
zs.push(z.clone());
|
zs.push(z.clone());
|
||||||
activation = z.map(sigmoid);
|
activation = z.map(sigmoid);
|
||||||
activations.push(activation.clone());
|
activations.push(activation.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
// backward pass
|
// backward pass
|
||||||
let delta: DMatrix<f32> = self.cost_derivative(
|
// delta = self.cost_derivative(activations[-1], y) * sigmoid_prime(zs[-1])
|
||||||
&activations[activations.len() - 1],
|
let delta: DMatrix<f32> = self.cost_derivative(&activations[activations.len() - 1], y).component_mul((&zs[zs.len() - 1].map(sigmoid_prime)));
|
||||||
y as f32);
|
|
||||||
println!("delta {:?}", delta.shape());
|
|
||||||
println!("z {:?}", &zs[zs.len() - 1].transpose().shape());
|
|
||||||
let delta = delta * (&zs[zs.len() - 1].transpose().map(sigmoid_prime));
|
|
||||||
println!("delta {:?}", delta.shape());
|
|
||||||
let index = nabla_b.len() - 1;
|
let index = nabla_b.len() - 1;
|
||||||
nabla_b[index] = delta.clone();
|
nabla_b[index] = delta.clone();
|
||||||
|
|
||||||
println!("delta {:?}", delta.shape());
|
|
||||||
println!("activation {:?}", activations[activations.len() - 2].shape());
|
|
||||||
let index = nabla_w.len() - 1;
|
let index = nabla_w.len() - 1;
|
||||||
nabla_w[index] = delta * &activations[activations.len() - 2];
|
let ac = &activations[activations.len() - 2].transpose();
|
||||||
|
nabla_w[index] = &delta * ac;
|
||||||
|
let lens_zs = zs.len();
|
||||||
|
for l in 2..self.num_layers {
|
||||||
|
let z = &zs[lens_zs - l];
|
||||||
|
let sp = z.map(sigmoid_prime);
|
||||||
|
let weight = self.weights[self.weights.len() - l + 1].transpose();
|
||||||
|
let delta2 = (weight * &delta).component_mul(&sp);
|
||||||
|
let len_nb = nabla_b.len();
|
||||||
|
nabla_b[len_nb - l] = delta2.clone();
|
||||||
|
let len_nw = nabla_w.len();
|
||||||
|
nabla_w[len_nw - l] = delta2 * activations[activations.len() - l - 1].transpose();
|
||||||
|
}
|
||||||
|
|
||||||
(nabla_b, nabla_w)
|
(nabla_b, nabla_w)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cost_derivative(&self, output_activations: &DMatrix<f32>, y: f32) -> DMatrix<f32> {
|
fn cost_derivative(&self, output_activations: &DMatrix<f32>, y: &OneHotVector) -> DMatrix<f32> {
|
||||||
output_activations.add_scalar(-y)
|
// output_activations - y
|
||||||
|
let shape = output_activations.shape();
|
||||||
|
DMatrix::from_iterator(shape.0, shape.1, output_activations.iter().enumerate()
|
||||||
|
.map(|(index, a)| a - y.get(index)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -171,7 +174,6 @@ fn biases(sizes: Vec<usize>) -> Vec<DMatrix<f32>> {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn weights(sizes: Vec<(usize, usize)>) -> Vec<DMatrix<f32>> {
|
fn weights(sizes: Vec<(usize, usize)>) -> Vec<DMatrix<f32>> {
|
||||||
println!("{:?}", sizes);
|
|
||||||
sizes.iter().map(|size| random_matrix(size.1, size.0)).collect()
|
sizes.iter().map(|size| random_matrix(size.1, size.0)).collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue