diff --git a/.gitea/workflows/build.yaml b/.gitea/workflows/build.yaml new file mode 100644 index 0000000..51c9fb1 --- /dev/null +++ b/.gitea/workflows/build.yaml @@ -0,0 +1,18 @@ +name: CI +on: + push: + branches: [ master ] +jobs: + build: + runs-on: woryzen + steps: + - name: Checkout sources + uses: actions/checkout@v4 + - name: Run unit tests + run: | + cargo test + - name: Publish artifacts + env: + CARGO_REGISTRIES_GITEA_TOKEN: Bearer ${{ secrets.PUBLISHER_TOKEN }} + run: | + cargo publish diff --git a/Cargo.lock b/Cargo.lock index df06ffe..c5c8c40 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "heck" @@ -10,7 +10,7 @@ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" [[package]] name = "levtree" -version = "0.1.0" +version = "0.1.1" dependencies = [ "sealed", "trait-group", diff --git a/Cargo.toml b/Cargo.toml index 8a4a2ae..f343459 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,9 +1,9 @@ [package] name = "levtree" -version = "0.1.0" +version = "0.1.1" authors = ["Walter Oggioni "] license = "MIT" -rust-version = "1.60" +edition = "2024" [dependencies] trait-group = "0.1.0" @@ -19,8 +19,3 @@ bench = false name = "levtree_benchmark" path = "examples/benchmark.rs" -[profile.release] -strip = true -lto = true -debug-assertions = false -codegen-units = 1 \ No newline at end of file diff --git a/examples/benchmark.rs b/examples/benchmark.rs index ddfd1d4..bd2bcc0 100644 --- a/examples/benchmark.rs +++ b/examples/benchmark.rs @@ -47,6 +47,7 @@ fn main() { for key in keys { let word = &key.into_char_slice()[..]; let results = trie.fuzzy_search::(word, 6); + println!("needle: {}", key); for result in results { let word: String = trie.lineal_descendant(result.word).into_iter().collect(); println!("distance: {}, wordkey: {}", result.distance, word); diff --git a/src/levtrie.rs b/src/levtrie.rs index ec3d14a..f5a13d2 100644 --- a/src/levtrie.rs +++ b/src/levtrie.rs @@ -1,9 +1,9 @@ extern crate sealed; use self::sealed::sealed; -use std::collections::BTreeSet; +use std::collections::BinaryHeap; use super::keychecker::KeyChecker; -use super::result::Result; +use super::search_result::SearchResult; use super::trie::Trie; use super::trie::VisitOutcome; use super::trienode::TrieKey; @@ -20,8 +20,8 @@ where { fn compute( workspace: &mut Vec>, - nodes: &Vec>, - stack: &Vec, + nodes: &[LevTrieNode], + stack: &[usize], wordkey: &[KEY], worst_case: Option, ) -> VisitOutcome; @@ -33,10 +33,10 @@ where KEYCHECKER: KeyChecker, { pub fn new() -> LevTrie { - Trie::empty() + Trie::default() } - pub fn from_words(wordlist: U) -> LevTrie + pub fn from_words(wordlist: U) -> LevTrie where T: IntoIterator, U: IntoIterator, @@ -48,14 +48,13 @@ where result } - pub fn fuzzy_search(&mut self, word: &[KEY], max_result: usize) -> BTreeSet + pub fn fuzzy_search(&mut self, word: &[KEY], max_result: usize) -> Vec where DC: DistanceCalculator, { - let word_len = word.into_iter().count(); - let mut workspace: &mut Vec> = - &mut (0..self.nodes()).map(|_| Vec::new()).collect(); - let mut results = BTreeSet::new(); + let word_len = word.len(); + let workspace: &mut Vec> = &mut (0..self.nodes()).map(|_| Vec::new()).collect(); + let mut result_heap = BinaryHeap::::with_capacity(max_result + 1); let required_size = word_len + 1; let visit_pre = |stack: &Vec| -> VisitOutcome { let stack_size = stack.len(); @@ -63,32 +62,33 @@ where let payload = &mut workspace[current_node_id]; payload.resize(required_size, usize::default()); if stack_size == 1 { - for i in 0..required_size { - payload[i] = i; + for (i, item) in payload.iter_mut().enumerate().take(required_size) { + *item = i; } } else { - for i in 0..required_size { - payload[i] = if i == 0 { stack_size - 1 } else { 0 } + for (i, item) in payload.iter_mut().enumerate().take(required_size) { + *item = if i == 0 { stack_size - 1 } else { 0 } } } if stack_size > 1 { let current_node = &mut self.get_node(current_node_id); if current_node.key.is_none() { let distance = workspace[stack[stack_size - 2]][word_len]; - results.insert(Result { - distance: distance, + let search_result = SearchResult { + distance, word: current_node_id, - }); - if results.len() > max_result { - results.pop_last(); + }; + result_heap.push(search_result); + if result_heap.len() > max_result { + result_heap.pop(); } VisitOutcome::Skip } else { - let worst_case = results - .last() - .filter(|_| results.len() == max_result) + let worst_case = result_heap + .peek() + .filter(|_| result_heap.len() == max_result) .map(|it| it.distance); - DC::compute(&mut workspace, &self.nodes, stack, word, worst_case) + DC::compute(workspace, &self.nodes, stack, word, worst_case) } } else { VisitOutcome::Continue @@ -96,7 +96,8 @@ where }; let visit_post = |_: &Vec| {}; self.walk(visit_pre, visit_post); - results + + result_heap.into_sorted_vec() } } @@ -110,13 +111,13 @@ where { fn compute( workspace: &mut Vec>, - nodes: &Vec>, - stack: &Vec, + nodes: &[LevTrieNode], + stack: &[usize], wordkey: &[KEY], worst_case: Option, ) -> VisitOutcome { let sz = stack.len(); - let key_size = wordkey.into_iter().count(); + let key_size = wordkey.len(); for i in 1..=key_size { if KEYCHECKER::check(Some(wordkey[i - 1]), nodes[stack[sz - 1]].key) { workspace[stack[sz - 1]][i] = workspace[stack[sz - 2]][i - 1]; @@ -131,7 +132,7 @@ where } } let condition = worst_case - .map(|wv| wv <= *workspace[stack[sz - 1]][..].into_iter().min().unwrap()) + .map(|wv| wv <= *workspace[stack[sz - 1]][..].iter().min().unwrap()) .unwrap_or(false); if condition { VisitOutcome::Skip @@ -151,13 +152,13 @@ where { fn compute( workspace: &mut Vec>, - nodes: &Vec>, - stack: &Vec, + nodes: &[LevTrieNode], + stack: &[usize], wordkey: &[KEY], worst_case: Option, ) -> VisitOutcome { let sz = stack.len(); - let key_size = wordkey.into_iter().count(); + let key_size = wordkey.len(); for i in 1..=key_size { if KEYCHECKER::check( Some(wordkey[i - 1]), @@ -185,7 +186,7 @@ where } } let condition = worst_case - .map(|wv| wv <= *workspace[stack[sz - 2]][..].into_iter().min().unwrap()) + .map(|wv| wv <= *workspace[stack[sz - 2]][..].iter().min().unwrap()) .unwrap_or(false); if condition { VisitOutcome::Skip diff --git a/src/lib.rs b/src/lib.rs index ae4e1e4..65e8ecf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,5 +21,8 @@ pub use self::keychecker::KeyChecker; pub type CaseSensitiveLevTrie = LevTrie; pub type CaseInSensitiveLevTrie = LevTrie; -mod result; -pub use self::result::Result; +mod search_result; +pub use self::search_result::SearchResult; + +#[cfg(test)] +mod tests; diff --git a/src/result.rs b/src/search_result.rs similarity index 51% rename from src/result.rs rename to src/search_result.rs index d78795a..acf8444 100644 --- a/src/result.rs +++ b/src/search_result.rs @@ -1,27 +1,26 @@ use std::cmp::Ordering; -pub struct Result { +#[derive(Clone)] +pub struct SearchResult { pub word: usize, pub distance: usize, } -impl PartialOrd for Result { +impl PartialOrd for SearchResult { fn partial_cmp(&self, other: &Self) -> Option { - Some(self.distance.cmp(&other.distance)) - .filter(|it| it != &Ordering::Equal) - .or_else(|| Some(self.word.cmp(&other.word))) + Some(self.cmp(other)) } } -impl PartialEq for Result { +impl PartialEq for SearchResult { fn eq(&self, other: &Self) -> bool { self.distance == other.distance && self.word == other.word } } -impl Eq for Result {} +impl Eq for SearchResult {} -impl Ord for Result { +impl Ord for SearchResult { fn cmp(&self, other: &Self) -> std::cmp::Ordering { match self.distance.cmp(&other.distance) { std::cmp::Ordering::Equal => self.word.cmp(&other.word), @@ -30,21 +29,3 @@ impl Ord for Result { } } } - -//struct Standing { -// size: usize, -// results: Vec, -//} -// -//impl Standing { -// pub fn new(size: usize) -> Standing { -// Standing { -// size, -// results: BTreeSet::new(), -// } -// } -// -// pub fn addResult(&mut self, res: Result) { -// self.results.push(res) -// } -//} diff --git a/src/tests.rs b/src/tests.rs new file mode 100644 index 0000000..61543ed --- /dev/null +++ b/src/tests.rs @@ -0,0 +1,138 @@ +use super::{ + CaseSensitiveLevTrie, DamerauLevenshteinDistanceCalculator, KeyChecker, LevTrie, + LevenshteinDistanceCalculator, SearchResult, +}; +use std::collections::BTreeMap; +use std::fmt::Display; +use std::io::Write; + +struct ExpectedResults { + data: Vec<(usize, usize)>, +} + +impl ExpectedResults { + fn new(id_map: &BTreeMap, results: &[(String, usize)]) -> ExpectedResults { + let data = results + .iter() + .map(|(key, distance)| { + ( + *id_map + .get(key) + .ok_or_else(|| format!("Id not found for key '{key}'")) + .unwrap(), + *distance, + ) + }) + .collect::>(); + ExpectedResults { data } + } + + fn check(&self, search_results: &[SearchResult]) { + for i in 0..self.data.len() { + let SearchResult { word, distance } = search_results[i]; + let data = self.data[i]; + if data != (word, distance) { + panic!("({}, {}) <> ({}, {})", data.0, data.1, word, distance); + } + } + } +} + +fn print_search_results>( + trie: &LevTrie, + search_results: &[SearchResult], + key_separator: &str, +) -> Result<(), std::io::Error> { + for result in search_results { + let mut word = Vec::::new(); + for (i, fragment) in trie.lineal_descendant(result.word).enumerate() { + if i > 0 { + word.write(format!("{}{}", key_separator, fragment).as_bytes())?; + } else { + word.write(format!("{}", fragment).as_bytes())?; + } + } + println!( + "distance: {}, wordkey: {}, id: {}", + result.distance, + String::from_utf8(word).unwrap(), + result.word + ); + } + Ok(()) +} + +const WORDLIST: [&str; 16] = [ + "skyscraper", + "camel", + "coal", + "caos", + "copper", + "hello", + "Bugis", + "Kembangan", + "Singapore", + "Fullerton", + "Lavender", + "aircraft", + "boat", + "ship", + "cargo", + "tanker", +]; + +#[test] +fn test_damerau_levenshtein_strings() { + let mut trie: CaseSensitiveLevTrie = LevTrie::new(); + let mut id_map = BTreeMap::::new(); + for word in WORDLIST { + let (_, id) = trie.add(word.chars()); + id_map.insert(String::from(word), id); + } + let results = trie.fuzzy_search::( + &"coat".chars().collect::>(), + 6, + ); + + print_search_results(&trie, &results, "").unwrap(); + + let expected_results = ExpectedResults::new( + &id_map, + &[ + (String::from("coal"), 1), + (String::from("boat"), 1), + (String::from("caos"), 3), + (String::from("camel"), 4), + (String::from("copper"), 4), + (String::from("ship"), 4), + ], + ); + expected_results.check(&results); +} + +#[test] +fn test_levenshtein_strings() { + let mut trie: CaseSensitiveLevTrie = LevTrie::new(); + let mut id_map = BTreeMap::::new(); + for word in WORDLIST { + let (_, id) = trie.add(word.chars()); + id_map.insert(String::from(word), id); + } + let results = trie + .fuzzy_search::(&"coat".chars().collect::>(), 6); + + print_search_results(&trie, &results, "").unwrap(); + + let expected_results = ExpectedResults::new( + &id_map, + &[ + (String::from("coal"), 1), + (String::from("boat"), 1), + (String::from("caos"), 3), + (String::from("camel"), 4), + (String::from("copper"), 4), + (String::from("ship"), 4), + ], + ); + expected_results.check(&results); +} diff --git a/src/trie.rs b/src/trie.rs index 3b93842..fcba0a9 100644 --- a/src/trie.rs +++ b/src/trie.rs @@ -1,9 +1,9 @@ -use std::collections::BTreeSet; -use std::marker::PhantomData; - use super::keychecker::KeyChecker; use super::trienode::TrieKey; use super::trienode::TrieNode; +use std::collections::BTreeSet; +use std::iter::Iterator; +use std::marker::PhantomData; pub enum VisitOutcome { Continue, @@ -16,47 +16,49 @@ where KEY: TrieKey, KEYCHECKER: KeyChecker, { - pub (crate) nodes: Vec>, + pub(crate) nodes: Vec>, tails: BTreeSet, checker: PhantomData, } - +impl Default for Trie +where + KEY: TrieKey, + KEYCHECKER: KeyChecker, +{ + fn default() -> Self { + Trie { + nodes: vec![TrieNode::new0(None)], + tails: BTreeSet::new(), + checker: PhantomData, + } + } +} impl Trie where KEY: TrieKey, KEYCHECKER: KeyChecker, { - pub fn empty() -> Trie { - Trie { - nodes: vec![TrieNode::new0(None)], - tails: BTreeSet::new(), - checker: PhantomData::default(), - } - } - - pub fn trie_from_words( - wordlist: U, - ) -> Trie + pub fn trie_from_words(wordlist: U) -> Trie where T: IntoIterator, U: IntoIterator, { - let mut result = Trie::empty(); + let mut result = Trie::default(); for word in wordlist { result.add(word); } result } - pub (crate) fn get_node_mut(&mut self, index: usize) -> &mut TrieNode { + pub(crate) fn get_node_mut(&mut self, index: usize) -> &mut TrieNode { &mut self.nodes[index] } - pub (crate) fn get_node(&self, index: usize) -> &TrieNode { + pub(crate) fn get_node(&self, index: usize) -> &TrieNode { &self.nodes[index] } - pub (crate) fn nodes(&self) -> usize { + pub(crate) fn nodes(&self) -> usize { self.nodes.len() } @@ -98,7 +100,7 @@ where result_index } - pub fn add(&mut self, path: T) -> (bool, usize) + pub fn add(&mut self, path: T) -> (bool, usize) where T: IntoIterator, { @@ -106,22 +108,15 @@ where let mut pnode = 0; 'wordLoop: for key in path { let mut cnode = self.get_node(pnode).child; - loop { - match cnode { - Some(cnode_index) => { - let cnode_node = self.get_node(cnode_index); - if KEYCHECKER::check(cnode_node.key, Some(key)) { - pnode = cnode_index; - continue 'wordLoop; - } else if self.get_node(cnode_index).next.is_none() { - break; - } else { - cnode = self.get_node(cnode_index).next; - } - } - None => { - break; - } + while let Some(cnode_index) = cnode { + let cnode_node = self.get_node(cnode_index); + if KEYCHECKER::check(cnode_node.key, Some(key)) { + pnode = cnode_index; + continue 'wordLoop; + } else if self.get_node(cnode_index).next.is_none() { + break; + } else { + cnode = self.get_node(cnode_index).next; } } pnode = self.add_node(Some(key), pnode, cnode); @@ -131,17 +126,10 @@ where let tail = self.add_node(None, pnode, None); self.tails.insert(tail); let mut node = Some(tail); - loop { - match node { - Some(n) => { - let current_node = self.get_node_mut(n); - current_node.ref_count += 1; - node = current_node.parent; - } - None => { - break; - } - } + while let Some(n) = node { + let current_node = self.get_node_mut(n); + current_node.ref_count += 1; + node = current_node.parent; } (true, tail) } else { @@ -177,31 +165,23 @@ where result } - pub fn lineal_descendant(&self, start: usize) -> Vec<&KEY> { - let mut chars: Vec<&KEY> = vec![]; + pub fn lineal_descendant(&self, start: usize) -> impl Iterator { + let mut nodes: Vec = vec![]; let mut node_option = Some(start); - loop { - match node_option { - Some(node) => { - let key = &self.get_node(node).key; - match key { - Some(key) => { - chars.push(key); - } - None => {} - } - node_option = self.get_node(node).parent; - } - None => { - break; - } + while let Some(node) = node_option { + let key = &self.get_node(node).key; + if key.is_some() { + nodes.push(node); } + node_option = self.get_node(node).parent; } - chars.reverse(); - chars + nodes + .into_iter() + .rev() + .map(|node_index| self.get_node(node_index).key.as_ref().unwrap()) } - pub (crate) fn walk(&self, mut visit_pre: CB1, mut visit_post: CB2) + pub(crate) fn walk(&self, mut visit_pre: CB1, mut visit_post: CB2) where CB1: FnMut(&Vec) -> VisitOutcome, CB2: FnMut(&Vec), @@ -243,4 +223,3 @@ where &self.tails } } - diff --git a/src/trienode.rs b/src/trienode.rs index c8a0a85..5bad84a 100644 --- a/src/trienode.rs +++ b/src/trienode.rs @@ -28,10 +28,10 @@ where { TrieNode { key, - prev: prev, - next: next, - child: child, - parent: parent, + prev, + next, + child, + parent, ref_count: 0, } }