Spring Data Neo4j实现知识图谱的效果

134 阅读5分钟

Spring Data Neo4j实现知识图谱的效果

官网:spring.io/projects/sp…

快速集成

Docker部署neo4j:

 version: '3.8'
 ​
 services:
   neo4j:
     image: neo4j
     container_name: neo4j
     restart: always
     ports:
       - "7474:7474"
       - "7687:7687"
     volumes:
       - ./data:/data
       - ./plugins:/plugins
     environment:
       - NEO4J_apoc_export_file_enabled=true
       - NEO4J_apoc_import_file_enabled=true
       - NEO4J_apoc_import_file_use__neo4j__config=true
       - NEO4JLABS_PLUGINS=["apoc"]
       - NEO4J_dbms_security_procedures_unrestricted=apoc.*

SpringBoot中使用neo4j,添加依赖:

 <dependency>
     <groupId>org.springframework.boot</groupId>
     <artifactId>spring-boot-starter-data-neo4j</artifactId>
 </dependency>

导入neo4j配置信息:

 # Neo4j
 spring.neo4j.uri=neo4j://localhost:7687
 spring.neo4j.authentication.username=neo4j
 spring.neo4j.authentication.password=xxx

实体关系定义

  • @Node注解:这个注解标识一个类作为Neo4j数据库中的一个节点实体。在Neo4j中,节点通常表示一个实体对象,每个节点可以有属性和关系。
  • @Id:该注解标识一个字段为节点的唯一标识符,类似于数据库中的主键。Neo4j中的每个节点都可以有一个唯一的标识符,用来区分不同的节点。@GeneratedValue 表示这个ID会自动生成,通常是在插入节点时由Neo4j自动生成。
  • @Property:这个注解用于将类中的字段映射到Neo4j节点的属性。它指定了该字段在Neo4j节点中的属性名。
  • @Relationship 注解用来定义Neo4j节点之间的关系。它描述了当前节点与其他节点的关系类型和方向。

type:指定关系的类型,在这里是 scope,表示该关系的名称。

direction:指定关系的方向。Relationship.Direction.OUTGOING 表示当前节点到目标节点的方向,Relationship.Direction.INCOMING 表示目标节点到当前节点的方向。OUTGOING 表示从当前节点指向其他节点。

例子:

 /**
  * 创作一个对应 KnowledgeEntity 实体对象 -> 对应我们 Neo4j数据库中的 Node 对象
  */
 @Data
 @Node("KnowledgePoint")
 public class KnowledgePointNode {
     @GeneratedValue
     @Id
     private Long id;
 ​
     @Property("name")
     private String name;
 ​
     @Property("description")
     private String description;
 ​
     @Relationship(type = "scope", direction = Relationship.Direction.OUTGOING)
     private ScopeNode scopeNode;
 ​
     @Data
     @Node("scope")
     @AllArgsConstructor(staticName = "of")
     public static class ScopeNode {
         @Property("name")
         @Id
         private String name;
     }
 }

存储层实现

@Repository 是Spring的一个标记注解,表示该接口是一个数据访问层的接口,负责与数据库进行交互。在Spring容器中,@Repository 注解的类会被自动扫描和注册为一个Bean

Neo4jRepository是Spring Data Neo4j提供的一个接口,类似于Spring Data JPA中的JpaRepository。通过继承 Neo4jRepository,你可以直接使用其提供的一些标准方法(如保存、删除、查询等),同时也能通过自定义查询来执行复杂的Cypher查询。

@Query 注解允许你使用Cypher查询语言直接编写自定义查询,而不需要依赖Spring Data Neo4j的自动查询解析机制。

@Param 注解用于将方法参数传递给查询中的占位符(如 $name)。通过 MATCH, WHERE, RETURN 等Cypher语法,可以构建强大的查询来从Neo4j数据库中检索数据。

 @Repository
 public interface KnowledgePointRepository extends Neo4jRepository<KnowledgePointNode, Long> {
     @Query("MATCH (kp:KnowledgePoint)-[:grade]->(g:grade) WHERE g.name = $gradeName RETURN kp")
     List<KnowledgePointNode> findByGrade(@Param("gradeName") String gradeName);
 ​
     @Query("MATCH (n:KnowledgePointNode) WHERE n.name CONTAINS $name RETURN n")
     List<KnowledgePointNode> findByNameContainingCypher(String name);
 ​
     @Query("MATCH (n) <-[r]->(m) where n.name contains($name) return 10")
     List<KnowledgePointNode> findByName(String name);
 }

查询关联节点

上面的存储层我用起来不是特别方便,所以去网站找了个工具了,可以查询一个节点的相关节点以及边。

查询请求实体:

 @Data
 public class GraphQuery {
     private String nodeName;
     private int pageSize = 10;
 ​
     public String toCypher() {
         String cypher = "";
         if (StringUtils.isNotEmpty(nodeName)) {
             cypher = "MATCH (n) <-[r]->(m) where n.name contains('" + nodeName + "') return * limit " + pageSize;
         } else {
             cypher = "MATCH (n) <-[r]->(m) return * limit " + pageSize;
         }
         return cypher;
     }
 }

工具类:

功能就是去执行cypherSql,得到对应的节点和关系

 public HashMap<String, Object> getGraphNodeAndShip(String cypherSql) {
     HashMap<String, Object> mo = new HashMap<>();
     try (Session session = neo4jDriver.session()) {
         Result result = session.run(cypherSql);  // 使用 session.run 执行查询
         List<HashMap<String, Object>> ents = new ArrayList<>();
         List<HashMap<String, Object>> ships = new ArrayList<>();
         List<String> uuids = new ArrayList<>();
         List<String> shipids = new ArrayList<>();
 ​
         // 逐条遍历结果
         while (result.hasNext()) {
             Record recordItem = result.next();  // 获取每一条记录
             List<Pair<String, Value>> fields = recordItem.fields();
             for (Pair<String, Value> pair : fields) {
                 HashMap<String, Object> rships = new HashMap<>();
                 HashMap<String, Object> rss = new HashMap<>();
                 String typeName = pair.value().type().name();
 ​
                 if ("NULL".equals(typeName)) {
                 } else if ("NODE".equals(typeName)) {
                     Node neo4jNode = pair.value().asNode();
                     Map<String, Object> map = neo4jNode.asMap();
                     String uuid = String.valueOf(neo4jNode.id());
                     if (!uuids.contains(uuid)) {
                         for (Entry<String, Object> entry : map.entrySet()) {
                             String key = entry.getKey();
                             rss.put(key, entry.getValue());
                         }
                         rss.put("uuid", uuid);
                         uuids.add(uuid);
                     }
                     if (!rss.isEmpty()) {
                         ents.add(rss);
                     }
                 } else if ("RELATIONSHIP".equals(typeName)) {
                     Relationship rship = pair.value().asRelationship();
                     String uuid = String.valueOf(rship.id());
                     if (!shipids.contains(uuid)) {
                         String sourceid = String.valueOf(rship.startNodeId());
                         String targetid = String.valueOf(rship.endNodeId());
                         Map<String, Object> map = rship.asMap();
                         for (Entry<String, Object> entry : map.entrySet()) {
                             String key = entry.getKey();
                             rships.put(key, entry.getValue());
                         }
                         rships.put("uuid", uuid);
                         rships.put("sourceid", sourceid);
                         rships.put("targetid", targetid);
                         shipids.add(uuid);
                         ships.add(rships);
                     }
                 } else if ("PATH".equals(typeName)) {
                     Path path = pair.value().asPath();
                     Map<String, Object> startNodemap = path.start().asMap();
                     String startNodeuuid = String.valueOf(path.start().id());
                     if (!uuids.contains(startNodeuuid)) {
                         rss = new HashMap<String, Object>();
                         for (Entry<String, Object> entry : startNodemap.entrySet()) {
                             String key = entry.getKey();
                             rss.put(key, entry.getValue());
                         }
                         rss.put("uuid", startNodeuuid);
                         uuids.add(startNodeuuid);
                         ents.add(rss);
                     }
 ​
                     Map<String, Object> endNodemap = path.end().asMap();
                     String endNodeuuid = String.valueOf(path.end().id());
                     if (!uuids.contains(endNodeuuid)) {
                         rss = new HashMap<String, Object>();
                         for (Entry<String, Object> entry : endNodemap.entrySet()) {
                             String key = entry.getKey();
                             rss.put(key, entry.getValue());
                         }
                         rss.put("uuid", endNodeuuid);
                         uuids.add(endNodeuuid);
                         ents.add(rss);
                     }
 ​
                     for (Node next : path.nodes()) {
                         String uuid = String.valueOf(next.id());
                         if (!uuids.contains(uuid)) {
                             rss = new HashMap<String, Object>();
                             Map<String, Object> map = next.asMap();
                             for (Entry<String, Object> entry : map.entrySet()) {
                                 String key = entry.getKey();
                                 rss.put(key, entry.getValue());
                             }
                             rss.put("uuid", uuid);
                             uuids.add(uuid);
                             ents.add(rss);
                         }
                     }
 ​
                     for (Relationship next : path.relationships()) {
                         String uuid = String.valueOf(next.id());
                         if (!shipids.contains(uuid)) {
                             rships = new HashMap<String, Object>();
                             String sourceid = String.valueOf(next.startNodeId());
                             String targetid = String.valueOf(next.endNodeId());
                             Map<String, Object> map = next.asMap();
                             for (Entry<String, Object> entry : map.entrySet()) {
                                 String key = entry.getKey();
                                 rships.put(key, entry.getValue());
                             }
                             rships.put("uuid", uuid);
                             rships.put("sourceid", sourceid);
                             rships.put("targetid", targetid);
                             shipids.add(uuid);
                             ships.add(rships);
                         }
                     }
                 } else if (typeName.contains("LIST")) {
                     Iterable<Value> values = pair.value().values();
                     Value next = values.iterator().next();
                     String type = next.type().name();
                     if ("RELATIONSHIP".equals(type)) {
                         Relationship rship = next.asRelationship();
                         String uuid = String.valueOf(rship.id());
                         if (!shipids.contains(uuid)) {
                             String sourceid = String.valueOf(rship.startNodeId());
                             String targetid = String.valueOf(rship.endNodeId());
                             Map<String, Object> map = rship.asMap();
                             for (Entry<String, Object> entry : map.entrySet()) {
                                 String key = entry.getKey();
                                 rships.put(key, entry.getValue());
                             }
                             rships.put("uuid", uuid);
                             rships.put("sourceid", sourceid);
                             rships.put("targetid", targetid);
                             shipids.add(uuid);
                             ships.add(rships);
                         }
                     }
                 } else if (typeName.contains("MAP")) {
                     rss.put(pair.key(), pair.value().asMap());
                 } else {
                     rss.put(pair.key(), pair.value().toString());
                     ents.add(rss);
                 }
             }
         }
         mo.put("node", ents);
         mo.put("relationship", ships);
     } catch (Exception e) {
         throw new RuntimeException("执行Cypher查询异常", e);
     }
     return mo;
 }

前端实现

前端使用React+vis-network来实现知识图谱的效果,

官网:visjs.github.io/vis-network…

实现的效果:

image-20241128134317224

首先是展示整个图谱的容器div

 <div ref={networkContainerRef} id="mynetwork" style={{
   width: "80%",
   height: "80vh",
   border: "1px solid #444444",
   backgroundColor: '#222222'
 }}/>

我们使用networkContainerRef可以方便的去操作这个Dom

初始化的时候去加载所有的数据

const getDataFromServer = async (nodeName = "", pageSize = 10) => {
  const res = await getDomainGraph({
    nodeName: nodeName,
    pageSize: pageSize
  })
  if (res.code === 0) {
    const node = res.data.node;
    const relationship = res.data.relationship;
    console.log("node", node)
    console.log("relationship", relationship)
    return {
      vertices: node,
      edges: relationship
    }
  }
}

拿到对应的节点和边的关系

接着去执行绘图的操作:

 const drawGraph = (vertices, edges) => {
    addNode(vertices);
    addEdge(edges);
    const data = {
      nodes: graphNodes.current,
      edges: graphEdges.current
    }
    const network = new Network(networkContainerRef.current as any, data, options);
    network.on('doubleClick', (params) => doubleClick(params))
    network.on('click', (params) => leftClick(params))
    network.on('oncontext', (params) => rightClick(params))
  }

我想要实现双击节点,可以查询到新的关于这个节点的信息,相邻节点以及边的关系,那么每次添加节点的时候,都需要去做一个去重的操作

const addNode = (vertices: any[]) => {
  console.log("vertices", vertices)
  vertices?.forEach((vertex) => {
    if (!graphNodes.current.get(vertex.uuid)) {
      graphNodes.current.add({
        id: vertex.uuid,
        label: vertex.name,
        group: vertex.name,
        vertex: vertex,
      } as any);
    }
  });
};

const addEdge = (edges: any[]) => {
  edges?.forEach((edge) => {
    if (!graphEdges.current.get(edge.uuid)) {
      graphEdges.current.add({
        id: edge.uuid,
        from: edge.sourceid,
        to: edge.targetid,
        edge: edge,
      } as any);
    }
  });
};

其他的一些变量:

const options = {
  autoResize: true,
  width: '100%',
  interaction: {
    hover: true,
    navigationButtons: true,
    zoomView: true
  },
  nodes: {
    font: {
      color: "#fff",
      face: "arial",
      size: 12,
    },
    shape: 'dot',
    color: {
      background: '#00ccff',
      border: '#00ccff',
      highlight: {background: '#fb6a02', border: '#fb6a02'},
      hover: {background: '#ec3112', border: '#ec3112'}
    },
    borderWidth: 3,
    shadow: {
      enabled: true,
      color: 'rgba(1, 187, 223, 1)',
      size: 12,
    },
  },
  edges: {
    arrows: 'to',
    color: {color: "#39ADF1"},
    smooth: {
      type: 'dynamic'
    },
  },
  physics: {
    maxVelocity: 50,
    solver: 'barnesHut',
    timestep: 0.3,
    stabilization: {iterations: 150}
  }
}


  const graphNodes = useRef(new DataSet());
  const graphEdges = useRef(new DataSet());