C++ 踩坑:折叠表达式,Fold expression

300 阅读4分钟

C++ Varadic Template and Folder Expression 可变模板和折叠表达式

Parameter pack, 形参包

A template parameter pack is a template parameter that accepts zero or more template arguments (non-types, types, or templates). A function parameter pack is a function parameter that accepts zero or more function arguments.

A template with at least one parameter pack is called a variadic template.

Variadic function, 变长形参函数

一个简单的例子:

auto f(auto... args) {
    /** code **/
}

f();
f(1, 2, 3, 4);
f(1, "abc", []{});

函数 f 可以接受任意类型和个数的形参。函数形参类型也可以显式声明为模板类型

template<typename... T>
auto tf(T... args) {
    /** code **/ 
}

tf();
tf(1, 2, 3);
tf(1, "abc", []{});

限定符 const, & 也是可以被接受的:

auto f(const auto&... args) {
    /** code **/
}

template<typename... T>
auto tf(const T&... args) {
    /** code **/ 
}

计算变长参数的尺寸

可以通过 sizeof...(args) 来计算变长参数的尺寸,例如:

auto f(auto... args) {
    std::clog << sizeof...(args) << "\n";
}

template<typename... T>
auto tf(T... args) {
    std::clog << sizeof...(args) << "\n";
}

Fold expression 折叠表达式 (Since C++17)

通过二元运算符将形参包展开。

(  pack op  ... )(1)unary right fold
( ...  op pack  )(2)unary left fold
(  pack op  ...  op init  )(3)binary right fold
(  init op  ...  op pack  )(4)binary left fold

(1): (E op ...): 展开 (E1 op (... op (EN-1 op EN)))

(2): (... op E): 展开 (((E1 op E2) op ...) op EN)

(3): (E op ... op I): 展开 (E1 op (... op (EN-1 op (EN op I))))

(3): (I op ... op E): 展开 ((((I op E1) op E2) op ...) op EN)

其中 op 表示二元运算符,包括:

+ - * / % ^ & | = < > << >> += -= *= /= %= ^= &= |= <<= >>= == != <= >= && || , .* ->*

在 binary folder 中, 两个 op 运算符必须一样。示例:

template<typename... T>
auto unary_rigth_fold_minus(T... args) {
    return (args - ...);
}

template<typename... T>
auto unary_left_fold_minus(T... args) {
    return (... - args);
}

template<typename... T>
auto binary_rigth_fold_minus(T... args) {
    return (args - ... - 100);
}

template<typename... T>
auto binary_left_fold_minus(T... args) {
    return (100 - ... - args);
}

int main()
{
    unary_rigth_fold_minus(1, 2, 3);      // (1 - (2 - 3)) = 2
    unary_left_fold_minus(1, 2, 3);       // ((1 - 2) - 3) = -4;
    binary_rigth_fold_minus(1, 2, 3);     // (1 - (2 - 3)) - 100 = -98
    binary_left_fold_minus(1, 2, 3);      // 100 - (((1 - 2) - 3)) = 94;
}

完美转发,Perfect forwarding

在使用模板的折叠表达式时,对模板参数进行完美转发

template<typename... T>
auto fold_expression_perfect_forward_print(T&&... args) {
    return (callToFoo(std::forward<T>(args)), ...);
}

空形参包

只有三个运算符接受空参数包 &&, ||, ,,如下:

  • &&, 返回 true
  • ||, 返回 false
  • ,, 返回 void
auto fold_and_empty(auto... args){
    return (args && ...);
}

auto fold_or_empty(auto... args){
    return (args || ...);
}

auto fold_comma_empty(auto... args){
    return (args, ...);
}

int main()
{
    assert(fold_and_empty() == true);
    assert(fold_or_empty() == false);
    static_assert(std::is_same<decltype(fold_comma_empty()), std::decay<void>::type>::value);
}

折叠表达式中的逗号运算符

在使用逗号运算符时,一般按照从左往右的计算顺序执行,其返回值为逗号运算符最后一条语句的结果。如下:

    int a = (1, 2, 3);              // a = 3
    auto b = (1.2, "hi", 1 + 2);    // b = 3
    auto c = (b = b - 1, b++, b);   // c = 3

因此,逗号运算符的折叠表达式展开式的返回结果也遵循这种规则。

auto f(auto... args) {
    return (args, ...);
}

auto a = f(1);           // a = 1
auto b = f("hi", 2);      // b = 2
auto c = f(0.2, 2, 'A');   // c = 'A'

参数包的展开方式

以打印参数表为例子

递归

使用递归求解时,必须要定义一个递归的终止条件的重载函数, 否则将导致无限递归

void print_v1() {
    std::cout << "\n" << std::endl;
}

template<typename T, typename... Args>
void print_v1(T t, Args... args)
{
    std::cout << t << " ";
    print_v1(args...);
}

print_v1(1, 2, 3); // 1 2 3 

无终止条件导致的无限递归错误

template<typename... Args>
void print_v1_inf(Args... args)
{
    print_v1_inf(args...);
}

print_v1_inf(1, 2, 3); // error, infinite recursion

使用 sizeof... 作为终止条件判断

这里,有人会尝试使用 sizeof...(args) 判断可变参数的个数,从而终止递归,如下:

template<typename T, typename... Args>
void print_v3(T t, Args... args)
{
    std::cout << t << " ";
    if (sizeof...(args) > 0)
        print_v3(args...);
}

print_v3(1, 2, 3); // compile error

编译器依然报错:

 <source>: In instantiation of 'void print_v3(T, Args ...) [with T = int; Args = {}]':
 <source>:128:17:   recursively required from 'void print_v3(T, Args ...) [with T = int; Args = {int}]'
 <source>:128:17:   required from 'void print_v3(T, Args ...) [with T = int; Args = {int, int}]'
 <source>:133:13:   required from here
 <source>:128:17: error: no matching function for call to 'print_v3() '
  128 |         print_v3(args...) ;
      |         ~~~~~~~~^~~~~~~~~ 
 <source>:124:6: note: candidate: 'template<class T, class ... Args> void print_v3(T, Args ...)'
  124 | void print_v3(T t, Args... args)
      |      ^~~~~~~~ 
 <source>:124:6: note:  template argument deduction/substitution failed:
 <source>:128:17: note:  candidate expects at least 1 argument, 0 provided
  128 |         print_v3(args...) ;

因为编译器在编译器会实例化可变参数模板的所有重载函数版本,包括空参数版本。而重载函数没有空参数版本,从而导致编译错误。

void print_v3<int, int, int>(int, int, int);
void print_v3<int, int>(int, int);
void print_v3<int>(int);
void print_v3(); // error, no candidate

那有没有可能通过 sizeof... 来实现呢, 答案是可以的。我们可以使用 constexpr 来限制编译器在编译期的行为,如下:

template<typename T, typename... Args>
void print_v3(T t, Args... args)
{
    std::cout << t << " ";
    if constexpr (sizeof...(args) > 0)
        print_v3(args...);
}

print_v3(1, 2, 3); // ok, 1 2 3

初始化列表

在使用递归展开式,编译器将会在编译时生成不同参数个数的重载表达式,但是有时候我们不需要那些重载版本,比如上述打印的函数,我们只关注当前传入的某个参数。因此可以借用初始化列表进行折叠展开, 如下。

其中的逗号表达式主要是为了能够使初始化列表里的函数可以被执行。根据逗号表达式规则,每次展开后按从左到右的执行顺序,返回的结果是 0, 因此 res 保存的信息是一个尺寸为参数个数(展开次数),全为 0 的一维数组。

template<typename T>
void print_v5(T t) {
    std::cout << t << " ";
}

template<typename... Args>
void print_v5(Args... args)
{
    // option 1, array
    // auto res = { (print_v5(args), 0)... };
    
    // option 2, initializer_list
    auto res = std::initializer_list<int>{ (print_v5(args), 0)... };
}

print_v5(1, 2, 3); // ok, 1 2 3

逗号表达式

除了递归和初始化列表展开可变参数列表,我们还可以通过逗号表达式来实现参数列表的展开。

template<typename... Args>
void print_v4(Args... args)
{
    ((std::cout << args << " "), ...);
}

print_v4(1, 2, 3); // ok, 1 2 3

在 C++20 中,auto 关键字在模板编程中又进一步加强,逗号表达式的展开还可以简化成:

void print_v2(auto... args){
   ((std::cout << args << " "), ...);
}

print_v2(1, 2, 3);  // ok, 1 2 3 

应用示例 (基于 auto 模板)

打印

void print(auto... args){
   ((std::cout << args << " "), ...);
}

print(1, 1.1, "abc"); // 1 1.1 abc

表达式展开后如下

void print(int a, float b, char c){
    (std::cout << a << ",", std::cout << b << ",", std::cout << c << ",");
}

求和

auto sum(auto... args){
    return (args + ...);
}

sum(1, 2, 3.3); // ok, 6.3

但是上述方法对于空表达式 sum() 存在问题,编译器将报错:

error: fold of empty expansion over operator+

对于这种情况,有如下方式解决

auto sum(auto... args){
    return (args + ...);
}

auto sum_v1(auto... args){
    auto s = 0;
    ((s += args), ...);
    return s;
}

auto sum_v2(auto... args){
    return (args + ... + 0);
}

auto sum_v3() { return 0; }

auto sum_v3(auto first, auto... args) {
    return first + sum_v2(args...);
}

int main()
{
    // std::cout << sum() << "\n"; // error
    std::cout << sum_v1() << "\n"; // ok
    std::cout << sum_v2() << "\n"; // ok
}

排序

由之前逗号运算符的折叠表达式展开的规律,既然可以每次逗号运算符的返回值,我们是否可以通过这个方法来实现排序呢?

template<typename T, typename... Args>
auto sort(const T& t, const Args&... args)
{
    T temp = t;
    std::initializer_list<T>{ (temp = std::min(temp, args))... };
    return temp;
}

int main(void){
    std::cout << sort(2, 1, 3, 5, 4) << "\n"; // 1
}

Ref