Newcoder 148 I.Rikka with Zombies(樹形DP)
阿新 • • 發佈:2018-12-12
Description
給出一棵個節點的樹,再給出只殭屍的位置以及能力,第條邊可以等概率建立起高度為的圍牆,第只殭屍可以越過任何小於的圍牆,問至少存在一個位置安全(即沒有殭屍可以到達這個位置)的概率
Input
第一行一整數表示用例組數,每組用例首先輸入兩個整數表示點數和殭屍數量,之後行每行輸入四個整數表示第條樹邊為,可以建立圍牆高度範圍為,最後行每行兩個整數表示第個點處有一個能力為的殭屍
Output
輸出至少存在一個位置安全的概率,結果模
Sample Input
2 4 2 1 2 1 2 2 3 1 2 1 4 1 2 1 2 3 2 5 2 1 2 1 10 2 3 2 9 1 4 3 12 2 5 4 6 1 7 5 5
Sample Output
374341633 888437475
Solution
考慮所有位置都不安全的方案數,把殭屍按能力升序排,以表示以為根的子樹全部不安全,且子樹中可以到達的最強殭屍編號為,以此考慮的子樹,對於當前考慮的兒子,假設邊被殭屍通過的方案數為,不被通過的方案數為,那麼有三種情況:
1.被幹掉了,
2.被其子樹內弱於的殭屍幹掉,此時顯然殭屍不會在子樹中,為使幹掉的最強殭屍不超過,之間需要阻礙殭屍通過,故有
3.被其子樹內強於的殭屍幹掉,此時不能讓這些強於的殭屍通過之間的邊去幹掉,故之間需要阻礙這些更強的殭屍通過,故有
字首和優化一下,第二部分從弱殭屍到強殭屍考慮,第三部分從強殭屍到弱殭屍考慮即可,答案即為 時間複雜度
Code
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
using namespace std;
typedef long long ll;
typedef pair<int,int>P;
#define maxn 2005
#define mod 998244353
int mul(int x,int y)
{
ll z=1ll*x*y;
return z-z/mod*mod;
}
int add(int x,int y)
{
x+=y;
if(x>=mod)x-=mod;
return x;
}
int Pow(int x,int y)
{
int ans=1;
while(y)
{
if(y&1)ans=mul(ans,x);
x=mul(x,x);
y>>=1;
}
return ans;
}
#define id second
#define val first
P a[maxn];
vector<P>g[maxn];
int T,n,m,l[maxn],r[maxn],dp[maxn][maxn],vis[maxn][maxn],temp[maxn];
void dfs(int u,int fa)
{
int pos=1;
for(int i=1;i<=m;i++)
if(a[i].id==u)
{
vis[u][i]=1;
pos=i;
}
else vis[u][i]=0;
for(int i=pos;i<=m;i++)dp[u][i]=1;
for(int i=0;i<g[u].size();i++)
{
int v=g[u][i].first,t=g[u][i].second;
if(v==fa)continue;
dfs(v,u);
for(int j=1;j<=m;j++)temp[j]=dp[u][j],dp[u][j]=0;
int res=0;
for(int j=1;j<=m;j++)
{
int x=max(0,min(a[j].val-1,r[t])-l[t]+1),y=r[t]-l[t]+1-x;
dp[u][j]=add(dp[u][j],mul(x,mul(temp[j],dp[v][j])));
if(vis[v][j])res=add(res,dp[v][j]);
else dp[u][j]=add(dp[u][j],mul(mul(y,res),temp[j]));
}
res=0;
for(int j=m;j>=1;j--)
{
int x=max(0,min(a[j].val-1,r[t])-l[t]+1),y=r[t]-l[t]+1-x;
if(vis[v][j])res=add(res,mul(dp[v][j],y));
else dp[u][j]=add(dp[u][j],mul(res,temp[j]));
}
for(int j=1;j<=m;j++)vis[u][j]|=vis[v][j];
}
}
int main()
{
scanf("%d",&T);
while(T--)
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)g[i].clear();
memset(dp,0,sizeof(dp));
int ans=1;
for(int i=1;i<n;i++)
{
int u,v;
scanf("%d%d%d%d",&u,&v,&l[i],&r[i]);
ans=mul(ans,r[i]-l[i]+1);
g[u].push_back(P(v,i)),g[v].push_back(P(u,i));
}
for(int i=1;i<=m;i++)scanf("%d%d",&a[i].id,&a[i].val);
sort(a+1,a+m+1);
dfs(1,0);
int res=0;
for(int i=1;i<=m;i++)res=add(res,dp[1][i]);
res=add(ans,mod-res);
printf("%d\n",mul(res,Pow(ans,mod-2)));
}
return 0;
}