1. 程式人生 > >LOJ565. 「LibreOJ Round #10」mathematican 的二進位制(NTT)

LOJ565. 「LibreOJ Round #10」mathematican 的二進位制(NTT)

題目連結

https://loj.ac/problem/565

題解

首先,若進行所有操作之後成功執行的運算元為 \(m\),最終得到的數為 \(w\),那麼發生改變的二進位制位的數量之和(即代價之和)為 \(2m - {\rm bit}(w)\)。其中,\({\rm bit}(x)\) 表示 \(x\) 在二進位制下 \(1\) 的個數。

證明:不難發現,一次操作會改變的二進位制位的總數為進位次數\(+1\),因此執行 \(m\) 次操作後改變的二進位制位的數量總和為總進位次數\(+m\)。由於一次進位會導致整個二進位制數中 \(1\) 的個數減少 \(1\),因此若最終得到的數為 \(w\)

,那麼 \(m - {\rm bit}(w)\) 即為總進位次數。故總代價為 \(2m - {\rm bit}(w)\)。不僅如此,該結論也同時告訴我們操作順序與最終答案是無關的。

這樣,本題轉化為求 \(2m - {\rm bit}(w)\) 的期望值。考慮將其拆開計算,\(2m\) 的期望值顯然為所有操作執行成功的概率和乘以 \(2\)\({\rm bit}(w)\) 的期望值可以通過期望的線性性質轉化為所有操作後每一個二進位制位上的數字最終為 \(1\) 的概率之和。這樣,我們就只需要求出最終每一個二進位制位上的數字為 \(1\) 的概率即可。

\(f_{i, j}\) 表示進行所有操作之後從小到大第 \(i\)

個二進位制位(以下簡稱第 \(i\) 位)的值被改變了 \(j\) 次的概率。由於第 \(i\) 位最終的結果可能被在第 \(i\) 位本身的操作以及所有第 \(j(j < i)\) 位上的操作所影響,我們將兩部分分開考慮,再設 \(g_{i, j}\) 表示進行完所有在第 \(i\) 位上的操作之後第 \(i\) 位的值被改變了 \(j\) 次的概率,這樣,我們就能夠用 \(g_i\)\(f_{i - 1}\) 算出 \(f_i\)

首先考慮如何計算 \(g_i\)。對於在第 \(i\) 位上的某個操作,若該操作成功執行的概率為 \(p\),那麼有 \(g_{i, j} \leftarrow g_{i, j - 1} \times p + g_{i, j} \times (1 - p)\)

,顯然該轉移可視為兩個多項式相乘,那麼最終的結果多項式即為若干個一次多項式的乘積,可用堆優化 NTT 合併。這樣,我們就能在 \(O(m \log^2 m)\) 的時間內求出整個 \(g\) 陣列。

接下來考慮算 \(f\),由於在第 \(i - 1\) 位上的每兩次改變才會導致第 \(i\) 位的一次改變,故轉移為 \(f_{i, j} = \sum_\limits{\lfloor\frac{a}{2}\rfloor + b = j} f_{i - 1, a} + g_{i, b}\)。顯然該轉移也可視為兩個多項式相乘,可用 NTT 優化。對於每個 \(i\),計算 \(f\) 陣列時我們只需要暴力做 NTT 即可。以下是來自 yww 的時間複雜度證明:

\(c_i\) 表示所有操作中在第 \(i\) 位上的運算元量(那麼有 \(\sum c_i = m\))。

\(f\) 的時間複雜度是
\[\begin{aligned}& O(\log m \times \sum_{i = 0}^n \sum_{j = 0}^i \frac{c_j}{2^{i - j}}) \\ =& O(\log m \times \sum_{i = 0}^{n} c_i \sum_{j = 0}^{i} 2^{-j}) \\ =& O(m \log m)\end{aligned}\]

上面的時間複雜度分析基於以下兩點:

  • 在第 \(j\) 位上的每 \(2^{i - j}\) 次改變才會導致第 \(i\) 位的一次改變(\(j < i\))。
  • \(\sum_{i = 0}^{+\infty} 2^{-i} = 2\)

這樣,暴力用 NTT 求 \(f\) 的時間複雜度是可接受的,因此解決整個問題的時間複雜度為 \(O(m \log^2 m)\)

程式碼

#include<bits/stdc++.h>

using namespace std;

const int N = 524288, mod = 998244353, G = 3;

void add(int& x, int y) {
  x += y;
  if (x > mod) {
    x -= mod;
  }
}

void sub(int& x, int y) {
  x -= y;
  if (x < 0) {
    x += mod;
  }
}

int mul(int x, int y) {
  return (long long) x * y % mod;
}

int qpow(int v, int p) {
  int result = 1;
  for (; p; p >>= 1, v = mul(v, v)) {
    if (p & 1) {
      result = mul(result, v);
    }
  }
  return result;
}

int n, m, a[N], b[N], c[N], rev[N], len, poly_a[N], poly_b[N];
vector<int> event[N], arr[N], f[N];

void ntt(int* c, int n, int type) {
  for (int i = 0; i < n; ++i) {
    if (i < rev[i]) {
      swap(c[i], c[rev[i]]);
    }
  }
  for (int i = 1; i < n; i <<= 1) {
    int x = qpow(G, type == 1 ? (mod - 1) / (i << 1) : mod - 1 - (mod - 1) / (i << 1));
    for (int j = 0; j < n; j += i << 1) {
      int y = 1;
      for (int k = 0; k < i; ++k, y = mul(y, x)) {
        int p = c[j + k], q = mul(y, c[i + j + k]);
        c[j + k] = (p + q) % mod;
        c[i + j + k] = (p - q + mod) % mod;
      }
    }
  }
  if (type == -1) {
    int inv = qpow(n, mod - 2);
    for (int i = 0; i < n; ++i) {
      c[i] = mul(c[i], inv);
    }
  }
}

void mul(int len_a, int len_b, int* a, int* b, int* c) {
  for (len = 0; (1 << len) <= len_a + len_b; ++len);
  int m = 1 << len;
  for (int i = 0; i < m; ++i) {
    rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << len - 1);
  }
  for (int i = 0; i < m; ++i) {
    poly_a[i] = poly_b[i] = 0;
    if (i <= len_a) {
      poly_a[i] = a[i];
    }
    if (i <= len_b) {
      poly_b[i] = b[i];
    }
  }
  ntt(poly_a, m, 1);
  ntt(poly_b, m, 1);
  for (int i = 0; i < m; ++i) {
    poly_a[i] = mul(poly_a[i], poly_b[i]);
  }
  ntt(poly_a, m, -1);
  for (int i = 0; i < m; ++i) {
    c[i] = poly_a[i];
  }
}

struct poly {
  vector<int> poly_s;

  poly () {
    poly_s.clear();
  }

  bool operator < (const poly& a) const {
    return poly_s.size() > a.poly_s.size();
  }
};

int main() {
  scanf("%d%d", &n, &m);
  n += 19;
  int ans = 0;
  for (int i = 1; i <= m; ++i) {
    int a, x, y;
    scanf("%d%d%d", &a, &x, &y);
    int p = mul(x, qpow(y, mod - 2));
    event[a].push_back(p);
    add(ans, p);
  }
  ans = (ans << 1) % mod;

  auto get_array = [&] (vector<int>& all) {
    priority_queue<poly> s;
    for (auto v : all) {
      poly new_poly;
      new_poly.poly_s.push_back((1 - v + mod) % mod);
      new_poly.poly_s.push_back(v);
      s.push(new_poly);
    }
    while (s.size() > 1) {
      poly l = s.top();
      s.pop();
      poly r = s.top();
      s.pop();
      int len_a = l.poly_s.size() - 1, len_b = r.poly_s.size() - 1;
      for (int i = 0; i <= len_a; ++i) {
        a[i] = l.poly_s[i];
      }
      for (int i = 0; i <= len_b; ++i) {
        b[i] = r.poly_s[i];
      }
      mul(len_a, len_b, a, b, c);
      poly new_poly;
      for (int i = 0; i <= len_a + len_b; ++i) {
        new_poly.poly_s.push_back(c[i]);
      }
      s.push(new_poly);
    }
    return s.top().poly_s;
  };

  for (int i = 0; i <= n; ++i) {
    if (event[i].size()) {
      arr[i] = get_array(event[i]);
    } else {
      arr[i].push_back(1);
    }
  }
  f[0] = arr[0];
  int total_len = event[0].size();
  for (int i = 1; i <= n; ++i) {
    int old = total_len >> 1;
    fill(a, a + old + 1, 0);
    for (int j = 0; j <= total_len; ++j) {
      add(a[j >> 1], f[i - 1][j]);
    }
    for (int j = 0; j <= event[i].size(); ++j) {
      b[j] = arr[i][j];
    }
    mul(old, event[i].size(), a, b, c);
    total_len = old + event[i].size();
    for (int j = 0; j <= total_len; ++j) {
      f[i].push_back(c[j]);
    }
  }
  for (int i = 0; i <= n; ++i) {
    for (int j = 1; j < f[i].size(); j += 2) {
      sub(ans, f[i][j]);
    }
  }
  printf("%d\n", ans);
  return 0;
}