(叨逼叨)基于Flink实现粒子群算法

824 阅读8分钟

@[toc]

粒子群算法简介

目的

解决对目标函数的规划问题,例如 求取函数 F(x1,x2,x3) 中 X在某一个范围内函数的最小值,或者最大值。

原理

粒子群算法通过模拟鸟类种群的迁移,自适应地去找出在迭代范围内符合要求的位置,体现在函数里面就是,通过粒子群算法可以在一定的精度内找出在多维变量X的取值范围内,满足函数 F 可取得最大值/最小值的 X 向量,也就是x1 x2 x3...的取值。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-sTtUhk6c-1635777673284)(C:\Users\31395\AppData\Roaming\Typora\typora-user-images\image-20211101164412441.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-5790SI4I-1635777673286)(C:\Users\31395\AppData\Roaming\Typora\typora-user-images\image-20211101163147849.png)]

本文目标

基于Flink 分布式流处理引擎实现对粒子群算法的实现(默认求取目标函数最小值)

基本环境搭建

Java 1.8 (yyds)

maven 3.6x

Flink 1.10.1

网上有无教程,不知道,反正我没找到!

配置如下:



<dependencies>
    <dependency>
        <groupId>org.apache.flink</groupId>
        <artifactId>flink-streaming-java_2.12</artifactId>
        <version>1.10.1</version>
    </dependency>

    <dependency>
        <groupId>org.apache.flink</groupId>
        <artifactId>flink-java</artifactId>
        <version>1.10.1</version>
    </dependency>

    <dependency>
        <groupId>org.slf4j</groupId>
        <artifactId>slf4j-log4j12</artifactId>
        <version>1.7.21</version>
        <scope>test</scope>
    </dependency>
    
    <dependency>
        <groupId>org.projectlombok</groupId>
        <artifactId>lombok</artifactId>
        <version>1.18.4</version>
        <scope>provided</scope>
    </dependency>


    <dependency>
        <groupId>log4j</groupId>
        <artifactId>log4j</artifactId>
        <version>1.2.17</version>
    </dependency>

    <dependency>
        <groupId> org.apache.cassandra</groupId>
        <artifactId>cassandra-all</artifactId>
        <version>0.8.1</version>

        <exclusions>
            <exclusion>
                <groupId>org.slf4j</groupId>
                <artifactId>slf4j-log4j12</artifactId>
            </exclusion>
            <exclusion>
                <groupId>log4j</groupId>
                <artifactId>log4j</artifactId>
            </exclusion>
        </exclusions>

    </dependency>

</dependencies>


项目结构V0.5测试版

V0.5测试版本:具备基本功能,暂时未开放使用可复用的接口。(不是不能用,只是没有工程化)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-jBvGokDT-1635777673288)(C:\Users\31395\AppData\Roaming\Typora\typora-user-images\image-20211101213631649.png)]

整个项目的结构一目了然。首先Test包,这个老规矩了,是个测试包,没啥用,忽略即可!

设计思想

由于Flink是一个流处理引擎所以我们这边的操作就很好办了,我们可以直接摒弃以前使用矩阵的方式来进行运算。而是采取更加简便和直观的方式来进行表示和运算,那就是我们直接对鸟类进行模拟。我们定义一个鸟类,之后使用这玩意在我们的流里面不断地进行处理!

并且直接使用Bird类的好处就是我们可以直接使用Bird记录自己的最佳位置,也就是个体最优,这样一来就不需要系统再对个体最优位置进行整体的计算了。Bird类的定义

package com.java.PSO.StreamPso;

import com.java.PSO.ConfigPso.ConfigPso;
import jdk.nashorn.internal.objects.annotations.Constructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.ToString;

import java.lang.reflect.Array;
import java.util.ArrayList;

@Data
@ToString
@NoArgsConstructor
public class Bird implements Cloneable {
    //大鸟的编号
    private Integer id;
    private ArrayList<Double> Pbest;
    private ArrayList<Double> Gbest;

    private Double Functionresult;

    private Double LFunctionresult;

    private ArrayList<Double> Xpostion;
    private ArrayList<Double> Vpersent;
    private Integer InterTimes;


    public Bird(Integer id, ArrayList<Double> pbest, ArrayList<Double> gbest, Double functionresult, Double LFunctionresult, ArrayList<Double> xpostion, ArrayList<Double> vpersent, Integer interTimes) {
        this.id = id;
        this.Pbest = pbest;
        this.Gbest = gbest;
        this.Functionresult = functionresult;
        this.LFunctionresult = LFunctionresult;
        this.InterTimes = interTimes;
        this.setXpostion(xpostion);
        this.setVpersent(vpersent);
    }

    public void setXpostion(ArrayList<Double> xpostion) {
        //越界处理
        int index = 0;
        for (Double aDouble : xpostion) {
            if(aDouble > ConfigPso.X_up)
                xpostion.set(index,ConfigPso.X_up);
            else if (aDouble < ConfigPso.X_down)
                xpostion.set(index,ConfigPso.X_down);
            index++;
        }

        Xpostion = xpostion;
    }

    public void setVpersent(ArrayList<Double> vpersent) {
        int index = 0;
        for (Double aDouble : vpersent) {
            if(aDouble > ConfigPso.V_max)
                vpersent.set(index,ConfigPso.V_max);
            else if (aDouble < ConfigPso.V_min)
                vpersent.set(index,ConfigPso.V_min);
            index++;
        }
        Vpersent = vpersent;
    }

    @Override
    protected Object clone() throws CloneNotSupportedException {
        return super.clone();
    }
}

ConfigPso

ConfigPso这个毫无疑问就是配置选项。

配置的需求说明如下:

package com.java.PSO.ConfigPso;


public  class ConfigPso {
    //关于粒子群算法的相关参数设置
    /**
     *X(i+1) = X(i) + V(i+1)
     * V(i+1) = w*V(i) +c1*r1*(Pbest-X(i)) + c2*r2*(Gbest-X(i))
     * r1,r2为随机数【0,1】这边不设置
     */


    public static final Double C1 = 2.0;
    public static final Double C2 = 2.0;
    public static final Double w = 0.4;

    public static final Double X_down = -2.0;
    public static final Double X_up = 2.0;

    public static final Double V_min = -4.0;
    public static final Double V_max = 4.0;

    public static final Integer PopulationNumber = 2; //种群个数
    public static final Integer IterationsNumber = 20;//迭代次数不能为0

    public static final Integer ParamesNumber = 1;
    

}

这个对标我们前面的数学公式。

Function

这个包适用于存放目标函数的,例如我们需要优化

F(x) = x^2 (X为一维矩阵,矩阵长度表示维度)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-AMRU9Rup-1635777673290)(C:\Users\31395\AppData\Roaming\Typora\typora-user-images\image-20211101214152728.png)]

FunctionMake 就是一个工厂类

演示代码从上到下如下:

package com.java.PSO.Function.FunctionImp;

import java.util.ArrayList;

public interface FunctionsImpl {
     Double FourFunction(ArrayList<Double> parames);
}




package com.java.PSO.Function;

import com.java.PSO.Function.FunctionImp.FunctionsImpl;

import java.util.ArrayList;

public class FunctionMake {
    static FunctionsImpl functions=new Functions();
    public static Double FourFunction(ArrayList<Double> List){
        Double rest = functions.FourFunction(List);
        return rest;

    }
}




package com.java.PSO.Function;

import com.java.PSO.Function.FunctionImp.FunctionsImpl;

import java.util.ArrayList;

public class Functions implements FunctionsImpl {

    @Override
    public Double FourFunction(ArrayList<Double> parames) {

        //测试函数,寻找最小值,x 假设都在 [5,-5] vmax = [-10,10] w=0.4 c1=c2=2默认初始
        Double res = 0.0;
        int index = 0;
        for (Object parame : parames) {
            res = res + Math.pow((Double) parames.get(index),2);
            index ++;
        }
        return res;

    }
}

注意这里面有很多方法都是静态的,原因很简单后面调用需要用到,而且为了方便调用也是使用静态方法好一点。

StreamPso

这个下面有一个子包

Core

这个包里面存放的就是我们这个算法的核心,也就是

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-jhGyYVAq-1635777673292)(C:\Users\31395\AppData\Roaming\Typora\typora-user-images\image-20211101163147849.

不过这里实现起来还是简单的。毕竟能够想出来的玩意其实都是简单的,只是信息差导致你以为很难。

package com.java.PSO.StreamPso.Core;
import com.java.PSO.ConfigPso.ConfigPso;
import com.java.PSO.StreamPso.Bird;

import java.util.ArrayList;
import java.util.Random;

public class Core {

    static Random random = new Random();

    public static ArrayList<Double> UpdateSpeed(Bird bird){
        ArrayList<Double> CurrentSpeed = bird.getVpersent();
        //更新速度,传入大鸟,会自动更新大鸟的速度,同时返回更新后的速度向量
        Double fai1 = ConfigPso.C1 * random.nextDouble(); //c1*r1
        Double fai2 = ConfigPso.C2 * random.nextDouble(); //c2*r2
        int index = 0;
        for (Double aDouble : CurrentSpeed) {

            aDouble = ConfigPso.w * aDouble + fai1*(bird.getPbest().get(index) - bird.getXpostion().get(index))
                    + fai2*(bird.getGbest().get(index) - bird.getXpostion().get(index));
            CurrentSpeed.set(index,aDouble);

            index ++ ;

        }
        //完成对速度的更新
        bird.setVpersent(CurrentSpeed);

        return CurrentSpeed;
    }

    public static ArrayList<Double> UpdatePosition(Bird bird){
        //更新位置,传入大鸟,会自动更新大鸟的位置,同时返回更新后的位置的向量
        int index = 0;
        ArrayList<Double> CurrentXposition = bird.getXpostion();
        for (Double aDouble : CurrentXposition) {

//            System.out.println(aDouble+"<--->"+bird.getVpersent().get(index));
            aDouble = aDouble+bird.getVpersent().get(index);
            CurrentXposition.set(index,aDouble);
            index++;
        }
        //完成对位置的更新

        bird.setXpostion(CurrentXposition);

        return CurrentXposition;
    }

    public static Bird UpDataBird(Bird bird){
        //返回Bird,负责对前面的方法进行调度。只需要调用这一个方法就可以实现位置和速度更新
        //先更新速度然后才能够更新位置
        //由于每一个个体过来都会需要执行一下算子,所以每一次在执行的时候fai1,fai2都是不同的
        //也就是每一个在每一轮当中的fai都是不同的,有可能会提高拟真度。
        UpdateSpeed(bird);
        UpdatePosition(bird);
        return bird;
    }

}

Dostream

在这前面还有一个BirdFactory这个是一个鸟类的工厂类,也就是产生鸟类,值得一提的是这里使用的是clone。但这个不是重点实现起来也很简单,所以这里就不展示了,权当代码补全留给各位读者。

那么关于这个里面涉及到了很多的基本算子,同时如何实现个体与全局的最优排序。

个体最优

这个实现起来就是比较简单的,由于Flink是流处理,所以我们来一个处理一个只需要记录前面的状态就好了。

    static class MinMapsP implements MapFunction<Bird,Bird>{
        @Override
        public Bird map(Bird bird) throws Exception{
            //此时状态由Bird自己进行管理,Lfunctionresult记录的就是t-1次的个体最优的值,我们这边是找最小的的函数值
         if(bird.getFunctionresult()<bird.getLFunctionresult()){
             bird.setPbest(bird.getXpostion());
             //更新最优值
             bird.setLFunctionresult(bird.getFunctionresult());
         }
            return bird;
        }
    }

全局最优

这个的话我们需要一个全局状态记录

也就是这个

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-cu08xNuh-1635777673293)(C:\Users\31395\AppData\Roaming\Typora\typora-user-images\image-20211101215825609.png)]

之后进行状态记录

    static class MinMapsG implements MapFunction<Bird,Bird>{
        //这个是通用的不存在初始化例外使用的情况

        @Override
        public Bird map(Bird bird) throws Exception {
            //状态流,状态由系统记录
            if(bird1!=null){
                if( bird.getFunctionresult()> bird1.getFunctionresult())
                    bird.setGbest(bird1.getXpostion());
                else {
                    bird.setGbest(bird.getXpostion());
                    bird1=bird;
                }
            }
            else{
                bird1 = bird;
                bird.setGbest(bird.getXpostion());
            }
            return bird;
        }


    }


但是这里有个问题我想你也注意到了,我们求到了全局最优,但是我们还需要对每一个Bird的进行记录,也就是告诉Bird谁是最Best的(全局)之后进入计算,当然你也可以选择直接使用bird1原因就是bird1记录的就是全局最优Pbest的个体(具有Pbest但是不代表它是Best)不过虽然这个是个方案,但是Flink是个多线程的,so You Know 这个方案直接实施时不行的,你还是需要在算子里去做,但是这里我直接复用MinMapsG。

全局调用代码

这个就是Dostream里面的代码也是主要的代码。

package com.java.PSO.StreamPso;

import com.java.PSO.ConfigPso.ConfigPso;
import com.java.PSO.Function.FunctionImp.FunctionsImpl;
import com.java.PSO.Function.FunctionMake;
import com.java.PSO.Function.Functions;
import com.java.PSO.StreamPso.Core.Core;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.java.tuple.Tuple;
import org.apache.flink.streaming.api.datastream.*;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.co.CoFlatMapFunction;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.util.Collector;
import org.junit.Test;

import java.util.ArrayList;
import java.util.Random;

public class DoSteam {

    static Bird bird1;

    public static void main(String[] args) throws Exception {

        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        DataStreamSource<Bird> BirdInitStream = env.addSource(new InitBirds());
        KeyedStream<Bird, Integer> Birdtimeing = BirdInitStream.keyBy(Bird::getInterTimes);

        //进行初始化,获取全局最优,获取全局最优需要调用两次MinMapsG这个算子
        //由于是基于流处理,不使用窗口所以必须使用状态流进行全局最优筛选,第一次调用只是选择出全局最优
        //第二次调用是为了给所有的个体赋值,和个体的最优处理
        SingleOutputStreamOperator<Bird> map = Birdtimeing.map(new MinMapsG());
        KeyedStream<Bird, Tuple> id = map.keyBy("id");
        SingleOutputStreamOperator<Bird> map1 = id.map(new MinMapsG());
        SingleOutputStreamOperator<Bird> RealStream = map1.map(new MinMapsPinitial());
//        RealStream.print("init");

        //完成初始化后的数据流,到这里开始进行循环
        IterativeStream<Bird> iterateStream = RealStream.iterate();
        SingleOutputStreamOperator<Bird> IterationBody = iterateStream.keyBy(Bird::getInterTimes) //分组
                .map(new MinMapsG()) //首次寻早最优解
                .keyBy("id") //再次分组两个原因
                .map(new MinMapsG()) // 再次统计最优解,为全局的位置最优解
                .map(new MinMapsP())// 循环处理当中的个体最优解决
                .map(new CalculationPso());//这一步是进行粒子群的运算,也是比较重要的一环

        //需要进入循环的条件
        SingleOutputStreamOperator<Bird> IterationFlag = IterationBody.filter(new FilterFunction<Bird>() {
            @Override
            public boolean filter(Bird bird) throws Exception {
                return bird.getInterTimes() < ConfigPso.IterationsNumber;
            }
        });

        iterateStream.closeWith(IterationFlag);

        SingleOutputStreamOperator<Bird> Outstream = IterationBody.filter(new FilterFunction<Bird>() {
            @Override
            public boolean filter(Bird bird) throws Exception {
                return bird.getInterTimes() >= ConfigPso.IterationsNumber;
            }
        });
//        Outstream.print("1-->");
        //到这一步的话我们的程序已经进行了最后一次的运行,但是此时的是没有进行排序的,所以需要进行最后一次排序
        //这里由于只输出一个,所以这里打算直接开个技术窗口,然后输出最值!
        SingleOutputStreamOperator<Bird> MinBrid = Outstream.countWindowAll(ConfigPso.PopulationNumber).min("Functionresult");
        MinBrid.print("The best bird");


        env.execute();

    }


    static class CalculationPso implements MapFunction<Bird,Bird>{

        @Override
        public Bird map(Bird bird) throws Exception {
            /**
             * @Huterox
             * @Time:2021-11-1
             * 目标,通过Core的Update实现对大鸟(粒子)的位置和速度更新
             * 之后通过更新后的位置计算出目标函数的值,进行设置,前面的算子再进行一个新的轮回
             * 更新粒子迭代次数
             */

            Core.UpDataBird(bird);
            bird.setFunctionresult(FunctionMake.FourFunction(bird.getXpostion()));


            bird.setInterTimes(bird.getInterTimes()+1);
            return bird;
        }
    }


    static class MinMapsP implements MapFunction<Bird,Bird>{
        @Override
        public Bird map(Bird bird) throws Exception{
            //此时状态由Bird自己进行管理,Lfunctionresult记录的就是t-1次的个体最优的值,我们这边是找最小的的函数值
         if(bird.getFunctionresult()<bird.getLFunctionresult()){
             bird.setPbest(bird.getXpostion());
             //更新最优值
             bird.setLFunctionresult(bird.getFunctionresult());
         }
            return bird;
        }
    }


    static class MinMapsPinitial implements MapFunction<Bird,Bird>{

        // 计算个体最优的都是无序的数据流,系统不好记录同时为了性能,所以个体状态由个体自己记录
        @Override
        public Bird map(Bird bird) throws Exception {

            //本次进行初始化
            //为了减少条件判读,所以直接把个体最优的算子进行拆分
            bird.setPbest(bird.getXpostion());
            bird.setLFunctionresult(bird.getFunctionresult());

            return  bird;
        }


    }


    static class MinMapsG implements MapFunction<Bird,Bird>{
        //这个是通用的不存在初始化例外使用的情况

        @Override
        public Bird map(Bird bird) throws Exception {
            //状态流,状态由系统记录
            if(bird1!=null){
                if( bird.getFunctionresult()> bird1.getFunctionresult())
                    bird.setGbest(bird1.getXpostion());
                else {
                    bird.setGbest(bird.getXpostion());
                    bird1=bird;
                }
            }
            else{
                bird1 = bird;
                bird.setGbest(bird.getXpostion());
            }
            return bird;
        }


    }


    static class InitBirds implements SourceFunction<Bird>{


        @Override
        public void run(SourceContext<Bird> ctx) throws Exception {

            for(int i=1;i<=ConfigPso.PopulationNumber; i++) {
                Bird bird = BirdFactory.MakeBird(i);

                Double functionresult = FunctionMake.FourFunction(bird.getXpostion());
                bird.setFunctionresult(functionresult);

                bird.setInterTimes(0);//表示正在初始化
                ctx.collect(bird);
            }

        }

        @Override
        public void cancel() {

        }
    }
}


测试

我们直接拿到前面的那个配置文件测试也就时F(x) = x^2

这个函数的最小值,迭代20次看结果

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-g7rGmB8z-1635777673293)(C:\Users\31395\AppData\Roaming\Typora\typora-user-images\image-20211101220547535.png)]

可以看到20次后出现了不错的效果,数值逼近0

由于其他的函数测试需要调节相关的配置参数,这里不作演示了。