From 068cf2a1d1c1550ea8651df0b4aa533d7fe99507 Mon Sep 17 00:00:00 2001 From: Shautvast Date: Wed, 22 Feb 2023 10:14:43 +0100 Subject: [PATCH] fixed clippy warnings --- src/dataloader.rs | 9 ++++++--- src/mat.rs | 4 ++++ src/net.rs | 5 ++--- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/dataloader.rs b/src/dataloader.rs index a3ba0e5..44bcd6f 100644 --- a/src/dataloader.rs +++ b/src/dataloader.rs @@ -1,6 +1,4 @@ use std::iter::zip; -use nalgebra::DMatrix; - use rand::prelude::*; use serde::Deserialize; @@ -33,7 +31,6 @@ pub struct DataLine { pub label: Y, } -pub struct Data(pub Vec>); pub struct OneHotVector{ pub val: usize @@ -57,6 +54,8 @@ impl OneHotVector{ } +pub struct Data(pub Vec>); + impl Data { pub fn shuffle(&mut self) { let mut rng = thread_rng(); @@ -67,6 +66,10 @@ impl Data { self.0.len() } + pub fn is_empty(&self, ) -> bool{ + self.0.is_empty() + } + pub fn as_batches(&self, batch_size: usize) -> Vec<&[DataLine]> { let mut batches = Vec::with_capacity(self.0.len() / batch_size + 1); let mut offset = 0; diff --git a/src/mat.rs b/src/mat.rs index 4564cc1..4e35de8 100644 --- a/src/mat.rs +++ b/src/mat.rs @@ -3,6 +3,10 @@ use std::fmt::Debug; use std::ops::AddAssign; use nalgebra::DMatrix; +/// matrix add with broadcasting +/// like the numpy add operation +/// not sure I even need it anymore, after fixing inconsistencies with the matrix shapes +/// TODO see if it's still needed, or that standard matrix addition suffices pub fn add(v1: DMatrix, v2: DMatrix) -> Result, String> where T: PartialEq + Copy + Clone + Debug + Add + Add + AddAssign + 'static { diff --git a/src/net.rs b/src/net.rs index 0b68ce6..0681a34 100644 --- a/src/net.rs +++ b/src/net.rs @@ -6,7 +6,6 @@ use rand::prelude::*; use rand_distr::Normal; use crate::dataloader::{Data, DataLine, OneHotVector}; -use crate::mat; use crate::mat::add; #[derive(Debug)] @@ -43,7 +42,7 @@ impl Network { a = add(b.clone(), w * a).unwrap(); a.apply(sigmoid_inplace); } - a.column(1).iter().map(|v| *v).collect() + a.column(1).iter().copied().collect() } pub fn sgd(&mut self, mut training_data: Data, epochs: usize, minibatch_size: usize, eta: f32, test_data: &Option>) { @@ -127,7 +126,7 @@ impl Network { } // backward pass // delta = self.cost_derivative(activations[-1], y) * sigmoid_prime(zs[-1]) - let delta: DMatrix = self.cost_derivative(&activations[activations.len() - 1], y).component_mul((&zs[zs.len() - 1].map(sigmoid_prime))); + let delta: DMatrix = self.cost_derivative(&activations[activations.len() - 1], y).component_mul(&zs[zs.len() - 1].map(sigmoid_prime)); let index = nabla_b.len() - 1; nabla_b[index] = delta.clone();