1. 程式人生 > >位元組跳動冬令營網路賽 D.The Easiest One(貪心 數位DP)

位元組跳動冬令營網路賽 D.The Easiest One(貪心 數位DP)

題目連結

\(x:\ 11010011\)
\(y:\ 10011110\)
(下標是從高位往低位,依次是\(1,2,...,n\)
比如對於這兩個數,先找到最高的滿足\(x\)\(0\)\(y\)\(1\)的一位\(j\),顯然我們還要找比\(i\)高的最近的一位\(i\),滿足第\(i\)\(x\)\(1\)\(y\)\(0\)
然後我們要將\(x\)\(i\)之後的位上的\(1\)全變成\(0\),然後\(x\)-=\(1\),才能使得\(x\)\(j\)這一位為\(1\)。這樣\(x\)的第\(i\)位變為\(0\),之後全變為\(1\),顯然此時還需要的代價就是 \(n-i-\ y的後i位中1的個數\)


而之前的代價就是 \(x後i位中1的個數\)。再算上\(i\)及前\(i\)位的代價,記\(cnt(x)\)\(x\)的二進位制表示中\(1\)的個數,\(x\)變成\(y\)的總代價其實就是 \(cnt(x)-cnt(y)+n-i\)
這樣就可以數位DP了。
\(f[i][las][0/1][0/1][0/1]\)表示 當前考慮到第\(i\)位,最近的滿足\(x\)\(1\)\(y\)\(0\)的位是\(las\),是否處於上界,是否\(x\)已經大於\(y\)(要保證\(x\geq y\)),是否已統計過答案(只在最高的那位\(j\)統計答案),此時的總答案。
同樣還需要記\(g[i][las][0/1][0/1][0/1]\)
表示方案數,用來統計答案。
轉移時列舉\(x,y\)當前位填什麼就行了。
狀態數\(O(n^2)\)程式碼在這兒,博文最下面也有。


也有\(O(n)\)的解法(我用記搜寫了下,程式碼在這兒,博文最下面也有):
orz \(tangjz\)的題解後,發現\(las\)這一維很容易就省去了...
原本記\(las\)是為了計算\(n-i\)這部分貢獻,放到程式碼中就是\(ans\)+=\(g*(n-i)\)\(g\)是方案數)。
現在我們只記是否出現過\(i\)(其實是三進位制,分別表示未出現\(i,j\)、出現過\(i\)還沒出現\(j\)\(i,j\)都出現了),如果出現過\(i\)

,轉移的時候就\(ans\)+=\(g\)
這樣對於\(i\)\(g\)的貢獻還是會算\(n-i\)次。
這樣複雜度就是\(O(n)\)了。


解釋一下\(tangjz\)的程式碼。。
寫的從低位往高位的遞推,每次列舉當前要計算的位\(i\),然後列舉\(j,k\)\(j\)是上界,\(k\)是那個三進位制),據此列舉(要滿足這個狀態)第\(i\)\(x,y\)\(0\)還是\(1\)
因為是從低位往高位(要注意轉移對狀態),所以\(k=2\)是說還沒有找到\(v\)\(k=1\)是指找到了\(v\)但沒有確定\(u\)\(k=0\)是指\(u,v\)都已經確定了。
那麼看轉移,\(j\)這維和記搜寫法是一樣的...(其實哪一維和記搜都差不多)。
下面忽略\(j\)這一維,只考慮\(k\)。為了不混用,下面的\(u,v\)就是最上面\(x\)變成\(y\)這一過程中的\(i,j\),代價還是\(cnt[x]+cnt[y]+n-u\)
\(k=0\)\(f[i][0]\)可以\(f[i-1][0],f[i-1][1]\)轉移。但要從\(f[i-1][1]\)轉移就是令第\(i\)位為\(u\),也就是需滿足\(x=1,y=0\),此時既可以\(f[i-1][0]\)也可以從\(f[i-1][1]\)轉移。
\(k=1\)\(f[i][1]\)可以\(f[i-1][1],f[i-1][2]\)轉移。同理要從\(f[i-1][2]\)轉移就是令第\(i\)位為\(v\),也就是要滿足\(x=0,y=1\),且此時只能從\(f[i-1][2]\)轉移。
上面兩個轉移雖然形式像但是不同。
\(k=2\)\(f[i][2]\)只能從\(f[i-1][2]\)轉移。
\(k=1\)\(2\)時,也就是\(u\)還沒出現,此時\(ans\)+=\(g\)(每次答案都累加一次方案數),這樣就能統計\(n-u\)這個答案了。
還不懂的話看程式碼吧...感覺說不太清楚啊QAQ


\(O(n^2)\)

//0.68s 16.9MB
#include <cstdio>
#include <cstring>
#include <algorithm>
#define mp std::make_pair
#define pr std::pair<int,int>
#define mod 1000000007
#define Add(x,v) (x+=v)>=mod&&(x-=mod)
typedef long long LL;
const int N=505;

int A[N];
pr f[N][N][8];
char s[N];

pr DFS(int x,int las,int s)//s:lim,f1,f2 f1:是否已x>y f2:是否已統計過答案 
{
    if(!x) return mp(0,1);
    if(f[x][las][s].first!=-1) return f[x][las][s];
    LL res1=0,res2=0;
    int lim=s&1, f1=s>>1&1, f2=s>>2&1;
    for(int i=0; i<2; ++i)
        for(int j=0; j<2; ++j)
        {
            if(lim && i>A[x]) continue;
            if(!f1 && i<j) continue;
            pr v=DFS(x-1,i>j?x:las,(lim&&i==A[x])|((f1||i>j)<<1)|((f2||i<j)<<2));
            res1+=v.first, res2+=v.second;
            if(i) res1+=v.second;
            if(j) res1+=mod-v.second;
            if(!f2 && i<j) res1+=1ll*v.second*(las-1)%mod;
        }
    return f[x][las][s]=mp(res1%mod,res2%mod);
}

int main()
{
//  freopen("yjqaa.in","r",stdin);
//  freopen("yjqaa.out","w",stdout);

    scanf("%s",s+1); int n=strlen(s+1);
    std::reverse(s+1,s+1+n);
    for(int i=1; i<=n; ++i) A[i]=s[i]-'0';
    for(int i=1; i<=n; ++i)
        for(int j=0; j<=n; ++j)
            f[i][j][0].first=f[i][j][1].first=f[i][j][2].first=f[i][j][3].first=
            f[i][j][4].first=f[i][j][5].first=f[i][j][6].first=f[i][j][7].first=-1;
//          f[i][j][0][0][0].first=f[i][j][0][0][1].first=f[i][j][0][1][0].first=f[i][j][0][1][1].first=
//          f[i][j][1][0][0].first=f[i][j][1][0][1].first=f[i][j][1][1][0].first=f[i][j][1][1][1].first=-1;
    printf("%d\n",DFS(n,0,1).first);

    return 0;
}

\(O(n)\)

//4ms   488KB
#include <cstdio>
#include <cstring>
#include <algorithm>
#define mp std::make_pair
#define pr std::pair<int,int>
#define mod 1000000007
#define Add(x,v) (x+=v)>=mod&&(x-=mod)
typedef long long LL;
const int N=505;

int A[N];
pr f[N][2][3];
char s[N];

void operator +=(pr &x,pr y)
{
    Add(x.first,y.first), Add(x.second,y.second);
}
pr DFS(int x,int lim,int s)//s:0/1/2
{
    if(!x) return s==1?mp(0,0):mp(0,1);
    if(f[x][lim][s].first!=-1) return f[x][lim][s];
    LL res1=0,res2=0;
    const int xL=0, xR=lim?A[x]:1;
    for(int i=xL; i<=xR; ++i)
    {
        const int yL=s==1?i:0, yR=!s?i:1;
        for(int j=yL; j<=yR; ++j)
        {
            pr v;
            if(!s)
            {
                v=DFS(x-1,lim&&i==xR,0);
                if(i && !j) v+=DFS(x-1,lim&&i==xR,1);
            }
            else if(s==1)
            {
                if(!i && j) v=DFS(x-1,lim&&i==xR,2);
                else v=DFS(x-1,lim&&i==xR,1);
            }
            else v=DFS(x-1,lim&&i==xR,2);
            res1+=v.first, res2+=v.second;
//          if(i) res1+=v.second;
//          if(j) res1+=mod-v.second;
//          if(s) res1+=v.second;
            res1+=(i-j+(s>0))*v.second;
        }
    }
    return f[x][lim][s]=mp(res1%mod,res2%mod);
}

int main()
{
//  freopen("yjqaa.in","r",stdin);
//  freopen("yjqaa.out","w",stdout);

    int T; scanf("%d",&T);
    while(T--)
    {
        scanf("%s",s+1); int n=strlen(s+1);
        std::reverse(s+1,s+1+n);
        for(int i=1; i<=n; ++i) A[i]=s[i]-'0';
        for(int i=1; i<=n; ++i)
            f[i][0][0].first=f[i][0][1].first=f[i][0][2].first=
            f[i][1][0].first=f[i][1][1].first=f[i][1][2].first=-1;
        printf("%d\n",DFS(n,1,0).first);
    }

    return 0;
}