1. 程式人生 > >2018.11.05【校內模擬】規避(最短路計數)(容斥)(正難則反)

2018.11.05【校內模擬】規避(最短路計數)(容斥)(正難則反)

傳送門

解析:

首先直接統計並不好做,考慮反著做,先求出總共的方案數,然後減去相遇的方案數。

總方案數就是SSTT的最短路數量的平方(兩人分別作選擇)。

首先這是個計數類問題,先做一個最短路計數。

distSudistS_u表示SSuu的最短路長度,cntSucntS_u表示SSuu的最短路數量,distTudistT_ucntTucntT_u同理,記S,TS,T最短路長度為tottot

怎麼統計相遇? 首先我們發現兩人的相交位置一定是最短路的中點,這個中點可能是點也可能是邊,所以考慮以中點為標誌統計答案。

當一個點不在任何一條最短路上時,直接passpass。 否則若這個點uu滿足distSu==distTudistS_u==distT_u,則兩人有可能在這個點上相遇,方案數為(cntSu×cntTu)2(cntS_u\times cntT_u)^2,因為兩人要走完全程,又必須經過點uu,根據乘法原理可以輕易得出答案。

如果這個點uu滿足,存在一條邊e=<u,v>e=<u,v>distSutot/2distS_u\le tot/2

tSutot/2distTvtot/2distT_v\le tot/2,那麼兩人就可能在這條邊上相遇。那麼經過這條邊的最短路數就是cntSu×cntTvcntS_u\times cntT_v,根據乘法原理,平方一下就好了。

程式碼:

#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define re register
#define gc getchar
#define pc putchar
#define cs const

inline int
getint(){ re int num; re char c; while(!isdigit(c=gc()));num=c^48; while(isdigit(c=gc()))num=(num<<1)+(num<<3)+(c^48); return num; } cs ll mod=1000000007; cs int N=100005,M=200005; int last[N],nxt[M<<1],to[M<<1],ecnt; int w[M<<1]; inline void addedge(int u,int v,int val){ nxt[++ecnt]=last[u],last[u]=ecnt,to[ecnt]=v,w[ecnt]=val; nxt[++ecnt]=last[v],last[v]=ecnt,to[ecnt]=u,w[ecnt]=val; } ll distS[N],distT[N]; ll cntS[N],cntT[N],tot,ans; bool flag; bool vis[N]; set<pair<ll,int> > q; inline void Dijkstra(ll *cs dist,ll *cs cnt,int S,int T){ memset(dist,0x3f,sizeof distS); dist[S]=0;cnt[S]=1; q.clear(); q.insert(make_pair(0,S)); memset(vis,0,sizeof vis); while(!q.empty()){ int u=q.begin()->second; q.erase(q.begin()); if(vis[u])continue; vis[u]=true; for(int re e=last[u],v=to[e];e;v=to[e=nxt[e]]){ if(dist[v]>dist[u]+w[e]){ q.erase(make_pair(dist[v],v)); dist[v]=dist[u]+w[e]; cnt[v]=0; q.insert(make_pair(dist[v],v)); } if(dist[v]==dist[u]+w[e])cnt[v]=(cnt[v]+cnt[u])%mod; } } } int S,T; int n,m; signed main(){ n=getint(); m=getint(); S=getint(); T=getint(); for(int re i=1;i<=m;++i){ int u=getint(),v=getint(); ll val=getint(); if(v!=u) addedge(u,v,val); } Dijkstra(distS,cntS,S,T); Dijkstra(distT,cntT,T,S); tot=distS[T]; ans=cntS[T]*cntT[S]%mod; for(int re u=1;u<=n;++u){ if(distS[u]+distT[u]!=tot)continue; if((!(tot&1))&&distS[u]==(tot>>1)){ ans=(ans-cntS[u]*cntT[u]%mod*cntS[u]%mod*cntT[u]%mod+mod)%mod; continue; } if(distS[u]*2>tot)continue; for(int re e=last[u],v=to[e];e;v=to[e=nxt[e]]){ if(distS[v]!=distS[u]+w[e])continue; if(distT[v]*2>=tot)continue; if(tot!=distS[u]+w[e]+distT[v])continue; ans=(ans-cntS[u]*cntT[v]%mod*cntS[u]%mod*cntT[v]%mod+mod)%mod; } } cout<<(ans%mod+mod)%mod; return 0; }