1. 程式人生 > >2829 (斜率優化dp)

2829 (斜率優化dp)

思路:還是先列出普通的dp方程,dp[i][j]=min(dp[k][j-1]+s(i,k+1))  表示到前i個站點,用了j次爆炸,得到的最小值,其中s(i,k+1)表示從k+1到i的連乘積。

直接做必然是n^3 因為需要列舉k。考慮到dp[i][j],當j為定值,它隨著i單調遞增的,所以考慮使用斜率優化來做。

在分析之前先處理一下任意區間的連乘積問題:

有兩種解決方案: 1. 記錄字首和和 字首平方和 (具體原因分析前幾項就行)

                              2. 記錄字首和 和字首區間乘積 (意思是每新加入一個元素,都會增加 a[i]*(sum[i-1])項,sum[i-1]代表字首和)

下面就是一般的套路:

設y<k<i (還是那句話,考慮單調性不考慮定義域就是在刷流氓)

假定:dp[k][j-1] +c[i]- c[k] -sum[k]*(sum[i]-sum[k]) 優於 dp[y][j-1] +c[i] -c[y] -sum[y]*(sum[i] - sum[y]) 

其中 c[i]代表解決方案二中的 字首區間 乘積。

那麼必然滿足:dp[k][j-1] +c[i]- c[k] -sum[k]*(sum[i]-sum[k]) < dp[y][j-1] +c[i] -c[y] -sum[y]*(sum[i] - sum[y])  (因為我們求的是最小值)

化簡得:(dp[k][j-1] - c[k] +sum[k]*sum[k] -( dp[y][j-1] -c[y] +sum[y]*sum[y] ) )  / (sum[k] -sum[y])  < sum[i] 

觀察到sum[i]是單調遞增的,而前面的式子可以看做是平面上兩點構成的斜率,也就是說我們的斜率在不斷上升,而我們需要求出最小值。所以取到最值的點必然是下凸的。(這是從座標圖上來分析的)

從數值的角度分析:當我們已經知道 k比 y優的時候,y就已經完全沒有用處了,同時,當我們考慮將新點i加入時,也要考慮維護一個下凸值。也就是說 如果 y<k<i 如何 ky 的斜率大於 ik的斜率,那麼k點必然是沒有用的,具體分析可以參考前面的一篇部落格。

下面就來分析這題的具體實現:首先這是一個二維的dp,我們對不同的j需要維護不同的單調佇列 ,因為我們之前的分析,都是建立在j不變的情況下討論的。

下面的程式碼是採用記錄字首區間乘積實現的(記錄字首平方和大概是乘0.5導致精度問題,所以wa了,思想是沒問題的)

程式碼1:字首區間乘積(已AC)

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll inf=1e18;
const int maxn=1e3+7;
const int mod=1e9+7;
ll sum[maxn],sp[maxn];
ll dp[maxn][maxn];
int a[maxn];
int q[maxn];
int n,m;
ll solve(int k,int y,int j) //y<k,且k優於y
{
    if(sum[k]==sum[y])
    {
        if(dp[k][j-1]-sp[k]<dp[y][j-1]-sp[y])
        {
            return -1;
        }
        else
        {
            return inf;
        }
    }
    return ((dp[k][j-1]-sp[k]+sum[k]*sum[k])-(dp[y][j-1]-sp[y]+sum[y]*sum[y]))/
            (sum[k]-sum[y]);
}
int main()
{
    #ifndef ONLINE_JUDGE
        freopen("in.txt","r",stdin);
        freopen("out.txt","w",stdout);
    #endif
    while(scanf("%d%d",&n,&m)!=EOF&&(n+m))
    {
        sum[0]=sp[0]=0;
        for(int i=1;i<=n;i++)
        {
            scanf("%d",&a[i]);
            sum[i]=sum[i-1]+a[i];
            sp[i]=sp[i-1]+a[i]*sum[i-1];
        }
        int head,tail;
        // for(int j=1;j<=m;j++)
        // {
        //     dp[1][j]=a[1];
        // }
        for(int i=1;i<=n;i++)
        {
            dp[i][0]=sp[i];
        }
        for(int j=1;j<=m;j++)
        {
            head=tail=0;
            q[0]=j;
            for(int i=j+1;i<=n;i++)
            {
                while(head<tail&&solve(q[head+1],q[head],j)<sum[i])
                {
                    head++;
                }
                dp[i][j]=dp[q[head]][j-1]+sp[i]-sp[q[head]]-sum[q[head]]*(sum[i]-sum[q[head]]);
                while(head<tail&&solve(q[tail],q[tail-1],j)>solve(i,q[tail],j))
                {
                    tail--;
                } 
                q[++tail]=i;
            }
        }
        printf("%lld\n",dp[n][m]);
    }
    return 0;
}

程式碼2:字首平方和(WA了)

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll inf=1e18;
const int maxn=1e3+7;
const int mod=1e9+7;
ll sum[maxn],sp[maxn];
ll dp[maxn][maxn];
int a[maxn];
int q[maxn];
int n,m;
ll solve(int k,int y,int j) //y<k,且k優於y
{
    if(sum[k]==sum[y])
    {
        if(dp[k][j-1]+0.5*sp[k]<dp[y][j-1]+0.5*sp[y])
        {
            return -1;
        }
        else
        {
            return inf;
        }
    }
    return (dp[k][j-1]+(sum[k]*sum[k]+sp[k])/2-(dp[y][j-1]+(sum[y]*sum[y]+sp[y])/2))/
            (sum[k]-sum[y]);
}
int main()
{
    #ifndef ONLINE_JUDGE
        freopen("in.txt","r",stdin);
        freopen("out.txt","w",stdout);
    #endif
    while(scanf("%d%d",&n,&m)!=EOF&&(n+m))
    {
        sum[0]=sp[0]=0;
        for(int i=1;i<=n;i++)
        {
            scanf("%d",&a[i]);
            sum[i]=sum[i-1]+a[i];
            sp[i]=a[i]*a[i]+sp[i-1];
        }
        int head,tail;
        // for(int j=1;j<=m;j++)
        // {
        //     dp[1][j]=a[1];
        // }
        for(int i=1;i<=n;i++)
        {
            dp[i][0]=sum[i]*sum[i]-sp[i];
        }
        for(int j=1;j<=m;j++)
        {
            head=tail=0;
            q[0]=j;
            for(int i=j+1;i<=n;i++)
            {
                while(head<tail&&solve(q[head+1],q[head],j)<sum[i])
                {
                    head++;
                }
                dp[i][j]=dp[q[head]][j-1]+((sum[i]-sum[q[head]])*(sum[i]-sum[q[head]])-(sp[i]-sp[q[head]]))/2;
                while(head<tail&&solve(q[tail],q[tail-1],j)>solve(i,q[tail],j))
                {
                    tail--;
                } 
                q[++tail]=i;
            }
        }
        printf("%lld\n",dp[n][m]);
    }
    return 0;
}