使用 Java 8 流计算加权平均值

2022-09-03 13:02:10

我如何计算 a 的加权平均值,其中 Integer 值是要平均的 Double 值的权重。例如:地图具有以下元素:Map<Double, Integer>

  1. (0.7, 100) // 值为 0.7,权重为 100
  2. (0.5, 200)
  3. (0.3, 300)
  4. (0.0, 400)

我希望使用Java 8流应用以下公式,但不确定如何一起计算分子和分母并同时保留它。如何在这里使用减少?

enter image description here


答案 1

您可以为此任务创建自己的收集器:

static <T> Collector<T,?,Double> averagingWeighted(ToDoubleFunction<T> valueFunction, ToIntFunction<T> weightFunction) {
    class Box {
        double num = 0;
        long denom = 0;
    }
    return Collector.of(
             Box::new,
             (b, e) -> { 
                 b.num += valueFunction.applyAsDouble(e) * weightFunction.applyAsInt(e); 
                 b.denom += weightFunction.applyAsInt(e);
             },
             (b1, b2) -> { b1.num += b2.num; b1.denom += b2.denom; return b1; },
             b -> b.num / b.denom
           );
}

此自定义收集器将两个函数作为参数:一个是返回要用于给定流元素的值(作为 ToDoubleFunction)的函数,另一个是返回权重(作为 ToIntFunction)的函数。它使用在收集过程中存储分子和分母的帮助器局部类。每次接受条目时,分子都会随着值与其权重的相乘而增加,并且分母随权重增加。然后,完成者将两者的除法作为 返回 。Double

示例用法如下:

Map<Double,Integer> map = new HashMap<>();
map.put(0.7, 100);
map.put(0.5, 200);

double weightedAverage =
  map.entrySet().stream().collect(averagingWeighted(Map.Entry::getKey, Map.Entry::getValue));

答案 2

您可以使用此过程计算地图的加权平均值。请注意,映射条目的键应包含值,映射条目的值应包含权重。

     /**
     * Calculates the weighted average of a map.
     *
     * @throws ArithmeticException If divide by zero happens
     * @param map A map of values and weights
     * @return The weighted average of the map
     */
    static Double calculateWeightedAverage(Map<Double, Integer> map) throws ArithmeticException {
        double num = 0;
        double denom = 0;
        for (Map.Entry<Double, Integer> entry : map.entrySet()) {
            num += entry.getKey() * entry.getValue();
            denom += entry.getValue();
        }

        return num / denom;
    }

您可以查看其单元测试以查看用例。

     /**
     * Tests our method to calculate the weighted average.
     */
    @Test
    public void testAveragingWeighted() {
        Map<Double, Integer> map = new HashMap<>();
        map.put(0.7, 100);
        map.put(0.5, 200);
        Double weightedAverage = calculateWeightedAverage(map);
        Assert.assertTrue(weightedAverage.equals(0.5666666666666667));
    }

单元测试需要以下导入:

import org.junit.Assert;
import org.junit.Test;

您需要为代码导入以下内容:

import java.util.HashMap;
import java.util.Map;

我希望它有帮助。


推荐