实现std::function
std::function介绍
std::function是C++11 引入的一个通用函数包装器,它提供了一种类型安全的方式来存储、复制和调用任何可调用对象(函数、成员函数、lambda 表达式、函数对象等)。在学习std::function时,我非常好奇它是如何将各种不同类型的可调用对象使用统一且安全的方式存储的,而不需要关心它们的具体类型。探索其底层原理时发现了std::function使用了类型擦除的技术保障这一点。
以下是std::function的使用示例:
// 1. 包装普通函数
int add(int a, int b) { return a + b; }
std::function<int(int, int)> func1 = add;
std::cout << func1(2, 3) << std::endl; // 输出:5
// 2. 包装lambda表达式
std::function<void(int)> func2 = [](int x) {
std::cout << "x = " << x << std::endl;
};
func2(42); // 输出:x = 42
// 3. 包装成员函数
class Calculator {
public:
int multiply(int a, int b) { return a * b; }
};
Calculator calc;
std::function<int(int, int)> func3 = std::bind(&Calculator::multiply, &calc,
std::placeholders::_1, std::placeholders::_2);
std::cout << func3(4, 5) << std::endl; // 输出:20
// 4. 作为回调函数
class Button {
std::function<void()> onClick;
public:
void setOnClick(std::function<void()> callback) {
onClick = callback;
}
void click() {
if (onClick) onClick();
}
};
std::function实现
先看std::function的使用示例std::function<返回值类型(参数类型列表)> func;,由于其中返回值类型和参数类型列表是不确定的,所以我们这里应该使用模板类:
template <class _Ret, class ..._Args>
struct Function<_Ret(_Args...)>
这就是Function类的模板定义,注意我们在_Ret(_Args...)的_Args后面要加...标志着将可变参数模板的参数包展开。此外,为了防止用户初始化Function时出现错误,我们还需要特化一个Function的类模板定义,在用户初始化语法错误时进行报错:
template <class _FnSig>
struct Function{
// 只在使用了不符合 Ret(Args...) 模式的 FnSig 时会进入此特化,导致报错
static_assert(!std::is_same_v<_FnSig, _FnSig>, "not a valid function signature");
};
其中static_assert是一个编译期断言,如果编译器发现用户声明的Function容器不符合语法,在编译器就会出现报错,std::is_same_v<_FnSig, _FnSig> 这个表达式实际上永远为 true,所以加上!后就永远为false,但是因为 _FnSig 是模板参数,编译器必须等到模板实例化时才能确定这个值,这样就实现了"延迟"到模板实例化时才报错的效果,这种设计确保了用户只能使用正确的函数签名格式来实例化 Function 模板。
下面给出一些报错示例:
// 这些会导致编译错误:
Function<int>; // 错误:不是有效的函数签名
Function<std::string>; // 错误:不是有效的函数签名
Function<void>; // 错误:不是有效的函数签名
// 这些是正确的用法:
Function<int(int)>; // 正确:接受一个int参数,返回int的函数
Function<void(string)>; // 正确:接受一个string参数,返回void的函数
Function<double(int, string)>; // 正确:接受int和string参数,返回double的函数
回到 Function<_Ret(_Args...)>中的实现,在这个模板类中,我们定义了一个嵌套类,或者说内部类--_FuncBase,这个类的作用是实现类型擦除,提供一个统一的接口,使得不同类型的可调用对象(函数、lambda、成员函数等)都可以被统一处理。使得std::function可以存储任意类型的可调用对象,而不需要知道具体类型。
struct _FuncBase{
virtual _Ret _M_call(_Args ...__args) = 0; // 类型擦除后的统一接口
virtual std::unique_ptr<_FuncBase> _M_clone() const = 0; // 原型模式,克隆当前函数对象
virtual std::type_info const &_M_type() const = 0; // 获得函数对象类型信息
virtual ~_FuncBase() = default;
};
下面开始详细解释这个类:
-
这个抽象基类定义了统一的接口,供不同的函数对象实现,具体的函数实现是在下面会介绍的派生类
_FuncImpl中完成的,虚函数允许我们通过基类指针调用派生类的实现,这样就能在运行时动态地处理不同类型的可调用对象。 -
_M_call函数是统一的函数调用接口,其中_Ret是返回值类型,_Args是可变函数参数列表,它使得不同类型的可调用对象都能通过统一的接口被调用。 -
_M_clone函数是克隆接口,原型模式,用于复制函数对象,它返回一个智能指针,指向新创建的副本,后面的const参数表示,它不会修改当前对象的内容,这个函数支持 std::function 的拷贝语义,使得函数对象可以被安全地复制。// 假设我们有一个 std::function 对象 std::function<int(int)> func1 = [](int x) { return x * 2; }; // 当我们进行拷贝时 std::function<int(int)> func2 = func1; // 这里会调用 _M_clone // 内部发生的过程大致如下: // 1. 原始对象 func1 中存储了一个 _FuncImpl<lambda类型> 对象 // 2. 当执行拷贝时,会调用 _M_clone // 3. _M_clone 会创建一个新的 _FuncImpl<lambda类型> 对象 // 4. 新对象会复制原始 lambda 的状态 // 实际调用过程: // func1._M_base->_M_clone() 会返回一个新的 unique_ptr<_FuncBase> // 这个新指针指向一个复制的函数对象 // 验证两个对象是独立的 func1 = [](int x) { return x * 3; }; // 修改 func1 std::cout << func1(2) << std::endl; // 输出 6 std::cout << func2(2) << std::endl; // 输出 4,说明 func2 没有被修改由此可见,
_M_clone实现了深拷贝,确保每个std::function对象都有自己独立的函数对象副本,支持std::function对象的拷贝操作和赋值操作。如果没有_M_clone的话,我们就无法安全的复制std::function对象,因为直接复制指针会导致多个std::function共享函数对象,这可能导致意外的行为,比如一个对象修改了函数,其他对象也会受到影响,在对象销毁时也可能导致重复释放内存。 -
_M_type获取实际函数对象的类型信息,返回 type_info 的引用,用于类型识别,帮助用户了解当前存储的是什么类型的可调用对象。 -
基类的析构函数一般都定义为虚函数,以确保正确释放资源,通过基类指针删除派生类对象时能正确调用派生类的析构函数,底层由虚表实现。
接下来介绍Function中核心的类_FuncImpl,这个类是_FuncBase 的具体实现类,它负责存储和调用实际的函数对象。作为模板类,可以为每种类型的可调用对象生成具体的实现,实现基类 _FuncBase 定义的所有虚函数。
template<class _Fn>
struct _FuncImpl : _FuncBase{
// FuncImpl 会被实例化多次,每个不同的仿函数类都产生一次实例化,存储的实际函数对象
_Fn _M_f;
template <class ..._CArgs>
explicit _FuncImpl(std::in_place_t, _CArgs &&...__args) : _M_f(std::forward<_CArgs>(__args)...){}
_Ret _M_call(_Args ...__args) override{
// 完美转发所有参数给构造时保存的仿函数对象:
return _M_f(std::forward<_Args>(__args)...);
}
std::unique_ptr<_FuncBase> _M_clone() const override{
return std::make_unique<_FuncImpl>(std::in_place, _M_f);
}
std::type_info const &_M_type() const override {
return typeid(_Fn);
}
};
下面详细解释一下这个构造函数。使用了可变模板参数template <class ..._CArgs>,可以接受任意数量和类型的参数,这些参数是用来构建可调用对象的,而不是调用函数时传递的参数。
std::in_place_t是一个标记类型,它的主要作用是区分构造函数,防止构造函数重载时的歧义。如果不使用这个参数,就会出现一下歧义:
// 假设没有 std::in_place_t,我们可能会这样写:
template <class ..._CArgs>
_FuncImpl(_CArgs &&...__args)
: _M_f(std::forward<_CArgs>(__args)...) {}
// 这样会导致问题:
_FuncImpl f1(42); // 是构造 _M_f(42) 还是其他含义?
_FuncImpl f2(f1); // 是拷贝构造还是构造 _M_f(f1)?
但是在使用std::in_place_t之后:
template <class ..._CArgs>
explicit _FuncImpl(std::in_place_t, _CArgs &&...__args)
: _M_f(std::forward<_CArgs>(__args)...) {}
// 现在调用必须显式指定 in_place:
_FuncImpl f1(std::in_place, 42); // 明确表示要构造 _M_f(42)
使用这个参数除了可以明确构造意图,避免歧义;还可以允许在容器中直接构造对象,避免额外的拷贝。
构造函数中的_CArgs &&...__args这个参数是一个可变参数列表的万能引用,用于接下来调用函数的完美转发--参数会被转发给 _M_f 的构造函数。其中...表示展开可变参数列表的参数包。
_M_f(std::forward<_CArgs>(__args)...)使用了C++中的完美转发特性,它是 C++ 中用于保持参数值类别(左值/右值)的重要特性。具体来说,_M_f 是存储实际函数对象的成员变量,它可能是函数指针、lambda 表达式、成员函数等任何可调用对象。std::forward<_CArgs>(__args)... 中的 std::forward 是一个模板函数,它通过 static_cast 将参数转换为原始的值类别,确保参数在传递过程中保持其左值或右值特性。
_CArgs &&...__args 是可变参数模板,允许构造函数接受任意数量和类型的参数。当这个构造函数被调用时,比如 _FuncImpl(std::in_place, 42, "hello"),参数会被完美转发给 _M_f 的构造函数,保持其原始的值类别。其中42是右值,而"hello"是左值。这种设计使得 _FuncImpl 能够高效地构造和存储各种类型的可调用对象,同时避免不必要的拷贝,支持移动语义,提高性能。
接下来是在_FuncImpl中实现的重新基类虚函数方法_M_call,这个函数是统一调用接口的,所以在当前类中需要把函数对象在这个函数中传参调用,而返回值和参数列表正是Function类中的template <class _Ret, class ..._Args>模板参数。该函数中的_M_f(std::forward<_Args>(__args)...);这句代码是将函数调用的参数完美转发给构造时保存的仿函数对象,让其进行函数调用,并且返回结果值。
剩下两个函数中值得一提的点就是_M_clone函数中的这个返回值std::make_unique<_FuncImpl>(std::in_place, _M_f);,这个操作调用了之前我们解释过的构造函数,其中传入参数_M_f对应其中的可变参数列表_CArgs &&...__args,因此在构造函数中将_M_f传递给这个可变参数列表时,相对于调用该函数对象的拷贝构造函数,而不是有参构造函数。
讲解完这两个内部类之后,我们再回过头来看主类Function的实现,Function类中定义了这个对象std::unique_ptr<_FuncBase> _M_base;,这个对象是存储和管理实际的函数对象,通过基类指针 _FuncBase* 实现类型擦除。具体来说_M_base指向了一个_FuncImpl对象(_FuncBase 的派生类),这个对象存储了实际的函数实现。
同时,unique_ptr 的独占所有权语义确保了函数对象不会被多个std::function对象共享,每个std::function对象都有自己独立的函数对象副本。这种设计使得std::function能够安全地管理各种类型的可调用对象,支持拷贝和移动操作,同时保证资源的正确释放。例如,当我们创建一个std::function对象时,_M_base 会被初始化为指向新创建的 _FuncImpl 对象,这个对象存储了实际的函数实现;当std::function对象被拷贝时,会通过_M_clone创建新的 _FuncImpl 对象;当 std::function 对象被移动时,_M_base 的所有权会被转移,确保资源不会被重复释放。
下面给出Function类中public成员的代码:
Function() = default; // _M_base 初始化为 nullptr
Function(std::nullptr_t) noexcept : Function() {}
// 此处 enable_if_t 的作用:阻止 Function 从不可调用的对象中初始化
// 另外标准要求 Function 还需要函数对象额外支持拷贝(用于 _M_clone)
template <class _Fn, class = std::enable_if_t<std::is_invocable_r_v<_Ret, std::decay_t<_Fn>, _Args...>
&& std::is_copy_constructible_v<_Fn>
&& !std::is_same_v<std::decay_t<_Fn>, Function<_Ret(_Args...)>> >>
Function(_Fn &&__f) // 没有 explicit,允许 lambda 表达式隐式转换成 Function
: _M_base(std::make_unique<_FuncImpl<std::decay_t<_Fn>>>(std::in_place, std::forward<_Fn>(__f)))
{}
Function(Function &&) = default;
Function &operator=(Function &&) = default;
Function(Function const &__that): _M_base(__that._M_base ? __that._M_base->_M_clone() : nullptr){}
Function &operator=(Function const &__that){
if(__that._M_base){
_M_base = __that._M_base->_M_clone();
}else{
_M_base = nullptr;
}
return *this;
}
explicit operator bool() const noexcept{
return _M_base != nullptr;
}
bool operator==(std::nullptr_t) const noexcept{
return _M_base == nullptr;
}
bool operator!=(std::nullptr_t) const noexcept {
return _M_base != nullptr;
}
_Ret operator()(_Args ...__args) const {
if (!_M_base) [[unlikely]]
throw std::bad_function_call();
// 完美转发所有参数,这样即使 Args 中具有引用,也能不产生额外的拷贝开销
return _M_base->_M_call(std::forward<_Args>(__args)...);
}
std::type_info const &target_type() const noexcept {
return _M_base ? _M_base->_M_type() : typeid(void);
}
template <class _Fn>
_Fn *target() const noexcept {
return _M_base && typeid(_Fn) == _M_base->_M_type() ? std::addressof(static_cast<_FuncImpl<_Fn> *>(_M_base.get())->_M_f) : nullptr;
}
void swap(Function &__that) const noexcept {
_M_base.swap(__that._M_base);
}
这段代码中有三个部分需要额外讲解,模板的构造函数,重载()运算符和target()函数。我们首先看带有模板的构造函数,
首先看构造函数这一复杂的模板参数:
template <class _Fn, class = std::enable_if_t<std::is_invocable_r_v<_Ret, std::decay_t<_Fn>, _Args...>
&& std::is_copy_constructible_v<_Fn>
&& !std::is_same_v<std::decay_t<_Fn>, Function<_Ret(_Args...)>> >>
这段模板参数使用了 SFINAE来限制Function构造函数的类型,它通过std::enable_if_t和类型萃取器实现。_Fn是函数实例化的模板参数,后面这个默认的模板参数通过std::enable_if_t控制模板的启用条件。条件包括三个部分:std::is_invocable_r_v<_Ret, std::decay_t<_Fn>, _Args...>检查_Fn是否可以接收_Args...作为参数并且返回_Ret类型。
std::is_copy_constructible_v<_Fn> 确保_Fn可以被拷贝构造,因为我们在前面实现的克隆函数就是利用了函数的可拷贝特性,!std::is_same_v<std::decay_t<_Fn>, Function<_Ret(_Args...)>> 防止 Function 对象被自身类型构造。具体来说,如果我们允许 Function 对象被自身类型构造,如果有如下代码
std::function<int(int)> f1;
std::function<int(int)> f2(f1);
这会导致 f2 内部存储一个_FuncImpl<Function<int(int)>>对象,我们应该存储_FuncImpl<int(int)>,如果不这样写的话就会导致不必要的嵌套。std::decay_t<_Fn> 用于移除_Fn的引用和 cv 限定符,确保类型比较的准确性。
接下来是重载括号运算符的函数:
_Ret operator()(_Args ...__args) const {
if (!_M_base) [[unlikely]]
throw std::bad_function_call();
// 完美转发所有参数,这样即使 Args 中具有引用,也能不产生额外的拷贝开销
return _M_base->_M_call(std::forward<_Args>(__args)...);
}
这个函数是std::function的调用运算符重载,它允许std::function对象像普通函数一样被调用。_Ret operator()(_Args ...__args) const 定义了一个返回类型为 _Ret、接受可变参数_Args...的 const 成员函数,const 表示这个函数不会修改std::function对象的状态。
函数首先检查_M_base是否为空(即std::function对象是否存储了有效的函数对象),如果为空,则抛出std::bad_function_call异常,表示尝试调用一个空的std::function对象。[[unlikely]] 是一个属性,提示编译器这个条件(_M_base 为空)不太可能发生,可以优化分支预测。如果 _M_base 不为空,函数会通过 _M_base->_M_call(std::forward<_Args>(__args)...) 调用存储的函数对象,std::forward 用于完美转发参数,保持参数的值类别(左值/右值),避免不必要的拷贝开销。例如,如果_Args中包含引用类型,std::forward 会确保参数以引用方式传递,而不是拷贝。
最后看到target()函数:
template <class _Fn>
_Fn *target() const noexcept {
return _M_base && typeid(_Fn) == _M_base->_M_type() ? std::addressof(static_cast<_FuncImpl<_Fn> *>(_M_base.get())->_M_f) : nullptr;
}
这个target()函数是std::function的一个模板成员函数,用于获取存储的函数对象的指针。函数首先检查_M_base是否为空,并且存储的函数对象类型是否与_Fn匹配(通过 typeid(_Fn) == _M_base->_M_type()),如果条件满足,则返回存储的函数对象的地址,否则返回 nullptr。std::addressof 用于获取对象的地址,即使对象重载了 operator&,也能正确获取地址。static_cast<_FuncImpl<_Fn> *>(_M_base.get()) 将 _M_base 转换为_FuncImpl<_Fn>*类型,然后通过 _M_f 成员访问存储的函数对象。这个函数的主要作用是允许用户获取 std::function 存储的具体函数对象的指针,例如,如果 std::function 存储了一个 lambda 表达式,用户可以通过 target<lambda类型>() 获取这个 lambda 的指针,进行进一步的操作或类型检查。
下面给出完整代码和测试代码:
//function.hh
#include <stdio.h>
#include <iostream>
#include <functional>
#include <memory>
#include <typeinfo>
#include <utility>
#include <type_traits>
template <class _FnSig>
struct Function{
// 只在使用了不符合 Ret(Args...) 模式的 FnSig 时会进入此特化,导致报错
// 此处表达式始终为 false,仅为避免编译期就报错,才让其依赖模板参数
static_assert(!std::is_same_v<_FnSig, _FnSig>, "not a valid function signature");
};
template <class _Ret, class ..._Args>
struct Function<_Ret(_Args...)>{
private:
struct _FuncBase{
virtual _Ret _M_call(_Args ...__args) = 0; // 类型擦除后的统一接口
virtual std::unique_ptr<_FuncBase> _M_clone() const = 0; // 原型模式,克隆当前函数对象
virtual std::type_info const &_M_type() const = 0; // 获得函数对象类型信息
virtual ~_FuncBase() = default; // 应对_Fn可能有非平凡析构的情况
};
template<class _Fn>
struct _FuncImpl : _FuncBase{
// FuncImpl 会被实例化多次,每个不同的仿函数类都产生一次实例化
_Fn _M_f;
template <class ..._CArgs>
explicit _FuncImpl(std::in_place_t, _CArgs &&...__args) : _M_f(std::forward<_CArgs>(__args)...){}
_Ret _M_call(_Args ...__args) override{
// 完美转发所有参数给构造时保存的仿函数对象:
return _M_f(std::forward<_Args>(__args)...);
}
std::unique_ptr<_FuncBase> _M_clone() const override{
return std::make_unique<_FuncImpl>(std::in_place, _M_f);
}
std::type_info const &_M_type() const override {
return typeid(_Fn);
}
};
std::unique_ptr<_FuncBase> _M_base; // 使用智能指针管理仿函数对象
public:
Function() = default; // _M_base 初始化为 nullptr
Function(std::nullptr_t) noexcept : Function() {}
// 此处 enable_if_t 的作用:阻止 Function 从不可调用的对象中初始化
// 另外标准要求 Function 还需要函数对象额外支持拷贝(用于 _M_clone)
template <class _Fn, class = std::enable_if_t<std::is_invocable_r_v<_Ret, std::decay_t<_Fn>, _Args...>
&& std::is_copy_constructible_v<_Fn>
&& !std::is_same_v<std::decay_t<_Fn>, Function<_Ret(_Args...)>> >>
Function(_Fn &&__f) // 没有 explicit,允许 lambda 表达式隐式转换成 Function
: _M_base(std::make_unique<_FuncImpl<std::decay_t<_Fn>>>(std::in_place, std::forward<_Fn>(__f)))
{}
Function(Function &&) = default;
Function &operator=(Function &&) = default;
Function(Function const &__that): _M_base(__that._M_base ? __that._M_base->_M_clone() : nullptr){}
Function &operator=(Function const &__that){
if(__that._M_base){
_M_base = __that._M_base->_M_clone();
}else{
_M_base = nullptr;
}
return *this;
}
explicit operator bool() const noexcept{
return _M_base != nullptr;
}
bool operator==(std::nullptr_t) const noexcept{
return _M_base == nullptr;
}
bool operator!=(std::nullptr_t) const noexcept {
return _M_base != nullptr;
}
_Ret operator()(_Args ...__args) const {
if (!_M_base) [[unlikely]]
throw std::bad_function_call();
// 完美转发所有参数,这样即使 Args 中具有引用,也能不产生额外的拷贝开销
return _M_base->_M_call(std::forward<_Args>(__args)...);
}
std::type_info const &target_type() const noexcept {
return _M_base ? _M_base->_M_type() : typeid(void);
}
template <class _Fn>
_Fn *target() const noexcept {
return _M_base && typeid(_Fn) == _M_base->_M_type() ? std::addressof(static_cast<_FuncImpl<_Fn> *>(_M_base.get())->_M_f) : nullptr;
}
void swap(Function &__that) const noexcept {
_M_base.swap(__that._M_base);
}
};
// test.cpp
#include <iostream>
#include <cassert>
#include "function.hh"
// 测试函数
void test_func_hello() {
std::cout << "Hello" << std::endl;
}
// 测试类
class Multiplier {
public:
Multiplier(int factor) : factor_(factor) {}
int operator()(int x) const {
return x * factor_;
}
private:
int factor_;
};
// 测试函数
void repeattwice(Function<void()> const &func) {
func();
func();
}
void test_basic_functionality() {
std::cout << "测试基本功能..." << std::endl;
// 测试普通函数
repeattwice(test_func_hello);
// 测试lambda表达式
Function<int(int)> f2 = [](int x) { return x * x; };
assert(f2(4) == 16);
// 测试函数对象
Function<int(int)> f3 = Multiplier(3);
assert(f3(4) == 12);
std::cout << "基本功能测试通过!" << std::endl;
}
void test_copy_and_move() {
std::cout << "\n测试拷贝和移动..." << std::endl;
Function<int(int)> original = [](int x) { return x + 1; };
// 测试拷贝构造
Function<int(int)> copy = original;
assert(copy(5) == 6);
// 测试移动构造
Function<int(int)> moved = std::move(original);
assert(moved(5) == 6);
// 测试拷贝赋值
Function<int(int)> assigned;
assigned = copy;
assert(assigned(5) == 6);
std::cout << "拷贝和移动测试通过!" << std::endl;
}
void test_null_checks() {
std::cout << "\n测试空值检查..." << std::endl;
Function<int(int)> empty;
assert(!empty);
assert(empty == nullptr);
Function<int(int)> non_empty = [](int x) { return x; };
assert(non_empty);
assert(non_empty != nullptr);
std::cout << "空值检查测试通过!" << std::endl;
}
void test_target_type() {
std::cout << "\n测试类型信息..." << std::endl;
Multiplier mult(2);
Function<int(int)> f = mult;
assert(f.target_type() == typeid(Multiplier));
assert(f.target<Multiplier>() != nullptr);
assert(f.target<int>() == nullptr);
std::cout << "类型信息测试通过!" << std::endl;
}
void test_exception_handling() {
std::cout << "\n测试异常处理..." << std::endl;
Function<int(int)> empty;
try {
empty(5);
assert(false && "应该抛出异常");
} catch (const std::bad_function_call&) {
std::cout << "异常处理测试通过!" << std::endl;
}
}
int main() {
test_basic_functionality();
test_copy_and_move();
test_null_checks();
test_target_type();
test_exception_handling();
std::cout << "\n所有测试通过!" << std::endl;
return 0;
}
测试结果: