跟着官网学Flink+Kafka:实现欺诈检测

1,505 阅读7分钟

需求说明

首先我们需要一个持续的数据流,产生很多随机的消费数据。然后从这些消费数据里面进行检测,检测符合以下规则:

  • 初阶:检测到同一个账号任意的连续两笔消费,如果第一笔消费金额特别小,第二笔消费金额很大,那么认为存在欺诈行为;
  • 进阶:在上面的基础上添加时间窗口,如果这两笔消费的时间间隔超过一定时间则不认为存在欺诈行为。

基于以上需求,我们设计的方案中存在几个部分:

  • 数据生产端,生产数据并推送到Kafka;
  • 数据消费端,使用Flink消费Kafka数据,并进行检测与记录;
  • 主程序,启动两个任务分别进行生产和消费。

准备工作

ZooKeeper 和 Kafka 可以去官网下载,下载完以后进行配置和启动,配置过程这里省略,给一个启动脚本(加入保存为zk_kafka_ops.sh):

#!/bin/bash

KAFKA_HOME="/Users/liwan/software/kafka_2.13-2.6.0"
ZOOKEEPER_HOME="/Users/liwan/software/apache-zookeeper-3.6.2-bin"

start() {
  echo "starting..."
  cur_path=$(pwd)
  # start zookeeper
  cd $ZOOKEEPER_HOME
  cmd1="bin/zkServer.sh start"
  eval $cmd1

  # start kafka
  cd $KAFKA_HOME
  cmd2="bin/kafka-server-start.sh config/server.properties"
  eval $cmd2
  cd $cur_path
}

stop() {
  echo "stoping..."
  cur_path=$(pwd)
  # stop kafka
  cd $KAFKA_HOME
  cmd2="bin/kafka-server-stop.sh"
  eval $cmd2

  # stop zookeeper
  cd $ZOOKEEPER_HOME
  cmd1="bin/zkServer.sh stop"
  eval $cmd1
  cd $cur_path
}

if [ $1 == 'stop' ]; then
  stop
else
  start
fi

那么使用 ./zk_kafka_ops.sh start./zk_kafka_ops.sh stop 可以分别启动和停止zk哈kafka服务。

需要依赖的Maven库

<dependencies>
    <dependency>
        <groupId>org.apache.kafka</groupId>
        <artifactId>kafka-clients</artifactId>
        <version>2.6.0</version>
    </dependency>
    <dependency>
        <groupId>com.fasterxml.jackson.core</groupId>
        <artifactId>jackson-databind</artifactId>
        <version>2.9.5</version>
    </dependency>
    <dependency>
        <groupId>org.apache.flink</groupId>
        <artifactId>flink-connector-kafka-0.11_2.12</artifactId>
        <version>1.11.2</version>
    </dependency>
    <dependency>
        <groupId>org.apache.flink</groupId>
        <artifactId>flink-streaming-java_2.12</artifactId>
        <version>1.11.2</version>
        <scope>provided</scope>
    </dependency>
    <dependency>
        <groupId>org.apache.flink</groupId>
        <artifactId>flink-clients_2.12</artifactId>
        <version>1.11.2</version>
    </dependency>
    <dependency>
        <groupId>org.projectlombok</groupId>
        <artifactId>lombok</artifactId>
        <version>1.18.16</version>
        <scope>provided</scope>
    </dependency>
    <dependency>
        <groupId>com.alibaba</groupId>
        <artifactId>fastjson</artifactId>
        <version>1.2.75</version>
    </dependency>
</dependencies>

定义一个消费记录实体类 TransactionRecord

这个实体类在生产和消费的时候都需要用到,并定义序列化和反序列化类:

package com.wanli.fraud;

import lombok.Getter;
import lombok.Setter;

/**
 * 生产和消费的实体类
 *
 * @author liwan
 * @version 1.0.0
 * @since 2020/11/19 13:49
 */
@Getter
@Setter
public class TransactionRecord {
  private long accountId;
  private long transactionTime = System.currentTimeMillis();
  private double amount;
  private String accountName;
}

序列化和反序列化类 TransactionRecordSerializer 的定义:

package com.wanli.fraud;

import com.alibaba.fastjson.JSON;
import org.apache.flink.api.common.serialization.DeserializationSchema;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.kafka.common.serialization.Serializer;

import java.io.IOException;

/**
 * 序列化和反序列化方法都在这里定义
 *
 * @author liwan
 * @version 1.0.0
 * @since 2020/11/19 14:04
 */
public class TransactionRecordSerializer implements DeserializationSchema<TransactionRecord>,
    Serializer<TransactionRecord> {

  @Override
  public void open(InitializationContext context) throws Exception {
  }

  @Override
  public TransactionRecord deserialize(byte[] message) throws IOException {
    String s = new String(message);
    TransactionRecord record = JSON.parseObject(s, TransactionRecord.class);

    // System.out.println("----- deserialize, userid=" + record.getAccountId());

    return record;
  }

  /**
   * 每次取到数据进行反序列化 之后 调用
   *
   * @param nextElement
   * @return
   */
  @Override
  public boolean isEndOfStream(TransactionRecord nextElement) {
    // System.out.println("----- detect if end, userid=" + nextElement.getAccountId());
    return false;
  }

  @Override
  public TypeInformation<TransactionRecord> getProducedType() {
    return TypeInformation.of(TransactionRecord.class);
  }

  @Override
  public byte[] serialize(String topic, TransactionRecord data) {
    String s = JSON.toJSONString(data);
    return s.getBytes();
  }
}

使用Kafka生产数据

首先来看下Kafka的配置信息:

package com.wanli.fraud;

/**
 * Kafka配置信息
 *
 * @author liwan
 * @version 1.0.0
 * @since 2020/11/19 14:19
 */
public class KafkaConfiguration {
  public static final String KAFKA_TOPIC = "test_transaction_topic";
  public static final String KAFKA_ADDRESS = "localhost:9092";
}

后面所有关于kafka的配置信息都从本类里面取得即可。

Kafka数据生产者 TransactionDataProducer

package com.wanli.fraud;

import org.apache.kafka.clients.Metadata;
import org.apache.kafka.clients.producer.KafkaProducer;
import org.apache.kafka.clients.producer.Producer;
import org.apache.kafka.clients.producer.ProducerConfig;
import org.apache.kafka.clients.producer.ProducerRecord;
import org.apache.kafka.common.serialization.StringSerializer;

import java.util.Properties;
import java.util.Random;
import java.util.concurrent.Future;

/**
 * 生产者生产数据推送到Kafka
 *
 * @author liwan
 * @version 1.0.0
 * @since 2020/11/19 14:22
 */
public class TransactionDataProducer implements Runnable {
  private Producer producer;

  public TransactionDataProducer() {
    Properties config = new Properties();
    config.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, KafkaConfiguration.KAFKA_ADDRESS);
    config.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, StringSerializer.class);
    config.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, TransactionRecordSerializer.class); // 序列化
    config.put(ProducerConfig.ACKS_CONFIG, "all");
    config.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG,
        KafkaConfiguration.KAFKA_ADDRESS + "_" + Thread.currentThread().getId());

    producer = new KafkaProducer<String, TransactionRecord>(config);
    producer.initTransactions(); // 初始化kafka事务
  }

  /**
   * 生产数据,每1秒随机生产一条数据,共生产1000条
   */
  @Override
  public void run() {
    try {
      for (int i = 0; i < 1000; i++) {
        Thread.sleep(1000);
        TransactionRecord record = getRecord();
        System.out.println("--userId:" + record.getAccountId() + " amount:" + record.getAmount());

        String key = record.getAccountName();

        producer.beginTransaction();
        Future<Metadata> result = producer.send(new ProducerRecord(KafkaConfiguration.KAFKA_TOPIC, key, record));
        producer.commitTransaction();
      }
    } catch (Exception e) {
      e.printStackTrace();
    }
  }

  public TransactionRecord getRecord() {
    long[] userIds = new long[] {1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 10L};
    String[] userNames = new String[] {"u1", "u2", "u3", "u4", "u5", "u6", "u7", "u8", "u9", "u10"};

    Random random = new Random(System.currentTimeMillis());
    int r = random.nextInt(10);
    double amount = random.nextDouble() * 1000;

    TransactionRecord record = new TransactionRecord();
    record.setAccountId(userIds[r]);
    record.setAccountName(userNames[r]);
    record.setAmount(amount);
    return record;
  }
}

只要线程运行起来,就可以不断的生产数据啦。

定义Flink执行的主流程 FraudDetectionJob

这里先介绍主流程,然后后面一一介绍依赖的其他代码:

package com.wanli.fraud;

import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;

/**
 * 定义Flink执行的主流程
 *
 * @author liwan
 * @version 1.0.0
 * @since 2020/11/19 14:44
 */
public class FraudDetectionJob implements Runnable {
  StreamExecutionEnvironment env;

  public FraudDetectionJob() {
    // 初始化 Flink 运行环境
    env = StreamExecutionEnvironment.getExecutionEnvironment();
  }

  @Override
  public void run() {
    try {
      // 设置数据源
      DataStream<TransactionRecord> transactions = env.addSource(new TransactionSource())
          .name("transactions");

      DataStream<Alert> alerts = transactions.keyBy(TransactionRecord::getAccountId) // 按照用户id对数据分组处理
          .process(new FraudDetectorNoTimer()) // 使用处理器监控主流程,注意这里使用的是不带定时器的版本,如果需要定时器版本使用 FraudDetector 即可
          .name("fraud-detector");

      alerts.addSink(new AlertSink()) // 在sink阶段使用AlertSink打印错误信息
          .name("send-alerts");

      env.execute("FraudDetectionJob"); // 启动作业
    } catch (Exception e) {
      e.printStackTrace();
    }
  }
}

流程说明如下:

  • 第一步:设置数据源,这里是设置为 TransactionSource,后面会介绍,实际上是封装的Kafka数据源;
  • 第二步:交易记录根据账号分组transactions.keyBy(TransactionRecord::getAccountId)
  • 第三步:针对组内的数据监控和欺诈检测:process(new FraudDetectorNoTimer()),需要注意的是,目前设置的是初阶的demo,即不进行时间范围限制。这一步处理完成以后报警信息会转换为 Alert 对象
  • 第四步:在sink阶段打印欺诈检测信息:addSink(new AlertSink())

可以看到,整体流程就是:读数据流 --> 根据账号分组 --> 组内数据检测 --> 输出。

数据源定义 TransactionSource

这里,自己封装了数据源的定义,包括反序列化等:

package com.wanli.fraud;

import org.apache.flink.api.common.serialization.DeserializationSchema;
import org.apache.flink.streaming.connectors.kafka.FlinkKafkaConsumer010;
import org.apache.kafka.clients.consumer.ConsumerConfig;

import java.util.Properties;

/**
 * 定义Kafka数据源
 *
 * @author liwan
 * @version 1.0.0
 * @since 2020/11/19 13:53
 */
public class TransactionSource extends FlinkKafkaConsumer010<TransactionRecord> {

  public TransactionSource() {
    this(KafkaConfiguration.KAFKA_TOPIC,
        new TransactionRecordSerializer(), // 反序列化定义
        new Properties() {{
          put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, KafkaConfiguration.KAFKA_ADDRESS);
          put(ConsumerConfig.GROUP_ID_CONFIG, KafkaConfiguration.KAFKA_TOPIC + "-" + Thread.currentThread().getId());
        }});
  }

  public TransactionSource(String topic, DeserializationSchema<TransactionRecord> valueDeserializer, Properties props) {
    super(topic, valueDeserializer, props);
  }
}

最重要的检测类定义 FraudDetectorNoTimer

检测类继承自 KeyedProcessFunction 类:

  • 在初始化的时候会调用其 open 方法;
  • 在数据到达的时候会调用其 processElement 方法;
  • 如果设置了定时器,会在时间到达时调用其 onTimer 方法(后面说明)。

使用这种机制,我们可以实现有状态的数据处理,在对象中保存其状态,并进行状态的维护等操作。

package com.wanli.fraud;

import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
import org.apache.flink.util.Collector;

/**
 * 无定时器版风险监控,只要监控到连续两笔消费,
 * 第一笔小于 SMALL_AMOUNT 且第二笔大于 LARGE_AMOUNT 即进行报警
 *
 * @author liwan
 * @version 1.0.0
 * @since 2020/11/19 14:37
 */
public class FraudDetectorNoTimer extends KeyedProcessFunction<Long, TransactionRecord, Alert> {
  private static final long serialVersionUID = 1L;

  private static final double SMALL_AMOUNT = 10.00;
  private static final double LARGE_AMOUNT = 500.00;

  // 定义数据流的状态信息,这里保存的是上一次小金额交易记录
  private transient ValueState<TransactionRecord> lastSmallTransaction;

  @Override
  public void open(Configuration parameters) {
    ValueStateDescriptor<TransactionRecord> des1 = new ValueStateDescriptor<>("des1", TransactionRecord.class);
    lastSmallTransaction = getRuntimeContext().getState(des1);
  }

  @Override
  public void processElement(TransactionRecord record, Context context, Collector<Alert> collector)
      throws Exception {
    if (record.getAmount() <= SMALL_AMOUNT) { // 小额消费就记录状态
      lastSmallTransaction.update(record);
      return;
    }

    // 上一笔不是小额消费,则直接忽略
    if (lastSmallTransaction.value() == null) return;

    // 上一笔是小额消费,且这一笔是大额消费,则直接收集预警信息
    if (record.getAmount() > LARGE_AMOUNT) {
      Alert alert = new Alert();

      alert.setRecord1(lastSmallTransaction.value());
      alert.setRecord2(record);

      collector.collect(alert);
    }

    // 清除状态
    clearState();
  }

  public void clearState() {
    lastSmallTransaction.clear();
  }
}

依赖的 Alert 类记录了连续两次消费的记录:

package com.wanli.fraud;

import lombok.Getter;
import lombok.Setter;

/**
 * 报警信息记录,这里记录报警的两次记录
 *
 * @author liwan
 * @version 1.0.0
 * @since 2020/11/19 14:49
 */
@Getter
@Setter
public class Alert {
  private TransactionRecord record1;
  private TransactionRecord record2;
}

在 sink 阶段打印预警信息 AlertSink

sink的类需要继承自 SinkFunction

package com.wanli.fraud;

import org.apache.flink.streaming.api.functions.sink.SinkFunction;

/**
 * 报警信息输出
 *
 * @author liwan
 * @version 1.0.0
 * @since 2020/11/19 14:50
 */
public class AlertSink implements SinkFunction<Alert> {

  public void invoke(Alert alert, Context context) throws Exception {
    TransactionRecord r1 = alert.getRecord1();
    TransactionRecord r2 = alert.getRecord2();

    System.out.println(
        "警报:userId=" + r1.getAccountId() +
            ", 前一次消费: " + r1.getAmount() +
            ", 后一次消费:" + r2.getAmount() +
            ", 时间间隔:" + (r2.getTransactionTime() - r1.getTransactionTime()) / 1000.0F + " s");
  }
}

最后,主程序

package com.wanli.fraud;

/**
 * 入口程序
 *
 * @author liwan
 * @version 1.0.0
 * @since 2020/11/19 14:15
 */
public class Main {

  public static void main(String[] args) throws Exception {
    startProducer(); // 启动Kafka生产数据
    startDetection(); // 使用Flink消费并监测数据
  }

  public static void startProducer() throws Exception {
    new Thread(new TransactionDataProducer()).start();
  }

  public static void startDetection() throws Exception {
    new Thread(new FraudDetectionJob()).start();
  }
}

进阶:带定时器版本

前面进阶说,需要在一定时间内先后发生两笔交易,第一笔低金额消费,第二笔高金额消费,那么才可以认定为欺诈行为,我们把上述的 FraudDetectorNoTimer 类改一改:

package com.wanli.fraud;

import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
import org.apache.flink.util.Collector;

/**
 * 定时器版本风险监控,监控在 ONE_MINUTE 毫秒范围内,
 * 只要监控到连续两笔消费,第一笔小于 SMALL_AMOUNT 且第二笔大于 LARGE_AMOUNT 即进行报警
 *
 * @author liwan
 * @version 1.0.0
 * @since 2020/11/19 14:37
 */
public class FraudDetector extends KeyedProcessFunction<Long, TransactionRecord, Alert> {
  private static final long serialVersionUID = 1L;

  private static final double SMALL_AMOUNT = 50.00;
  private static final double LARGE_AMOUNT = 500.00;
  private static final long ONE_MINUTE = 60L * 1000; // 1分钟内发生有效

  // 定义数据流的状态信息,这里保存的是上一次小金额交易记录
  private transient ValueState<TransactionRecord> lastSmallTransaction;
  private transient ValueState<Long> lastTime;

  @Override
  public void open(Configuration parameters) {
    ValueStateDescriptor<TransactionRecord> des1 = new ValueStateDescriptor<>("des1", TransactionRecord.class);
    lastSmallTransaction = getRuntimeContext().getState(des1);

    ValueStateDescriptor<Long> des2 = new ValueStateDescriptor<>("des2", Long.class);
    lastTime = getRuntimeContext().getState(des2);
  }

  @Override
  public void processElement(TransactionRecord record, Context context, Collector<Alert> collector)
      throws Exception {
    if (record.getAmount() <= SMALL_AMOUNT) {
      lastSmallTransaction.update(record);

      // 清除原有定时器,防止出现短时间内多次低消费的情况
      if (lastTime.value() != null) {
        System.out.println("清除定时器-1 userid=" + record.getAccountId());
        clearTimer(context);
      }

      long timer = context.timerService().currentProcessingTime() + ONE_MINUTE;
      lastTime.update(timer);

      // 重新注册定时器
      System.out.println("注册定时器 userid=" + record.getAccountId());
      context.timerService().registerProcessingTimeTimer(timer); // 注册定时器,到时间会自动调用 onTimer

      return;
    }

    if (lastSmallTransaction.value() == null) return;

    if (record.getAmount() > LARGE_AMOUNT) {
      Alert alert = new Alert();

      alert.setRecord1(lastSmallTransaction.value());
      alert.setRecord2(record);

      collector.collect(alert);
    }

    if (lastTime.value() != null) {
      System.out.println("清除定时器-2 userid=" + record.getAccountId());
      clearTimer(context); // 清除定时器
    }

    clearState();
  }

  /**
   * 定时器时间到达的时候执行,表示 ONE_MINUTE 时间内发生的事情才有效
   *
   * @param timestamp
   * @param ctx
   * @param out
   * @throws Exception
   */
  @Override
  public void onTimer(long timestamp, OnTimerContext ctx, Collector<Alert> out) throws Exception {
    System.out.println("<<<定时任务>>>");
    clearState();
  }

  public void clearState() {
    lastSmallTransaction.clear();
    lastTime.clear();
  }

  /**
   * 清除定时器
   *
   * @param context
   * @throws Exception
   */
  public void clearTimer(Context context) throws Exception {
    long t = lastTime.value();
    context.timerService().deleteProcessingTimeTimer(t);
  }
}

这样的话,在每次数据到达的话,如果是小金额消费就注册一个一分钟后执行 onTimer 方法的定时器,在 onTimer 中清除状态,这样跨过一分钟时间区间的就不会被识别为欺诈。执行流程如下: