前言
最近有个项目找到我,说是一个医院客户需要一个每天记录员工的健康状况,也就是一个简单的表单,主要就是所在地址,有没经过高风险地区,如何通勤,体温是否正常等一些常规问题,但由于是公立医院,职工较多。而且基本都是在上班前需要上报,所以在 7-8点 会有一个明显的高峰段。最后客户提出的要求是必须并发要在 500 以上
开始
接到这个任务后,我就哈哈了 500,就 500 能有多大问题,咱 SpringBoot 难道 500 都不行,那不可能。为了测试结果我专门下载并安装了 jmeter ,之后就用它来测试。为了简化问题,我们添加一个简单的接口,并实现一个简单的表单入库逻辑。
基础代码
首先我们需要一个 SpringBoot 项目,直接在 springboot 官网生成并下载 start.spring.io/
这里使用 SpringBoot 2.7.1 并添加 web 与 data 依赖,生成代码。包路径与项目名都是默认的,只为方便
一段时间等待之后,我们添加一个简单的接口,与实体。并实现入库逻辑。
由于我比较中意yml 的配置方式,所以将 application.properties 换成了 application.yml
application.yml
spring:
datasource:
name: datasource
url: jdbc:mysql://dev.local:3306/demo?autoReconnect=true&useUnicode=true&characterEncoding=UTF-8&zeroDateTimeBehavior=convertToNull
username: root
password: yV2jJxvNs8BD
添加 mysql 驱动与 lombok
build.gradle
implementation 'mysql:mysql-connector-java'
compileOnly "org.projectlombok:lombok:1.18.24"
annotationProcessor "org.projectlombok:lombok:1.18.24"
testCompileOnly "org.projectlombok:lombok:1.18.24"
testAnnotationProcessor "org.projectlombok:lombok:1.18.24"
继续添加样板代码
domain/User.java
package com.example.demo.domain;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import lombok.*;
import org.springframework.data.annotation.CreatedDate;
import org.springframework.data.annotation.LastModifiedDate;
import javax.persistence.*;
import java.util.Date;
@Getter
@Setter
@ToString
@RequiredArgsConstructor
@Builder
@AllArgsConstructor
@Entity
@Table(name = "DEMO_USER")
@JsonIgnoreProperties({"hibernateLazyInitializer", "handler"})
public class User {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
@Column(name = "ID", nullable = false, updatable = false)
private Long id;
@Column(name = "NICKNAME", length = 30)
private String nickname;
@Column(name = "USERNAME", length = 30)
private String username;
@Column(name = "PASSWORD", length = 21)
private String password;
@CreatedDate
@Temporal(TemporalType.TIMESTAMP)
@Column(updatable = false, name = "CREATED_AT")
private Date createdAt;
@LastModifiedDate
@Temporal(TemporalType.TIMESTAMP)
@Column(name = "UPDATED_AT")
private Date updatedAt;
}
dao/UserDao.java
package com.example.demo.dao;
import com.example.demo.domain.User;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.stereotype.Repository;
@Repository
public interface UserDao extends JpaRepository<User, Long> {
}
service/UserService.java
package com.example.demo.service;
import com.example.demo.dao.UserDao;
import com.example.demo.domain.User;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@Service
public class UserService {
private final UserDao userDao;
public UserService(UserDao userDao) {
this.userDao = userDao;
}
public Page<User> findPage(Pageable pageable) {
return this.userDao.findAll(pageable);
}
@Transactional
public User save(User user) {
return this.userDao.save(user);
}
}
web/UserController.java
package com.example.demo.web;
import com.example.demo.domain.User;
import com.example.demo.service.UserService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.*;
@Slf4j
@Controller
public class UserController {
private final UserService userService;
public UserController(UserService userService) {
this.userService = userService;
}
@GetMapping("/users")
@ResponseBody
public Page<User> users() {
return this.userService.findPage(Pageable.ofSize(20));
}
@PostMapping("/users")
@ResponseStatus(HttpStatus.CREATED)
@ResponseBody
public User addUser(@RequestBody User user) {
return this.userService.save(user);
}
}
启动,并访问 post: /users 与 get: /users 接口
嗯,没问题,现在使用 jmeter 来对 post /users 接口并发插入数据
开始测试
post_user.jtl 可以通过 apifox 生成一个,然后使用 jmeter 命令行方式执行, 模拟 10W 数据,并发 500
jmeter -n -t post_user.jmx -l post_user.jtl -e -o post_user_ResultReport
结果并不理想,tps 300 左右,而且请求时间都在 1s 以上, 那还有什么办法。修改连接池还是??先让我们分析慢的原因
x是请求连接线程, y是数据库连接
x* 表示客户端请求 y* 表示数据库连接
理论上来说 x* 有多少 y* 就有多少
那优化的重点就是减少 x* 与 y* 的数量
首先请求连接如何优化,之前连接数一直受容器服务,如 tomcat 线程数限制,但 Servlet 3.0 不是支持异步 Servlet 吗?理论上是支持类似 Netty 非堵塞效果
数据库连接,如何优化。最简单的想法就是,把同一时间内数据,在同一个事物里面提交到数据库中去。不是就可以省下很多。
有了思路我们就开始
先改造数据库操作,为了把同一时间的请求在一个事物里提交,那这个提交操作一个是在单独的线程中处理的。所以需要添加几个类
数据封装(数据)Cargo.java
package com.example.demo.batch;
import lombok.Getter;
import java.util.concurrent.CompletableFuture;
@Getter
public class Cargo<T, R> {
private final CompletableFuture<R> hearthstone = new CompletableFuture<>();
private final T content;
private Cargo(T o) {
this.content = o;
}
public static <T, R> Cargo<T, R> of(T o) {
return new Cargo<>(o);
}
}
收集数据类(队列)LinkedQueue.java
package com.example.demo.batch;
import java.util.*;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class LinkedQueue<E> extends AbstractQueue<E> implements BlockingQueue<E>, java.io.Serializable {
private static final long serialVersionUID = -4457362206741191196L;
static class Node<E> {
volatile E item;
Node<E> next;
Node<E> previous;
// Node(E x) { item = x; } 修改 内部类 的构造方法
Node(E item, Node<E> next, Node<E> previous) {
this.item = item;
this.next = next;
this.previous = previous;
}
}
private final int capacity;
private final AtomicInteger count = new AtomicInteger(0);
private transient Node<E> head;
private transient Node<E> last;
private final ReentrantLock takeLock = new ReentrantLock();
private final Condition notEmpty = takeLock.newCondition();
private final ReentrantLock putLock = new ReentrantLock();
private final Condition notFull = putLock.newCondition();
private final List<E> list = new Li();
private void signalNotEmpty() {
this.takeLock.lock();
try {
this.notEmpty.signal();
} finally {
this.takeLock.unlock();
}
}
private void signalNotFull() {
this.putLock.lock();
try {
this.notFull.signal();
} finally {
this.putLock.unlock();
}
}
private void insert(E x) {
last = last.next = new Node<>(x, null, last);
head.previous = last;
}
private E extract() {
Node<E> first = head.next;
head = first;
head.previous = last;
E x = first.item;
first.item = null;
return x;
}
public void fullyLock() {
putLock.lock();
takeLock.lock();
}
public void fullyUnlock() {
takeLock.unlock();
putLock.unlock();
}
public LinkedQueue() {
this(Integer.MAX_VALUE);
}
public LinkedQueue(int capacity) {
if (capacity <= 0) {
throw new IllegalArgumentException();
}
this.capacity = capacity;
last = head = new Node<>(null, null, null);
}
public LinkedQueue(Collection<? extends E> c) {
this(Integer.MAX_VALUE);
this.addAll(c);
}
@Override
public int size() {
return count.get();
}
@Override
public int remainingCapacity() {
return capacity - count.get();
}
@Override
public void put(E o) throws InterruptedException {
if (o == null) {
throw new NullPointerException();
}
int c;
putLock.lockInterruptibly();
try {
try {
while (count.get() == capacity) {
notFull.await();
}
} catch (InterruptedException ie) {
notFull.signal();
throw ie;
}
insert(o);
c = count.getAndIncrement();
if (c + 1 < capacity) {
notFull.signal();
}
} finally {
putLock.unlock();
}
if (c == 0) {
signalNotEmpty();
}
}
@Override
public boolean offer(E o, long timeout, TimeUnit unit) throws InterruptedException {
if (o == null) {
throw new NullPointerException();
}
long nanos = unit.toNanos(timeout);
int c;
final ReentrantLock putLock = this.putLock;
final AtomicInteger count = this.count;
putLock.lockInterruptibly();
try {
for (; ; ) {
if (count.get() < capacity) {
insert(o);
c = count.getAndIncrement();
if (c + 1 < capacity) {
notFull.signal();
}
break;
}
if (nanos <= 0) {
return false;
}
try {
nanos = notFull.awaitNanos(nanos);
} catch (InterruptedException ie) {
notFull.signal();
throw ie;
}
}
} finally {
putLock.unlock();
}
if (c == 0) {
signalNotEmpty();
}
return true;
}
@Override
public boolean offer(E o) {
if (o == null) {
throw new NullPointerException();
}
final AtomicInteger count = this.count;
if (count.get() == capacity) {
return false;
}
int c = -1;
final ReentrantLock putLock = this.putLock;
putLock.lock();
try {
if (count.get() < capacity) {
insert(o);
c = count.getAndIncrement();
if (c + 1 < capacity) {
notFull.signal();
}
}
} finally {
putLock.unlock();
}
if (c == 0) {
signalNotEmpty();
}
return c >= 0;
}
@Override
public E take() throws InterruptedException {
E x;
int c;
final AtomicInteger count = this.count;
final ReentrantLock takeLock = this.takeLock;
takeLock.lockInterruptibly();
try {
try {
while (count.get() == 0) {
notEmpty.await();
}
} catch (InterruptedException ie) {
notEmpty.signal();
throw ie;
}
x = extract();
c = count.getAndDecrement();
if (c > 1) {
notEmpty.signal();
}
} finally {
takeLock.unlock();
}
if (c == capacity) {
signalNotFull();
}
return x;
}
@Override
public E poll(long timeout, TimeUnit unit) throws InterruptedException {
E x;
int c;
long nanos = unit.toNanos(timeout);
final AtomicInteger count = this.count;
final ReentrantLock takeLock = this.takeLock;
takeLock.lockInterruptibly();
try {
for (; ; ) {
if (count.get() > 0) {
x = extract();
c = count.getAndDecrement();
if (c > 1) {
notEmpty.signal();
}
break;
}
if (nanos <= 0) {
return null;
}
try {
log.debug("等待时间:" + nanos + "\t" + TimeUnit.NANOSECONDS.toMillis(nanos));
nanos = notEmpty.awaitNanos(nanos);
log.debug("剩余时间:" + nanos + "\t" + TimeUnit.NANOSECONDS.toMillis(nanos));
} catch (InterruptedException ie) {
notEmpty.signal();
throw ie;
}
}
} finally {
takeLock.unlock();
}
if (c == capacity) {
signalNotFull();
}
return x;
}
@Override
public E poll() {
final AtomicInteger count = this.count;
if (count.get() == 0) {
return null;
}
E x = null;
int c = -1;
final ReentrantLock takeLock = this.takeLock;
takeLock.lock();
try {
if (count.get() > 0) {
x = extract();
c = count.getAndDecrement();
if (c > 1) {
notEmpty.signal();
}
}
} finally {
takeLock.unlock();
}
if (c == capacity) {
signalNotFull();
}
return x;
}
@Override
public E peek() {
if (count.get() == 0) {
return null;
}
final ReentrantLock takeLock = this.takeLock;
takeLock.lock();
try {
Node<E> first = head.next;
if (first == null) {
return null;
} else {
return first.item;
}
} finally {
takeLock.unlock();
}
}
@Override
public boolean remove(Object o) {
if (o == null) {
return false;
}
boolean removed = false;
fullyLock();
try {
Node<E> trail = head;
Node<E> p = head.next;
while (p != null) {
if (o.equals(p.item)) {
removed = true;
break;
}
trail = p;
p = p.next;
}
if (removed) {
p.item = null;
trail.next = p.next;
if (p.next != null) {
p.next.previous = trail;
}
if (count.getAndDecrement() == capacity) {
notFull.signalAll();
}
}
} finally {
fullyUnlock();
}
return removed;
}
@Override
public Object[] toArray() {
fullyLock();
try {
int size = count.get();
Object[] a = new Object[size];
int k = 0;
for (Node<E> p = head.next; p != null; p = p.next) {
a[k++] = p.item;
}
return a;
} finally {
fullyUnlock();
}
}
@Override
public <T> T[] toArray(T[] a) {
fullyLock();
try {
int size = count.get();
if (a.length < size) {
a = (T[]) java.lang.reflect.Array.newInstance(a.getClass().getComponentType(), size);
}
int k = 0;
for (Node<E> p = head.next; p != null; p = p.next) {
a[k++] = (T) p.item;
}
return a;
} finally {
fullyUnlock();
}
}
@Override
public String toString() {
fullyLock();
try {
return super.toString();
} finally {
fullyUnlock();
}
}
@Override
public void clear() {
fullyLock();
try {
head.next = null;
assert head.item == null;
last = head;
if (count.getAndSet(0) == capacity) {
notFull.signalAll();
}
} finally {
fullyUnlock();
}
}
protected void _clear() {
head.next = null;
assert head.item == null;
last = head;
if (count.getAndSet(0) == capacity) {
notFull.signalAll();
}
}
@Override
public int drainTo(Collection<? super E> c) {
if (c == null) {
throw new NullPointerException();
}
if (c == this) {
throw new IllegalArgumentException();
}
Node<E> first;
fullyLock();
try {
first = head.next;
head.next = null;
assert head.item == null;
last = head;
if (count.getAndSet(0) == capacity) {
notFull.signalAll();
}
} finally {
fullyUnlock();
}
int n = 0;
for (Node<E> p = first; p != null; p = p.next) {
c.add(p.item);
p.item = null;
++n;
}
return n;
}
@Override
public int drainTo(Collection<? super E> c, int maxElements) {
if (c == null) {
throw new NullPointerException();
}
if (c == this) {
throw new IllegalArgumentException();
}
fullyLock();
try {
int n = 0;
Node<E> p = head.next;
while (p != null && n < maxElements) {
c.add(p.item);
p.item = null;
p = p.next;
++n;
}
if (n != 0) {
head.next = p;
assert head.item == null;
if (p == null) {
last = head;
}
if (count.getAndAdd(-n) == capacity) {
notFull.signalAll();
}
}
return n;
} finally {
fullyUnlock();
}
}
@Override
@SuppressWarnings("NullableProblems")
public Iterator<E> iterator() {
return new Itr();
}
private class Itr implements Iterator<E> {
private Node<E> current;
private Node<E> lastRet;
private E currentElement;
Itr() {
final ReentrantLock putLock = LinkedQueue.this.putLock;
final ReentrantLock takeLock = LinkedQueue.this.takeLock;
putLock.lock();
takeLock.lock();
try {
current = head.next;
if (current != null) {
currentElement = current.item;
}
} finally {
takeLock.unlock();
putLock.unlock();
}
}
@Override
public boolean hasNext() {
return current != null;
}
@Override
public E next() {
final ReentrantLock putLock = LinkedQueue.this.putLock;
final ReentrantLock takeLock = LinkedQueue.this.takeLock;
putLock.lock();
takeLock.lock();
try {
if (current == null) {
throw new NoSuchElementException();
}
E x = currentElement;
lastRet = current;
current = current.next;
if (current != null) {
currentElement = current.item;
}
return x;
} finally {
takeLock.unlock();
putLock.unlock();
}
}
@Override
public void remove() {
if (lastRet == null) {
throw new IllegalStateException();
}
final ReentrantLock putLock = LinkedQueue.this.putLock;
final ReentrantLock takeLock = LinkedQueue.this.takeLock;
putLock.lock();
takeLock.lock();
try {
Node<E> node = lastRet;
lastRet = null;
Node<E> trail = head;
Node<E> p = head.next;
while (p != null && p != node) {
trail = p;
p = p.next;
}
if (p == node) {
assert p != null;
p.item = null;
trail.next = p.next;
int c = count.getAndDecrement();
if (c == capacity) {
notFull.signalAll();
}
}
} finally {
takeLock.unlock();
putLock.unlock();
}
}
}
private void writeObject(java.io.ObjectOutputStream s) throws java.io.IOException {
fullyLock();
try {
s.defaultWriteObject();
for (Node<E> p = head.next; p != null; p = p.next) {
s.writeObject(p.item);
}
s.writeObject(null);
} finally {
fullyUnlock();
}
}
private void readObject(java.io.ObjectInputStream s) throws java.io.IOException, ClassNotFoundException {
s.defaultReadObject();
count.set(0);
last = head = new Node<>(null, null, null);
for (; ; ) {
E item = (E) s.readObject();
if (item == null) {
break;
}
add(item);
}
}
public E get(int index) {
return entry(index).item;
}
private Node<E> entry(int index) {
int size = count.get();
if (index < 0 || index >= size) {
throw new IndexOutOfBoundsException("Index: " + index + ", Size: " + size);
}
Node<E> e = head;
if (index < (size >> 1)) {
for (int i = 0; i <= index; i++) {
e = e.next;
}
} else {
for (int i = size; i > index; i--) {
e = e.previous;
}
}
return e;
}
public boolean add(int index, E o) {
final AtomicInteger count = this.count;
if (count.get() == capacity) {
return false;
}
int c = -1;
final ReentrantLock putLock = this.putLock;
putLock.lock();
try {
if (count.get() < capacity) {
if (index >= size()) {
this.insert(o);
} else {
Node<E> p = entry(index);
p.previous.next = p.previous = new Node<>(o, p, p.previous);
}
c = count.getAndIncrement();
if (c + 1 < capacity) {
notFull.signal();
}
}
} finally {
putLock.unlock();
}
if (c == 0) {
signalNotEmpty();
}
return c >= 0;
}
public int indexOf(E o) {
int index = 0;
if (o == null) {
for (Node<E> e = head.next; e != head; e = e.next) {
if (e.item == null) {
return index;
}
index++;
}
} else {
for (Node<E> e = head.next; e != head; e = e.next) {
if (o.equals(e.item)) {
return index;
}
index++;
}
}
return -1;
}
public E set(int index, E element) {
Node<E> e = entry(index);
E oldVal = e.item;
e.item = element;
return oldVal;
}
public E remove(int index) {
E oldVal = entry(index).item;
boolean removed = remove(oldVal);
return removed ? oldVal : null;
}
public void poll(long timeout, TimeUnit unit, PollCallBack<E> pollCallBack) throws InterruptedException {
int c = -1;
long nanos = unit.toNanos(timeout);
final AtomicInteger count = this.count;
final ReentrantLock takeLock = this.takeLock;
takeLock.lockInterruptibly();
try {
for (; ; ) {
if (count.get() > 0) {
E x = extract();
c = count.getAndDecrement();
if (pollCallBack.CallBack(x, this.peek())) {
continue;
} else {
break;
}
}
if (nanos <= 0) {
break;
}
try {
nanos = notEmpty.awaitNanos(nanos);
} catch (InterruptedException ie) {
notEmpty.signal();
throw ie;
}
}
} finally {
takeLock.unlock();
}
if (c > 1) {
notEmpty.signal();
}
if (c == capacity) {
signalNotFull();
}
}
public interface PollCallBack<E> {
boolean CallBack(E current, E next);
}
public List<E> list() {
return list;
}
private class Li implements List<E> {
@Override
public boolean add(E o) {
return LinkedQueue.this.add(o);
}
@Override
public void add(int index, E element) {
LinkedQueue.this.add(index, element);
}
@Override
public boolean addAll(Collection<? extends E> c) {
return LinkedQueue.this.addAll(c);
}
@Override
public boolean addAll(int index, Collection<? extends E> c) {
throw new RuntimeException("null method");
}
@Override
public void clear() {
LinkedQueue.this.clear();
}
@Override
public boolean contains(Object o) {
return LinkedQueue.this.contains(o);
}
@Override
public boolean containsAll(Collection<?> c) {
return LinkedQueue.this.containsAll(c);
}
@Override
public E get(int index) {
return LinkedQueue.this.get(index);
}
@Override
public int indexOf(Object o) {
return LinkedQueue.this.indexOf((E) o);
}
@Override
public boolean isEmpty() {
return LinkedQueue.this.isEmpty();
}
@Override
public Iterator<E> iterator() {
return LinkedQueue.this.iterator();
}
@Override
public int lastIndexOf(Object o) {
throw new RuntimeException("null method");
}
@Override
public ListIterator<E> listIterator() {
throw new RuntimeException("null method");
}
@Override
public ListIterator<E> listIterator(int index) {
throw new RuntimeException("null method");
}
@Override
public boolean remove(Object o) {
return LinkedQueue.this.remove(o);
}
@Override
public E remove(int index) {
return LinkedQueue.this.remove(index);
}
@Override
public boolean removeAll(Collection<?> c) {
return LinkedQueue.this.removeAll(c);
}
@Override
public boolean retainAll(Collection<?> c) {
return LinkedQueue.this.retainAll(c);
}
@Override
public E set(int index, E element) {
return LinkedQueue.this.set(index, element);
}
@Override
public int size() {
return LinkedQueue.this.size();
}
@Override
public List<E> subList(int fromIndex, int toIndex) {
throw new RuntimeException("null method");
}
@Override
public Object[] toArray() {
return LinkedQueue.this.toArray();
}
@Override
@SuppressWarnings({"SuspiciousToArrayCall"})
public <T> T[] toArray(T[] a) {
return LinkedQueue.this.toArray(a);
}
}
}
匹配提交类(工人)Worker.java
package com.example.demo.batch;
import lombok.extern.slf4j.Slf4j;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
@Slf4j
public class Worker<T, R> implements Runnable {
private final LinkedQueue<Cargo<T, R>> queue = new LinkedQueue<>();
private final int batchSize;
private final Consumer<List<Cargo<T, R>>> saver;
public Worker(Consumer<List<Cargo<T, R>>> saver, int batchSize) {
this.saver = saver;
this.batchSize = batchSize;
}
public CompletableFuture<R> add(T o) {
Cargo<T, R> item = Cargo.of(o);
queue.add(item);
return item.getHearthstone();
}
public List<Cargo<T, R>> getItems() {
try {
Cargo<T, R> item = queue.take();
List<Cargo<T, R>> items = new ArrayList<>(batchSize);
items.add(item);
item = queue.poll();
while (item != null) {
items.add(item);
if (items.size() >= batchSize) {
return items;
}
item = queue.poll();
}
return items;
} catch (InterruptedException e) {
log.error(e.getMessage());
throw new RuntimeException(e.getMessage());
}
}
@Override
public void run() {
do {
List<Cargo<T, R>> items = getItems();
save(items);
} while (true);
}
private void save(List<Cargo<T, R>> items) {
try {
saver.accept(items);
} catch (Exception e) {
log.error(e.getMessage(), e);
items.forEach(item -> item.getHearthstone().obtrudeException(e));
}
}
}
批量操作接口 BatchService.java
package com.example.demo.batch;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;
public interface BatchService<T, R> {
/**
* 生成批量提交服务
*
* @param saver 保存方法
* @param batchSize 批处理大小
* @param works 工人数量
* @param <T> 类型
* @return BatchService<T>
*/
static <T, R> BatchService<T, R> create(
Consumer<List<Cargo<T, R>>> saver, int batchSize, int works) {
return new DefaultBatchService<>(saver, batchSize, works);
}
CompletableFuture<R> submit(T entity);
}
工厂 DefaultBatchService.java
package com.example.demo.batch;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
public class DefaultBatchService<T, R> implements BatchService<T, R> {
public AtomicInteger atomicInteger = new AtomicInteger(0);
public ConcurrentHashMap<Integer, Worker<T, R>> cache = new ConcurrentHashMap<>();
private final int workerNumber;
public Executor asyncServiceExecutor() {
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
executor.setCorePoolSize(100);
executor.setMaxPoolSize(100);
executor.setQueueCapacity(999);
executor.setKeepAliveSeconds(30);
executor.setThreadNamePrefix("bath_service");
executor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());
executor.initialize();
return executor;
}
public DefaultBatchService(Consumer<List<Cargo<T, R>>> saver, int batchSize, int works) {
workerNumber = works;
for (int i = 0; i < workerNumber; i++) {
Worker<T, R> task = new Worker<>(saver, batchSize);
cache.put(i, task);
Executor executor = asyncServiceExecutor();
executor.execute(task);
}
}
public final int getAndIncrement() {
int current;
int next;
do {
current = this.atomicInteger.get();
next = current >= 214748364 ? 0 : current + 1;
} while (!this.atomicInteger.compareAndSet(current, next));
return next;
}
@Override
public CompletableFuture<R> submit(T entity) {
int count = getAndIncrement();
int position = count % workerNumber;
Worker<T, R> queue = cache.get(position);
return queue.add(entity);
}
}
改造 UserController.java
这是生命 8 个队列,队列满,每个事物最大提交 500 条数据,这个的 8 也就表示同时只有 8 个线程处理保存逻辑,也就只会用到 8 个数据库连接
private final BatchService<User,User> batchSaveService;
public UserController(UserService userService) {
this.userService = userService;
this.batchSaveService = BatchService.create(UserController.this.userService::saveAll, 500, 8);
}
生成批量保存操作类
public CompletableFuture<User> addUser(@RequestBody User user) {
return this.batchSaveService.submit(user);
}
改造 UserServer.java 将 save 改为 saveAll
@Transactional
public void saveAll(List<Cargo<User, User>> cargos) {
List<User> users = cargos.stream().map(Cargo::getContent).collect(Collectors.toList());
this.userDao.saveAll(users);
for (Cargo<User, User> cargo : cargos) {
User user = cargo.getContent();
cargo.getHearthstone().complete(user);
}
}
使用 postman 测试保存逻辑没问题后,我们再压力测试一下
这个看起来有点奇怪,多点测试数据,测试 50W 数据看看
现在并发 500 tps 已经可以到 1200 以上,而且请求都在 1s 以内。
其实到这里,就已经满足要求了,但,还能再优化吗?最后我找到了这篇文章:
MySql 批量提交,但在主键自增的情况下不好使,所以如果要使用这个方案需要自己维护主键,也试试。顺带把 jpa 的 save 逻辑也优化一下,因为 save 其实是 saveOrUpdate 内部会有一个 entityInformation.isNew(entity) 自己实现一个保存方法
需要去掉 domain id上的主键策略,并修改数据库表,为非自增
dao/impl/UserDaoImpl.java
@Override
public <S extends User> S save(S entity) {
Assert.notNull(entity, "Entity must not be null.");
if(entity.getId() == null) {
entity.setId(snowflake.nextId());
}
return super.save(entity);
}
public void saveInBatch(List<User> entities) {
if (entities == null) {
throw new IllegalArgumentException("The given Iterable of entities cannot be null!");
}
entities.forEach(item -> item.setId(snowflake.nextId()));
for (User entity : entities) {
this.em.persist(entity);
}
}
build.gradle 添加主键生成依赖
implementation 'cn.hutool:hutool-all:5.8.3'
resources/application.yml
spring:
jpa:
properties:
hibernate:
jdbc:
batch_size: 500
batch_versioned_data: true
order_inserts: true
order_updates: true
generate_statistics: false
datasource:
name: datasource
url: jdbc:mysql://dev.local:3306/demo?autoReconnect=true&useUnicode=true&characterEncoding=UTF-8&zeroDateTimeBehavior=convertToNull
username: root
password: yV2jJxvNs8BD
hikari:
connection-timeout: 30000
minimum-idle: 20
maximum-pool-size: 100
auto-commit: false
idle-timeout: 600000
pool-name: DateSourceHikariCP
max-lifetime: 1800000
connection-test-query: SELECT 1
connection-init-sql: set names utf8mb4
data-source-properties:
cachePrepStmts: true
prepStmtCacheSize: 250
prepStmtCacheSqlLimit: 2048
useServerPrepStmts: true
useLocalSessionState: true
rewriteBatchedStatements: true
cacheResultSetMetadata: true
cacheServerConfiguration: true
elideSetAutoCommits: true
maintainTimeStats: false
使用 postman 测试保存逻辑没问题,我们再压力测试一下,看看有没有效果
看来是有效果的,这个时候我们把测试数据量设置为 200W ,再来一次
现在并发 500, tps 已经可以到 5000 以上,而且大部分请求都在 200ms 以内完成。
结果
以上压力测试的错误率都是 0 ,所以单表入库再如此一顿折腾后 tps 5000 , 没有任何的数据库或者连接池优化,理论上来说即使数据库连接次只有 20 个连接,也是可以达到这个效果
再放几张压力测试的结果
完整代码:demo