1. 程式人生 > >bzoj4518 [Sdoi2016]征途(斜率優化dp)

bzoj4518 [Sdoi2016]征途(斜率優化dp)

分析:
斜率優化dp
很多人做斜率優化的時候喜歡畫出斜率
我偏向畫柿子

題目就是把若干個元素分成m份,
每一份的價值是該組中的元素之和
使得m組數的方差最小

平均數:x=(sum[1]+sum[2]+..+sum[m])/m //sum是每一組的價值
x=(Σa[i])/m
方差:s=((sum[1]-x)^2+(sum[2]-x)^2+…+(sum[m]-x)^2)/m

先想狀態轉移方程:
設 f[i][j]表示到第i個點,分成了j段
f[i][k]=f[j][k-1]+(sum[i]-sum[j]-x)^2
sum是字首和

這樣的話就簡單了,這就是一個斜率優化的模板

我們已經說過了:
x=(Σa[i])/m
s=((sum[1]-x)^2+(sum[2]-x)^2+…+(sum[m]-x)^2)/m
最後答案是s*m^2

帶入s的表示式得:
這裡寫圖片描述

ans=Σ(sum[i]-sum[j])^2*m-sum[n]^2

只要把狀態轉移方程f[i][k]=f[j][k-1]+(sum[i]-sum[j]-x)^2
變成f[i][k]=f[j][k-1]+(sum[i]-sum[j])^2 即可

之後就是畫柿子

斜率優化

我們假設k < j < i
如果j的決策比k的決策要好
則有
這裡寫圖片描述

左邊那一大坨是一個斜率的形式,
我們可以用ta來優化了

設g[j][k]=(那一大坨式子)

g[j][k] < sum[i]
表示j比k更優

現在關鍵來了
設k < j < i
如果g[i][j] < g[j][k],那麼j點永遠不可能成為最優轉移點

解釋一下
我們假設g[i][j] < sum[i],那麼也就是說i要比j優,排除j

若g[i][j]>=sum[i] 也就是說j比i要優
但是g[i][k]>g[i][j],說明k比j還要優

接下來看看怎麼找最優解
設k < j < i
我們排除了g[i][j] < g[j][k]
則整個有效點集呈現一種上凸性質,即k <=> j的斜率要大於j <=> i的斜率

做法可以總結如下:
1.用一個單調佇列來維護解集
2.假設佇列中從頭到尾已經有元素a,b,c
那麼當d要入隊的時候,我們維護佇列的上凸性質,
即如果g[d][c] < g[c][b],那麼就將c點刪除
直到找到g[d][x]>=g[x][y]為止,並將d點加入在該位置中
3.求解時候,從隊頭開始,如果已有元素a,b,c,
當i點要求解時,如果g[b][a] < sum[i],那麼說明b點比a點更優,a點可以排除,於是a出隊
直到g[x][y]>=sum[i]
當前點就從y轉移

最終答案:

min(f[n])*m-sum[n]^2

tip

我為什麼要做斜率優化!!!
我討厭式子

這是一個二維的方程,所以我們需要另一個數組記錄上一層的狀態,輔助dp
在計算斜率的轉移的時候我們都要用上一層的狀態

我和學姐對式子的處理方式不一樣,
我超虛的,然而我一A了!!!

這裡寫程式碼片
#include<cstdio>
#include<iostream>
#include<cstring>
#include<cmath>
#define ll long long

using namespace std;

const ll INF=1e16;
int n,m; 
ll a[3010];
ll f[3010],g[3010],q[3010],tou,wei,x=0;

ll sqr(ll x)
{
    return x*x;
}

double get(int j,int k)
{
    return (double)(g[j]+sqr(a[j])-g[k]-sqr(a[k]))/(double)(2*(a[j]-a[k]));
}

void doit()
{
    int i,j;
    ll ans=INF;
    tou=wei=1;
    for (int i=1;i<=n;i++) g[i]=INF;   //記錄上一層狀態 
    g[0]=0;
    for (i=1;i<=m;i++)
    {
        tou=wei=0;
        for (j=1;j<=n;j++)
        {
            while (tou<wei&&get(q[tou+1],q[tou])<a[j]) tou++;
            f[j]=g[q[tou]]+sqr(a[j]-a[q[tou]]);
            while (tou<wei&&get(j,q[wei])<get(q[wei],q[wei-1])) wei--;
            q[++wei]=j;
        }
        ans=min(ans,f[n]);
        for (int j=1;j<=n;j++) g[j]=f[j];
    }
    printf("%lld",m*ans-sqr(a[n]));
}

int main()
{
    scanf("%d%d",&n,&m);
    for (int i=1;i<=n;i++) scanf("%lld",&a[i]),a[i]+=a[i-1];
    doit();
    return 0;
}