1. 程式人生 > >Meet in the middle演算法總結 (附模板及SPOJ ABCDEF、BZOJ4800、POJ 1186、BZOJ 2679 題解)

Meet in the middle演算法總結 (附模板及SPOJ ABCDEF、BZOJ4800、POJ 1186、BZOJ 2679 題解)

目錄

Meet in the Middle 總結

1.演算法模型

1.1 Meet in the Middle演算法的適用範圍

如果將搜尋的路徑看成一個有向圖,Meet in the Middle 演算法適用於求有向圖上從A到B,且長度為L的路徑數。

換句話說,Meet in the Middle 演算法適用於求經過L步變化,從A變到B需要的方案數

1.2Meet in the Middle的基本思想

對於一般的搜尋演算法,我們一般是從一個初始狀態開始搜尋,直到找到結果。

meet in the middle演算法採用從初態和終態出發各搜尋一半狀態,使得搜尋樹在中間相遇,這樣

可以顯著減少搜尋的深度。

1.3Meet in the Middle的演算法過程

  1. 從狀態A出發搜尋L1步,記錄走L1步到達狀態i的步數為count(i)
  2. 從狀態B出發搜尋L2步,如果到達狀態i,且count(i)不為0,則把count(i)累加到答案中

首先要保證L1+L2=L,一般選擇L1=L2=\(\frac{L}{2}\) ,但在某些問題中,不均勻的搜尋反而會跑得更快

通俗的說,就是從起點,終點各搜尋一半狀態,再將狀態合起來

1.4Meet in the Middle的時間複雜度分析

設A到B搜尋L步,每個搜尋樹的節點的分叉最多D個,計算出下一個節點的時間為T,每個節點的空間為S

則總狀態數為$ 2*D^{\frac{L}{2}}$

搜尋答案需要\(O(D^{\frac{L}{2}}T)\)的時間,合併答案理論上需要\(O(D^{\frac{L}{2}}S)\)的時間(在後面的程式碼實現部分我們會討論,由於寫法的不同,時間複雜度可能會略大),總時間複雜度為\(O(D^{\frac{L}{2}}(S+T))\)

2.程式碼實現

Meet in the Middle演算法有幾種實現方法,搜尋部分大致相同,合併答案部分有多種實現

我們以下面這道題為例:

例題 [SPOJ ABCDEF]

在[-30000,30000]範圍裡,給出一組整數集合S。找到滿足的六元組的總數使得:

\[\frac{ab+c}{d}-e=f\]

並且保證元組\((a,b,c,d,e,f):a,b,c,d,e,f \in S;d \neq0\)

我們將問題轉化成\(ab+c=d(e+f)\) (注意,使等式兩邊未知數個數相等或儘量均勻分佈是用meet in the middle演算法解決等式問題的常見方法)

然後我們先搜尋ab+c的所有可能結果(可以用DFS,但for迴圈更簡潔)

然後搜尋d(e+f)的所有可能結果,然後將兩步的結果合起來即可得到答案

法1: 結果合併法

**我們將兩次搜尋的結果分別存在陣列a,b裡,然後嘗試在$O(m)$或$O(m\log m)$的時間內將結果合併**

(在本部分,我們定義m為搜尋到的狀態數量)

首先我們把a,b陣列排序

對於此題來說,對於a陣列的每一個數a[i],我們在b中二分查詢與a[i]相等的數的個數 (查詢第一>=a[i]的數的位置與第一個>a[i]的數的位置,兩個位置相減即為答案
#include<iostream>
#include<cstdio>
#include<algorithm>
#define maxn 100
using namespace std;
int n,m;
long long s[maxn+5];
long long num[maxn*maxn*maxn+5];
int main(){
    int l,r;
    scanf("%d",&n);
    for(int i=1;i<=n;i++){
        scanf("%lld",&s[i]);
    }
    for(int i=1;i<=n;i++){//搜尋出a陣列
        for(int j=1;j<=n;j++){
            for(int k=1;k<=n;k++){
                num[++m]=s[i]*s[j]+s[k];
            }
        }
    }
    sort(num+1,num+1+m);
    long long ans=0;
    for(int i=1;i<=n;i++){
        if(s[i]==0) continue;
        for(int j=1;j<=n;j++){
            for(int k=1;k<=n;k++){//搜尋出b陣列,為了節約空間,此處不必儲存
                l=lower_bound(num+1,num+1+m,s[i]*(s[j]+s[k]))-num;//第一個=a[i]的數的位置
                r=upper_bound(num+1,num+1+m,s[i]*(s[j]+s[k]))-num-1;//最後一個一個=a[i]的數的位置
                if(r>=l) ans+=(r-l+1);
            }
        }
    }
    printf("%lld\n",ans);
} 

法2:雜湊表

此方法相對較簡潔,理論時間複雜度也更低,為\(O(m)\)

(因為法1常常要用到排序和二分查詢,時間複雜度要多一個log)

第一次搜尋時把答案存進雜湊表,第二次搜尋時在雜湊表中查詢即可

但是考慮到hash過程中可能會產生碰撞,導致效率降低。本人不太推薦這種方法(被卡了很多次)

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<map> 
#define maxn 100
using namespace std;
int n,m;
long long s[maxn+5];

const int c=133331;
struct myhash {
    struct node {
        long long v;
        long long cnt; 
        node* next;
    };
    node a[200005];
    void set0(){
        for(int i=0;i<maxn;i++){
            a[i].v=0;
            a[i].next=NULL;
        }
    }
    void push(long long x) {
        long long t;
        if(x<0) t=-x;
        else t=x;
        node *tmp=a[t%c].next;
        if(tmp==NULL){
            tmp=new node();
            tmp->v=x;
            tmp->cnt=1;
            tmp->next=NULL; 
            a[t%c].next=tmp;
            return;
        }
        while(tmp!=NULL){
            if(tmp->v==x){
                tmp->cnt++;
                return;
            } 
            tmp=tmp->next;
        }
        if(tmp==NULL){ 
            node *tmp=new node();
            tmp->v=x;
            tmp->cnt=1;
            tmp->next=a[t%c].next;
            a[t%c].next=tmp;
        } 
    }
    long long count(long long x) {
        long long t;
        if(x<0) t=-x;
        else t=x;
        if(a[t%c].next==NULL) return 0;
        else {
            node* tmp=a[t%c].next;
            while(tmp!=NULL) {
                if(tmp->v==x) return tmp->cnt;
                tmp=tmp->next;
            }
            return 0;
        }
    }
};
myhash num;
int main(){
    int l,r;
    scanf("%d",&n);
    for(int i=1;i<=n;i++){
        scanf("%lld",&s[i]);
    }
    num.set0();
    for(int i=1;i<=n;i++){
        for(int j=1;j<=n;j++){
            for(int k=1;k<=n;k++){
                num.push(s[i]*s[j]+s[k]);
            }
        }
    }
    long long ans=0;
    for(int i=1;i<=n;i++){
        if(s[i]==0) continue;
        for(int j=1;j<=n;j++){
            for(int k=1;k<=n;k++){ 
                ans+=num.count(s[i]*(s[j]+s[k]));
            }
        }
    }
    printf("%lld\n",ans);
} 

法3:map

使用STL map,本質和法2一樣,但是由於map是基於紅黑樹,時間複雜度為\(O(m\log m)\)

雖然map的效率看起來比雜湊表低,但map更穩定,不會出現衝突的情況。

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<map> 
#define maxn 100
using namespace std;
int n,m;
long long s[maxn+5];
map<long long,long long>num;//num[i]表示數字i出現的次數
int main(){
    int l,r;
    scanf("%d",&n);
    for(int i=1;i<=n;i++){
        scanf("%lld",&s[i]);
    }
    for(int i=1;i<=n;i++){
        for(int j=1;j<=n;j++){
            for(int k=1;k<=n;k++){
                num[s[i]*s[j]+s[k]]++;//存進map裡
            }
        }
    }
    long long ans=0;
    for(int i=1;i<=n;i++){
        if(s[i]==0) continue;
        for(int j=1;j<=n;j++){
            for(int k=1;k<=n;k++){
                ans+=num[s[i]*(s[j]+s[k])];//直接查詢
            }
        }
    }
    printf("%lld\n",ans);
} 

三種實現方法的比較:

實現方法 時間複雜度 實際執行效率 思維難度 程式碼量
序列合併 O(mlogm) 較快 較大(有些題不好合並)
散列表 O(m) 不被卡時最快,但可能會被卡 較小 大(手寫hash表)
STL map O(mlogm) 較快,但常數較大 最小

3.擴充套件運用

我們通過一些例題來了解

[BZOJ 4800] 冰球世界錦標賽

有n個物品,m塊錢,給定每個物品的價格,求買物品的方案數。

解法:先搜尋用前n/2個物品湊出的錢數,再搜尋後n/2個物品湊出的錢數

採用結果合併法,lans儲存前一半物品湊出的錢數,rans儲存後一半

對於lans[i],在rans中查詢<=m-lans[i]的數的個數

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#define maxn 45
#define maxs 2000005
using namespace std;
int n;
long long m;
long long a[maxn];
int lcnt; 
long long lans[maxs];//前一半
int rcnt; 
long long rans[maxs];//後一半
void dfs(long long *ans,int &size,int deep,int top,long long sum){
    //ans答案陣列,size結果數量,deep表示當前搜尋的物品,top表示最多搜到第幾個物品,sum表示總和
    if(sum>m) return;
    if(deep==top+1){
        ans[++size]=sum;
        return;
    }
    dfs(ans,size,deep+1,top,sum+a[deep]);
    dfs(ans,size,deep+1,top,sum);
}

int main(){
    scanf("%d %lld",&n,&m);
    for(int i=1;i<=n;i++){
        scanf("%lld",&a[i]);
    }
    sort(a+1,a+1+n);
    dfs(lans,lcnt,1,n/2,0);//搜尋[1,n/2]
    dfs(rans,rcnt,n/2+1,n,0);//搜尋[n/2+1,n]
    sort(lans+1,lans+1+lcnt);
    sort(rans+1,rans+1+rcnt);
    long long ans=0;
    for(int i=1;i<=lcnt;i++){//合併答案
        ans+=upper_bound(rans+1,rans+1+rcnt,m-lans[i])-rans-1;
        //這裡求的是<=m-lans[i]的rans的個數,即rans中第一個大於m-lans[i]的數的位置-1
    } 
    printf("%lld\n",ans);
}

[POJ 1186] 方程的解數

 已知一個n元高次方程:

  \(k_1x_1^{p_1}+k_2x_2^{p_2}+ \dots+k_nx_n^{p_n}=0\)

 其中:x1, x2,…,xn是未知數,k1,k2,…,kn是係數,p1,p2,…pn是指數。且方程中的所有數均為整數。

  假設未知數1 <= xi <= M, i=1,,,n,求這個方程的整數解的個數。

解法:將方程變形為

\(k_1x_1^{p_1}+k_2x_2^{p_2}+ \dots+k_{n/2}x_{n/2}^{p_{n/2}}=-k_{n/2+1}x_{n/2+1}^{p_{n/2+1}}-\dots-k_{n-1}x_{n-1}^{p_{n-1}}-k_nx_n^{p_n}\)

然後就可以類似例題"ABCDEF",分別搜尋左邊和右邊的所有可能取值,然後再合併

合併方法類似點分治的合併,維護兩個索引i,j

i從1~n遍歷lans

對於每個i,j從n到1迴圈,直到找到一個rans[j0]=lans[i0]

然後對於i,從i0向後找到與ans[i0]相等的所有i,並計算出個數cnt1

對於j,從j0向前找到與ans[j0]相等的所有j,並計算出個數cnt2

根據乘法原理,應將cnt1*cnt2累加進答案

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define maxn 7
#define maxs 33750005
using namespace std;
int n,m;
int k[maxn];
int p[maxn];
int x[maxn];
int lcnt; 
long long lans[maxs];
int rcnt; 
long long rans[maxs];
inline long long fast_pow(long long x,long long k){
    int ans=1;
    while(k){
        if(k&1) ans=ans*x;
        x=x*x;
        k>>=1;
    }
    return ans;
}

void dfs(long long *ans,int &size,int deep,int l,int r,long long sum){
    if(deep==r+1){
        ans[++size]=sum;
        return;
    }
    for(int i=1;i<=m;i++){
        x[deep]=i;//列舉x的可能取值
        dfs(ans,size,deep+1,l,r,sum+k[deep]*fast_pow(x[deep],p[deep]));
    }
}

int main(){
    scanf("%d %d",&n,&m);
    for(int i=1;i<=n;i++){
        scanf("%d %d",&k[i],&p[i]);
    }
    dfs(lans,lcnt,1,1,n/2,0);
    dfs(rans,rcnt,n/2+1,n/2+1,n,0);
    sort(lans+1,lans+1+lcnt);
    sort(rans+1,rans+1+rcnt);
    int j=rcnt;
    long long ans,cnt1,cnt2;
    ans=0;
    for(int i=1;i<=lcnt;i++){//合併答案
        while(lans[i]+rans[j]>0){
            j--;
        }
        if(j<=0) break;
        if(lans[i]+rans[j]!=0) continue;
        cnt1=cnt2=1;
        while(i<lcnt&&lans[i+1]==lans[i]){
            i++;
            cnt1++;
        }
        while(j>1&&rans[j-1]==rans[j]){
            j--;
            cnt2++;
        }
        ans+=cnt1*cnt2;
    }
    printf("%lld\n",ans);
} 

[BZOJ 2679] Balanced Cow Subsets

給n個數,從中任意選出一些數,使這些數能分成和相等的兩組。
求有多少種選數的方案。n≤20

解法:轉化為和上面的方程類似的模型

$ a_1x_1+a_2 x_2+ \dots a_n x_n=0 $

其中x只能取1,-1,0 (1代表放在集合1,-1代表放在集合2,0代表不選)

同樣對方程移項

不同的是,此題求的不是方程的解,而是方程的解對應集合的數量

我們採用狀壓,每次搜尋除了記錄數的和,還要記錄狀態 (選為1,不選為0)

因為這題不方便用序列合併,我們用map和vector來實現:

開一個map,狀態個數個vector

第一次搜尋,記錄下lsum,lset(lsum為和,lset為對應的集合,mp為一個STL map),用map來離散化,得到lsum對應的離散化後值,並往該值對應的vector中插入lset

第二次搜尋,如果搜到的和rsum在mp裡面出現過,則可以查到rsum對應的vector編號

遍歷這個vector,若vector中的數是lset,當前搜尋的集合為rset,則lset+rset為滿足條件的集合 (和相等且不重複)

用一個bool陣列統計該集合是否出現過即可

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<map>
#include<vector>
#define maxn 25
#define mod 999917
#define maxh 1000005
#define maxb 1050000
using namespace std;
inline int qread(){
    int x=0,sign=1;
    char c=getchar();
    while(c<'0'||c>'9'){
        if(c=='-') sign=-1;
        c=getchar();
    }
    while(c>='0'&&c<='9'){
        x=x*10+c-'0';
        c=getchar();
    }
    return x*sign;
}
int n;
long long a[maxn];
int vis[maxb];
map<long long,int>mp;
vector<int>v[maxh];
int ans=0;
int cnt=0;
void dfs1(int deep,int top,long long sum,int s) {
    if(deep==top+1) {
        if(!mp.count(sum)) mp[sum]=++cnt;
        v[mp[sum]].push_back(s);
        return;
    }
    dfs1(deep+1,top,sum+a[deep],s|(1<<(deep-1)));
    dfs1(deep+1,top,sum-a[deep],s|(1<<(deep-1)));
    dfs1(deep+1,top,sum,s);
}

void dfs2(int deep,int top,long long sum,int s) {
    if(deep==top+1) {
        if(!mp.count(sum)) return;
        int id=mp[sum];
        for(int i=0;i<v[id].size();i++){
            if(!vis[v[id][i]|s]) ans++;
            vis[v[id][i]|s]=1;
        } 
        return;
    }
    dfs2(deep+1,top,sum+a[deep],s|(1<<(deep-1)));
    dfs2(deep+1,top,sum-a[deep],s|(1<<(deep-1)));
    dfs2(deep+1,top,sum,s);
}

int main() {
    n=qread();
    for(int i=1; i<=n; i++) {
        a[i]=qread();
    }
    int m=n/2;;
    dfs1(1,m,0,0);
    dfs2(m+1,n,0,0);
    printf("%d\n",ans-1);

}

參考資料:

喬明達《搜尋問題中的meet in the middle技巧》