1. 程式人生 > >二叉排序樹刪除節點

二叉排序樹刪除節點

#include<iostream>
using namespace std;

class TreeNode {
public:
    int val;
    TreeNode *left, *right;
    TreeNode(int val) {
        this->val = val;
        this->left = this->right = NULL;
    }
};

class Solution {
public:
    TreeNode* removeNode(TreeNode* root, int val) {
        // if root is null, just return
        if (!root)
           return root; 

        // Find node whose val equal to "val"
        TreeNode* parent = NULL;
        TreeNode* curr = root;
        while (curr) {
            if (curr->val == val) 
                break;
            parent = curr;
            if (curr->val > val)
                curr = curr->left;
            else
                curr = curr->right;
        }

        // If there is no node's value equals to val
        if (!curr)
            return root;
        
        // Situation 1, leafe node
        if (!curr->left && !curr->right) {
            if (!parent) {
                root = NULL;
            } else {
                if (parent->left == curr) 
                    parent->left = NULL;
                else 
                    parent->right = NULL;
            }
            delete curr;
        } else if (curr->left && !curr->right) {
            // Situation 2, no right child
            if (!parent) {
                root = curr->left;
            } else {
                if (parent->left == curr) 
                    parent->left = curr->left;
                else 
                    parent->right = curr->left;
            }
            delete curr;
        } else if (!curr->left && curr->right) {
            // Situation 3, no left child
            if (!parent) {
                root = curr->right;
            } else {
                if (parent->left == curr) 
                    parent->left = curr->right;
                else
                   parent->right = curr->right; 
            } 
            delete curr;
        } else {
            // Situation 4, both childs are existing
            TreeNode* left_max = curr->left;
            parent = curr; 
            while (left_max->right) {
                parent = left_max;
                left_max = left_max->right; 
            }
            
            // Swap with curr node
            curr->val = curr->val ^ left_max->val;
            left_max->val = curr->val ^ left_max->val;
            curr->val = curr->val ^ left_max->val;
            
            // Delete left_max
            if (parent->left == left_max)
                parent->left = NULL;
            else if (parent->right == left_max)
                parent->right = NULL;
            delete left_max;
        }       
        return root;
    }
};

void pLevel(TreeNode* node) {
    if (!node)
        return;
    cout << node->val << endl;
    pLevel(node->left);
    pLevel(node->right);
}

int main() {
    TreeNode* node5 = new TreeNode(5);
    TreeNode* node3 = new TreeNode(3);
    TreeNode* node6 = new TreeNode(6);
    TreeNode* node2 = new TreeNode(2);
    TreeNode* node4 = new TreeNode(4);

    node5->left = node3;
    node5->right = node6;
    node3->left = node2;
    node3->right = node4;

    Solution s;
    TreeNode* ret = s.removeNode(node5, 3);
    pLevel(ret);
    return 0;
}