算法:如何实现带权重的随机选择?

3,599 阅读3分钟

引子:我们都知道,很多公司一般都会在年中/年终进行绩效考核,这里我们假设绩效得分从高到低分别是A、B、C、D四个等级。然后,因为绩效考核成绩直接关系到一个员工的升职加薪和年终奖发放标准(比如:A等级发放3月月薪;B等级发放2.5月月薪;C等级发放2月月薪;D等级没有年终奖,并且面临被辞退风险),所以公司高管对于绩效考核定了一个潜规则,直接暗中提前划分了每个绩效等级占有的比例(比如:A 10%;B 20%;C 65%;D 5%)。那么现在问题来了,在一个项目组内除了极个别特别优秀的员工,其他人干的活都差不多,给谁高谁低都会有意见,假如这时候你作为项目组领导,你应该如何进行绩效考核才能尽可能平衡高管和组员的意见呢?

没错,既然不好选择,那就写代码实现一个带权重的随机选择算法吧,运气好就可以抽到好绩效,运气不好就得一个低绩效。算法公开透明,绩效好坏全凭运气,想必组员们都可以接受吧,大概(◔◡◔)

算法实现

在写具体代码之前,我们可以先想一想如何实现这个带权重的随机选择。

还是以上面那个引子举例:A出现的概率为10%,B出现的概率为20%,C出现的概率为65%,D出现的概率为5%。我们我们将之换算成累进制概率值,就是下面这样:

带权重的随机选择

算法思路已经明确了,算法逻辑就好写了,下面我就给大家一份我写的算法示例,以供参考。

import org.junit.Test;

import java.math.BigDecimal;
import java.text.MessageFormat;
import java.util.Random;

/**
 * 带权重的随机选择
 *
 * @author zifangsky
 * @date 2020/5/15
 * @since 1.0.0
 */
public class Problem_004_Weight_Random {

    /**
     * 测试代码
     */
    @Test
    public void testMethods(){
        Item[] items = new Item[]{new Item("A", 0.1),
                new Item("B", 0.2),
                new Item("C", 0.65),
                new Item("D", 0.05),};

        WeightRandom weightRandom = new WeightRandom(items);
        for(int i = 0; i < 10; i++){
            System.out.println(MessageFormat.format("员工{0}的绩效得分:{1}",
                    (i + 1),weightRandom.nextItem()));
        }
    }


    /**
     * 带权重的随机选择
     */
    static class WeightRandom{
        /**
         * 选项数组
         */
        private Item[] options;

        /**
         * 权重的临界值
         */
        private BigDecimal[] criticalWeight;

        private Random rnd;

        public WeightRandom(Item[] options) {
            if(options == null || options.length < 1){
                throw new IllegalArgumentException("选项数组存在异常!");
            }
            this.options = options;
            this.rnd = new Random();
            //初始化
            this.init();
        }

        /**
         * 随机函数
         */
        public String nextItem(){
            double randomValue = this.rnd.nextDouble();
            //查找随机值所在区间
            int index = this.searchIndex(randomValue);

            return this.options[index].getName();
        }

        /**
         * 查找随机值所在区间
         */
        private int searchIndex(double randomValue){
            BigDecimal rndValue = new BigDecimal(randomValue);
            int high = this.criticalWeight.length - 1;
            int low = 0;
            int median = (high + low) / 2;

            BigDecimal medianValue = null;
            while (median != low && median != high){
                medianValue = this.criticalWeight[median];

                if(rndValue.compareTo(medianValue) == 0){
                    return median;
                }else if(rndValue.compareTo(medianValue) > 0){
                    low = median;
                    median = (high + low) / 2;
                }else{
                    high = median;
                    median = (high + low) / 2;
                }
            }

            return median;
        }

        /**
         * 初始化
         */
        private void init(){
            //总权重
            BigDecimal sumWeights = BigDecimal.ZERO;
            //权重的临界值
            this.criticalWeight = new BigDecimal[this.options.length + 1];

            //1. 计算总权重
            for(Item item : this.options){
                sumWeights = sumWeights.add(new BigDecimal(item.getWeight()));
            }

            //2. 计算每个选项的临界值
            BigDecimal tmpSum = BigDecimal.ZERO;
            this.criticalWeight[0] = tmpSum;
            for(int i = 0; i < this.options.length; i++){
                tmpSum = tmpSum.add(new BigDecimal(this.options[i].getWeight()));
                this.criticalWeight[i + 1] = tmpSum.divide(sumWeights, 2, BigDecimal.ROUND_HALF_UP);
            }
        }
    }

    /**
     * 需要随机的item
     */
    static class Item{
        /**
         * 名称
         */
        private String name;
        /**
         * 权重
         */
        private double weight;

        public Item(String name, double weight) {
            this.name = name;
            this.weight = weight;
        }

        public String getName() {
            return name;
        }

        public double getWeight() {
            return weight;
        }
    }

}

示例代码输出如下:

员工1的绩效得分:C
员工2的绩效得分:B
员工3的绩效得分:B
员工4的绩效得分:C
员工5的绩效得分:C
员工6的绩效得分:B
员工7的绩效得分:C
员工8的绩效得分:C
员工9的绩效得分:B
员工10的绩效得分:C