1. 程式人生 > >[線性代數]矩陣乘法演算法實現

[線性代數]矩陣乘法演算法實現

作者zhonglihao    
演算法名矩陣乘法 Matrix Multiplication
分類線性代數
複雜度n^3
形式與資料結構C++實現 一維陣列儲存
特性指標封裝返回
具體參考出處 教科書
備註
// ConsoleApplication1.cpp : 定義控制檯應用程式的入口點。
//

#include "stdafx.h"
#include <stdlib.h> 
#include "stdio.h"

void MatrixPrint(int* arr, int row, int col);
int* MatrixMul(int* arr_A, const int row_A, const int col_A, int* arr_B, const int row_B, const int col_B);

int _tmain(int argc, _TCHAR* argv[])
{
	//統一使用一維陣列實現
	const int row_A = 3;
	const int col_A = 3;
	int Mat_A[row_A*col_A] = { 1, 1, 1, 2, 2, 2, 3, 3, 3 };
	const int row_B = 3;
	const int col_B = 4;
	int Mat_B[row_B*col_B] = { 1,2,3,4, 0,1,1,0, 2,2,2,2 };

	//列印相乘原始矩陣
	MatrixPrint(Mat_A, row_A, col_A);
	MatrixPrint(Mat_B, row_B, col_B);

	//矩陣相乘返回陣列指標
	int* arr_C = MatrixMul(Mat_A, row_A, col_A, Mat_B, row_B, col_B);
	MatrixPrint(arr_C, row_A, col_B);

	system("Pause");
	return 0;
}

//矩陣相乘方法
int* MatrixMul(int* arr_A, const int row_A, const int col_A, int* arr_B, const int row_B, const int col_B)
{
	int row_scan, col_scan, mul_scan, sum;	    //mul_scan 行列各數獨立相乘掃描
	int* arr_C;				    //輸出矩陣
	int arr_C_len = row_A * col_B;		    //輸出矩陣大小

	//判定是否符合相乘法則
	if (col_A != row_B) return NULL;

	//分配輸出陣列長度
	arr_C = (int*)malloc(arr_C_len*sizeof(int));

	//矩陣相乘
	for (row_scan = 0; row_scan < row_A; row_scan++)//矩陣A行迴圈
	{
		for (col_scan = 0; col_scan < col_B; col_scan++)//矩陣B列迴圈
		{
			for (mul_scan = 0,sum = 0; mul_scan < col_A; mul_scan++)//A列=B行各數獨立相乘
			{
				sum += arr_A[row_scan * col_A + mul_scan] * arr_B[col_scan + mul_scan * col_B];
			}
			arr_C[row_scan * col_B + col_scan] = sum;
		}
	}

	//返回指標
	return arr_C;
}


//矩陣列印方法
void MatrixPrint(int* arr, const int row, const int col)
{
	int len = row * col;
	int i,col_count;

	for (i = 0, col_count = 0; i < len; i++)
	{
		printf("%d\t", arr[i]);

		//單換行
		if (++col_count >= col)
		{
			printf("\n");
			col_count = 0;
		}
	}

	//跳空換行
	printf("\n");

	return;
}