diff --git a/bintree.h b/bintree.h index 1e6ba7d..2b3a940 100644 --- a/bintree.h +++ b/bintree.h @@ -12,17 +12,9 @@ typedef int Mword; template class Bin_tree; -template struct Bin_tree_node { -public: - typedef _Key Key; - -protected: - typedef Bin_tree_node Self; - typedef Bin_tree Tree; - private: - friend class Bin_tree<_Key>; + template friend class Bin_tree; void recalc() { @@ -35,7 +27,7 @@ private: height = (leftHeight < rightHeight ? rightHeight : leftHeight); } - void replaceChild(Self* child, Self* replace) + void replaceChild(Bin_tree_node* child, Bin_tree_node* replace) { assert(child == left || child == right); @@ -52,38 +44,61 @@ private: recalc(); } -public: - constexpr Bin_tree_node(Key key) - : key(key), parent(nullptr), left(nullptr), +protected: + constexpr Bin_tree_node() + : parent(nullptr), left(nullptr), right(nullptr), tree(nullptr), balance(0), height(0) {} - constexpr Bin_tree_node() - : Bin_tree_node(0) - {} - - Self* parent; - Self* left; - Self* right; - Tree* tree; +private: + Bin_tree_node* parent; + Bin_tree_node* left; + Bin_tree_node* right; + void* tree; // AVL Tree Mword balance; Umword height; - -public: - Key key; }; -template +template +class Bin_tree_node_t : public Bin_tree_node { +public: + class Key_trait { + public: + typedef Key Key_type; + + static inline + Key_type get_key(Bin_tree_node* node) + { return static_cast* >(node)->key(); } + + static inline + bool compare(Key_type a, Key_type b) + { return (a < b); } + }; + + Bin_tree_node_t(Key k) : Bin_tree_node(), _key(k) + {} + + inline + Key& key() + { return _key; } + +private: + Key _key; + +}; + +template class Bin_tree { protected: - typedef Bin_tree_node Node; + typedef Bin_tree_node Node; + typedef typename Key_trait::Key_type Key_type; private: Node* root; - Node* findByKey(Key key) const + Node* findByKey(Key_type key) const { Node* iterparent = nullptr; Node* iter = root; @@ -92,13 +107,12 @@ private: { iterparent = iter; - if(key < iter->key) { + if(Key_trait::compare(key, Key_trait::get_key(iter))) iter = iter->left; - } else if(key > iter->key) { + else if(Key_trait::compare(Key_trait::get_key(iter), key)) iter = iter->right; - } else { + else break; - } } return iterparent; @@ -116,25 +130,6 @@ private: return iter; } - // void replaceInParent(Node* node, Node* replace) { - // assert(node->parent); - // assert(node->parent->left == node - // || node->parent->right == node); - - // assert(node->parent); - - // if(node->parent->left == node) - // node->parent->left = replace; - // else if(node->parent->right == node) - // node->parent->right = replace; - - // if(replace) - // { - // replace->parent = node->parent; - // replace->recalc(); - // } - // } - void leafRemove(Node* leaf) { assert(leaf); @@ -164,43 +159,6 @@ private: child->parent = nullptr; } - // if(leaf->left == nullptr && leaf->right == nullptr) - // { - // if(leaf->parent != nullptr) - // { - // replaceInParent(leaf, nullptr); - // rebalance(leaf->parent, 1, 2); - // } - // else - // root = nullptr; - // } - // else if(leaf->left == nullptr) - // { - // if(leaf->parent != nullptr) - // { - // replaceInParent(leaf, leaf->right); - // rebalance(leaf->parent, 1, 2); - // } - // else - // { - // root = leaf->right; - // leaf->right->parent = nullptr; - // } - // } - // else if(leaf->right == nullptr) - // { - // if(leaf->parent != nullptr) - // { - // replaceInParent(leaf, leaf->left); - // rebalance(leaf->parent, 1, 2); - // } - // else - // { - // root = leaf->left; - // leaf->left->parent = nullptr; - // } - // } - leaf->parent = nullptr; leaf->left = nullptr; leaf->right = nullptr; @@ -270,7 +228,6 @@ private: else node->parent->replaceChild(node, partner); - // partner->parent = node->parent; node->parent = partner; partner->left = node; @@ -318,7 +275,6 @@ private: else node->parent->replaceChild(node,partner); - // partner->parent = node->parent; node->parent = partner; partner->right = node; @@ -359,7 +315,7 @@ private: dump(node->left, indent+1); for(int i = 0; i < indent; i++) cout << " "; - cout << node->key << endl; + cout << Key_trait::get_key(node) << endl; assert(node->right == nullptr || node->right->parent == node); @@ -383,12 +339,14 @@ public: else { // else, find a suitable parent - parent = findByKey(node->key); + parent = findByKey(Key_trait::get_key(node)); assert(parent); - if(node->key < parent->key) + if(Key_trait::compare(Key_trait::get_key(node), + Key_trait::get_key(parent))) parent->left = node; - else if(node->key > parent->key) + else if(Key_trait::compare(Key_trait::get_key(parent), + Key_trait::get_key(node))) parent->right = node; else // Node with same key exists return false; @@ -451,7 +409,7 @@ public: node->recalc(); } - Node* lookup(Key key) const + Node* lookup(Key_type key) const { if(root == nullptr) return nullptr; @@ -461,7 +419,8 @@ public: if(node == nullptr) __builtin_unreachable(); - if(node->key != key) + if(Key_trait::compare(Key_trait::get_key(node),key) + || Key_trait::compare(key, Key_trait::get_key(node))) return nullptr; return node; @@ -477,10 +436,12 @@ public: #endif }; -template -class Bin_tree_t : Bin_tree { +template +class Bin_tree_t : Bin_tree { private: - typedef Bin_tree Base; + typedef Bin_tree Base; + typedef typename Key_trait::Key_type Key_type; + public: inline bool insert(T* node) @@ -491,7 +452,7 @@ public: { Base::remove(static_cast(node)); } inline - T* lookup(Key key) + T* find(Key_type key) { return static_cast(Base::lookup(key)); } #ifdef DUMP