1. 程式人生 > >牛客練習賽28 D 隨風飄(dp + 字串雜湊)

牛客練習賽28 D 隨風飄(dp + 字串雜湊)

能用字串雜湊解決的問題,千萬別用字尾陣列、字典樹什麼的了……

這題有很多個詢問,每次詢問是從n箇中拿走k個字串,問拿走之後的答案。我們顯然不能把所有拿走的方案列舉一遍,所以考慮計算每一個字串的貢獻。這裡我的貢獻指第i個字串與它前面的字串的貢獻。而這個貢獻就是計算當前串與前面所有串的lcp。這裡千萬不要看到lcp就去想字尾陣列,這裡是多個串的lcp,而不是一個串的lcp,所以後綴陣列顯然是麻煩了。字典樹是一個好的選擇,但是字串雜湊顯然更簡單。對於每個串不斷的把它所有字首的雜湊值記錄下來,然後每個串的貢獻,就可以看它前綴出現的次數。

然後我們考慮dp。令dp[i][j]表示考慮前i個串,從中取走j個串的答案。那麼可以分為第i個串取走或者不被取走,有:

                  \large dp[i][j]=dp[i-1][j-1]+(dp[i-1][j]+C_{i-2}^{j}*f[i])*[i-j>2]

前面部分表示第i個串取走,後面部分表示第i個串不取走。不取走的話就是前面i-1個串取走j個的答案,加上第i個串對前面產生的貢獻。出現這種情況的次數是 \large C_{i-1}^{j} ,第i個串對前面產生的單次總貢獻是f[i],這樣合起來就是\large C_{i-1}^{j}*f[i],但是由於有j個串被刪除了所以得去掉一些。對於去掉的部分,我們再次考慮貢獻。如果第k個串被刪掉,那麼多加的答案就是\large C_{i-2}^{j-1}*lcp(i,k)那麼總的就是:

                          \large \sum_{k=1}^{i-1}C_{i-2}^{j-1}*lcp(i,k)=C_{i-2}^{j-1}*\sum_{k=1}^{i-1}lcp(i,k)=C_{i-2}^{j-1}*f[i]

所以說最後就是 \large C_{i-1}^{j}*f[i]-C_{i-2}^{j-1}*f[i]=C_{i-2}^j*f[i]。dp求解,然後對於每個詢問O(1)輸出即可。最後你會發現,答案其實就是dp[n][0]*C(n-2,j)。具體見程式碼:

#include<bits/stdc++.h>
#define LL long long
#define pb push_back
#define lb lower_bound
#define ub upper_bound
#define INF ((1LL<<31)-1)
#define PI 3.1415926535
#define sf(x) scanf("%d",&x)
#include<ext/pb_ds/tree_policy.hpp>
#include<ext/pb_ds/hash_policy.hpp>
#include<ext/pb_ds/priority_queue.hpp>
#include<ext/pb_ds/assoc_container.hpp>
#define sc(x,y,z) scanf("%d%d%d",&x,&y,&z)
#define clr(x,n) memset(x,0,sizeof(x[0])*(n+5))
#define file(x) freopen(#x".in","r",stdin),freopen(#x".out","w",stdout)
 
using namespace std;
using namespace __gnu_pbds;
 
const int N = 4010;
const int mod1 = 100001651;
const int mod2 = 100001623;
const int mod = 1000000007;
  
LL dp[N][310],c[N][N],f[N];
gp_hash_table<LL,int> t; 
char s[3000010];
int n,q;
 
LL add(LL pre1,LL pre2,int cur)
{
    pre1=(pre1*111+cur)%mod1;
    pre2=(pre2*100007+cur)%mod2;
    return pre1<<31|pre2;
}
  
int main()
{
    sf(n); sf(q);
    for(int i=1;i<=n;i++)
    {
        LL now=0;
        scanf("%s",s);
        for(int j=0;s[j];j++)
        {
            now=add(now>>31,now&INF,s[j]-'a');
            f[i]=(f[i]+t[now])%mod2; t[now]++;
        }
    }
    for(int i=0;i<=n;i++) c[i][0]=1;
    for(int i=1;i<=n;i++)
        for(int j=1;j<=i;j++)
            c[i][j]=(c[i-1][j-1]+c[i-1][j])%mod;
    for(int i=1;i<=n;i++)
        for(int j=0;j<=min(i,300);j++)
        {
            if (j) dp[i][j]=dp[i-1][j-1];
            if (i-j>=2) 
            {
                dp[i][j]=(dp[i][j]+dp[i-1][j])%mod;
                dp[i][j]=(dp[i][j]+c[i-2][j]*f[i]%mod)%mod;
            }
        }
    while(q--)
    {
        int x; sf(x);
        printf("%lld\n",dp[n][x]);
    }
    return 0;
}