1. 程式人生 > >LOJ6356 四色燈 容斥 dp

LOJ6356 四色燈 容斥 dp

題目連結

題意:
你有 n n 個燈,每個燈有四種可能的顏色,一開始都是第一種顏色,有 m m 種操作,每種操作是一個 x

i x_i ,表示把 x i x_i 的倍數全都變成下一個顏色,第四次變化後會變回第一次的顏色,問你從這 m
m
種操作中隨機選出一個集合的操作去進行,期望有多少個顏色為初始的第一種顏色的燈。 n < = 1 e 9 ,
m < = 20 n<=1e9,m<=20
,答案對998244353取模。

題解:
首先感覺這個題可能會 2 m 2^m 列舉集合,於是這樣我們可以轉化成求總方案數,最後再除以一個 2 m 2^m 就可以了。我們可能會跟著直覺往dp或者容斥方向去想,但是這個題確實不好想。先%%%y_immortal大佬,他向我推薦的這個題,並且寫了全網第一篇題解,我也是跟他學的。

首先我們應該可以想到,一個燈最後還是初始顏色的條件是它被變化顏色的次數是4的倍數。我們發現一個數字它被變化多次的條件是它是兩個操作的lcm的倍數,那麼我們先設 f [ S ] f[S] 表示 [ 1 , n ] [1,n] 中有多少個數是集合 S S 的公倍數,但是我們會發現,這樣考慮是會有重複的,就是小集合裡的數的公倍數可能會在它的超集裡再次被計算。所以我們想定義一個 g [ S ] g[S] ,表示在 [ 1 , n ] [1,n] 中有多少個數是 S S 中元素的lcm的倍數,並且不是任何一個超集的lcm的倍數的個數。 g [ S ] g[S] 需要用容斥去算。

但是就算是算出來了剛才那些,還是會複雜度爆炸啊,於是這個題正解的做法是考慮把元素個數相同的 f f g g 一起算出來。對於 f f 比較好算,再原來的基礎上算的時候記錄一下當前集合有多少個元素,然後加到對應元素數的地方就行了。重點是求這個 g g 。首先,我們先把 g g 的初始值設為 f f 的初始值,然後考慮容斥,容斥的思路還是減去超集中的答案。我們用一個 m 2 m^2 的複雜度來計算每一個 g [ i ] g[i] 的答案,計算的方法是對於當前的 i i ,列舉所有比它大的集合,考慮從當前的答案中減去是當前集合lcm的倍數同時大集合lcm的倍數的數,過程中要乘一個組合數。求出 g g 之後就可以計算答案了,答案就是 i = 0 m g [ i ] j = 0 & j % 4 = 0 m C i j 2 m i \sum_{i=0}^mg[i]*\sum_{j=0\&j\%4=0}^mC_{i}^j*2^{m-i} 。原因是你在一個 i i 個操作裡選任意4的倍數個操作都會是合法的。

感覺還是很神仙的一道題。

程式碼:

#include <bits/stdc++.h>
using namespace std;

int n,m,a[30];
long long f[30],g[30],c[51][51],ans;
const long long mod=998244353;
inline long long ksm(long long x,long long y)
{
	long long res=1;
	while(y)
	{
		if(y&1)
		res=res*x%mod;
		x=x*x%mod;
		y>>=1;
	}
	return res;
} 
inline long long gcd(long long x,long long y)
{
	return y?gcd(y,x%y):x;
}
int main()
{
	scanf("%d%d",&n,&m);
	for(int i=1;i<=m;++i)
	scanf("%d",&a[i]);
	for(int i=0;i<=50;++i)
	c[i][0]=1;
	for(int i=1;i<=50;++i)
	{
		for(int j=1;j<=i;++j)
		c[i][j]=(c[i-1][j-1]+c[i-1][j])%mod;
	}
	int mx=(1<<m);
	for(int i=0;i<mx;++i)
	{
		int ji=0;
		long long lcm=1;
		for(int j=1;j<=m;++j)
		{
			if(i&(1<<(j-1)))
			{
				lcm=lcm*a[j]/gcd(lcm,a[j]);
				if(lcm>n)
				break;
				++ji;
			}
		}
		if(lcm>n)
		continue;
		f[ji]=(f[ji]+n/lcm)%mod;
	}	
	for(int i=0;i<=m;++i)
	g[i]=f[i];
	for(int i=m;i>=0;--i)
	{
		for(int j=i+1;j<=m;++j)
		g[i]=(g[i]-g[j]*c[j][i]%mod+mod)%mod;
	}
	for(int i=0;i<=m;++i)
	ans=(ans+g[i]*((c[i][0]+c[i][4]+c[i][8]+c[i][12]+c[i][16]+c[i][20])%mod)%mod*ksm(2,m-i)%mod)%mod;
	ans=ans*ksm(ksm(2,m),mod-2)%mod;
	printf("%lld\n",ans);
	return 0;
}