1. 程式人生 > >C. 痛苦的 01 矩陣 (推公式,樹狀陣列維護)

C. 痛苦的 01 矩陣 (推公式,樹狀陣列維護)

現有一個 n×n 的 01 矩陣 M。

定義 cost(i,j) 為:把第 i 行和第 j 列全部變成 1 最少需要改動多少個元素。

定義矩陣的痛苦值 pain(M) 為:

pain(M)=(∑i=1n∑j=1n(cost(i,j))2)mod(109+7)

要求求出初始矩陣的痛苦值和每次修改操作之後的痛苦值。

Input

第一行三個正整數 n,k,q (2≤n≤2⋅105, 1≤k≤min(n2,2⋅105), 0≤q≤2⋅105)。k 表示這個矩陣中有 k 個 1。q 表示修改操作次數。

接下來 k 行,每行兩個正整數 xi, yi (1≤xi,yi≤n),表示有一個 1 在第 xi 行,第 yi 列。保證所有 (xi,yi) 各不相同。

接下來 q 行,每行兩個正整數 ui, vi (1≤ui,vi≤n),表示修改第 ui 行,第 vi 列。如果該位置原先為 0,則改為 1;如果該位置原先為 1,則改為 0。

Output

輸出 q+1 行,依次為所有修改發生前的痛苦值,和每次修改操作後的痛苦值。

Examples

Input

3 4 9
1 1
1 2
2 3
3 1
3 3
1 2
1 3
2 2
2 2
2 1
3 1
1 1
2 3

Output

73
48
75
52
29
52
33
52
77
104
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef unsigned long long ULL;

#define rep(i,a,b) for(int i=a;i<b;++i)
#define per(i,a,b) for(int i=b-1;i>=a;--i)
#define lowbit(x)  (x&(-x))

const int mod=1e9+7;

const int N=2e5+10;
LL tr_r2[N],tr_c2[N],tr_r[N],tr_c[N];

LL n;
void update(LL tr[],int x,LL val)
{
	val%=mod;
    while(x<=n) {
        tr[x]=(tr[x]+val)%mod;
        x+=lowbit(x);
    }
}

LL query(LL tr[],LL x)
{
    LL res=0;
    while(x>0) {
        res=(res+tr[x])%mod;
		if(res<0)res+=mod;
		x-=lowbit(x);
    }
    return res;
}

LL r[N],c[N];

set<int> st[N];
LL s;

LL solve(LL n)
{

    LL ans1=query(tr_r2,n),ans2=query(tr_c2,n);

    LL ans3=query(tr_r,n), ans4=query(tr_c,n);

    //printf("ans1:%lld ans2:%lld ans3:%lld ans4:%lld\n",ans1,ans2,ans3,ans4);

    ans1=ans1*(n-2)%mod;
    ans2=ans2*(n-2)%mod;
    ans3=ans3*ans4%mod;
    ans3=ans3*2%mod;

    ans1=(((ans1+ans2)%mod+s)%mod+ans3)%mod;
    return ans1;
}

void change(int x,int y,LL v)
{
    update(tr_r2,x,-r[x]*r[x]);
    update(tr_r2,x,(r[x]+v)*(r[x]+v));

    update(tr_c2,y,-c[y]*c[y]);
    update(tr_c2,y,(c[y]+v)*(c[y]+v));

    update(tr_r,x,-r[x]);
    update(tr_r,x,r[x]+v);

    update(tr_c,y,-c[y]);
    update(tr_c,y,c[y]+v);

    r[x]=r[x]+v; //if(r[x]>=mod)r[x]-=mod; if(r[x]<=-mod)r[x]+=mod;
    c[y]=c[y]+v; //if(c[y]>=mod)c[y]-=mod; if(c[y]<=-mod)c[y]+=mod;
   // printf("y:%d c[y]:%lld\n\n",y,c[y]);
    s=s+v; s%=mod;//if(s>=mod)s-=mod; if(s<=-mod)s+=mod;
}
/*
3 4 9
1 1
1 2
2 3
3 1

3 3
1 2
1 3

2 2
2 2
2 1

3 1
1 1
2 3
*/

int main()
{
    LL K,Q;
    scanf("%lld %lld %lld",&n,&K,&Q);

    s=n*n%mod;
    for(int i=1; i<=n; i++) {
        update(tr_r2,i,n*n);
        update(tr_c2,i,n*n);
        update(tr_r,i,n);
        update(tr_c,i,n);
        c[i]=n;r[i]=n;
    }

    rep(i,0,K) {
        int x,y;
        scanf("%d %d",&x,&y);
        st[x].insert(y);
		change(x,y,-1);
    }
   // printf("s:%lld\n",s);
	//rep(i,1,n+1)printf("i:%d %lld %lld\n",i,r[i],c[i]);

    LL ans=solve(n);

    printf("%lld\n",ans);

    rep(i,0,Q) {
        int x,y;
        scanf("%d %d",&x,&y);
		if(st[x].count(y)){
			change(x,y,1);
			st[x].erase(y);
		}else{
			change(x,y,-1);
			st[x].insert(y);
		}
		ans=solve(n);
		printf("%lld\n",ans);
    }
    return 0;
}