started developing Max Likelyhood Estimator, counts vocabulary

This commit is contained in:
Sander Hautvast 2022-05-11 18:39:57 +02:00
parent b9fc3386f0
commit 8d39285ff9
3 changed files with 112 additions and 5 deletions

106
src/lm/mle.rs Normal file
View file

@ -0,0 +1,106 @@
use std::collections::BTreeMap;
struct Vocabulary<'a> {
cutoff: usize,
counter: Counter<'a>,
unk_label: &'a str,
}
impl <'a>Vocabulary<'a> {
pub(crate) fn new(cutoff: usize) -> Self {
Self {
cutoff,
counter: Counter::new(),
unk_label: "<UNK>",
}
}
pub fn update_word(&mut self, word: &'a str) {
self.counter.update_word(word);
}
pub fn update_sentence(&mut self, sentence: impl Iterator<Item=&'a &'a str>) {
self.counter.update_sentence(sentence);
}
pub fn lookup_word(&self, word: &'a str) -> &str {
return if self.counter.get(word) > self.cutoff {
word
} else {
self.unk_label
};
}
pub(crate) fn lookup_sentence(&self, words: impl Iterator<Item=&'a &'a str> + 'a) -> impl Iterator<Item=&'a str> + '_{
words.map(|word| if self.counter.get(word) > self.cutoff {
word
} else {
self.unk_label
})
}
}
struct Counter<'a> {
counts: BTreeMap<&'a str, usize>, //may just need hashmap, not sure yet, do we need ordered keys?
}
impl<'a> Counter<'a> {
pub(crate) fn new() -> Self {
Self {
counts: BTreeMap::new()
}
}
pub(crate) fn update_word(&mut self, word: &'a str) {
let count = self.counts.entry(word).or_insert(0);
*count += 1;
}
pub(crate) fn update_sentence(&mut self, sentence: impl Iterator<Item=&'a &'a str>) {
sentence.for_each(|word| self.update_word(word));
}
pub(crate) fn get(&self, word: &str) -> usize {
*self.counts.get(word).unwrap_or(&0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lookup() {
let mut vocab = Vocabulary::new(1);
vocab.update_sentence(["a", "b", "c", "a", "b", "c"].iter());
let looked_up: Vec<&str> = vocab.lookup_sentence(["a", "b", "c"].iter()).collect();
assert_eq!(looked_up, vec!["a", "b", "c"]);
let looked_up: Vec<&str> = vocab.lookup_sentence(["Aliens", "from", "Mars"].iter()).collect();
assert_eq!(looked_up, vec!["<UNK>", "<UNK>", "<UNK>"]);
}
#[test]
fn test_lookup_below_cutoff() {
let mut vocab = Vocabulary::new(1);
vocab.update_sentence(["a", "b", "c"].iter());
let looked_up: Vec<&str> = vocab.lookup_words(["a", "b", "c"].iter()).collect();
assert_eq!(looked_up, vec!["<UNK>", "<UNK>", "<UNK>"]);
}
#[test]
fn test_count_words() {
let mut counter = Counter::new();
counter.update_word("a");
assert_eq!(counter.get("a"), 1);
}
#[test]
fn test_count_sentence() {
let mut counter = Counter::new();
counter.update_sentence(["a", "b", "a"].iter());
assert_eq!(counter.get("a"), 2);
assert_eq!(counter.get("b"), 1);
}
}

View file

@ -1 +1,2 @@
pub mod preprocessing;
pub mod preprocessing;
pub mod mle;

View file

@ -1,4 +1,4 @@
use crate::util::flatten;
// use crate::util::flatten;
/// Pads a sequence of words with defaults; prepends "<s>" and appends "<s>"
///
@ -12,9 +12,9 @@ pub fn padded_everygrams<'a>(sentence: impl Iterator<Item=&'a &'a str> + 'a, ord
crate::util::everygrams(pad_both_ends(sentence, order), order)
}
pub fn padded_everygram_pipeline<'a>(text: impl Iterator<Item=&'a &'a str> + 'a, order: usize) -> (impl Iterator<Item=&'a &'a str>){
(text.iter().map(|sent| rltk::lm::preprocessing::pad_both_ends(sent.iter(), order)).flatten())//vocab
}
// pub fn padded_everygram_pipeline<'a>(text: impl Iterator<Item=&'a &'a str> + 'a, order: usize) -> (impl Iterator<Item=&'a &'a str>){
// (text.map(|sent| crate::lm::preprocessing::pad_both_ends(sent), order)).flatten())//vocab
// }
#[cfg(test)]
mod tests{