1. 程式人生 > >[BZOJ5292][Bjoi2018]治療之雨(期望DP+高斯消元)

[BZOJ5292][Bjoi2018]治療之雨(期望DP+高斯消元)

Address

Solution

首先,一個顯然的 DP 狀態: f[i]f[i] 表示第一個數當前為 ii ,將其變成 00 的期望步數。 邊界當然是 f[0]=0f[0]=0 。 討論一波轉移: 設 P(i,x)P(i,x) 表示當第一個數為 ii 時, kk 輪減操作讓第一個數減少 xx 的概率。 這樣轉移就很顯然了: 當 i<ni<n 時: f[i]=1+1m+1j=0i+1P(i+1,j)×f[ij+1]+mm+1j=0iP(i,j)×f[ij]f[i]=1+\frac 1{m+1}\sum_{j=0}^{i+1}P(i+1,j)\times f[i-j+1]+\frac m{m+1}\sum_{j=0}^iP(i,j)\times f[i-j]

i=ni=n 時: f[i]=1+j=0iP(i,j)×f[ij]f[i]=1+\sum_{j=0}^iP(i,j)\times f[i-j] 要解決兩個小問題: (1) P(i,j)P(i,j) 的值。 分下類: 當 j<ij<i 時,相當於在 kk 次操作中選出 jj 次操作對第一個數進行,剩下的 kjk-j 次操作對剩下的 mm 個數進行。 所以: P
(i,j)={Ckj×mkj(m+1)kj<i1k=0i1P(i,k)j=iP(i,j)=\begin{cases}\frac{C_k^j\times m^{k-j}}{(m+1)^k}&j<i\\1-\sum_{k=0}^{i-1}P(i,k)&j=i\end{cases}
特別地,如果 k<jk<jP(i,j)=0P(i,j)=0 。 (2) 轉移的後效性。 把每個 f[i
]f[i]
當作一個未知變數,使用高斯消元解方程。 但這樣複雜度是 O(Tn3)O(Tn^3) 的。 發現係數矩陣長這個樣子: [X0000000XXX00000XXXX0000XXXXX000XXXXXX00XXXXXXX0XXXXXXXXXXXXXXXXXX]\begin{bmatrix}X&0&0&0&0&0&0&\dots&0\\X&X&X&0&0&0&0&\dots&0\\X&X&X&X&0&0&0&\dots&0\\X&X&X&X&X&0&0&\dots&0\\X&X&X&X&X&X&0&\dots&0\\X&X&X&X&X&X&X&\dots&0\\\vdots&\vdots&\vdots&\vdots&\vdots&\vdots&\vdots&\ddots&\vdots\\X&X&X&X&X&X&X&X&X\\X&X&X&X&X&X&X&X&X\end{bmatrix} 從第一列到第 n+1n+1 列分別表示 f[0]f[0]f[n]f[n] ,第一行到第 n+1n+1 行分別表示 f[0]f[0]f[n]f[n] 的轉移。 這矩陣已經非常接近於下三角矩陣。 我們只需要從最後一行開始網上,對於第 iii>2i>2 )行,只需要用第 ii 行去消第 ii 行使得第 ii 行第 i+1i+1 列為 00 即可。 這樣係數矩陣就變成了下三角矩陣,從 f[0]f[0] 開始一一代入即可。 注:如果出現了除以 00 的情況則方程組無解,輸出 1-1 。 時間複雜度 O(Tn2)O(Tn^2)

Code

#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define For(i, a, b) for (i = a; i <= b; i++)
#define Rof(i, a, b) for (i = a; i >= b; i--)

inline int read()
{
	int res = 0; bool bo = 0; char c;
	while (((c = getchar()) < '0' || c > '9') && c != '-');
	if (c == '-') bo = 1; else res = c - 48;
	while ((c = getchar()) >= '0' && c <= '9')
		res = (res << 3) + (res << 1) + (c - 48);
	return bo ? ~res + 1 : res;
}

template <class T>
T Min(T a, T b) {return a < b ? a : b;}

const int N = 1505, ZZQ = 1e9 + 7;

int n, p, m, k, inv[N], f[N][N], pw[N], C[N], a[N];

int qpow(int a, int b)
{
	int res = 1;
	while (b)
	{
		if (b & 1) res = 1ll * res * a % ZZQ;
		a = 1ll * a * a % ZZQ;
		b >>= 1;
	}
	return res;
}

void work()
{
	int i, j, alls, orz, tmp, rp;
	n = read(); p = read(); m = read(); k = read();
	orz = qpow(m + 1, ZZQ - 2);
	alls = qpow(qpow(m + 1, k)