1. 程式人生 > >資料探勘之關聯規則挖掘之Apriori演算法實現

資料探勘之關聯規則挖掘之Apriori演算法實現

演算法細節見論文:Fast Algorithm for Mining Association Rules

控制檯版本C++程式碼如下:

#include <iostream>
#include <sstream>
#include <fstream>
#include <vector>
#include <set>
#include <map>
#include <ctime>
using namespace std;

//讀取檔案獲取整個資料庫儲存在database中,fileName必須為char*型,要是用string會報錯,in()不認
bool ObtainDatabase(vector<set<int> > &database,char *fileName)
{
/*	set<int> data;
	data.insert(1);data.insert(2);data.insert(5);
	database.push_back(data);
	
	data.clear();
	data.insert(2);data.insert(4);
	database.push_back(data);
	
	data.clear();
	data.insert(2);data.insert(3);
	database.push_back(data);
	
	data.clear();
	data.insert(1);data.insert(2);data.insert(4);
	database.push_back(data);
	
	data.clear();
	data.insert(1);data.insert(3);
	database.push_back(data);
	
	data.clear();
	data.insert(2);data.insert(3);
	database.push_back(data);
	
	data.clear();
	data.insert(1);data.insert(3);
	database.push_back(data);
	
	data.clear();
	data.insert(1);data.insert(2);data.insert(3);data.insert(5);
	database.push_back(data);
	
	data.clear();
	data.insert(1);data.insert(2);data.insert(3);
	database.push_back(data);
*/	
	ifstream in(fileName);
	if(!in)
	{
		cout<<"檔案開啟失敗!"<<endl;
		return false;
	}
	
	string s="";
	unsigned int i=0;
	while(getline(in,s))
	{//讀取一行記錄
	i++;
		set<int> transaction;
		int len=s.length();
		string str="";
		for(int i=0;i<len;i++)
		{//將記錄中的數提取出來
			if(s[i]!=' ')
			{
				str+=s[i];
			}
			else if(s[i]==' '||i==len-1)
			{
				//字串轉int
				stringstream stoi(str);
				int item=0;
				stoi>>item;
				
				transaction.insert(item);
				
				str="";
			}
		}
		database.push_back(transaction);
		s="";
	}
	cout<<i<<endl; //system("pause");
	return true;

}

//遍歷一遍資料庫,建立1-項大項集
void CreateItemset(vector<set<int> >&database,vector<set<int> > &largeItemset,unsigned int minSupport,map<set<int>,int> &lm1)
{
	map<int,int> dir;
	map<int,int>::iterator dirIt;
	
	vector<set<int> >::iterator databaseIt;
	
	set<int> temp;
	set<int>::iterator tempIt;
	
	//根據資料庫建立字典,字典形式為<item,count>
	for(databaseIt=database.begin();databaseIt!=database.end();databaseIt++)
	{
		temp=*databaseIt;
		for(tempIt=temp.begin();tempIt!=temp.end();tempIt++)
		{
			int item=*tempIt;
			dirIt=dir.find(item);
			if(dirIt==dir.end())
			{//item不在字典dir中
				dir.insert(pair<int,int>(item,1));
			}
			else
			{//item在字典dir中,則將其count值加1
				(dirIt->second)++;
			}
		}
	}
	
	//從字典中選出支援度超過minSopport的item
	for(dirIt=dir.begin();dirIt!=dir.end();dirIt++)
	{
		if(dirIt->second>=minSupport)
		{
			set<int> large;
			large.insert(dirIt->first);
			largeItemset.push_back(large);
			lm1.insert(pair<set<int>,int>(large,dirIt->second));
		}
	}
	
}


//輸出大項集
void OutputLargeItemset(vector<set<int> > &largeItemset,unsigned int i)
{
	cout<<"包含 "<<largeItemset.size()<<" 項的 "<<i<<"-項大項集:"<<endl;
	
	vector<set<int> >::iterator largeItemsetIt;
	int j=0;
	for(largeItemsetIt=largeItemset.begin();largeItemsetIt!=largeItemset.end();largeItemsetIt++)
	{
		set<int> temp=*largeItemsetIt;
		cout<<"{ ";
		for(set<int>::iterator tempIt=temp.begin();tempIt!=temp.end();tempIt++)
		{
			cout<<(*tempIt)<<" ";
		}
		cout<<"}";
		j++;
		if(j%4==0)
		{
			cout<<endl;
		}
	}
	cout<<endl<<endl;
}

//連線步驟,若it1和it2符合連線條件,則把它們連線為temp,返回true,否則返回false
bool Joint(set<int> &recordI,set<int> &recordJ,set<int> &temp)
{
	if(recordI.size()!=recordJ.size())
	{//倆集合大小不一樣,立馬返回!
		return false;
	}
	set<int>::iterator it1=recordI.begin();
	set<int>::iterator it2=recordJ.begin();
	
	unsigned int size=recordI.size()-1;
	for(int i=0;i<size;i++)
	{
		if(*it1!=*it2)
		{
			return false;
		}
		temp.insert(*it1);
		it1++;
		it2++;
	}
	if(*it1==*it2)
	{
		return false;
	}
	temp.insert(*it1);
	temp.insert(*it2);
	//cout<<"連線"<<*it1<<" "<<*it2<<endl;
	return true;
}

//剪枝步驟,若temp的k-1項集有不在L[k-1]中,則剪掉,返回false,否則返回true
bool Prune(set<int> &temp,vector<set<int> > &largeTemp)
{
	unsigned int size=temp.size();
	
	//獲取temp的全部k-1項子集,並判斷每個子集是否在L[k-1]中
	for(int i=0;i<size;i++)
	{	
		set<int>::iterator tempIt=temp.begin();
		set<int> tempMinusOne;//盛放k-1項子集
		for(int j=0;j<size;j++)
		{
			if(j!=i)
			{
				tempMinusOne.insert(*tempIt);
			}
			*tempIt++;
		}
		
		//判斷tempMinusOne是否在L[k-1]中
		vector<set<int> >::iterator largeTempIt;
		bool flag=false;//temp是否被剪掉的標識
		for(largeTempIt=largeTemp.begin();largeTempIt!=largeTemp.end();largeTempIt++)
		{//對大項集集合largeTemp中的大項集*largeTempIt逐個與tempMinusOne進行比對,看相不相同,相同就會保證flag=true,否則為false
			flag=true;
			set<int> large=*largeTempIt;
			set<int>::iterator tempMinusOneIt=tempMinusOne.begin();
			for(set<int>::iterator largeIt=large.begin();largeIt!=large.end();largeIt++)
			{
				if(*largeIt!=*tempMinusOneIt)
				{
					flag=false;
					break;
				}
				tempMinusOneIt++;
			}
			if(flag==true)
			{//存在了,不用再和其它大項集比較了,浪費時間
				return true;
			}
		}
	}
	return false;
}

//利用L[k-1],通過連線和剪枝兩個步驟,生成候選集集合candidate
void AprioriGen(vector<set<int> > &largeTemp,vector<set<int> > &candidate)
{
	unsigned int largeTempSize=largeTemp.size();
	
	unsigned int sizeTemp=largeTempSize-1;
		
	vector<set<int> >::iterator largeTempIt=largeTemp.begin();
	//L[k-1]中的大項集兩兩連線,求候選集集合
	for(int i=0;i<sizeTemp;i++,largeTempIt++)
	{//system("pause");cout<<largeTempSize<<" "<<i<<endl;
		set<int> recordI=*largeTempIt;
		for(int j=i+1;j<largeTempSize;j++)
		{//cout<<j<<endl;
			set<int> recordJ=*(largeTempIt+(j-i));
			set<int> temp;
		//	cout<<"進行連線"<<endl;
			if(Joint(recordI,recordJ,temp))
			{//recordI和recordJ能連線成temp,則對temp進行剪枝
			//cout<<"連線成功,進行剪枝"<<endl; 
				if(Prune(temp,largeTemp))
				{//temp沒有被剪掉,則把它加到候選集的集合中
				if(!temp.empty())
			//	cout<<"temp不為空,沒有被剪掉,成為到候選集"<<endl;
					candidate.push_back(temp);
				}
			//	else{cout<<"被剪掉了"<<endl;} 
			}
			//else{cout<<"不符合連線條件"<<endl; } 
		}//system("pause");
	}
}


//對比資料庫中的每條交易,計算每個候選集的支援度,選出大於等於最小支援度的候選集來構成L[k]
void Subset(vector<set<int> > &database,vector<set<int> > &candidate,vector<set<int> > &largeK,unsigned int minSupport,map<set<int>,int> &lm)
{	
	
	vector<set<int> >::iterator databaseIt;
	vector<set<int> >::iterator candidateIt;
	
	for(candidateIt=candidate.begin();candidateIt!=candidate.end();candidateIt++)
	{//對於每個候選集can
		//bool cunzai=true;
	
		set<int> can=*candidateIt;
		
		//cout<<"cannnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnn"<<endl;
		
		unsigned int canCount=0;
		for(databaseIt=database.begin();databaseIt!=database.end();databaseIt++)
		{//對於資料庫中每條交易,檢視can是否在其中
			set<int> data=*databaseIt;
			
			if(can.size()>data.size())
			{
				continue;//候選集大小大於交易大小,肯定不在這個交易中
			}
			
			set<int>::iterator canIt;
			for(canIt=can.begin();canIt!=can.end();canIt++)
			{//對於can中每個項,看它是否在交易data中
				if(data.find(*canIt)==data.end())
				{
					break;
				}
				
			}
			
			if(canIt==can.end())
			{//cout<<"在"<<endl;//system("pause");
				canCount++;
                
                //cout<<canCount<<endl;
			}
		}
		if(canCount>=minSupport)
		{//canCount只要大於等於最小支援度,我們就退出迴圈,不再對該候選集進行計數了,浪費時間
			largeK.push_back(can);
			lm.insert(pair<set<int>,int>(can,canCount));
		}
	}
}

int main(int argc,char *argv[])
{
	char name[200];
	string file="";
	char *fileName="retail.dat";
	int minSupport=5000;//最小支援度
/*	
	string ctl="";
	cout<<"手動輸入檔案路徑和最小支援度(Y/N)?";
	cin>>ctl;
	if(ctl=="Y"||ctl=="y")
	{
		cout<<"請依次輸入檔案路徑和最小支援度,用空格隔開。(檔案路徑要用雙斜槓):\n";
		cin>>file>>minSupport;
		strcpy(name,file.c_str());
		fileName=name;
	}	
	*/
	
	vector<map<set<int>,int> > liss;
	
    clock_t start=clock();	
	vector<set<int> > database;//資料庫
	ObtainDatabase(database,fileName);
	
	vector<set<int> > large1;
	map<set<int>,int> lm1;
	CreateItemset(database,large1,minSupport,lm1);
	
	liss.push_back(lm1);
	
	int k=1;
	vector<set<int> > largeTemp=large1;
	while(!largeTemp.empty())
	{
		
		OutputLargeItemset(largeTemp,k);
		k++;
		
		vector<set<int> > candidate;
		AprioriGen(largeTemp,candidate);
		
		vector<set<int> > largeK;
		map<set<int>,int> lm;
		Subset(database,candidate,largeK,minSupport,lm);
		
		largeTemp=largeK;
		
		if(largeTemp.empty())
		{
			cout<<"L["<<k<<"]為空"<<endl;
		} 
		else
		{
			liss.push_back(lm);	
		}
	}		
	
	
	
	clock_t end=clock();
	cout<<"Finish!共用時:"<<(end-start)<<"ms"<<endl;
	system("pause");
}