您的位置:首页 > 其它

Leetcode: Kth Smallest Element in a BST

2015-12-19 13:06 176 查看
Given a binary search tree, write a function kthSmallest to find the kth smallest element in it.

You may assume k is always valid, 1 ≤ k ≤ BST's total elements.

Follow up:
What if the BST is modified (insert/delete operations) often and you need to find the kth smallest frequently? How would you optimize the kthSmallest routine?


Try to utilize the property of a BST.
What if you could modify the BST node's structure?
The optimal runtime complexity is O(height of BST).

Java Solution 1 - Inorder Traversal

We can inorder traverse the tree and get the kth smallest element. Time is O(n).

* Definition for a binary tree node.
* public class TreeNode {
*     int val;
*     TreeNode left;
*     TreeNode right;
*     TreeNode(int x) { val = x; }
* }
public class Solution {
public int kthSmallest(TreeNode root, int k) {
TreeNode node = root;
Stack<TreeNode> st = new Stack<TreeNode>();
int counter = 0;
while (!st.isEmpty() || node != null) {
if (node != null) {
node = node.left;
else {
node = st.pop();
if (counter == k) return node.val;
node = node.right;
return -1;

Recursion method:

public class Solution {
int count = 0;

public int kthSmallest(TreeNode root, int k) {
List<Integer> res = new ArrayList<Integer>();
helper(root, k, res);
return res.get(0);

public void helper(TreeNode root, int k, List<Integer> res) {
if (root == null) return;
helper(root.left, k, res);
if (count == k) res.set(0, root.val);
helper(root.right, k, res);

Java Solution 2 - Extra Data Structure

We can let each node track the order, i.e., the number of elements that are less than itself(left Subtree size). Time is O(log(n)).



public int kthSmallest(TreeNode root, int k) {
int count = countNodes(root.left);
if (k <= count) {
return kthSmallest(root.left, k);
} else if (k > count + 1) {
return kthSmallest(root.right, k-1-count); // 1 is counted as current node

return root.val;

public int countNodes(TreeNode n) {
if (n == null) return 0;

return 1 + countNodes(n.left) + countNodes(n.right);
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息