1. 程式人生 > >基本演算法之分治法

基本演算法之分治法

合併排序

合併排序的時間複雜度為:O(nlogn),最壞情況下的鍵值比較次數接近於任何基於比較的排序演算法的理論上能夠達到的最小次數,主要缺點是該演算法需要線性的額外空間。

#include "stdafx.h"

#include<iostream>
using namespace std;

void Merge(int *a,int *b,int left,int middle,int right)    {
    
    int i=left,j=middle+1,k=left;
    while(i<=middle&&j<=right)    {
        if(a[i]<=a[j])    
            b[k++]=a[i++];
        else
            b[k++]=a[j++];
    }
    if(i>middle)    {
    
        for(int p=j;p<=right;++p)
            b[k++]=a[p];
    }

    else    {
        
        for(int p=i;p<=middle;++p)
            b[k++]=a[p];
    }
}


void CopyArray(int *a,int *b,int left,int right)    {

    for(int i=left;i<=right;++i)
        a[i]=b[i];

}

void MergeSort(int *a,int left,int right)    {

    if(left<right)    {
        int middle=(left+right)/2;
        int b[10];
        MergeSort(a,left,middle);
        MergeSort(a,middle+1,right);
        //合併左右兩部分
        Merge(a,b,left,middle,right);
        CopyArray(a,b,left,right);
    }
}

int _tmain(int argc, _TCHAR* argv[])
{
    
    int num=0;
    int a[10];
    cout<<"將要輸入的數的個數:"<<endl;
    cin>>num;
    cout<<"依次輸入每一個數:"<<endl;
    for(int i=0;i<num;++i)
        cin>>a[i];
    MergeSort(a,0,num-1);
    cout<<endl;
    for(int i=0;i<num;++i)
        cout<<a[i]<<"   ";
    system("pause");
    return 0;

}

待拓展:K-way merge algorithm

快速排序

時間複雜度為O(nlogn)

#include "stdafx.h"
#include<iostream>
using namespace std;

int Partion(int *a,int left,int right)    {
    
    int pivot=a[left];
    while(left<right)    {
        while(left<right&&a[right]>=pivot)    {
            --right;
        }
        if(left<right)    {
            a[left++]=a[right];
        }
        while(left<right&&a[left]<=pivot)    {
            ++left;
        }
        if(left<right)
            a[right--]=a[left];
    }
    a[left]=pivot;
    return left;
}

void QuickSort(int *a,int left,int right)    {
        
    if(left<right)    {
        
        int povitpos=Partion(a,left,right);
        QuickSort(a,left,povitpos-1);
        QuickSort(a,povitpos+1,right);
    }
}


int _tmain(int argc, _TCHAR* argv[])
{
    
    int a[5];
    cout<<"please input five digits:"<<endl;
    for(int i=0;i<5;++i)
        cin>>a[i];
    QuickSort(a,0,4);
        for(int i=0;i<5;++i)
        cout<<a[i]<<"  ";
        
    system("pause");
    return 0;
}

待擴充套件:

  • 中樞的選擇
  • 當子問題更小時,改用插入排序
  • 不用遞迴方法

折半查詢的遞迴演算法和非遞迴演算法

時間複雜度為logn

#include "stdafx.h"
#include<iostream>
using namespace std;
//遞迴演算法
int BinarySearch(int *a,int key,int left,int right)    {
    
    if(left<=right)    {
    
        int middle=(left+right)/2;
        
        if(a[middle]==key)
            return 1;
        
        else if(a[middle]<key)
            return BinarySearch(a,key,middle+1,right);
        
        else if(a[middle]>key)
            return BinarySearch(a,key,left,middle-1);
    }
   else
       return -1;
}

//非遞迴演算法

int BinarySearchNotRecursion(int *a,int key,int left,int right)    {

    while(left<=right)    {
    
        int middle=(left+right)/2;
        if(key==a[middle]) return 1;
        else if(key<a[middle])
            right=middle-1;
        else if(key>a[middle])
            left=middle+1;
    }
    return -1;


}

int _tmain(int argc, _TCHAR* argv[])
{
    int a[5];
    int key;

    cout<<"請輸入5個有序的數字:"<<endl;
    for(int i=0;i<5;++i)
        cin>>a[i];
    cout<<"輸入要查詢的數字:"<<endl;
    cin>>key;
    //cout<<BinarySearch(a,key,0,4);
    cout<<BinarySearchNotRecursion(a,key,0,4);
    system("pause");
    return 0;
}

大整數相乘

該演算法是為了減少相乘的次數,對於不是很大的整數,該演算法的執行時間很可能比經典演算法長,因為它是一個遞迴演算法,但實驗顯示,從大於600位的整數開始,分治法的效能超越了筆算演算法的效能。    其演算法時間複雜度為n的平方;

#include "stdafx.h"
#include<iostream>
using namespace std;
// 假設輸入的a和b都是非負的,且N 為2的冪

int BigIntMultiplication(int a,int b,int N)    {
    
    if(a==0||b==0)
        return 0;
    else if(N==1)
        return a*b;
    else {
    
    int a1,a2,b1,b2;
    
    a1=a/pow(10.0,N/2);
    a2=a-a1*pow(10.0,N/2);
    b1=b/pow(10.0,N/2);
    b2=b-b1*pow(10.0,N/2);

    int c0,c1,c2;
    c1=BigIntMultiplication(a1,b1,N/2);
    c2=BigIntMultiplication(a2,b2,N/2);
    c0=(BigIntMultiplication((a1+a2),(b1+b2),N/2))-(c1+c2);     //  實際上c0=a1*b2+a2*b1

    return c1*pow(10.0,N)+c0*pow(10.0,N/2)+c2;
    }

}

int _tmain(int argc, _TCHAR* argv[])
{
    cout<<"請輸入大整數的位數(必須為2的冪)及兩個大整數:"<<endl;
    int a,b,N;
    cin>>N>>a>>b;
    cout<<"結果為:"<<BigIntMultiplication(a,b,N);
    system("pause");
    return 0;
}

strassen矩陣乘法

思路: 兩個n階方正相乘a*b=c(假設n為2的冪,若n不是2的冪可以通過給方正新增0來達到效果),可以分別將這兩個方正拆成4個n/2階的小方陣,分別為a11、a12、a21、a22、b11、b12、b21、b22,然後分別通過將這8個小方陣相乘、相加、或是相減來分別求出c11、c12、c21、c22,具體的公式如下:

m1=(a11+a22)*(b11+b22),m2=(a21+a22)*b11,m3=a11*(b12-b22),m4=a22*(b21-b11),m4=a22*(b21-b11),m5=(a11+a12)*b22,m6=(a21-a11)*(b11+b12)

m7=(a12-a22)*(b21+b22);  

c11=m1+m4+m7-m5,c12=m3+m5,c21=m2+m4,c22=m1+m3+m6-m2  ;

通常的方法計算方陣乘積要計算nE3個乘法,而strassen方法有所提高,需要nElog7次乘法計算。

原始碼如下:

#include "stdafx.h"
#include<iostream>
using namespace std;

//#define N 8;
const int N=8;

//方陣資料
void InputMatrix(int n,int a[][N])    {
    
    for(int i=0;i<n;++i)    {
        cout<<"請輸入第"<<i+1<<"行資料"<<endl;
        for(int j=0;j<n;++j)
            cin>>a[i][j];
    }

}

//方陣資料
void OutputMatrix(int n,int a[][N])    {

        for(int i=0;i<n;++i)    {
            for(int j=0;j<n;++j)
                cout<<a[i][j]<<" ";
            cout<<endl;
    }

}


//方陣相加
void MatrixAdd(int n,int a[][N],int b[][N],int c[][N])    {
    for(int i=0;i<n;++i)
        for(int j=0;j<n;++j)    {
        
        c[i][j]=a[i][j]+b[i][j];
        }
}

//方陣相減
void MatrixSub(int n,int a[][N],int b[][N],int c[][N])    {
    for(int i=0;i<n;++i)
        for(int j=0;j<n;++j)    {
        
        c[i][j]=a[i][j]-b[i][j];
        }
}

//2階矩陣的乘法
void MatrixMultiplication(int a[][N],int b[][N],int c[][N])    {

    for(int i=0;i<2;++i)
        for(int j=0;j<2;++j)    {
            c[i][j]=0;
            for(int k=0;k<2;++k)
                c[i][j]+=a[i][k]*b[k][j];
        }

}

/* 遞迴法   strassen矩陣乘法*/

void StrassenMatrixMultiplication(int n,int a[][N],int b[][N],int c[][N])    {
    
    int a11[N][N],a12[N][N],a21[N][N],a22[N][N];
    int b11[N][N],b12[N][N],b21[N][N],b22[N][N];

    int c11[N][N],c12[N][N],c21[N][N],c22[N][N];
    int M1[N][N],M2[N][N],M3[N][N],M4[N][N],M5[N][N],M6[N][N],M7[N][N];

    int a_temp1[N][N],a_temp2[N][N],b_temp1[N][N],b_temp2[N][N],M_temp1[N][N],M_temp2[N][N];

    for(int i=0;i<n/2;++i)
        for(int j=0;j<n/2;++j)    {
        
            a11[i][j]=a[i][j];
            a12[i][j]=a[i][j+n/2];
            a21[i][j]=a[i+n/2][j];
            a22[i][j]=a[i+n/2][j+n/2];


            b11[i][j]=b[i][j];
            b12[i][j]=b[i][j+n/2];
            b21[i][j]=b[i+n/2][j];
            b22[i][j]=b[i+n/2][j+n/2];
    }
    if(n==2)    {
        MatrixMultiplication(a,b,c);
        
    }

    else    {

        MatrixAdd(n/2,a11,a22,a_temp1);
        MatrixAdd(n/2,b11,b22,b_temp1);
        StrassenMatrixMultiplication(n/2,a_temp1,b_temp2,M1);    //  M1=(a11+a22)*(b11+b22)

        MatrixAdd(n/2,a21,a22,a_temp1);
        StrassenMatrixMultiplication(n/2,a_temp1,b11,M2);       //  M2=(a21+a22)*b11

        MatrixSub(n/2,b12,b22,b_temp1);
        StrassenMatrixMultiplication(n/2,a11,b_temp1,M3);        //  M3=a11*(b12-b22)

        MatrixAdd(n/2,b21,b11,b_temp1);
        StrassenMatrixMultiplication(n/2,a22,b_temp1,M4);    //   M4=a22*(b21-b11)
        
        MatrixAdd(n/2,a12,a11,a_temp1);
        StrassenMatrixMultiplication(n/2,a_temp1,b22,M5);    //  M5=(a11+a12)*b22

        MatrixSub(n/2,a21,a11,a_temp1);
        MatrixAdd(n/2,b11,b12,b_temp1);
        StrassenMatrixMultiplication(n/2,a_temp1,b_temp1,M6);     // M6=(a21-a11)*(b11+b12)

        MatrixSub(n/2,a12,a22,a_temp1);
        MatrixAdd(n/2,b21,b22,b_temp1);
        StrassenMatrixMultiplication(n/2,a_temp1,b_temp1,M7);    //M7=(a12-a22)*(b21+b22)


        MatrixAdd(n/2,M1,M4,M_temp1);
        MatrixSub(n/2,M7,M5,M_temp2);
        MatrixAdd(n/2,M_temp1,M_temp2,c11);      // c11=M1+M4-M5+M7

        MatrixAdd(n/2,M3,M5,c12);           //  c12=M3+M5

        MatrixAdd(n/2,M2,M4,c21);        // c21=M2+M4


        MatrixAdd(n/2,M1,M3,M_temp1);
        MatrixSub(n/2,M6,M2,M_temp2);
        MatrixAdd(n/2,M_temp1,M_temp2,c22);    //c22=M1+M3-M2+M6
        
        for(int i=0;i<n/2;++i)
            for(int j=0;j<n/2;++j)    {
                c[i][j]=c11[i][j];
                c[i][j+n/2]=c12[i][j];
                c[i+n/2][j]=c21[i][j];
                c[i+n/2][j+n/2]=c22[i][j];
            }
         }

}

int _tmain(int argc, _TCHAR* argv[])
{
    
    int a[N][N],b[N][N],c[N][N];
    int n=0;
    
    cout<<"請輸入矩陣的階數(<8):"<<endl;
    cin>>n;
    cout<<"輸入第一個矩陣:"<<endl;
    InputMatrix(n,a);

    cout<<"輸入第2個矩陣:"<<endl;
    InputMatrix(n,b);

    StrassenMatrixMultiplication(n,a,b,c);
    
    OutputMatrix(n,c);
    system("pause");
    return 0;
}

分治法求最近對

思路:時間複雜度為nlog(n),先按照每個點的x座標用快速排序法進行排序,然後再遞迴求出左右兩側的最近對,還要求出中點附近位於分別位於左右兩側的最近對,為了處理

所有可能的點的個數,遞迴的出口時當點的數量為2或者3時。具體的程式碼如下:

#include "stdafx.h"
#include<iostream>
using namespace std;

typedef struct Point{
    float x;
    float y;
};

typedef struct Closepair{
    Point a;
    Point b;
    float dis;

};

//   依據點的X座標對點進行排序,快排
int X_PartitionsPointArray(Point *a,int left,int right)    {
    Point pivot=a[left];
    while(left<right)    {
    
        while(left<right&&a[right].x>=pivot.x)
            --right;
        if(left<right)
            a[left++]=a[right];
        
        while(left<right&&a[left].x<=pivot.x)
            ++left;
        if(left<right)
            a[right--]=a[left];
    }
    a[left]=pivot;
    return left;
}

void SortPointArray(Point *a,int left,int right)    {
    if(left<right)    {
    
        int par=X_PartitionsPointArray(a,left,right);
        SortPointArray(a,left,par-1);
        SortPointArray(a,par+1,right);
    }

}

float DistancePointPair(Point a,Point b)    {
    
    return sqrt((a.x-b.x)*(a.x-b.x)+(a.y-b.y)*(a.y-b.y));

}


void ClosedPair(Point *a,int left,int right,Closepair &result)    {
    
    if((right-left+1)==1)
        cout<<"錯誤,點的數量必須大於1"<<endl;
    
    else if((right-left+1)==2)    {
    
        result.a=a[left];
        result.b=a[right];
        result.dis=DistancePointPair(a[left],a[right]);
    
    }
    
    
    else if((right-left+1)==3)    {
        float temp1=DistancePointPair(a[left],a[left+1]),temp2=DistancePointPair(a[right],a[left+1]),temp3=DistancePointPair(a[left],a[right]);

        if(temp1<temp2)    {
            result.a=a[left];
            result.b=a[left+1];
            result.dis=temp1;
            
        }
        else{
            result.a=a[right];
            result.b=a[left+1];
            result.dis=temp2;
            
        }
            
        if(temp3<result.dis){
            result.a=a[left];
            result.b=a[right];
            result.dis=temp3;
            
        }
    }


    else    {
        
        SortPointArray(a,left,right);
        

        Closepair left_result;
        Closepair right_result;

        ClosedPair(a,left,(left+right)/2,left_result);
        ClosedPair(a,(left+right)/2+1,right,right_result);

        Closepair all_result;
        
        if(left_result.dis<right_result.dis)
            all_result=left_result;
        else
            all_result=right_result;

        float distance=all_result.dis;

        int middle=(left+right)/2;
        
        for(int i=middle;i>=left;--i)    {
            
            if((a[middle+1].x-a[i].x)>distance)
                break;
            else    {
                for(int j=middle+1;j<=right;++j)    {
                    
                
                    if((a[j].x-a[middle].x)>distance)
                        break;
                    else    {
                    
                        float tempdis=DistancePointPair(a[j],a[i]);
                        if(tempdis<all_result.dis)    {
                            all_result.dis=tempdis;
                            all_result.a=a[i];
                            all_result.b=a[j];
                        }
                    
                    }
                
                }
            }
        }
        result=all_result;
    }

}

int _tmain(int argc, _TCHAR* argv[])
{
    Point a[10];
    Closepair result;
    int num;
    cout<<"要輸入的點數(<10):"<<endl;
    cin>>num;

    cout<<"依次輸入每個點的x座標和y座標:"<<endl;
    for(int i=0;i<num;++i)
        cin>>a[i].x>>a[i].y;

    ClosedPair(a,0,num-1,result);

    cout<<"最近點為:("<<result.a.x<<","<<result.a.y<<")--->"<<"("<<result.b.x<<","<<result.b.y<<")    最近距離為:"<<result.dis<<endl;
    system("pause");
    return 0;
}

分治法求凸包

#include "stdafx.h"
#include<iostream>
using namespace std;
const int N=20;

typedef struct Point{
    float x;
    float y;
};

typedef struct PointArray{
    Point plist[N];
    int num;
};

typedef struct Line{
    Point a;
    Point b;
};

typedef struct LineArray{
    Line llist[N];
    int num;

};


int X_PartitionsPointArray(Point *a,int left,int right)    {
    Point pivot=a[left];
    while(left<right)    {
    
        while(left<right&&a[right].x>=pivot.x)
            --right;
        if(left<right)
            a[left++]=a[right];
        
        while(left<right&&a[left].x<=pivot.x)
            ++left;
        if(left<right)
            a[right--]=a[left];
    }
    a[left]=pivot;
    return left;
}

void SortPointArray(Point *a,int left,int right)    {
    if(left<right)    {
    
        int par=X_PartitionsPointArray(a,left,right);
        SortPointArray(a,left,par-1);
        SortPointArray(a,par+1,right);
    }

}

// 求有向面積
float GetArea(Point p1,Point p2,Point p3)    {
      
    return (p1.x * p2.y + p3.x * p1.y + p2.x * p3.y -
             p3.x * p2.y - p2.x * p1.y - p1.x * p3.y);

}

void QuickHull(PointArray &a,LineArray &b)    {

    if(a.num==0||a.num==1)
        return;
    b.num=0;
    PointArray left,right;
    left.num=right.num=0;   //初始化

    SortPointArray(a.plist,0,a.num-1);   //按X座標排序

    Point mostleft=a.plist[0];
    Point mostright=a.plist[a.num-1];

    float area=0;

    for(int i=1;i<a.num-1;++i)    {
        
        Point temp=a.plist[i];
        area=GetArea(mostleft,a.plist[i],mostright);
        if(area>0)    {
            left.plist[left.num++]=temp;
        }
        else if(area>0)    {
            right.plist[right.num++]=temp;    
        }
    }

    RecurtionHull(mostleft,mostright,left,b);
    RecurtionHull(mostright,mostleft,right,b);

}

void RecurtionHull(Point mostleft,Point mostright,PointArray a,LineArray &b )    {

        float area=0;
        float maxarea=0;
        Point pmax;
    
    
    
    if(a.num==0)    {
        
        b.llist[b.num++].a=mostleft;
        b.llist[b.num++].a=mostright;
    }


    else    {


        for(int i=0;i<a.num;++i)    {
            Point temp=a.plist[i];
            area=GetArea(mostleft,mostright,temp);
            if(area>maxarea)    {
                maxarea=area;
                pmax=temp;
            
            }
        }
    }

    PointArray a_left,a_right;
    a_left.num=a_right.num=0;
    for(int i=0;i<a.num;++i)    {
        Point temp=a.plist[i];
        if (GetArea(mostleft,pmax,temp)>0) {
            a_left.plist[a_left.num++]=temp;
            
        }
        else if(GetArea(mostleft,pmax,temp)<0)    {
            a_right.plist[a_right.num++]=temp;
        }
    }
    

    RecurtionHull(mostleft,pmax,a_left,b);
    RecurtionHull(pmax,mostright,a_right,b);
}