1. 程式人生 > >prim演算法 優化前O(n²) 優化後O(n-k)

prim演算法 優化前O(n²) 優化後O(n-k)

1.演算法思想
圖採用鄰接矩陣儲存,貪心找到目前情況下能連上的權值最小的邊的另一端點,加入之,直到所有的頂點加入完畢。
2.演算法實現步驟
設圖G=(V,E),其生成樹的頂點集合為U。
(1)把v0放入U。
(2)在所有u∈U,v∈V-U的邊(u,v)∈E中找一條最小權值的邊,加入生成樹。
(3)把(2)找到的邊的v加入U集合。如果U集合已有n個元素,則結束,否則繼續執行(2)。
最後得到最小生成樹U=

#include<cstdio>
#include<cstring>
using namespace std;
#define vmax 200
int w[vmax][vmax],i,j,k,v,e;

void
prim(int v0) { bool flag[vmax]; int min,nextk,prevk; memset(flag,false,sizeof(flag)); flag[v0]=true; for (i=1;i<=v-1;i++) { min=0x7fffffff; for (k=1;k<=v;k++) if (flag[k]) for (j=1;j<=v;j++) if (!flag[j] && w[k][j]<min && w[k][j]!=0
) { min=w[k][j]; nextk=j; prevk=k; } if (min!=0x7fffffff) { flag[nextk]=true; printf("%d %d %d",prevk,nextk,min); } } } int main() { memset(w,0,sizeof(w)); scanf
("%d %d",&v,&e); for (k=1;k<=e;k++) { scanf("%d %d",&i,&j); scanf("%d",&w[i][j]); w[j][i]=w[i][j]; } prim(1); return 0; }

3.演算法的關鍵與優化
我們很容易就可以發現prim演算法的關鍵:每次如何從生成樹T到T外的所有邊中,找出一條最小邊。例如,在第k次前,生成樹T中已有k個頂點和(k-1)條邊,此時,T到T外的所有邊數為k*(n-k),當然,包括沒有邊的兩頂點我們記權值為“無窮大”的邊在內,從如此多的邊中查詢最短邊,時間複雜度為O(k(n-k)),顯然無法滿足我們的期望。
我們來看O(n-k)的方法:假定在進行第k次前已經保留著從T中到T外的每一個頂點(共n-k個)的各一條最短邊,在進行第k次時,首先從這(n-k)條最短邊中,找出一條最最短邊(它就是從T到T外的最短邊),假設為(vi,vj),此步需要進行(n-k)次比較;然後把邊(vi,vj)和頂點vj併入T中的邊集TE和頂點集U中,此時,T外只有n-(k+1)個頂點,對於其中的每個頂點vt,若(vj,vt)邊上的權值小於原來儲存的從T中到vt的最短邊的權值,則用(v,vt)修改之,否則,保持原最小邊不變。這樣就把第k次後T中到T外的每一個頂點vt的各一條最短邊都保留下來了,為第(k+1)次做好了準備。這樣,prim的總時間複雜度為O(n²)。
【樣例輸入】
6 10
1 2 10
1 5 19
1 6 21
2 3 5
2 4 6
2 6 11
3 4 6
4 5 18
4 6 14
5 6 33
【樣例輸出】
50
優化後:

#include<cstdio>
using namespace std;
#define MXN 1000
int map[MXN][MXN],cost[MXN],visit[MXN],i,j,n,m,x,y,v;

int prim()
{
    int i,j,min,mini,ans;
    ans=0;
    for (i=1;i<=n;i++)
      {
        visit[i]=false;
        cost[i]=0x7fffffff;
      }
    for (i=2;i<=n;i++)
      if (map[1][i]!=0)
        cost[i]=map[1][i];
    visit[1]=true;
    for (i=1;i<=n-1;i++)
      {
        min=0x7fffffff;
        for (j=1;j<=n;j++)
          if (!visit[j] && cost[j]<min)
            {
              min=cost[j];
              mini=j;
            }
        visit[mini]=true;
        ans+=min;
        for (j=1;j<=n;j++) 
          if (!visit[j] && map[mini][j]>0 && map[mini][j]<cost[j])
            cost[j]=map[mini][j];
      }
    return ans;
}

int main()
{
    scanf("%d %d",&n,&m);
    for (i=1;i<=m;i++)
      {
        scanf("%d %d %d",&x,&y,&v);
        if (map[x][y]==0 || map[x][y]>v)
          {
            map[x][y]=v;
            map[y][x]=v;
          }
      }
    printf("%d",prim());
    return 0;
}