简介
在这篇文章中,我将向你展示我们如何构建一个JPA关联获取验证器,以断定JPA和Hibernate关联是否使用连接或二次查询来获取。
虽然Hibernate没有提供内置的支持,以编程方式检查实体关联获取行为,但API非常灵活,允许我们自定义,这样我们就可以实现这个非简单的要求。
领域模型
让我们假设我们有以下Post,PostComment, 和PostCommentDetails 实体。

Post 父实体看起来如下。
@Entity(name = "Post")
@Table(name = "post")
public class Post {
@Id
private Long id;
private String title;
//Getters and setters omitted for brevity
}
接下来,我们定义PostComment 子实体,像这样。
@Entity(name = "PostComment")
@Table(name = "post_comment")
public class PostComment {
@Id
private Long id;
@ManyToOne
private Post post;
private String review;
//Getters and setters omitted for brevity
}
请注意,post 关联使用@ManyToOne 关联提供的默认获取策略,即臭名昭著的FetchType.EAGER 策略,该策略是导致很多性能问题的原因,在这篇文章中已经解释过。
而PostCommentDetails 子实体定义了一个与PostComment 父实体的一对一的子关联。同样,comment 关联使用默认的FetchType.EAGER 读取策略。
@Entity(name = "PostCommentDetails")
@Table(name = "post_comment_details")
public class PostCommentDetails {
@Id
private Long id;
@OneToOne
@MapsId
@OnDelete(action = OnDeleteAction.CASCADE)
private PostComment comment;
private int votes;
//Getters and setters omitted for brevity
}
FetchType.EAGER策略的问题
所以,我们有两个使用FetchType.EAGER 反模式的关联。因此,当执行下面的JPQL查询时。
List<PostCommentDetails> commentDetailsList = entityManager.createQuery("""
select pcd
from PostCommentDetails pcd
order by pcd.id
""",
PostCommentDetails.class)
.getResultList();
Hibernate会执行以下3个SQL查询。
SELECT
pce.comment_id AS comment_2_2_,
pce.votes AS votes1_2_
FROM
post_comment_details pce
ORDER BY
pce.comment_id
SELECT
pc.id AS id1_1_0_,
pc.post_id AS post_id3_1_0_,
pc.review AS review2_1_0_,
p.id AS id1_0_1_,
p.title AS title2_0_1_
FROM
post_comment pc
LEFT OUTER JOIN
post p ON pc.post_id=p.id
WHERE
pc.id = 1
SELECT
pc.id AS id1_1_0_,
pc.post_id AS post_id3_1_0_,
pc.review AS review2_1_0_,
p.id AS id1_0_1_,
p.title AS title2_0_1_
FROM
post_comment pc
LEFT OUTER JOIN
post p ON pc.post_id=p.id
WHERE
pc.id = 2
这是一个典型的N+1查询问题。然而,不仅执行了额外的二级查询来获取PostComment 关联,而且这些查询还使用JOIN来获取关联的Post 实体。
除非你想用一个查询来加载整个数据库,否则最好避免使用
FetchType.EAGER反模式。
所以,让我们看看是否能以编程方式检测这些额外的二级查询和JOIN。
Hibernate统计数据检测二次查询
正如我在这篇文章中所解释的,Hibernate不仅可以收集统计信息,而且我们甚至可以定制被收集的数据。
例如,我们可以使用下面的SessionStatistics 工具来监控每个会话被取走了多少个实体。
public class SessionStatistics extends StatisticsImpl {
private static final ThreadLocal<Map<Class, AtomicInteger>>
entityFetchCountContext = new ThreadLocal<>();
public SessionStatistics(
SessionFactoryImplementor sessionFactory) {
super(sessionFactory);
}
@Override
public void openSession() {
entityFetchCountContext.set(new LinkedHashMap<>());
super.openSession();
}
@Override
public void fetchEntity(
String entityName) {
Map<Class, AtomicInteger> entityFetchCountMap = entityFetchCountContext
.get();
entityFetchCountMap
.computeIfAbsent(
ReflectionUtils.getClass(entityName),
clazz -> new AtomicInteger()
)
.incrementAndGet();
super.fetchEntity(entityName);
}
@Override
public void closeSession() {
entityFetchCountContext.remove();
super.closeSession();
}
public static int getEntityFetchCount(
String entityClassName) {
return getEntityFetchCount(
ReflectionUtils.getClass(entityClassName)
);
}
public static int getEntityFetchCount(
Class entityClass) {
AtomicInteger entityFetchCount = entityFetchCountContext.get()
.get(entityClass);
return entityFetchCount != null ? entityFetchCount.get() : 0;
}
public static class Factory implements StatisticsFactory {
public static final Factory INSTANCE = new Factory();
@Override
public StatisticsImplementor buildStatistics(
SessionFactoryImplementor sessionFactory) {
return new SessionStatistics(sessionFactory);
}
}
}
SessionStatistics 类扩展了默认的HibernateStatisticsImpl 类并重写了以下方法。
openSession- 这个回调方法在第一次创建Hibernate 时被调用。我们使用这个回调方法来初始化包含实体获取注册表的 存储。SessionThreadLocalfetchEntity- 每当使用二级查询从数据库中获取一个实体时,这个回调就会被调用。并且我们使用这个回调方法来增加实体获取的计数器。closeSession- 当Hibernate 被关闭时,这个回调方法被调用。在我们的案例中,这时我们需要重置 存储。SessionThreadLocal
getEntityFetchCount 方法将允许我们检查对于一个给定的实体类,有多少实体实例被从数据库中取走了。
Factory 嵌套类实现了StatisticsFactory 接口,并实现了buildStatistics 方法,该方法在启动时被SessionFactory 所调用。
为了配置Hibernate来使用自定义的SessionStatistics ,我们必须提供以下两个设置。
properties.put(
AvailableSettings.GENERATE_STATISTICS,
Boolean.TRUE.toString()
);
properties.put(
StatisticsInitiator.STATS_BUILDER,
SessionStatistics.Factory.INSTANCE
);
第一个是激活Hibernate的统计机制,第二个是告诉Hibernate使用自定义的StatisticsFactory 。
那么,让我们看看它的实际效果吧!
assertEquals(0, SessionStatistics.getEntityFetchCount(PostCommentDetails.class));
assertEquals(0, SessionStatistics.getEntityFetchCount(PostComment.class));
assertEquals(0, SessionStatistics.getEntityFetchCount(Post.class));
List<PostCommentDetails> commentDetailsList = entityManager.createQuery("""
select pcd
from PostCommentDetails pcd
order by pcd.id
""",
PostCommentDetails.class)
.getResultList();
assertEquals(2, commentDetailsList.size());
assertEquals(0, SessionStatistics.getEntityFetchCount(PostCommentDetails.class));
assertEquals(2, SessionStatistics.getEntityFetchCount(PostComment.class));
assertEquals(0, SessionStatistics.getEntityFetchCount(Post.class));
所以,SessionStatistics 只能帮助我们确定额外的二级查询,但对于因为FetchType.EAGER 关联而执行的额外 JOIN,它不起作用。
Hibernate事件监听器同时检测二次查询和额外的JOINs
幸运的是,Hibernate是非常可定制的,因为在内部,它是建立在观察者模式之上的。
每个实体动作都会产生一个事件,由事件监听器来处理,我们可以利用这个机制来监控实体的获取行为。
当一个实体被直接使用find 方法或通过查询来获取时,一个LoadEvent 将被触发。LoadEvent ,首先由LoadEventListener 和PostLoadEventListener Hibernate事件处理程序来处理。
虽然Hibernate为所有实体事件提供了默认的事件处理程序,但我们也可以使用Integrator ,预先添加或追加我们自己的监听器,比如下面这个。
public class AssociationFetchingEventListenerIntegrator
implements Integrator {
public static final AssociationFetchingEventListenerIntegrator INSTANCE =
new AssociationFetchingEventListenerIntegrator();
@Override
public void integrate(
Metadata metadata,
SessionFactoryImplementor sessionFactory,
SessionFactoryServiceRegistry serviceRegistry) {
final EventListenerRegistry eventListenerRegistry =
serviceRegistry.getService(EventListenerRegistry.class);
eventListenerRegistry.prependListeners(
EventType.LOAD,
AssociationFetchPreLoadEventListener.INSTANCE
);
eventListenerRegistry.appendListeners(
EventType.LOAD,
AssociationFetchLoadEventListener.INSTANCE
);
eventListenerRegistry.appendListeners(
EventType.POST_LOAD,
AssociationFetchPostLoadEventListener.INSTANCE
);
}
@Override
public void disintegrate(
SessionFactoryImplementor sessionFactory,
SessionFactoryServiceRegistry serviceRegistry) {
}
}
我们的AssociationFetchingEventListenerIntegrator 注册了三个额外的事件监听器。
- 一个
AssociationFetchPreLoadEventListener,在默认的Hibernate之前执行。LoadEventListener - 一个在默认的Hibernate之后执行的
AssociationFetchLoadEventListenerLoadEventListener - 和一个在默认的Hibernate之后执行的
AssociationFetchPostLoadEventListener。PostLoadEventListener
为了指示Hibernate使用我们的自定义AssociationFetchingEventListenerIntegrator ,以注册额外的事件监听器,我们只需要设置hibernate.integrator_provider 配置属性。
properties.put(
"hibernate.integrator_provider",
(IntegratorProvider) () -> Collections.singletonList(
AssociationFetchingEventListenerIntegrator.INSTANCE
)
);
AssociationFetchPreLoadEventListener 实现了LoadEventListener 接口,看起来像这样。
public class AssociationFetchPreLoadEventListener
implements LoadEventListener {
public static final AssociationFetchPreLoadEventListener INSTANCE =
new AssociationFetchPreLoadEventListener();
@Override
public void onLoad(
LoadEvent event,
LoadType loadType) {
AssociationFetch.Context
.get(event.getSession())
.preLoad(event);
}
}
AssociationFetchLoadEventListener 也实现了LoadEventListener 接口,看起来如下。
public class AssociationFetchLoadEventListener
implements LoadEventListener {
public static final AssociationFetchLoadEventListener INSTANCE =
new AssociationFetchLoadEventListener();
@Override
public void onLoad(
LoadEvent event,
LoadType loadType) {
AssociationFetch.Context
.get(event.getSession())
.load(event);
}
}
而且,AssociationFetchPostLoadEventListener 实现了PostLoadEventListener 接口,看起来像这样。
public class AssociationFetchPostLoadEventListener
implements PostLoadEventListener {
public static final AssociationFetchPostLoadEventListener INSTANCE =
new AssociationFetchPostLoadEventListener();
@Override
public void onPostLoad(
PostLoadEvent event) {
AssociationFetch.Context
.get(event.getSession())
.postLoad(event);
}
}
请注意,所有的实体获取监控逻辑都被封装在以下AssociationFetch 类中。
public class AssociationFetch {
private final Object entity;
public AssociationFetch(Object entity) {
this.entity = entity;
}
public Object getEntity() {
return entity;
}
public static class Context implements Serializable {
public static final String SESSION_PROPERTY_KEY = "ASSOCIATION_FETCH_LIST";
private Map<String, Integer> entityFetchCountByClassNameMap =
new LinkedHashMap<>();
private Set<EntityIdentifier> joinedFetchedEntities =
new LinkedHashSet<>();
private Set<EntityIdentifier> secondaryFetchedEntities =
new LinkedHashSet<>();
private Map<EntityIdentifier, Object> loadedEntities =
new LinkedHashMap<>();
public List<AssociationFetch> getAssociationFetches() {
List<AssociationFetch> associationFetches = new ArrayList<>();
for(Map.Entry<EntityIdentifier, Object> loadedEntityMapEntry :
loadedEntities.entrySet()) {
EntityIdentifier entityIdentifier = loadedEntityMapEntry.getKey();
Object entity = loadedEntityMapEntry.getValue();
if(joinedFetchedEntities.contains(entityIdentifier) ||
secondaryFetchedEntities.contains(entityIdentifier)) {
associationFetches.add(new AssociationFetch(entity));
}
}
return associationFetches;
}
public List<AssociationFetch> getJoinedAssociationFetches() {
List<AssociationFetch> associationFetches = new ArrayList<>();
for(Map.Entry<EntityIdentifier, Object> loadedEntityMapEntry :
loadedEntities.entrySet()) {
EntityIdentifier entityIdentifier = loadedEntityMapEntry.getKey();
Object entity = loadedEntityMapEntry.getValue();
if(joinedFetchedEntities.contains(entityIdentifier)) {
associationFetches.add(new AssociationFetch(entity));
}
}
return associationFetches;
}
public List<AssociationFetch> getSecondaryAssociationFetches() {
List<AssociationFetch> associationFetches = new ArrayList<>();
for(Map.Entry<EntityIdentifier, Object> loadedEntityMapEntry :
loadedEntities.entrySet()) {
EntityIdentifier entityIdentifier = loadedEntityMapEntry.getKey();
Object entity = loadedEntityMapEntry.getValue();
if(secondaryFetchedEntities.contains(entityIdentifier)) {
associationFetches.add(new AssociationFetch(entity));
}
}
return associationFetches;
}
public Map<Class, List<Object>> getAssociationFetchEntityMap() {
return getAssociationFetches()
.stream()
.map(AssociationFetch::getEntity)
.collect(groupingBy(Object::getClass));
}
public void preLoad(LoadEvent loadEvent) {
String entityClassName = loadEvent.getEntityClassName();
entityFetchCountByClassNameMap.put(
entityClassName,
SessionStatistics.getEntityFetchCount(
entityClassName
)
);
}
public void load(LoadEvent loadEvent) {
String entityClassName = loadEvent.getEntityClassName();
int previousFetchCount = entityFetchCountByClassNameMap.get(
entityClassName
);
int currentFetchCount = SessionStatistics.getEntityFetchCount(
entityClassName
);
EntityIdentifier entityIdentifier = new EntityIdentifier(
ReflectionUtils.getClass(loadEvent.getEntityClassName()),
loadEvent.getEntityId()
);
if (loadEvent.isAssociationFetch()) {
if (currentFetchCount == previousFetchCount) {
joinedFetchedEntities.add(entityIdentifier);
} else if (currentFetchCount > previousFetchCount){
secondaryFetchedEntities.add(entityIdentifier);
}
}
}
public void postLoad(PostLoadEvent postLoadEvent) {
loadedEntities.put(
new EntityIdentifier(
postLoadEvent.getEntity().getClass(),
postLoadEvent.getId()
),
postLoadEvent.getEntity()
);
}
public static Context get(Session session) {
Context context = (Context) session.getProperties()
.get(SESSION_PROPERTY_KEY);
if (context == null) {
context = new Context();
session.setProperty(SESSION_PROPERTY_KEY, context);
}
return context;
}
public static Context get(EntityManager entityManager) {
return get(entityManager.unwrap(Session.class));
}
}
private static class EntityIdentifier {
private final Class entityClass;
private final Serializable entityId;
public EntityIdentifier(Class entityClass, Serializable entityId) {
this.entityClass = entityClass;
this.entityId = entityId;
}
public Class getEntityClass() {
return entityClass;
}
public Serializable getEntityId() {
return entityId;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (!(o instanceof EntityIdentifier)) return false;
EntityIdentifier that = (EntityIdentifier) o;
return Objects.equals(getEntityClass(), that.getEntityClass()) &&
Objects.equals(getEntityId(), that.getEntityId());
}
@Override
public int hashCode() {
return Objects.hash(getEntityClass(), getEntityId());
}
}
}
而且,就是这样!
测试时间
那么,让我们看看这个新工具是如何工作的。当运行本文开始时使用的同一个查询时,我们可以看到,我们现在可以捕获所有在执行JPQL查询时进行的关联获取。
AssociationFetch.Context context = AssociationFetch.Context.get(
entityManager
);
assertTrue(context.getAssociationFetches().isEmpty());
List<PostCommentDetails> commentDetailsList = entityManager.createQuery("""
select pcd
from PostCommentDetails pcd
order by pcd.id
""",
PostCommentDetails.class)
.getResultList();
assertEquals(3, context.getAssociationFetches().size());
assertEquals(2, context.getSecondaryAssociationFetches().size());
assertEquals(1, context.getJoinedAssociationFetches().size());
Map<Class, List<Object>> associationFetchMap = context
.getAssociationFetchEntityMap();
assertEquals(2, associationFetchMap.size());
for (PostCommentDetails commentDetails : commentDetailsList) {
assertTrue(
associationFetchMap.get(PostComment.class)
.contains(commentDetails.getComment())
);
assertTrue(
associationFetchMap.get(Post.class)
.contains(commentDetails.getComment().getPost())
);
}
该工具告诉我们,又有3个实体被该查询取走。
- 2个
PostComment实体使用了两个二级查询 - 一个
Post实体,它是通过二次查询的JOIN子句获取的。
如果我们重写之前的查询,对所有这3个关联使用JOIN FETCH来代替。
AssociationFetch.Context context = AssociationFetch.Context.get(
entityManager
);
assertTrue(context.getAssociationFetches().isEmpty());
List<PostCommentDetails> commentDetailsList = entityManager.createQuery("""
select pcd
from PostCommentDetails pcd
join fetch pcd.comment pc
join fetch pc.post
order by pcd.id
""",
PostCommentDetails.class)
.getResultList();
assertEquals(3, context.getJoinedAssociationFetches().size());
assertTrue(context.getSecondaryAssociationFetches().isEmpty());
我们可以看到,这次确实没有执行二级SQL查询,而这3个关联是使用JOIN子句获取的。
很酷,对吗?
结论
使用Hibernate ORM可以很好地构建一个JPA关联获取验证器,因为API提供了许多扩展点。
如果你喜欢这个JPA关联获取验证器工具,那么你一定会喜欢Hypersistence Optizier,它承诺提供数十种检查和验证,这样你就可以从Spring Boot或Jakarta EE应用程序中获得最大收益。