rltk/src/lm/mle.rs

106 lines
No EOL
2.8 KiB
Rust

use std::collections::BTreeMap;
struct Vocabulary<'a> {
cutoff: usize,
counter: Counter<'a>,
unk_label: &'a str,
}
impl <'a>Vocabulary<'a> {
pub(crate) fn new(cutoff: usize) -> Self {
Self {
cutoff,
counter: Counter::new(),
unk_label: "<UNK>",
}
}
pub fn update_word(&mut self, word: &'a str) {
self.counter.update_word(word);
}
pub fn update_sentence(&mut self, sentence: impl Iterator<Item=&'a &'a str>) {
self.counter.update_sentence(sentence);
}
pub fn lookup_word(&self, word: &'a str) -> &str {
return if self.counter.get(word) > self.cutoff {
word
} else {
self.unk_label
};
}
pub(crate) fn lookup_sentence(&self, words: impl Iterator<Item=&'a &'a str> + 'a) -> impl Iterator<Item=&'a str> + '_{
words.map(|word| if self.counter.get(word) > self.cutoff {
word
} else {
self.unk_label
})
}
}
struct Counter<'a> {
counts: BTreeMap<&'a str, usize>, //may just need hashmap, not sure yet, do we need ordered keys?
}
impl<'a> Counter<'a> {
pub(crate) fn new() -> Self {
Self {
counts: BTreeMap::new()
}
}
pub(crate) fn update_word(&mut self, word: &'a str) {
let count = self.counts.entry(word).or_insert(0);
*count += 1;
}
pub(crate) fn update_sentence(&mut self, sentence: impl Iterator<Item=&'a &'a str>) {
sentence.for_each(|word| self.update_word(word));
}
pub(crate) fn get(&self, word: &str) -> usize {
*self.counts.get(word).unwrap_or(&0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lookup() {
let mut vocab = Vocabulary::new(1);
vocab.update_sentence(["a", "b", "c", "a", "b", "c"].iter());
let looked_up: Vec<&str> = vocab.lookup_sentence(["a", "b", "c"].iter()).collect();
assert_eq!(looked_up, vec!["a", "b", "c"]);
let looked_up: Vec<&str> = vocab.lookup_sentence(["Aliens", "from", "Mars"].iter()).collect();
assert_eq!(looked_up, vec!["<UNK>", "<UNK>", "<UNK>"]);
}
#[test]
fn test_lookup_below_cutoff() {
let mut vocab = Vocabulary::new(1);
vocab.update_sentence(["a", "b", "c"].iter());
let looked_up: Vec<&str> = vocab.lookup_words(["a", "b", "c"].iter()).collect();
assert_eq!(looked_up, vec!["<UNK>", "<UNK>", "<UNK>"]);
}
#[test]
fn test_count_words() {
let mut counter = Counter::new();
counter.update_word("a");
assert_eq!(counter.get("a"), 1);
}
#[test]
fn test_count_sentence() {
let mut counter = Counter::new();
counter.update_sentence(["a", "b", "a"].iter());
assert_eq!(counter.get("a"), 2);
assert_eq!(counter.get("b"), 1);
}
}