如何在 GO 服务中本地调用 XGBoost 模型

1,995 阅读7分钟

概述

目前,GO 语言想要本地调用(即不通过微服务进行远程调用)XGBoost 模型,可以使用以下两个开源库:

但是,我试了下,这两个库对于一些新的版本产生的 XGBoost 模型支持并不好,一载入模型就会报错 Out Of Memory。于是,我们采用了第二个方法,就是通过 GO 调用 XGBoost 的 C 库的方式实现。

获取 XGBoost 的 C 库

XGBoost 的官方开源库位于:xgboost。需要注意的是,这个库是基于 Apache 2.0 开源协议的。

首先,我们需要克隆这个库。由于 XGBoost 库使用了 submodules,所以,在克隆的时候需要注意把依赖的其它仓库也克隆下来。

git clone --recursive https://github.com/dmlc/xgboost

Windows 平台的用户,使用 git bash 的时候,在克隆了 xgboost 之后,还需要额外执行下面两条命令:

git submodule init
git submodule update

Linux 下编译 so 库

Linux 系统下,需要先安装 CMake,然后,再执行以下命令编译 xgboost:

cd xgboost
mkdir build
cd build
cmake ..
make -j $(nproc)

编译完成后,我们可以在 lib 目录下找到编译好的 libxgboost.so 文件。

Windows 下编译 dll 库

Windows 下,需要安装 CMake 和 Visual Studio 编译库,然后,执行以下命令:

mkdir buildcd build
cmake .. -G"Visual Studio 14 2015 Win64"
 # for VS15: cmake .. -G"Visual Studio 15 2017" -A x64
 # for VS16: cmake .. -G"Visual Studio 16 2019" -A x64
cmake --build . --config Release

编译完成后,我们可以在 lib 目录下找到编译好的 xgboost.dll 文件。

XGBoost 的 API 用法

XGBoost 的 C 的 API 接口定义在 include/xgboost/c_api.h 文件中。

XGBoost 的函数的返回值都是表示函数是否调用成功,0 表示调用成功。

由于我们在 GO 服务中,仅仅使用 xgboost 的预测功能,所以,我们只关注预测的时候需要用到的 API,其它的高级特性我们暂时忽略。

以下是我们本次需要使用到的函数。

XGB_DLL void XGBoostVersion(int* major, int* minor, int* patch);

XGB_DLL 是一个宏,用来标记这个函数是一个 C 风格的导出函数。这个函数要求传入三个整形值的指针,函数会给这三个参数填上对应的版本信息,即 major.minor.patch,我使用的库的版本是 1.6.0

XGB_DLL const char *XGBGetLastError(void);

由于所有的 xgboost 函数调用的返回值都是一个错误码,想要知道对应的错误信息,则需要通过调用这个函数来返回。如果没有发生错误,则返回结果是一个空串。

XGB_DLL int XGBoosterCreate(const DMatrixHandle dmats[],
                            bst_ulong len,
                            BoosterHandle *out);

这个函数用来创建一个 xgboost 实例,创建的实例指针通过 out 参数返回。默认情况下,dmats 可以填 NULLlen 可以填 0。

XGB_DLL int XGBoosterFree(BoosterHandle handle);

这个函数用来释放一个 xgboost 实例。

XGB_DLL int XGBoosterLoadModel(BoosterHandle handle,
                               const char *fname);

这个函数可以用来加载一个模型文件。模型文件支持二进制格式和 json 格式。

XGB_DLL int XGDMatrixCreateFromMat(const float *data,
                                   bst_ulong nrow,
                                   bst_ulong ncol,
                                   float missing,
                                   DMatrixHandle *out);

这个函数可以从一个数组来创建一个 DMatrix 实例,实例的指针通过 out 返回。传入的 data 是一个 C 的二维数组,如果需要用 GO 来模拟 C 的二维数组,则需要通过一个一维数组(行优先)来模拟。nrowncol 指定了这个二维数组的行数和列数。missing 指定缺失值,一般都是用 nan

XGB_DLL int XGDMatrixFree(DMatrixHandle handle);

这个函数用来释放一个 DMatrix 实例,以免内存泄漏。

XGB_DLL int XGBoosterPredict(BoosterHandle handle,
                             DMatrixHandle dmat,
                             int option_mask,
                             unsigned ntree_limit,
                             int training,
                             bst_ulong *out_len,
                             const float **out_result);

这个函数实现预测,handle 指的是 xgboost 实例,通过 XGBoosterCreate 创建,dmat 指的是 DMatrix 实例,通过 XGDMatrixCreateFromMat 创建。option_maskntree_limittraining 默认都可以填 0。预测的结果是一个一维数组。out_len 返回了这个一维数组的长度,out_result 中存储了返回的一维数组,也就是预测结果。

Windows 下的调用方式

因为 Windows 和 Linux 两个系统下的动态库形式不一样,所以,调用 C 库的方式也有所不同。

GO 语言可以通过代码文件名的后缀来实现不同系统下编译不同的源码。即 xxxx_windows.go 只会在 Windows 下编译,而 xxxx_linux.go 只会在 Linux 下编译。

首先,我们需要用过 syscall.NewLazyDLL 这个函数加载一个 dll 动态库,比如:

xgbLib := syscall.NewLazyDLL("xgboost.dll")

随后,我们需要使用 lazyDLL.NewProc 方法来定位一个远程调用函数。需要注意的是,这里定位函数的方法是通过函数名,而函数名指的是 DLL 的导出符号表里的函数名。在 C 语言中,函数名就是 DLL 的导出符号,而在 C++ 中,由于为了实现了函数重载而进行了命名粉碎,函数名并不是其导出符号,所以,这种方法使用 C++ 的库的时候会比较麻烦。不过,因为我们使用的是 C 的 xgboost 库,所以,没有这种麻烦,我们可以直接通过函数名来定位。比如:

procXGBoostVersion = xgbLibrary.NewProc("XGBoostVersion")
procXGBoosterCreate = xgbLibrary.NewProc("XGBoosterCreate")
procXGBoosterLoadModel = xgbLibrary.NewProc("XGBoosterLoadModel")
procXGDMatrixCreateFromMat = xgbLibrary.NewProc("XGDMatrixCreateFromMat")
procXGBGetLastError = xgbLibrary.NewProc("XGBGetLastError")
procXGBoosterPredict = xgbLibrary.NewProc("XGBoosterPredict")
procXGDMatrixFree = xgbLibrary.NewProc("XGDMatrixFree")
procXGBoosterFree = xgbLibrary.NewProc("XGBoosterFree")

最后,可以通过 Call 方法实现远程函数的调用。因为远程调用函数的参数都是 C 类型的,所以,在调用前需要进行 GO 类型对 C 类型的转换。又由于 Call 方法传递的参数都要求是 uintptr 类型的,所以,在最后都需要统一转成这个类型。

我们以获取版本号为例,看一下具体的实现流程:

xgbLib := syscall.NewLazyDLL("xgboost.dll")
procXGBoostVersion = xgbLibrary.NewProc("XGBoostVersion")
var major, minor, patch int32

// 传入参数要求是 C 的 int*,所以,需要通过 uintptr(unsafe.Pointer(&major)) 来实现
_, _, err := procXGBoostVersion.Call(uintptr(unsafe.Pointer(&major)), uintptr(unsafe.Pointer(&minor)), uintptr(unsafe.Pointer(&patch)))
if err != nil && strings.Index(err.Error(), "success") == -1 {
   panic(err)
}

// 此时,我们就获取到了版本号,输出即可
fmt.Printf("XGBoost version is %d.%d.%d\n", major, minor, patch)

需要注意的一点是,Call 函数成功调用了,返回的 error 也不是 nil,而是一个调用成功的提示。所以,判断是否出错不能通过 err != nil 来判断。

由于 xgboost 的函数都是返回错误码的形式,所以,我们可以对其进行一个统一的安全调用封装,即:

func safeCall(proc *syscall.LazyProc, args ...uintptr) (err error) {
   if proc == nil {
      return errors.New("remote function is not initialized")
   }

   defer func() {
      if r := recover(); r != nil {
         err = fmt.Errorf("panic: %v", r)
      }
   }()

   ret, _, err := proc.Call(args...)
   if ret != 0 {
      errMsg := GetLastError()
      return fmt.Errorf("remote function `%s` returns %d, msg: %s", proc.Name, ret, errMsg)
   }

   if err != nil && strings.Index(err.Error(), "success") == -1 {
      return fmt.Errorf("remote function `%s` failed: %v", proc.Name, err.Error())
   }

   return nil
}

那么,剩下的就是参数如何转换了。数值类型,可以直接用 uintptr(val) 的方式转换,指针类型,可以借助 unsafe.Pointer 来转换。比如 uintptr(unsafe.Pointer(&val))

把 GO 语言的字符串转换成 const char*,则可以通过下面这种方法实现:

func strPtrParam(s string) uintptr {
   bytes, err := syscall.BytePtrFromString(s)
   if err != nil {
      return nullptr
   }

   return uintptr(unsafe.Pointer(bytes))
}

而反过来,要把一个 const char* 转换成 GO 的字符串,则可以通过下面这种方法:

func ConvertGoString(ret uintptr) string {
   p := (*byte)(unsafe.Pointer(ret))
   if p == nil {
      return ""
   }

   msg := make([]byte, 0)
   for *p != 0 {
      msg = append(msg, *p)
      ret += unsafe.Sizeof(byte(0))
      p = (*byte)(unsafe.Pointer(ret))
   }

   return string(msg)
}

以上,便完成了调用,本代码在 Windows 下测试通过。

Linux 下的调用方式

Linux 下可以直接用 CGO 的方式调用。即首先,引入 xgboost 的头文件和库文件,我们的例子下,头文件放在当前目录,命名为 c_api.h,库文件放在同级目录 lib 下,命名为 libxgboost.so

那么,我们在代码中加入以下 import 代码:

/*
#cgo CFLAGS: -I.
#cgo LDFLAGS: -L../lib -lxgboost -Wl,-rpath,lib
#include "c_api.h"
*/
import "C"

注意,import "C" 上面的注释也是代码的一部分,如果这个注释和 import "C" 之间空了一行的话,会编译报错。上面的 #cgo CFLAGS: -I. 指定了 c_api.h 所在的目录,. 则表示在当前目录。#cgo LDFLAGS: -L../lib -lxgboost -Wl,-rpath,lib 则指定了 so 库的链接选项,-L../lib 指定了 so 库所在的路径,-lxgboost 指定了 so 库的名称。比如这里指定了名称叫 xgboost,则编译器就会去寻找 libxgboost.so 这个文件。

引入 CGO 之后,就可以直接像 C 语言那样直接调用库中的函数了,比如,获取版本号,就可以这么调用:

func Version() (major, minor, patch int32) {
   var cMajor, cMinor, cPatch C.int
   C.XGBoostVersion(&cMajor, &cMinor, &cPatch)
   return int32(cMajor), int32(cMinor), int32(cPatch)
}

所有定义在 c_api.h 中类型都可以直接使用,比如:

var boosterHandle C.BoosterHandle
ret := C.XGBoosterCreate((*C.DMatrixHandle)(nil), 0, &boosterHandle)
if ret != 0 {
   err = fmt.Errorf("XGBoosterCreate returns %d, error is %s", ret, GetLastError())
}

GO 字符串转 C 字符串,可以通过 C.CString(str) 实现,而 C 字符串转 GO 字符串,则可以通过 C.GoString(cstr) 实现。

以上,便完成了调用,本代码在 Debian 系统下测试通过。