另类优化 SpringBoot 并发性能的方法

635 阅读10分钟

前言

        最近有个项目找到我,说是一个医院客户需要一个每天记录员工的健康状况,也就是一个简单的表单,主要就是所在地址,有没经过高风险地区,如何通勤,体温是否正常等一些常规问题,但由于是公立医院,职工较多。而且基本都是在上班前需要上报,所以在 7-8点 会有一个明显的高峰段。最后客户提出的要求是必须并发要在 500 以上

开始

       接到这个任务后,我就哈哈了 500,就 500 能有多大问题,咱 SpringBoot 难道 500 都不行,那不可能。为了测试结果我专门下载并安装了 jmeter ,之后就用它来测试。为了简化问题,我们添加一个简单的接口,并实现一个简单的表单入库逻辑。

基础代码

首先我们需要一个 SpringBoot 项目,直接在 springboot 官网生成并下载 start.spring.io/

这里使用 SpringBoot 2.7.1 并添加 web 与 data 依赖,生成代码。包路径与项目名都是默认的,只为方便

一段时间等待之后,我们添加一个简单的接口,与实体。并实现入库逻辑。

由于我比较中意yml 的配置方式,所以将 application.properties 换成了 application.yml

application.yml

spring:
  datasource:
    name: datasource
    url: jdbc:mysql://dev.local:3306/demo?autoReconnect=true&useUnicode=true&characterEncoding=UTF-8&zeroDateTimeBehavior=convertToNull
    username: root
    password: yV2jJxvNs8BD

添加 mysql 驱动与 lombok

build.gradle

implementation 'mysql:mysql-connector-java'
compileOnly "org.projectlombok:lombok:1.18.24"
annotationProcessor "org.projectlombok:lombok:1.18.24"
testCompileOnly "org.projectlombok:lombok:1.18.24"
testAnnotationProcessor "org.projectlombok:lombok:1.18.24"

继续添加样板代码

domain/User.java

package com.example.demo.domain;


import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import lombok.*;
import org.springframework.data.annotation.CreatedDate;
import org.springframework.data.annotation.LastModifiedDate;

import javax.persistence.*;
import java.util.Date;

@Getter
@Setter
@ToString
@RequiredArgsConstructor
@Builder
@AllArgsConstructor
@Entity
@Table(name = "DEMO_USER")
@JsonIgnoreProperties({"hibernateLazyInitializer", "handler"})
public class User {
    @Id
    @GeneratedValue(strategy = GenerationType.IDENTITY)
    @Column(name = "ID", nullable = false, updatable = false)
    private Long id;

    @Column(name = "NICKNAME", length = 30)
    private String nickname;

    @Column(name = "USERNAME", length = 30)
    private String username;

    @Column(name = "PASSWORD", length = 21)
    private String password;

    @CreatedDate
    @Temporal(TemporalType.TIMESTAMP)
    @Column(updatable = false, name = "CREATED_AT")
    private Date createdAt;

    @LastModifiedDate
    @Temporal(TemporalType.TIMESTAMP)
    @Column(name = "UPDATED_AT")
    private Date updatedAt;

}

dao/UserDao.java

package com.example.demo.dao;

import com.example.demo.domain.User;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.stereotype.Repository;

@Repository
public interface UserDao extends JpaRepository<User, Long> {
}

service/UserService.java

package com.example.demo.service;

import com.example.demo.dao.UserDao;
import com.example.demo.domain.User;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;

@Service
public class UserService {

    private final UserDao userDao;

    public UserService(UserDao userDao) {
        this.userDao = userDao;
    }

    public Page<User> findPage(Pageable pageable) {
        return this.userDao.findAll(pageable);
    }

    @Transactional
    public User save(User user) {
        return this.userDao.save(user);
    }
}

web/UserController.java

package com.example.demo.web;

import com.example.demo.domain.User;
import com.example.demo.service.UserService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.*;

@Slf4j
@Controller
public class UserController {

    private final UserService userService;

    public UserController(UserService userService) {
        this.userService = userService;
    }

    @GetMapping("/users")
    @ResponseBody
    public Page<User> users() {
        return this.userService.findPage(Pageable.ofSize(20));
    }

    @PostMapping("/users")
    @ResponseStatus(HttpStatus.CREATED)
    @ResponseBody
    public User addUser(@RequestBody User user) {
        return this.userService.save(user);
    }
}

启动,并访问 post: /users 与 get: /users  接口

嗯,没问题,现在使用 jmeter 来对 post /users  接口并发插入数据

开始测试

post_user.jtl 可以通过 apifox 生成一个,然后使用 jmeter 命令行方式执行,
模拟 10W 数据,并发 500
jmeter -n -t post_user.jmx -l post_user.jtl -e -o post_user_ResultReport

结果并不理想,tps 300 左右,而且请求时间都在 1s 以上, 那还有什么办法。修改连接池还是??先让我们分析慢的原因

x是请求连接线程, y是数据库连接

x*  表示客户端请求 y*  表示数据库连接

理论上来说 x* 有多少 y* 就有多少

那优化的重点就是减少 x* 与 y* 的数量

首先请求连接如何优化,之前连接数一直受容器服务,如 tomcat 线程数限制,但 Servlet 3.0 不是支持异步 Servlet 吗?理论上是支持类似 Netty 非堵塞效果

数据库连接,如何优化。最简单的想法就是,把同一时间内数据,在同一个事物里面提交到数据库中去。不是就可以省下很多。

有了思路我们就开始

先改造数据库操作,为了把同一时间的请求在一个事物里提交,那这个提交操作一个是在单独的线程中处理的。所以需要添加几个类

数据封装(数据)Cargo.java

package com.example.demo.batch;

import lombok.Getter;

import java.util.concurrent.CompletableFuture;

@Getter
public class Cargo<T, R> {
  private final CompletableFuture<R> hearthstone = new CompletableFuture<>();
  private final T content;

  private Cargo(T o) {
    this.content = o;
  }

  public static <T, R> Cargo<T, R> of(T o) {
    return new Cargo<>(o);
  }
}

收集数据类(队列)LinkedQueue.java

package com.example.demo.batch;


import java.util.*;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;

import lombok.extern.slf4j.Slf4j;

@Slf4j
public class LinkedQueue<E> extends AbstractQueue<E> implements BlockingQueue<E>, java.io.Serializable {

    private static final long serialVersionUID = -4457362206741191196L;

    static class Node<E> {
        volatile E item;
        Node<E> next;
        Node<E> previous;

        // Node(E x) { item = x; } 修改 内部类 的构造方法
        Node(E item, Node<E> next, Node<E> previous) {
            this.item = item;
            this.next = next;
            this.previous = previous;
        }
    }

    private final int capacity;
    private final AtomicInteger count = new AtomicInteger(0);
    private transient Node<E> head;
    private transient Node<E> last;
    private final ReentrantLock takeLock = new ReentrantLock();
    private final Condition notEmpty = takeLock.newCondition();
    private final ReentrantLock putLock = new ReentrantLock();
    private final Condition notFull = putLock.newCondition();
    private final List<E> list = new Li();

    private void signalNotEmpty() {
        this.takeLock.lock();
        try {
            this.notEmpty.signal();
        } finally {
            this.takeLock.unlock();
        }
    }

    private void signalNotFull() {
        this.putLock.lock();
        try {
            this.notFull.signal();
        } finally {
            this.putLock.unlock();
        }
    }

    private void insert(E x) {
        last = last.next = new Node<>(x, null, last);
        head.previous = last;
    }

    private E extract() {
        Node<E> first = head.next;
        head = first;
        head.previous = last;
        E x = first.item;
        first.item = null;
        return x;
    }

    public void fullyLock() {
        putLock.lock();
        takeLock.lock();
    }

    public void fullyUnlock() {
        takeLock.unlock();
        putLock.unlock();
    }

    public LinkedQueue() {
        this(Integer.MAX_VALUE);
    }

    public LinkedQueue(int capacity) {
        if (capacity <= 0) {
            throw new IllegalArgumentException();
        }
        this.capacity = capacity;
        last = head = new Node<>(null, null, null);
    }

    public LinkedQueue(Collection<? extends E> c) {
        this(Integer.MAX_VALUE);
        this.addAll(c);
    }

    @Override
    public int size() {
        return count.get();
    }

    @Override
    public int remainingCapacity() {
        return capacity - count.get();
    }

    @Override
    public void put(E o) throws InterruptedException {
        if (o == null) {
            throw new NullPointerException();
        }
        int c;
        putLock.lockInterruptibly();
        try {
            try {
                while (count.get() == capacity) {
                    notFull.await();
                }
            } catch (InterruptedException ie) {
                notFull.signal();
                throw ie;
            }
            insert(o);
            c = count.getAndIncrement();
            if (c + 1 < capacity) {
                notFull.signal();
            }
        } finally {
            putLock.unlock();
        }
        if (c == 0) {
            signalNotEmpty();
        }
    }

    @Override
    public boolean offer(E o, long timeout, TimeUnit unit) throws InterruptedException {

        if (o == null) {
            throw new NullPointerException();
        }
        long nanos = unit.toNanos(timeout);
        int c;
        final ReentrantLock putLock = this.putLock;
        final AtomicInteger count = this.count;
        putLock.lockInterruptibly();
        try {
            for (; ; ) {
                if (count.get() < capacity) {
                    insert(o);
                    c = count.getAndIncrement();
                    if (c + 1 < capacity) {
                        notFull.signal();
                    }
                    break;
                }
                if (nanos <= 0) {
                    return false;
                }
                try {
                    nanos = notFull.awaitNanos(nanos);
                } catch (InterruptedException ie) {
                    notFull.signal();
                    throw ie;
                }
            }
        } finally {
            putLock.unlock();
        }
        if (c == 0) {
            signalNotEmpty();
        }
        return true;
    }

    @Override
    public boolean offer(E o) {
        if (o == null) {
            throw new NullPointerException();
        }
        final AtomicInteger count = this.count;
        if (count.get() == capacity) {
            return false;
        }
        int c = -1;
        final ReentrantLock putLock = this.putLock;
        putLock.lock();
        try {
            if (count.get() < capacity) {
                insert(o);
                c = count.getAndIncrement();
                if (c + 1 < capacity) {
                    notFull.signal();
                }
            }
        } finally {
            putLock.unlock();
        }
        if (c == 0) {
            signalNotEmpty();
        }
        return c >= 0;
    }

    @Override
    public E take() throws InterruptedException {
        E x;
        int c;
        final AtomicInteger count = this.count;
        final ReentrantLock takeLock = this.takeLock;
        takeLock.lockInterruptibly();
        try {
            try {
                while (count.get() == 0) {
                    notEmpty.await();
                }
            } catch (InterruptedException ie) {
                notEmpty.signal();
                throw ie;
            }

            x = extract();
            c = count.getAndDecrement();
            if (c > 1) {
                notEmpty.signal();
            }
        } finally {
            takeLock.unlock();
        }
        if (c == capacity) {
            signalNotFull();
        }
        return x;
    }

    @Override
    public E poll(long timeout, TimeUnit unit) throws InterruptedException {
        E x;
        int c;
        long nanos = unit.toNanos(timeout);
        final AtomicInteger count = this.count;
        final ReentrantLock takeLock = this.takeLock;
        takeLock.lockInterruptibly();
        try {
            for (; ; ) {
                if (count.get() > 0) {
                    x = extract();
                    c = count.getAndDecrement();
                    if (c > 1) {
                        notEmpty.signal();
                    }
                    break;
                }
                if (nanos <= 0) {
                    return null;
                }
                try {
                    log.debug("等待时间:" + nanos + "\t" + TimeUnit.NANOSECONDS.toMillis(nanos));
                    nanos = notEmpty.awaitNanos(nanos);
                    log.debug("剩余时间:" + nanos + "\t" + TimeUnit.NANOSECONDS.toMillis(nanos));
                } catch (InterruptedException ie) {
                    notEmpty.signal();
                    throw ie;
                }
            }
        } finally {
            takeLock.unlock();
        }
        if (c == capacity) {
            signalNotFull();
        }
        return x;
    }

    @Override
    public E poll() {
        final AtomicInteger count = this.count;
        if (count.get() == 0) {
            return null;
        }
        E x = null;
        int c = -1;
        final ReentrantLock takeLock = this.takeLock;
        takeLock.lock();
        try {
            if (count.get() > 0) {
                x = extract();
                c = count.getAndDecrement();
                if (c > 1) {
                    notEmpty.signal();
                }
            }
        } finally {
            takeLock.unlock();
        }
        if (c == capacity) {
            signalNotFull();
        }
        return x;
    }

    @Override
    public E peek() {
        if (count.get() == 0) {
            return null;
        }
        final ReentrantLock takeLock = this.takeLock;
        takeLock.lock();
        try {
            Node<E> first = head.next;
            if (first == null) {
                return null;
            } else {
                return first.item;
            }
        } finally {
            takeLock.unlock();
        }
    }

    @Override
    public boolean remove(Object o) {
        if (o == null) {
            return false;
        }
        boolean removed = false;
        fullyLock();
        try {
            Node<E> trail = head;
            Node<E> p = head.next;
            while (p != null) {
                if (o.equals(p.item)) {
                    removed = true;
                    break;
                }
                trail = p;
                p = p.next;
            }
            if (removed) {
                p.item = null;
                trail.next = p.next;
                if (p.next != null) {
                    p.next.previous = trail;
                }
                if (count.getAndDecrement() == capacity) {
                    notFull.signalAll();
                }
            }
        } finally {
            fullyUnlock();
        }
        return removed;
    }

    @Override
    public Object[] toArray() {
        fullyLock();
        try {
            int size = count.get();
            Object[] a = new Object[size];
            int k = 0;
            for (Node<E> p = head.next; p != null; p = p.next) {
                a[k++] = p.item;
            }
            return a;
        } finally {
            fullyUnlock();
        }
    }

    @Override
    public <T> T[] toArray(T[] a) {
        fullyLock();
        try {
            int size = count.get();
            if (a.length < size) {
                a = (T[]) java.lang.reflect.Array.newInstance(a.getClass().getComponentType(), size);
            }
            int k = 0;
            for (Node<E> p = head.next; p != null; p = p.next) {
                a[k++] = (T) p.item;
            }
            return a;
        } finally {
            fullyUnlock();
        }
    }

    @Override
    public String toString() {
        fullyLock();
        try {
            return super.toString();
        } finally {
            fullyUnlock();
        }
    }

    @Override
    public void clear() {
        fullyLock();
        try {
            head.next = null;
            assert head.item == null;
            last = head;
            if (count.getAndSet(0) == capacity) {
                notFull.signalAll();
            }
        } finally {
            fullyUnlock();
        }
    }

    protected void _clear() {
        head.next = null;
        assert head.item == null;
        last = head;
        if (count.getAndSet(0) == capacity) {
            notFull.signalAll();
        }
    }

    @Override
    public int drainTo(Collection<? super E> c) {
        if (c == null) {
            throw new NullPointerException();
        }
        if (c == this) {
            throw new IllegalArgumentException();
        }
        Node<E> first;
        fullyLock();
        try {
            first = head.next;
            head.next = null;
            assert head.item == null;
            last = head;
            if (count.getAndSet(0) == capacity) {
                notFull.signalAll();
            }
        } finally {
            fullyUnlock();
        }
        int n = 0;
        for (Node<E> p = first; p != null; p = p.next) {
            c.add(p.item);
            p.item = null;
            ++n;
        }
        return n;
    }

    @Override
    public int drainTo(Collection<? super E> c, int maxElements) {
        if (c == null) {
            throw new NullPointerException();
        }
        if (c == this) {
            throw new IllegalArgumentException();
        }
        fullyLock();
        try {
            int n = 0;
            Node<E> p = head.next;
            while (p != null && n < maxElements) {
                c.add(p.item);
                p.item = null;
                p = p.next;
                ++n;
            }
            if (n != 0) {
                head.next = p;
                assert head.item == null;
                if (p == null) {
                    last = head;
                }
                if (count.getAndAdd(-n) == capacity) {
                    notFull.signalAll();
                }
            }
            return n;
        } finally {
            fullyUnlock();
        }
    }

    @Override
    @SuppressWarnings("NullableProblems")
    public Iterator<E> iterator() {
        return new Itr();
    }

    private class Itr implements Iterator<E> {
        private Node<E> current;
        private Node<E> lastRet;
        private E currentElement;

        Itr() {
            final ReentrantLock putLock = LinkedQueue.this.putLock;
            final ReentrantLock takeLock = LinkedQueue.this.takeLock;
            putLock.lock();
            takeLock.lock();
            try {
                current = head.next;
                if (current != null) {
                    currentElement = current.item;
                }
            } finally {
                takeLock.unlock();
                putLock.unlock();
            }
        }

        @Override
        public boolean hasNext() {
            return current != null;
        }

        @Override
        public E next() {
            final ReentrantLock putLock = LinkedQueue.this.putLock;
            final ReentrantLock takeLock = LinkedQueue.this.takeLock;
            putLock.lock();
            takeLock.lock();
            try {
                if (current == null) {
                    throw new NoSuchElementException();
                }
                E x = currentElement;
                lastRet = current;
                current = current.next;
                if (current != null) {
                    currentElement = current.item;
                }
                return x;
            } finally {
                takeLock.unlock();
                putLock.unlock();
            }
        }

        @Override
        public void remove() {
            if (lastRet == null) {
                throw new IllegalStateException();
            }
            final ReentrantLock putLock = LinkedQueue.this.putLock;
            final ReentrantLock takeLock = LinkedQueue.this.takeLock;
            putLock.lock();
            takeLock.lock();
            try {
                Node<E> node = lastRet;
                lastRet = null;
                Node<E> trail = head;
                Node<E> p = head.next;
                while (p != null && p != node) {
                    trail = p;
                    p = p.next;
                }
                if (p == node) {
                    assert p != null;
                    p.item = null;
                    trail.next = p.next;
                    int c = count.getAndDecrement();
                    if (c == capacity) {
                        notFull.signalAll();
                    }
                }
            } finally {
                takeLock.unlock();
                putLock.unlock();
            }
        }
    }

    private void writeObject(java.io.ObjectOutputStream s) throws java.io.IOException {
        fullyLock();
        try {
            s.defaultWriteObject();
            for (Node<E> p = head.next; p != null; p = p.next) {
                s.writeObject(p.item);
            }
            s.writeObject(null);
        } finally {
            fullyUnlock();
        }
    }

    private void readObject(java.io.ObjectInputStream s) throws java.io.IOException, ClassNotFoundException {
        s.defaultReadObject();
        count.set(0);
        last = head = new Node<>(null, null, null);
        for (; ; ) {
            E item = (E) s.readObject();
            if (item == null) {
                break;
            }
            add(item);
        }
    }

    public E get(int index) {
        return entry(index).item;
    }

    private Node<E> entry(int index) {
        int size = count.get();
        if (index < 0 || index >= size) {
            throw new IndexOutOfBoundsException("Index: " + index + ", Size: " + size);
        }
        Node<E> e = head;
        if (index < (size >> 1)) {
            for (int i = 0; i <= index; i++) {
                e = e.next;
            }
        } else {
            for (int i = size; i > index; i--) {
                e = e.previous;
            }
        }
        return e;
    }

    public boolean add(int index, E o) {
        final AtomicInteger count = this.count;
        if (count.get() == capacity) {
            return false;
        }
        int c = -1;
        final ReentrantLock putLock = this.putLock;
        putLock.lock();
        try {
            if (count.get() < capacity) {
                if (index >= size()) {
                    this.insert(o);
                } else {
                    Node<E> p = entry(index);
                    p.previous.next = p.previous = new Node<>(o, p, p.previous);
                }
                c = count.getAndIncrement();
                if (c + 1 < capacity) {
                    notFull.signal();
                }
            }
        } finally {
            putLock.unlock();
        }
        if (c == 0) {
            signalNotEmpty();
        }
        return c >= 0;
    }

    public int indexOf(E o) {
        int index = 0;
        if (o == null) {
            for (Node<E> e = head.next; e != head; e = e.next) {
                if (e.item == null) {
                    return index;
                }
                index++;
            }
        } else {
            for (Node<E> e = head.next; e != head; e = e.next) {
                if (o.equals(e.item)) {
                    return index;
                }
                index++;
            }
        }
        return -1;
    }

    public E set(int index, E element) {
        Node<E> e = entry(index);
        E oldVal = e.item;
        e.item = element;
        return oldVal;
    }

    public E remove(int index) {
        E oldVal = entry(index).item;
        boolean removed = remove(oldVal);
        return removed ? oldVal : null;
    }

    public void poll(long timeout, TimeUnit unit, PollCallBack<E> pollCallBack) throws InterruptedException {
        int c = -1;
        long nanos = unit.toNanos(timeout);
        final AtomicInteger count = this.count;
        final ReentrantLock takeLock = this.takeLock;
        takeLock.lockInterruptibly();
        try {
            for (; ; ) {
                if (count.get() > 0) {
                    E x = extract();
                    c = count.getAndDecrement();
                    if (pollCallBack.CallBack(x, this.peek())) {
                        continue;
                    } else {
                        break;
                    }
                }
                if (nanos <= 0) {
                    break;
                }
                try {
                    nanos = notEmpty.awaitNanos(nanos);
                } catch (InterruptedException ie) {
                    notEmpty.signal();
                    throw ie;
                }
            }
        } finally {
            takeLock.unlock();
        }
        if (c > 1) {
            notEmpty.signal();
        }
        if (c == capacity) {
            signalNotFull();
        }
    }

    public interface PollCallBack<E> {

        boolean CallBack(E current, E next);
    }

    public List<E> list() {
        return list;
    }

    private class Li implements List<E> {

        @Override
        public boolean add(E o) {
            return LinkedQueue.this.add(o);
        }

        @Override
        public void add(int index, E element) {
            LinkedQueue.this.add(index, element);
        }

        @Override
        public boolean addAll(Collection<? extends E> c) {
            return LinkedQueue.this.addAll(c);
        }

        @Override
        public boolean addAll(int index, Collection<? extends E> c) {
            throw new RuntimeException("null method");
        }

        @Override
        public void clear() {
            LinkedQueue.this.clear();
        }

        @Override
        public boolean contains(Object o) {
            return LinkedQueue.this.contains(o);
        }

        @Override
        public boolean containsAll(Collection<?> c) {
            return LinkedQueue.this.containsAll(c);
        }

        @Override
        public E get(int index) {
            return LinkedQueue.this.get(index);
        }

        @Override
        public int indexOf(Object o) {
            return LinkedQueue.this.indexOf((E) o);
        }

        @Override
        public boolean isEmpty() {
            return LinkedQueue.this.isEmpty();
        }

        @Override
        public Iterator<E> iterator() {
            return LinkedQueue.this.iterator();
        }

        @Override
        public int lastIndexOf(Object o) {
            throw new RuntimeException("null method");
        }

        @Override
        public ListIterator<E> listIterator() {
            throw new RuntimeException("null method");
        }

        @Override
        public ListIterator<E> listIterator(int index) {
            throw new RuntimeException("null method");
        }

        @Override
        public boolean remove(Object o) {
            return LinkedQueue.this.remove(o);
        }

        @Override
        public E remove(int index) {
            return LinkedQueue.this.remove(index);
        }

        @Override
        public boolean removeAll(Collection<?> c) {
            return LinkedQueue.this.removeAll(c);
        }

        @Override
        public boolean retainAll(Collection<?> c) {
            return LinkedQueue.this.retainAll(c);
        }

        @Override
        public E set(int index, E element) {
            return LinkedQueue.this.set(index, element);
        }

        @Override
        public int size() {
            return LinkedQueue.this.size();
        }

        @Override
        public List<E> subList(int fromIndex, int toIndex) {
            throw new RuntimeException("null method");
        }

        @Override
        public Object[] toArray() {
            return LinkedQueue.this.toArray();
        }

        @Override
        @SuppressWarnings({"SuspiciousToArrayCall"})
        public <T> T[] toArray(T[] a) {
            return LinkedQueue.this.toArray(a);
        }
    }
}

匹配提交类(工人)Worker.java

package com.example.demo.batch;

import lombok.extern.slf4j.Slf4j;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;

@Slf4j
public class Worker<T, R> implements Runnable {
    private final LinkedQueue<Cargo<T, R>> queue = new LinkedQueue<>();
    private final int batchSize;
    private final Consumer<List<Cargo<T, R>>> saver;

    public Worker(Consumer<List<Cargo<T, R>>> saver, int batchSize) {
        this.saver = saver;
        this.batchSize = batchSize;
    }

    public CompletableFuture<R> add(T o) {
        Cargo<T, R> item = Cargo.of(o);
        queue.add(item);
        return item.getHearthstone();
    }

    public List<Cargo<T, R>> getItems() {
        try {
            Cargo<T, R> item = queue.take();
            List<Cargo<T, R>> items = new ArrayList<>(batchSize);
            items.add(item);
            item = queue.poll();
            while (item != null) {
                items.add(item);

                if (items.size() >= batchSize) {
                    return items;
                }

                item = queue.poll();
            }
            return items;
        } catch (InterruptedException e) {
            log.error(e.getMessage());
            throw new RuntimeException(e.getMessage());
        }
    }

    @Override
    public void run() {
        do {
            List<Cargo<T, R>> items = getItems();
            save(items);
        } while (true);
    }

    private void save(List<Cargo<T, R>> items) {
        try {
            saver.accept(items);
        } catch (Exception e) {
            log.error(e.getMessage(), e);
            items.forEach(item -> item.getHearthstone().obtrudeException(e));
        }
    }
}

批量操作接口 BatchService.java

package com.example.demo.batch;

import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;

public interface BatchService<T, R> {

  /**
   * 生成批量提交服务
   *
   * @param saver 保存方法
   * @param batchSize 批处理大小
   * @param works 工人数量
   * @param <T> 类型
   * @return BatchService<T>
   */
  static <T, R> BatchService<T, R> create(
      Consumer<List<Cargo<T, R>>> saver, int batchSize, int works) {
    return new DefaultBatchService<>(saver, batchSize, works);
  }

  CompletableFuture<R> submit(T entity);
}

工厂 DefaultBatchService.java

package com.example.demo.batch;

import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;

import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;

public class DefaultBatchService<T, R> implements BatchService<T, R> {

  public AtomicInteger atomicInteger = new AtomicInteger(0);
  public ConcurrentHashMap<Integer, Worker<T, R>> cache = new ConcurrentHashMap<>();

  private final int workerNumber;

  public Executor asyncServiceExecutor() {
    ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
    executor.setCorePoolSize(100);
    executor.setMaxPoolSize(100);
    executor.setQueueCapacity(999);
    executor.setKeepAliveSeconds(30);
    executor.setThreadNamePrefix("bath_service");
    executor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());
    executor.initialize();
    return executor;
  }

  public DefaultBatchService(Consumer<List<Cargo<T, R>>> saver, int batchSize, int works) {
    workerNumber = works;
    for (int i = 0; i < workerNumber; i++) {
      Worker<T, R> task = new Worker<>(saver, batchSize);
      cache.put(i, task);
      Executor executor = asyncServiceExecutor();
      executor.execute(task);
    }
  }

  public final int getAndIncrement() {
    int current;
    int next;
    do {
      current = this.atomicInteger.get();
      next = current >= 214748364 ? 0 : current + 1;
    } while (!this.atomicInteger.compareAndSet(current, next));
    return next;
  }

  @Override
  public CompletableFuture<R> submit(T entity) {
    int count = getAndIncrement();
    int position = count % workerNumber;
    Worker<T, R> queue = cache.get(position);
    return queue.add(entity);
  }
}

改造 UserController.java

这是生命 8 个队列,队列满,每个事物最大提交 500 条数据,这个的 8 也就表示同时只有 8 个线程处理保存逻辑,也就只会用到 8 个数据库连接

private final BatchService<User,User> batchSaveService;

public UserController(UserService userService) {
    this.userService = userService;
    this.batchSaveService = BatchService.create(UserController.this.userService::saveAll, 500, 8);
}

生成批量保存操作类

public CompletableFuture<User> addUser(@RequestBody User user) {
    return this.batchSaveService.submit(user);
}

改造 UserServer.java  将 save 改为 saveAll

@Transactional
public void saveAll(List<Cargo<User, User>> cargos) {
    List<User> users = cargos.stream().map(Cargo::getContent).collect(Collectors.toList());

    this.userDao.saveAll(users);

    for (Cargo<User, User> cargo : cargos) {
        User user = cargo.getContent();
        cargo.getHearthstone().complete(user);
    }
}

使用 postman 测试保存逻辑没问题后,我们再压力测试一下

这个看起来有点奇怪,多点测试数据,测试 50W 数据看看

现在并发 500 tps 已经可以到 1200 以上,而且请求都在 1s 以内。

其实到这里,就已经满足要求了,但,还能再优化吗?最后我找到了这篇文章:

dzone.com/articles/be…

MySql 批量提交,但在主键自增的情况下不好使,所以如果要使用这个方案需要自己维护主键,也试试。顺带把 jpa 的 save 逻辑也优化一下,因为 save 其实是 saveOrUpdate 内部会有一个 entityInformation.isNew(entity) 自己实现一个保存方法

 需要去掉 domain id上的主键策略,并修改数据库表,为非自增

dao/impl/UserDaoImpl.java

 @Override
    public <S extends User> S save(S entity) {
        Assert.notNull(entity, "Entity must not be null.");
        if(entity.getId() == null) {
            entity.setId(snowflake.nextId());
        }
        return super.save(entity);
    }

    public void saveInBatch(List<User> entities) {
        if (entities == null) {
            throw new IllegalArgumentException("The given Iterable of entities cannot be null!");
        }
        entities.forEach(item -> item.setId(snowflake.nextId()));
        for (User entity : entities) {
            this.em.persist(entity);
        }
    }

build.gradle 添加主键生成依赖

implementation 'cn.hutool:hutool-all:5.8.3'

resources/application.yml 

spring:
  jpa:
    properties:
      hibernate:
        jdbc:
          batch_size: 500
          batch_versioned_data: true
        order_inserts: true
        order_updates: true
        generate_statistics: false
  datasource:
    name: datasource
    url: jdbc:mysql://dev.local:3306/demo?autoReconnect=true&useUnicode=true&characterEncoding=UTF-8&zeroDateTimeBehavior=convertToNull
    username: root
    password: yV2jJxvNs8BD
    hikari:
      connection-timeout: 30000
      minimum-idle: 20
      maximum-pool-size: 100
      auto-commit: false
      idle-timeout: 600000
      pool-name: DateSourceHikariCP
      max-lifetime: 1800000
      connection-test-query: SELECT 1
      connection-init-sql: set names utf8mb4
      data-source-properties:
        cachePrepStmts: true
        prepStmtCacheSize: 250
        prepStmtCacheSqlLimit: 2048
        useServerPrepStmts: true
        useLocalSessionState: true
        rewriteBatchedStatements: true
        cacheResultSetMetadata: true
        cacheServerConfiguration: true
        elideSetAutoCommits: true
        maintainTimeStats: false

 使用 postman 测试保存逻辑没问题,我们再压力测试一下,看看有没有效果

看来是有效果的,这个时候我们把测试数据量设置为 200W ,再来一次

现在并发 500, tps 已经可以到 5000 以上,而且大部分请求都在 200ms 以内完成。

结果

以上压力测试的错误率都是 0 ,所以单表入库再如此一顿折腾后 tps 5000 , 没有任何的数据库或者连接池优化,理论上来说即使数据库连接次只有 20 个连接,也是可以达到这个效果

再放几张压力测试的结果

完整代码:demo