七天用Go从零实现系列:Day3 前缀树路由写注释

93 阅读3分钟

七天用Go从零实现系列

Day3 前缀树路由

字典树(前缀树)

    type node struct {
    	pattern  string  // 待匹配路由 例如 /p/:lang
    	part     string  // 路由中的一部分,例如 :lang
    	children []*node // 子节点,例如 [doc, tutorial, intro]
    	isWild   bool    // 是否精确匹配,part 含有 : 或 * 时为true
    }

    // 格式化输出
    func (n *node) String() string {
    	return fmt.Sprintf("node{pattern=%s, part=%s, isWild=%t}", n.pattern, n.part, n.isWild)
    }

    // 插入节点
    func (n *node) insert(pattern string, parts []string, height int) {

    	// 如果已经匹配完了,那么将pattern赋值给该node,表示它是一个完整的url
    	// 如果是最后赋值,如果刚好最后一个匹配则是覆盖
    	// 这是递归的终止条件
    	if len(parts) == height {
    		n.pattern = pattern
    		return
    	}

    	//解析路径
    	part := parts[height]
    	//获取匹配的子节点
    	child := n.matchChild(part)
    	//空则赋值添加子节点
    	if child == nil {
    		child = &node{part: part, isWild: part[0] == ':' || part[0] == '*'}
    		n.children = append(n.children, child)
    	}
    	//递归向下面节点匹配
    	child.insert(pattern, parts, height+1)
    }

    func (n *node) search(parts []string, height int) *node {
    	// 递归终止条件,找到末尾了或者通配符
    	if len(parts) == height || strings.HasPrefix(n.part, "*") {
    		// pattern为空字符串表示它不是一个完整的url,匹配失败
    		if n.pattern == "" {
    			return nil
    		}
    		return n
    	}

    	//取路径
    	part := parts[height]
    	//获取当前节点子节点所有匹配的节点
    	children := n.matchChildren(part)

    	for _, child := range children {
    		//递归查找
    		result := child.search(parts, height+1)
    		if result != nil {
    			return result
    		}
    	}

    	return nil
    }

    // 返回这个节点下的所有节点
    func (n *node) travel(list *([]*node)) {
    	if n.pattern != "" {
    		*list = append(*list, n)
    	}
    	for _, child := range n.children {
    		child.travel(list)
    	}
    }

    // 当前第一个字节点匹配成功的节点,用于插入
    func (n *node) matchChild(part string) *node {
    	for _, child := range n.children {
    		if child.part == part || child.isWild {
    			return child
    		}
    	}
    	return nil
    }

    // 当前匹配成功的节点,用于查找
    func (n *node) matchChildren(part string) []*node {
    	nodes := make([]*node, 0)
    	for _, child := range n.children {
    		if child.part == part || child.isWild {
    			nodes = append(nodes, child)
    		}
    	}
    	return nodes
    }

路由封装

type router struct {
	roots    map[string]*node
	handlers map[string]http.HandlerFunc
}

func newRouter() *router {
	return &router{
		roots:    make(map[string]*node),
		handlers: make(map[string]http.HandlerFunc),
	}
}

// 把路径分割
func parsePattern(pattern string) []string {
	//路径分割
	vs := strings.Split(pattern, "/")
	parts := make([]string, 0)
	for _, item := range vs {
		if item != "" {
			parts = append(parts, item)
			//遇到*号开头的 下面不处理了
			if item[0] == '*' {
				break
			}
		}
	}
	return parts
}

func (r *router) addRoute(method string, pattern string, handler http.HandlerFunc) {
	//路径分割处理
	parts := parsePattern(pattern)

	//把请求类型+路径作为key
	key := method + "-" + pattern
	_, ok := r.roots[method]
	if !ok { //不存在初始化
		r.roots[method] = &node{}
	}
	//请求类型map 添加处理节点
	r.roots[method].insert(pattern, parts, 0)
	//具体的请求类型+路径 处理方法
	r.handlers[key] = handler
}


 //根据请求类型+路径 查找匹配节点和占位符参数
func (r *router) getRoute(method string, path string) (*node, map[string]string) {
	//路径分割
	searchParts := parsePattern(path)
	//占位符参数
	params := make(map[string]string)
	//请求类型
	root, ok := r.roots[method]

	if !ok {
		return nil, nil
	}
	//根据路径查找节点
	n := root.search(searchParts, 0)

	if n != nil {
		//路径分割
		parts := parsePattern(n.pattern)
		for index, part := range parts {
			//查找:和*号占位符
			if part[0] == ':' {
				//请求路径的占位符值
				params[part[1:]] = searchParts[index]
			}
			if part[0] == '*' && len(part) > 1 {
				//请求路径的占位符值
				params[part[1:]] = strings.Join(searchParts[index:], "/")
				break
			}
		}
		return n, params
	}

	return nil, nil
}

// 获取所有请求类型的节点
func (r *router) getRoutes(method string) []*node {
	root, ok := r.roots[method]
	if !ok {
		return nil
	}
	nodes := make([]*node, 0)
	root.travel(&nodes)
	return nodes
}

// 请求路径的处理函数添加
func (r *router) handle(c *Context) {
	//获取请求类型+路径的存储
	n, params := r.getRoute(c.Method, c.Path)
	if n != nil {
		//占位符赋值
		c.Params = params
		//请求处理处理方法赋值
		key := c.Method + "-" + n.pattern
		r.handlers[key](c.Writer, c.Req)
	} else {
		c.String(http.StatusNotFound, "404 NOT FOUND: %s\n", c.Path)
	}
}