你会使用JAVA流进行分组和聚合吗

243 阅读8分钟

学习使用Java Streams解决问题的简单方法,Java Streams是一个框架,它允许我们快速有效地处理大量数据。

当我们对列表中的元素进行分组时,我们可以随后聚合分组元素的字段,以执行有意义的操作,帮助我们分析数据。一些例子是加法、平均值或最大/最小值。使用Java流和收集器可以很容易地完成单个字段的聚合。文档提供了如何进行这些类型计算的简单示例。

然而,还有更复杂的聚合,如加权平均值、几何平均值。此外,可能需要同时聚合几个字段。在本文中,我们将展示一种使用Java流解决这类问题的简单方法。使用这个框架,我们可以快速有效地处理大量数据。

我们假设读者对以下内容有基本的了解Java流和效用收集者上课。

问题布局

让我们考虑一个简单的例子来展示我们想要解决的问题类型。我们将使它非常通用,这样我们可以很容易地概括它。让我们考虑一系列TaxEntry由以下代码定义的实体:

public class TaxEntry {

    private String state;
    private String city;
    private int numEntries;
    private double price;
    //Constructors, getters, hashCode, equals etc
}
```

计算给定城市的条数总数非常简单


````
Map<String, Integer> totalNumEntriesByCity = 
              taxes.stream().collect(Collectors.groupingBy(TaxEntry::getCity, 
                                                           Collectors.summingInt(TaxEntry::getNumEntries)));
````
`Collectors.groupingBy`接受两个参数:一个分类器函数进行分组,一个收集器对属于给定组的所有元素进行下游聚合。我们使用`TaxEntry::getCity`作为分类器功能。对于下游,我们使用`Collectors::summingInt`它返回一个`Collector`它合计了我们从每个分组元素中获得的税收条目的数量。

如果我们试图寻找复合分组,事情会稍微复杂一些。例如,在前面的示例中,给定状态的条目总数**和**城市。有几种方法可以做到这一点,但一种非常直接的方法是首先定义:

````
record StateCityGroup(String state, String city) {}
````

请注意,我们使用的是Java`record`,这是一种定义不可变类的简洁方法。此外,Java编译器为我们生成字段访问器方法,`hashCode`,等于,和`toString`实现。有了这些,现在的解决方案很简单:


````
Map<StateCityGroup, Integer> totalNumEntriesForStateCity = 
                    taxes.stream().collect(groupingBy(p -> new StateCityGroup(p.getState(), p.getCity()), 
                                                      Collectors.summingInt(TaxEntrySimple::getNumEntries))
                                          );
````

为`Collectors::groupingBy`我们使用lambda表达式设置分类器函数,该表达式创建一个新的`StateCityGroup`封装每个州-城市的记录。下游收集器与之前相同。

**注意:** 为了简明起见,在代码示例中,我们将假设Collectors类的所有方法都是静态导入的,因此我们不必显示它们的类限定。

事情开始变得更复杂的地方是,如果我们想同时做几个聚合。例如,查找给定州和城市的条目数和平均价格的总和。该库没有提供这个问题的简单解决方案。

为了解决这个问题,我们从前面的聚合中得到启示,定义一个记录,该记录封装了所有需要聚合的字段:


````
record TaxEntryAggregation (int totalNumEntries, double averagePrice ) {}
````

现在,我们如何同时对这两个字段进行聚合呢?总有可能执行两次流收集来分别查找每个聚合,如下面的代码所示:

````
Map<StateCityGroup, TaxEntryAggregation> aggregationByStateCity = taxes.stream().collect(
           groupingBy(p -> new StateCityGroup(p.getState(), p.getCity()),
                      collectingAndThen(Collectors.toList(), 
                                        list -> {int entries = list.stream().collect(
                                                                   summingInt(TaxEntrySimple::getNumEntries));
                                                 double priceAverage = list.stream().collect(
                                                                   averagingDouble(TaxEntrySimple::getPrice));
                                                 return new TaxEntryAggregation(entries, priceAverage);})));
````

分组如前所述,但是对于下游,我们使用`Collectors::collectingAndThen`(第3行)。这个函数有两个参数:

-   来自初始分组的下载流,我们将其转换成一个列表(使用`Collectors::toList()`在第3行)
-   Finisher函数(第4–9行),其中我们使用一个lambda表达式从前面的列表中创建两个不同的流,进行聚合并将它们组合在一个新的`TaxEntryAggregation`记录

假设我们想同时进行更多的字段聚合。我们将需要相应地增加来自下游列表的流的数量。代码变得低效、重复、不尽如人意。我们应该寻找更好的选择。

此外,问题并没有到此结束,一般来说,我们受限于可以用Collectors helper类完成的聚合类型。他们的方法summing *、average *和summaring *仅支持整型、长整型和双精度本机类型。如果我们有更复杂的类型,比如`BigInteger`或者`BigDecimal`? 

雪上加霜的是,summarizing *方法只提供了min、max、count、sum和average的汇总统计数据。如果我们想要执行更复杂的计算,如加权平均或几何平均,该怎么办?

有些人会争辩说,我们总是可以编写自定义收集器,但这需要了解收集器接口,并对流收集器流有很好的理解。使用collectors类中的utility方法提供的内置收集器更简单。在下一节中,我们将展示几个如何实现这一点的策略。

## 复杂的多重聚合:一种解决途径

让我们考虑一个简单的例子,它将突出我们在上一节中提到的挑战。假设我们有以下实体:

````
public class TaxEntry {
    private String state;
    private String city;
    private BigDecimal rate;
    private BigDecimal price;
    record StateCityGroup(String state, String city) {
    }
    //Constructors, getters, hashCode/equals etc
}
````

我们首先问,对于每个不同的州-城市对,我们如何找到条目的总数和`rate`和`price`(∑(费率*价格))。请注意,我们正在使用`BigDecimal`.

正如我们在上一节中所做的那样,我们定义了一个封装聚合的类:

````
record RatePriceAggregation(int count, BigDecimal ratePrice) {}
````

乍一看,这似乎令人惊讶,但是对于简单聚合后的分组,一个简单的解决方案是使用`Collectors::toMap`。让我们看看我们会怎么做:

````
Map<StateCityGroup, RatePriceAggregation> mapAggregation = taxes.stream().collect(
      toMap(p -> new StateCityGroup(p.getState(), p.getCity()), 
            p -> new RatePriceAggregation(1, p.getRate().multiply(p.getPrice())), 
            (u1,u2) -> new RatePriceAggregation( u1.count() + u2.count(), u1.ratePrice().add(u2.ratePrice()))
            ));
````

这`Collectors::toMap`(第2行)采用三个参数,我们执行以下实现:

-   第一个参数是生成映射键的lambda表达式。该功能创建`StateCityGroup`作为地图的钥匙。这将按照州和城市对元素进行分组(第2行)。
-   第二个参数产生地图的值。在我们的例子中,我们创建了一个`RatePriceAggregation`用计数1和rate与price的乘积初始化(第3行)。
-   最后,最后一个参数是`BinaryOperator`合并多个元素映射到同一个州/市键的情况。我们将计数和价格相加来进行聚合(第4行)。

让我们演示如何设置一些示例数据:

````
List<TaxEntry> taxes = Arrays.asList(
                          new TaxEntry("New York", "NYC", BigDecimal.valueOf(0.2), BigDecimal.valueOf(20.0)), 
                          new TaxEntry("New York", "NYC", BigDecimal.valueOf(0.4), BigDecimal.valueOf(10.0)), 
                          new TaxEntry("New York", "NYC", BigDecimal.valueOf(0.6), BigDecimal.valueOf(10.0)), 
                          new TaxEntry("Florida", "Orlando", BigDecimal.valueOf(0.3), BigDecimal.valueOf(13.0)));
                          
                         
````
从前面的代码示例中获得纽约的结果很简单:

````
System.out.println("New York: " + mapAggregation.get(new StateCityGroup("New York", "NYC")));
````

这将打印:

````
New York: RatePriceAggregation[count=3, ratePrice=14.00]
````
`这是确定多个字段和非原始数据类型(`BigDecimal`在我们的情况下)。但是,它有一个缺点,即它没有任何允许您执行额外操作的终结器。例如,你不能做任何形式的平均。

为了展示这个问题,让我们考虑一个更复杂的问题。假设我们想要找到费率-价格的加权平均值,以及每个州和城市对的所有价格的总和。具体来说,要找到加权平均值,我们需要计算属于每个州-城市对的所有条目的费率和价格的乘积之和,然后除以每种情况下的条目总数n:1/n∑(费率*价格)。

为了解决这个问题,我们开始定义一个包含聚合的记录:`

````
record TaxEntryAggregation(int count, BigDecimal weightedAveragePrice, BigDecimal totalPrice) {}
````

`有了这些,我们可以进行以下实现:`

````
Map<StateCityGroup, TaxEntryAggregation> groupByAggregation = taxes.stream().collect(
    groupingBy(p -> new StateCityGroup(p.getState(), p.getCity()), 
               mapping(p -> new TaxEntryAggregation(1, p.getRate().multiply(p.getPrice()), p.getPrice()), 
                       collectingAndThen(reducing(new TaxEntryAggregation(0, BigDecimal.ZERO, BigDecimal.ZERO),
                                                  (u1,u2) -> new TaxEntryAggregation(u1.count() + u2.count(),
                                                      u1.weightedAveragePrice().add(u2.weightedAveragePrice()), 
                                                      u1.totalPrice().add(u2.totalPrice()))
                                                  ),
                                         u -> new TaxEntryAggregation(u.count(), 
                                                 u.weightedAveragePrice().divide(BigDecimal.valueOf(u.count()),
                                                                                 2, RoundingMode.HALF_DOWN), 
                                                 u.totalPrice())
                                         )
                      )
              ));
````

我们可以看到代码有点复杂,但允许我们得到我们正在寻找的解决方案。我们将更详细地跟踪它:

-   `Collectors::groupingBy`(第二行):

    1.  对于分类函数,我们创建了一个`StateCityGroup `记录

    1.  对于下游,我们调用`Collectors::mapping`(第3行):

        -   对于第一个参数,我们应用于输入元素的映射器将分组的州-城市税收记录转换为新的`TaxEntryAggregation`将初始计数指定为1,用价格乘以比率,然后设置价格的条目(第3行)。

        -   对于下游,我们调用`Collectors::collectingAndThen`(第4行),正如我们将看到的,这将允许我们对下游收集器应用最终转换。

            -   引起`Collectors::reducing`(第4行)

                1.  创建默认值`TaxEntryAggregation `以涵盖没有下游元素的情况(第4行)。
                1.  Lambda表达式进行约简,并返回一个新的`TaxEntryAggregation`它包含字段的聚合(第5行,6 7行)

            -   执行精加工转换,使用前一次缩减中计算的计数计算平均值,并返回最终值`TaxEntryAggregation`(第9、10、11行)。

我们看到,这种实现不仅允许我们同时进行多个字段聚合,还可以在几个阶段执行复杂的计算。

这可以很容易地推广到解决更复杂的问题。方法很简单:定义一条记录,封装所有需要聚合的字段,使用`Collectors::mapping`初始化记录,然后应用`Collectors::collectingAndThen`进行归约和最终汇总。

像以前一样,我们可以得到纽约的汇总:

````
System.out.println("Finished aggregation: " + groupByAggregation.get(new StateCityGroup("New York", "NYC")));
````
```
我们得到了结果:

````
Finished aggregation: TaxEntryAggregation[count=3, weightedAveragePrice=4.67, totalPrice=40.0]
````

还值得指出的是,因为`TaxEntryAggregation`Java`record`,它是不可变的,因此可以使用流收集器的库提供的支持来并行化计算。

## 结论

我们已经展示了几个策略,通过聚合进行复杂的多字段分组,这些聚合包括具有多字段和跨字段计算的非原始数据类型。这是使用Java streams和Collectors API的记录列表,因此它为我们提供了快速有效地处理大量数据的能力。