From 8d39285ff9ad01a7a74998195755274367a0a48b Mon Sep 17 00:00:00 2001 From: Sander Hautvast Date: Wed, 11 May 2022 18:39:57 +0200 Subject: [PATCH] started developing Max Likelyhood Estimator, counts vocabulary --- src/lm/mle.rs | 106 ++++++++++++++++++++++++++++++++++++++++ src/lm/mod.rs | 3 +- src/lm/preprocessing.rs | 8 +-- 3 files changed, 112 insertions(+), 5 deletions(-) create mode 100644 src/lm/mle.rs diff --git a/src/lm/mle.rs b/src/lm/mle.rs new file mode 100644 index 0000000..b22b805 --- /dev/null +++ b/src/lm/mle.rs @@ -0,0 +1,106 @@ +use std::collections::BTreeMap; + +struct Vocabulary<'a> { + cutoff: usize, + counter: Counter<'a>, + unk_label: &'a str, +} + +impl <'a>Vocabulary<'a> { + pub(crate) fn new(cutoff: usize) -> Self { + Self { + cutoff, + counter: Counter::new(), + unk_label: "", + } + } + + pub fn update_word(&mut self, word: &'a str) { + self.counter.update_word(word); + } + + pub fn update_sentence(&mut self, sentence: impl Iterator) { + self.counter.update_sentence(sentence); + } + + pub fn lookup_word(&self, word: &'a str) -> &str { + return if self.counter.get(word) > self.cutoff { + word + } else { + self.unk_label + }; + } + + pub(crate) fn lookup_sentence(&self, words: impl Iterator + 'a) -> impl Iterator + '_{ + words.map(|word| if self.counter.get(word) > self.cutoff { + word + } else { + self.unk_label + }) + } +} + +struct Counter<'a> { + counts: BTreeMap<&'a str, usize>, //may just need hashmap, not sure yet, do we need ordered keys? +} + +impl<'a> Counter<'a> { + pub(crate) fn new() -> Self { + Self { + counts: BTreeMap::new() + } + } + + pub(crate) fn update_word(&mut self, word: &'a str) { + let count = self.counts.entry(word).or_insert(0); + *count += 1; + } + + pub(crate) fn update_sentence(&mut self, sentence: impl Iterator) { + sentence.for_each(|word| self.update_word(word)); + } + + pub(crate) fn get(&self, word: &str) -> usize { + *self.counts.get(word).unwrap_or(&0) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_lookup() { + let mut vocab = Vocabulary::new(1); + vocab.update_sentence(["a", "b", "c", "a", "b", "c"].iter()); + let looked_up: Vec<&str> = vocab.lookup_sentence(["a", "b", "c"].iter()).collect(); + assert_eq!(looked_up, vec!["a", "b", "c"]); + let looked_up: Vec<&str> = vocab.lookup_sentence(["Aliens", "from", "Mars"].iter()).collect(); + assert_eq!(looked_up, vec!["", "", ""]); + } + + #[test] + fn test_lookup_below_cutoff() { + let mut vocab = Vocabulary::new(1); + vocab.update_sentence(["a", "b", "c"].iter()); + let looked_up: Vec<&str> = vocab.lookup_words(["a", "b", "c"].iter()).collect(); + assert_eq!(looked_up, vec!["", "", ""]); + } + + #[test] + fn test_count_words() { + let mut counter = Counter::new(); + counter.update_word("a"); + + assert_eq!(counter.get("a"), 1); + } + + #[test] + fn test_count_sentence() { + let mut counter = Counter::new(); + counter.update_sentence(["a", "b", "a"].iter()); + + assert_eq!(counter.get("a"), 2); + assert_eq!(counter.get("b"), 1); + } +} \ No newline at end of file diff --git a/src/lm/mod.rs b/src/lm/mod.rs index fbfe084..4447e47 100644 --- a/src/lm/mod.rs +++ b/src/lm/mod.rs @@ -1 +1,2 @@ -pub mod preprocessing; \ No newline at end of file +pub mod preprocessing; +pub mod mle; \ No newline at end of file diff --git a/src/lm/preprocessing.rs b/src/lm/preprocessing.rs index f4cff45..58100f3 100644 --- a/src/lm/preprocessing.rs +++ b/src/lm/preprocessing.rs @@ -1,4 +1,4 @@ -use crate::util::flatten; +// use crate::util::flatten; /// Pads a sequence of words with defaults; prepends "" and appends "" /// @@ -12,9 +12,9 @@ pub fn padded_everygrams<'a>(sentence: impl Iterator + 'a, ord crate::util::everygrams(pad_both_ends(sentence, order), order) } -pub fn padded_everygram_pipeline<'a>(text: impl Iterator + 'a, order: usize) -> (impl Iterator){ - (text.iter().map(|sent| rltk::lm::preprocessing::pad_both_ends(sent.iter(), order)).flatten())//vocab -} +// pub fn padded_everygram_pipeline<'a>(text: impl Iterator + 'a, order: usize) -> (impl Iterator){ +// (text.map(|sent| crate::lm::preprocessing::pad_both_ends(sent), order)).flatten())//vocab +// } #[cfg(test)] mod tests{