概述
目前,GO 语言想要本地调用(即不通过微服务进行远程调用)XGBoost 模型,可以使用以下两个开源库:
但是,我试了下,这两个库对于一些新的版本产生的 XGBoost 模型支持并不好,一载入模型就会报错 Out Of Memory。于是,我们采用了第二个方法,就是通过 GO 调用 XGBoost 的 C 库的方式实现。
获取 XGBoost 的 C 库
XGBoost 的官方开源库位于:xgboost。需要注意的是,这个库是基于 Apache 2.0 开源协议的。
首先,我们需要克隆这个库。由于 XGBoost 库使用了 submodules,所以,在克隆的时候需要注意把依赖的其它仓库也克隆下来。
git clone --recursive https://github.com/dmlc/xgboost
Windows 平台的用户,使用 git bash 的时候,在克隆了 xgboost 之后,还需要额外执行下面两条命令:
git submodule init
git submodule update
Linux 下编译 so 库
Linux 系统下,需要先安装 CMake,然后,再执行以下命令编译 xgboost:
cd xgboost
mkdir build
cd build
cmake ..
make -j $(nproc)
编译完成后,我们可以在 lib 目录下找到编译好的 libxgboost.so 文件。
Windows 下编译 dll 库
Windows 下,需要安装 CMake 和 Visual Studio 编译库,然后,执行以下命令:
mkdir buildcd build
cmake .. -G"Visual Studio 14 2015 Win64"
# for VS15: cmake .. -G"Visual Studio 15 2017" -A x64
# for VS16: cmake .. -G"Visual Studio 16 2019" -A x64
cmake --build . --config Release
编译完成后,我们可以在 lib 目录下找到编译好的 xgboost.dll 文件。
XGBoost 的 API 用法
XGBoost 的 C 的 API 接口定义在 include/xgboost/c_api.h 文件中。
XGBoost 的函数的返回值都是表示函数是否调用成功,0 表示调用成功。
由于我们在 GO 服务中,仅仅使用 xgboost 的预测功能,所以,我们只关注预测的时候需要用到的 API,其它的高级特性我们暂时忽略。
以下是我们本次需要使用到的函数。
XGB_DLL void XGBoostVersion(int* major, int* minor, int* patch);
XGB_DLL 是一个宏,用来标记这个函数是一个 C 风格的导出函数。这个函数要求传入三个整形值的指针,函数会给这三个参数填上对应的版本信息,即 major.minor.patch,我使用的库的版本是 1.6.0。
XGB_DLL const char *XGBGetLastError(void);
由于所有的 xgboost 函数调用的返回值都是一个错误码,想要知道对应的错误信息,则需要通过调用这个函数来返回。如果没有发生错误,则返回结果是一个空串。
XGB_DLL int XGBoosterCreate(const DMatrixHandle dmats[],
bst_ulong len,
BoosterHandle *out);
这个函数用来创建一个 xgboost 实例,创建的实例指针通过 out 参数返回。默认情况下,dmats 可以填 NULL,len 可以填 0。
XGB_DLL int XGBoosterFree(BoosterHandle handle);
这个函数用来释放一个 xgboost 实例。
XGB_DLL int XGBoosterLoadModel(BoosterHandle handle,
const char *fname);
这个函数可以用来加载一个模型文件。模型文件支持二进制格式和 json 格式。
XGB_DLL int XGDMatrixCreateFromMat(const float *data,
bst_ulong nrow,
bst_ulong ncol,
float missing,
DMatrixHandle *out);
这个函数可以从一个数组来创建一个 DMatrix 实例,实例的指针通过 out 返回。传入的 data 是一个 C 的二维数组,如果需要用 GO 来模拟 C 的二维数组,则需要通过一个一维数组(行优先)来模拟。nrow 和 ncol 指定了这个二维数组的行数和列数。missing 指定缺失值,一般都是用 nan。
XGB_DLL int XGDMatrixFree(DMatrixHandle handle);
这个函数用来释放一个 DMatrix 实例,以免内存泄漏。
XGB_DLL int XGBoosterPredict(BoosterHandle handle,
DMatrixHandle dmat,
int option_mask,
unsigned ntree_limit,
int training,
bst_ulong *out_len,
const float **out_result);
这个函数实现预测,handle 指的是 xgboost 实例,通过 XGBoosterCreate 创建,dmat 指的是 DMatrix 实例,通过 XGDMatrixCreateFromMat 创建。option_mask,ntree_limit,training 默认都可以填 0。预测的结果是一个一维数组。out_len 返回了这个一维数组的长度,out_result 中存储了返回的一维数组,也就是预测结果。
Windows 下的调用方式
因为 Windows 和 Linux 两个系统下的动态库形式不一样,所以,调用 C 库的方式也有所不同。
GO 语言可以通过代码文件名的后缀来实现不同系统下编译不同的源码。即 xxxx_windows.go 只会在 Windows 下编译,而 xxxx_linux.go 只会在 Linux 下编译。
首先,我们需要用过 syscall.NewLazyDLL 这个函数加载一个 dll 动态库,比如:
xgbLib := syscall.NewLazyDLL("xgboost.dll")
随后,我们需要使用 lazyDLL.NewProc 方法来定位一个远程调用函数。需要注意的是,这里定位函数的方法是通过函数名,而函数名指的是 DLL 的导出符号表里的函数名。在 C 语言中,函数名就是 DLL 的导出符号,而在 C++ 中,由于为了实现了函数重载而进行了命名粉碎,函数名并不是其导出符号,所以,这种方法使用 C++ 的库的时候会比较麻烦。不过,因为我们使用的是 C 的 xgboost 库,所以,没有这种麻烦,我们可以直接通过函数名来定位。比如:
procXGBoostVersion = xgbLibrary.NewProc("XGBoostVersion")
procXGBoosterCreate = xgbLibrary.NewProc("XGBoosterCreate")
procXGBoosterLoadModel = xgbLibrary.NewProc("XGBoosterLoadModel")
procXGDMatrixCreateFromMat = xgbLibrary.NewProc("XGDMatrixCreateFromMat")
procXGBGetLastError = xgbLibrary.NewProc("XGBGetLastError")
procXGBoosterPredict = xgbLibrary.NewProc("XGBoosterPredict")
procXGDMatrixFree = xgbLibrary.NewProc("XGDMatrixFree")
procXGBoosterFree = xgbLibrary.NewProc("XGBoosterFree")
最后,可以通过 Call 方法实现远程函数的调用。因为远程调用函数的参数都是 C 类型的,所以,在调用前需要进行 GO 类型对 C 类型的转换。又由于 Call 方法传递的参数都要求是 uintptr 类型的,所以,在最后都需要统一转成这个类型。
我们以获取版本号为例,看一下具体的实现流程:
xgbLib := syscall.NewLazyDLL("xgboost.dll")
procXGBoostVersion = xgbLibrary.NewProc("XGBoostVersion")
var major, minor, patch int32
// 传入参数要求是 C 的 int*,所以,需要通过 uintptr(unsafe.Pointer(&major)) 来实现
_, _, err := procXGBoostVersion.Call(uintptr(unsafe.Pointer(&major)), uintptr(unsafe.Pointer(&minor)), uintptr(unsafe.Pointer(&patch)))
if err != nil && strings.Index(err.Error(), "success") == -1 {
panic(err)
}
// 此时,我们就获取到了版本号,输出即可
fmt.Printf("XGBoost version is %d.%d.%d\n", major, minor, patch)
需要注意的一点是,Call 函数成功调用了,返回的 error 也不是 nil,而是一个调用成功的提示。所以,判断是否出错不能通过 err != nil 来判断。
由于 xgboost 的函数都是返回错误码的形式,所以,我们可以对其进行一个统一的安全调用封装,即:
func safeCall(proc *syscall.LazyProc, args ...uintptr) (err error) {
if proc == nil {
return errors.New("remote function is not initialized")
}
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic: %v", r)
}
}()
ret, _, err := proc.Call(args...)
if ret != 0 {
errMsg := GetLastError()
return fmt.Errorf("remote function `%s` returns %d, msg: %s", proc.Name, ret, errMsg)
}
if err != nil && strings.Index(err.Error(), "success") == -1 {
return fmt.Errorf("remote function `%s` failed: %v", proc.Name, err.Error())
}
return nil
}
那么,剩下的就是参数如何转换了。数值类型,可以直接用 uintptr(val) 的方式转换,指针类型,可以借助 unsafe.Pointer 来转换。比如 uintptr(unsafe.Pointer(&val))。
把 GO 语言的字符串转换成 const char*,则可以通过下面这种方法实现:
func strPtrParam(s string) uintptr {
bytes, err := syscall.BytePtrFromString(s)
if err != nil {
return nullptr
}
return uintptr(unsafe.Pointer(bytes))
}
而反过来,要把一个 const char* 转换成 GO 的字符串,则可以通过下面这种方法:
func ConvertGoString(ret uintptr) string {
p := (*byte)(unsafe.Pointer(ret))
if p == nil {
return ""
}
msg := make([]byte, 0)
for *p != 0 {
msg = append(msg, *p)
ret += unsafe.Sizeof(byte(0))
p = (*byte)(unsafe.Pointer(ret))
}
return string(msg)
}
以上,便完成了调用,本代码在 Windows 下测试通过。
Linux 下的调用方式
Linux 下可以直接用 CGO 的方式调用。即首先,引入 xgboost 的头文件和库文件,我们的例子下,头文件放在当前目录,命名为 c_api.h,库文件放在同级目录 lib 下,命名为 libxgboost.so。
那么,我们在代码中加入以下 import 代码:
/*
#cgo CFLAGS: -I.
#cgo LDFLAGS: -L../lib -lxgboost -Wl,-rpath,lib
#include "c_api.h"
*/
import "C"
注意,import "C" 上面的注释也是代码的一部分,如果这个注释和 import "C" 之间空了一行的话,会编译报错。上面的 #cgo CFLAGS: -I. 指定了 c_api.h 所在的目录,. 则表示在当前目录。#cgo LDFLAGS: -L../lib -lxgboost -Wl,-rpath,lib 则指定了 so 库的链接选项,-L../lib 指定了 so 库所在的路径,-lxgboost 指定了 so 库的名称。比如这里指定了名称叫 xgboost,则编译器就会去寻找 libxgboost.so 这个文件。
引入 CGO 之后,就可以直接像 C 语言那样直接调用库中的函数了,比如,获取版本号,就可以这么调用:
func Version() (major, minor, patch int32) {
var cMajor, cMinor, cPatch C.int
C.XGBoostVersion(&cMajor, &cMinor, &cPatch)
return int32(cMajor), int32(cMinor), int32(cPatch)
}
所有定义在 c_api.h 中类型都可以直接使用,比如:
var boosterHandle C.BoosterHandle
ret := C.XGBoosterCreate((*C.DMatrixHandle)(nil), 0, &boosterHandle)
if ret != 0 {
err = fmt.Errorf("XGBoosterCreate returns %d, error is %s", ret, GetLastError())
}
GO 字符串转 C 字符串,可以通过 C.CString(str) 实现,而 C 字符串转 GO 字符串,则可以通过 C.GoString(cstr) 实现。
以上,便完成了调用,本代码在 Debian 系统下测试通过。