1. 程式人生 > >線段樹(區間樹)之區間染色和4n推導過程

線段樹(區間樹)之區間染色和4n推導過程

  前言

  線段樹(區間樹)是什麼呢?有了二叉樹、二分搜尋樹,線段樹又是幹什麼的呢?最經典的線段樹問題:區間染色;正如它的名字而言,主要解決區間的問題

  一、線段樹說明

  1、什麼是線段樹?

  線段樹首先是二叉樹,並且是平衡二叉樹(它是一 棵空樹或它的左右兩個子樹的高度差的絕對值不超過1,並且左右兩個子樹都是一棵平衡二叉樹),並且具有二分性質。

如下圖,就是一顆線段樹:

  

  假如,用陣列表示線段樹,如果區間有n個元素,陣列表示需要有多少節點?

  2、4n節點推導過程

  要進行一下,如果對推導過程不感興趣的,可以直接記住結論,需要4n個節點,推導過程如下圖:  PS:依舊是全部落格園最醜圖,當感覺有進步啊!是不是推薦一下,鼓勵一下啊

 

  說明:感覺用盡了洪荒之力,才推匯出來了。感覺高考之後再也不會用到等比公式了,但又用到了,還是緣分未盡啊,哈哈哈!最後,都放棄了,一直推導不出來,忘卻了最後一層的null,假設是滿二叉樹,按最大值進行估算,所以4n是完全夠大的!

  二、為什麼要使用線段樹

  線段樹主要解決一些區間問題的,如下:  

  1、區間染色

  有一面牆,長度為n,每次選擇一段牆進行染色,m次操作之後,我們可以看見多少種顏色?

  2、區間查詢

  查詢區間[i,j]的最大值、最小值,或者區間數字和;實質:基於區間的統計查詢。

  例如:2018年註冊使用者中消費最高的使用者?消費最低的使用者?學習最長時間的使用者?

  三、程式碼實現

  1、建立線段樹

  二叉樹具有天然遞迴性質,所以用遞迴相對簡單,用迭代也是可以的,我才用遞迴實現,程式碼如下:

template<class T>
class SegmentTree {
private:
    T *tree;
    T *data;
    int size;
    std::function<T(T, T)> function;
    
    int leftChild(int index) {  //左孩子下標;例如用陣列儲存,根節點是下標0,則左孩子為1,右孩子為2
        return
index * 2 + 1; } int rightChild(int index) {  //右孩子下標 return index * 2 + 2; } void buildSegmentTree(int treeIndex, int l, int r) { if (l == r) { tree[treeIndex] = data[l]; return; } int leftTreeIndex = leftChild(treeIndex); int rightTreeIndex = rightChild(treeIndex); int mid = l + (r - l) / 2;  //中間值求法,防止整型溢位 buildSegmentTree(leftTreeIndex, l, mid);  //構建左子樹 buildSegmentTree(rightTreeIndex, mid + 1, r);  //構建右子樹 tree[treeIndex] = function(tree[leftTreeIndex], tree[rightTreeIndex]); } public: SegmentTree(T arr[], int n, std::function<T(T, T)> function) {  //建構函式,構建一棵樹 this->function = function; data = new T[n]; for (int i = 0; i < n; ++i) { data[i] = arr[i]; } tree = new T[n * 4];  //分配4n節點 size = n; buildSegmentTree(0, 0, size - 1); } };

  2、線段樹查詢

   線段樹具有二分查詢性質,所以二分查詢那種思路就可以了,程式碼如下:

T query(int treeIndex, int l, int r, int queryL, int queryR) {
        if (l == queryL && r == queryR) {
            return tree[treeIndex];
        }

        int mid = l + (r - l) / 2;
        int leftTreeIndex = leftChild(treeIndex);
        int rightTreeIndex = rightChild(treeIndex);

        if (queryL >= mid + 1) {
            return query(rightTreeIndex, mid + 1, r, queryL, queryR);
        } else if (queryR <= mid) {
            return query(leftTreeIndex, l, mid, queryL, queryR);
        }

        T leftResult = query(leftTreeIndex, l, mid, queryL, mid);
        T rightResult = query(rightTreeIndex, mid + 1, r, mid + 1, queryR);
        return function(leftResult, rightResult);
    }

T query(int queryL, int queryR) {
        assert(queryL >= 0 && queryL < size && queryR >= 0 && queryR < size && queryL <= queryR);
        return query(0, 0, size - 1, queryL, queryR);
    }

  3、整體程式碼

  SegmentTree.h如下:

#ifndef SEGMENT_TREE_SEGMENTTREE_H
#define SEGMENT_TREE_SEGMENTTREE_H

#include <cassert>
#include <functional>

template<class T>
class SegmentTree {
private:
    T *tree;
    T *data;
    int size;
    std::function<T(T, T)> function;
    
    int leftChild(int index) {
        return index * 2 + 1;
    }

    int rightChild(int index) {
        return index * 2 + 2;
    }
    
    void buildSegmentTree(int treeIndex, int l, int r) {
        if (l == r) {
            tree[treeIndex] = data[l];
            return;
        }
        int leftTreeIndex = leftChild(treeIndex);
        int rightTreeIndex = rightChild(treeIndex);
        int mid = l + (r - l) / 2;

        buildSegmentTree(leftTreeIndex, l, mid);
        buildSegmentTree(rightTreeIndex, mid + 1, r);
        tree[treeIndex] = function(tree[leftTreeIndex], tree[rightTreeIndex]);
    }
    
    T query(int treeIndex, int l, int r, int queryL, int queryR) {
        if (l == queryL && r == queryR) {
            return tree[treeIndex];
        }

        int mid = l + (r - l) / 2;
        int leftTreeIndex = leftChild(treeIndex);
        int rightTreeIndex = rightChild(treeIndex);

        if (queryL >= mid + 1) {
            return query(rightTreeIndex, mid + 1, r, queryL, queryR);
        } else if (queryR <= mid) {
            return query(leftTreeIndex, l, mid, queryL, queryR);
        }

        T leftResult = query(leftTreeIndex, l, mid, queryL, mid);
        T rightResult = query(rightTreeIndex, mid + 1, r, mid + 1, queryR);
        return function(leftResult, rightResult);
    }

public:
    SegmentTree(T arr[], int n, std::function<T(T, T)> function) {
        this->function = function;
        data = new T[n];
        for (int i = 0; i < n; ++i) {
            data[i] = arr[i];
        }
        tree = new T[n * 4];
        size = n;
        buildSegmentTree(0, 0, size - 1);
    }

    int getSize() {
        return size;
    }

    T get(int index) {
        assert(index >= 0 && index < size);
        return data[index];
    }
    
    T query(int queryL, int queryR) {
        assert(queryL >= 0 && queryL < size && queryR >= 0 && queryR < size && queryL <= queryR);
        return query(0, 0, size - 1, queryL, queryR);
    }

    void print() {
        std::cout << "[";
        for (int i = 0; i < size * 4; ++i) {
            if (tree[i] != NULL) {
                std::cout << tree[i];
            } else {
                std::cout << "0";
            }
            if (i != size * 4 - 1) {
                std::cout << ", ";
            }
        }
        std::cout << "]" << std::endl;
    }
};

#endif //SEGMENT_TREE_SEGMENTTREE_H
View Code

  main.cpp如下:

#include <iostream>
#include "SegmentTree.h"

int main() {
    int nums[] = {-2, 0, 3, -5, 2, -1};
    SegmentTree<int> *segmentTree = new SegmentTree<int>(nums, sizeof(nums) / sizeof(int), [](int a, int b) -> int {
        return a + b;
    });
    std::cout << segmentTree->query(2,5) << std::endl;
    segmentTree->print();
    return 0;
}

  4、演示

   執行結果,如下:  

  

  5、時間複雜度分析

  更新  O(logn)

  查詢  O(logn)