1. 程式人生 > >【模板】矩陣加速

【模板】矩陣加速

求解 print 什麽 就會 復雜 [1] 忽略 版本 cin

矩陣加速,專門用來解決一些遞推的關系,其原理和矩陣運算的法則有關

由於矩陣的乘法有交換律和結合律,所以我們可以通過矩陣快速冪來快速求解遞推關系,一般時間復雜度是O(nlogn)。

矩陣快速冪很簡單,寫一下模板就會了,但是推導單位矩陣是個難題。

一般地,我們推導單位矩陣時,有這幾個步驟。

  1. 確定遞推初始條件
  2. 確定遞推式子的系數
  3. 確定單位矩陣的大小,一般來說和遞推式涉及到的已知量個數有關,必須考慮系數為0的原量,不能忽略
  4. 填單位矩陣,註意一下幾點:
    1. 第一行才是生成新數據的力量。
    2. 所以後面的行只需要把自己表示出來就好了。(為什麽是用“行”來表示已知量?根據矩陣運算法則)

一般容易犯的錯誤:

  • 遞推關系寫錯了
  • 單位矩陣錯誤了
  • 矩陣乘法寫錯了
  • 答案輸出錯了,沒有找到到底應該輸出哪一個
  • 矩陣的指數確定錯了
  • %錯了
  • 沒開long long導致乘法爆了

下貼斐波那契數列矩陣加速版本

#include<iostream>
#include<algorithm>
#include<cstring>
using namespace std;
#define RP(t,a,b) for(ll t=(a),edd=(b);t<=edd;t++)
const int mod=1e9+7;
typedef unsigned long long ll;
const int n=2;
ll k;
struct mtx {
    ll data[3][3];
    ll* operator [](int x){return data[x];}
    mtx() {
        memset(data,0,sizeof data);
    }
    mtx operator *(const mtx& x) {
        mtx temp;
        RP(t,0,n-1)
        RP(i,0,n-1)
        RP(k,0,n-1)
        temp.data[t][i]=(temp.data[t][i]+(data[t][k]*x.data[k][i])%mod)%mod;
        return temp;
    }
    mtx operator *=(const mtx& x) {
        return (*this)=(*this)*x;
    }
    mtx operator ^(const ll& p) {
        ll b=p;
        mtx ans,base;
        ans.unis();
        base=(*this);
        while(b) {
            if(b&1)
                ans*=base;
            base*=base;
            b>>=1;
        }
        return ans;
    }
    mtx operator ^=(const ll& p) {
        return (*this)=(*this)^p;
    }
    void unis() {
        for(ll t=0;t<n;t++)
            data[t][t]=1;
    }
    void print() {
        for(ll t=0; t<n; t++) {
            for(ll ti=0; ti<n; ti++)
                cout<<data[t][ti]<<‘ ‘;
            cout<<"\n";
        }
        return;
    }
};
int T;
int main() {
    //      freopen("in.in","r",stdin);
    cin>>k;
    mtx qaq;
    qaq.data[0][1]=qaq.data[1][0]=qaq.data[1][1]=1;
    qaq^=k-1;
    ll asa=qaq.data[0][0]%mod+qaq.data[1][0]%mod;
    cout<<asa%mod<<endl;
    return 0;
}

然後又附luogu矩陣加速板子的代碼

#include<iostream>
#include<algorithm>
#include<cstring>
using namespace std;
#define RP(t,a,b) for(ll t=(a),edd=(b);t<=edd;t++)
const int mod=1e9+7;
typedef long long ll;
const int n=3;
int k;
struct mtx {
    ll data[3][3];
    mtx() {
        memset(data,0,sizeof data);
    }
    mtx operator *(const mtx& x) {
        mtx temp;
        RP(t,0,n-1)
        RP(i,0,n-1)
        RP(k,0,n-1)
        temp.data[t][i]=(temp.data[t][i]+(data[t][k]*x.data[k][i])%mod)%mod;
        return temp;
    }
    mtx operator *=(const mtx& x) {
        return (*this)=(*this)*x;
    }
    mtx operator ^(const ll& p) {
        ll b=p;
        mtx ans,base;
        ans.unis();
        base=(*this);
        while(b) {
            if(b&1)
                ans*=base;
            base*=base;
            b>>=1;
        }
        return ans;
    }
    mtx operator ^=(const ll& p) {
        return (*this)=(*this)^p;
    }
    void unis() {
        for(ll t=0;t<n;t++)
            data[t][t]=1;
    }
    void print() {
        for(ll t=0; t<n; t++) {
            for(ll ti=0; ti<n; ti++)
                cout<<data[t][ti]<<‘ ‘;
            cout<<"\n";
        }
        return;
    }
};
int T;
int main() {
    //  freopen("in.in","r",stdin);
    cin>>T;
    while(T--){
        cin>>k;
        mtx qaq;
        mtx temp;
        qaq.data[0][0]=qaq.data[0][2]=qaq.data[1][0]=qaq.data[2][1]=1;
        qaq^=k;
        int pos=0;
        cout<<qaq.data[1][0]<<endl;
    }
    return 0;
}
#include<iostream>
#include<algorithm>
#include<cstring>
using namespace std;
#define RP(t,a,b) for(ll t=(a),edd=(b);t<=edd;t++)
const int mod=1e9+7;
typedef long long ll;
const int n=3;
int k;
struct mtx {
    ll data[3][3];
    mtx() {
        memset(data,0,sizeof data);
    }
    mtx operator *(const mtx& x) {
        mtx temp;
        RP(t,0,n-1)
        RP(i,0,n-1)
        RP(k,0,n-1)
        temp.data[t][i]=(temp.data[t][i]+(data[t][k]*x.data[k][i])%mod)%mod;
        return temp;
    }
    mtx operator *=(const mtx& x) {
        return (*this)=(*this)*x;
    }
    mtx operator ^(const ll& p) {
        ll b=p;
        mtx ans,base;
        ans.unis();
        base=(*this);
        while(b) {
            if(b&1)
                ans*=base;
            base*=base;
            b>>=1;
        }
        return ans;
    }
    mtx operator ^=(const ll& p) {
        return (*this)=(*this)^p;
    }
    void unis() {
        for(ll t=0;t<n;t++)
            data[t][t]=1;
    }
    void print() {
        for(ll t=0; t<n; t++) {
            for(ll ti=0; ti<n; ti++)
                cout<<data[t][ti]<<‘ ‘;
            cout<<"\n";
        }
        return;
    }
};
int T;
int main() {
    //  freopen("in.in","r",stdin);
    cin>>T;
    while(T--){
        cin>>k;
        mtx qaq;
        mtx temp;
        qaq.data[0][0]=qaq.data[0][2]=qaq.data[1][0]=qaq.data[2][1]=1;
        qaq^=k;
        int pos=0;
        cout<<qaq.data[1][0]<<endl;
    }
    return 0;
}

唉我太弱了orz

【模板】矩陣加速