1. 程式人生 > >【洛谷11月月賽T3】【P4996】咕咕咕(組合數)

【洛谷11月月賽T3】【P4996】咕咕咕(組合數)

遲到的題解 昨天亂翻的時候感覺這道題挺有意思的

一眼看過去狀態壓縮亂搜

轉移方程大概是 設f[i]表示從0轉移到i的遺憾值之和 f[i]=sigma(f[j])+val[i]*dis[i] dis[i]=sigma(dis[j]) 其中j是i的子集,dis[i]表示從0轉移到i的方案數之和

妙啊有70分了

// luogu-judger-enable-o2
// luogu-judger-enable-o2
#include<bits/stdc++.h>
#define N 20
#define mod 998244353
#define ll long long
using namespace std;
int n,m;
ll val[1<<(N+1)],dp[1<<(N+1)],dis[1<<(N+1)];
char s[N+5];
inline ll dfs(int state)
{
    if(state==0)    return val[state];
    if(dp[state]!=-1)   return dp[state];
    dp[state]=0;
    for(int i=state&(state-1);i;i=(i-1)&state)
    {
        dp[state]=(dp[state]+dfs(i))%mod;
        dis[state]=(dis[state]+dis[i])%mod;
    }
    dp[state]=(dp[state]+dfs(0))%mod;    //因為上面運算元集沒有算到0
    dis[state]=(dis[state]+dis[0])%mod;
    dp[state]=(dp[state]+val[state]*dis[state]%mod)%mod;
    return dp[state];
}
int main()
{
    cin>>n>>m;
    for(int i=1;i<=m;i++)
    {
        scanf("%s",s+1);
        ll sum=0;
        for(int j=1;j<=n;j++)
             if(s[j]=='1') 
                sum+=(1<<(n-j));
        cin>>val[sum];
    }
    int max_state=(1<<n)-1;
    dis[0]=1;
    memset(dp,-1,sizeof(dp));
    dfs(max_state);
    cout<<dp[max_state];
    return 0;
}

然而正解是組合數

分析上面的過程你會發現:要是能夠快速的算出從0到當前狀態的方案數就好了

然後又發現和具體的狀態沒有什麼關係 之和0,1的個數有關係

記f[i]表示i個1的方案數 先預處理出來

顯然f[i]可以從子集j轉移過來 可以用到組合數的思想 即sigma(c(i,j)*f[j])

然後對於每個可能帶來歉意的狀態 根據一個乘法原理 相當於先把這n個0變為i個1 再把i個1變為n個1

#include<bits/stdc++.h>
#define N 25
#define mod 998244353
#define ll long long
using namespace std;
int n,m;
ll ans,c[N+5][N+5],f[N];
char s[N];
void init()
{
    for(int i=0;i<=N;i++)   c[i][0]=1,c[i][i]=1;
    for(int i=1;i<=N;i++)
    {
        for(int j=1;j<=i;j++)
        {
            c[i][j]=(c[i-1][j]+c[i-1][j-1])%mod;
        }
    }
    f[1]=1,f[0]=1;
    for(int i=2;i<=20;i++)
        for(int j=0;j<i;j++)
            f[i]=(f[i]+c[i][j]*f[j]%mod)%mod;
}
int main()
{
    cin>>n>>m;
    init();
    for(int i=1;i<=m;i++)
    {
        scanf("%s",s+1);
        ll cnt=0;
        for(int j=1;j<=n;j++)
             if(s[j]=='1') cnt++;
        ll x;
        cin>>x;
        ans=(ans+x*f[cnt]%mod*f[n-cnt]%mod)%mod; 
    }
    cout<<ans;
    return 0;
}