1. 程式人生 > >【知識總結】多項式全家桶(一)(NTT、加減乘除和求逆)

【知識總結】多項式全家桶(一)(NTT、加減乘除和求逆)

我這種數學一竅不通的菜雞終於開始學多項式全家桶了……

必須要會的前置技能:FFT(不會?戳我:【知識總結】快速傅立葉變換(FFT)

一、NTT

跟FFT功能差不多,只是把複數域變成了模域(計算複數係數多項式相乘變成計算在模意義下整數係數多項式相乘)。你看FFT裡的單位圓是迴圈的,模一個質數也是迴圈的嘛qwq。\(n\)次單位根\(w_n\)怎麼搞?看這裡:【BZOJ3328】PYXFIB(數學)(內含相關證明。只看與原根和單位根相關的內容即可。)

注意裸的NTT要求模數\(p\)存在原根並且\(p-1\)\(2\)的若干次冪的倍數(這個次數要大於多項式次數\(n\))。於是通常就會用著名的NTT模數:\(998244353=2^{23}\times 7\times 17+1\)

節約篇幅,程式碼先不放了。後面所有程式碼裡都有NTT模板……

二、多項式求逆

對於\(n\)次多項式\(A\),如果有多項式\(B\)滿足\(AB\equiv 1 \mod x^{n+1}\),則稱\(B\)\(A\)在模\(x^{n+1}\)意義下的逆元(和整數逆元差不多)。通常採用倍增的方法求逆元。通常都會規定多項式係數在模\(p\)的意義下。

首先,\(A\)在模\(x\)的意義下就只有一個常數項,所以此時的逆元\(B\)也只有一個常數項,就是\(A\)的常數項模\(p\)的逆元。

如果我們知道\(B_0\)\(A\)在模\(x^{\lceil\frac{n}{2}\rceil}\)

意義下的逆元,現在要求\(B\)\(A\)在模\(x^n\)意義下的逆元。根據題設,顯然有:

\[AB=1\mod x^n\]

很明顯,\(AB\)\(1\)\(n-1\)次項係數全是\(0\),所以模一個\(x\)的低於\(n\)次冪也一定是\(1\)。所以

\[AB_0=AB=1\mod x^{\lceil\frac{n}{2}\rceil}\]

那麼

\[B-B_0=0\mod x^{\lceil\frac{n}{2}\rceil}\]

兩邊和模數同時平方:

\[B^2+B_0^2-2BB_0=0\mod x^n\]

兩邊同時乘\(A\),得到(別忘了\(AB=1\mod x^n\)

):

\[B+AB_0^2-2B_0=0\mod x^n\]

然後移項,得到:

\[B=2B_0-AB_0^2\mod x^n\]

照著這個式子遞迴算就行了。

程式碼:

洛谷4238

注意程式碼裡面的\(n\)是項數不是次數。一定要把沒用的陣列清空,以及進行NTT時把多項式項數寫對。

程式碼最開始是防機慘護身符。

#include <cstdio>
#include <algorithm>
#include <cctype>
#include <cstring>
#undef i
#undef j
#undef k
#undef max
#undef min
#undef swap
#undef sort
#undef true
#undef false
#undef if
#undef for
#undef while
#define _ 0
using namespace std;

namespace zyt
{
    template<typename T>
    inline bool read(T &x)
    {
        char c;
        bool f = false;
        x = 0;
        do
            c = getchar();
        while (c != EOF && c != '-' && !isdigit(c));
        if (c == EOF)
            return false;
        if (c == '-')
            f = true, c = getchar();
        do
            x = x * 10 + c - '0', c = getchar();
        while (isdigit(c));
        if (f)
            x = -x;
        return true;
    }
    template<typename T>
    inline void write(T x)
    {
        static char buf[20];
        char *pos = buf;
        if (x < 0)
            putchar('-'), x = -x;
        do
            *pos++ = x % 10 + '0';
        while (x /= 10);
        while (pos > buf)
            putchar(*--pos);
    }
    typedef long long ll;
    const int N = 1e5 + 10, B = 17, LEN = 1 << (B + 2) | 11, p = 998244353;
    inline int power(int a, int b)
    {
        a %= p, b %= p - 1;
        int ans = 1;
        while (b)
        {
            if (b & 1)
                ans = (ll)ans * a % p;
            a = (ll)a * a % p;
            b >>= 1;
        }
        return ans;
    }
    inline int get_inv(const int a)
    {
        return power(a, p - 2);
    }
    namespace Polynomial
    {
        int omega[LEN], winv[LEN], rev[LEN];
        namespace Primitive_Root
        {
            int cnt;
            pair<int, int> prime[20];
            inline void get_prime(int n)
            {
                cnt = 0;
                for (int i = 2; i * i <= n; i++)
                {
                    if (n % i == 0)
                        prime[cnt++] = make_pair(i, 0);
                    while (n % i == 0)
                        ++prime[cnt - 1].second, n /= i;
                }
                if (n > 1)
                    prime[cnt++] = make_pair(n, 1);
            }
            inline int get_g(const int n)
            {
                get_prime(n - 1);
                for (int i = 2; i < n; i++)
                {
                    bool flag = true;
                    for (int j = 0; j < cnt && flag; j++)
                        flag &= (power(i, (n - 1) / prime[j].first) != 1);
                    if (flag)
                        return i;
                }
                return -1;
            }
        }
        void ntt(int *a, const int *w, const int n)
        {
            for (int i = 0; i < n; i++)
                if (i < rev[i])
                    swap(a[i], a[rev[i]]);
            for (int l = 1; l < n; l <<= 1)
                for (int i = 0; i < n; i += (l << 1))
                    for (int k = 0; k < l; k++)
                    {
                        int tmp = (a[i + k] - (ll)w[n / (l << 1) * k] * a[i + l + k] % p + p) % p;
                        a[i + k] = (a[i + k] + (ll)w[n / (l << 1) * k] * a[i + l + k] % p) % p;
                        a[i + l + k] = tmp;
                    }
        }
        void init(const int n, const int lg2)
        {
            static int g = 0;
            if (!g)
                g = Primitive_Root::get_g(p);
            int w = power(g, (p - 1) / n), wi = get_inv(w);
            omega[0] = winv[0] = 1;
            for (int i = 1; i < n; i++)
            {
                omega[i] = (ll)omega[i - 1] * w % p;
                winv[i] = (ll)winv[i - 1] * wi % p;
            }
            for (int i = 0; i < n; i++)
                rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (lg2 - 1)));
        }
        void inv(const int *a, int *ans, const int n)
        {
            if (n == 1)
                ans[0] = get_inv(a[0]);
            else
            {
                static int tmp[LEN];
                inv(a, ans, (n + 1) >> 1);
                int m = 1, lg2 = 0;
                while (m < (n << 1) - 1)
                    m <<= 1, ++lg2;
                memcpy(tmp, a, sizeof(int[n]));
                init(m, lg2);
                ntt(tmp, omega, m);
                ntt(ans, omega, m);
                for (int i = 0; i < m; i++)
                    ans[i] = (ans[i] * 2LL % p - (ll)tmp[i] * ans[i] % p * ans[i] % p + p) % p;
                ntt(ans, winv, m);
                int invm = get_inv(m);
                for (int i  = 0; i < m; i++)
                    ans[i] = (ll)ans[i] * invm % p;
                memset(ans + n, 0, sizeof(int[m - n]));
                memset(tmp, 0, sizeof(int[m]));
            }
        }
    }
    int a[LEN], b[LEN], n;
    int work()
    {
        read(n);
        for (int i = 0; i < n; i++)
            read(a[i]);
        Polynomial::inv(a, b, n);
        for (int i = 0; i < n; i++)
            write(b[i]), putchar(' ');
        return (0^_^0);
    }
}
int main()
{
    return zyt::work();
}

三、加減乘除

加減法:直接每項對應相加減。

乘法:這就是NTT的目的啊喂!

除法:如果不是帶餘除法直接乘逆元。下面著重介紹帶餘除法。

已知\(n\)次多項式\(F\)\(m\)次多項式\(G\),求\(n-m\)次多項式\(Q\)和多項式\(R\)\(R\)的次數\(deg_R\)小於\(m\)),滿足:

\[F=QG+R\]

(未完待續咕咕咕……