added basic unit tests

This commit is contained in:
2025-07-30 11:03:19 +08:00
parent 575281869b
commit 3d2a08801f
10 changed files with 259 additions and 143 deletions

View File

@@ -0,0 +1,18 @@
name: CI
on:
push:
branches: [ master ]
jobs:
build:
runs-on: woryzen
steps:
- name: Checkout sources
uses: actions/checkout@v4
- name: Run unit tests
run: |
cargo test
- name: Publish artifacts
env:
CARGO_REGISTRIES_GITEA_TOKEN: Bearer ${{ secrets.PUBLISHER_TOKEN }}
run: |
cargo publish

4
Cargo.lock generated
View File

@@ -1,6 +1,6 @@
# This file is automatically @generated by Cargo. # This file is automatically @generated by Cargo.
# It is not intended for manual editing. # It is not intended for manual editing.
version = 3 version = 4
[[package]] [[package]]
name = "heck" name = "heck"
@@ -10,7 +10,7 @@ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
[[package]] [[package]]
name = "levtree" name = "levtree"
version = "0.1.0" version = "0.1.1"
dependencies = [ dependencies = [
"sealed", "sealed",
"trait-group", "trait-group",

View File

@@ -1,9 +1,9 @@
[package] [package]
name = "levtree" name = "levtree"
version = "0.1.0" version = "0.1.1"
authors = ["Walter Oggioni <oggioni.walter@gmail.com>"] authors = ["Walter Oggioni <oggioni.walter@gmail.com>"]
license = "MIT" license = "MIT"
rust-version = "1.60" edition = "2024"
[dependencies] [dependencies]
trait-group = "0.1.0" trait-group = "0.1.0"
@@ -19,8 +19,3 @@ bench = false
name = "levtree_benchmark" name = "levtree_benchmark"
path = "examples/benchmark.rs" path = "examples/benchmark.rs"
[profile.release]
strip = true
lto = true
debug-assertions = false
codegen-units = 1

View File

@@ -47,6 +47,7 @@ fn main() {
for key in keys { for key in keys {
let word = &key.into_char_slice()[..]; let word = &key.into_char_slice()[..];
let results = trie.fuzzy_search::<DamerauLevenshteinDistanceCalculator>(word, 6); let results = trie.fuzzy_search::<DamerauLevenshteinDistanceCalculator>(word, 6);
println!("needle: {}", key);
for result in results { for result in results {
let word: String = trie.lineal_descendant(result.word).into_iter().collect(); let word: String = trie.lineal_descendant(result.word).into_iter().collect();
println!("distance: {}, wordkey: {}", result.distance, word); println!("distance: {}, wordkey: {}", result.distance, word);

View File

@@ -1,9 +1,9 @@
extern crate sealed; extern crate sealed;
use self::sealed::sealed; use self::sealed::sealed;
use std::collections::BTreeSet; use std::collections::BinaryHeap;
use super::keychecker::KeyChecker; use super::keychecker::KeyChecker;
use super::result::Result; use super::search_result::SearchResult;
use super::trie::Trie; use super::trie::Trie;
use super::trie::VisitOutcome; use super::trie::VisitOutcome;
use super::trienode::TrieKey; use super::trienode::TrieKey;
@@ -20,8 +20,8 @@ where
{ {
fn compute( fn compute(
workspace: &mut Vec<Vec<usize>>, workspace: &mut Vec<Vec<usize>>,
nodes: &Vec<LevTrieNode<KEY>>, nodes: &[LevTrieNode<KEY>],
stack: &Vec<usize>, stack: &[usize],
wordkey: &[KEY], wordkey: &[KEY],
worst_case: Option<usize>, worst_case: Option<usize>,
) -> VisitOutcome; ) -> VisitOutcome;
@@ -33,10 +33,10 @@ where
KEYCHECKER: KeyChecker<KEY>, KEYCHECKER: KeyChecker<KEY>,
{ {
pub fn new() -> LevTrie<KEY, KEYCHECKER> { pub fn new() -> LevTrie<KEY, KEYCHECKER> {
Trie::empty() Trie::default()
} }
pub fn from_words<T: IntoIterator, U: IntoIterator>(wordlist: U) -> LevTrie<KEY, KEYCHECKER> pub fn from_words<T, U>(wordlist: U) -> LevTrie<KEY, KEYCHECKER>
where where
T: IntoIterator<Item = KEY>, T: IntoIterator<Item = KEY>,
U: IntoIterator<Item = T>, U: IntoIterator<Item = T>,
@@ -48,14 +48,13 @@ where
result result
} }
pub fn fuzzy_search<DC>(&mut self, word: &[KEY], max_result: usize) -> BTreeSet<Result> pub fn fuzzy_search<DC>(&mut self, word: &[KEY], max_result: usize) -> Vec<SearchResult>
where where
DC: DistanceCalculator<KEY, KEYCHECKER>, DC: DistanceCalculator<KEY, KEYCHECKER>,
{ {
let word_len = word.into_iter().count(); let word_len = word.len();
let mut workspace: &mut Vec<Vec<usize>> = let workspace: &mut Vec<Vec<usize>> = &mut (0..self.nodes()).map(|_| Vec::new()).collect();
&mut (0..self.nodes()).map(|_| Vec::new()).collect(); let mut result_heap = BinaryHeap::<SearchResult>::with_capacity(max_result + 1);
let mut results = BTreeSet::new();
let required_size = word_len + 1; let required_size = word_len + 1;
let visit_pre = |stack: &Vec<usize>| -> VisitOutcome { let visit_pre = |stack: &Vec<usize>| -> VisitOutcome {
let stack_size = stack.len(); let stack_size = stack.len();
@@ -63,32 +62,33 @@ where
let payload = &mut workspace[current_node_id]; let payload = &mut workspace[current_node_id];
payload.resize(required_size, usize::default()); payload.resize(required_size, usize::default());
if stack_size == 1 { if stack_size == 1 {
for i in 0..required_size { for (i, item) in payload.iter_mut().enumerate().take(required_size) {
payload[i] = i; *item = i;
} }
} else { } else {
for i in 0..required_size { for (i, item) in payload.iter_mut().enumerate().take(required_size) {
payload[i] = if i == 0 { stack_size - 1 } else { 0 } *item = if i == 0 { stack_size - 1 } else { 0 }
} }
} }
if stack_size > 1 { if stack_size > 1 {
let current_node = &mut self.get_node(current_node_id); let current_node = &mut self.get_node(current_node_id);
if current_node.key.is_none() { if current_node.key.is_none() {
let distance = workspace[stack[stack_size - 2]][word_len]; let distance = workspace[stack[stack_size - 2]][word_len];
results.insert(Result { let search_result = SearchResult {
distance: distance, distance,
word: current_node_id, word: current_node_id,
}); };
if results.len() > max_result { result_heap.push(search_result);
results.pop_last(); if result_heap.len() > max_result {
result_heap.pop();
} }
VisitOutcome::Skip VisitOutcome::Skip
} else { } else {
let worst_case = results let worst_case = result_heap
.last() .peek()
.filter(|_| results.len() == max_result) .filter(|_| result_heap.len() == max_result)
.map(|it| it.distance); .map(|it| it.distance);
DC::compute(&mut workspace, &self.nodes, stack, word, worst_case) DC::compute(workspace, &self.nodes, stack, word, worst_case)
} }
} else { } else {
VisitOutcome::Continue VisitOutcome::Continue
@@ -96,7 +96,8 @@ where
}; };
let visit_post = |_: &Vec<usize>| {}; let visit_post = |_: &Vec<usize>| {};
self.walk(visit_pre, visit_post); self.walk(visit_pre, visit_post);
results
result_heap.into_sorted_vec()
} }
} }
@@ -110,13 +111,13 @@ where
{ {
fn compute( fn compute(
workspace: &mut Vec<Vec<usize>>, workspace: &mut Vec<Vec<usize>>,
nodes: &Vec<LevTrieNode<KEY>>, nodes: &[LevTrieNode<KEY>],
stack: &Vec<usize>, stack: &[usize],
wordkey: &[KEY], wordkey: &[KEY],
worst_case: Option<usize>, worst_case: Option<usize>,
) -> VisitOutcome { ) -> VisitOutcome {
let sz = stack.len(); let sz = stack.len();
let key_size = wordkey.into_iter().count(); let key_size = wordkey.len();
for i in 1..=key_size { for i in 1..=key_size {
if KEYCHECKER::check(Some(wordkey[i - 1]), nodes[stack[sz - 1]].key) { if KEYCHECKER::check(Some(wordkey[i - 1]), nodes[stack[sz - 1]].key) {
workspace[stack[sz - 1]][i] = workspace[stack[sz - 2]][i - 1]; workspace[stack[sz - 1]][i] = workspace[stack[sz - 2]][i - 1];
@@ -131,7 +132,7 @@ where
} }
} }
let condition = worst_case let condition = worst_case
.map(|wv| wv <= *workspace[stack[sz - 1]][..].into_iter().min().unwrap()) .map(|wv| wv <= *workspace[stack[sz - 1]][..].iter().min().unwrap())
.unwrap_or(false); .unwrap_or(false);
if condition { if condition {
VisitOutcome::Skip VisitOutcome::Skip
@@ -151,13 +152,13 @@ where
{ {
fn compute( fn compute(
workspace: &mut Vec<Vec<usize>>, workspace: &mut Vec<Vec<usize>>,
nodes: &Vec<LevTrieNode<KEY>>, nodes: &[LevTrieNode<KEY>],
stack: &Vec<usize>, stack: &[usize],
wordkey: &[KEY], wordkey: &[KEY],
worst_case: Option<usize>, worst_case: Option<usize>,
) -> VisitOutcome { ) -> VisitOutcome {
let sz = stack.len(); let sz = stack.len();
let key_size = wordkey.into_iter().count(); let key_size = wordkey.len();
for i in 1..=key_size { for i in 1..=key_size {
if KEYCHECKER::check( if KEYCHECKER::check(
Some(wordkey[i - 1]), Some(wordkey[i - 1]),
@@ -185,7 +186,7 @@ where
} }
} }
let condition = worst_case let condition = worst_case
.map(|wv| wv <= *workspace[stack[sz - 2]][..].into_iter().min().unwrap()) .map(|wv| wv <= *workspace[stack[sz - 2]][..].iter().min().unwrap())
.unwrap_or(false); .unwrap_or(false);
if condition { if condition {
VisitOutcome::Skip VisitOutcome::Skip

View File

@@ -21,5 +21,8 @@ pub use self::keychecker::KeyChecker;
pub type CaseSensitiveLevTrie = LevTrie<char, CaseSensitiveKeyChecker>; pub type CaseSensitiveLevTrie = LevTrie<char, CaseSensitiveKeyChecker>;
pub type CaseInSensitiveLevTrie = LevTrie<char, CaseInsensitiveKeyChecker>; pub type CaseInSensitiveLevTrie = LevTrie<char, CaseInsensitiveKeyChecker>;
mod result; mod search_result;
pub use self::result::Result; pub use self::search_result::SearchResult;
#[cfg(test)]
mod tests;

View File

@@ -1,27 +1,26 @@
use std::cmp::Ordering; use std::cmp::Ordering;
pub struct Result { #[derive(Clone)]
pub struct SearchResult {
pub word: usize, pub word: usize,
pub distance: usize, pub distance: usize,
} }
impl PartialOrd for Result { impl PartialOrd for SearchResult {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> { fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.distance.cmp(&other.distance)) Some(self.cmp(other))
.filter(|it| it != &Ordering::Equal)
.or_else(|| Some(self.word.cmp(&other.word)))
} }
} }
impl PartialEq for Result { impl PartialEq for SearchResult {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
self.distance == other.distance && self.word == other.word self.distance == other.distance && self.word == other.word
} }
} }
impl Eq for Result {} impl Eq for SearchResult {}
impl Ord for Result { impl Ord for SearchResult {
fn cmp(&self, other: &Self) -> std::cmp::Ordering { fn cmp(&self, other: &Self) -> std::cmp::Ordering {
match self.distance.cmp(&other.distance) { match self.distance.cmp(&other.distance) {
std::cmp::Ordering::Equal => self.word.cmp(&other.word), std::cmp::Ordering::Equal => self.word.cmp(&other.word),
@@ -30,21 +29,3 @@ impl Ord for Result {
} }
} }
} }
//struct Standing {
// size: usize,
// results: Vec<Result>,
//}
//
//impl Standing {
// pub fn new(size: usize) -> Standing {
// Standing {
// size,
// results: BTreeSet::new(),
// }
// }
//
// pub fn addResult(&mut self, res: Result) {
// self.results.push(res)
// }
//}

138
src/tests.rs Normal file
View File

@@ -0,0 +1,138 @@
use super::{
CaseSensitiveLevTrie, DamerauLevenshteinDistanceCalculator, KeyChecker, LevTrie,
LevenshteinDistanceCalculator, SearchResult,
};
use std::collections::BTreeMap;
use std::fmt::Display;
use std::io::Write;
struct ExpectedResults {
data: Vec<(usize, usize)>,
}
impl ExpectedResults {
fn new(id_map: &BTreeMap<String, usize>, results: &[(String, usize)]) -> ExpectedResults {
let data = results
.iter()
.map(|(key, distance)| {
(
*id_map
.get(key)
.ok_or_else(|| format!("Id not found for key '{key}'"))
.unwrap(),
*distance,
)
})
.collect::<Vec<(usize, usize)>>();
ExpectedResults { data }
}
fn check(&self, search_results: &[SearchResult]) {
for i in 0..self.data.len() {
let SearchResult { word, distance } = search_results[i];
let data = self.data[i];
if data != (word, distance) {
panic!("({}, {}) <> ({}, {})", data.0, data.1, word, distance);
}
}
}
}
fn print_search_results<T: Display + Copy, C: KeyChecker<T>>(
trie: &LevTrie<T, C>,
search_results: &[SearchResult],
key_separator: &str,
) -> Result<(), std::io::Error> {
for result in search_results {
let mut word = Vec::<u8>::new();
for (i, fragment) in trie.lineal_descendant(result.word).enumerate() {
if i > 0 {
word.write(format!("{}{}", key_separator, fragment).as_bytes())?;
} else {
word.write(format!("{}", fragment).as_bytes())?;
}
}
println!(
"distance: {}, wordkey: {}, id: {}",
result.distance,
String::from_utf8(word).unwrap(),
result.word
);
}
Ok(())
}
const WORDLIST: [&str; 16] = [
"skyscraper",
"camel",
"coal",
"caos",
"copper",
"hello",
"Bugis",
"Kembangan",
"Singapore",
"Fullerton",
"Lavender",
"aircraft",
"boat",
"ship",
"cargo",
"tanker",
];
#[test]
fn test_damerau_levenshtein_strings() {
let mut trie: CaseSensitiveLevTrie = LevTrie::new();
let mut id_map = BTreeMap::<String, usize>::new();
for word in WORDLIST {
let (_, id) = trie.add(word.chars());
id_map.insert(String::from(word), id);
}
let results = trie.fuzzy_search::<DamerauLevenshteinDistanceCalculator>(
&"coat".chars().collect::<Vec<char>>(),
6,
);
print_search_results(&trie, &results, "").unwrap();
let expected_results = ExpectedResults::new(
&id_map,
&[
(String::from("coal"), 1),
(String::from("boat"), 1),
(String::from("caos"), 3),
(String::from("camel"), 4),
(String::from("copper"), 4),
(String::from("ship"), 4),
],
);
expected_results.check(&results);
}
#[test]
fn test_levenshtein_strings() {
let mut trie: CaseSensitiveLevTrie = LevTrie::new();
let mut id_map = BTreeMap::<String, usize>::new();
for word in WORDLIST {
let (_, id) = trie.add(word.chars());
id_map.insert(String::from(word), id);
}
let results = trie
.fuzzy_search::<LevenshteinDistanceCalculator>(&"coat".chars().collect::<Vec<char>>(), 6);
print_search_results(&trie, &results, "").unwrap();
let expected_results = ExpectedResults::new(
&id_map,
&[
(String::from("coal"), 1),
(String::from("boat"), 1),
(String::from("caos"), 3),
(String::from("camel"), 4),
(String::from("copper"), 4),
(String::from("ship"), 4),
],
);
expected_results.check(&results);
}

View File

@@ -1,9 +1,9 @@
use std::collections::BTreeSet;
use std::marker::PhantomData;
use super::keychecker::KeyChecker; use super::keychecker::KeyChecker;
use super::trienode::TrieKey; use super::trienode::TrieKey;
use super::trienode::TrieNode; use super::trienode::TrieNode;
use std::collections::BTreeSet;
use std::iter::Iterator;
use std::marker::PhantomData;
pub enum VisitOutcome { pub enum VisitOutcome {
Continue, Continue,
@@ -20,28 +20,30 @@ where
tails: BTreeSet<usize>, tails: BTreeSet<usize>,
checker: PhantomData<KEYCHECKER>, checker: PhantomData<KEYCHECKER>,
} }
impl<KEY, KEYCHECKER> Default for Trie<KEY, KEYCHECKER>
where
KEY: TrieKey,
KEYCHECKER: KeyChecker<KEY>,
{
fn default() -> Self {
Trie {
nodes: vec![TrieNode::new0(None)],
tails: BTreeSet::new(),
checker: PhantomData,
}
}
}
impl<KEY, KEYCHECKER> Trie<KEY, KEYCHECKER> impl<KEY, KEYCHECKER> Trie<KEY, KEYCHECKER>
where where
KEY: TrieKey, KEY: TrieKey,
KEYCHECKER: KeyChecker<KEY>, KEYCHECKER: KeyChecker<KEY>,
{ {
pub fn empty() -> Trie<KEY, KEYCHECKER> { pub fn trie_from_words<T, U>(wordlist: U) -> Trie<KEY, KEYCHECKER>
Trie {
nodes: vec![TrieNode::new0(None)],
tails: BTreeSet::new(),
checker: PhantomData::default(),
}
}
pub fn trie_from_words<T: IntoIterator, U: IntoIterator>(
wordlist: U,
) -> Trie<KEY, KEYCHECKER>
where where
T: IntoIterator<Item = KEY>, T: IntoIterator<Item = KEY>,
U: IntoIterator<Item = T>, U: IntoIterator<Item = T>,
{ {
let mut result = Trie::empty(); let mut result = Trie::default();
for word in wordlist { for word in wordlist {
result.add(word); result.add(word);
} }
@@ -98,7 +100,7 @@ where
result_index result_index
} }
pub fn add<T: IntoIterator>(&mut self, path: T) -> (bool, usize) pub fn add<T>(&mut self, path: T) -> (bool, usize)
where where
T: IntoIterator<Item = KEY>, T: IntoIterator<Item = KEY>,
{ {
@@ -106,9 +108,7 @@ where
let mut pnode = 0; let mut pnode = 0;
'wordLoop: for key in path { 'wordLoop: for key in path {
let mut cnode = self.get_node(pnode).child; let mut cnode = self.get_node(pnode).child;
loop { while let Some(cnode_index) = cnode {
match cnode {
Some(cnode_index) => {
let cnode_node = self.get_node(cnode_index); let cnode_node = self.get_node(cnode_index);
if KEYCHECKER::check(cnode_node.key, Some(key)) { if KEYCHECKER::check(cnode_node.key, Some(key)) {
pnode = cnode_index; pnode = cnode_index;
@@ -119,11 +119,6 @@ where
cnode = self.get_node(cnode_index).next; cnode = self.get_node(cnode_index).next;
} }
} }
None => {
break;
}
}
}
pnode = self.add_node(Some(key), pnode, cnode); pnode = self.add_node(Some(key), pnode, cnode);
result = true; result = true;
} }
@@ -131,18 +126,11 @@ where
let tail = self.add_node(None, pnode, None); let tail = self.add_node(None, pnode, None);
self.tails.insert(tail); self.tails.insert(tail);
let mut node = Some(tail); let mut node = Some(tail);
loop { while let Some(n) = node {
match node {
Some(n) => {
let current_node = self.get_node_mut(n); let current_node = self.get_node_mut(n);
current_node.ref_count += 1; current_node.ref_count += 1;
node = current_node.parent; node = current_node.parent;
} }
None => {
break;
}
}
}
(true, tail) (true, tail)
} else { } else {
(false, pnode) (false, pnode)
@@ -177,28 +165,20 @@ where
result result
} }
pub fn lineal_descendant(&self, start: usize) -> Vec<&KEY> { pub fn lineal_descendant(&self, start: usize) -> impl Iterator<Item = &KEY> {
let mut chars: Vec<&KEY> = vec![]; let mut nodes: Vec<usize> = vec![];
let mut node_option = Some(start); let mut node_option = Some(start);
loop { while let Some(node) = node_option {
match node_option {
Some(node) => {
let key = &self.get_node(node).key; let key = &self.get_node(node).key;
match key { if key.is_some() {
Some(key) => { nodes.push(node);
chars.push(key);
}
None => {}
} }
node_option = self.get_node(node).parent; node_option = self.get_node(node).parent;
} }
None => { nodes
break; .into_iter()
} .rev()
} .map(|node_index| self.get_node(node_index).key.as_ref().unwrap())
}
chars.reverse();
chars
} }
pub(crate) fn walk<CB1, CB2>(&self, mut visit_pre: CB1, mut visit_post: CB2) pub(crate) fn walk<CB1, CB2>(&self, mut visit_pre: CB1, mut visit_post: CB2)
@@ -243,4 +223,3 @@ where
&self.tails &self.tails
} }
} }

View File

@@ -28,10 +28,10 @@ where
{ {
TrieNode { TrieNode {
key, key,
prev: prev, prev,
next: next, next,
child: child, child,
parent: parent, parent,
ref_count: 0, ref_count: 0,
} }
} }