1. 程式人生 > >基於多程序的網路聊天程式

基於多程序的網路聊天程式

參考:linux高效能伺服器程式設計,作者:遊雙

程式簡介:該程式用了共享記憶體來實現程序間的同步,由於只是同時讀取共享記憶體,所以沒有用到鎖。該程式的功能是伺服器監聽網路連線,當有一個客戶端連線時,伺服器建立一個子程序處理該連線。每個子程序只負責自己的客戶端以及和父程序通訊。當子程序從客戶端讀取資料後,把資料放到共享記憶體上,每個子程序在共享記憶體上有自己的一段空間,因此不會出現同時寫。放上去後通知父程序,說:共享記憶體上有新資料到達了,然後父程序通知其他子程序,去到該位置讀取資料,把資料傳送到自己的客戶端,實現了群聊的效果。該程式對於多程序程式設計的初學者是個不錯的例子,寫下來是為了讓自己熟悉一下。

伺服器程式碼:編譯的時候需要加上 -lrt選項

#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <assert.h>
#include <stdio.h>
#include <unistd.h>
#include <errno.h>
#include <string.h>
#include <fcntl.h>
#include <stdlib.h>
#include <sys/epoll.h>
#include <signal.h>
#include <sys/wait.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <fcntl.h>

#define USER_LIMIT 5
#define BUFFER_SIZE 1024
#define FD_LIMIT 65535
#define MAX_EVENT_NUMBER 1024
#define PROCESS_LIMIT 65536

/* 處理一個客戶端連線的必要資料 */
struct client_data			
{
	sockaddr_in address;
	int connfd;				/* 客戶端的fd  */
	pid_t pid;				/* 處理這個連線的子程序的pid */
	int pipefd[2];			/* 和父程序通訊用的管道 */
};

int sig_pipefd[2];//當有訊號發生時,用於父程序自己的通訊
char* share_mem;
int user_count = 0; //當前客戶的數量
client_data* users = 0 ; 
int* sub_process = 0;
static const char* shm_name = "/my_share_memory";
int maxevents = 100;
bool stop_child = false;

void setnonblock(int fd)
{
	int flag = fcntl(fd,F_GETFL);
	assert(flag != -1);
	fcntl(fd,F_SETFL,flag | O_NONBLOCK);
}
void addfd(int epollfd,int fd)
{
	epoll_event ee;
	ee.data.fd = fd;
	ee.events = EPOLLIN | EPOLLET;
	epoll_ctl(epollfd,EPOLL_CTL_ADD,fd,&ee);
	setnonblock(fd);
}
void sig_handler(int sig)
{
	int save_errno = errno;
	int msg = sig;
	send(sig_pipefd[1],(char*)&msg,1,0);
	errno = save_errno;//恢復錯誤值
}
void child_sig_handler(int sig)
{
	stop_child = true;
}
void addsig(int sig,void (*handler)(int),bool restart = true)
{
	struct sigaction sa;
	memset(&sa,'\0',sizeof(sa));
	sa.sa_handler = handler;
	if(restart)sa.sa_flags |= SA_RESTART;
	sigfillset(&sa.sa_mask);//作用?
	assert(sigaction(sig,&sa,NULL) != -1);
}
/* 子程序的處理函式,idx表示該子程序處理的客戶端連線的編號,users表示所有客戶端連線資料的陣列,share_mem表示共享記憶體的起始地址 */
int run_child(int idx,client_data* users,char* share_mem)
{
	int connfd = users[idx].connfd;
	int pipefd = users[idx].pipefd[1];
	int epollfd = epoll_create(100);//子程序的事件處理函式
	assert(epollfd != -1);
	addfd(epollfd,connfd);//與客戶端通訊
	addfd(epollfd,pipefd);//與父程序通訊
	addsig(SIGTERM,child_sig_handler,false);
	epoll_event events[maxevents];

	int ret;
	while(!stop_child)
	{
		int number = epoll_wait(epollfd,events,maxevents,-1);
		if(number < 0 && errno != EINTR)
		{
			printf("epoll error\n");
			break;
		}
		int i;
		for(i = 0;i < number;i++)
		{
			int sockfd = events[i].data.fd;
			if(sockfd == connfd && (events[i].events & EPOLLIN))//客戶端發來資料
			{
				memset(share_mem+idx*BUFFER_SIZE,'\0',BUFFER_SIZE);
				/* 將客戶端資料讀取到對應的讀快取中,該讀快取是共享記憶體的一段 */
				ret = recv(sockfd,share_mem+idx*BUFFER_SIZE,BUFFER_SIZE-1,0);
				if(ret < 0 && errno != EAGAIN)
				{
					printf("recv error\n");
					stop_child =  true;
				}
				else if(ret == 0)
				{
					printf("client close\n");
					stop_child = true;
				}
				else 
				{
					send(pipefd,(char*)&idx,sizeof(idx),0);//告訴父程序,“我”收到資料了
				}
			}
			/* 父程序通知"我"將第client個客戶端的資料傳送到我負責的客戶端 */
			else if(sockfd == pipefd && (events[i].events & EPOLLIN))
			{
				int client = 0;
				ret = recv(sockfd,(char*)&client,sizeof(client),0);
				if(ret < 0 && errno != EAGAIN)stop_child = true;
				else if(ret == 0) stop_child = true;
				else
				{
					send(connfd,share_mem+client*BUFFER_SIZE,BUFFER_SIZE,0);
				}
			}
		}
	}
	close(connfd);
	close(pipefd);
	close(epollfd);
	return 0;
}
int main(int argc,char* argv[])
{
	if(argc != 3)
	{
		printf("usage %s server_ip server_port \n",basename(argv[0]));
		return -1;
	}
	sockaddr_in server;
	server.sin_family = AF_INET;
	inet_pton(AF_INET,argv[1],&server.sin_addr);
	server.sin_port = htons(atoi(argv[2]));
	int listenfd = socket(AF_INET,SOCK_STREAM,0);
	assert(listenfd != -1);
	int opt = 1;
	int ret = setsockopt(listenfd,SOL_SOCKET,SO_REUSEADDR,&opt,sizeof(opt));
	assert(ret == 0);
	ret = bind(listenfd,(const sockaddr*)&server,sizeof(server));
	assert(ret != -1);
	ret = listen(listenfd,100);
	assert(ret != -1);

	/* 初始化連線池 */
	user_count = 0;
	users = new client_data[USER_LIMIT+1];
	sub_process = new int[PROCESS_LIMIT];
	int i;
	for(i = 0; i < PROCESS_LIMIT;++i)
	{
		sub_process[i] = -1;
	}

	/* epoll的初始化 */
	int epollfd = epoll_create(100);
	assert(epollfd != -1);
	addfd(epollfd,listenfd);//監聽網路連線埠

	ret = socketpair(AF_UNIX,SOCK_STREAM,0,sig_pipefd);//當有訊號發生時,用於父程序自己的通訊
	assert(ret != -1);
	setnonblock(sig_pipefd[1]);//UNIX域套接字的0號埠用於訊號處理函式
	addfd(epollfd,sig_pipefd[0]);//主程序監聽UNIX域套接字的1號埠

	/*  設定訊號處理函式 */ 
	addsig(SIGCHLD,sig_handler);
	addsig(SIGPIPE,SIG_IGN);
	addsig(SIGINT,sig_handler);
	addsig(SIGTERM,sig_handler);

	/* 建立共享記憶體,用於所有客戶socket連線的讀快取 */
	int shmfd = shm_open(shm_name,O_CREAT|O_RDWR,0666);
	assert(shmfd != -1);
	ret = ftruncate(shmfd,USER_LIMIT*BUFFER_SIZE);//設定shmfd的大小
	assert(ret != -1);
	share_mem = (char*)mmap(NULL,USER_LIMIT*BUFFER_SIZE,PROT_READ|PROT_WRITE,MAP_SHARED,shmfd,0);
	assert(share_mem != MAP_FAILED);
	close(shmfd);

	/* 進入epoll事件迴圈 */
	bool stop_server = false;
	bool terminate  = false;
	epoll_event events[maxevents];
	while(!stop_server)
	{
		int number = epoll_wait(epollfd,events,maxevents,-1);
		if(number < 0 && errno != EINTR)
		{
			printf("epoll error\n");
			break;
		}
		for(i = 0;i < number;i++)
		{
			int sockfd = events[i].data.fd;
			/* 新的客戶連線 */
			if(sockfd == listenfd)
			{
				sockaddr_in client;
				socklen_t clilen = sizeof(client);
				int connfd =  accept(listenfd,(struct sockaddr*)&client,&clilen);
				if(connfd < 0)
				{
					printf("accept error\n");
					continue;
				}
				if(user_count >= USER_LIMIT)
				{
					const char* info = "to many users\n";
					printf("%s\n",info);
					send(connfd,info,strlen(info),0);
					close(connfd);
					continue;
				}

				/* 儲存第user_count 個客戶連線的相關資料  */
				users[user_count].address = client;
				users[user_count].connfd = connfd;
				ret = socketpair(AF_UNIX,SOCK_STREAM,0,users[user_count].pipefd);
				assert( ret != -1);

				pid_t pid = fork();
				if(pid < 0)
				{
					close(connfd);
					continue;
				}
				if(pid == 0)//子程序
				{
					close(sig_pipefd[0]);
					close(sig_pipefd[1]);
					close(users[user_count].pipefd[0]);//子程序關閉0埠
					close(listenfd);
					close(epollfd);
					run_child(user_count,users,share_mem);//子程序的處理函式
					munmap((void*)share_mem,USER_LIMIT*BUFFER_SIZE);
					exit(0);
				}
				else //父程序
				{
					close(users[user_count].pipefd[1]);//父程序關閉1埠
					close(connfd);
					addfd(epollfd,users[user_count].pipefd[0]);
					users[user_count].pid = pid;
					sub_process[pid] = user_count;
					user_count ++;
				}
			}
			/* 處理訊號事件 */
			else if(sockfd ==sig_pipefd[0] &&  events[i].events & EPOLLIN)
			{
				int sig;
				char signals[1024];
				ret = recv(sockfd,signals,sizeof(signals),0);
				if(ret < 0 && ret != EAGAIN)
				{
					printf("recv error\n");
					continue;
				}
				if(ret == 0)continue;
				for(i = 0; i < ret; ++ i)
				{
					switch(signals[i])
					{
						case SIGCHLD : //子程序關閉
						{
							pid_t pid;
							int status;
							while((pid = waitpid(-1,&status,WNOHANG)) > 0)
							{
								/* 用子程序的pid獲取被關閉的客戶端連線的編號 */
								int del_user = sub_process[pid];
								sub_process[del_user] = -1;
								/* 清楚第del_user個客戶連線使用的相關資料 */
								epoll_ctl(epollfd,EPOLL_CTL_DEL,users[del_user].pipefd[0],0);
								close(users[del_user].pipefd[0]);
								/* 把最後一個客戶連線的資訊移動到該位置,用於保證0~user_count-1直接的連線都是活著的 */
								users[del_user] = users[--user_count];
								sub_process[users[del_user].pid] = del_user;
							}
							if(terminate && user_count == 0) stop_server = true;
							break;
						}
						case SIGINT :
						case SIGTERM : //結束伺服器程序
						{
							printf("kill all the child now\n");
							for(i = 0 ;i < user_count;++i)
							{
								pid_t pid = users[i].pid;
								kill(pid,SIGTERM);
							}
							terminate = true;//此處不是stop_sever是為了等待所有子程序結束後再結束
							break;
						}
						default : break;
					}
				}
			}
			/* 某個子程序向父程序寫入了資料 */
			else if(events[i].events & EPOLLIN)
			{
				int child;
				ret = recv(sockfd,(char*)&child,sizeof(child),0);
				if(ret < 0 && errno != EAGAIN) continue;
				else if(ret == 0)continue;
				printf("read data from child accross pipe\n");
				for(i = 0 ;i < user_count;i++)
				{
					if(i != child)
					{
						printf("send data to child accross pipe\n");
						send(users[i].pipefd[0],(char*)&child,sizeof(child),0);
					}
				}
			}
		}
	}
	close(listenfd);
	close(epollfd);
	close(sig_pipefd[0]);
	close(sig_pipefd[1]);
	shm_unlink(shm_name);
	delete[] users;
	delete[] sub_process;
	return 0;
}
客戶端程式碼(比較簡單,就沒有給出註釋):
#define _GNU_SOURCE 1
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <assert.h>
#include <stdio.h>
#include <unistd.h>
#include <string.h>
#include <stdlib.h>
#include <poll.h>
#include <fcntl.h>

#define BUFFER_SIZE 64

int main( int argc, char* argv[] )
{
    if( argc <= 2 )
    {
        printf( "usage: %s ip_address port_number\n", basename( argv[0] ) );
        return 1;
    }
    const char* ip = argv[1];
    int port = atoi( argv[2] );

    struct sockaddr_in server_address;
    bzero( &server_address, sizeof( server_address ) );
    server_address.sin_family = AF_INET;
    inet_pton( AF_INET, ip, &server_address.sin_addr );
    server_address.sin_port = htons( port );

    int sockfd = socket( PF_INET, SOCK_STREAM, 0 );
    assert( sockfd >= 0 );
    if ( connect( sockfd, ( struct sockaddr* )&server_address, sizeof( server_address ) ) < 0 )
    {
        printf( "connection failed\n" );
        close( sockfd );
        return 1;
    }

    pollfd fds[2];
    fds[0].fd = 0;
    fds[0].events = POLLIN;
    fds[0].revents = 0;
    fds[1].fd = sockfd;
    fds[1].events = POLLIN | POLLRDHUP;
    fds[1].revents = 0;
    char read_buf[BUFFER_SIZE];
    int pipefd[2];
    int ret = pipe( pipefd );
    assert( ret != -1 );

    while( 1 )
    {
        ret = poll( fds, 2, -1 );
        if( ret < 0 )
        {
            printf( "poll failure\n" );
            break;
        }

        if( fds[1].revents & POLLRDHUP )
        {
            printf( "server close the connection\n" );
            break;
        }
        else if( fds[1].revents & POLLIN )
        {
            memset( read_buf, '\0', BUFFER_SIZE );
            int len = recv( fds[1].fd, read_buf, BUFFER_SIZE-1, 0 );
			int i;
		    for(i = 0;i<len;i++)printf("%c",read_buf[i]);
        }

        if( fds[0].revents & POLLIN )
        {
            ret = splice( 0, NULL, pipefd[1], NULL, 32768, SPLICE_F_MORE | SPLICE_F_MOVE );
            ret = splice( pipefd[0], NULL, sockfd, NULL, 32768, SPLICE_F_MORE | SPLICE_F_MOVE );
        }
    }
    
    close( sockfd );
    return 0;
}