1. 程式人生 > >HDU-4219 Randomization?(樹形DP+概率DP)

HDU-4219 Randomization?(樹形DP+概率DP)

題意

給定一棵 n 個節點的樹,每條邊的權值為 [ 0 , L ]

之間的隨機整數,求這棵樹兩點之間最長距離不超過 S 的概率。
1 n
1000

1 L 10
1 S 2000

思路

這種概率題以前沒碰到過,現在碰到連暴力也不會打。其實樣本點除以樣本空間是最穩的暴力。從 1 2 n 均有一條邊的特殊情況入手,比較顯然需要一條邊一條邊新增。用 d p i 儲存最大邊權為 i 的合法情況的概率(也就是說不能存在大於 S i 的另一條邊)。而樹的情況,也不就是變成了一個子樹一個子樹新增, d f s 如下:

void dfs(int u,int f)
{
    dp[u][0]=1;
    EOR(i,G,u)
    {
        int v=G.to[i];
        if(v==f)continue;
        dfs(v,u);
        memset(ad,0,sizeof(ad));
        memset(tmp,0,sizeof(tmp));
        FOR(j,0,L)
            FOR(k,0,S-j)
                ad[k+j]+=dp[v][k]/(L+1);
        FOR(j,0,S)
            FOR(k,0,S-j)
                tmp[max(j,k)]+=dp[u][j]*ad[k];
        FOR(j,0,S)dp[u][j]=tmp[j];
    }
}

a d 陣列表示新新增的子樹 v d p t m p d p u 的一個滾動(先把新的值整體賦回去)。
觀察上述程式碼,發現用新增子樹兩層迴圈時間開銷過大,考慮通過字首和優化掉一層。可以先令 j k ,然後乘上所有的 k [ 0 , m i n { j , S j } ] 的和,同理令 k 大,乘上所有 j [ 0 , m i n { k , S k } ] ,最後發現 j = k 的情況被算了兩次,減掉一次即可。

程式碼

#include<bits/stdc++.h>
#define FOR(i,x,y) for(register int i=(x);i<=(y);++i)
#define DOR(i,x,y) for(register int i=(x);i>=(y);--i)
#define N 1003
typedef long long LL;
using namespace std;
template<const int maxn,const int maxm>struct Linked_list
{
    int head[maxn],to[maxm],nxt[maxm],tot;
    void clear(){memset(head,-1,sizeof(head));tot=0;}
    void add(int u,int v){to[++tot]=v,nxt[tot]=head[u],head[u]=tot;}
    #define EOR(i,G,u) for(register int i=G.head[u];~i;i=G.nxt[i])
};
Linked_list<N,N<<1>G;
double dp[N][2003],ad[2003],sumdp[2003],sumad[2003],tmp[2003];
int n,L,S;
void dfs(int u,int f)
{
    dp[u][0]=1;
    EOR(i,G,u)
    {
        int v=G.to[i];
        if(v==f)continue;
        dfs(v,u);
        memset(tmp,0,sizeof(tmp));
        memset(ad,0,sizeof(ad));
        FOR(j,0,L)
            FOR(k,0,S-j)
                ad[k+j]+=dp[v][k]/(L+1);
        sumad[0]=ad[0];
        sumdp[0]=dp[u][0];
        FOR(j,1,S)sumad[j]=sumad[j-1]+ad[j],sumdp[j]=sumdp[j-1]+dp[u][j];
        FOR(j,0,S)
        {
            int k=min(j,S-j);
            tmp[j]+=dp[u][j]*sumad[k];
        }
        FOR(k,0,S)
        {
            int j=min(k,S-k);
            tmp[k]+=ad[k]*sumdp[j];
        }
        FOR(j,0,S/2)tmp[j]-=dp[u][j]*ad[j];
        FOR(j,0,S)dp[u][j]=tmp[j];
    }
}

int main()
{
    G.clear();
    scanf("%d%d%d",&n,&L,&S);
    FOR(i,1,n-1)
    {
        int u,v;
        scanf("%d%d",&u,&v);
        G.add(u,v);
        G.add(v,u);
    }
    dfs(1,0);
    double ans=0;
    FOR(i,0,S)ans+=dp[1][i];
    printf("%.6lf\n",ans);
    return 0;
}