fixed clippy warnings

This commit is contained in:
Shautvast 2023-02-22 10:14:43 +01:00
parent f4fd0a5fe0
commit 068cf2a1d1
3 changed files with 12 additions and 6 deletions

View file

@ -1,6 +1,4 @@
use std::iter::zip;
use nalgebra::DMatrix;
use rand::prelude::*;
use serde::Deserialize;
@ -33,7 +31,6 @@ pub struct DataLine<X, Y> {
pub label: Y,
}
pub struct Data<X, Y>(pub Vec<DataLine<X, Y>>);
pub struct OneHotVector{
pub val: usize
@ -57,6 +54,8 @@ impl OneHotVector{
}
pub struct Data<X, Y>(pub Vec<DataLine<X, Y>>);
impl<X, Y> Data<X, Y> {
pub fn shuffle(&mut self) {
let mut rng = thread_rng();
@ -67,6 +66,10 @@ impl<X, Y> Data<X, Y> {
self.0.len()
}
pub fn is_empty(&self, ) -> bool{
self.0.is_empty()
}
pub fn as_batches(&self, batch_size: usize) -> Vec<&[DataLine<X, Y>]> {
let mut batches = Vec::with_capacity(self.0.len() / batch_size + 1);
let mut offset = 0;

View file

@ -3,6 +3,10 @@ use std::fmt::Debug;
use std::ops::AddAssign;
use nalgebra::DMatrix;
/// matrix add with broadcasting
/// like the numpy add operation
/// not sure I even need it anymore, after fixing inconsistencies with the matrix shapes
/// TODO see if it's still needed, or that standard matrix addition suffices
pub fn add<T>(v1: DMatrix<T>, v2: DMatrix<T>) -> Result<DMatrix<T>, String>
where T: PartialEq + Copy + Clone + Debug + Add + Add<Output=T> + AddAssign + 'static
{

View file

@ -6,7 +6,6 @@ use rand::prelude::*;
use rand_distr::Normal;
use crate::dataloader::{Data, DataLine, OneHotVector};
use crate::mat;
use crate::mat::add;
#[derive(Debug)]
@ -43,7 +42,7 @@ impl Network {
a = add(b.clone(), w * a).unwrap();
a.apply(sigmoid_inplace);
}
a.column(1).iter().map(|v| *v).collect()
a.column(1).iter().copied().collect()
}
pub fn sgd(&mut self, mut training_data: Data<f32, OneHotVector>, epochs: usize, minibatch_size: usize, eta: f32, test_data: &Option<Data<f32, OneHotVector>>) {
@ -127,7 +126,7 @@ impl Network {
}
// backward pass
// delta = self.cost_derivative(activations[-1], y) * sigmoid_prime(zs[-1])
let delta: DMatrix<f32> = self.cost_derivative(&activations[activations.len() - 1], y).component_mul((&zs[zs.len() - 1].map(sigmoid_prime)));
let delta: DMatrix<f32> = self.cost_derivative(&activations[activations.len() - 1], y).component_mul(&zs[zs.len() - 1].map(sigmoid_prime));
let index = nabla_b.len() - 1;
nabla_b[index] = delta.clone();