FSM有限状态机实现

143 阅读3分钟

1.类型定义

  • State[T any]:这是一个函数类型,表示状态机中的状态。它接受一个上下文ctx、一个泛型参数args,并返回更新后的参数、下一个状态以及可能发生的错误。
  • Args:这是一个结构体,包含了服务执行所需的所有参数,如服务名称、类型以及客户端接口。

2.方法

  • validate:用于验证Args结构体的字段是否有效。
  • Execute:是服务执行的主函数,它调用Run函数来启动状态机。
  • checkService、invokeService、reportService:这三个函数分别对应状态机中的三个状态,用于检查服务、调用服务和报告服务。

3.状态机执行

  • Run[T any]:这是一个通用的状态机执行函数,它接受初始状态和参数,然后依次执行每个状态,直到没有下一个状态。

4.接口和实现

  • CheckClient、InvokeClient、ReportClient:这三个接口定义了服务执行过程中需要调用的方法。
  • Server:实现了上述三个接口,提供了具体的方法实现。

5.设计优势

  1. 模块化:通过将服务执行过程分解为多个状态,代码更加模块化,易于理解和维护。
  2. 可扩展性:由于使用了泛型和接口,可以很容易地添加新的服务类型或状态。
  3. 错误处理:每个状态都可以返回错误,使得错误处理更加集中和明确。
  4. 灵活性强:状态机的结构使得流程控制更加灵活,可以根据需要添加或修改状态。

6.代码

package main

import (
	"context"
	"errors"
	"fmt"
)

type State[T any] func(ctx context.Context, args T) (T, State[T], error)

type Args struct {
	Name   string
	Type   string
	Report ReportClient
	Check  CheckClient
	Invoke InvokeClient
}

func (a Args) validate(ctx context.Context) error {
	if a.Name == "" {
		return fmt.Errorf("name cannot be an empty string")
	}

	if a.Check == nil {
		return fmt.Errorf("check cannot be nil")
	}
	if a.Invoke == nil {
		return fmt.Errorf("invoke cannot be nil")
	}
	if a.Report == nil {
		return fmt.Errorf("report cannot be nil")
	}
	return nil
}

func Execute(ctx context.Context, args Args) error {
	if err := args.validate(ctx); err != nil {
		return err
	}
	start := checkService
	_, err := Run[Args](ctx, args, start)
	if err != nil {
		return fmt.Errorf("execute service %q: %w", args.Name, err)
	}
	return nil
}

func checkService(ctx context.Context, args Args) (Args, State[Args], error) {
	err := args.Check.CheckLocalAuth(ctx, args.Name)
	if err != nil {
		return args, nil, fmt.Errorf("the service was not center auth")
	}

	err = args.Check.CheckCenterAuth(ctx, args.Name)
	if err != nil {
		return args, nil, fmt.Errorf("the service was not local auth")
	}

	err = args.Check.CheckLimit(ctx, args.Name)
	if err != nil {
		return args, nil, fmt.Errorf("the service was not local auth")
	}
	return args, invokeService, nil
}

func invokeService(ctx context.Context, args Args) (Args, State[Args], error) {
	var err error
	switch args.Type {
	case "PQ":
		err = args.Invoke.PQ(ctx, args.Name)
	case "PIR":
		err = args.Invoke.PIR(ctx, args.Name)
	case "PSI":
		err = args.Invoke.PSI(ctx, args.Name)
	}
	if err != nil {
		return args, nil, fmt.Errorf("the service invoke error: %+v", err)
	}
	return args, reportService, nil
}

func reportService(ctx context.Context, args Args) (Args, State[Args], error) {
	err := args.Report.ReportLocal(ctx, args.Name)
	if err != nil {
		return args, nil, fmt.Errorf("the service report local error")
	}

	err = args.Report.ReportCenter(ctx, args.Name)
	if err != nil {
		return args, nil, fmt.Errorf("the service was not local auth")
	}
	return args, nil, nil
}

func Run[T any](ctx context.Context, args T, start State[T]) (T, error) {
	var err error
	current := start
	for {
		if ctx.Err() != nil {
			return args, ctx.Err()
		}
		args, current, err = current(ctx, args)
		if err != nil {
			return args, err
		}
		if current == nil {
			return args, nil
		}
	}
}

func main() {
	server := NewServer()
	args := Args{
		Name:   "xuetu-链",
		Type:   "PQ",
		Report: server,
		Check:  server,
		Invoke: server,
	}
	err := Execute(context.Background(), args)
	if err != nil {
		fmt.Println(err.Error())
	}
}

type CheckClient interface {
	CheckLocalAuth(ctx context.Context, service string) error
	CheckCenterAuth(ctx context.Context, service string) error
	CheckLimit(ctx context.Context, service string) error
}

type InvokeClient interface {
	PQ(ctx context.Context, service string) error
	PIR(ctx context.Context, service string) error
	PSI(ctx context.Context, service string) error
}

type ReportClient interface {
	ReportLocal(ctx context.Context, service string) error
	ReportCenter(ctx context.Context, service string) error
}

type Server struct {
}

func NewServer() *Server {
	return &Server{}
}

func (s Server) CheckLocalAuth(ctx context.Context, service string) error {
	fmt.Println("CheckLocalAuth")
	return nil
}

func (s Server) CheckCenterAuth(ctx context.Context, service string) error {
	fmt.Println("CheckCenterAuth")
	return nil
}

func (s Server) CheckLimit(ctx context.Context, service string) error {
	fmt.Println("CheckLimit")
	return nil
}

func (s Server) PQ(ctx context.Context, service string) error {
	fmt.Println("PQ")
	return errors.New("出错了")
}

func (s Server) PIR(ctx context.Context, service string) error {
	fmt.Println("PIR")
	return nil
}

func (s Server) PSI(ctx context.Context, service string) error {
	fmt.Println("PSI")
	return nil
}

func (s Server) ReportLocal(ctx context.Context, service string) error {
	fmt.Println("ReportLocal")
	return nil
}

func (s Server) ReportCenter(ctx context.Context, service string) error {
	fmt.Println("ReportCenter")
	return nil
}