0230. Kth Smallest Element in a BST

https://leetcode.com/problems/kth-smallest-element-in-a-bst

Description

Given the root of a binary search tree, and an integer k, return the kth smallest value (1-indexed) of all the values of the nodes in the tree.

Example 1:

**Input:** root = [3,1,4,null,2], k = 1
**Output:** 1

Example 2:

**Input:** root = [5,3,6,2,4,null,null,1], k = 3
**Output:** 3

Constraints:

  • The number of nodes in the tree is n.

  • 1 <= k <= n <= 104

  • 0 <= Node.val <= 104

Follow up: If the BST is modified often (i.e., we can do insert and delete operations) and you need to find the kth smallest frequently, how would you optimize?

ac1: Recusive

/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode(int x) { val = x; }
 * }
 */

// BST-> in-order traversal
class Solution {
    public int kthSmallest(TreeNode root, int k) {
        // edge cases
        if (root == null || k < 1) return -1;

        List<Integer> res = new ArrayList<Integer>();
        helper(root, k, res);

        return res.get(0);
    }

    private int helper(TreeNode root, int k, List<Integer> res) {
        // exit
        if (root == null || k < 0) return k;

        // handle
        k = helper(root.left, k, res);
        if (k == -1) {
            return -1;
        } else if (k == 1) {
            res.add(root.val);
            return -1;
        } else {
            return helper(root.right, k-1, res);
        }
    }
}

// in-order traversal
class Solution {
    private TreeNode kThNode;

    public int kthSmallest(TreeNode root, int k) {
        countNode(root, k);

        return kThNode.val;
    }

    private int countNode(TreeNode node, int k) {
        if (node == null || k == 0) return 0;

        int leftCount = countNode(node.left, k);
        if (leftCount == k - 1) {
            kThNode = node;
        }

        return 1 + leftCount + countNode(node.right, k - 1 - leftCount);
    }
}

ac2: iterative

/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode(int x) { val = x; }
 * }
 */
class Solution {
    public int kthSmallest(TreeNode root, int k) {
        TreeNode r = root;
        Stack<TreeNode> stack = new Stack<>();

        int res = 0;
        while (!stack.isEmpty() || r != null) {
            while (r != null) {
                stack.push(r);
                r = r.left;
            }
            r = stack.pop();
            k--;
            if (k == 0) res = r.val;
            r = r.right;
        }

        return res;
    }
}

Another way, but less intuitive:

class Solution {
    public int kthSmallest(TreeNode root, int k) {
        Stack<TreeNode> stack = new Stack<>();
        TreeNode curr = root;

        while (curr != null || !stack.isEmpty()) {
            if (curr != null) {
                stack.push(curr);
                curr = curr.left;
            } else {
                TreeNode tmp = stack.pop();
                curr = tmp.right;

                k--;
                if (k == 0) return tmp.val;
            }
        }

        return curr.val;
    }
}

Last updated