bzoj4518 [Sdoi2016]征途(斜率優化dp)
分析:
斜率優化dp
很多人做斜率優化的時候喜歡畫出斜率
我偏向畫柿子
題目就是把若干個元素分成m份,
每一份的價值是該組中的元素之和
使得m組數的方差最小
平均數:x=(sum[1]+sum[2]+..+sum[m])/m //sum是每一組的價值
x=(Σa[i])/m
方差:s=((sum[1]-x)^2+(sum[2]-x)^2+…+(sum[m]-x)^2)/m
先想狀態轉移方程:
設 f[i][j]表示到第i個點,分成了j段
f[i][k]=f[j][k-1]+(sum[i]-sum[j]-x)^2
sum是字首和
這樣的話就簡單了,這就是一個斜率優化的模板
我們已經說過了:
x=(Σa[i])/m
s=((sum[1]-x)^2+(sum[2]-x)^2+…+(sum[m]-x)^2)/m
最後答案是s*m^2
帶入s的表示式得:
ans=Σ(sum[i]-sum[j])^2*m-sum[n]^2
只要把狀態轉移方程f[i][k]=f[j][k-1]+(sum[i]-sum[j]-x)^2
變成f[i][k]=f[j][k-1]+(sum[i]-sum[j])^2 即可
之後就是畫柿子
斜率優化
我們假設k < j < i
如果j的決策比k的決策要好
則有
左邊那一大坨是一個斜率的形式,
我們可以用ta來優化了
設g[j][k]=(那一大坨式子)
g[j][k] < sum[i]
表示j比k更優
現在關鍵來了
設k < j < i
如果g[i][j] < g[j][k],那麼j點永遠不可能成為最優轉移點
解釋一下
我們假設g[i][j] < sum[i],那麼也就是說i要比j優,排除j
若g[i][j]>=sum[i] 也就是說j比i要優
但是g[i][k]>g[i][j],說明k比j還要優
接下來看看怎麼找最優解
設k < j < i
我們排除了g[i][j] < g[j][k]
則整個有效點集呈現一種上凸性質,即k <=> j的斜率要大於j <=> i的斜率
做法可以總結如下:
1.用一個單調佇列來維護解集
2.假設佇列中從頭到尾已經有元素a,b,c
那麼當d要入隊的時候,我們維護佇列的上凸性質,
即如果g[d][c] < g[c][b],那麼就將c點刪除
直到找到g[d][x]>=g[x][y]為止,並將d點加入在該位置中
3.求解時候,從隊頭開始,如果已有元素a,b,c,
當i點要求解時,如果g[b][a] < sum[i],那麼說明b點比a點更優,a點可以排除,於是a出隊
直到g[x][y]>=sum[i]
當前點就從y轉移
最終答案:
min(f[n])*m-sum[n]^2
tip
我為什麼要做斜率優化!!!
我討厭式子
這是一個二維的方程,所以我們需要另一個數組記錄上一層的狀態,輔助dp
在計算斜率的轉移的時候我們都要用上一層的狀態
我和學姐對式子的處理方式不一樣,
我超虛的,然而我一A了!!!
這裡寫程式碼片
#include<cstdio>
#include<iostream>
#include<cstring>
#include<cmath>
#define ll long long
using namespace std;
const ll INF=1e16;
int n,m;
ll a[3010];
ll f[3010],g[3010],q[3010],tou,wei,x=0;
ll sqr(ll x)
{
return x*x;
}
double get(int j,int k)
{
return (double)(g[j]+sqr(a[j])-g[k]-sqr(a[k]))/(double)(2*(a[j]-a[k]));
}
void doit()
{
int i,j;
ll ans=INF;
tou=wei=1;
for (int i=1;i<=n;i++) g[i]=INF; //記錄上一層狀態
g[0]=0;
for (i=1;i<=m;i++)
{
tou=wei=0;
for (j=1;j<=n;j++)
{
while (tou<wei&&get(q[tou+1],q[tou])<a[j]) tou++;
f[j]=g[q[tou]]+sqr(a[j]-a[q[tou]]);
while (tou<wei&&get(j,q[wei])<get(q[wei],q[wei-1])) wei--;
q[++wei]=j;
}
ans=min(ans,f[n]);
for (int j=1;j<=n;j++) g[j]=f[j];
}
printf("%lld",m*ans-sqr(a[n]));
}
int main()
{
scanf("%d%d",&n,&m);
for (int i=1;i<=n;i++) scanf("%lld",&a[i]),a[i]+=a[i-1];
doit();
return 0;
}