230. Kth Smallest Element in a BST - cocoder39/coco39_LC GitHub Wiki

230. Kth Smallest Element in a BST

2020 Notes:

Leverage inorder traversal

Time complexity: H is hight. If H >> k then time is O(H + k) = O(H). If H << k then time is O(k) as at most K + H elements will be pushed into stack and k elements will be popped out from stack. So overall time is O(H + K)

class Solution:
    def kthSmallest(self, root: TreeNode, k: int) -> int:
        stack = []
        cur = root
        cnt = 0
        while cur or stack:
            while cur:
                stack.append(cur)
                cur = cur.left
            
            cur = stack.pop()
            cnt += 1
            if cnt == k:
                return cur.val
            cur = cur.right

For follow-up question, a LRU cache based solution was posted by leetcode. However it takes O(k) time to find kth element. A better approach is to record left subtree node count and right subtree node count so that we can find kth element in O(H) time where H is the hight

======================================================================

method 1: inorder traversal

int kthSmallest(TreeNode* root, int k) {
        stack<TreeNode*> s;
        while (root || ! s.empty()) {
            if (root) {
                s.push(root);
                root = root->left;
            }
            else { //root == nullptr && ! S.empty()
                TreeNode* node = s.top();
                s.pop();
                k--;
                if (k == 0) {
                    return node->val;
                }
                root = node->right;
            }
        }
        return root->val;
    }

method 2: take advantage 272. Closest Binary Search Tree Value II, O(h + k) time and O(h) space

class Solution {
public:
    int kthSmallest(TreeNode* root, int k) {
        return closestKValues(root, INT_MIN, k);
    }
int closestKValues(TreeNode* root, int target, int k) {
        int res;
        stack<TreeNode*> succ;
        stack<TreeNode*> pred;

        initStack(pred, succ, root, target);
        while (k-- > 0) {
            if (succ.empty()) { 
                res = nextPred(pred);
            } 
            else if (pred.empty()) {
                res = nextSucc(succ);
            } 
            else {
                double succ_diff = abs((double)succ.top()->val - target);
                double pred_diff = abs((double)pred.top()->val - target);
                res = succ_diff < pred_diff ? nextSucc(succ) : nextPred(pred);
            }
        }
        return res;
    }

    void initStack(stack<TreeNode*>& pred, stack<TreeNode*>& succ, TreeNode* root, int target) { //O(log n)
        while (root) {
            if (root->val <= target) {
                pred.push(root);
                root = root->right;
            }
            else{
                succ.push(root);
                root = root->left;
            }
        }
    }

    int nextSucc(stack<TreeNode*>& succ) {
        TreeNode* cur = succ.top();
        succ.pop();
        int res = cur->val;
        cur = cur->right;
        while (cur) {
            succ.push(cur);
            cur = cur->left;
        }
        return res;
    }

    int nextPred(stack<TreeNode*>& pred) {
        TreeNode* cur = pred.top();
        pred.pop();
        int res = cur->val;
        cur = cur->left;
        while (cur) {
            pred.push(cur);
            cur = cur->right;
        }
        return res;
    }
};

method 3, count to record the number of nodes in the subtree rooted with this node. Thus updating count of all its parents nodes is enough, which takes O(height) time. To achieve O(height) time, updating should go up rather than down (O(n)). Hence, count is total number of subtree instead of a node's rank in sorted order

O(n) time for first query, and O(log n) time for modification.

class Solution {
public:
     int kthSmallest(TreeNode* root, int k) {
        TreeNodeWithCount* rootWithCount = buildTreeWithCount(root);
        return kthSmallest(rootWithCount, k);
    }
    
    class TreeNodeWithCount {
    public:
        int val;
        int count; //#node in subtree rooted with this node
        TreeNodeWithCount* left;
        TreeNodeWithCount* right;
        TreeNodeWithCount(int x) : val(x), count(1), left(nullptr), right(nullptr) {}
    };
private:
    TreeNodeWithCount* buildTreeWithCount(TreeNode* root) {
        if (! root) {
            return nullptr;
        }
        
        TreeNodeWithCount* rootWithCount = new TreeNodeWithCount(root->val);
        rootWithCount->left = buildTreeWithCount(root->left);
        rootWithCount->right = buildTreeWithCount(root->right);
        if (rootWithCount->left) {
            rootWithCount->count += rootWithCount->left->count;
        }
        if (rootWithCount->right) {
            rootWithCount->count += rootWithCount->right->count;
        }
        return rootWithCount;
    }
    
    int kthSmallest(TreeNodeWithCount* rootWithCount, int k) {
        if (rootWithCount->left) {
            if (rootWithCount->left->count == k - 1) {
                return rootWithCount->val;
            }
            else if (rootWithCount->left->count >= k) {
                return kthSmallest(rootWithCount->left, k);
            }
            else {
                return kthSmallest(rootWithCount->right, k - rootWithCount->left->count - 1);
            }
        }
        else {
            if (k == 1) {
                return rootWithCount->val;
            }
            return kthSmallest(rootWithCount->right, k - 1);
        }
    }
};
⚠️ **GitHub.com Fallback** ⚠️