class Node(object):
def __init__(self, data):
self._left, self._right = None, None
self.data = int(data)
def __repr__(self):
return 'Node({})'.format(self.data)
@property
def left(self):
return self._left
@left.setter
def left(self, node):
self._left = node
@property
def right(self):
return self._right
@right.setter
def right(self, node):
self._right = node
class BinarySearchTree(object):
def __init__(self, root=None):
self.root = root
self.search_mode = 'in_order'
# O(logN) time complexity if balanced, it could reduce to O(N)
def insert(self, data, **kwargs):
"""Insert from root"""
BinarySearchTree.insert_node(self.root, data, **kwargs)
# O(logN) time complexity if balanced, it could reduce to O(N)
def remove(self, data):
"""Insert from root"""
BinarySearchTree.remove_node(self.root, data)
@staticmethod
def insert_node(node, data, **kwargs):
node_consturctor = kwargs.get('node_constructor', None) or Node
if node:
if data < node.data:
if node.left is None:
node.left = node_consturctor(data)
else:
BinarySearchTree.insert_node(node.left, data, **kwargs)
elif data > node.data:
if node.right is None:
node.right = node_consturctor(data)
else:
BinarySearchTree.insert_node(node.right, data, **kwargs)
else:
node.data = data
return node
@staticmethod
def remove_node(node, data):
if not node:
return None
if data < node.data:
node.left = BinarySearchTree.remove_node(node.left, data)
elif data > node.data:
node.right = BinarySearchTree.remove_node(node.right, data)
else:
if not (node.left and node.right): # leaf
del node
return None
if not node.left:
tmp = node.right
del node
return tmp
if not node.right:
tmp = node.left
del node
return tmp
predeccessor = BinarySearchTree.get_max_node(node.left)
node.data = predeccessor.data
node.left = BinarySearchTree.remove_node(node.left, predeccessor.data)
return node
def get_min(self):
return self.get_min_node(self.root)
@staticmethod
def get_min_node(node):
if node.left:
return BinarySearchTree.get_max_node(node.left)
return node
def get_max(self):
return self.get_max_node(self.root)
@staticmethod
def get_max_node(node):
if node.right:
return BinarySearchTree.get_max_node(node.right)
return node
def search_decorator(func):
def interface(*args, **kwargs):
res = func(*args, **kwargs)
if isinstance(res, Node):
return res
elif 'data' in kwargs:
for node in res:
if node.data == kwargs['data']:
return node
return res
return interface
@staticmethod
@search_decorator
def in_order(root, **kwargs):
"""left -> root -> right"""
f = BinarySearchTree.in_order
res = []
if root:
left = f(root.left, **kwargs)
if isinstance(left, Node):
return left
right = f(root.right, **kwargs)
if isinstance(right, Node):
return right
res = left + [root] + right
return res
@staticmethod
@search_decorator
def pre_order(root, **kwargs):
"""root -> left -> right"""
f = BinarySearchTree.pre_order
res = []
if root:
left = f(root.left, **kwargs)
if isinstance(left, Node):
return left
right = f(root.right, **kwargs)
if isinstance(right, Node):
return right
res = [root] + left + right
return res
@staticmethod
@search_decorator
def post_order(root, **kwargs):
"""root -> right -> root"""
f = BinarySearchTree.post_order
res = []
if root:
left = f(root.left, **kwargs)
if isinstance(left, Node):
return left
right = f(root.right, **kwargs)
if isinstance(right, Node):
return right
res = left + right + [root]
return res
def traversal(self,
order:"in_order|post_order|post_order"=None,
data=None):
order = order or self.search_mode
if order == 'in_order':
return BinarySearchTree.in_order(self.root, data=data)
elif order == 'pre_order':
return BinarySearchTree.pre_order(self.root, data=data)
elif order == 'post_order':
return BinarySearchTree.post_order(self.root, data=data)
else:
raise NotImplementedError()
def search(self, data, *args, **kwargs):
return self.traversal(*args, data=data, **kwargs)