Compare commits

..

10 commits

Author SHA1 Message Date
Shautvast
0bd18ba314 changed to f32: ~30s -> ~25s 2025-02-04 08:09:00 +01:00
Shautvast
f0876967d2 added the convert in README.md 2025-02-02 16:46:37 +01:00
Shautvast
aff6e78326 updated results 2025-02-02 15:42:41 +01:00
Shautvast
4259b9c646 updated results 2025-02-02 15:41:52 +01:00
Shautvast
221e5aa058 added convert_pickle.py 2025-02-02 15:40:49 +01:00
Shautvast
69d518e975 added results 2023-03-03 16:49:12 +01:00
Shautvast
8fba0ba300 reformat readme 2023-03-03 16:47:29 +01:00
Shautvast
0940071724 reorganized 2023-03-03 16:45:52 +01:00
Shautvast
9ca26feac2 extends readme 2023-03-03 16:43:47 +01:00
Shautvast
0b62f3cbc2 tests fixed, removed unused broadcast add 2023-03-03 16:40:30 +01:00
9 changed files with 260 additions and 350 deletions

173
Cargo.lock generated
View file

@ -1,6 +1,6 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
version = 4
[[package]]
name = "approx"
@ -17,6 +17,12 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "bitflags"
version = "2.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36"
[[package]]
name = "bytemuck"
version = "1.13.0"
@ -31,13 +37,14 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "getrandom"
version = "0.2.8"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31"
checksum = "43a49c392881ce6d5c3b8cb70f98717b7c07aabbdff06687b9030dbfbe2725f8"
dependencies = [
"cfg-if",
"libc",
"wasi",
"windows-targets",
]
[[package]]
@ -48,9 +55,9 @@ checksum = "fad582f4b9e86b6caa621cabeb0963332d92eea04729ab12892c2533951e6440"
[[package]]
name = "libc"
version = "0.2.139"
version = "0.2.169"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "201de327520df007757c1f0adce6e827fe8562fbc28bfd9c15571c66ca1f5f79"
checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a"
[[package]]
name = "libm"
@ -69,7 +76,7 @@ dependencies = [
[[package]]
name = "mnist-rs"
version = "0.1.0"
version = "1.0.0"
dependencies = [
"nalgebra",
"rand",
@ -80,9 +87,9 @@ dependencies = [
[[package]]
name = "nalgebra"
version = "0.32.1"
version = "0.33.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6515c882ebfddccaa73ead7320ca28036c4bc84c9bcca3cc0cbba8efe89223a"
checksum = "26aecdf64b707efd1310e3544d709c5c0ac61c13756046aaaba41be5c4f66a3b"
dependencies = [
"approx",
"matrixmultiply",
@ -96,13 +103,13 @@ dependencies = [
[[package]]
name = "nalgebra-macros"
version = "0.2.0"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d232c68884c0c99810a5a4d333ef7e47689cfd0edc85efc9e54e1e6bf5212766"
checksum = "254a5372af8fc138e36684761d3c0cdb758a4410e938babcff1c860ce14ddbfc"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.98",
]
[[package]]
@ -159,38 +166,38 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
[[package]]
name = "proc-macro2"
version = "1.0.51"
version = "1.0.93"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d727cae5b39d21da60fa540906919ad737832fe0b1c165da3a34d6548c849d6"
checksum = "60946a68e5f9d28b0dc1c21bb8a97ee7d018a8b322fa57838ba31cc878e22d99"
dependencies = [
"unicode-ident",
]
[[package]]
name = "quote"
version = "1.0.23"
version = "1.0.38"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8856d8364d252a14d474036ea1358d63c9e6965c8e5c1885c18f73d70bff9c7b"
checksum = "0e4dccaaaf89514f546c693ddc140f729f958c247918a13380cccc6078391acc"
dependencies = [
"proc-macro2",
]
[[package]]
name = "rand"
version = "0.8.5"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94"
dependencies = [
"libc",
"rand_chacha",
"rand_core",
"zerocopy",
]
[[package]]
name = "rand_chacha"
version = "0.3.1"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
dependencies = [
"ppv-lite86",
"rand_core",
@ -198,18 +205,19 @@ dependencies = [
[[package]]
name = "rand_core"
version = "0.6.4"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
checksum = "b08f3c9802962f7e1b25113931d94f43ed9725bebc59db9d0c3e9a23b67e15ff"
dependencies = [
"getrandom",
"zerocopy",
]
[[package]]
name = "rand_distr"
version = "0.4.3"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31"
checksum = "ddc3b5afe4c995c44540865b8ca5c52e6a59fa362da96c5d30886930ddc8da1c"
dependencies = [
"num-traits",
"rand",
@ -253,7 +261,7 @@ checksum = "af487d118eecd09402d70a5d72551860e788df87b464af30e5ea6a38c75c541e"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 1.0.107",
]
[[package]]
@ -269,9 +277,9 @@ dependencies = [
[[package]]
name = "simba"
version = "0.8.0"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "50582927ed6f77e4ac020c057f37a268fc6aebc29225050365aacbb9deeeddc4"
checksum = "b3a386a501cd104797982c15ae17aafe8b9261315b5d07e3ec803f2ea26be0fa"
dependencies = [
"approx",
"num-complex",
@ -291,6 +299,17 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "syn"
version = "2.0.98"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "36147f1a48ae0ec2b5b3bc5b537d267457555a10dc06f3dbc8cb11ba3006d3b1"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "typenum"
version = "1.16.0"
@ -305,9 +324,12 @@ checksum = "84a22b9f218b40614adcb3f4ff08b703773ad44fa9423e4e0d346d5db86e4ebc"
[[package]]
name = "wasi"
version = "0.11.0+wasi-snapshot-preview1"
version = "0.13.3+wasi-0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
checksum = "26816d2e1a4a36a2940b96c5296ce403917633dff8f3440e9b236ed6f6bacad2"
dependencies = [
"wit-bindgen-rt",
]
[[package]]
name = "wide"
@ -318,3 +340,96 @@ dependencies = [
"bytemuck",
"safe_arch",
]
[[package]]
name = "windows-targets"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973"
dependencies = [
"windows_aarch64_gnullvm",
"windows_aarch64_msvc",
"windows_i686_gnu",
"windows_i686_gnullvm",
"windows_i686_msvc",
"windows_x86_64_gnu",
"windows_x86_64_gnullvm",
"windows_x86_64_msvc",
]
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3"
[[package]]
name = "windows_aarch64_msvc"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469"
[[package]]
name = "windows_i686_gnu"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b"
[[package]]
name = "windows_i686_gnullvm"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66"
[[package]]
name = "windows_i686_msvc"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66"
[[package]]
name = "windows_x86_64_gnu"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d"
[[package]]
name = "windows_x86_64_msvc"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
[[package]]
name = "wit-bindgen-rt"
version = "0.33.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c"
dependencies = [
"bitflags",
]
[[package]]
name = "zerocopy"
version = "0.8.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1e101d4bc320b6f9abb68846837b70e25e380ca2f467ab494bf29fcc435fcc3"
dependencies = [
"zerocopy-derive",
]
[[package]]
name = "zerocopy-derive"
version = "0.8.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03a73df1008145cd135b3c780d275c57c3e6ba8324a41bd5e0008fe167c3bc7c"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.98",
]

View file

@ -1,13 +1,11 @@
[package]
name = "mnist-rs"
version = "0.1.0"
version = "1.0.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
rand = "0.8"
rand_distr = "0.4"
nalgebra = "0.32"
rand = "0.9"
rand_distr = "0.5"
nalgebra = "0.33"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"

View file

@ -2,14 +2,15 @@ rust port of python in http://neuralnetworksanddeeplearning.com/chap1.html
main goal: me understanding what's going on
done:
* implementation that 'works' without runtime errors
results: ~95% accuracy in 30 epochs, learning rate = 3.0, batchsize = 10
to do:
* verify correctness
* add unit tests
* train using actual training data
* evaluate with test/validation data
* make more efficient
*training_data/test_data not included*<br/> too big for github
training_data/test_data not included
Format: json: <br/>
[{"x":[float;784], "y": u32}]<br/>
=> x: 28x28 gray image as float<br/>
=> y: label 0.. 9
the data can be found here:
https://github.com/mnielsen/neural-networks-and-deep-learning/blob/master/data/mnist.pkl.gz
albeit in a different format, convert it with `convert_pickle.py`

22
convert_pickle.py Normal file
View file

@ -0,0 +1,22 @@
import pickle
import gzip
import json
# Load the data from the .pkl.gz file
with gzip.open("mnist.pkl.gz", "rb") as f:
training_data, validation_data, test_data = pickle.load(f, encoding="latin1")
# Define a helper function to convert the data into JSON serializable format
def convert_data(data):
features, labels = data
return [{"x": features[i].tolist(), "y": int(labels[i])} for i in range(len(features))]
# Convert and save to JSON
with open("training_data.json", "w") as train_json:
json.dump(convert_data(training_data), train_json)
with open("validation_data.json", "w") as val_json:
json.dump(convert_data(validation_data), val_json)
with open("test_data.json", "w") as test_json:
json.dump(convert_data(test_data), test_json)

View file

@ -1,17 +1,21 @@
use std::fmt::Debug;
use nalgebra::DMatrix;
use rand::prelude::*;
use serde::Deserialize;
pub fn load_data() -> (Data<f64, OneHotVector>, Data<f64, OneHotVector>) {
pub fn load_data() -> (Data<f32, OneHotVector>, Data<f32, OneHotVector>)
{
// the mnist data is structured as
// x: [[[pixels]],[[pixels]], etc],
// y: [label1, label2, etc]
// this is transformed to:
// Data : Vec<DataLine>
// DataLine {inputs: Vec<pixels as f64>, label: f64}
let raw_training_data: Vec<RawData> = serde_json::from_slice(include_bytes!("data/training.json")).unwrap();
let raw_test_data: Vec<RawData> = serde_json::from_slice(include_bytes!("data/test.json")).unwrap();
let raw_training_data: Vec<RawData> =
serde_json::from_slice(include_bytes!("data/training_data.json")).unwrap();
let raw_test_data: Vec<RawData> =
serde_json::from_slice(include_bytes!("data/test_data.json")).unwrap();
let train = vectorize(raw_training_data);
let test = vectorize(raw_test_data);
@ -19,17 +23,19 @@ pub fn load_data() -> (Data<f64, OneHotVector>, Data<f64, OneHotVector>) {
(Data(train), Data(test))
}
fn vectorize(raw_training_data: Vec<RawData>) -> Vec<DataLine<f64, OneHotVector>>{
fn vectorize(raw_data: Vec<RawData>) -> Vec<DataLine<f32, OneHotVector>>
{
let mut result = Vec::new();
for line in raw_training_data {
result.push(DataLine { inputs: line.x, label: onehot(line.y) });
for line in raw_data {
result.push(DataLine { inputs: DMatrix::from_vec(line.x.len(), 1, line.x), label: onehot(line.y) });
}
result
}
#[derive(Deserialize)]
struct RawData {
x: Vec<f64>,
struct RawData
{
x: Vec<f32>,
y: u8,
}
@ -37,24 +43,22 @@ struct RawData {
/// Y is type of output
#[derive(Debug, Clone)]
pub struct DataLine<X, Y> where X: Clone, Y: Clone {
pub inputs: Vec<X>,
pub inputs: DMatrix<X>,
pub label: Y,
}
/// simple way to encode a onehot vector. An object that returns 1.0 if you get the 'right' index, or 0.0 otherwise
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq)]
pub struct OneHotVector {
pub val: usize,
}
impl OneHotVector{
pub fn new(val: usize) -> Self {
Self {
val
}
Self { val }
}
pub fn get(&self, index: usize) -> f64 {
pub fn get(&self, index: usize) -> f32 {
if self.val == index {
1.0
} else {
@ -68,7 +72,7 @@ pub struct Data<X, Y>(pub Vec<DataLine<X, Y>>) where X: Clone, Y: Clone ;
impl<X, Y> Data<X, Y> where X: Clone, Y: Clone {
pub fn shuffle(&mut self) {
let mut rng = thread_rng();
let mut rng = rand::rng();
self.0.shuffle(&mut rng);
}

View file

@ -1,3 +1,2 @@
pub mod net;
pub mod dataloader;
mod mat;

View file

@ -1,16 +1,12 @@
use mnist_rs::dataloader::load_data;
use std::time::Instant;
fn main() {
let mut net = mnist_rs::net::Network::gaussian(vec![784, 30, 10]);
let (training_data, test_data) = load_data();
net.sgd(training_data, 30, 1, 0.01, Some(test_data));
// let sizes = vec![5,3,2];
// let net = mnist_rs::net::Network::from(sizes);
// println!("biases {:?}", net.biases.iter().map(|b|b.shape()).collect::<Vec<(usize,usize)>>());
// println!("weights {:?}", net.weights.iter().map(|b|b.shape()).collect::<Vec<(usize,usize)>>());
let t0 = Instant::now();
net.sgd(training_data, 30, 10, 3.0, Some(test_data));
println!("{}", t0.elapsed().as_millis());
}

View file

@ -1,209 +0,0 @@
use core::ops::Add;
use std::fmt::Debug;
use std::ops::AddAssign;
use nalgebra::DMatrix;
/// matrix add with broadcasting
/// like the numpy add operation
/// not sure I even need it anymore, after fixing inconsistencies with the matrix shapes
/// TODO see if it's still needed, or that standard matrix addition suffices
pub fn add<T>(v1: DMatrix<T>, v2: DMatrix<T>) -> Result<DMatrix<T>, String>
where T: PartialEq + Copy + Clone + Debug + Add + Add<Output=T> + AddAssign + 'static
{
let (r1, c1) = v1.shape();
let (r2, c2) = v2.shape();
if r1 == r2 && c1 == c2 {
// same size, no broadcasting needed
Ok(v1 + v2)
} else if r1 == 1 && c2 == 1 {
Ok(DMatrix::from_fn(r2, c1, |r, c| *v1.get(c).unwrap() + *v2.get(r).unwrap()))
} else if c1 == 1 && r2 == 1 {
Ok(DMatrix::from_fn(r1, c2, |r, c| *v1.get(r).unwrap() + *v2.get(c).unwrap()))
} else if r1 == 1 && c1 == c2 {
Ok(DMatrix::from_fn(r2, c1, |r, c| *v1.get(c).unwrap() + *v2.get(c * r2 + r).unwrap()))
} else if r2 == 1 && c1 == c2 {
Ok(DMatrix::from_fn(r1, c2, |r, c| *v2.get(c).unwrap() + *v1.get(c * r1 + r).unwrap()))
} else if c1 == 1 && r1 == r2 {
Ok(DMatrix::from_fn(r1, c2, |r, c| *v1.get(r).unwrap() + *v2.get(c * r2 + r).unwrap()))
} else if c2 == 1 && r1 == r2 {
Ok(DMatrix::from_fn(r2, c1, |r, c| *v2.get(r).unwrap() + *v1.get(c * r2 + r).unwrap()))
} else {
Err(format!("ValueError: operands could not be broadcast together ({},{}), ({},{})", r1,c1, r2, c2))
}
}
#[cfg(test)]
mod test {
use nalgebra::dmatrix;
use super::*;
#[test]
fn stretch_row_column_to_square() {
let v1: DMatrix<u32> = dmatrix![1,2,3];
let v2: DMatrix<u32> = dmatrix![1;2;3];
let sum = add(v1, v2).unwrap();
assert_eq!(sum.shape(), (3, 3));
assert_eq!(sum, dmatrix![2,3,4;3,4,5;4,5,6]);
}
#[test]
fn stretch_row_column_to_rect() {
let v1: DMatrix<u32> = dmatrix![1,2,3];
let v2: DMatrix<u32> = dmatrix![1;2];
let sum = add(v1, v2).unwrap();
assert_eq!(sum.shape(), (2, 3));
assert_eq!(sum, dmatrix![2,3,4;3,4,5]);
}
#[test]
fn stretch_column_row_to_square() {
let v1: DMatrix<u32> = dmatrix![1;2;3];
let v2: DMatrix<u32> = dmatrix![1,2,3];
let sum = add(v1, v2).unwrap();
assert_eq!(sum.shape(), (3, 3));
assert_eq!(sum, dmatrix![2,3,4;3,4,5;4,5,6]);
}
#[test]
fn stretch_column_row_to_rect() {
let v1: DMatrix<u32> = dmatrix![1;2;3];
let v2: DMatrix<u32> = dmatrix![1,2];
let sum = add(v1, v2).unwrap();
assert_eq!(sum.shape(), (3, 2));
assert_eq!(sum, dmatrix![2,3;3,4;4,5]);
}
#[test]
fn stretch_row() {
let v1: DMatrix<u32> = dmatrix![1,2,3];
let v2: DMatrix<u32> = dmatrix![1,2,3;4,5,6];
let sum = add(v1, v2).unwrap();
assert_eq!(sum.shape(), (2, 3));
assert_eq!(sum, dmatrix![2,4,6;5,7,9]);
}
#[test]
fn stretch_row_commute() {
let v1: DMatrix<u32> = dmatrix![1,2,3;4,5,6];
let v2: DMatrix<u32> = dmatrix![1,2,3];
let sum = add(v1, v2).unwrap();
assert_eq!(sum.shape(), (2, 3));
assert_eq!(sum, dmatrix![2,4,6;5,7,9]);
}
#[test]
fn stretch_column() {
let v1: DMatrix<u32> = dmatrix![1;2];
let v2: DMatrix<u32> = dmatrix![1,2,3;4,5,6];
let sum = add(v1, v2).unwrap();
assert_eq!(sum.shape(), (2, 3));
assert_eq!(sum, dmatrix![2,3,4;6,7,8]);
}
#[test]
fn stretch_column_commute() {
let v1: DMatrix<u32> = dmatrix![1,2,3;4,5,6];
let v2: DMatrix<u32> = dmatrix![1;2];
let sum = add(v1, v2).unwrap();
assert_eq!(sum.shape(), (2, 3));
assert_eq!(sum, dmatrix![2,3,4;6,7,8]);
}
#[test]
fn test_broadcast_2dims() {
let v1: DMatrix<u32> = dmatrix![1,2,3];
let v2: DMatrix<u32> = dmatrix![1;2;3];
let sum = add(v1, v2).unwrap();
assert_eq!(sum.shape(), (3, 3));
assert_eq!(sum, dmatrix![2,3,4;3,4,5;4,5,6]);
}
#[test]
fn test_add_commutative() {
let v1 = dmatrix![1,2,3];
let v2 = dmatrix![1;2;3];
let sum = add(v2, v1).unwrap();
assert_eq!(sum.shape(), (3, 3));
assert_eq!(sum, dmatrix![2,3,4;3,4,5;4,5,6]);
}
#[test]
fn test_add_same_size() {
let v1 = dmatrix![1,2;3,4];
let v2 = dmatrix![3,4;5,6];
let sum = add(v2, v1).unwrap();
assert_eq!(sum.shape(), (2, 2));
assert_eq!(sum, dmatrix![4,6;8,10]);
}
#[test]
fn test_add_row_broadcast() {//
let v1 = dmatrix![1,2;3,4];
let v2 = dmatrix![3,4];
let sum = add(v1, v2).unwrap();
assert_eq!(sum, dmatrix![4,6;6,8]);
}
#[test]
fn test_add_row_broadcast2() {
let v1 = dmatrix![1,1];
let v2 = dmatrix![1,2;3,4];
let sum = add(v1, v2).unwrap();
assert_eq!(sum.shape(), (2, 2));
assert_eq!(sum, dmatrix![2,3;4,5]);
}
#[test]
fn test_column_broadcast() {
let v1 = dmatrix![1;1];
let v2 = dmatrix![1,2;3,4];
let sum = add(v1, v2).unwrap();
assert_eq!(sum.shape(), (2, 2));
assert_eq!(sum, dmatrix![2,3;4,5]);
}
#[test]
fn test_column_broadcast2() {
let v1 = dmatrix![1,2;3,4];
let v2 = dmatrix![1;1];
let sum = add(v1, v2).unwrap();
assert_eq!(sum.shape(), (2, 2));
assert_eq!(sum, dmatrix![2,3;4,5]);
}
#[test]
fn column_too_long() {
let v1 = dmatrix![1;1;1];
let v2 = dmatrix![1,2;3,4];
let result = add(v1, v2);
assert_eq!(result, Err("ValueError: operands could not be broadcast together".to_owned()));
}
#[test]
fn row_too_long() {
let v1 = dmatrix![1,1,1];
let v2 = dmatrix![1,2;3,4];
let result = add(v1, v2);
assert_eq!(result, Err("ValueError: operands could not be broadcast together".to_owned()));
}
}

View file

@ -6,14 +6,13 @@ use rand::prelude::*;
use rand_distr::Normal;
use crate::dataloader::{Data, DataLine, OneHotVector};
use crate::mat::add;
#[derive(Debug)]
pub struct Network {
_sizes: Vec<usize>,
num_layers: usize,
pub biases: Vec<DMatrix<f64>>,
pub weights: Vec<DMatrix<f64>>,
pub biases: Vec<DMatrix<f32>>,
pub weights: Vec<DMatrix<f32>>,
}
impl Network{
@ -49,20 +48,20 @@ impl Network {
}
}
fn feed_forward(&self, input: Vec<f64>) -> Vec<f64> {
fn feed_forward(&self, input: &DMatrix<f32>) -> DMatrix<f32> {
self.feed_forward_activation(input, sigmoid_inplace)
}
fn feed_forward_activation(&self, input: Vec<f64>, activation: fn(&mut f64)) -> Vec<f64> {
let mut a = DMatrix::from_vec(input.len(), 1, input);
fn feed_forward_activation(&self, input: &DMatrix<f32>, activation: fn(&mut f32)) -> DMatrix<f32> {
let mut a = input.clone();
for (b, w) in zip(&self.biases, &self.weights) {
a = add(b.clone(), w * a).unwrap();
a = b + w * a;
a.apply(activation);
}
a.column(0).iter().copied().collect()
a
}
pub fn sgd(&mut self, mut training_data: Data<f64, OneHotVector>, epochs: usize, minibatch_size: usize, eta: f64, test_data: Option<Data<f64, OneHotVector>>) {
pub fn sgd(&mut self, mut training_data: Data<f32, OneHotVector>, epochs: usize, minibatch_size: usize, eta: f32, test_data: Option<Data<f32, OneHotVector>>) {
for j in 0..epochs {
training_data.shuffle();
let mini_batches = training_data.as_batches(minibatch_size);
@ -82,38 +81,29 @@ impl Network {
/// gradient descent using backpropagation to a single mini batch.
/// The ``mini_batch`` is a list of tuples ``(x, y)``, and ``eta``
/// is the learning rate.
fn update_mini_batch(&mut self, mini_batch: &[DataLine<f64, OneHotVector>], eta: f64) {
let mut nabla_b: Vec<DMatrix<f64>> = self.biases.iter()
.map(|b| b.shape())
.map(|s| DMatrix::zeros(s.0, s.1))
.collect();
let mut nabla_w: Vec<DMatrix<f64>> = self.weights.iter()
.map(|w| w.shape())
.map(|s| DMatrix::zeros(s.0, s.1))
.collect();
fn update_mini_batch(&mut self, mini_batch: &[DataLine<f32, OneHotVector>], eta: f32) {
let (mut nabla_b, mut nabla_w) = self.zero_gradient();
for line in mini_batch.iter() {
let (delta_nabla_b, delta_nabla_w) = self.backprop(line.inputs.to_vec(), &line.label);
let (delta_nabla_b, delta_nabla_w) = self.backprop(&line.inputs, &line.label);
// nabla_b = [nb + dnb for nb, dnb in zip(nabla_b, delta_nabla_b)]
// nabla_w = [nw + dnw for nw, dnw in zip(nabla_w, delta_nabla_w)]
nabla_b = zip(&nabla_b, &delta_nabla_b).map(|(nb, dnb)| nb.add(dnb)).collect();
nabla_w = zip(&nabla_w, &delta_nabla_w).map(|(nw, dnw)| nw.add(dnw)).collect();
}
self.weights = zip(&self.weights, &nabla_w)
.map(|(w, nw)| w.sub(nw.scale(eta / mini_batch.len() as f64))).collect();
.map(|(w, nw)| w.sub(nw.scale(eta / mini_batch.len() as f32))).collect();
self.biases = zip(&self.biases, &nabla_b)
.map(|(b, nb)| b.sub(nb.scale(eta / mini_batch.len() as f64))).collect();
.map(|(b, nb)| b.sub(nb.scale(eta / mini_batch.len() as f32))).collect();
}
/// Return the number of test inputs for which the neural
/// network outputs the correct result. Note that the neural
/// network's output is assumed to be the index of whichever
/// neuron in the final layer has the highest activation.
fn evaluate(&self, test_data: &Data<f64, OneHotVector>) -> usize {
fn evaluate(&self, test_data: &Data<f32, OneHotVector>) -> usize {
let test_results: Vec<(usize, usize)> = test_data.0.iter()
.map(|line| (argmax(self.feed_forward(line.inputs.clone())), line.label.val))
.map(|line| (argmax(self.feed_forward(&line.inputs)), line.label.val))
.collect();
test_results.into_iter().filter(|(x, y)| *x == *y).count()
@ -123,38 +113,27 @@ impl Network {
/// gradient for the cost function C_x. `nabla_b` and
/// `nabla_w` are layer-by-layer lists of matrices, similar
/// to `self.biases` and `self.weights`.
fn backprop(&self, x: Vec<f64>, y: &OneHotVector) -> (Vec<DMatrix<f64>>, Vec<DMatrix<f64>>) {
// zero_grad ie. set gradient to zero
let mut nabla_b: Vec<DMatrix<f64>> = self.biases.iter()
.map(|b| b.shape())
.map(|s| DMatrix::zeros(s.0, s.1))
.collect();
let mut nabla_w: Vec<DMatrix<f64>> = self.weights.iter()
.map(|w| w.shape())
.map(|s| DMatrix::zeros(s.0, s.1))
.collect();
fn backprop(&self, x: &DMatrix<f32>, y: &OneHotVector) -> (Vec<DMatrix<f32>>, Vec<DMatrix<f32>>) {
let (mut nabla_b, mut nabla_w) = self.zero_gradient();
// feedforward
let mut activation = DMatrix::from_vec(x.len(), 1, x);
let mut activation = x.clone();
let mut activations = vec![activation.clone()];
let mut zs = vec![];
for (b, w) in zip(&self.biases, &self.weights) {
let z = (w * &activation)+b.clone();
let z = (w * activation) + b;
zs.push(z.clone());
activation = z.map(sigmoid);
activations.push(activation.clone());
}
// backward pass
// delta = self.cost_derivative(activations[-1], y) * sigmoid_prime(zs[-1])
let delta: DMatrix<f64> = cost_derivative(&activations[activations.len() - 1], y).component_mul(&zs[zs.len() - 1].map(sigmoid_prime));
// println!("delta {:?}", delta);
let delta: DMatrix<f32> = cost_derivative(&activations[activations.len() - 1], y).component_mul(&zs[zs.len() - 1].map(sigmoid_prime));
let index = nabla_b.len() - 1;
nabla_b[index] = delta.clone();
let index = nabla_w.len() - 1;
let ac = &activations[activations.len() - 2].transpose();
nabla_w[index] = &delta * ac;
nabla_w[index] = &delta * (&activations[activations.len() - 2].transpose());
let lens_zs = zs.len();
for l in 2..self.num_layers {
let z = &zs[lens_zs - l];
@ -168,71 +147,76 @@ impl Network {
(nabla_b, nabla_w)
}
fn zero_gradient(&self) -> (Vec<DMatrix<f32>>, Vec<DMatrix<f32>>) {
let nabla_b: Vec<DMatrix<f32>> = self.biases.iter()
.map(|b| b.shape())
.map(|s| DMatrix::zeros(s.0, s.1))
.collect();
let nabla_w: Vec<DMatrix<f32>> = self.weights.iter()
.map(|w| w.shape())
.map(|s| DMatrix::zeros(s.0, s.1))
.collect();
(nabla_b, nabla_w)
}
}
fn cost_derivative(output_activations: &DMatrix<f64>, y: &OneHotVector) -> DMatrix<f64> {
// output_activations - y
// println!("output {:?}", output_activations);
// println!("expected {:?}", y);
fn cost_derivative(output_activations: &DMatrix<f32>, y: &OneHotVector) -> DMatrix<f32> {
let shape = output_activations.shape();
let t = DMatrix::from_iterator(shape.0, shape.1, output_activations.iter().enumerate()
.map(|(index, a)| a - y.get(index)));
// println!("t {:?}",t);
t
DMatrix::from_iterator(shape.0, shape.1, output_activations.iter().enumerate()
.map(|(index, a)| a - y.get(index)))
}
fn argmax(val: Vec<f64>) -> usize {
/// index of max value
/// only meaningful for single row or column matrix
fn argmax(val: DMatrix<f32>) -> usize {
let mut max = 0.0;
let mut index = 0;
for (i, x) in val.iter().enumerate() {
// print!("{},",x);
if *x > max {
index = i;
max = *x;
}
}
// println!();
index
}
fn biases(sizes: Vec<usize>, init: fn(&usize) -> DMatrix<f64>) -> Vec<DMatrix<f64>> {
fn biases(sizes: Vec<usize>, init: fn(&usize) -> DMatrix<f32>) -> Vec<DMatrix<f32>> {
sizes.iter().map(init).collect()
}
fn weights(sizes: Vec<(usize, usize)>, init: fn(&(usize, usize)) -> DMatrix<f64>) -> Vec<DMatrix<f64>> {
fn weights(sizes: Vec<(usize, usize)>, init: fn(&(usize, usize)) -> DMatrix<f32>) -> Vec<DMatrix<f32>> {
sizes.iter().map(init).collect()
}
fn random_matrix(rows: usize, cols: usize) -> DMatrix<f64> {
let normal: Normal<f64> = Normal::new(0.0, 1.0).unwrap();
fn random_matrix(rows: usize, cols: usize) -> DMatrix<f32> {
let normal: Normal<f32> = Normal::new(0.0, 1.0).unwrap();
DMatrix::from_fn(rows, cols, |_, _| normal.sample(&mut thread_rng()))
DMatrix::from_fn(rows, cols, |_, _| normal.sample(&mut rand::rng()))
}
fn sigmoid_inplace(val: &mut f64) {
fn sigmoid_inplace(val: &mut f32) {
*val = sigmoid(*val);
}
fn sigmoid(val: f64) -> f64 {
fn sigmoid(val: f32) -> f32 {
1.0 / (1.0 + (-val).exp())
}
/// Derivative of the sigmoid function.
fn sigmoid_prime(val: f64) -> f64 {
fn sigmoid_prime(val: f32) -> f32 {
sigmoid(val) * (1.0 - sigmoid(val))
}
#[cfg(test)]
mod test {
use std::convert::identity;
use nalgebra::DMatrix;
use super::*;
#[test]
fn test_sigmoid() {
let mut mat: DMatrix<f64> = DMatrix::from_vec(1, 1, vec![0.0]);
let mut mat: DMatrix<f32> = DMatrix::from_vec(1, 1, vec![0.0]);
mat.apply(sigmoid_inplace);
assert_eq!(mat.get(0), Some(&0.5));
}
@ -241,7 +225,7 @@ mod test {
fn test_sigmoid_inplace() {
let mut v = 10.0;
sigmoid_inplace(&mut v);
assert_eq!(0.9999546, v);
assert_eq!(0.9999546021312976, v);
}
#[test]
@ -251,7 +235,7 @@ mod test {
#[test]
fn test_argmax() {
assert_eq!(5, argmax(vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 4.0, 3.0, 2.0, 1.0]));
assert_eq!(5, argmax(DMatrix::from_vec(10, 1, vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 4.0, 3.0, 2.0, 1.0])));
}
#[test]
@ -266,15 +250,15 @@ mod test {
// 2 layers of 2 units
let mut net = Network::ones(vec![2, 2]);
let prediction = net.feed_forward_activation(vec![2.0, 2.0], |a| {});
assert_eq!(prediction, vec![5.0, 5.0])
let prediction = net.feed_forward_activation(&DMatrix::from_vec(2, 1, vec![2.0, 2.0]), |a| {});
assert_eq!(prediction, DMatrix::from_vec(2, 1, vec![5.0, 5.0]))
}
#[test]
fn test_sgd() {
// 2 layers of 2 units
let mut net = Network::ones(vec![2, 2]);
let data = Data(vec![DataLine { inputs: vec![1.0, 1.0], label: OneHotVector::new(1) }]);
let data = Data(vec![DataLine { inputs: DMatrix::from_vec(2, 1, vec![1.0, 1.0]), label: OneHotVector::new(1) }]);
net.sgd(data, 1, 1, 0.001, None);
println!("{:?}", net);
}