Just do a binary search on the tree. Very basic, no more to say.
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, x):
# self.val = x
# self.left = None
# self.right = None
class Solution:
def searchBST(self, root: TreeNode, val: int) -> TreeNode:
tmp = root
while tmp:
if tmp.val == val:
return tmp
elif tmp.val < val:
tmp = tmp.right
else:
tmp = tmp.left
return None