1. 程式人生 > >Matlab中特征向量間距離矩陣的並行mex程序

Matlab中特征向量間距離矩陣的並行mex程序

ng- 編譯 res threads else val article col 博文

在matlab中, 有n個向量(m維)的矩陣Mat(n, m)

要計算任兩個向量間的距離, 即距離矩陣, 可使用以下的並行算法以加速:


#include <iostream>
#include <mex.h>
#include <matrix.h>
#include <thread>

using namespace std;

//提前定義線程數
const int nThreads = 4;
//全局變量
int rows, cols, nrow, nw;
double *inVals, *outVals;

//線程運行體定義
void calc(int start, int end) {
    double sum, tmp;
    int no, i, j;

    //計算指定區間
    for(no = start; no < end; no++) {
        //第i輸入向量
        i = outVals[no + nrow] - 1;    //C索引下標
        //第j輸入向量
        j = outVals[no + 2 * nrow] - 1;    //C索引下標
        //計算兩輸入向量間的距離
        sum = 0;
        for(int k = 0; k < cols; k++)
        {
            tmp = (inVals[i + k * rows] - inVals[j + k * rows]);
            sum += (tmp * tmp);
        }
        outVals[no + 2 * nrow] = sum;
    }
}

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
    if (nrhs != 1) {
        mexPrintf("Usage: adjmat(double_features[n_rows * m_cols_features])\n");
    }
    
    //指針指向輸入數據
    inVals = mxGetPr(prhs[0]);
    
    //輸入矩陣的行數和列數
    rows = mxGetM(prhs[0]);
    cols = mxGetN(prhs[0]);
    
    //結果的行數nrow
    //結果的列數nw=(i, j, distance)
    nrow = (rows * rows - rows)/2, nw = 3;
    
    //分配結果內存
    nlhs = 1;
    plhs[0] = mxCreateDoubleMatrix(nrow, nw, mxREAL); 
    outVals = mxGetPr(plhs[0]);
    
    //在結果中分配i和j的組合
    int curL = 0;
    for(int i = 0; i < rows - 1; i++)
        for(int j = i + 1; j < rows; j++) {
            outVals[curL] = i + 1;                  //符合Matlab索引下標規範
            outVals[curL + nrow] = j + 1;    //符合Matlab索引下標規範
            curL++;
        }
    
    //按線程數分配計算區間
    int seg = nrow / nThreads;

    //線程數組
    thread threads[nThreads];
    //分配每一個線程的計算區間,避免沖突
    for(int i = 0; i < nThreads; i++) {
        if (i == nThreads - 1)
            threads[i] = thread(calc, i * seg, nrow);
        else
            threads[i] = thread(calc, i * seg, (i + 1) * seg);
    }
    //等待全部線程結束
    for (int i = 0; i < nThreads; i++){
        threads[i].join();
    }
}

編譯: (註意:看上一篇博文,怎樣設置matlab支持C++ 11標準)

mex adjmat.cpp


Matlab中簡單測試:

tic; x = rand(5000, 50);

adjmat(x);

toc


筆記本測試時間約:

0.57s


Matlab中特征向量間距離矩陣的並行mex程序