MATLAB C++ MEX 文件函数 实现原理与使用技巧详解

527 阅读17分钟

要在MATLAB中调用C++代码,最好的方法就是使用MEX文件函数。MATLAB官方给MEX文件函数提供的文档可以说是极其简略,很多原理和技巧都需要自行摸索。本文假设你已经阅读了官方提供的文档,文档中已经给出的知识本文不会再详述。

作者使用的开发环境是Windows系统下的 Visual Studio 和 Visual Studio Code,MSVC编译器,因此有些知识还需要参考微软的文档,本文也不会赘述,需要读者自行查阅了解。如果你使用其它环境亦可供参考。

除此之外,你还需要对C++的语言特性和编译链接流程有一定了解,并能在遇到未知时有查阅标准文档或其它教程的能力。本文不会赘述C++相关知识。

MEX文件函数实现原理

网上很多教程都只讲了MEX文件函数怎么编写、编译、使用,对其实现原理基本没有提及。这种一知半解的开发者,就会写出很多包含匪夷所思的bug的代码,或遇到难以解决的性能问题。所以我们先从底层原理开始说起。

如果你见过编译好的MEX文件函数,它其实就是一个.mexw64文件。本质上,它也不是什么MATLAB内部自定义的神秘文件格式,而是一个标准的 Win32 DLL ——只有扩展名的区别而已。如果你不了解 Win32 DLL,请参阅微软文档,本文不做详述。

既然是DLL,就必须有导出函数。用VSCode打开MATLAB安装目录\extern\include,搜索MEXFUNCTION_LINKAGE,可以找到cppmex\detail\mexFunctionAdapterImpl.hpp中定义了3个导出函数 mexCreateMexFunction, mexDestroyMexFunction和mexFunctionAdapter。而MEXFUNCTION_LINKAGE则被定义为extern "C" __declspec(dllexport),这是标准的 Win32 DLL 导出函数声明。除了这三个函数之外,再也没有其它函数被导出了。这就意味着,MATLAB只能通过这三个函数与用户编写的MEX文件函数交换信息。

这三个函数代码不长,我们可以一行一行看看它们究竟干了什么:

//这个函数在MATLAB首次加载本MEX文件函数的时候调用
MEXFUNCTION_LINKAGE
void* mexCreateMexFunction(void (*callbackErrHandler)(const char*, const char*)) {
    try {
    
        matlab::mex::Function *mexFunc = mexCreatorUtil<MexFunction>();
        //构造MexFunction对象。MATLAB文档中要求用户定义一个MexFunction类型,这里就是这个类型被实例化的地方。实例化后,指针被返回给MATLAB。
        
        return mexFunc;
    } catch(...) {
    
        mexHandleException(callbackErrHandler);
        //如果MexFunction构造失败,向MATLAB返回错误信息。
        
        return nullptr;
    }
}
//这个函数在MATLAB会话终止或执行 clear mex 命令时调用
MEXFUNCTION_LINKAGE
void mexDestroyMexFunction(void* mexFunc,
                           void (*callbackErrHandler)(const char*, const char*)) {
    matlab::mex::Function* mexFunction = reinterpret_cast<matlab::mex::Function*>(mexFunc);
    try {
    
        mexDestructorUtil(mexFunction);
        //析构MexFunction对象
        
    } catch(...) {
        mexHandleException(callbackErrHandler);
        return;
    }
}
//这个函数在每次调用MEX文件函数时调用
MEXFUNCTION_LINKAGE
void mexFunctionAdapter(int nlhs_,
                        int nlhs,
                        int nrhs,
                        void* vrhs[],
                        void* mFun,
                        void (*callbackOutput)(int, void**),
                        void (*callbackErrHandler)(const char*, const char*)) {

    matlab::mex::Function* mexFunction = reinterpret_cast<matlab::mex::Function*>(mFun);
    //这里获取的就是mexCreateMexFunction返回的MexFunction指针

    std::vector<matlab::data::Array> edi_prhs;
    edi_prhs.reserve(nrhs);
    implToArray(nrhs, vrhs, edi_prhs);
    //将ABI转译为C++标准类型
    std::vector<matlab::data::Array> edi_plhs(nlhs);
    matlab::mex::ArgumentList outputs(edi_plhs.begin(), edi_plhs.end(), nlhs_);
    matlab::mex::ArgumentList inputs(edi_prhs.begin(), edi_prhs.end(), nrhs);

    try {
    
        (*mexFunction)(outputs, inputs);
        //调用operator()
        
    } catch(...) {
        mexHandleException(callbackErrHandler);
        return;
    }

    arrayToImplOutput(nlhs, edi_plhs, callbackOutput);
    //将C++标准类型转回ABI
}

综上,这三个导出函数本质上干的就是MexFunction的构造、调用和析构三件事,涵盖了一个C++类型的标准生命周期,只不过额外做了一些ABI和异常处理的工作。因此,MATLAB文档里只要求用户定义MexFunction即可。由于C++会自动添加构造和析构函数,所以一般来说用户只需要定义operator()。但文档没说的是,用户仍可以自定义构造和析构。只不过构造函数必须是无参的,析构函数必须是虚析构。这就方便我们在首次调用时分配持久资源,以及在清理阶段释放资源等一次性操作。

由此我们可以得出一个结论,编译MEX文件函数并不需要执行MATLAB提供的编译命令。你可以在任何一款编译器上编译出 Win32 DLL,只要它至少能正确导出以上三个C风格ABI的函数,就可以将其扩展名改为mexw64,然后就可以被MATLAB当作MEX文件函数来调用了。你甚至还可以让它导出更多函数以供你解决方案中的其它项目使用。但需要注意,链接时至少要附加依赖extern\lib\win64\microsoft下的libMatlabDataArray.lib和libmex.lib两个静态库。如果你使用了其它本文未提及的高级功能,可能还需要附加该目录下的更多静态库。

此外,细心的读者可能会发现,这三个函数都是外部链接的,并且定义在头文件而非源文件中!这就意味着,如果该头文件被多个编译单元包含,将会导致符号重定义错误。实际上,MATLAB提供的所有C++代码都只有头文件,其中不乏类似这样的外部链接函数定义。只要有任何一个这样的头文件被多个编译单元包含,就会导致链接时发生符号重定义。不得不说这无疑是MATLAB的一个重大设计缺陷,并且2024年了都没有丝毫要改邪归正的意思。开发者只有两个选择,要么控制#include的使用以确保每个符号只存在于单个编译单元中;要么修改这些头文件,将其改为inline或static(但此方法不适用于DLL导出函数)或者仅声明,而将定义移到自己的源文件中。作者更倾向于后者(但如果你没读过本文,只根据粗浅的官方文档和其它教程,恐怕根本想不到后一个方法,只能乖乖控制#include)

MATLAB与C++信息交换

MATLAB官方文档中最大的篇幅都是在讲述各种MATLAB数据类型如何与C++标准类型互相转换。但是它主要是以非常有限的代码示例来讲解,没有详述每个类型的实现原理。

MexFunction::operator()的所有输入输出都通过matlab::mex::ArgumentList传递进来,它本质上是一个包含多个matlab::data::Arraystd::vector。Array则是所有MATLAB数据类型的通用包装器。这样做是为了解决MATLAB和C++对类型处理的一个重要区别:MATLAB类型是动态的,可以在运行时改变,而在编码时是未知的。但C++类型是静态的,必须在编码时已知,不能运行时改变。因此,MATLAB数据类型不能直接投射为C++类型,只能在运行时转换。Array的运行时类型可以通过getType方法获取一个matlab::data::ArrayType枚举,这个枚举值标识了这个Array对应的MATLAB类型:

//include\MatlabDataArray\ArrayType.hpp
namespace matlab {
    namespace data {
        enum class ArrayType : int {
            LOGICAL,
            CHAR,
            MATLAB_STRING,           
            DOUBLE,
            SINGLE,
            INT8,
            UINT8,
            INT16,
            UINT16,
            INT32,
            UINT32,
            INT64,
            UINT64,
            COMPLEX_DOUBLE,
            COMPLEX_SINGLE,
            COMPLEX_INT8,
            COMPLEX_UINT8,
            COMPLEX_INT16,
            COMPLEX_UINT16,
            COMPLEX_INT32,
            COMPLEX_UINT32,
            COMPLEX_INT64,
            COMPLEX_UINT64,
            CELL,
            STRUCT,
            OBJECT,
            VALUE_OBJECT,
            HANDLE_OBJECT_REF,
            ENUM,
            SPARSE_LOGICAL,
            SPARSE_DOUBLE,
            SPARSE_COMPLEX_DOUBLE,
            UNKNOWN,
        };
    }
}

除了稀疏矩阵需要使用SparseArray模板外,所有Array类型都可以转换为对应类型的TypedArray模板。TypedArray模板采用了C++静态类型参数,在从Array转换时会判断运行时类型是否匹配静态类型,不匹配则出错。在MatlabDataArray\GetArrayType.hpp中可以查到C++静态类型与MATLAB动态类型的匹配关系:

namespace matlab {
    namespace data {

        class Array;
        class Struct;
        class Enumeration;
        class Object;
        template<typename T> class SparseArray;
        
        struct GetCellType { static const ArrayType type = ArrayType::CELL; };
        struct GetStringType { static const ArrayType type = ArrayType::UNKNOWN; };
        
        template<typename T> struct GetArrayType;
        template<> struct GetArrayType<bool> { static const ArrayType type = ArrayType::LOGICAL; };
        template<> struct GetArrayType<CHAR16_T> { static const ArrayType type = ArrayType::CHAR; };

        template<> struct GetArrayType<MATLABString> { static const ArrayType type = ArrayType::MATLAB_STRING; };
        template<> struct GetArrayType<double> { static const ArrayType type = ArrayType::DOUBLE; };
        template<> struct GetArrayType<float> { static const ArrayType type = ArrayType::SINGLE; };

        template<> struct GetArrayType<int8_t> { static const ArrayType type = ArrayType::INT8; };
        template<> struct GetArrayType<int16_t> { static const ArrayType type = ArrayType::INT16; };
        template<> struct GetArrayType<int32_t> { static const ArrayType type = ArrayType::INT32; };
        template<> struct GetArrayType<int64_t> { static const ArrayType type = ArrayType::INT64; };

        template<> struct GetArrayType<uint8_t> { static const ArrayType type = ArrayType::UINT8; };
        template<> struct GetArrayType<uint16_t> { static const ArrayType type = ArrayType::UINT16; };
        template<> struct GetArrayType<uint32_t> { static const ArrayType type = ArrayType::UINT32; };
        template<> struct GetArrayType<uint64_t> { static const ArrayType type = ArrayType::UINT64; };
#if !defined(__linux__) && !defined(_WIN32)
        template<> struct GetArrayType<unsigned long> { static const ArrayType type = ArrayType::UINT64; };
#endif
        template<> struct GetArrayType<std::complex<int8_t>> { static const ArrayType type = ArrayType::COMPLEX_INT8; };
        template<> struct GetArrayType<std::complex<int16_t>> { static const ArrayType type = ArrayType::COMPLEX_INT16; };
        template<> struct GetArrayType<std::complex<int32_t>> { static const ArrayType type = ArrayType::COMPLEX_INT32; };
        template<> struct GetArrayType<std::complex<int64_t>> { static const ArrayType type = ArrayType::COMPLEX_INT64; };

        template<> struct GetArrayType<std::complex<uint8_t>> { static const ArrayType type = ArrayType::COMPLEX_UINT8; };
        template<> struct GetArrayType<std::complex<uint16_t>> { static const ArrayType type = ArrayType::COMPLEX_UINT16; };
        template<> struct GetArrayType<std::complex<uint32_t>> { static const ArrayType type = ArrayType::COMPLEX_UINT32; };
        template<> struct GetArrayType<std::complex<uint64_t>> { static const ArrayType type = ArrayType::COMPLEX_UINT64; };
        template<> struct GetArrayType<std::complex<double>> { static const ArrayType type = ArrayType::COMPLEX_DOUBLE; };
        template<> struct GetArrayType<std::complex<float>> { static const ArrayType type = ArrayType::COMPLEX_SINGLE; };

        template<> struct GetArrayType<Array> { static const ArrayType type = ArrayType::CELL; };
        template<> struct GetArrayType<Struct> { static const ArrayType type = ArrayType::STRUCT; };
        template<> struct GetArrayType<Enumeration> { static const ArrayType type = ArrayType::ENUM; };
        template<> struct GetArrayType<Object> { static const ArrayType type = ArrayType::OBJECT; };
        template<> struct GetArrayType<SparseArray<bool>> { static const ArrayType type = ArrayType::SPARSE_LOGICAL;  };
        template<> struct GetArrayType<SparseArray<double>> { static const ArrayType type = ArrayType::SPARSE_DOUBLE; };
        template<> struct GetArrayType<SparseArray<std::complex<double>>> { static const ArrayType type = ArrayType::SPARSE_COMPLEX_DOUBLE; };

        template<typename T> struct GetSparseArrayType;
        template<> struct GetSparseArrayType<bool> { static const ArrayType type = ArrayType::SPARSE_LOGICAL;  };
        template<> struct GetSparseArrayType<double> { static const ArrayType type = ArrayType::SPARSE_DOUBLE; };
        template<> struct GetSparseArrayType<std::complex<double>> { static const ArrayType type = ArrayType::SPARSE_COMPLEX_DOUBLE; };


    }
}

可以看到,并非所有的MATLAB类型都可以转为C++类型。对于可以转换的类型,除了平凡的逻辑、算术和复数类型外,还有几个特殊的要点:

  • MATLAB char 类型对应的是 C++ CHAR16_T 类型,string对应的则是MATLABString。MATLABString是matlab::data::optional<matlab::data::String>的别名,而String则是std::basic_string<char16_t>的别名。这意味着,MATLAB字符串是UTF16编码的,而不是C++中常用的UTF8。在Windows上,C++ char16_t 和 wchar_t 本质上是一样的(C++标准认为它们是不同的类型,但可以reinterpret_cast)。所以处理MATLAB字符串时,需要用宽字符或u16版本的函数和类型,必要时还得转UTF8。注意,无论是MATLAB还是C++标准库,都不提供任何UTF8和UTF16字符串编码转换的方法(即使表面上看上去像是提供了,实际测试就会发现并非如此)。例如,虽然CharArray提供了toAscii方法,但它只是单纯截取了每个char16_t的低8位,并不是正统的编码转换,仅适用于纯ASCII字符串。如果字符串包含非ASCII字符,则不能用任何MATLAB提供的方法进行转换。一般来说,需要使用Windows提供的 Win32 API 或者其它第三方C++库实现编码转换。
  • MATLAB cell 数组的元素是 C++ Array,这其实很好理解,因为cell本身就是可以包含任何类型的MATLAB类型包装器,跟 C++ Array 的设计功能完全重合。
  • 所有的MATLAB结构体都是相同的 C++ Struct 类型,所有的MATLAB枚举都是相同的 C++ Enumeration,所有的MATLAB类对象都是相同的 C++ Object。因为这些类型都是用户自定义的,C++不能识别MATLAB的用户自定义类型。因此允许对这些类型在C++领域执行的操作是非常有限的。这也意味着,一般来说我们应避免将这些类型传递给MEX文件函数,而是尽量转成其它更方便操作的类型。例外是Struct,Struct在C++中提供了方便的字符串索引方法,类似于CellArray那样,可以作为万能的Array容器。

MATLAB 引用和避免拷贝等性能改进技巧

出于性能考虑,MATLAB尽可能通过指针和引用与C++交换信息。Array本身并不拥有数组中的实际数据,而是通过引用计数和写入时复制机制访问数据:

  • C++std::shared_ptr引用计数。当Array和不同的子类之间发生拷贝构造或转换构造时,引用计数就会加一。使用C++引用或移动构造可以避免拷贝。但即使发生了拷贝构造,也只是 C++ shared_ptr 的拷贝,不会真正拷贝数据。
  • 写入时复制。调用非const版本的operator[]以及使用release释放缓冲区时,都会发生真正的数据复制。当然,这种复制仍然是浅拷贝,只会拷贝本层Array,不会拷贝元胞或结构体中的下层Array。

因此,C++开发者应当审慎考虑数据使用方法,尽可能使用引用而避免拷贝。除此之外,还需要考虑隐式转换的开销:有些写法表面上减少了转换步骤,实际上由于中间存在隐式转换,未必能减少性能开销。通常来说,应该按照以下优先顺序使用数据,只有当上级无法实现目的时再使用下级:

  1. C++标准常量引用。这种用法几乎没有开销,性能最高。但是,它不允许任何类型转换。即使是隐式转换也会实际上构造新对象,这种情况下引用的使用不能带来任何性能收益。一个典型的例子就是对数组元素的访问,由于operator[]返回的并非C++标准引用,在转换到C++标准引用时就会带来额外开销。而且由于常量要求,不能修改数据。
  2. MATLAB引用。MATLAB为C++定义了多种特定条件下使用的引用类型,各有不同的使用方法,细节很难一一枚举。一般来说,对Array及子类使用operator[]可以获得MATLAB引用类型。对于非const的POD类型,还可以转换为C++标准引用。POD类型的常量引用通常没有意义,因为其开销不低于直接拷贝。对于非POD类型,例如从CellArray中提取Array时,无法使用C++标准引用,通常只能使用MATLAB引用。注意,始终确保调用const版本的operator[]和迭代器来获取MATLAB引用。非const版本的使用会造成很大的开销。
  3. C++移动构造。这种操作只会发生指针拷贝并构造新对象,不涉及引用计数和数据拷贝。但是,MATLAB数组只允许整体移动,而不能单独移走其中的一个元素。对POD类型,移动和拷贝开销相同,无意义。对复杂类型,只能使用MATLAB引用的operator转换,shared_ptr的拷贝和引用计数的增加无法避免。
  4. 拷贝构造。如上述,拷贝构造并不拷贝真实数据,而是拷贝Array内置的shared_ptr并增加引用计数。只有在对新对象进行修改时才会发生实际的数据拷贝。
  5. 非 const operator[]和迭代器。对Array及其子类使用非const版本的operator[]和迭代器视为对数组的修改,即使没有实际的修改发生。这将导致实际数据被拷贝。MATLAB内部当然有优化策略,可以使得对同一个引用的多次迭代不发生重复拷贝。但是,它总归是要检查是否需要拷贝的,所以开销一定会大于const版本。
  6. release。release意味着将MATLAB托管的数组完全交给C++处理,这几乎一定会导致发生拷贝。但它的优点是一次性拷贝完之后就不用再检查是否需要拷贝了,所以相比于循环使用非 const operator[] 和迭代器来说,release有时反而性能更高。之所以将release放在最后,是因为对于单个元素的访问release是性能最差的,但对多个元素的循环访问情境下release往往是更优解。

如果你理解了上述规则,可能会发现,MATLAB似乎有种怪癖,就是死死捂着它自己所托管的数据不撒手,你只能看不能摸,想摸就得先拷贝。这其实不能怪MATLAB不信任开发者,而可能与MATLAB内部对数组的优化机制有关,例如延迟计算:当使用zeros创建全0数组后,内存中可能只存在这个数组的抽象定义,而不实际存在这样一段全0的内存,只在必要的时候才会将其展开为真正的内存中的数组。但如果要将这个数组交给C++处理,C++是不能理解“全零数组”这个抽象定义的。所以如果C++需要只读访问,尚可以直接给它返回一个0;但如果C++需要修改或接管这个数组,那就只能展开成真正的全零数组再交给C++了。我们上述的拷贝只是一个抽象意义上的拷贝,实际上可能是内存中不实际存在的抽象数组定义展开成真正的数组。

异常处理

mexFunctionAdapterImpl.hpp中定义了MEX默认的异常处理机制:

inline void mexHandleException(void (*callbackErrHandler)(const char*, const char*)) {
    try {
        throw;
    } catch(const matlab::engine::MATLABException& ex) {
        callbackErrHandler(ex.what(), ex.getMessageID().c_str());
    } catch(const matlab::engine::Exception& ex) {
        callbackErrHandler(ex.what(), "");
    } catch(const matlab::Exception& ex) {
        callbackErrHandler(ex.what(), "");
    } catch(const std::exception& ex) {
        callbackErrHandler(ex.what(), "");
    } catch(...) {
        callbackErrHandler("Unknown exception thrown.", "");
    }
}

这个函数为所有C++标准异常提供了托底的处理机制。但是它存在几个不足之处:

  • 只有列出的几种异常类型能够返回有效的详细信息显示在MATLAB中,其它异常只能返回一个未知。
  • 只能捕获C++标准异常,不能捕获 Windows SEH 异常(如内存访问越界、堆栈损坏、重复析构等)
  • 不能捕获违反noexcept约定抛出的异常。这个问题无法解决,因为C++标准规定违反noexcept约定就必须终止进程,所以这一定会直接导致MATLAB进程被强制退出。

因此,建议的做法是,开发者应自行捕获一些特定类型的异常,转换成matlab::engine::MATLABException后再次抛出,以便将重要的异常信息在MATLAB中呈现。对于SEH异常,则必须使用MSVC特定的__try__except语法,并且有函数内不能展开对象、不能同时使用C++标准异常处理等限制,详见微软文档。一般来说,应当将捕获的SEH异常重新throw为不需要展开的平凡类型,再在上级函数中做处理。

clear mex

MATLAB的clear mex命令会立即卸载MEX文件函数。这不仅导致MexFunction被析构,还会释放所有MEX构造的对象的内存。这种操作非常危险,因为它假定所有对象都是平凡类型,而不会去依次执行每个对象的析构函数,这可能导致对象未被正确释放。此外,如果之前构造的对象以指针或其它等效的索引句柄等的形式返回给了MATLAB保存在工作区变量中,这些指针将变为无效的野指针。如果将这些指针再次传递给被重新载入的MEX,将导致未定义行为,包括但不限于数据损坏、意外结果、MATLAB进程崩溃等。因此,开发者必须考虑到clear mex命令可能会对对象生命周期管理带来的干扰。

幸运的是,可以通过实现MexFunction的虚析构方法来解决这个问题。如前述,clear mex在释放内存之前,会先通过mexDestroyMexFunction导出函数调用MexFunction的析构。因此,应当在这个析构函数中妥善清理之前创建的所有C++对象。一种常见的做法是,每当需要向MATLAB返回指针时,将指针和/或其析构方法同时记录在一个全局容器中。这样在MexFunction析构时,就可以遍历这个全局容器,以正确释放所有对象。同时,当接受MATLAB传来的指针参数时,也应当检查指针是否存在于容器中:如果存在,则认为是有效的指针并执行处理;如不存在,说明是野指针,应当抛出异常。

示例代码

作者编写了一些体现上述内容的示例代码,放在GitHub存储库中,欢迎交流学习。