use std::iter::zip; use rand::prelude::*; use serde::Deserialize; pub fn load_data() -> Data { /// the mnist data is structured as /// x: [[[pixels]],[[pixels]], etc], /// y: [label1, label2, etc] /// this is transformed to: /// Data : Vec /// DataLine {inputs: Vec, label: f32} let raw_data: RawData = serde_json::from_slice(include_bytes!("data/unittest.json")).unwrap(); let mut vec = Vec::new(); for (x, y) in zip(raw_data.x, raw_data.y) { vec.push(DataLine { inputs: x, label: y}); } Data(vec) } #[derive(Deserialize)] struct RawData { x: Vec>, y: Vec, } /// X is type of input /// Y is type of output pub struct DataLine { pub inputs: Vec, pub label: Y, } pub struct Data(pub Vec>); impl Data { pub fn shuffle(&mut self) { let mut rng = rand::thread_rng(); self.0.shuffle(&mut rng); } pub fn len(&self) -> usize { self.0.len() } pub fn as_batches(&self, batch_size: usize) -> Vec<&[DataLine]> { let mut batches = Vec::with_capacity(self.0.len() / batch_size + 1); let mut offset = 0; for _ in 0..self.0.len() / batch_size { batches.push(&self.0[offset..offset + batch_size]); offset += batch_size; } batches.push(&self.0[offset..self.0.len()]); batches } }