1. 程式人生 > >LibreOJ #2541.「PKUWC 2018」獵人殺 分治NTT+容斥原理

LibreOJ #2541.「PKUWC 2018」獵人殺 分治NTT+容斥原理

題意

獵人殺是一款風靡一時的遊戲“狼人殺”的民間版本,他的規則是這樣的:
一開始有n個獵人,第i個獵人有仇恨度wi,每個獵人只有一個固定的技能:死亡後必須開一槍,且被射中的人也會死亡。
然而向誰開槍也是有講究的,假設當前還活著的獵人有[i1...im],那麼有wikj=1mwij概率是向獵人ik開槍。
一開始第一槍由你打響,目標的選擇方法和獵人一樣,由於開槍導致的連鎖反應,所有獵人最終都會死亡,現在1號獵人想知道它是最後一個死的的概率。
答案對998244353取模。
wi>0,

wi105

分析

比較有意思的一道題。
考慮容斥,列舉哪些人一定在第一個人後面選,那麼容斥係數就是(1)
關鍵在於怎麼求其他人被選的概率。
A表示n個人wi的和,S表示選出的人的wi的和,那麼係數後面的概率就是

i=0(1S+w1A)iw1A=w1S+w1
為什麼這樣是對的呢?
考慮這樣一個問題,假設現在有n個白球,每次會隨機選任意一個,選出一個球后就把它扔掉,問第i次選到某一個白球的概率是多少。
這個問題顯然等價於,每次選出一個白球后,我不扔掉它,而是給它打上一個標記。若某次選出了一個被打了標記的球,則將其放回去,問第i個被標記的球是某一個的概率是多少。
那麼現在我就可以把原來的問題看成把一個人打標記而不是殺死一個人,這樣抽中每個人的概率就是不變的。然後一個到正無窮的迴圈,剛好可以算出每一種情況的概率的和,所以就是正確的。
剩下的問題就變成了如何求
S
,只要用分治NTT搞一搞就好了。

程式碼

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>

typedef long long LL;

const int N=100005;
const int MOD=998244353;

int n,w[N],a[20][N*4],rev[N*4],L,s[N];

int ksm(int x,int y)
{
    int ans=1;
    while
(y) { if (y&1) ans=(LL)ans*x%MOD; x=(LL)x*x%MOD;y>>=1; } return ans; } void NTT(int *a,int f) { for (int i=0;i<L;i++) if (i<rev[i]) std::swap(a[i],a[rev[i]]); for (int i=1;i<L;i<<=1) { int wn=ksm(3,f==1?(MOD-1)/i/2:MOD-1-(MOD-1)/i/2); for (int j=0;j<L;j+=(i<<1)) { int w=1; for (int k=0;k<i;k++) { int u=a[j+k],v=(LL)a[j+k+i]*w%MOD; a[j+k]=(u+v)%MOD;a[j+k+i]=(u+MOD-v)%MOD; w=(LL)w*wn%MOD; } } } int ny=ksm(L,MOD-2); if (f==-1) for (int i=0;i<L;i++) a[i]=(LL)a[i]*ny%MOD; } void solve(int l,int r,int d) { if (l==r) { for (int i=1;i<w[l];i++) a[d][i]=0; a[d][0]=1;a[d][w[l]]=MOD-1; return; } int mid=(l+r)/2; solve(l,mid,d+1); for (int i=0;i<=s[mid]-s[l-1];i++) a[d][i]=a[d+1][i]; solve(mid+1,r,d+1); int lg=0; for (L=1;L<=s[r]-s[l-1];L<<=1,lg++); for (int i=0;i<L;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(lg-1)); for (int i=s[mid]-s[l-1]+1;i<L;i++) a[d][i]=0; for (int i=s[r]-s[mid]+1;i<L;i++) a[d+1][i]=0; NTT(a[d],1);NTT(a[d+1],1); for (int i=0;i<L;i++) a[d][i]=(LL)a[d][i]*a[d+1][i]%MOD; NTT(a[d],-1); } int main() { scanf("%d",&n); for (int i=1;i<=n;i++) scanf("%d",&w[i]),s[i]=s[i-1]+w[i]; solve(2,n,0); int ans=0; for (int i=0;i<=s[n]-s[1];i++) (ans+=(LL)w[1]*ksm(i+w[1],MOD-2)%MOD*a[0][i]%MOD)%=MOD; printf("%d",ans); return 0; }