view src/lib.rs @ 11:b3cd822e1983 default tip

Better deletion but still with crucial bug
author Lewin Bormann <lbo@spheniscida.de>
date Sat, 09 Jul 2022 15:40:10 -0700
parents eec91e90ede7
children
line wrap: on
line source

use ptree::TreeBuilder;
use std::cmp::Ordering;

// Min. order: 3
const ORDER: usize = 4;

#[derive(Debug)]
struct Node<T> {
    entries: [Option<T>; ORDER - 1],
    children: [Option<usize>; ORDER],
}

impl<T: Ord> Node<T> {
    fn new() -> Node<T> {
        Node {
            entries: Default::default(),
            children: Default::default(),
        }
    }
}

trait Compare<T> {
    fn cmp(v1: &T, v2: &T) -> Ordering;
}

struct StandardCompare;

impl<T: Ord> Compare<T> for StandardCompare {
    fn cmp(v1: &T, v2: &T) -> Ordering {
        v1.cmp(v2)
    }
}

#[derive(Debug)]
struct BTree<T, Cmp: Compare<T>> {
    v: Vec<Node<T>>,
    root: usize,
    cmp: std::marker::PhantomData<Cmp>,
}

impl<T: std::fmt::Debug + Ord> BTree<T, StandardCompare> {
    fn new(initial_capacity: usize) -> BTree<T, StandardCompare> {
        BTree {
            v: Vec::with_capacity(initial_capacity),
            root: 0,
            cmp: Default::default(),
        }
    }
}

impl<T: std::fmt::Debug + Ord, Cmp: Compare<T>> BTree<T, Cmp> {
    fn free(&self, node: usize) -> usize {
        self.v[node]
            .entries
            .iter()
            .map(|e| if e.is_some() { 0 } else { 1 })
            .sum()
    }
    fn occupied(&self, node: usize) -> usize {
        self.v[node]
            .entries
            .iter()
            .map(|e| if e.is_some() { 1 } else { 0 })
            .sum()
    }
    fn get(&self, node: usize, elem: usize) -> &Option<T> {
        &self.v[node].entries[elem]
    }
    fn get_mut(&mut self, node: usize, elem: usize) -> &mut Option<T> {
        &mut self.v[node].entries[elem]
    }
    fn get_child(&self, node: usize, child: usize) -> &Option<usize> {
        &self.v[node].children[child]
    }
    fn get_child_mut(&mut self, node: usize, child: usize) -> &mut Option<usize> {
        &mut self.v[node].children[child]
    }

    fn count(&self) -> usize {
        self.count_from(self.root)
    }
    fn count_in(&self, node: usize) -> usize {
        self.v[node].entries.iter().filter(|e| e.is_some()).count()
    }
    fn count_from(&self, node: usize) -> usize {
        let mut s = self.occupied(node);
        for i in 0..ORDER {
            if let Some(ch) = self.v[node].children[i] {
                s += self.count_from(ch);
            }
        }
        s
    }

    pub fn find<'a>(&'a self, value: &T) -> Option<&'a T> {
        self.find_from(self.root, value).map(|e| e.0)
    }
    // Returns reference to value if found, and node/entry indices.
    fn find_from<'a>(&'a self, node: usize, value: &T) -> Option<(&'a T, usize, usize)> {
        let mut j = 0;
        // Break once we arrived at the entry right of the correct child, if we don't return
        // earlier.

        if false {
            // Comment this in if binary search within nodes makes sense. Usually it takes longer than a contiguous scan.
            j = self.binary_node_search(node, value);
            if j < ORDER - 1 {
                if let Some(el) = self.get(node, j) {
                    if Cmp::cmp(value, el) == Ordering::Equal {
                        return Some((el, node, j));
                    }
                }
            }
        } else {
            for i in 0..ORDER {
                if i < ORDER - 1 {
                    if let Some(el) = self.get(node, i) {
                        match Cmp::cmp(value, el) {
                            Ordering::Less => {
                                j = i;
                                break;
                            }
                            Ordering::Equal => return Some((el, node, i)),
                            Ordering::Greater => continue,
                        }
                    } else {
                        j = i;
                        break;
                    }
                } else {
                    j = i;
                    break;
                }
            }
        }
        if let Some(ch) = self.get_child(node, j) {
            return self.find_from(*ch, value);
        }
        None
    }

    // Find equal or next-greater element in node array.
    fn binary_node_search(&self, node: usize, value: &T) -> usize {
        let mut lo = 0;
        let mut hi = ORDER - 2 - self.free(node);
        let mut mid = hi / 2;

        while lo < hi {
            if let Some(el) = self.get(node, mid) {
                match Cmp::cmp(value, el) {
                    Ordering::Less => {
                        hi = mid;
                        mid = lo + (hi - lo) / 2;
                        continue;
                    }
                    Ordering::Equal => return mid,
                    Ordering::Greater => {
                        lo = mid + 1; // We already know that mid cannot be a potential result; skip it.
                        mid = lo + (hi - lo) / 2;
                        continue;
                    }
                }
            } else {
                // This shouldn't occur in a contiguous node.
                unimplemented!();
            }
        }
        if let Some(el) = self.get(node, lo) {
            if Ordering::Greater == Cmp::cmp(value, el) {
                return lo + 1;
            }
        }
        return lo;
    }

    /// Insert a value into an existing node, preserving the order.
    /// Requires that at least one slot be free.
    fn insert_inline(
        &mut self,
        node: usize,
        value: T,
        left: Option<usize>,
        right: Option<usize>,
    ) -> usize {
        assert!(self.free(node) >= 1);
        let mut loc = ORDER - 2;

        let node = &mut self.v[node];
        for i in 0..ORDER - 1 {
            if let Some(ref el) = node.entries[i] {
                if Cmp::cmp(&value, el) == Ordering::Less {
                    // Found right location. Shift remaining elements.
                    loc = i;
                    break;
                }
            } else {
                loc = i;
                break;
            }
        }

        for i in 0..(ORDER - 2 - loc) {
            node.entries[ORDER - 2 - i] = node.entries[ORDER - 3 - i].take();
            node.children[ORDER - 1 - i] = node.children[ORDER - 2 - i].take();
        }
        node.entries[loc] = Some(value);
        node.children[loc] = left;
        node.children[loc + 1] = right;

        loc
    }

    /// Split a node, returning the IDs of the two new nodes and the pivot value.
    fn split(&mut self, node: usize) -> (usize, T, usize) {
        assert!(self.free(node) <= 1);
        // `node` will become left node.
        let leftix = node;
        let rightix = self.v.len();

        let mut right = Node::<T>::new();

        let pivotix = (ORDER - 1) / 2;
        let pivot = self.get_mut(node, pivotix).take();

        // Pivot-left child remains in left node.
        for i in pivotix + 1..ORDER {
            right.children[i - pivotix - 1] = self.get_child_mut(leftix, i).take();
        }
        // Make sure to move all children pointers!
        //right.children[ORDER - pivotix] = self.get_child_mut(leftix, ORDER - 1).take();
        for i in pivotix + 1..ORDER - 1 {
            right.entries[i - pivotix - 1] = self.get_mut(leftix, i).take();
        }

        self.v.push(right);
        (leftix, pivot.unwrap(), rightix)
    }

    /// Returns true if the given node is a leaf (doesn't have children).
    fn is_leaf(&self, node: usize) -> bool {
        self.v[node].children.iter().all(|c| c.is_none())
    }

    /// Insert a value into the BTree.
    pub fn insert(&mut self, value: T) -> bool {
        if let Some((l, pivot, r)) = self.insert_into(self.root, value) {
            // Split root
            let mut newroot = Node::new();
            let rootix = self.v.len();
            newroot.children[0] = Some(l);
            newroot.children[1] = Some(r);
            newroot.entries[0] = Some(pivot);

            self.root = rootix;

            self.v.push(newroot);
            true
        } else {
            true
        }
    }

    /// Insert value recursively. If a split occurs, returns (left, pivot, right) tuple.
    fn insert_into(&mut self, node: usize, value: T) -> Option<(usize, T, usize)> {
        // BTree initialized?
        if self.v.len() == 0 {
            self.v.push(Node::new());
            self.insert_inline(self.root, value, None, None);
            return None;
        }
        // Leaf node with free space? We are finished.
        if self.is_leaf(node) && self.free(node) >= 1 {
            self.insert_inline(node, value, None, None);
            return None;
        } else if self.is_leaf(node) && self.free(node) < 1 {
            // Leaf node but full - needs to be split.
            let (l, pivot, r) = self.split(node);

            // Insert value into left or right subtree.
            match Cmp::cmp(&value, &pivot) {
                Ordering::Less => assert!(self.insert_into(l, value).is_none()),
                Ordering::Greater => assert!(self.insert_into(r, value).is_none()),
                _ => panic!("Attempting to insert duplicate element"),
            };
            // Return fragments upwards to be inserted.
            return Some((l, pivot, r));
        } else {
            // Descend into node.

            let mut split = None;
            // Find correct child to descend into.
            for i in 0..ORDER {
                if i < ORDER - 1 {
                    if let Some(el) = self.get(node, i) {
                        match Cmp::cmp(&value, el) {
                            Ordering::Less => {
                                split = self.insert_into(self.get_child(node, i).unwrap(), value);
                                break;
                            }
                            Ordering::Equal => {
                                panic!("Attempting to insert duplicate element {:?}", value)
                            }
                            Ordering::Greater => continue,
                        }
                    } else {
                        if let Some(ch) = self.get_child(node, i) {
                            split = self.insert_into(*ch, value);
                            break;
                        }
                        panic!(
                            "Internal node missing child? {} => {:?}",
                            node, self.v[node]
                        );
                    }
                } else {
                    if let Some(ch) = self.get_child(node, i) {
                        split = self.insert_into(*ch, value);
                        break;
                    } else {
                        // This code is dead. Protect it, later remove it:
                        panic!("This code should not be used");

                        let (l, pivot, r) = self.split(node);

                        // Insert value into left or right subtree.
                        match Cmp::cmp(&value, &pivot) {
                            Ordering::Less => assert!(self.insert_into(l, value).is_none()),
                            Ordering::Greater => assert!(self.insert_into(r, value).is_none()),
                            _ => panic!("Attempting to insert duplicate element"),
                        };

                        return Some((l, pivot, r));
                    }
                }
            }

            // Merge split if one happened below us.
            if let Some((l, pivot, r)) = split {
                if self.free(node) >= 1 {
                    // We can insert the split here.
                    self.insert_inline(node, pivot, Some(l), Some(r));
                    return None;
                } else {
                    // Split this node too.
                    let (nl, npivot, nr) = self.split(node);
                    match pivot.cmp(&npivot) {
                        Ordering::Less => self.insert_inline(nl, pivot, Some(l), Some(r)),
                        Ordering::Greater => self.insert_inline(nr, pivot, Some(l), Some(r)),
                        _ => panic!("Attempting to insert duplicate element"),
                    };
                    return Some((nl, npivot, nr));
                }
            } else {
                return None;
            }
        }
    }

    pub fn delete(&mut self, entry: &T) -> Option<T> {
        match self.find_from(self.root, entry) {
            None => None,
            Some((_, node, nodeentry)) => {
                let d = self.yank(self.root, entry);
                d
            }
        }
    }

    // Yank-Refill-Rebalance strategy:
    // 1. find element and remove it, leaving a hole in a node.
    // 2. Fill hole by removing an element from one of the children, creating a new hole below.
    //    Recursively bubble down the whole until a leaf is reached.
    // 3. Rebalance: check if either child is below half capacity and rotate or merge nodes.

    fn yank(&mut self, node: usize, entry: &T) -> Option<T> {
        println!("yank {}", node);
        for i in 0..ORDER-1 {
            if let Some(val) = self.get(node, i) {
                match Cmp::cmp(entry, val) {
                    Ordering::Less => {
                        let d = self.yank(self.get_child(node, i).unwrap(), entry);
                        self.rebalance(node, i);
                        return d;
                    },
                    Ordering::Equal => {
                        let d = self.pop(node, i);
                        self.rebalance(node, i);
                        return d;
                    },
                    Ordering::Greater => {
                        // Last element or no more entries
                        if i == ORDER-2 || self.get(node, i+1).is_none() {
                            let d = self.yank(self.get_child(node, i+1).unwrap(), entry);
                            self.rebalance(node, i);
                            return d;
                        } else {
                            continue
                        }
                    }
                }
            }
        }
        None
    }

    // Fill hole in node at entryix by rotating or merging children's elements.
    fn pop(&mut self, node: usize, entryix: usize) -> Option<T> {
        println!("pop {}/{}", node, entryix);
        if self.is_leaf(node) {
            let d = self.get_mut(node, entryix).take();
            for i in entryix..ORDER-2 {
                *self.get_mut(node, i) = self.get_mut(node, i+1).take();
            }
            return d;
        }
        // Invariant: every entry has two children unless in leaf node.
        let (left, right) = (self.get_child(node, entryix).unwrap(), self.get_child(node, entryix+1).unwrap());
        let (leftcount, rightcount) = (self.count_in(left), self.count_in(right));
        // Replace with element from fuller child.
        let d = self.get_mut(node, entryix).take();
        if leftcount > rightcount {
            *self.get_mut(node, entryix) = self.pop(left, leftcount-1);
        } else {
            *self.get_mut(node, entryix) = self.pop(right, 0);
        }
        // TODO: illegal "crossing"!
        d
    }

    fn rebalance(&mut self, node: usize, entryix: usize) {
        if self.is_leaf(node) {
            return;
        }
        println!("rebalance {}/{}", node, entryix);
        let (left, right) = (self.get_child(node, entryix).unwrap(), self.get_child(node, entryix+1).unwrap());
        let (leftcount, rightcount) = (self.count_in(left), self.count_in(right));
        if leftcount >= ORDER/2 && rightcount >= ORDER/2 {
            return;
        }
        assert!(leftcount < ORDER-1 || rightcount < ORDER-1);
        // Merge or rotate?
        if leftcount + rightcount + 1 > ORDER-1 {
            assert!(leftcount >= ORDER/2 || rightcount >= ORDER/2);
            // Rotate
            let piv = self.get_mut(node, entryix).take();
            if leftcount > rightcount {
                // Move one element from left to right child.
                println!("before {:?}", self.v[right]);

                *self.get_child_mut(right, ORDER-1) = self.get_child_mut(right, ORDER-2).take();
                for i in (1..ORDER-1).rev() {
                    // Shuffle right node.
                    *self.get_child_mut(right, i) = self.get_child_mut(right, i-1).take();
                    *self.get_mut(right, i) = self.get_mut(right, i-1).take();
                }
                *self.get_mut(right, 0) = piv;


                *self.get_child_mut(right, 0) = self.get_child_mut(left, leftcount).take();
                *self.get_mut(node, entryix) = self.get_mut(left, leftcount-1).take();
                println!("after {:?}", self.v[right]);
            } else {
                println!("before {:?}", self.v[left]);
                *self.get_mut(left, leftcount) = piv;
                *self.get_child_mut(left, leftcount+1) = self.get_child_mut(right, 0).take();
                //*self.get_child_mut(node, entryix) = self.get_child_mut(right, 0).take();
                *self.get_mut(node, entryix) = self.get_mut(right, 0).take();

                for i in 0..ORDER-2 {
                    // Shuffle right node.
                    *self.get_child_mut(right, i) = self.get_child_mut(right, i+1).take();
                    *self.get_mut(right, i) = self.get_mut(right, i+1).take();
                }
                println!("after: {:?}", self.v[left]);
                *self.get_child_mut(right, ORDER-2) = self.get_child_mut(right, ORDER-1).take();
            }
        } else {
            // Merge
            let piv = self.get_mut(node, entryix).take();

            let mut newnode = Node::new();
            let mut count = 0;
            for i in 0..ORDER-1 {
                newnode.entries[i] = self.get_mut(left, i).take();
                newnode.children[i] = self.get_child_mut(left, i).take();
                if newnode.entries[i].is_none() {
                    break
                }
                count += 1;
            }
            // Copy right-most child
            if count == ORDER-1 {
                panic!("Encountered full left node in merge!");
                // newnode.children[ORDER-1] = self.get_child_mut(left, ORDER-1).take();
            }
            newnode.entries[count] = piv;
            count += 1;
            let off = count;
            for i in off..ORDER-1 {
                newnode.entries[i] = self.get_mut(right, i-off).take();
                newnode.children[i] = self.get_child_mut(right, i-off).take();
                if newnode.entries[i].is_none() {
                    break;
                }
                count += 1;
            }
            // Copy right-most child.
            newnode.children[count] = self.get_child_mut(right, count-off).take();

            // TODO: recycle nodes.
            let newnodeix = self.v.len();
            self.v.push(newnode);

            // Replace lchild-piv-rchild in `node`.
            *self.get_child_mut(node, entryix) = Some(newnodeix);
            for i in entryix..ORDER-2 {
                *self.get_mut(node, i) = self.get_mut(node, i+1).take();
                if i > entryix { // Discard former right node.
                    *self.get_child_mut(node, i) = self.get_child_mut(node, i+1).take();
                }
            }
            *self.get_child_mut(node, ORDER-2) = self.get_child_mut(node, ORDER-1).take();

            // Replace root if there is no active entry left.
            if node == self.root && self.count_in(node) == 0 {
                self.root = newnodeix;
            }
        }
    }
}
impl<T: std::fmt::Debug, Cmp: Compare<T>> BTree<T, Cmp> {
    fn debug_print(&self) -> String {
        let mut t = TreeBuilder::new(format!("<{}>", self.root));
        self.print_subtree(&mut t, self.root);
        let mut o: Vec<u8> = Vec::new();
        ptree::write_tree(&t.build(), &mut o);
        String::from_utf8(o).unwrap()
    }
    fn print_subtree(&self, tb: &mut TreeBuilder, node: usize) {
        let node = &self.v[node];
        for i in 0..ORDER {
            if let Some(ch) = node.children[i] {
                self.print_subtree(tb.begin_child(format!("<{}>", ch)), ch);
                tb.end_child();
            }
            if i < ORDER - 1 {
                tb.add_empty_child(format!("{:?}", node.entries[i]));
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::BTree;
    use rand::{rngs::StdRng, RngCore, SeedableRng};
    #[test]
    fn basic_test() {
        let mut bt = BTree::new(16);

        let ns = &[
            15, 2, 21, 35, 4, 62, 38, 27, 88, 1, 12, 45, 11, 16, 23, 20, 22, 24, 25, 26, 28, 29,
            30, 31, 32,
        ];

        for n in ns {
            bt.insert(*n);
        }
        for (i, e) in bt.v.iter().enumerate() {
            println!("{} {:?}", i, e);
        }
        assert_eq!(bt.count(), ns.len());
    }

    #[test]
    fn test_delete() {
        let mut bt = BTree::new(16);

        let ns = &[
            15, 2, 21, 35, 4, 62, 38, 27, 88, 1, 12, 45, 11, 16, 23, 20, 22, 24, 25, 26, 28, 29,
            30, 31, 32,
        ];

        for n in ns {
            bt.insert(*n);
        }
        for e in ns {
            println!("\n\ndeleting {}", e);
            println!("{}", bt.debug_print());
            assert_eq!(bt.delete(e), Some(*e));
        }
    }

    #[test]
    fn randomized_all_reachable() {
        time_test::time_test!("10 x 1000 inserts");
        for seed in &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10] {
            let mut rng = StdRng::seed_from_u64(*seed);
            let N = 100;
            let mut bt = BTree::new(N / 2);
            for i in 0..N {
                let rn = rng.next_u32();
                bt.insert(rn);
                if i + 1 != bt.count() {
                    println!("inserting {}", rn);
                    println!("seed = {}\n{}", seed, bt.debug_print());
                }
                assert_eq!(i + 1, bt.count());
            }
            //println!("{}", bt.debug_print());
        }
    }

    #[test]
    fn test_find() {
        time_test::time_test!("1000 x insert/find");
        let mut rng = StdRng::seed_from_u64(1);
        let N = 1000;
        let mut items = Vec::with_capacity(N);
        let mut bt = BTree::new(N / 3);
        for i in 0..N {
            let rn = rng.next_u32();
            bt.insert(rn);
            items.push(rn);

            for rn in items.iter() {
                assert_eq!(bt.find(rn), Some(rn));
            }
        }

        for (i, rn) in items.iter().enumerate() {
            assert_eq!(bt.find(rn), Some(rn));
        }
    }

    #[test]
    fn test_find_sequential() {
        time_test::time_test!("1000 x insert/find sequential");
        let N = 100;
        let mut bt = BTree::new(N / 3);
        for i in 0..N {
            bt.insert(i);
        }

        for i in 0..N {
            assert_eq!(bt.find(&i), Some(&i));
        }
    }

    #[test]
    fn randomized_stdbtree_test() {
        time_test::time_test!("10 x 1000 inserts");
        for seed in &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10] {
            let mut rng = StdRng::seed_from_u64(*seed);
            let N = 100;
            let mut bt = std::collections::BTreeSet::new();
            for i in 0..N {
                let rn = rng.next_u32();
                bt.insert(rn);
                assert_eq!(i + 1, bt.len());
            }
        }
    }
}