fixed clippy warnings
This commit is contained in:
parent
f4fd0a5fe0
commit
068cf2a1d1
3 changed files with 12 additions and 6 deletions
|
|
@ -1,6 +1,4 @@
|
||||||
use std::iter::zip;
|
use std::iter::zip;
|
||||||
use nalgebra::DMatrix;
|
|
||||||
|
|
||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
|
||||||
|
|
@ -33,7 +31,6 @@ pub struct DataLine<X, Y> {
|
||||||
pub label: Y,
|
pub label: Y,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Data<X, Y>(pub Vec<DataLine<X, Y>>);
|
|
||||||
|
|
||||||
pub struct OneHotVector{
|
pub struct OneHotVector{
|
||||||
pub val: usize
|
pub val: usize
|
||||||
|
|
@ -57,6 +54,8 @@ impl OneHotVector{
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct Data<X, Y>(pub Vec<DataLine<X, Y>>);
|
||||||
|
|
||||||
impl<X, Y> Data<X, Y> {
|
impl<X, Y> Data<X, Y> {
|
||||||
pub fn shuffle(&mut self) {
|
pub fn shuffle(&mut self) {
|
||||||
let mut rng = thread_rng();
|
let mut rng = thread_rng();
|
||||||
|
|
@ -67,6 +66,10 @@ impl<X, Y> Data<X, Y> {
|
||||||
self.0.len()
|
self.0.len()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn is_empty(&self, ) -> bool{
|
||||||
|
self.0.is_empty()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn as_batches(&self, batch_size: usize) -> Vec<&[DataLine<X, Y>]> {
|
pub fn as_batches(&self, batch_size: usize) -> Vec<&[DataLine<X, Y>]> {
|
||||||
let mut batches = Vec::with_capacity(self.0.len() / batch_size + 1);
|
let mut batches = Vec::with_capacity(self.0.len() / batch_size + 1);
|
||||||
let mut offset = 0;
|
let mut offset = 0;
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,10 @@ use std::fmt::Debug;
|
||||||
use std::ops::AddAssign;
|
use std::ops::AddAssign;
|
||||||
use nalgebra::DMatrix;
|
use nalgebra::DMatrix;
|
||||||
|
|
||||||
|
/// matrix add with broadcasting
|
||||||
|
/// like the numpy add operation
|
||||||
|
/// not sure I even need it anymore, after fixing inconsistencies with the matrix shapes
|
||||||
|
/// TODO see if it's still needed, or that standard matrix addition suffices
|
||||||
pub fn add<T>(v1: DMatrix<T>, v2: DMatrix<T>) -> Result<DMatrix<T>, String>
|
pub fn add<T>(v1: DMatrix<T>, v2: DMatrix<T>) -> Result<DMatrix<T>, String>
|
||||||
where T: PartialEq + Copy + Clone + Debug + Add + Add<Output=T> + AddAssign + 'static
|
where T: PartialEq + Copy + Clone + Debug + Add + Add<Output=T> + AddAssign + 'static
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,6 @@ use rand::prelude::*;
|
||||||
use rand_distr::Normal;
|
use rand_distr::Normal;
|
||||||
|
|
||||||
use crate::dataloader::{Data, DataLine, OneHotVector};
|
use crate::dataloader::{Data, DataLine, OneHotVector};
|
||||||
use crate::mat;
|
|
||||||
use crate::mat::add;
|
use crate::mat::add;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
|
@ -43,7 +42,7 @@ impl Network {
|
||||||
a = add(b.clone(), w * a).unwrap();
|
a = add(b.clone(), w * a).unwrap();
|
||||||
a.apply(sigmoid_inplace);
|
a.apply(sigmoid_inplace);
|
||||||
}
|
}
|
||||||
a.column(1).iter().map(|v| *v).collect()
|
a.column(1).iter().copied().collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn sgd(&mut self, mut training_data: Data<f32, OneHotVector>, epochs: usize, minibatch_size: usize, eta: f32, test_data: &Option<Data<f32, OneHotVector>>) {
|
pub fn sgd(&mut self, mut training_data: Data<f32, OneHotVector>, epochs: usize, minibatch_size: usize, eta: f32, test_data: &Option<Data<f32, OneHotVector>>) {
|
||||||
|
|
@ -127,7 +126,7 @@ impl Network {
|
||||||
}
|
}
|
||||||
// backward pass
|
// backward pass
|
||||||
// delta = self.cost_derivative(activations[-1], y) * sigmoid_prime(zs[-1])
|
// delta = self.cost_derivative(activations[-1], y) * sigmoid_prime(zs[-1])
|
||||||
let delta: DMatrix<f32> = self.cost_derivative(&activations[activations.len() - 1], y).component_mul((&zs[zs.len() - 1].map(sigmoid_prime)));
|
let delta: DMatrix<f32> = self.cost_derivative(&activations[activations.len() - 1], y).component_mul(&zs[zs.len() - 1].map(sigmoid_prime));
|
||||||
let index = nabla_b.len() - 1;
|
let index = nabla_b.len() - 1;
|
||||||
nabla_b[index] = delta.clone();
|
nabla_b[index] = delta.clone();
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue