ARC115 E

AtCoder

Problem Statement

Given is a sequence of \(N\) integers \(A_1\),\(A_2\),...,\(A_N\). Print the number, modulo \(998244353\), of sequence of \(N\) integers \(X_1\),\(X_2\),...,\(X_N\) satisfying all of the following conditions:

\(1 \leq X_i \leq A_i\)

  • \(X_i \not = X_{i+1} (1 \leq i \leq N-1)\)

Constraints

  • \(2 \leq N \leq 5 * {10}^5\)
  • \(1 \leq A_i \leq {10}^9\)

Input

Input is given from Standard Input in the following format:

\[N \\
A_1 A_2 ... A_N
\]

Output

Print the answer.

解法

題意說n個位置,每個位置上面的數字不能大於Ai,問每對相鄰的數都不相同的序列數有多少個。這種問題一看就是容斥,用所有的減去不符合的。不符合的分為至少某1個位置不符合,至少某2個位置不符合……這樣就可以用dp去做了。

dp[i][j]代表前i個元素分為j段的方案數,使得每段內的所有元素都相等。那麼答案其實就是dp[n][n]-dp[n][n-1]+dp[n][n-2]...。這個轉移方程是顯然的:

\[dp[i][j] = \sum_{k \leq i-1} dp[k][j-1] * \min_{k+1 \leq l \leq i} A_k
\]

但是這個轉移怎麼看都要\(O(N^2)\),不過好在最終的容斥式子係數只與j的奇偶性有關,於是只考慮奇偶性轉移:

\[dp[i][1] = \sum_{k \leq i-1} dp[k][0] * \min_{k+1 \leq l \leq i} A_k \\
dp[i][0] = \sum_{k \leq i-1} dp[k][1] * \min_{k+1 \leq l \leq i} A_k
\]

現在就是怎麼去做這個轉移的問題了。考慮到某個\(A_i\)的時候,\(A_i\)作為新的一段的轉移,\(A_i\)是新的一段中的最小值:

[A1 , ... Aj] [Aj+1 ... Ai ... Ak]

這裡\([A_{j+1} ... A_i ... A_k]\)是新新增上去的一段,可以發現\(l(i) \leq j \lt i\),\(i \leq k \leq r(i)\),其中\(l(i)\)是\(A_i\)左邊第一個小於等於\(A_i\)的下標,\(r(i)\)是\(A_i\)右邊第一個小於\(A_i\)的下標。這裡用了不同的符號是規定同樣大的數字,前面的更小,防止重複更新同一段。這麼規定也不會漏掉,因為每一個新增的段一定有一個\(A_i\)會被我們遍歷到。\(l(i)\)和\(r(i)\)可以通過單調棧輕鬆計算,這裡不贅述。

這樣的話,如果我們維護了當前元素左邊所有dp值的字首和,那麼我們就可以快速獲得所有滿足條件的\(j\)的dp和,然後更新到這一段的末尾可能的取值,即\(k\)的範圍:

rangeAdd(i+1, r[i], 0, preSum(l[i], i-1, 1) * A[i]);
rangeAdd(i+1, r[i], 1, preSum(l[i], i-1, 0) * A[i]);

這裡的preSum(l, r, p)是\(\sum_{l \leq i \leq r} dp[i][p]\),rangeAdd可以通過資料結構維護,這裡我們採用陣列這種快速的陣列結構來維護它:

rangeAdd(l, r, p, v) => diff[l][p] += v, diff[r+1][p] -= v;

這樣這道題就做完了,以下是AC程式碼:

#pragma GCC optimize ("Ofast,unroll-loops")
#pragma GCC optimize("no-stack-protector,fast-math") #include <bits/stdc++.h> using namespace std; constexpr int N = 5e5+7;
constexpr int M = 998244353; #define fastio ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
#define int long long
#define pii pair<int, int>
#define fi first
#define se second
#define SZ(x) ((int)(x.size())) #ifdef int
#define INF 0x3f3f3f3f3f3f3f3f
#define INF2 (int)(0xcfcfcfcfcfcfcfcf)
#else
#define INF 0x3f3f3f3f
#define INF2 0xcfcfcfcf
#endif signed main() {
fastio
int n;
cin >> n;
vector<int> a(n+1, 0);
for (int i = 1; i <= n; i++) cin >> a[i];
auto add = [&](int& x, int y) {
x = (x % M + M + y % M) % M;
};
vector<int> l(n+1, 0), r(n+1, n+1);
vector<int> st;
for (int i = 1; i <= n; i++) {
while (!st.empty() and a[st.back()] > a[i])
st.pop_back();
l[i] = st.empty() ? 0 : st.back();
st.emplace_back(i);
}
st.clear();
for (int i = n; i >= 1; i--) {
while (!st.empty() and a[st.back()] >= a[i])
st.pop_back();
r[i] = st.empty() ? n+1 : st.back();
st.emplace_back(i);
}
st.clear(); vector<array<int, 2>> dp(n+1, {0, 0});
vector<array<int, 2>> sum(n+1, {0, 0});
vector<array<int, 2>> diff(n+2, {0, 0});
dp[0][0] = 1;
sum[0][0] = dp[0][0];
auto rangeAdd = [&](int l, int r, int parity, int v) {
add(diff[l][parity], v);
add(diff[r+1][parity], -v);
};
auto preSum = [&](int l, int r, int parity) {
return (sum[r][parity] - (l ? sum[l-1][parity] : 0ll) + M) % M;
};
for (int i = 1; i <= n; i++) {
int ll = l[i], lr = i-1;
int rl = i, rr = r[i]-1; add(diff[i][0], diff[i-1][0]);
add(diff[i][1], diff[i-1][1]); if (ll <= lr and rl <= rr) {
rangeAdd(rl, rr, 0, preSum(ll, lr, 1) * a[i]);
rangeAdd(rl, rr, 1, preSum(ll, lr, 0) * a[i]);
} add(dp[i][0], diff[i][0]);
add(dp[i][1], diff[i][1]); add(sum[i][0], sum[i-1][0]);
add(sum[i][1], sum[i-1][1]);
add(sum[i][0], dp[i][0]);
add(sum[i][1], dp[i][1]);
} cout << (dp[n][n&1] - dp[n][1^(n&1)] + M)%M << "\n";
return 0;
}