[原始碼解析] 機器學習引數伺服器 Paracel (1)-----總體架構

0x00 摘要

Paracel是豆瓣開發的一個分散式計算框架,它基於引數伺服器正規化來解決機器學習的問題:邏輯迴歸、SVD、矩陣分解(BFGS,sgd,als,cg),LDA,Lasso...。

Paracel支援資料和模型的並行,為使用者提供簡單易用的通訊介面,比mapreduce式的系統要更加靈活。Paracel同時支援非同步的訓練模式,使迭代問題收斂地更快。此外,Paracel程式的結構與序列程式十分相似,使用者可以更加專注於演算法本身,不需將精力過多放在分散式邏輯上。

因為我們之前已經用ps-lite對引數伺服器的基本功能做了介紹,所以在本文中,我們主要與ps-lite比對大的方面和一些關鍵技術點(paracel沒有開源容錯機制,是個不小的遺憾),而不會像對 ps-lite 那樣做較詳細的分析。

對於本文來說,ps-lite的主要邏輯如下:

本系列其他文章是:

[原始碼解析] 機器學習引數伺服器ps-lite 之(1) ----- PostOffice

[原始碼解析] 機器學習引數伺服器ps-lite(2) ----- 通訊模組Van

[原始碼解析] 機器學習引數伺服器ps-lite 之(3) ----- 代理人Customer

[原始碼解析]機器學習引數伺服器ps-lite(4) ----- 應用節點實現

本文在解析時候會刪除部分非主體程式碼。

0x01使用

我們首先通過原始碼提供的LR演算法看看如何使用。

1.1 配置&啟動

我們從原始碼中找到 LR 相關部分來看,以下就是一些必要配置,在其中我做了部分翻譯,需要留意的是:用一條命令可以啟動若干不同型別的例項,例項執行的都是可執行程式 lr

  1. Enter Paracel's home directory 進入Paracel工作目錄

```cd paracel;```

  1. Generate training dataset for classification 產生訓練資料集

```python ./tool/datagen.py -m classification -o training.dat -n 2500 -k 100```

  1. Set up link library path: 設定連結庫路徑

```export LD_LIBRARY_PATH=your_paracel_install_path/lib```

  1. Create a json file named cfg.json, see example in Parameters section below. 建立配置檔案

  2. Run (4 workers, local mode in the following example) 執行(4個worker,2個引數伺服器

```./prun.py -w 4 -p 2 -c cfg.json -m local your_paracel_install_path/bin/lr```

Default parameters are set in a JSON format file. For example, we create a cfg.json as below(modify your_paracel_install_path):

{

​ "training_input" : "training.dat", 訓練集

​ "test_input" : "training.dat", 驗證集

​ "predict_input" : "training.dat", label資料

​ "output" : "./lr_result/",

​ "update_file" : "your_paracel_install_path/lib/liblr_update.so",

​ "update_func" : "lr_theta_update", 更新函式

​ "method" : "ipm",

​ "rounds" : 100,

​ "alpha" : 0.001,

​ "beta" : 0.01,

​ "debug" : false

}

1.2 編譯

通過makefile我們可以看到,是把 lr_driver.cpp, lr.cpp一起編譯成為 lr 可執行檔案。把 update.cpp 編譯成庫,被伺服器載入呼叫。

add_library(lr_update SHARED update.cpp) # 引數伺服器如何更新
target_link_libraries(lr_update ${CMAKE_DL_LIBS})
install(TARGETS lr_update LIBRARY DESTINATION lib) add_library(lr_method SHARED lr.cpp) # 演算法程式碼
target_link_libraries(lr_method ${Boost_LIBRARIES} comm scheduler)
install(TARGETS lr_method LIBRARY DESTINATION lib) add_executable(lr lr_driver.cpp) # 驅動程式碼
target_link_libraries(lr
${Boost_LIBRARIES}
comm scheduler lr_method)
install(TARGETS lr RUNTIME DESTINATION bin)

1.3 梯度下降法

對於 LR,有四種 大規模深度神經網路的隨機梯度下降法 可以選擇

  • dgd: distributed gradient descent learning

  • ipm: iterative parameter mixtures learning

  • downpour: asynchrounous gradient descent learning

  • agd: slow asynchronous gradient descent learning

我們選擇 agd 演算法來學習分析:http://www.eecs.berkeley.edu/~brecht/papers/hogwildTR.pdf

1.4 驅動程式碼

首先,我們看看驅動程式碼 lr_driver.cpp,邏輯就是:

  • 配置執行環境和通訊。
  • 讀取分析引數。
  • 生成 logistic_regression,進行訓練,驗證,預測。
DEFINE_string(server_info,
"host1:7777PARACELhost2:8888",
"hosts name string of paracel-servers.\n"); DEFINE_string(cfg_file,
"",
"config json file with absolute path.\n"); int main(int argc, char *argv[])
{
// 配置執行環境和通訊
paracel::main_env comm_main_env(argc, argv);
paracel::Comm comm(MPI_COMM_WORLD); google::SetUsageMessage("[options]\n\t--server_info\n\t--cfg_file\n");
google::ParseCommandLineFlags(&argc, &argv, true); // 讀取分析引數
paracel::json_parser pt(FLAGS_cfg_file);
std::string training_input, test_input, predict_input, output, update_file, update_func, method;
try {
training_input = pt.check_parse<std::string>("training_input");
test_input = pt.check_parse<std::string>("test_input");
predict_input = pt.check_parse<std::string>("predict_input");
output = pt.parse<std::string>("output");
update_file = pt.check_parse<std::string>("update_file");
update_func = pt.parse<std::string>("update_func");
method = pt.parse<std::string>("method");
} catch (const std::invalid_argument & e) {
std::cerr << e.what();
return 1;
}
int rounds = pt.parse<int>("rounds");
double alpha = pt.parse<double>("alpha");
double beta = pt.parse<double>("beta");
bool debug = pt.parse<bool>("debug"); // 生成 logistic_regression,進行訓練,驗證,預測
paracel::alg::logistic_regression lr_solver(comm,
FLAGS_server_info,
training_input,
output,
update_file,
update_func,
method,
rounds,
alpha,
beta,
debug);
lr_solver.solve();
std::cout << "final loss: " << lr_solver.calc_loss() << std::endl;
lr_solver.test(test_input);
lr_solver.predict(predict_input);
lr_solver.dump_result(); return 0;
}

從之前的配置中我們知道更新部分是:

"update_file" : "your_paracel_install_path/lib/liblr_update.so",
"update_func" : "lr_theta_update",

所以我們從 alg/classification/logistic_regression/update.cpp 中得到更新函式如下:

具體就是合併兩個引數然後返回。這部分程式碼被編譯成庫,在server之中被載入執行。

#include <vector>
#include "proxy.hpp"
#include "paracel_types.hpp" using std::vector; extern "C" {
extern paracel::update_result lr_theta_update;
} vector<double> local_update(vector<double> a, vector<double> b) {
vector<double> r;
for(int i = 0; i < (int)a.size(); ++i) {
r.push_back(a[i] + b[i]);
}
return r;
} paracel::update_result lr_theta_update = paracel::update_proxy(local_update);

1.5 演算法程式碼

1.5.1 類定義

logistic_regression 是類定義,位於lr.hpp。logistic_regression 需要繼承 paracel::paralg 才能使用。

namespace paracel {
namespace alg { class logistic_regression: public paracel::paralg { public:
logistic_regression(paracel::Comm,
string,
string _input,
string output,
string update_file_name,
string update_func_name,
string = "ipm",
int _rounds = 1,
double _alpha = 0.002,
double _beta = 0.1,
bool _debug = false); virtual ~logistic_regression(); double lr_hypothesis(const vector<double> &); void dgd_learning(); // distributed gradient descent learning
void ipm_learning(); // by default: iterative parameter mixtures learning
void downpour_learning(); // asynchronous gradient descent learning
void agd_learning(); // slow asynchronous gradient descent learning virtual void solve(); double calc_loss();
void dump_result();
void print(const vector<double> &);
void test(const std::string &);
void predict(const std::string &); private:
void local_parser(const vector<string> &, const char);
void local_parser_pred(const vector<string> &, const char); private:
string input;
string update_file, update_func;
std::string learning_method;
int worker_id;
int rounds;
double alpha, beta;
bool debug = false;
vector<vector<double> > samples, pred_samples;
vector<double> labels;
vector<double> theta;
vector<double> loss_error;
vector<std::pair<vector<double>, double> > predv;
int kdim; // not contain 1
}; } // namespace alg
} // namespace paracel

1.5.2 主體程式碼

solve 是主體程式碼,依據不同配置選擇不同的隨機梯度下降法來訓練。

void logistic_regression::solve() {

  auto lines = paracel_load(input);
local_parser(lines);
paracel_sync(); if(learning_method == "dgd") {
dgd_learning();
} else if(learning_method == "ipm") {
ipm_learning();
} else if(learning_method == "downpour") {
downpour_learning();
} else if(learning_method == "agd") {
agd_learning();
} else {
ERROR_ABORT("method do not support");
}
paracel_sync();
}

1.5.3 Agd演算法

我們找出論文中的演算法比對:

下面程式碼和論文演算法基本一一對應,邏輯如下。

  • 首先把 theta 推送到引數伺服器;
  • 迭代訓練:
    • 從引數伺服器讀取最新的 theta;
    • 進行訓練;
    • 把計算結果推送到引數伺服器;
  • 從引數伺服器得到最新結果;
void logistic_regression::agd_learning() {
int data_sz = samples.size();
int data_dim = samples[0].size();
theta = paracel::random_double_list(data_dim);
paracel_write("theta", theta); // first push // 首先把 theta 推送到引數伺服器
vector<int> idx;
for(int i = 0; i < data_sz; ++i) {
idx.push_back(i);
}
paracel_register_bupdate(update_file, update_func);
double coff2 = 2. * beta * alpha;
vector<double> delta(data_dim); unsigned time_seed = std::chrono::system_clock::now().time_since_epoch().count();
// train loop
for(int rd = 0; rd < rounds; ++rd) {
std::shuffle(idx.begin(), idx.end(), std::default_random_engine(time_seed));
theta = paracel_read<vector<double> >("theta"); // 從引數伺服器讀取最新的 theta
vector<double> theta_old(theta); // traverse data
for(auto sample_id : idx) {
theta = paracel_read<vector<double> >("theta");
theta_old = theta;
double coff1 = alpha * (labels[sample_id] - lr_hypothesis(samples[sample_id]));
for(int i = 0; i < data_dim; ++i) {
double t = coff1 * samples[sample_id][i] - coff2 * theta[i];
theta[i] += t;
}
if(debug) {
loss_error.push_back(calc_loss());
}
for(int i = 0; i < data_dim; ++i) {
delta[i] = theta[i] - theta_old[i];
} // 把計算結果推送到引數伺服器
paracel_bupdate("theta", delta); // you could push a batch of delta into a queue to optimize
} // traverse } // rounds
theta = paracel_read<vector<double> >("theta"); // last pull // 得到最終結果
}

lr的邏輯圖如下:

+------------+                     +-------------------------------------------------+
| lr_driver | |logistic_regression |
| | | |
| +---------------------------------------> solve |
+------------+ lr_solver.solve() | + |
| | |
| | |
| | |
| +---------------------+-----------------------+ |
| | agd_learning | |
| | +-----------------------+ | |
| | | | | |
| | | v | |
| | | theta = paracel_read("theta") | |
| | | | | |
| | | | | |
| | | v | |
| | | | |
| | | delta[i] = theta[i] - theta_old[i] | |
| | | + | |
| | | | | |
| | | | | |
| | | v | |
| | | paracel_bupdate("theta", delta) | |
| | | + + | |
| | | | | | |
| | +-----------------------+ | | |
| +---------------------------------------------+ |
| | |
+-------------------------------------------------+
|
Worker |
+------------------------------------------------------------------------------------+
Server |
+---------------------+
| Server | |
| | |
| v |
| local_update |
| |
+---------------------+

1.6 小結

至此,我們知道了Paracel如何使用,實現是以driver為核心進行展開,使用者需要編寫 update函式和演算法函式。但是距離深入瞭解還差得很遠。

我們目前有幾個問題需要解決:

  • Paracel 怎麼啟動了多個worker進行訓練?
  • Paracel 怎麼啟動了引數伺服器?
  • update 函式如何被使用?

我們需要通過啟動部分來繼續研究。

0x02 啟動

如前所述./prun.py -w 4 -p 2 -c cfg.json -m local your_paracel_install_path/bin/lr是啟動命令,paracel 通過 prun.py 進入系統,所以我們分析這個指令碼。

2.1 python指令碼 prun.py

2.1.1 主體函式

下面我們省略一些非主體程式碼,比如處理引數,邏輯如下:

  • 處理引數;
  • 利用 init_starter 得到如何啟動server,worker,構建出一個相應字串;
  • 利用 subprocess.Popen 啟動server,其中server的執行程式是 bin/start_server
  • 利用 os.system 啟動 worker;
if __name__ == '__main__':
optpar = OptionParser()
# 省略處理引數
(options, args) = optpar.parse_args() nsrv = 1
nworker = 1
if options.parasrv_num:
nsrv = options.parasrv_num
if options.worker_num:
nworker = options.worker_num if not options.method_server:
options.method_server = options.method
if not options.ppn_server:
options.ppn_server = options.ppn
if not options.mem_limit_server:
options.mem_limit_server = options.mem_limit
if not options.hostfile_server:
options.hostfile_server = options.hostfile # 利用 init_starter 得到如何啟動server,worker,構建出相應字串
server_starter = init_starter(options.method_server,
str(options.mem_limit_server),
str(options.ppn_server),
options.hostfile_server,
options.server_group)
worker_starter = init_starter(options.method,
str(options.mem_limit),
str(options.ppn),
options.hostfile,
options.worker_group) #initport = random.randint(30000, 65000)
#initport = get_free_port()
initport = 11777 start_parasrv_cmd_lst = [server_starter, str(nsrv), os.path.join(PARACEL_INSTALL_PREFIX, 'bin/start_server --start_host'), socket.gethostname(), ' --init_port', str(initport)]
start_parasrv_cmd = ' '.join(start_parasrv_cmd_lst) # 利用 subprocess.Popen 啟動server,其中server的執行程式是 bin/start_server
procs = subprocess.Popen(start_parasrv_cmd, shell=True, preexec_fn=os.setpgrp) try:
serverinfo = paracelrun_cpp_proxy(nsrv, initport)
entry_cmd = ''
if args:
entry_cmd = ' '.join(args)
alg_cmd_lst = [worker_starter, str(nworker), entry_cmd, '--server_info', serverinfo, '--cfg_file', options.config]
alg_cmd = ' '.join(alg_cmd_lst) # 利用 os.system 啟動 worker
os.system(alg_cmd)
os.killpg(procs.pid, 9)
except Exception as e:
logger.exception(e)
os.killpg(procs.pid, 9)

2.1.2 starter函式

init_starter 函式會依據配置構建一個字串。其中 paracel 有三種啟動方式:

The –m_server and -m options above refer to what type of cluster you use. Paracel support mesos clusters, mpi clusters and multiprocessers in a single machine.

我們利用前面horovod文章的知識可以知道,mpirun 是可以啟動多個程序。

結合之前的命令列,./prun.py -w 4 -p 2 -c cfg.json -m local your_paracel_install_path/bin/lr,可以知道 local 就是 mpirun,所以paracel 通過 mpirun 來啟動了 4 個 lr 程序

具體程式碼如下:

def init_starter(method, mem_limit, ppn, hostfile, group):
'''Assemble commands for running paracel programs'''
starter = ''
if not hostfile:
hostfile = '~/.mpi/large.18'
if method == 'mesos':
if group:
starter = '%s/mrun -m %s -p %s -g %s -n ' % (PARACEL_INSTALL_PREFIX, mem_limit, ppn, group)
else:
starter = '%s/mrun -m %s -p %s -n ' % (PARACEL_INSTALL_PREFIX, mem_limit, ppn)
elif method == 'mpi':
starter = 'mpirun --hostfile %s -n ' % hostfile
elif method == 'local':
starter = 'mpirun -n '
else:
print 'method %s not supported.' % method
sys.exit(1)
return starter

2.2 可執行程式 start_server

前面提到,server 執行程式對應的是 bin/start_server。

我們看看其構建 src/CMakeLists.txt,於是我們可以去查詢 start_server.cpp。

add_library(comm SHARED comm.cpp) # 通訊相關庫
install(TARGETS comm LIBRARY DESTINATION lib) add_library(scheduler SHARED scheduler.cpp # 排程
install(TARGETS scheduler LIBRARY DESTINATION lib) add_library(default SHARED default.cpp) # 預設庫
install(TARGETS default LIBRARY DESTINATION lib) # 這裡可以看到start_server.cpp
add_executable(start_server start_server.cpp)
target_link_libraries(start_server ${Boost_LIBRARIES} ${CMAKE_DL_LIBS})
install(TARGETS start_server RUNTIME DESTINATION bin) add_executable(paracelrun_cpp_proxy paracelrun_cpp_proxy.cpp)
target_link_libraries(paracelrun_cpp_proxy ${Boost_LIBRARIES} ${CMAKE_DL_LIBS})
install(TARGETS paracelrun_cpp_proxy RUNTIME DESTINATION bin)

2.3 伺服器程式碼

src/start_server.cpp 是伺服器主體程式碼。

結合之前的命令列,./prun.py -w 4 -p 2 -c cfg.json -m local your_paracel_install_path/bin/lr,可以知道 local 就是 mpirun,所以paracel 通過 mpirun 來啟動了 2 個 start_server 程序,即兩個引數伺服器。

#include <gflags/gflags.h>

#include "server.hpp"

DEFINE_string(start_host, "beater7", "host name of start node\n");
DEFINE_string(init_port, "7773", "init port"); int main(int argc, char *argv[])
{
google::SetUsageMessage("[options]\n\
--start_host\tdefault: balin\n\
--init_port\n");
google::ParseCommandLineFlags(&argc, &argv, true);
paracel::init_thrds(FLAGS_start_host, FLAGS_init_port); // join inside
return 0;
}

在 include/server.hpp 檔案之中,init_thrds 函式啟動了一系列執行緒,具體邏輯如下。

  • 構建 zmq 環境;
  • 為每個執行緒建立了socket;
  • 建立伺服器處理執行緒;
  • 建立SSP執行緒;
  • 等待執行緒結束;
// init_host is the hostname of starter
void init_thrds(const paracel::str_type & init_host,
const paracel::str_type & init_port) { // 構建 zmq 環境
zmq::context_t context(2);
zmq::socket_t sock(context, ZMQ_REQ); paracel::str_type info = "tcp://" + init_host + ":" + init_port;
sock.connect(info.c_str()); char hostname[1024], freeport[1024];
size_t size = sizeof(freeport); // hostname of servers
gethostname(hostname, sizeof(hostname));
paracel::str_type ports = hostname;
ports += ":"; // create sock in every thrd 為每個執行緒建立了socket
std::vector<zmq::socket_t *> sock_pt_lst;
for(int i = 0; i < paracel::threads_num; ++i) {
zmq::socket_t *tmp;
tmp = new zmq::socket_t(context, ZMQ_REP);
sock_pt_lst.push_back(tmp);
sock_pt_lst.back()->bind("tcp://*:*");
sock_pt_lst.back()->getsockopt(ZMQ_LAST_ENDPOINT, &freeport, &size);
if(i == paracel::threads_num - 1) {
ports += local_parse_port(paracel::str_type(freeport));
} else {
ports += local_parse_port(std::move(paracel::str_type(freeport))) + ",";
}
} zmq::message_t request(ports.size());
std::memcpy((void *)request.data(), &ports[0], ports.size());
sock.send(request); zmq::message_t reply;
sock.recv(&reply); // 建立伺服器處理執行緒 thrd_exec
paracel::list_type<std::thread> threads;
for(int i = 0; i < paracel::threads_num - 1; ++i) {
threads.push_back(std::thread(thrd_exec, std::ref(*sock_pt_lst[i])));
}
// 建立ssp執行緒 thrd_exec_ssp
threads.push_back(std::thread(thrd_exec_ssp, std::ref(*sock_pt_lst.back()))); // 等待執行緒結束
for(auto & thrd : threads) {
thrd.join();
} for(int i = 0; i < paracel::threads_num; ++i) {
delete sock_pt_lst[i];
} zmq_ctx_destroy(context);
} // init_thrds

./prun.py -w 4 -p 2 -c cfg.json -m local your_paracel_install_path/bin/lr 的對應啟動邏輯圖具體如下:

 prun.py
+
|
|
| +----------------+
| +--> | start_server |
v | +----------------+
server_starter = init_starter +--> mpirun -n 2 +----+
+ | +----------------+
| | | start_server |
| | | + |
| +--> | | |
v | | |
worker_starter = init_starter +--> mpirun -n 4 | | |
+ | v |
| | init_thrds |
| | + |
| | | |
+-------+----+--+-------+ | | |
| | | | | | |
| | | | | v |
v v v v | thrd_exec |
bin/lr bin/lr bin/lr bin/lr | + |
| | |
| | |
| | |
| v |
| thrd_exec_ssp |
+----------------+

2.4 小結

目前我們知道了,worker和server都有多種啟動方式,比如用 mpi 的方式來啟動多個程序。

  • worker 端就是通過 driver.cpp 為主體,啟動多個程序。

  • server端就是通過 start_server 為主體,啟動多個程序,就是多個程序(引數伺服器)組成了一個叢集。

以上這些和ps-lite非常類似。

下面我們要分別深入這兩個角色的內部。

0x03 Server總體

通過之前ps-lite我們知道,引數伺服器大多使用 KV 儲存來儲存引數,所以我們先介紹KV儲存。

3.1 KV 儲存

在 include/kv_def.hpp 給出了server 端使用的KV儲存。

#include "paracel_types.hpp"
#include "kv.hpp" namespace paracel {
paracel::kvs<paracel::str_type, int> ssp_tbl; // 用來協助實現 SSP
paracel::kvs<paracel::str_type, paracel::str_type> tbl_store; // 主要的kv儲存
}

KV 儲存的定義在 include/kv.hpp,下面省略了部分程式碼。

可以看出來,基本功能就是維護了記憶體table,提供了set系列函式和get系列函式,其中當需要返回 value, unique 的時候,就採用hash函式處理。

template <class K, class V> struct kvs {

public:

  bool contains(const K & k) {
return kvdct.count(k);
} void set(const K & k, const V & v) {
kvdct[k] = v;
} void set_multi(const paracel::dict_type<K, V> & kvdict) {
for(auto & kv : kvdict) {
set(kv.first, kv.second);
}
} boost::optional<V> get(const K & k) {
auto fi = kvdct.find(k);
if(fi != kvdct.end()) {
return boost::optional<V>(fi->second);
} else return boost::none;
} bool get(const K & k, V & v) {
auto fi = kvdct.find(k);
if(fi != kvdct.end()) {
v = fi->second;
return true;
} else {
return false;
}
} paracel::list_type<V>
get_multi(const paracel::list_type<K> & keylst) {
paracel::list_type<V> valst;
for(auto & key : keylst) {
valst.push_back(kvdct.at(key));
}
return valst;
} void get_multi(const paracel::list_type<K> & keylst,
paracel::list_type<V> & valst) {
for(auto & key : keylst) {
valst.push_back(kvdct.at(key));
}
} void get_multi(const paracel::list_type<K> & keylst,
paracel::dict_type<K, V> & valdct) {
valdct.clear();
for(auto & key : keylst) {
auto it = kvdct.find(key);
if(it != kvdct.end()) {
valdct[key] = it->second;
}
}
} // 這裡使用了 hash 函式
// gets(key) -> value, unique
boost::optional<std::pair<V, paracel::hash_return_type> >
gets(const K & k) {
if(auto v = get(k)) {
std::pair<V, paracel::hash_return_type> ret(*v, hfunc(*v));
return boost::optional<
std::pair<V, paracel::hash_return_type>
>(ret);
} else {
return boost::none;
}
} // compare-and-set, cas(key, value, unique) -> True/False
bool cas(const K & k, const V & v, const paracel::hash_return_type & uniq) {
if(auto r = gets(k)) {
if(uniq == (*r).second) {
set(k, v);
return true;
} else {
return false;
}
} else {
kvdct[k] = v;
}
return true;
} paracel::dict_type<K, V> getall() {
return kvdct;
} private:
//std::tr1::unordered_map<K, V> kvdct;
paracel::dict_type<K, V> kvdct;
paracel::hash_type<V> hfunc;
};

3.2 服務處理邏輯

thrd_exec 執行緒實現了引數伺服器的基本處理邏輯:就是針對worker傳來的不同的命令進行相關處理(大部分就是針對KV儲存進行處理),比如:

  • 如果是 "pull" 命令,則使用 paracel::tbl_store.get(key, result) 獲取到數值,然後返回給使用者。
  • 如果是 "push" 命令,則使用 paracel::tbl_store.set(key, msg[2]) 往 KV 中插入引數;

需要注意的是,這裡使用了使用者定義的update函式,即:

  • 用了dlopen_update_lambda來對使用者設定的update函式進行生成,賦值為 update_f。
  • 當處理"update“或者"bupdate"型別請求時候,使用使用者的update函式來對kv進行處理。

下面刪除了部分非主體程式碼。

// thread entry
void thrd_exec(zmq::socket_t & sock) { paracel::packer<> pk;
update_result update_f;
filter_result pullall_special_f;
filter_result remove_special_f; // 這裡使用了dlopen_update_lambda來對使用者設定的update函式進行生成,賦值為 update_f
auto dlopen_update_lambda = [&](const paracel::str_type & fn, const paracel::str_type & fcn) {
void *handler = dlopen(fn.c_str(), RTLD_NOW | RTLD_LOCAL | RTLD_NODELETE);
auto local = dlsym(handler, fcn.c_str());
update_f = *(std::function<paracel::str_type(paracel::str_type, paracel::str_type)>*) local;
dlclose(handler);
}; // 主體邏輯
while(1) {
zmq::message_t s;
sock.recv(&s);
auto scrip = paracel::str_type(static_cast<const char *>(s.data()), s.size());
auto msg = paracel::str_split_by_word(scrip, paracel::seperator);
auto indicator = pk.unpack(msg[0]); if(indicator == "pull") { // 如果是從引數伺服器讀取引數,則直接返回
auto key = pk.unpack(msg[1]);
paracel::str_type result;
auto exist = paracel::tbl_store.get(key, result); // 讀取kv
if(!exist) {
paracel::str_type tmp = "nokey";
rep_send(sock, tmp);
} else {
rep_send(sock, result); // 返回
}
}
if(indicator == "pull_multi") { // 讀取多個引數
paracel::packer<paracel::list_type<paracel::str_type> > pk_l;
auto key_lst = pk_l.unpack(msg[1]);
auto result = paracel::tbl_store.get_multi(key_lst);
rep_pack_send(sock, result);
}
if(indicator == "pullall") { // 讀取所有引數
auto dct = paracel::tbl_store.getall();
rep_pack_send(sock, dct);
}
mutex.lock();
if(indicator == "push") { // 插入引數
auto key = pk.unpack(msg[1]);
paracel::tbl_store.set(key, msg[2]);
bool result = true;
rep_pack_send(sock, result);
}
if(indicator == "push_multi") { // 插入多個引數
paracel::packer<paracel::list_type<paracel::str_type> > pk_l;
paracel::dict_type<paracel::str_type, paracel::str_type> kv_pairs;
auto key_lst = pk_l.unpack(msg[1]);
auto val_lst = pk_l.unpack(msg[2]);
assert(key_lst.size() == val_lst.size());
for(int i = 0; i < (int)key_lst.size(); ++i) {
kv_pairs[key_lst[i]] = val_lst[i];
}
paracel::tbl_store.set_multi(kv_pairs); //插入kv
bool result = true;
rep_pack_send(sock, result);
}
if(indicator == "update" || indicator == "bupdate") { // 更新引數
if(msg.size() > 3) {
if(msg.size() != 5) {
ERROR_ABORT("invalid invoke in server end");
}
// open request func
auto file_name = pk.unpack(msg[3]);
auto func_name = pk.unpack(msg[4]);
dlopen_update_lambda(file_name, func_name);
} else {
if(!update_f) {
dlopen_update_lambda("../local/build/lib/default.so",
"default_incr_i");
}
}
auto key = pk.unpack(msg[1]);
// 這裡使用使用者的update函式來對kv進行處理
std::string result = kv_update(key, msg[2], update_f);
rep_send(sock, result);
} if(indicator == "remove") { // 刪除引數
auto key = pk.unpack(msg[1]);
auto result = paracel::tbl_store.del(key);
rep_pack_send(sock, result);
}
mutex.unlock();
} // while
} // thrd_exec

簡化如圖:

+--------------------------------------------------------------------------------------+
| thrd_exec |
| |
| +---------------------------------> while(1) |
| | + |
| | | |
| | | |
| | +----------+----------+--------+--+------+----------+---------+---------+ |
| | | | | | | | | | |
| | | | | | | | | | |
| | | | | | | | | | |
| | | | | | | | | | |
| | | | | | | | | | |
| | v v v v v v v v |
| | |
| | pull pull_multi pullall push push_multi update bupdate remove |
| | + + + + + + + + |
| | | | | | | | | | |
| | | | | | | | | | |
| | | | | | | | | | |
| | | | | | | | | | |
| | v v v v v v v v |
| | +----------+----------+--------+----+----+----------+---------+---------+ |
| | | |
| | | |
| | | |
| | | |
| +-----------------------------------------+ |
| |
+--------------------------------------------------------------------------------------+

3.3 小結

目前為止,我們可以看到,Paracel和ps-lite也很類似,伺服器維護了一個儲存,伺服器也可以處理客戶端的請求。

0x04 Worker總體

Worker 就是用來訓練演算法的程序。從前面我們瞭解,演算法需要繼承paracel::paralg才能使用引數伺服器功能。

namespace paracel {
namespace alg { class logistic_regression: public paracel::paralg { .....

paracel::paralg 就可以認為是引數伺服器的API,或者代理,我們下面就看看。

4.1 基礎功能類 Paralg

Paralg是提供Paracel主要功能的基本類,可以理解為一個演算法API類,或者對外功能API類

我們只給出其成員變數,暫時省略其函式實現。最主要幾個為:

  • int stale_cache, clock, total_iters; 同步需要
  • paracel::Comm worker_comm; 通訊類,比如 MPI 通訊
  • int nworker = 1; worker的數目
  • bool ssp_switch = false; 是否開啟 SSP 模式
  • parasrv *ps_obj; // 可以理解為是正式的引數伺服器類。
class paralg {
private: class parasrv { // 可以理解為是引數伺服器類 using l_type = paracel::list_type<paracel::kvclt>;
using dl_type = paracel::list_type<paracel::dict_type<paracel::str_type, paracel::str_type> >; public:
parasrv(paracel::str_type hosts_dct_str) {
// init dct_lst
dct_lst = paracel::get_hostnames_dict(hosts_dct_str);
// init srv_sz
srv_sz = dct_lst.size();
// init kvm
for(auto & srv : dct_lst) {
paracel::kvclt kvc(srv["host"], srv["ports"]);
kvm.push_back(std::move(kvc));
}
// init servers
for(auto i = 0; i < srv_sz; ++i) {
servers.push_back(i);
}
// init hashring
p_ring = new paracel::ring<int>(servers);
} virtual ~parasrv() {
delete p_ring;
} public:
dl_type dct_lst;
int srv_sz = 1;
l_type kvm;
paracel::list_type<int> servers; // 具體伺服器列表
paracel::ring<int> *p_ring; // hash ring }; // nested class parasrv private:
int stale_cache, clock, total_iters; // 同步需要
int clock_server = 0;
paracel::Comm worker_comm; //通訊類,比如 MPI 通訊
paracel::str_type output;
int nworker = 1;
int rounds = 1;
int limit_s = 0;
bool ssp_switch = false;
parasrv *ps_obj; // 可以理解為是正式的引數伺服器類。
paracel::dict_type<paracel::default_id_type, paracel::default_id_type> rm;
paracel::dict_type<paracel::default_id_type, paracel::default_id_type> cm;
paracel::dict_type<paracel::default_id_type, paracel::default_id_type> dm;
paracel::dict_type<paracel::default_id_type, paracel::default_id_type> col_dm;
paracel::dict_type<paracel::str_type, paracel::str_type> keymap;
paracel::dict_type<paracel::str_type, boost::any> cached_para;
paracel::update_result update_f;
int npx = 1, npy = 1;
}

4.2 派生

編寫一個Paracel程式需要對paralg基類進行子類化,並且必須重寫virtual solve方法。其中一些是SPMD iterfaces 並行介面。

我們從之前 LR 的實現可以看到需要繼承 paracel::paralg 。

class logistic_regression: public paracel::paralg

就是說,使用者的solve函式可以直接呼叫 Paralg 的函式來完成基本功能。

我們以 paracel::paracel_read 為例,可以看到是使用 parasrv.kvm 的功能,我們後續會繼續介紹 parasrv。

  template <class V>
V paracel_read(const paracel::str_type & key,
int replica_id = -1) {
if(ssp_switch) { // 如果應用ssp,應該如何處理。我們下文就將具體介紹ssp如何處理
V val;
if(clock == 0 || clock == total_iters) {
cached_para[key] = boost::any_cast<V>(ps_obj->
kvm[ps_obj->p_ring->get_server(key)].
pull<V>(key));
val = boost::any_cast<V>(cached_para[key]);
} else if(stale_cache + limit_s > clock) {
val = boost::any_cast<V>(cached_para[key]);
} else {
while(stale_cache + limit_s < clock) {
stale_cache = ps_obj->
kvm[clock_server].pull_int(paracel::str_type("server_clock"));
}
cached_para[key] = boost::any_cast<V>(ps_obj->
kvm[ps_obj->p_ring->get_server(key)].
pull<V>(key));
val = boost::any_cast<V>(cached_para[key]);
}
return val;
}
// 否則直接返回
return ps_obj->kvm[ps_obj->p_ring->get_server(key)].pull<V>(key);
}

worker邏輯如下:

+---------------------------------------------------------------------------+
| Algorithm |
| ^ +------------------------------v |
| | | |
| | | |
| | v |
| | +----------------------------+------------------------------+ |
| | | paracel_read | |
| | | | |
| | | ps_obj+>kvm[ps_obj+>p_ring+>get_server(key)].pull<V>(key) | |
| | | | |
| | +----------------------------+------------------------------+ |
| | | |
| | | |
| | | |
| | v |
| | Compute |
| | + |
| | | |
| | | |
| | v |
| | +---------------------------+-------------------------------+ |
| | | paracel_bupdate | |
| | | ps_obj->kvm[indx].bupdate | |
| | | | |
| | +---------------------------+-------------------------------+ |
| | | |
| | | |
| | | |
| | | |
| +-----<--------------------------+ |
| |
+---------------------------------------------------------------------------+

4.3 小結

Worker端的機理也類似ps-lite,通過read,pull等操作,向伺服器提出請求。

0x05 Ring Hash

在沐神論文中,Ring hash 是與資料一致性,容錯,可擴充套件等機制聯絡在一起,比如:

parameter server 在資料一致性上,使用的是傳統的一致性雜湊演算法,引數key與server node id被插入到一個hash ring中。

但可惜的是,ps-lite 沒有提供這部分程式碼,paracel 雖然有 ring hash,但也不齊全,豆瓣沒有開源容錯和一致性等部分。我們只能基於已有程式碼進行學習分析

5.1 原理

這裡只是大致講解下,有需求的同學可以去網上搜索詳細文章。

從拗口的技術術語來解釋,一致性雜湊的技術關鍵點是:按照常用的hash演算法來將對應的key雜湊到一個具有2^32次方個桶的空間中,即0 ~ (2^32)-1的數字空間。我們可以將這些數字頭尾相連,想象成一個閉合的環形。

用通俗白話來理解,這個關鍵點就是:在部署伺服器的時候,伺服器的序號空間已經配置成了一個固定的非常大的數字 1~2^32(不需要再改變)。伺服器可以分配為 1~2^32 中任一序號。這樣伺服器叢集可以固定大多數演算法規則 (因為序號空間是演算法的重要引數),這樣面對擴容等變化只有"分配規則" 需要根據實際系統容量做相應微調。從而對整體系統影響較小。

5.2 定義

ring 就是hash 環的實現類,這裡主要功能就是把 伺服器 加入到 hash ring 之中,以及從ring之中取出伺服器。

// T rep type of server name
template <class T>
class ring { public: ring(paracel::list_type<T> names) {
for(auto & name : names) {
add_server(name);
}
} ring(paracel::list_type<T> names, int cp) : replicas(cp) {
for(auto & name : names) {
add_server(name);
}
} void add_server(const T & name) {
//std::hash<paracel::str_type> hfunc;
paracel::hash_type<paracel::str_type> hfunc;
std::ostringstream tmp;
tmp << name;
auto name_str = tmp.str();
for(int i = 0; i < replicas; ++i) { //對每一個副本進行處理
std::ostringstream cvt;
cvt << i;
auto n = name_str + ":" + cvt.str();
auto key = hfunc(n); // 依據name生成一個key
srv_hashring_dct[key] = name; //新增value
srv_hashring.push_back(key); //往list新增內容
}
// sort srv_hashring
std::sort(srv_hashring.begin(), srv_hashring.end());
} void remove_server(const T & name) {
//std::hash<paracel::str_type> hfunc;
paracel::hash_type<paracel::str_type> hfunc;
std::ostringstream tmp;
tmp << name;
auto name_str = tmp.str();
for(int i = 0; i < replicas; ++i) { // 對每個副本進行處理
std::ostringstream cvt;
cvt << i;
auto n = name_str + ":" + cvt.str();
auto key = hfunc(n);// 依據name生成一個key
srv_hashring_dct.erase(key);// 刪除value
auto iter = std::find(srv_hashring.begin(), srv_hashring.end(), key);
if(iter != srv_hashring.end()) {
srv_hashring.erase(iter); // 刪除list中的內容
}
}
} // TODO: relief load of srv_hashring_dct[srv_hashring[0]]
template <class P>
T get_server(const P & skey) {
//std::hash<P> hfunc;
paracel::hash_type<P> hfunc;
auto key = hfunc(skey);// 依據name生成一個key
auto server = srv_hashring[paracel::ring_bsearch(srv_hashring, key)];//獲取server
return srv_hashring_dct[server];
} private:
int replicas = 32;
// 分別用list和dict儲存
paracel::list_type<paracel::hash_return_type> srv_hashring;
paracel::dict_type<paracel::hash_return_type, T> srv_hashring_dct;
};

5.3 使用

我們使用 paracel_read 來看,可以發現呼叫順序是

  • 先使用 ps_obj->p_ring->get_server(key) 得到本 key 對應的 引數伺服器(就是從ring hash 中提取出來某一個引數伺服器);
  • 然後從這個伺服器中獲取到本 key 對應的 value;
V paracel_read(const paracel::str_type & key,
int replica_id = -1) {
...... ps_obj->kvm[ps_obj->p_ring->get_server(key)].pull<V>(key);
}

5.4 小結

這裡是和ps-lite的不同之處,就是用ring-hash來維護資料一致性,容錯等,比如把 伺服器 加入到 hash ring 之中,以及從ring之中取出伺服器。

0x06 引數伺服器介面 parasrv

我們把目前邏輯梳理一下,綜合看看。

6.1 引數伺服器介面 parasrv 構建

如何使用ring hash,需要從 parasrv 說起。

我們知道,paralg 是基礎API類,其中在 paralg 中有如下定義 以及 構建了 ps_obj , ps_obj是一個 parasrv 型別的例項。

注:以下都是在worker端使用的型別。

// paralg 內程式碼

  parasrv *ps_obj; // 成員變數定義,引數伺服器介面

  paralg(paracel::str_type hosts_dct_str,
paracel::Comm comm,
paracel::str_type _output = "",
int _rounds = 1,
int _limit_s = 0,
bool _ssp_switch = false) : worker_comm(comm),
output(_output),
nworker(comm.get_size()),
rounds(_rounds),
limit_s(_limit_s),
ssp_switch(_ssp_switch) {
ps_obj = new parasrv(hosts_dct_str); // 構建引數伺服器,一個parasrv的例項
init_output(_output);
clock = 0;
stale_cache = 0;
clock_server = 0;
total_iters = rounds;
if(worker_comm.get_rank() == 0) {
paracel::str_type key = "worker_sz";
(ps_obj->kvm[clock_server]).
push_int(key, worker_comm.get_size()); // 初始化時鐘伺服器
}
paracel_sync(); // mpi barrier同步一下
}

6.2 引數伺服器介面 parasrv 定義

parasrv 的定義如下,其中 p_ring 就是 ring 例項,使用 p_ring = new paracel::ring<int>(servers) 來完成了構建。

其中p_ring 是 ring hash,kvm是具體的kv儲存列表。

  class parasrv {

    using l_type = paracel::list_type<paracel::kvclt>;
using dl_type = paracel::list_type<paracel::dict_type<paracel::str_type, paracel::str_type> >; public:
parasrv(paracel::str_type hosts_dct_str) {
// 初始化host資訊,srv大小,kvm,servers,ring hash
// init dct_lst
dct_lst = paracel::get_hostnames_dict(hosts_dct_str);
// init srv_sz
srv_sz = dct_lst.size();
// init kvm
for(auto & srv : dct_lst) {
paracel::kvclt kvc(srv["host"], srv["ports"]);
kvm.push_back(std::move(kvc));
}
// init servers
for(auto i = 0; i < srv_sz; ++i) {
servers.push_back(i);
}
// init hashring
p_ring = new paracel::ring<int>(servers); // 構建
} virtual ~parasrv() {
delete p_ring;
} public:
dl_type dct_lst;
int srv_sz = 1;
l_type kvm; // 具體KV儲存介面
paracel::list_type<int> servers;
paracel::ring<int> *p_ring; // ring hash }; // nested class parasrv

kvm 初始化如下:

// init kvm
for(auto & srv : dct_lst) {
paracel::kvclt kvc(srv["host"], srv["ports"]);
kvm.push_back(std::move(kvc));
}

6.3 KV儲存控制介面

kvclt 是 kv control 的抽象。

只摘取部分程式碼,就是找到對應的伺服器進行互動

namespace paracel {

struct kvclt {

public:
kvclt(paracel::str_type hostname,
paracel::str_type ports) : host(hostname), context(1) {
ports_lst = paracel::str_split(ports, ',');
conn_prefix = "tcp://" + host + ":";
} template <class V, class K>
bool pull(const K & key, V & val) { // 從引數伺服器拉取
if(p_pull_sock == nullptr) {
p_pull_sock.reset(create_req_sock(ports_lst[0]));
}
auto scrip = paste(paracel::str_type("pull"), key); // paracel::str_type
return req_send_recv(*p_pull_sock, scrip, val);
} template <class K, class V>
bool push(const K & key, const V & val) { // 往引數伺服器推送
if(p_push_sock == nullptr) {
p_push_sock.reset(create_req_sock(ports_lst[1]));
}
auto scrip = paste(paracel::str_type("push"), key, val);
bool stat;
auto r = req_send_recv(*p_push_sock, scrip, stat);
return r && stat;
} template <class V>
bool req_send_recv(zmq::socket_t & sock,
const paracel::str_type & scrip,
V & val) {
zmq::message_t req_msg(scrip.size());
std::memcpy((void *)req_msg.data(), &scrip[0], scrip.size());
sock.send(req_msg);
zmq::message_t rep_msg;
sock.recv(&rep_msg);
paracel::packer<V> pk;
if(!rep_msg.size()) {
ERROR_ABORT("paracel internal error!");
} else {
std::string data = paracel::str_type(
static_cast<char*>(rep_msg.data()),
rep_msg.size());
if(data == "nokey") return false;
val = pk.unpack(data);
}
return true;
} private:
paracel::str_type host;
paracel::list_type<paracel::str_type> ports_lst;
paracel::str_type conn_prefix;
zmq::context_t context;
std::unique_ptr<zmq::socket_t> p_contains_sock = nullptr;
std::unique_ptr<zmq::socket_t> p_pull_sock = nullptr;
std::unique_ptr<zmq::socket_t> p_pull_multi_sock = nullptr;
std::unique_ptr<zmq::socket_t> p_pullall_sock = nullptr;
std::unique_ptr<zmq::socket_t> p_push_sock = nullptr;
std::unique_ptr<zmq::socket_t> p_push_multi_sock = nullptr;
std::unique_ptr<zmq::socket_t> p_update_sock = nullptr;
std::unique_ptr<zmq::socket_t> p_bupdate_sock = nullptr;
std::unique_ptr<zmq::socket_t> p_bupdate_multi_sock = nullptr;
std::unique_ptr<zmq::socket_t> p_remove_sock = nullptr;
std::unique_ptr<zmq::socket_t> p_clear_sock = nullptr;
std::unique_ptr<zmq::socket_t> p_ssp_sock = nullptr; }; // struct kvclt } // namespace paracel

所以目前總體邏輯如下:

+------------------+                                worker         +          server
| paralg | |
| | |
| | |
| parasrv *ps_obj | |
| + | | +------------------+
| | | | | start_server |
+------------------+ | | |
| | | |
| | | |
v | | |
+------------+-----+ +------------------+ +---------+ | | thrd_exec |
| parasrv | |kvclt | | kvclt | | | |
| | | | | | | | |
| | | host | | | | | thrd_exec_ssp |
| servers | | | | | | | |
| | | ports_lst | | | | | |
| kvm +-----------> | |.....| | | | ssp_tbl |
| | | context | | | | | |
| p_ring | | | | | | | |
| + | | conn_prefix | | | | | tbl_store |
| | | | | | | | | |
+------------------+ | p_pull_sock+---+ | | | | |
| | | | | | | | |
| | p_push_sock | | | | | | |
| | + | | | | | | |
v | | | | | | | | |
+------------+------+ +------------------+ | +---------+ | | |
| ring | | | | +---+---+----------+
| | | | | ^ ^
| | | | | | |
| srv_hashring | | +-----------------------+ |
| | +------------------------------------+
| srv_hashring_dct | |
| | |
+-------------------+ +

手機如下:

0xEE 個人資訊

★★★★★★關於生活和技術的思考★★★★★★

微信公眾賬號:羅西的思考

如果您想及時得到個人撰寫文章的訊息推送,或者想看看個人推薦的技術資料,敬請關注。

0xFF 參考

PARACEL:讓分散式機器學習變得簡單

引數伺服器——分散式機器學習的新殺器