如何构建一个JPA关联获取验证器

467 阅读7分钟

简介

在这篇文章中,我将向你展示我们如何构建一个JPA关联获取验证器,以断定JPA和Hibernate关联是否使用连接或二次查询来获取。

虽然Hibernate没有提供内置的支持,以编程方式检查实体关联获取行为,但API非常灵活,允许我们自定义,这样我们就可以实现这个非简单的要求。

领域模型

让我们假设我们有以下Post,PostComment, 和PostCommentDetails 实体。

JPA Association Fetching Validator Entities

Post 父实体看起来如下。

@Entity(name = "Post")
@Table(name = "post")
public class Post {

    @Id
    private Long id;

    private String title;

    //Getters and setters omitted for brevity
}

接下来,我们定义PostComment 子实体,像这样。

@Entity(name = "PostComment")
@Table(name = "post_comment")
public class PostComment {

    @Id
    private Long id;

    @ManyToOne
    private Post post;

    private String review;

    //Getters and setters omitted for brevity
}

请注意,post 关联使用@ManyToOne 关联提供的默认获取策略,即臭名昭著的FetchType.EAGER 策略,该策略是导致很多性能问题的原因,在这篇文章中已经解释过。

PostCommentDetails 子实体定义了一个与PostComment 父实体的一对一的子关联。同样,comment 关联使用默认的FetchType.EAGER 读取策略。

@Entity(name = "PostCommentDetails")
@Table(name = "post_comment_details")
public class PostCommentDetails {

    @Id
    private Long id;

    @OneToOne
    @MapsId
    @OnDelete(action = OnDeleteAction.CASCADE)
    private PostComment comment;

    private int votes;

    //Getters and setters omitted for brevity
}

FetchType.EAGER策略的问题

所以,我们有两个使用FetchType.EAGER 反模式的关联。因此,当执行下面的JPQL查询时。

List<PostCommentDetails> commentDetailsList = entityManager.createQuery("""
    select pcd
    from PostCommentDetails pcd
    order by pcd.id
    """,
    PostCommentDetails.class)
.getResultList();

Hibernate会执行以下3个SQL查询。

SELECT 
    pce.comment_id AS comment_2_2_,
    pce.votes AS votes1_2_
FROM 
    post_comment_details pce
ORDER BY 
    pce.comment_id

SELECT 
    pc.id AS id1_1_0_,
    pc.post_id AS post_id3_1_0_,
    pc.review AS review2_1_0_,
    p.id AS id1_0_1_,
    p.title AS title2_0_1_
FROM 
    post_comment pc
LEFT OUTER JOIN 
    post p ON pc.post_id=p.id
WHERE 
    pc.id = 1
    
SELECT 
    pc.id AS id1_1_0_,
    pc.post_id AS post_id3_1_0_,
    pc.review AS review2_1_0_,
    p.id AS id1_0_1_,
    p.title AS title2_0_1_
FROM 
    post_comment pc
LEFT OUTER JOIN 
    post p ON pc.post_id=p.id
WHERE 
    pc.id = 2

这是一个典型的N+1查询问题。然而,不仅执行了额外的二级查询来获取PostComment 关联,而且这些查询还使用JOIN来获取关联的Post 实体。

除非你想用一个查询来加载整个数据库,否则最好避免使用FetchType.EAGER 反模式。

所以,让我们看看是否能以编程方式检测这些额外的二级查询和JOIN。

Hibernate统计数据检测二次查询

正如我在这篇文章中所解释的,Hibernate不仅可以收集统计信息,而且我们甚至可以定制被收集的数据。

例如,我们可以使用下面的SessionStatistics 工具来监控每个会话被取走了多少个实体。

public class SessionStatistics extends StatisticsImpl {

    private static final ThreadLocal<Map<Class, AtomicInteger>> 
        entityFetchCountContext = new ThreadLocal<>();

    public SessionStatistics(
            SessionFactoryImplementor sessionFactory) {
        super(sessionFactory);
    }

    @Override
    public void openSession() {
        entityFetchCountContext.set(new LinkedHashMap<>());
        
        super.openSession();
    }

    @Override
    public void fetchEntity(
            String entityName) {
        Map<Class, AtomicInteger> entityFetchCountMap = entityFetchCountContext
            .get();
        
        entityFetchCountMap
            .computeIfAbsent(
                ReflectionUtils.getClass(entityName), 
                clazz -> new AtomicInteger()
            )
            .incrementAndGet();
        
        super.fetchEntity(entityName);
    }

    @Override
    public void closeSession() {
        entityFetchCountContext.remove();
        
        super.closeSession();
    }

    public static int getEntityFetchCount(
            String entityClassName) {        
        return getEntityFetchCount(
            ReflectionUtils.getClass(entityClassName)
        );
    }

    public static int getEntityFetchCount(
            Class entityClass) {
        AtomicInteger entityFetchCount = entityFetchCountContext.get()
            .get(entityClass);
            
        return entityFetchCount != null ? entityFetchCount.get() : 0;
    }

    public static class Factory implements StatisticsFactory {

        public static final Factory INSTANCE = new Factory();

        @Override
        public StatisticsImplementor buildStatistics(
                SessionFactoryImplementor sessionFactory) {
            return new SessionStatistics(sessionFactory);
        }
    }
}

SessionStatistics 类扩展了默认的HibernateStatisticsImpl 类并重写了以下方法。

  • openSession - 这个回调方法在第一次创建Hibernate 时被调用。我们使用这个回调方法来初始化包含实体获取注册表的 存储。Session ThreadLocal
  • fetchEntity - 每当使用二级查询从数据库中获取一个实体时,这个回调就会被调用。并且我们使用这个回调方法来增加实体获取的计数器。
  • closeSession - 当Hibernate 被关闭时,这个回调方法被调用。在我们的案例中,这时我们需要重置 存储。Session ThreadLocal

getEntityFetchCount 方法将允许我们检查对于一个给定的实体类,有多少实体实例被从数据库中取走了。

Factory 嵌套类实现了StatisticsFactory 接口,并实现了buildStatistics 方法,该方法在启动时被SessionFactory 所调用。

为了配置Hibernate来使用自定义的SessionStatistics ,我们必须提供以下两个设置。

properties.put(
    AvailableSettings.GENERATE_STATISTICS,
    Boolean.TRUE.toString()
);
properties.put(
    StatisticsInitiator.STATS_BUILDER,
    SessionStatistics.Factory.INSTANCE
);

第一个是激活Hibernate的统计机制,第二个是告诉Hibernate使用自定义的StatisticsFactory

那么,让我们看看它的实际效果吧!

assertEquals(0, SessionStatistics.getEntityFetchCount(PostCommentDetails.class));
assertEquals(0, SessionStatistics.getEntityFetchCount(PostComment.class));
assertEquals(0, SessionStatistics.getEntityFetchCount(Post.class));

List<PostCommentDetails> commentDetailsList = entityManager.createQuery("""
    select pcd
    from PostCommentDetails pcd
    order by pcd.id
    """,
    PostCommentDetails.class)
.getResultList();

assertEquals(2, commentDetailsList.size());

assertEquals(0, SessionStatistics.getEntityFetchCount(PostCommentDetails.class));
assertEquals(2, SessionStatistics.getEntityFetchCount(PostComment.class));
assertEquals(0, SessionStatistics.getEntityFetchCount(Post.class));

所以,SessionStatistics 只能帮助我们确定额外的二级查询,但对于因为FetchType.EAGER 关联而执行的额外 JOIN,它不起作用。

Hibernate事件监听器同时检测二次查询和额外的JOINs

幸运的是,Hibernate是非常可定制的,因为在内部,它是建立在观察者模式之上的。

每个实体动作都会产生一个事件,由事件监听器来处理,我们可以利用这个机制来监控实体的获取行为。

当一个实体被直接使用find 方法或通过查询来获取时,一个LoadEvent 将被触发。LoadEvent ,首先由LoadEventListenerPostLoadEventListener Hibernate事件处理程序来处理。

虽然Hibernate为所有实体事件提供了默认的事件处理程序,但我们也可以使用Integrator ,预先添加或追加我们自己的监听器,比如下面这个。

public class AssociationFetchingEventListenerIntegrator 
        implements Integrator {

    public static final AssociationFetchingEventListenerIntegrator INSTANCE = 
        new AssociationFetchingEventListenerIntegrator();

    @Override
    public void integrate(
            Metadata metadata,
            SessionFactoryImplementor sessionFactory,
            SessionFactoryServiceRegistry serviceRegistry) {

        final EventListenerRegistry eventListenerRegistry =
            serviceRegistry.getService(EventListenerRegistry.class);

        eventListenerRegistry.prependListeners(
            EventType.LOAD,
            AssociationFetchPreLoadEventListener.INSTANCE
        );

        eventListenerRegistry.appendListeners(
            EventType.LOAD,
            AssociationFetchLoadEventListener.INSTANCE
        );

        eventListenerRegistry.appendListeners(
            EventType.POST_LOAD,
            AssociationFetchPostLoadEventListener.INSTANCE
        );
    }

    @Override
    public void disintegrate(
            SessionFactoryImplementor sessionFactory,
            SessionFactoryServiceRegistry serviceRegistry) {
    }
}

我们的AssociationFetchingEventListenerIntegrator 注册了三个额外的事件监听器。

  • 一个AssociationFetchPreLoadEventListener ,在默认的Hibernate之前执行。LoadEventListener
  • 一个在默认的Hibernate之后执行的AssociationFetchLoadEventListener LoadEventListener
  • 和一个在默认的Hibernate之后执行的AssociationFetchPostLoadEventListenerPostLoadEventListener

为了指示Hibernate使用我们的自定义AssociationFetchingEventListenerIntegrator ,以注册额外的事件监听器,我们只需要设置hibernate.integrator_provider 配置属性。

properties.put(
    "hibernate.integrator_provider", 
    (IntegratorProvider) () -> Collections.singletonList(
        AssociationFetchingEventListenerIntegrator.INSTANCE
    )
);

AssociationFetchPreLoadEventListener 实现了LoadEventListener 接口,看起来像这样。

public class AssociationFetchPreLoadEventListener 
        implements LoadEventListener {

    public static final AssociationFetchPreLoadEventListener INSTANCE = 
        new AssociationFetchPreLoadEventListener();

    @Override
    public void onLoad(
            LoadEvent event, 
            LoadType loadType) {
        AssociationFetch.Context
            .get(event.getSession())
            .preLoad(event);
    }
}

AssociationFetchLoadEventListener 也实现了LoadEventListener 接口,看起来如下。

public class AssociationFetchLoadEventListener 
        implements LoadEventListener {

    public static final AssociationFetchLoadEventListener INSTANCE = 
        new AssociationFetchLoadEventListener();

    @Override
    public void onLoad(
            LoadEvent event, 
            LoadType loadType) {
        AssociationFetch.Context
            .get(event.getSession())
            .load(event);
    }
}

而且,AssociationFetchPostLoadEventListener 实现了PostLoadEventListener 接口,看起来像这样。

public class AssociationFetchPostLoadEventListener 
        implements PostLoadEventListener {

    public static final AssociationFetchPostLoadEventListener INSTANCE = 
        new AssociationFetchPostLoadEventListener();

    @Override
    public void onPostLoad(
            PostLoadEvent event) {
        AssociationFetch.Context
            .get(event.getSession())
            .postLoad(event);
    }
}

请注意,所有的实体获取监控逻辑都被封装在以下AssociationFetch 类中。

public class AssociationFetch {

    private final Object entity;

    public AssociationFetch(Object entity) {
        this.entity = entity;
    }

    public Object getEntity() {
        return entity;
    }

    public static class Context implements Serializable {
        public static final String SESSION_PROPERTY_KEY = "ASSOCIATION_FETCH_LIST";

        private Map<String, Integer> entityFetchCountByClassNameMap = 
            new LinkedHashMap<>();

        private Set<EntityIdentifier> joinedFetchedEntities = 
            new LinkedHashSet<>();

        private Set<EntityIdentifier> secondaryFetchedEntities = 
            new LinkedHashSet<>();

        private Map<EntityIdentifier, Object> loadedEntities = 
            new LinkedHashMap<>();

        public List<AssociationFetch> getAssociationFetches() {
            List<AssociationFetch> associationFetches = new ArrayList<>();

            for(Map.Entry<EntityIdentifier, Object> loadedEntityMapEntry : 
                    loadedEntities.entrySet()) {
                EntityIdentifier entityIdentifier = loadedEntityMapEntry.getKey();
                Object entity = loadedEntityMapEntry.getValue();
                
                if(joinedFetchedEntities.contains(entityIdentifier) ||
                   secondaryFetchedEntities.contains(entityIdentifier)) {
                    associationFetches.add(new AssociationFetch(entity));
                }
            }

            return associationFetches;
        }

        public List<AssociationFetch> getJoinedAssociationFetches() {
            List<AssociationFetch> associationFetches = new ArrayList<>();

            for(Map.Entry<EntityIdentifier, Object> loadedEntityMapEntry : 
                    loadedEntities.entrySet()) {
                EntityIdentifier entityIdentifier = loadedEntityMapEntry.getKey();
                Object entity = loadedEntityMapEntry.getValue();
                
                if(joinedFetchedEntities.contains(entityIdentifier)) {
                    associationFetches.add(new AssociationFetch(entity));
                }
            }

            return associationFetches;
        }

        public List<AssociationFetch> getSecondaryAssociationFetches() {
            List<AssociationFetch> associationFetches = new ArrayList<>();

            for(Map.Entry<EntityIdentifier, Object> loadedEntityMapEntry : 
                    loadedEntities.entrySet()) {
                EntityIdentifier entityIdentifier = loadedEntityMapEntry.getKey();
                Object entity = loadedEntityMapEntry.getValue();
                if(secondaryFetchedEntities.contains(entityIdentifier)) {
                    associationFetches.add(new AssociationFetch(entity));
                }
            }

            return associationFetches;
        }

        public Map<Class, List<Object>> getAssociationFetchEntityMap() {
            return getAssociationFetches()
                .stream()
                .map(AssociationFetch::getEntity)
                .collect(groupingBy(Object::getClass));
        }

        public void preLoad(LoadEvent loadEvent) {
            String entityClassName = loadEvent.getEntityClassName();
            entityFetchCountByClassNameMap.put(
                entityClassName, 
                SessionStatistics.getEntityFetchCount(
                    entityClassName
                )
            );
        }

        public void load(LoadEvent loadEvent) {
            String entityClassName = loadEvent.getEntityClassName();
            int previousFetchCount = entityFetchCountByClassNameMap.get(
                entityClassName
            );
            int currentFetchCount = SessionStatistics.getEntityFetchCount(
                entityClassName
            );

            EntityIdentifier entityIdentifier = new EntityIdentifier(
                ReflectionUtils.getClass(loadEvent.getEntityClassName()),
                loadEvent.getEntityId()
            );

            if (loadEvent.isAssociationFetch()) {
                if (currentFetchCount == previousFetchCount) {
                    joinedFetchedEntities.add(entityIdentifier);
                } else if (currentFetchCount > previousFetchCount){
                    secondaryFetchedEntities.add(entityIdentifier);
                }
            }
        }

        public void postLoad(PostLoadEvent postLoadEvent) {
            loadedEntities.put(
                new EntityIdentifier(
                    postLoadEvent.getEntity().getClass(),
                    postLoadEvent.getId()
                ),
                postLoadEvent.getEntity()
            );
        }

        public static Context get(Session session) {
            Context context = (Context) session.getProperties()
                .get(SESSION_PROPERTY_KEY);
                
            if (context == null) {
                context = new Context();
                session.setProperty(SESSION_PROPERTY_KEY, context);
            }
            return context;
        }

        public static Context get(EntityManager entityManager) {
            return get(entityManager.unwrap(Session.class));
        }
    }

    private static class EntityIdentifier {
        private final Class entityClass;

        private final Serializable entityId;

        public EntityIdentifier(Class entityClass, Serializable entityId) {
            this.entityClass = entityClass;
            this.entityId = entityId;
        }

        public Class getEntityClass() {
            return entityClass;
        }

        public Serializable getEntityId() {
            return entityId;
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) return true;
            if (!(o instanceof EntityIdentifier)) return false;
            EntityIdentifier that = (EntityIdentifier) o;
            return Objects.equals(getEntityClass(), that.getEntityClass()) && 
                   Objects.equals(getEntityId(), that.getEntityId());
        }

        @Override
        public int hashCode() {
            return Objects.hash(getEntityClass(), getEntityId());
        }
    }
}

而且,就是这样!

测试时间

那么,让我们看看这个新工具是如何工作的。当运行本文开始时使用的同一个查询时,我们可以看到,我们现在可以捕获所有在执行JPQL查询时进行的关联获取。

AssociationFetch.Context context = AssociationFetch.Context.get(
    entityManager
);
assertTrue(context.getAssociationFetches().isEmpty());

List<PostCommentDetails> commentDetailsList = entityManager.createQuery("""
    select pcd
    from PostCommentDetails pcd
    order by pcd.id
    """,
    PostCommentDetails.class)
.getResultList();

assertEquals(3, context.getAssociationFetches().size());
assertEquals(2, context.getSecondaryAssociationFetches().size());
assertEquals(1, context.getJoinedAssociationFetches().size());

Map<Class, List<Object>> associationFetchMap = context
    .getAssociationFetchEntityMap();
    
assertEquals(2, associationFetchMap.size());

for (PostCommentDetails commentDetails : commentDetailsList) {
    assertTrue(
        associationFetchMap.get(PostComment.class)
            .contains(commentDetails.getComment())
    );
    assertTrue(
        associationFetchMap.get(Post.class)
            .contains(commentDetails.getComment().getPost())
    );
}

该工具告诉我们,又有3个实体被该查询取走。

  • 2个PostComment 实体使用了两个二级查询
  • 一个Post 实体,它是通过二次查询的JOIN子句获取的。

如果我们重写之前的查询,对所有这3个关联使用JOIN FETCH来代替。

AssociationFetch.Context context = AssociationFetch.Context.get(
    entityManager
);
assertTrue(context.getAssociationFetches().isEmpty());

List<PostCommentDetails> commentDetailsList = entityManager.createQuery("""
    select pcd
    from PostCommentDetails pcd
    join fetch pcd.comment pc
    join fetch pc.post
    order by pcd.id
    """,
    PostCommentDetails.class)
.getResultList();

assertEquals(3, context.getJoinedAssociationFetches().size());
assertTrue(context.getSecondaryAssociationFetches().isEmpty());

我们可以看到,这次确实没有执行二级SQL查询,而这3个关联是使用JOIN子句获取的。

很酷,对吗?

结论

使用Hibernate ORM可以很好地构建一个JPA关联获取验证器,因为API提供了许多扩展点。

如果你喜欢这个JPA关联获取验证器工具,那么你一定会喜欢Hypersistence Optizier,它承诺提供数十种检查和验证,这样你就可以从Spring Boot或Jakarta EE应用程序中获得最大收益。