1. 程式人生 > >CF193D Two Segments (線段樹+dp)(外加兩個擴充套件題)

CF193D Two Segments (線段樹+dp)(外加兩個擴充套件題)

大概算是個系列整理

(最強版是模擬賽原題))

首先,我們先來看這個題目。

QWQ一開始是毫無頭緒,除了列舉就是列舉

首先,我們可以列舉一個右端點,然後算一下當前右端點的答案

我們令\(f[l,r]\)表示\(a_l到a_r\)這些數,能夠最少劃分成幾段連續的數。

顯然,我們要求的是以每個端點為右端點,\(f值<=2的\)
QWQ那麼這個玩意應該怎麼維護+更新呢

考慮右端點移動,會造成什麼後果。

我們令新擴充套件的位置是\(r\),數是\(x\),他的前驅的位置是\(pre\),後繼的位置是\(last\)

\(比如3的前驅就是2,後繼是4\)

首先,\(f[1,r]....f[r,r]\)

都會加1,不考慮和之前的數能合併的情況下,他自己就需要一段來完成

如果\(pre\)在當前位置的前面,那麼\(f[1..r]....[pre,r]\)應該要-1,因為所有從pre之前出發的左端點,新的數可以和前驅的合併,就可以減少一段

那麼如果\(last\)前面,也是同理的。

所以,我們需要一個支援區間維護最小值,最小值個數,次小值,次小值個數,還支援區間加和減的一個數據結構

線段樹!

這裡有幾個要注意的地方就是:

1.維護的是嚴格的最小值和次小值,也就是說兩個值不能相同,所以\(up\)的時候,會有一些小細節

void up(int root)
{
 if (f[2*root].mn<f[2*root+1].mn)
 {
  f[root].mn=f[2*root].mn;
  f[root].cimn=min(f[2*root].cimn,f[2*root+1].mn);
 }
 else
 {
  if (f[2*root].mn>f[2*root+1].mn)
  {
    f[root].mn=f[2*root+1].mn;
    f[root].cimn=min(f[2*root].mn,f[2*root+1].cimn);
     }
     else
     {
      f[root].mn=min(f[2*root].mn,f[2*root+1].mn);
      f[root].cimn=min(f[2*root].cimn,f[2*root+1].cimn);
  }
 }
 if(f[root].cimn==f[root].mn) f[root].cimn=1e9; 
 f[root].sum1=f[2*root].sum1*(f[2*root].mn==f[root].mn)+f[2*root+1].sum1*(f[2*root+1].mn==f[root].mn);
 f[root].sum2=f[2*root].sum2*(f[2*root].cimn==f[root].cimn)+f[2*root+1].sum2*(f[2*root+1].cimn==f[root].cimn && f[root].cimn!=1e9);
 f[root].sum2+=f[2*root].sum1*(f[2*root].mn==f[root].cimn)+f[2*root+1].sum1*(f[2*root+1].mn==f[root].cimn);
}

2.\(query\)由於我們要求的是\(<=2\)的值的個數,所以求和的時候,要注意滿足\(mn<=2\)\(cimn==2\))

long long query(int root,int l,int r,int x,int y)
{
 if (x<=l && r<=y)
 {
  //cout<<l<<" "<<r<<endl;
  //cout<<f[root].mn<<" "<<f[root].cimn<<endl;
  return f[root].sum1*(f[root].mn<=2) + f[root].sum2*(f[root].cimn!=f[root].mn && f[root].cimn<=2);
 }
 int mid = l+r >> 1;
 long long ans=0;
 pushdown(root,l,r);
 if (x<=mid) ans=ans+query(2*root,l,mid,x,y);
 if (y>mid) ans=ans+query(2*root+1,mid+1,r,x,y);
 return ans;
}

QWQ大概就是這樣了?

具體直接看程式碼吧

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<queue>
#include<map>
#include<set>
#define mk makr_pair
#define ll long long
#define int long long
using namespace std;
inline int read()
{
  int x=0,f=1;char ch=getchar();
  while (!isdigit(ch)) {if (ch=='-') f=-1;ch=getchar();}
  while (isdigit(ch)) {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
  return x*f;
}
const int maxn= 4e5+1e2;
struct Node
{
 int mn,cimn;
 int sum1,sum2;
};
Node f[4*maxn];
int add[4*maxn];
int n,m;
int a[maxn],b[maxn];
long long ans;
void up(int root)
{
 if (f[2*root].mn<f[2*root+1].mn)
 {
  f[root].mn=f[2*root].mn;
  f[root].cimn=min(f[2*root].cimn,f[2*root+1].mn);
 }
 else
 {
  if (f[2*root].mn>f[2*root+1].mn)
  {
    f[root].mn=f[2*root+1].mn;
    f[root].cimn=min(f[2*root].mn,f[2*root+1].cimn);
     }
     else
     {
      f[root].mn=min(f[2*root].mn,f[2*root+1].mn);
      f[root].cimn=min(f[2*root].cimn,f[2*root+1].cimn);
  }
 }
 if(f[root].cimn==f[root].mn) f[root].cimn=1e9; 
 f[root].sum1=f[2*root].sum1*(f[2*root].mn==f[root].mn)+f[2*root+1].sum1*(f[2*root+1].mn==f[root].mn);
 f[root].sum2=f[2*root].sum2*(f[2*root].cimn==f[root].cimn)+f[2*root+1].sum2*(f[2*root+1].cimn==f[root].cimn && f[root].cimn!=1e9);
 f[root].sum2+=f[2*root].sum1*(f[2*root].mn==f[root].cimn)+f[2*root+1].sum1*(f[2*root+1].mn==f[root].cimn);
}
void pushdown(int root,int l,int r)
{
 if (add[root])
 {
  add[2*root]+=add[root];
  add[2*root+1]+=add[root];
  f[2*root].mn+=add[root];
  f[2*root].cimn+=add[root];
  f[2*root+1].mn+=add[root];
  f[2*root+1].cimn+=add[root];
  add[root]=0;
 }
}
void build(int root,int l,int r)
{
 if(l==r)
 {
  f[root].sum1=1;
  f[root].sum2=0;
  f[root].cimn=1e9;
  return;
 }
 int mid = l+r >> 1;
 build(2*root,l,mid);
 build(2*root+1,mid+1,r);
 up(root);
}
void update(int root,int l,int r,int x,int y,int p)
{
 if (x<=l && r<=y)
 {
  add[root]+=p;
  f[root].mn+=p;
  f[root].cimn+=p;
  return;
 }
 int mid = l+r >> 1;
 pushdown(root,l,r);
 if(x<=mid) update(2*root,l,mid,x,y,p);
 if(y>mid) update(2*root+1,mid+1,r,x,y,p);
 up(root);
}
long long query(int root,int l,int r,int x,int y)
{
 if (x<=l && r<=y)
 {
  //cout<<l<<" "<<r<<endl;
  //cout<<f[root].mn<<" "<<f[root].cimn<<endl;
  return f[root].sum1*(f[root].mn<=2) + f[root].sum2*(f[root].cimn!=f[root].mn && f[root].cimn<=2);
 }
 int mid = l+r >> 1;
 long long ans=0;
 pushdown(root,l,r);
 if (x<=mid) ans=ans+query(2*root,l,mid,x,y);
 if (y>mid) ans=ans+query(2*root+1,mid+1,r,x,y);
 return ans;
} 
signed main()
{
  n=read();
  for (int i=1;i<=n;i++) a[i]=read();
  for (int i=1;i<=n;i++) b[a[i]]=i;
  build(1,1,n); 
 // update(1,1,n,1,3,1);
  //update(1,1,n,1,2,1);
  //cout<<query(1,1,n,1,3)<<endl;
  //return 0;
  for (int i=1;i<=n;i++)
  {
   int x = a[b[i]-1],y=a[b[i]+1];
   update(1,1,n,1,i,1);
   if (x && x<i) update(1,1,n,1,x,-1);
   if (y && y<i) update(1,1,n,1,y,-1);
   ans=ans+query(1,1,n,1,i);
   //cout<<ans<<endl;
  }  
  cout<<ans-n<<endl;
  return 0;
}

QWQ嚶嚶嚶

那麼如果換一種問法,應該怎麼辦呢?

給定一個你長度為\(n\)的序列,然後求出來有多少個區間滿足最大值減去最小值等於區間長度-1

其實和上個題目差不多了啦。

只不過我們只需要維護最小值,然後求和的時候,只需要滿足最小值等於1即可

直接給程式碼(只呈現關鍵部分的)

void up(int root)
{
 g[root].mn=min(g[2*root].mn,g[2*root+1].mn);
 g[root].ans=g[2*root].ans*(g[2*root].mn==g[root].mn)+g[2*root+1].ans*(g[2*root+1].mn==g[root].mn);  
}
void pushdown(int root,int l,int r)
{
 if (add[root])
 {
  add[2*root]+=add[root];
  add[2*root+1]+=add[root];
  g[2*root].mn+=add[root];
  g[2*root+1].mn+=add[root];
  add[root]=0;
 }
}
void build(int root,int l,int r)
{
 if (l==r)
 {
  g[root].ans=1;
  return;
 }
 int mid = l+r >> 1;
    build(2*root,l,mid);
 build(2*root+1,mid+1,r);
 up(root); 
}
void update(int root,int l,int r,int x,int y,int p)
{
 if(x<=l && r<=y)
 {
  g[root].mn+=p;
  add[root]+=p;
  return;
 }
 pushdown(root,l,r);
 int mid = l+r >> 1;
 if (x<=mid) update(2*root,l,mid,x,y,p);
 if (y>mid) update(2*root+1,mid+1,r,x,y,p);
 up(root);
}
long long query(int root,int l,int r,int x,int y)
{
 if(x<=l && r<=y)
 {
  return g[root].ans*(g[root].mn==1); 
 }
 pushdown(root,l,r);
 int mid = l+r >> 1;
 long long ans=0;
 if (x<=mid) ans=ans+query(2*root,l,mid,x,y);
 if (y>mid) ans=ans+query(2*root+1,mid+1,r,x,y);
 return ans;
}

既然都做到這個程度了,不如就更毒瘤 一點

現在給定你一顆n個點的樹,每個點都有一個編號,每條邊的長度都是1,讓你求有多少條路經滿足最大編號-最小編號等於路徑長度。

woc上樹了....那該怎麼做啊?

是不是可以考慮和序列上的相類似呢?

我們不妨對每個點維護一個\(dfn[x]\)表示這個點的\(dfs\)序。

然後依次列舉dfs序上的每個點,計算dfs序從1到當前點之前的所有點到當前點的合法路徑條數。

類比序列

對於當前點來說,首先我們要讓1到當前點之前所有的路徑都+1,然後呢。
我們考慮前驅和後繼的位置

這裡需要討論一個是否是祖先的關係(因為畫個圖就能發現,如果是祖先,那麼這個點會影響的路徑起點是1到\(dfn[x]\),不然就是他的子樹內的所有點)

後繼同樣是如此

而且在計算完每個兒子的時候,記得加上當前兒子對其他兒子的貢獻。

然後最後記得把一個點的貢獻都還原,因為我們需要計算別的答案,而當前點就會變成起點之一,那麼他作為終點的貢獻,就是要去掉的。

QWQ有一些細節寫到程式碼裡面了

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<queue>
#include<map>
#include<set>
#define mk makr_pair
#define ll long long
#define int long long
using namespace std;
inline int read()
{
  int x=0,f=1;char ch=getchar();
  while (!isdigit(ch)) {if (ch=='-') f=-1;ch=getchar();}
  while (isdigit(ch)) {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
  return x*f;
}
const int maxn = 1e5+1e2;
const int maxm =  2*maxn;
int point[maxn],nxt[maxm],to[maxm];
int cnt,n,m;
int dfn[maxn];
int ans;
void addedge(int x,int y)
{
 nxt[++cnt]=point[x];
 to[cnt]=y;
 point[x]=cnt;
}
struct Node{
 int mn,ans;
 int len;
};
Node g[4*maxn];
int add[4*maxn];
void up(int root)
{
 g[root].mn=min(g[2*root].mn,g[2*root+1].mn);
 g[root].ans=g[2*root].ans*(g[2*root].mn==g[root].mn)+g[2*root+1].ans*(g[2*root+1].mn==g[root].mn);  
}
void pushdown(int root,int l,int r)
{
 if (add[root])
 {
  add[2*root]+=add[root];
  add[2*root+1]+=add[root];
  g[2*root].mn+=add[root];
  g[2*root+1].mn+=add[root];
  add[root]=0;
 }
}
void build(int root,int l,int r)
{
 if (l==r)
 {
  g[root].ans=1;
  return;
 }
 int mid = l+r >> 1;
    build(2*root,l,mid);
 build(2*root+1,mid+1,r);
 up(root); 
}
void update(int root,int l,int r,int x,int y,int p)
{
 if (x>y) return;
 if(x<=l && r<=y)
 {
  g[root].mn+=p;
  add[root]+=p;
  return;
 }
 pushdown(root,l,r);
 int mid = l+r >> 1;
 if (x<=mid) update(2*root,l,mid,x,y,p);
 if (y>mid) update(2*root+1,mid+1,r,x,y,p);
 up(root);
}
long long query(int root,int l,int r,int x,int y)
{
 if (x>y) return 0;
 if(x<=l && r<=y)
 {
  return g[root].ans*(g[root].mn==1); 
 }
 pushdown(root,l,r);
 int mid = l+r >> 1;
 long long ans=0;
 if (x<=mid) ans=ans+query(2*root,l,mid,x,y);
 if (y>mid) ans=ans+query(2*root+1,mid+1,r,x,y);
 return ans;
}
int deep[maxn];
int f[maxn][21];
int a[maxn],b[maxn];
int size[maxn];
int tot;
int maxdfn[maxn]; //表示已經計算過的兒子的子樹裡面的最大的dfs序 
void dfs(int x,int fa,int dep)
{
 deep[x]=dep;
 dfn[x]=++tot;
 size[x]=1;
 for (int i=point[x];i;i=nxt[i])
 {
  int p = to[i];
  if (p==fa) continue; 
  f[p][0]=x;
  dfs(p,x,dep+1);
  size[x]+=size[p];
 }
}
void init()
{
 for (int j=1;j<=20;j++)
   for (int i=1;i<=n;i++)
     f[i][j]=f[f[i][j-1]][j-1];
}
int go_up(int x,int d)
{
 for (int i=0;i<=20;i++)
 if ((1<<i)&d) x=f[x][i];
 return x;
}
bool check(int x,int fa)
{
 if (fa==0 || fa==n+1) return 0;
 if(deep[x]<=deep[fa]) return 0;
 if (go_up(x,deep[x]-deep[fa])==fa) return 1;
 else return 0;
}
int dp(int x,int fa)//我們對於每個點,計算的是 dfs序上[i,r]的合法路徑條數 
{
 maxdfn[x]=dfn[x];
    update(1,1,n,1,dfn[x],1); //首先把之前的全部+1 
 if (dfn[x-1]<dfn[x] && x!=1)
 {
    if (check(x,x-1))
      update(1,1,n,1,maxdfn[x-1],-1); //相當於這些點都是從x-1到達x,(相當於除去這個子樹外所有的dfs小於當前點的點)所以應該-1,因為可以合併 
          else
      update(1,1,n,dfn[x-1],dfn[x-1]+size[x-1]-1,-1); //如果不是祖先關係,那麼從x-1到達x的路徑,一定是從他的子樹裡面出發的 (而且子樹內的任何一個點的dfs序一定都在當前點之前) 
 }
 if (dfn[x+1]<dfn[x] && x!=n)
 {
    if (check(x,x+1))
      update(1,1,n,1,maxdfn[x+1],-1);  
          else
      update(1,1,n,dfn[x+1],dfn[x+1]+size[x+1]-1,-1);  
 }
 ans=ans+query(1,1,n,1,dfn[x]);
 for (int i=point[x];i;i=nxt[i])
 {
  int p = to[i];
  if (p==fa) continue;
     int now = dp(p,x);
     update(1,1,n,maxdfn[x]+1,now,1); //處理兒子之間的影響  (因為計算一個點的代價的時候,不僅有祖先或者是別的子樹的,還要計算兄弟的) 
     if (x!=1 && (p==x-1 || check(x-1,p))) update(1,1,n,dfn[x-1],dfn[x-1]+size[x-1]-1,-1); //如果x-1在當前的兒子裡面,那麼他那個子樹裡到後面的點的代價就可以-1(理解成能夠合併) 
     if (x!=n && (p==x+1 || check(x+1,p))) update(1,1,n,dfn[x+1],dfn[x+1]+size[x+1]-1,-1);
     maxdfn[x]=now; 
 }
 if (dfn[x-1]<dfn[x] && x!=1)
 {
    if (check(x,x-1))
      update(1,1,n,1,maxdfn[x-1],1);  
          else
      update(1,1,n,dfn[x-1],dfn[x-1]+size[x-1]-1,1);  
 }
 if (dfn[x+1]<dfn[x] && x!=n)
 {
    if (check(x,x+1))
      update(1,1,n,1,maxdfn[x+1],1);  
          else
      update(1,1,n,dfn[x+1],dfn[x+1]+size[x+1]-1,1);  
 }
 update(1,1,n,1,dfn[x]-1,-1); //還原所有的操作,因為要計算別的為1 的ans,之所以是dfn[x]-1 相當於給這個賦值為1.(至少一段) 
    return maxdfn[x];
}
signed main()
{
  n=read();
  for (int i=1;i<n;i++)
  {
   int x=read(),y=read();
   addedge(x,y);
 addedge(y,x); 
  }
  dfs(1,0,1);
  init();
  build(1,1,n); 
  for (int i=1;i<=n;i++) b[dfn[i]]=i;
  dp(1,0);
  cout<<ans;
  return 0;
}