1. 程式人生 > >hdu 4358樹狀陣列

hdu 4358樹狀陣列

題解:

資料結構題,利用離線加樹狀陣列統計答案

            我們考慮這個問題的弱化版:給你N個數,求一個區間內不同的數字有多少個,這個問題可以通過記錄每個權值上次出現的位置,在O(N lgN)時間內得到解決,具體的做法是將查詢按右端點排序,從左向右列舉i維護一個樹狀陣列,其k位表示ki不同的數字有多少個,我們考慮將i個數字加入,這時候,會樹狀陣列中發生改變的僅僅是last[v] + 1i個區間內的值,其中v是第i位數字,last[v]表示v這個數上次出現的位置,如果沒有出現過就為0。所以,我們把樹狀陣列last[v]+1這個位置的值+1,把i+1這個位置-1,這樣我們求和的時候,只要是左端點(對左端點求和)在last[v]+1之後(包括自身)的,就能將v這點給包括進來了。然後列舉到

i之後,我們考慮右端點是i的全部查詢,看其左端點在樹狀陣列中的值是多少就可以了。

      接下來我們來考慮我們這個問題,將樹形結構轉化成線性結構,那麼我們就可以將問題轉化為求一個區間內,恰好出現K次的權值有多少種。我們記錄樹狀陣列第k位表示ki的答案,假設v出現的位置是p1; p2; p3;    ; pk,那麼我們假設現在列舉到了pk這個位置,將pk這個位置的數字加入集合之後,p(k−K−1) + 1p(k−K)這部分割槽間內權值v出現次數就超過K了,p(k−K )+ 1p(k−K+1)這部分割槽間內權值v出現的次數恰好達K,所以我們將樹狀陣列中的p(k−K−1) + 1

p(k−K)內的值全部減1p(k−K) +1p(k−K+1)內的值全部加1。查詢的時候是查詢左端點前數的和。總體複雜度O(N lgN)的。

      關鍵是想到變成線性序列,剩下也就簡單了。

這是核心核心又核心的程式碼,應該好好體會的。

         做題情況。剛剛表揚了自己,就又在陣列大小上面栽了跟斗,這裡建的是雙向邊,我卻只開了滿足單向的陣列。不懂的是為什麼杭電總是報的是WA呢?

還有就是初始化。L和c是必須重新初始化的。為什麼? 因為上一case的L和c,會影響這個case。為什麼?比如你上一次讓L[i]有了一個值,然而這一次卻還沒訪問到,那再dfs裡面豈不是永遠都訪問不到?目測WA了有個把小時。

         還有就是這題的作法十分令我佩服。但是我估計著這也是一個固定作法。如果以後再碰到此型別,應該就是水題一枚了。

         還有兩種作法。一種是線段樹,一種是用map。

         線段樹的作法是插入刪去一條線段。作法和之前的flowers很像,區間更新,單點查詢。

樹狀陣列核心程式碼:

for (int i=1; i<=n; i++){
            int v = val[i]; pl[v].push_back(i);
            int g = pl[v].size()-1;
            if (g >= k){
                if (g > k){
                    insert(pl[v][g-k-1]+1,-1);
                    insert(pl[v][g-k]+1,1);
                }
                insert(pl[v][g-k]+1, 1);
                insert(pl[v][g-k+1]+1, -1);
            }
            while (qn[t].y==i){
                ans[qn[t].id]=query(qn[t].x);//左端點向前的和
                t++;
            }
        }

線段樹核心程式碼:

for (int i = 1; i <= n; ++i) {  
            // 線段樹第j個數表示[j, i]間出現k次的數的個數   
            int num = a[i];  
            vv[num].push_back(i);  
            int size = vv[num].size();  
            if (size >= k) {  
                if (size > k) {  
                    // 1 ~ vv[num][size-k-1]都減1,就是刪除這條線段   
                    update(1, n, 1, 1, vv[num][size-k-1], -1);  
                    // vv[num][size-k-1]+1 ~ vv[num][size-k]都加1   ,插入
                    update(1, n, 1, vv[num][size-k-1] + 1, vv[num][size-k], 1);  
                } else {  
                    // 加1   ,插入
                    update(1, n, 1, 1, vv[num][size-k], 1);  
                }  
            }  
            while (Q[idx].r == i) {  
                ans[Q[idx].id] = query(1, n, 1, Q[idx].l);  
                idx++;  
            }  
        } 
樹狀陣列:
/*
Pro: 0

Sol:

date:
*/
#include <cstdio>
#include <iostream>
#include <cstring>
#include <map>
#include <iostream>
#include <vector>
#include <algorithm>
#define maxn 100010
using namespace std;
int n,k,t,a[maxn],Q,head[maxn],esub,linr[maxn],L[maxn],R[maxn],li,ans[maxn];
int c[maxn];
vector < int > pl[maxn];//用來記錄值出現的位置
bool cmpx(int x, int y){
    return a[x] < a[y];
}
void dis(){
    int tmp = -1,r[maxn];
    for(int i = 1; i <= n; i ++)
        r[i] = i;
    sort(r + 1, r + 1 + n, cmpx);
    int prev = a[r[1]] - 1;//必須要有一個prev將原來的值記錄下來,不然原來的值就儲存不下來了
    for(int i = 1; i <= n; i ++)
        if(prev != a[r[i]]) prev = a[r[i]],a[r[i]] = ++ tmp;
        else { a[r[i]] = tmp;}

    for(int i = 0; i <= tmp; i ++){
        pl[i].clear();//初始化很重要
        pl[i].push_back(0); //假設每個值都在第0個位置出現
    }
}
struct query{
    int L,R,id;
    bool operator < (const query& cmp) const{
        return R < cmp.R;
    }
}q[maxn];
struct Edge{
    int v,nxt;
}edge[maxn << 1 ];//
void init(){
    esub = 0;   li = 0;
    memset(head,-1,sizeof(head));
    memset(c,0,sizeof(c));
    memset(L,0,sizeof(L));
}
void add(int u, int v){
    edge[esub].v = v;
    edge[esub].nxt = head[u];
    head[u] = esub ++;
}
void dfs(int rt){
    L[rt] = ++li;   linr[li] = a[rt];
    for(int j = head[rt]; j != -1; j = edge[j].nxt){
        if(!L[edge[j].v])   dfs(edge[j].v);
    }
    R[rt] = li;
}
void modify(int pos, int val){
    while(pos <= n){
        c[pos] += val;
        pos += (pos & (-pos) );
    }
}
int getsum(int pos){
    int sum = 0;
    while(pos){
        sum += c[pos];
        pos -= (pos & (-pos) );
    }return sum;
}
int main(){
    scanf("%d",&t);
    for(int ca = 1; ca <= t; ca ++){
        printf("Case #%d:\n",ca);
        scanf("%d%d",&n,&k);
        init();//
        for(int i = 1; i <= n; i ++){
            scanf("%d",&a[i]);
        }
        dis();//從小到大離散化,可以不用從小到大,只要相同的依舊相同,不同的依舊不同就行

        for(int i = 1; i < n; i ++){
            int a,b;
            scanf("%d%d",&a,&b);
            add(a,b);
            add(b,a);
        }
        dfs(1); //產生一個線性序列
        scanf("%d",&Q);
        for(int i = 1; i <= Q; i ++){
            int x;
            scanf("%d",&x);
            q[i].L = L[x];
            q[i].R = R[x];
            q[i].id = i;
        }   sort(q + 1, q + 1 + n);
        int qsub = 1;
        for(int i = 1; i <= n; i ++){
            int v = linr[i];  pl[v].push_back(i);//記錄linr[i]出現的位置
            int g = pl[v].size() - 1;
            if(g >= k){
                if(g > k){
                    modify(pl[v][g - k - 1] + 1, -1);
                    modify(pl[v][g - k] + 1, 1);
                }
                modify(pl[v][g - k] + 1, 1);
                modify(pl[v][g - k + 1] + 1, -1);
            }
            while( q[qsub].R == i){
                ans[q[qsub].id] = getsum(q[qsub].L);
                qsub ++;
            }
        }
        for(int i = 1; i <= Q; i ++)
            printf("%d\n",ans[i]);

        if(ca < t) puts("");
    }
    return 0;
}

線段樹:(貼的)

#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
#include <vector>
#include <map>

using namespace std;

const int maxn = 100010;
int t, n, k, q;
int w[maxn], a[maxn];
int L[maxn], R[maxn];   // 詢問區間的左、右端點 
int ans[maxn], id;
vector<int> vt[maxn];   // 臨接表 
vector<int> vv[maxn];
bool vis[maxn];
map<int, int> mp;
struct Query {
    int l, r, id;
}Q[maxn];
// for segment tree:
int add[maxn<<2];

bool cmp(Query q1, Query q2)
{
    return q1.r < q2.r;
}

void dfs(int x)
{   // 將樹形結構變成線性結構 
    vis[x] = true;
    L[x] = id;
    a[id] = w[x];
    int size = vt[x].size();
    for (int i = 0; i < size; ++i) {
        if (!vis[vt[x][i]]) {
            id++;
            dfs(vt[x][i]);
        }
    }
    R[x] = id;
}

void pushDown(int rt)
{
    if (add[rt]) {
        add[rt<<1] += add[rt];
        add[rt<<1|1] += add[rt];
        add[rt] = 0;
    }
    return ;
}

void build(int l, int r, int rt)
{
    add[rt] = 0;
    if (l == r) return ;
    int m = (l + r) >> 1;
    build(l, m, rt << 1);
    build(m + 1, r, rt << 1 | 1);
}

void update(int l, int r, int rt, int L, int R, int c)
{
    if (L <= l && R >= r) {
        add[rt] += c;
        return ;
    }
    pushDown(rt);
    int m = (l + r) >> 1;
    if (L <= m) {
        update(l, m, rt << 1, L, R, c);
    }
    if (R > m) {
        update(m + 1, r, rt << 1 | 1, L, R, c);
    }
}

int query(int l, int r, int rt, int p)
{
    if (l == r) {
        return add[rt];
    }
    pushDown(rt);
    int m = (l + r) >> 1;
    if (p <= m) {
        return query(l, m, rt << 1, p);
    } else {
        return query(m + 1, r, rt << 1 | 1, p);
    }
}

int main()
{
    scanf("%d", &t);
    for (int cas = 1; cas <= t; ++cas) {
        scanf("%d%d", &n, &k);
        mp.clear();
        id = 1;
        for (int i = 1; i <= n; ++i) {
            scanf("%d", &w[i]);
            // 離散化 
            if (mp[w[i]] == 0) {
                mp[w[i]] = id++;
            }
            w[i] = mp[w[i]];
        }
        int u, v;
        for (int i = 0; i < maxn; ++i) {
            vt[i].clear();
            vv[i].clear();
        }
        for (int i = 1; i < n; ++i) {
            scanf("%d%d", &u, &v);
            vt[u].push_back(v);
            vt[v].push_back(u);
        }
        memset(vis, false, sizeof(vis));
        id = 1;
        dfs(1);
        scanf("%d", &q);
        for (int i = 0; i < q; ++i) {
            scanf("%d", &u);
            Q[i].id = i;
            Q[i].l = L[u];
            Q[i].r = R[u];
        }
        sort(Q, Q + q, cmp);
        build(1, n, 1);
        int idx = 0;
        for (int i = 1; i <= n; ++i) {
            // 線段樹第j個數表示[j, i]間出現k次的數的個數 
            int num = a[i];
            vv[num].push_back(i);
            int size = vv[num].size();
            if (size >= k) {
                if (size > k) {
                    // 1 ~ vv[num][size-k-1]都減1 
                    update(1, n, 1, 1, vv[num][size-k-1], -1);
                    // vv[num][size-k-1]+1 ~ vv[num][size-k]都加1 
                    update(1, n, 1, vv[num][size-k-1] + 1, vv[num][size-k], 1);
                } else {
                    // 加1 
                    update(1, n, 1, 1, vv[num][size-k], 1);
                }
            }
            while (Q[idx].r == i) {
                ans[Q[idx].id] = query(1, n, 1, Q[idx].l);
                idx++;
            }
        }
        if (cas != 1) {
            printf("\n");
        }
        printf("Case #%d:\n", cas);
        for (int i = 0; i < q; ++i) {
            printf("%d\n", ans[i]);
        }
    }
    return 0;
}

用map寫的(貼的):
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<map>
#include<queue>
#include<cmath>
#define LL long long
//c++提交的話 需要手動開棧(看解析上的) g++不用 但時間花費長
using namespace std;
const int N=100005;
int head[N];//鄰接表 表頭
struct node
{
    int j;
    int next;
} side[N*2];
int son[N];//兒子節點個數
int f[N];//儲存父節點
int id[N];//某個節點對應相關資料存在了哪個map裡面
//由於合併的關係 剛開始 i和id[i]對應相等 但一定的合併後就變了
int ans[N];//儲存答案 離線儲存要用
map<int ,int>str[N];
map<int ,int>:: iterator it,it1;
queue<int>qu;
int K;
void build(int x,int i)
{
    side[i].next=head[x];
    head[x]=i;
}
void dfs(int x,int pre)//用dfs求的本節點的父節點 和兒子節點數  和將葉子節點加入佇列
{
    int t=head[x];
    f[x]=pre;
    while(t!=-1)
    {
        if(side[t].j!=pre)
        {
            dfs(side[t].j,x);
            ++son[x];
        }
        t=side[t].next;
    }
    if(son[x]==0)
    {
        qu.push(x);
    }
}
void Add(int I,int i,int j)//將 map j 合併到i當中 將答案ans[I]進行更新
{
    for(it=str[j].begin(); it!=str[j].end(); ++it) //it 遍歷map j 的元素
    {
        it1=str[i].find(it->first);//到i 中查詢
        if(it1==str[i].end())//未找到
        {
            str[i][it->first]=it->second;//加入
            if(it->second==K)//如果正好等於K 則答案加一
                ++(ans[I]);
            continue;
        }
        if(it1->second==K)//如果找到 本來在i中 這個數出現次數正好為K
        //在加上一個非0數則不等於K 了所以答案個數減1
            --(ans[I]);
        if((it1->second+=it->second)==K)//如果加上正好為K 答案個數加1 這不會和上一個衝突
            ++(ans[I]);
    }
}
int main()
{
    int T;
    scanf("%d",&T);
    for(int cas=1; cas<=T; ++cas)
    {
        memset(head,-1,sizeof(head));
        memset(ans,0,sizeof(ans));
        memset(son,0,sizeof(son));//各種初始化
        while(!qu.empty())
            qu.pop();
        int n,q;
        scanf("%d %d",&n,&K);
        for(int i=1; i<=n; ++i)
        {
            int a;
            str[i].clear();//清空
            id[i]=i;//本來每個節點 的map 就是自己對應的
            scanf("%d",&a);
            str[i][a]=1;//先都加入 本節點的數
            if(K==1)//如果K正好為1 則答案也更新
                ++ans[i];
        }
        int x,y;
        for(int i=1; i<n; ++i) //輸入邊 建樹
        {
            scanf("%d %d",&x,&y);
            side[i].j=y;
            build(x,i);
            side[i+n].j=x;
            build(y,i+n);
        }
        dfs(1,0);//搜一個
        while(!qu.empty())
        {
            int l=qu.front();//取元素
            qu.pop();
            int fa=f[l];//fa為父親節點
            --son[fa];//父親節點的可以更新的兒子節點減少1
            if(son[fa]==0)//這是最後一個兒子節點 這是將fa加入佇列 千萬不能第一次更新就加入
            {
                //否則會出現父節點在兒子節點前面的情況 就錯了(自己輸在這裡了)  見上面資料
                qu.push(fa);
            }
            int li=id[l];//找到對應的map
            int fai=id[fa];
            if(str[fai].size()>str[li].size())//比較大小
            {//小的往大的裡面加
                Add(fa,fai,li);
            }
            else{
                ans[fa]=ans[l];//如果需要往l上合併 則先等於l的ans 在下面更新後ans[fa]
                //始終儲存當前對應map裡面的答案
                id[fa]=li;//更改對應的map
                Add(fa,li,fai);//更新
            }
        }
        scanf("%d",&q);
        printf("Case #%d:\n",cas);
        while(q--)
        {
            scanf("%d",&x);
            printf("%d\n",ans[x]);
        }
        if(cas<T)
            printf("\n");
    }
    return 0;
}