diff --git a/bintree.h b/bintree.h new file mode 100644 index 0000000..521e594 --- /dev/null +++ b/bintree.h @@ -0,0 +1,455 @@ +#include + +#ifdef DUMP +#include +using std::cout; +using std::endl; +#endif + +typedef unsigned int Umword; +typedef int Mword; + +template +class Bin_tree; + +template +struct Bin_tree_node { +private: + friend class Bin_tree<_Key>; + + void recalc() + { + Umword leftHeight = + (left == nullptr) ? 0 : left->height + 1; + Umword rightHeight = + (right == nullptr) ? 0 : right->height + 1; + + balance = leftHeight - rightHeight; + height = (leftHeight < rightHeight ? rightHeight : leftHeight); + } + +public: + typedef _Key Key; + +protected: + typedef Bin_tree_node Self; + typedef Bin_tree Tree; + + constexpr Bin_tree_node(Key key) + : key(key), parent(nullptr), left(nullptr), + right(nullptr), tree(nullptr), + balance(0), height(0) + {} + + constexpr Bin_tree_node() + : Bin_tree_node(0) + {} + + Self* parent; + Self* left; + Self* right; + Tree* tree; + + // AVL Tree + Mword balance; + Umword height; + +public: + Key key; +}; + +template +class Bin_tree { +protected: + typedef Bin_tree_node Node; +private: + Node* root; + + Node* findByKey(Key key) const + { + Node* iterparent = nullptr; + Node* iter = root; + + while(iter != nullptr) + { + iterparent = iter; + + if(key < iter->key) { + iter = iter->left; + } else if(key > iter->key) { + iter = iter->right; + } else { + break; + } + } + + return iterparent; + } + + Node* findLeft(Node* start) const + { + assert(start != nullptr); + Node* iter = start; + while(iter->left != nullptr) + { + iter = iter->left; + } + return iter; + } + + void replaceInParent(Node* node, Node* replace) { + assert(node->parent); + assert(node->parent->left == node + || node->parent->right == node); + + assert(node->parent); + + if(node->parent->left == node) + { + node->parent->left = replace; + if(replace) replace->parent = node->parent; + } + else if(node->parent->right == node) + { + node->parent->right = replace; + if(replace) replace->parent = node->parent; + } + } + + void leafRemove(Node* leaf) + { + assert(leaf->left == nullptr + || leaf->right == nullptr); + + if(leaf->left == nullptr && leaf->right == nullptr) + { + if(leaf->parent != nullptr) + { + assert(leaf->parent->left == leaf + || leaf->parent->right == leaf); + replaceInParent(leaf, nullptr); + rebalance(leaf->parent, 1, 2); + } + else + root = nullptr; + } + else if(leaf->left == nullptr) + { + if(leaf->parent != nullptr) + { + assert(leaf->parent->left == leaf + || leaf->parent->right == leaf); + replaceInParent(leaf, leaf->right); + rebalance(leaf->parent, 1, 2); + } + else + { + root = leaf->right; + leaf->right->parent = nullptr; + } + } + else if(leaf->right == nullptr) + { + if(leaf->parent != nullptr) + { + replaceInParent(leaf, leaf->left); + rebalance(leaf->parent, 1, 2); + } + else + { + root = leaf->left; + leaf->left->parent = nullptr; + } + } + + leaf->parent = nullptr; + leaf->left = nullptr; + leaf->right = nullptr; + leaf->tree = nullptr; + + leaf->recalc(); + } + + void + rebalance(Node* start, Mword abortCond, Mword rebalanceCond) + { + for(Node* iter = start; iter != nullptr; iter = iter->parent) + { + iter->recalc(); + + Mword absBalance = + (iter->balance < 0) ? -(iter->balance) : iter->balance; + + if(absBalance == abortCond) + { + break; + } + else if(absBalance == rebalanceCond) + { + // Left side is taller + if(iter->balance > 0) + { + // Left side's right side is taller + if(iter->left && iter->left->balance < 0) + { + rotateLeft(iter->left); + } + rotateRight(iter); + } + // Right side is taller + else + { + // Right side's left side is taller + if(iter->right && iter->right->balance > 0) + { + rotateRight(iter->right); + } + rotateLeft(iter); + } + + // then abort + break; + } + } + } + + void rotateLeft(Node* node) { + assert(node); + assert(node->right); + + Node* partner = node->right; + + node->right = partner->left; + + if(node->right) + node->right->parent = node; + + // No parent, partner is the new root + if(node->parent == nullptr) + root = partner; + else + replaceInParent(node, partner); + + partner->parent = node->parent; + node->parent = partner; + + partner->left = node; + + assert(node->parent == nullptr + || node->parent->left == node + || node->parent->right == node); + + assert(partner->parent == nullptr + || partner->parent->left == partner + || partner->parent->right == partner); + + assert(node->left == nullptr + || node->left->parent == node); + + assert(node->right == nullptr + || node->right->parent == node); + + assert(partner->left == nullptr + || partner->left->parent == partner); + + assert(partner->right == nullptr + || partner->right->parent == partner); + + node->recalc(); + partner->recalc(); + } + + void rotateRight(Node* node) { + assert(node); + assert(node->left); + + Node* partner = node->left; + + node->left = partner->right; + + if(node->left) + node->left->parent = node; + + // No parent, partner is the new root + if(node->parent == nullptr) + root = partner; + else + replaceInParent(node,partner); + + partner->parent = node->parent; + node->parent = partner; + + partner->right = node; + + assert(node->parent == nullptr + || node->parent->left == node + || node->parent->right == node); + + assert(partner->parent == nullptr + || partner->parent->left == partner + || partner->parent->right == partner); + + assert(node->left == nullptr + || node->left->parent == node); + + assert(node->right == nullptr + || node->right->parent == node); + + assert(partner->left == nullptr + || partner->left->parent == partner); + + assert(partner->right == nullptr + || partner->right->parent == partner); + + node->recalc(); + partner->recalc(); + + } + + #ifdef DUMP + void dump(Node* node, int indent = 0) + { + if(node == nullptr) return; + assert(node->tree == this); + + assert(node->left == nullptr + || node->left->parent == node); + dump(node->left, indent+1); + for(int i = 0; i < indent; i++) + cout << " "; + cout << node->key << endl; + + assert(node->right == nullptr + || node->right->parent == node); + dump(node->right, indent+1); + } + #endif + +public: + constexpr Bin_tree() : root(nullptr) + {} + + bool insert(Node* node) + { + assert(node->tree == nullptr); + + if(root == nullptr) + { + root = node; + root->parent = nullptr; + root->left = nullptr; + root->right = nullptr; + root->tree = this; + return true; + } + + Node* parent = findByKey(node->key); + + if(node->key < parent->key) + { + parent->left = node; + } + else if(node->key > parent->key) + { + parent->right = node; + } + else + { + return false; + } + + node->parent = parent; + node->left = nullptr; + node->right = nullptr; + node->tree = this; + node->height = 0; + node->balance = 0; + + rebalance(parent,0,2); + + return true; + } + + void remove(Node* node) + { + assert(node); + assert(node->tree == this); + + if(node->left == nullptr || node->right == nullptr) + { + leafRemove(node); + return; + } + + Node* replacement = findLeft(node->right); + + leafRemove(replacement); + + if(node == root) + { + root = replacement; + replacement->parent = nullptr; + } + else + replaceInParent(node, replacement); + + replacement->left = node->left; + replacement->right = node->right; + replacement->tree = this; + + if(node->left) + node->left->parent = replacement; + + if(node->right) + node->right->parent = replacement; + + replacement->recalc(); + + node->left = nullptr; + node->right = nullptr; + node->parent = nullptr; + node->tree = nullptr; + } + + Node* lookup(Key key) const { + Node* node = findByKey(key); + + if(node == nullptr || node->key != key) + return nullptr; + + return node; + } + + #ifdef DUMP + void dump() + { + cout << "---------------------------------------------" << endl; + dump(root); + cout << "---------------------------------------------" << endl; + } + #endif +}; + +template +class Bin_tree_t : Bin_tree { +private: + typedef Bin_tree Base; +public: + inline + bool insert(T* node) + { return Base::insert(static_cast(node)); } + + inline + void remove(T* node) + { Base::remove(static_cast(node)); } + + inline + T* lookup(Key key) + { return static_cast(Base::lookup(key)); } + + #ifdef DUMP + inline + void dump() + { Base::dump(); } + #endif +};