not working

This commit is contained in:
Sander Hautvast 2023-02-14 17:42:02 +01:00
commit d9ba8cc079
7 changed files with 622 additions and 0 deletions

3
.gitignore vendored Normal file
View file

@ -0,0 +1,3 @@
/target
*.iml
.idea

320
Cargo.lock generated Normal file
View file

@ -0,0 +1,320 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
[[package]]
name = "approx"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6"
dependencies = [
"num-traits",
]
[[package]]
name = "autocfg"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "bytemuck"
version = "1.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c041d3eab048880cb0b86b256447da3f18859a163c3b8d8893f4e6368abe6393"
[[package]]
name = "cfg-if"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "getrandom"
version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31"
dependencies = [
"cfg-if",
"libc",
"wasi",
]
[[package]]
name = "itoa"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fad582f4b9e86b6caa621cabeb0963332d92eea04729ab12892c2533951e6440"
[[package]]
name = "libc"
version = "0.2.139"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "201de327520df007757c1f0adce6e827fe8562fbc28bfd9c15571c66ca1f5f79"
[[package]]
name = "libm"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "348108ab3fba42ec82ff6e9564fc4ca0247bdccdc68dd8af9764bbc79c3c8ffb"
[[package]]
name = "matrixmultiply"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "add85d4dd35074e6fedc608f8c8f513a3548619a9024b751949ef0e8e45a4d84"
dependencies = [
"rawpointer",
]
[[package]]
name = "mnist-rs"
version = "0.1.0"
dependencies = [
"nalgebra",
"rand",
"rand_distr",
"serde",
"serde_json",
]
[[package]]
name = "nalgebra"
version = "0.32.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6515c882ebfddccaa73ead7320ca28036c4bc84c9bcca3cc0cbba8efe89223a"
dependencies = [
"approx",
"matrixmultiply",
"nalgebra-macros",
"num-complex",
"num-rational",
"num-traits",
"simba",
"typenum",
]
[[package]]
name = "nalgebra-macros"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d232c68884c0c99810a5a4d333ef7e47689cfd0edc85efc9e54e1e6bf5212766"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "num-complex"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02e0d21255c828d6f128a1e41534206671e8c3ea0c62f32291e808dc82cff17d"
dependencies = [
"num-traits",
]
[[package]]
name = "num-integer"
version = "0.1.45"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9"
dependencies = [
"autocfg",
"num-traits",
]
[[package]]
name = "num-rational"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0"
dependencies = [
"autocfg",
"num-integer",
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd"
dependencies = [
"autocfg",
"libm",
]
[[package]]
name = "paste"
version = "1.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d01a5bd0424d00070b0098dd17ebca6f961a959dead1dbcbbbc1d1cd8d3deeba"
[[package]]
name = "ppv-lite86"
version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
[[package]]
name = "proc-macro2"
version = "1.0.51"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d727cae5b39d21da60fa540906919ad737832fe0b1c165da3a34d6548c849d6"
dependencies = [
"unicode-ident",
]
[[package]]
name = "quote"
version = "1.0.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8856d8364d252a14d474036ea1358d63c9e6965c8e5c1885c18f73d70bff9c7b"
dependencies = [
"proc-macro2",
]
[[package]]
name = "rand"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
dependencies = [
"libc",
"rand_chacha",
"rand_core",
]
[[package]]
name = "rand_chacha"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
dependencies = [
"ppv-lite86",
"rand_core",
]
[[package]]
name = "rand_core"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
dependencies = [
"getrandom",
]
[[package]]
name = "rand_distr"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31"
dependencies = [
"num-traits",
"rand",
]
[[package]]
name = "rawpointer"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
[[package]]
name = "ryu"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b4b9743ed687d4b4bcedf9ff5eaa7398495ae14e61cba0a295704edbc7decde"
[[package]]
name = "safe_arch"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "794821e4ccb0d9f979512f9c1973480123f9bd62a90d74ab0f9426fcf8f4a529"
dependencies = [
"bytemuck",
]
[[package]]
name = "serde"
version = "1.0.152"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bb7d1f0d3021d347a83e556fc4683dea2ea09d87bccdf88ff5c12545d89d5efb"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.152"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "af487d118eecd09402d70a5d72551860e788df87b464af30e5ea6a38c75c541e"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "serde_json"
version = "1.0.93"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cad406b69c91885b5107daf2c29572f6c8cdb3c66826821e286c533490c0bc76"
dependencies = [
"itoa",
"ryu",
"serde",
]
[[package]]
name = "simba"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "50582927ed6f77e4ac020c057f37a268fc6aebc29225050365aacbb9deeeddc4"
dependencies = [
"approx",
"num-complex",
"num-traits",
"paste",
"wide",
]
[[package]]
name = "syn"
version = "1.0.107"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f4064b5b16e03ae50984a5a8ed5d4f8803e6bc1fd170a3cda91a1be4b18e3f5"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "typenum"
version = "1.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba"
[[package]]
name = "unicode-ident"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "84a22b9f218b40614adcb3f4ff08b703773ad44fa9423e4e0d346d5db86e4ebc"
[[package]]
name = "wasi"
version = "0.11.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
[[package]]
name = "wide"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "feff0a412894d67223777b6cc8d68c0dab06d52d95e9890d5f2d47f10dd9366c"
dependencies = [
"bytemuck",
"safe_arch",
]

13
Cargo.toml Normal file
View file

@ -0,0 +1,13 @@
[package]
name = "mnist-rs"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
rand = "0.8"
rand_distr = "0.4"
nalgebra = "0.32"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"

60
src/dataloader.rs Normal file
View file

@ -0,0 +1,60 @@
use std::iter::zip;
use rand::prelude::*;
use serde::Deserialize;
pub fn load_data() -> Data<f32, u8> {
/// the mnist data is structured as
/// x: [[[pixels]],[[pixels]], etc],
/// y: [label1, label2, etc]
/// this is transformed to:
/// Data : Vec<DataLine>
/// DataLine {inputs: Vec<pixels as f32>, label: f32}
let raw_data: RawData = serde_json::from_slice(include_bytes!("data/unittest.json")).unwrap();
let mut vec = Vec::new();
for (x, y) in zip(raw_data.x, raw_data.y) {
vec.push(DataLine { inputs: x, label: y});
}
Data(vec)
}
#[derive(Deserialize)]
struct RawData {
x: Vec<Vec<f32>>,
y: Vec<u8>,
}
/// X is type of input
/// Y is type of output
pub struct DataLine<X,Y> {
pub inputs: Vec<X>,
pub label: Y,
}
pub struct Data<X,Y>(pub Vec<DataLine<X,Y>>);
impl<X,Y> Data<X,Y> {
pub fn shuffle(&mut self) {
let mut rng = rand::thread_rng();
self.0.shuffle(&mut rng);
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn as_batches(&self, batch_size: usize) -> Vec<&[DataLine<X,Y>]> {
let mut batches = Vec::with_capacity(self.0.len() / batch_size + 1);
let mut offset = 0;
for _ in 0..self.0.len() / batch_size {
batches.push(&self.0[offset..offset + batch_size]);
offset += batch_size;
}
batches.push(&self.0[offset..self.0.len()]);
batches
}
}

2
src/lib.rs Normal file
View file

@ -0,0 +1,2 @@
pub mod net;
pub mod dataloader;

15
src/main.rs Normal file
View file

@ -0,0 +1,15 @@
use mnist_rs::dataloader::load_data;
fn main() {
let mut net = mnist_rs::net::Network::from(vec![784, 30, 10]);
for w in net.weights.iter() {
println!("{}, {}", w.shape().0, w.shape().1);
}
println!();
for b in net.biases.iter() {
println!("{:?}", b.shape());
}
let training_data = load_data();
net.sgd(training_data, 30, 10, 3.0, &None);
}

209
src/net.rs Normal file
View file

@ -0,0 +1,209 @@
use std::convert::identity;
use std::iter::zip;
use std::ops::{Add, Sub};
use nalgebra::{DMatrix, Matrix, OMatrix};
use rand::prelude::*;
use rand_distr::Normal;
use crate::dataloader::{Data, DataLine};
#[derive(Debug)]
pub struct Network {
_sizes: Vec<usize>,
_num_layers: usize,
pub biases: Vec<DMatrix<f32>>,
pub weights: Vec<DMatrix<f32>>,
}
impl Network {
/// The list `sizes` contains the number of neurons in the
/// respective layers of the network. For example, if the list
/// was [2, 3, 1] then it would be a three-layer network, with the
/// first layer containing 2 neurons, the second layer 3 neurons,
/// and the third layer 1 neuron. The biases and weights for the
/// network are initialized randomly, using a Gaussian
/// distribution with mean 0, and variance 1. Note that the first
/// layer is assumed to be an input layer, and by convention we
/// won't set any biases for those neurons, since biases are only
/// ever used in computing the outputs from later layers.
pub fn from(sizes: Vec<usize>) -> Self {
Self {
_sizes: sizes.clone(),
_num_layers: sizes.len(),
biases: biases(sizes[1..].to_vec()),
weights: weights(zip(sizes[..sizes.len() - 1].to_vec(), sizes[1..].to_vec()).collect()),
}
}
fn feed_forward(&self, input: Vec<f32>) -> Vec<f32> {
let mut a = DMatrix::from_vec(input.len(), 1, input);
for (b, w) in zip(&self.biases, &self.weights) {
a = b.add_scalar(w.dot(&a));
a.apply(sigmoid_inplace);
}
a.column(1).iter().map(|v| *v).collect()
}
pub fn sgd(&mut self, mut training_data: Data<f32, u8>, epochs: usize, minibatch_size: usize, eta: f32, test_data: &Option<Data<f32, u8>>) {
for j in 0..epochs {
training_data.shuffle();
let mini_batches = training_data.as_batches(minibatch_size);
for mini_batch in mini_batches {
self.update_mini_batch(mini_batch, eta);
}
if let Some(test_data) = test_data {
println!("Epoch {}: {} / {}", j, self.evaluate(test_data), test_data.len());
} else {
println!("Epoch {} complete", j);
}
}
}
/// Update the network's weights and biases by applying
/// gradient descent using backpropagation to a single mini batch.
/// The ``mini_batch`` is a list of tuples ``(x, y)``, and ``eta``
/// is the learning rate.
fn update_mini_batch(&mut self, mini_batch: &[DataLine<f32, u8>], eta: f32) {
let mut nabla_b: Vec<DMatrix<f32>> = self.biases.iter()
.map(|b| b.shape())
.map(|s| DMatrix::zeros(s.0, s.1))
.collect();
let mut nabla_w: Vec<DMatrix<f32>> = self.weights.iter()
.map(|w| w.shape())
.map(|s| DMatrix::zeros(s.0, s.1))
.collect();
for line in mini_batch.iter() {
let (delta_nabla_b, delta_nabla_w) = self.backprop(line.inputs.to_vec(), line.label);
nabla_b = zip(&nabla_b, &delta_nabla_b).map(|(nb, dnb)| nb.add(dnb)).collect();
nabla_w = zip(&nabla_w, &delta_nabla_w).map(|(nw, dnw)| nw.add(dnw)).collect();
}
self.weights = zip(&self.weights, &nabla_w)
.map(|(w, nw)| w.add_scalar(-eta / mini_batch.len() as f32)).collect();
self.biases = zip(&self.biases, &nabla_b)
.map(|(b, nb)| b.add_scalar(-eta / mini_batch.len() as f32)).collect();
}
/// Return the number of test inputs for which the neural
/// network outputs the correct result. Note that the neural
/// network's output is assumed to be the index of whichever
/// neuron in the final layer has the highest activation.
fn evaluate(&self, test_data: &Data<f32, u8>) -> usize {
let test_results: Vec<(usize, u8)> = test_data.0.iter()
.map(|line| (argmax(self.feed_forward(line.inputs.clone())), line.label))
.collect();
test_results.into_iter().filter(|(x, y)| *x == *y as usize).count()
}
/// Return a tuple `(nabla_b, nabla_w)` representing the
/// gradient for the cost function C_x. `nabla_b` and
/// `nabla_w` are layer-by-layer lists of matrices, similar
/// to `self.biases` and `self.weights`.
fn backprop(&self, x: Vec<f32>, y: u8) -> (Vec<DMatrix<f32>>, Vec<DMatrix<f32>>) {
// zero_grad ie. set gradient to zero
let mut nabla_b: Vec<DMatrix<f32>> = self.biases.iter()
.map(|b| b.shape())
.map(|s| DMatrix::zeros(s.0, s.1))
.collect();
let mut nabla_w: Vec<DMatrix<f32>> = self.weights.iter()
.map(|w| w.shape())
.map(|s| DMatrix::zeros(s.0, s.1))
.collect();
// feedforward
let mut activation = DMatrix::from_vec(x.len(), 1, x);
let mut activations = vec![activation.clone()];
let mut zs = vec![];
for (b, w) in zip(&self.biases, &self.weights) {
// println!("{:?}", w.shape());
// println!("{:?}", activation.shape());
// println!("{:?}", b.shape());
let mut z: DMatrix<f32> = w * &activation + b;
zs.push(z.clone());
activation = z.map(sigmoid);
activations.push(activation.clone());
}
// backward pass
let delta: DMatrix<f32> = self.cost_derivative(
&activations[activations.len() - 1],
y as f32);
println!("delta {:?}", delta.shape());
println!("z {:?}", &zs[zs.len() - 1].transpose().shape());
let delta = delta * (&zs[zs.len() - 1].transpose().map(sigmoid_prime));
println!("delta {:?}", delta.shape());
let index = nabla_b.len() - 1;
nabla_b[index] = delta.clone();
println!("delta {:?}", delta.shape());
println!("activation {:?}", activations[activations.len() - 2].shape());
let index = nabla_w.len() - 1;
nabla_w[index] = delta * &activations[activations.len() - 2];
(nabla_b, nabla_w)
}
fn cost_derivative(&self, output_activations: &DMatrix<f32>, y: f32) -> DMatrix<f32> {
output_activations.add_scalar(-y)
}
}
fn argmax(val: Vec<f32>) -> usize {
let mut max = 0.0;
let mut index = 0;
for (i, x) in val.iter().enumerate() {
if *x > max {
index = i;
max = *x;
}
}
index
}
fn biases(sizes: Vec<usize>) -> Vec<DMatrix<f32>> {
sizes.iter().map(|size| random_matrix(*size, 1)).collect()
}
fn weights(sizes: Vec<(usize, usize)>) -> Vec<DMatrix<f32>> {
println!("{:?}", sizes);
sizes.iter().map(|size| random_matrix(size.1, size.0)).collect()
}
fn random_matrix(rows: usize, cols: usize) -> DMatrix<f32> {
let normal: Normal<f32> = Normal::new(0.0, 1.0).unwrap();
DMatrix::from_fn(rows, cols, |_, _| normal.sample(&mut thread_rng()))
}
fn sigmoid_inplace(val: &mut f32) {
*val = sigmoid(*val);
}
fn sigmoid(val: f32) -> f32 {
1.0 / (1.0 + (-val).exp())
}
/// Derivative of the sigmoid function.
fn sigmoid_prime(val: f32) -> f32 {
sigmoid(val) * (1.0 - sigmoid(val))
}
#[cfg(test)]
mod test {
use nalgebra::DMatrix;
use super::*;
#[test]
fn test_sigmoid() {
let mut mat: DMatrix<f32> = DMatrix::from_vec(1, 1, vec![0.0]);
mat.apply(sigmoid_inplace);
assert_eq!(mat.get(0), Some(&0.5));
}
}