We could solve this problem recursivly. For each tree, we could always get the sub-tree who are greater than the key and another sub-tree who is smaller or equal to the original tree.
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, x):
# self.val = x
# self.left = None
# self.right = None
class Solution:
def splitBST(self, root: TreeNode, V: int) -> List[TreeNode]:
if root is None:
return [None, None]
if root.val > V:
l = self.splitBST(root.left, V)
root.left = l[1]
return [l[0], root]
else:
r = self.splitBST(root.right, V)
root.right = r[0]
return [root, r[1]]