二叉树左右值新增删除节点工具类

176 阅读8分钟

该工具类主要用于对二叉树结构的新增和删除节点做更新左右值。

image.png


目前个人测试没问题,分为两个工具类:
BinaryTreeUtil:用于操作二叉树结构节点的新增删除和查询。
BinaryTreeConfig:用于对二叉树结构的配置,例如:指定左值字段名称,顶层ID等,可自行进行配置的重构。


设计思想:

  • 1:根据一个配置类,来指定二叉树的结构,左值字段,右值字段,id字段等;
  • 2:工具类使用反射机制,对需要添加的对象进行字段的获取,校验,获取配置类中定义的字段;
  • 3:工具类中使用mybatis-plus的mapper进行全局数据查询,依托QueryWapper进行数据查询字段拼接;
  • 4:全程使用未定义出具体的操作对象类,使用T,R,P等泛型进行指定,使用反射工具类进行获取;

话不多说,上代码:

BinaryTreeConfig:

import lombok.Data;
import lombok.experimental.Accessors;

import java.io.Serializable;

/**
 * 类名:二叉树修改配置项
 *
 * @author a-xin
 * @date 2024/5/27 10:36
 */
@Data
@Accessors(chain = true)
public class BinaryTreeConfig<T> implements Serializable {

    private static final long serialVersionUID = -6306806910289247396L;

    /**
     * 主键ID字段名称
     */
    private String id = "id";

    /**
     * 父级ID字段名称
     */
    private String pId = "pId";

    /**
     * 顶层元素id值,必须设置,否则报错
     */
    private T topId;

    /**
     * 左值字段名称
     */
    private String lft = "lft";

    /**
     * 右值字段名称
     */
    private String rgt = "rgt";

    /**
     * 获取默认二叉树结构配置项,顶层id为 -1
     *
     * @return 默认二叉树结构配置项
     */
    public static BinaryTreeConfig<String> getDefaultConfig() {
        BinaryTreeConfig<String> config = new BinaryTreeConfig<>();
        config.setTopId("-1");
        return config;
    }

}

BinaryTreeUtil:

import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.text.CharSequenceUtil;
import cn.hutool.core.util.ObjectUtil;
import com.axin229913.constant.util.AssertUtil;
import com.axin229913.exception.enums.CommonExceptionEnum;
import com.axin229913.exception.exceptions.CommonException;
import com.axin229913.mybatis.base.SuperMapper;
import com.axin229913.mybatis.binaryTree.config.BinaryTreeConfig;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.core.toolkit.IdWorker;
import lombok.extern.slf4j.Slf4j;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;
import org.springframework.stereotype.Component;

import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.*;

/**
 * 类名:二叉树左右值更新工具类
 *
 * @author a-xin
 * @date 2024/5/24 16:58
 */
@Slf4j
@Component
public class BinaryTreeUtil {

    /**
     * 获取元素所有的子级
     *
     * @param id     元素ID
     * @param mapper 修改数据mapper
     * @param config 二叉树配置
     * @param <M>    mapper代理类型
     * @param <R>    mapper元素类型
     * @param <P>    查询元素ID类型
     * @return 所有子级元素
     */
    public static <M extends SuperMapper<R>, R, P> List<R> getAllChild(@NonNull P id,
                                                                       @NonNull M mapper,
                                                                       @NonNull BinaryTreeConfig<P> config) {
        return getAllChild(id, mapper, config, new HashMap<>());
    }

    /**
     * 获取元素所有的子级
     *
     * @param id     元素ID
     * @param mapper 修改数据mapper
     * @param config 二叉树配置
     * @param params 查询额外参数
     * @param <M>    mapper代理类型
     * @param <R>    mapper元素类型
     * @param <P>    查询元素ID类型
     * @return 所有子级元素
     */
    public static <M extends SuperMapper<R>, R, P> List<R> getAllChild(@NonNull P id,
                                                                       @NonNull M mapper,
                                                                       @NonNull BinaryTreeConfig<P> config,
                                                                       @Nullable Map<String, Object> params) {
        R element = mapper.selectOne(new QueryWrapper<R>().eq(config.getId(), id));

        Map<String, Object> elementFieldMap = getElementField(element);
        checkConfigField(elementFieldMap, config);

        QueryWrapper<R> queryWrapper = new QueryWrapper<>();
        queryWrapper.lt(config.getRgt(), elementFieldMap.get(config.getRgt()))
                .gt(config.getLft(), elementFieldMap.get(config.getLft()));
        AssertUtil.functionIfPre(params, CollUtil::isNotEmpty, () -> params.forEach(queryWrapper::eq));
        return mapper.selectList(queryWrapper);
    }

    /**
     * 获取元素所有的父级
     *
     * @param id     元素ID
     * @param mapper 修改数据mapper
     * @param config 二叉树配置
     * @param <M>    mapper代理类型
     * @param <R>    mapper元素类型
     * @param <P>    查询元素ID类型
     * @return 所有父级元素
     */
    public static <M extends SuperMapper<R>, R, P> List<R> getAllParent(@NonNull P id,
                                                                        @NonNull M mapper,
                                                                        @NonNull BinaryTreeConfig<P> config) {
        return getAllParent(id, mapper, config, new HashMap<>());
    }

    /**
     * 获取元素所有的父级
     *
     * @param id     元素ID
     * @param mapper 修改数据mapper
     * @param config 二叉树配置
     * @param params 查询额外参数
     * @param <M>    mapper代理类型
     * @param <R>    mapper元素类型
     * @param <P>    查询元素ID类型
     * @return 所有父级元素
     */
    public static <M extends SuperMapper<R>, R, P> List<R> getAllParent(@NonNull P id,
                                                                        @NonNull M mapper,
                                                                        @NonNull BinaryTreeConfig<P> config,
                                                                        @Nullable Map<String, Object> params) {
        R element = mapper.selectOne(new QueryWrapper<R>().eq(config.getId(), id));

        Map<String, Object> elementFieldMap = getElementField(element);
        checkConfigField(elementFieldMap, config);

        QueryWrapper<R> queryWrapper = new QueryWrapper<>();
        queryWrapper.gt(config.getRgt(), elementFieldMap.get(config.getRgt()))
                .lt(config.getLft(), elementFieldMap.get(config.getLft()));
        AssertUtil.functionIfPre(params, CollUtil::isNotEmpty, () -> params.forEach(queryWrapper::eq));
        return mapper.selectList(queryWrapper);
    }

    /**
     * 获取当前树结构层级
     *
     * @param id     元素ID
     * @param mapper 修改数据mapper
     * @param config 二叉树配置
     * @param <M>    mapper代理类型
     * @param <R>    mapper元素类型
     * @param <P>    查询元素ID类型
     * @return 所在层级
     */
    public static <M extends SuperMapper<R>, R, P> Long getLevel(@NonNull P id,
                                                                 @NonNull M mapper,
                                                                 @NonNull BinaryTreeConfig<P> config) {
        return getLevel(id, mapper, config, new HashMap<>());
    }

    /**
     * 获取当前树结构层级
     *
     * @param id     元素ID
     * @param mapper 修改数据mapper
     * @param config 二叉树配置
     * @param params 查询额外参数
     * @param <M>    mapper代理类型
     * @param <R>    mapper元素类型
     * @param <P>    查询元素ID类型
     * @return 所在层级
     */
    public static <M extends SuperMapper<R>, R, P> Long getLevel(@NonNull P id,
                                                                 @NonNull M mapper,
                                                                 @NonNull BinaryTreeConfig<P> config,
                                                                 @Nullable Map<String, Object> params) {
        return AssertUtil.functionIf(!ObjectUtil.equal(id, config.getTopId()), () -> {

            R element = mapper.selectOne(new QueryWrapper<R>().eq(config.getId(), id));

            Map<String, Object> elementFieldMap = getElementField(element);
            checkConfigField(elementFieldMap, config);

            QueryWrapper<R> queryWrapper = new QueryWrapper<>();
            queryWrapper.ge(config.getRgt(), elementFieldMap.get(config.getRgt()))
                    .le(config.getLft(), elementFieldMap.get(config.getLft()));
            AssertUtil.functionIfPre(params, CollUtil::isNotEmpty, () -> params.forEach(queryWrapper::eq));
            return mapper.selectCount(queryWrapper);
        }, 1L);
    }

    /**
     * 二叉树新增元素
     *
     * @param mapper  修改数据mapper
     * @param element 新增元素,需要包含父级ID
     * @param config  二叉树配置
     * @param <M>     mapper代理类型
     * @param <R>     新增元素类型
     * @param <P>     新增元素ID类型
     * @return 新增元素
     */
    public static <M extends SuperMapper<R>, R, P> R addElement(@NonNull M mapper,
                                                                @NonNull R element,
                                                                @NonNull BinaryTreeConfig<P> config) {
        return addElement(mapper, element, config, new HashMap<>());
    }

    /**
     * 二叉树新增元素
     *
     * @param mapper  修改数据mapper
     * @param element 新增元素,需要包含父级ID
     * @param config  二叉树配置
     * @param params  查询额外参数
     * @param <M>     mapper代理类型
     * @param <R>     新增元素类型
     * @param <P>     新增元素ID类型
     * @return 新增元素
     */
    public static <M extends SuperMapper<R>, R, P> R addElement(@NonNull M mapper,
                                                                @NonNull R element,
                                                                @NonNull BinaryTreeConfig<P> config,
                                                                @Nullable Map<String, Object> params) {

        Map<String, Object> elementFieldMap = getElementField(element);

        String idField = config.getId();
        String lftField = config.getLft();
        String rgtField = config.getRgt();
        String pIdField = config.getPId();

        Object pId = elementFieldMap.get(pIdField);

        if (ObjectUtil.equal(pId, config.getTopId())) {
            return addTopElement(mapper, element, config);
        }

        checkConfigField(elementFieldMap, config);

        QueryWrapper<R> parentQueryWrapper = new QueryWrapper<R>()
                .eq(config.getId(), pId);
        AssertUtil.functionIfPre(params, CollUtil::isNotEmpty, () -> params.forEach(parentQueryWrapper::eq));

        List<R> pIdElementList = mapper.selectList(parentQueryWrapper);
        CommonExceptionEnum.NORMAL_EXCEPTION.exceptionIf(CollUtil.isEmpty(pIdElementList) || pIdElementList.size() > 1,
                "The parent element is abnormal!");

        R parentElement = pIdElementList.get(0);
        Map<String, Object> parentElementFieldMap = getElementField(parentElement);
        checkConfigField(parentElementFieldMap, config);
        int parentRgt = Integer.parseInt(String.valueOf(parentElementFieldMap.get(rgtField)));
        int parentLft = Integer.parseInt(String.valueOf(parentElementFieldMap.get(lftField)));

        //如果父级的右值比左值大1,则说明该父级没有子级,直接左值+1,右值+2
        AssertUtil.functionIf(parentRgt - parentLft == 1, () -> {
            try {
                Method setLft = element.getClass().getMethod("set" + CharSequenceUtil.upperFirst(lftField), Integer.class);
                setLft.setAccessible(Boolean.TRUE);
                setLft.invoke(element, parentLft + 1);
                Method setRgt = element.getClass().getMethod("set" + CharSequenceUtil.upperFirst(rgtField), Integer.class);
                setRgt.setAccessible(Boolean.TRUE);
                setRgt.invoke(element, parentLft + 2);
            } catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException e) {
                throw new CommonException(e.getMessage());
            }
        }, () -> {//走这里则说明该父级存在子级,获取父级右值-1的子级,利用子级的右值进行+1,+2
            QueryWrapper<R> maxChildQueryWrapper = new QueryWrapper<R>()
                    .eq(config.getRgt(), parentRgt - 1);
            AssertUtil.functionIfPre(params, CollUtil::isNotEmpty, () -> params.forEach(maxChildQueryWrapper::eq));

            R maxChildElement = mapper.selectOne(maxChildQueryWrapper);
            Map<String, Object> maxChildElementFieldMap = getElementField(maxChildElement);
            checkConfigField(maxChildElementFieldMap, config);
            int maxChildRgt = Integer.parseInt(String.valueOf(maxChildElementFieldMap.get(rgtField)));

            try {
                Method setLft = element.getClass().getMethod("set" + CharSequenceUtil.upperFirst(lftField), Integer.class);
                setLft.setAccessible(Boolean.TRUE);
                setLft.invoke(element, maxChildRgt + 1);
                Method setRgt = element.getClass().getMethod("set" + CharSequenceUtil.upperFirst(rgtField), Integer.class);
                setRgt.setAccessible(Boolean.TRUE);
                setRgt.invoke(element, maxChildRgt + 2);
            } catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException e) {
                throw new CommonException(e.getMessage());
            }

        });

        try {
            Method setPid = element.getClass().getMethod("set" + CharSequenceUtil.upperFirst(pIdField), config.getTopId().getClass());
            setPid.setAccessible(Boolean.TRUE);
            setPid.invoke(element, pId);

            Method setId = element.getClass().getMethod("set" + CharSequenceUtil.upperFirst(idField), config.getTopId().getClass());
            setId.setAccessible(Boolean.TRUE);
            setId.invoke(element, IdWorker.getIdStr());

            Method setParentRgt = parentElement.getClass().getMethod("set" + CharSequenceUtil.upperFirst(rgtField), Integer.class);
            setParentRgt.setAccessible(Boolean.TRUE);
            setParentRgt.invoke(parentElement, parentRgt + 2);
        } catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException e) {
            throw new CommonException(e.getMessage());
        }

        //查询出比父级右值大的所有其他元素进行更新
        QueryWrapper<R> queryWrapper = new QueryWrapper<R>()
                .and(qw -> qw.gt(lftField, parentRgt)
                        .or()
                        .gt(rgtField, parentRgt));
        AssertUtil.functionIfPre(params, CollUtil::isNotEmpty, () -> params.forEach(queryWrapper::eq));
        List<R> otherElementList = mapper.selectList(queryWrapper);

        Object parentElementId = parentElementFieldMap.get(idField);
        List<R> updateList = new ArrayList<>();
        AssertUtil.functionIfPre(otherElementList, CollUtil::isNotEmpty, () -> otherElementList.stream().forEach(otherElement -> {
            try {
                Map<String, Object> otherElementFieldMap = getElementField(otherElement);

                Object otherElementId = otherElementFieldMap.get(idField);
                int otherElementRgt = Integer.parseInt(String.valueOf(otherElementFieldMap.get(rgtField)));
                int otherElementLft = Integer.parseInt(String.valueOf(otherElementFieldMap.get(lftField)));

                //如果id与父级一致,则不更新,因为上面已经更新了
                if (ObjectUtil.equal(parentElementId, otherElementId)) {
                    return;
                }

                //如果左值和右值都大于父级的右值,则左值和右值都+2
                if (otherElementRgt > parentRgt && otherElementLft > parentRgt) {
                    Method setOtherRgt = otherElement.getClass().getMethod("set" + CharSequenceUtil.upperFirst(rgtField), Integer.class);
                    setOtherRgt.setAccessible(Boolean.TRUE);
                    setOtherRgt.invoke(otherElement, otherElementRgt + 2);

                    Method setOtherLft = otherElement.getClass().getMethod("set" + CharSequenceUtil.upperFirst(lftField), Integer.class);
                    setOtherLft.setAccessible(Boolean.TRUE);
                    setOtherLft.invoke(otherElement, otherElementLft + 2);
                    updateList.add(otherElement);
                    return;
                }

                //如果左值小于父级的右值,右值大于父级的右值,则只需要右值+2即可
                if (otherElementLft < parentRgt && otherElementRgt > parentRgt) {
                    Method setOtherRgt = otherElement.getClass().getMethod("set" + CharSequenceUtil.upperFirst(rgtField), Integer.class);
                    setOtherRgt.setAccessible(Boolean.TRUE);
                    setOtherRgt.invoke(otherElement, otherElementRgt + 2);
                    updateList.add(otherElement);
                }

            } catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException e) {
                throw new CommonException(e.getMessage());
            }
        }));

        //更行其余数据左右值
        AssertUtil.functionIfPre(updateList, CollUtil::isNotEmpty, () -> mapper.updateBatch(updateList));
        //更新父级右值
        mapper.updateById(parentElement);
        //新增元素入库
        mapper.insert(element);

        return element;
    }

    /**
     * 二叉树删除元素
     *
     * @param mapper  修改数据mapper
     * @param element 删除元素,需要包含父级ID
     * @param config  二叉树配置
     * @param <M>     mapper代理类型
     * @param <R>     删除元素类型
     * @param <P>     删除元素ID类型
     */
    public static <M extends SuperMapper<R>, R, P> void deleteElement(@NonNull M mapper,
                                                                      @NonNull R element,
                                                                      @NonNull BinaryTreeConfig<P> config) {
        deleteElement(mapper, element, config, new HashMap<>());
    }

    /**
     * 二叉树删除元素
     *
     * @param mapper  修改数据mapper
     * @param element 删除元素,需要包含父级ID
     * @param config  二叉树配置
     * @param params  查询额外参数
     * @param <M>     mapper代理类型
     * @param <R>     删除元素类型
     * @param <P>     删除元素ID类型
     */
    public static <M extends SuperMapper<R>, R, P> void deleteElement(@NonNull M mapper,
                                                                      @NonNull R element,
                                                                      @NonNull BinaryTreeConfig<P> config,
                                                                      @Nullable Map<String, Object> params) {

        mapper.deleteById(element);

        Map<String, Object> elementFieldMap = getElementField(element);

        String lftField = config.getLft();
        String rgtField = config.getRgt();

        checkConfigField(elementFieldMap, config);

        int rgt = Integer.parseInt(String.valueOf(elementFieldMap.get(rgtField)));
        int lft = Integer.parseInt(String.valueOf(elementFieldMap.get(lftField)));

        CommonExceptionEnum.NORMAL_EXCEPTION.exceptionIf(rgt - lft > 1,
                "There are subclasses below the node and cannot be deleted!");

        //查询出比右值大的所有其他元素进行更新
        QueryWrapper<R> queryWrapper = new QueryWrapper<R>()
                .and(qw -> qw.gt(lftField, rgt)
                        .or()
                        .gt(rgtField, lft));

        AssertUtil.functionIfPre(params, CollUtil::isNotEmpty, () -> params.forEach(queryWrapper::eq));
        List<R> otherElementList = mapper.selectList(queryWrapper);

        List<R> updateList = new ArrayList<>();
        AssertUtil.functionIfPre(otherElementList, CollUtil::isNotEmpty, () -> otherElementList.stream().forEach(otherElement -> {
            try {
                Map<String, Object> otherElementFieldMap = getElementField(otherElement);

                int otherElementRgt = Integer.parseInt(String.valueOf(otherElementFieldMap.get(rgtField)));
                int otherElementLft = Integer.parseInt(String.valueOf(otherElementFieldMap.get(lftField)));

                //如果左值和右值都大于父级的右值,则左值和右值都+2
                if (otherElementRgt > rgt && otherElementLft > rgt) {
                    Method setOtherRgt = otherElement.getClass().getMethod("set" + CharSequenceUtil.upperFirst(rgtField), Integer.class);
                    setOtherRgt.setAccessible(Boolean.TRUE);
                    setOtherRgt.invoke(otherElement, otherElementRgt - 2);

                    Method setOtherLft = otherElement.getClass().getMethod("set" + CharSequenceUtil.upperFirst(lftField), Integer.class);
                    setOtherLft.setAccessible(Boolean.TRUE);
                    setOtherLft.invoke(otherElement, otherElementLft - 2);
                    updateList.add(otherElement);
                    return;
                }

                //如果左值小于父级的右值,右值大于父级的右值,则只需要右值+2即可
                if (otherElementLft < rgt && otherElementRgt > rgt) {
                    Method setOtherRgt = otherElement.getClass().getMethod("set" + CharSequenceUtil.upperFirst(rgtField), Integer.class);
                    setOtherRgt.setAccessible(Boolean.TRUE);
                    setOtherRgt.invoke(otherElement, otherElementRgt - 2);
                    updateList.add(otherElement);
                }

            } catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException e) {
                throw new CommonException(e.getMessage());
            }
        }));

        //更行其余数据左右值
        AssertUtil.functionIfPre(updateList, CollUtil::isNotEmpty, () -> mapper.updateBatch(updateList));

    }


    /**
     * 构建二叉树结构顶层元素
     *
     * @param mapper  修改数据mapper
     * @param element 新增元素,需要包含父级ID
     * @param config  二叉树配置
     * @param <M>     mapper代理类型
     * @param <R>     新增元素类型
     * @param <P>     新增元素ID类型
     * @return 新增元素
     */
    public static <M extends SuperMapper<R>, R, P> R addTopElement(@NonNull M mapper,
                                                                   @NonNull R element,
                                                                   @NonNull BinaryTreeConfig<P> config) {

        String idField = config.getId();
        String lftField = config.getLft();
        String rgtField = config.getRgt();
        String pIdField = config.getPId();

        try {

            Method[] methods = element.getClass().getMethods();
            System.out.println(Arrays.toString(methods));

            Method setPid = element.getClass().getMethod("set" + CharSequenceUtil.upperFirst(pIdField), config.getTopId().getClass());
            setPid.setAccessible(Boolean.TRUE);
            setPid.invoke(element, config.getTopId());

            Method setId = element.getClass().getMethod("set" + CharSequenceUtil.upperFirst(idField), config.getTopId().getClass());
            setId.setAccessible(Boolean.TRUE);
            setId.invoke(element, IdWorker.getIdStr());

            Method setLft = element.getClass().getMethod("set" + CharSequenceUtil.upperFirst(lftField), Integer.class);
            setLft.setAccessible(Boolean.TRUE);
            setLft.invoke(element, 1);

            Method setRgt = element.getClass().getMethod("set" + CharSequenceUtil.upperFirst(rgtField), Integer.class);
            setRgt.setAccessible(Boolean.TRUE);
            setRgt.invoke(element, 2);

            //新增元素入库
            mapper.insert(element);
        } catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException e) {
            throw new CommonException(e.getMessage());
        }
        return element;

    }


    /**
     * 获取对象字段信息与字段值
     *
     * @param element 对象参数
     * @param <R>     对象类型
     * @return 字段名称和字段值
     */
    private static <R> Map<String, Object> getElementField(@NonNull R element) {
        Field[] declaredFields = element.getClass().getDeclaredFields();
        Map<String, Object> elementFieldMap = new HashMap<>();
        try {
            for (Field field : declaredFields) {
                field.setAccessible(Boolean.TRUE);
                Object value = field.get(element);
                String name = field.getName();
                elementFieldMap.put(name, value);
            }
            return elementFieldMap;
        } catch (IllegalAccessException e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * 校验添加元素的字段是否与配置类匹配
     *
     * @param elementFieldMap 添加元素存在字段
     * @param config          字段信息配置类
     */
    private static <P> void checkConfigField(@NonNull Map<String, Object> elementFieldMap,
                                             @NonNull BinaryTreeConfig<P> config) {
        String id = config.getId();
        String pId = config.getPId();
        String lft = config.getLft();
        String rgt = config.getRgt();

        CommonExceptionEnum.NORMAL_EXCEPTION.exceptionIf(!elementFieldMap.containsKey(id),
                "The [{}] field does not exist for the element to be added!", id);
        CommonExceptionEnum.NORMAL_EXCEPTION.exceptionIf(!elementFieldMap.containsKey(pId),
                "The [{}] field does not exist for the element to be added!", pId);
        CommonExceptionEnum.NORMAL_EXCEPTION.exceptionIf(!elementFieldMap.containsKey(lft),
                "The [{}] field does not exist for the element to be added!", lft);
        CommonExceptionEnum.NORMAL_EXCEPTION.exceptionIf(!elementFieldMap.containsKey(rgt),
                "The [{}] field does not exist for the element to be added!", rgt);

    }

}

代码中的SupperMapper类可以更换成自己定义的mapper类,继承BaseMapper就行,这里不贴代码。

操作指南:

1:代码: image.png

2:数据库: image.png

3:进行数据添加: image.png

4:数据库查看:

image.png

如有问题,请联系:QQxin7045