use std::iter::zip; use nalgebra::DMatrix; use rand::prelude::*; use serde::Deserialize; pub fn load_data() -> Data { // the mnist data is structured as // x: [[[pixels]],[[pixels]], etc], // y: [label1, label2, etc] // this is transformed to: // Data : Vec // DataLine {inputs: Vec, label: f32} let raw_data: RawData = serde_json::from_slice(include_bytes!("data/unittest.json")).unwrap(); let mut vec = Vec::new(); for (x, y) in zip(raw_data.x, raw_data.y) { vec.push(DataLine { inputs: x, label: onehot(y) }); } Data(vec) } #[derive(Deserialize)] struct RawData { x: Vec>, y: Vec, } /// X is type of input /// Y is type of output pub struct DataLine { pub inputs: Vec, pub label: Y, } pub struct Data(pub Vec>); pub struct OneHotVector{ pub val: usize } impl OneHotVector{ fn new(val: usize) -> Self{ Self{ val } } pub fn get(&self, index: usize) -> f32{ if self.val == index { 1.0 } else { 0.0 } } } impl Data { pub fn shuffle(&mut self) { let mut rng = thread_rng(); self.0.shuffle(&mut rng); } pub fn len(&self) -> usize { self.0.len() } pub fn as_batches(&self, batch_size: usize) -> Vec<&[DataLine]> { let mut batches = Vec::with_capacity(self.0.len() / batch_size + 1); let mut offset = 0; for _ in 0..self.0.len() / batch_size { batches.push(&self.0[offset..offset + batch_size]); offset += batch_size; } batches.push(&self.0[offset..self.0.len()]); batches } } /// returns a vector as matrix where y is one-hot encoded fn onehot(y: u8) -> OneHotVector { OneHotVector::new(y as usize) }