14.2 自定义DSL和循环依赖检测竟然还能这样做?

0 阅读8分钟

14.2 震撼!自定义DSL和循环依赖检测竟然还能这样做?

在上一节中,我们讨论了任务编排和规则引擎的基本概念和实现。今天我们将深入探讨两个关键技术点:自定义DSL的设计与实现,以及循环依赖检测机制。这些技术将使我们的任务调度系统更加灵活和健壮。

自定义DSL设计与ANTLR实现

DSL(Domain Specific Language)是针对特定领域设计的语言,它可以让非技术人员也能方便地定义复杂的业务逻辑。我们将使用ANTLR(Another Tool for Language Recognition)来实现我们的DSL解析器。

首先,让我们定义DSL的语法规则:

// TaskOrchestration.g4
grammar TaskOrchestration;

// 解析入口
orchestration: workflow+;

// 工作流定义
workflow: 'workflow' STRING '{' taskDefinition+ '}';

// 任务定义
taskDefinition: 'task' STRING '{' taskProperty* '}';

// 任务属性
taskProperty
    : 'type' '=' STRING
    | 'depends_on' '=' '[' STRING (',' STRING)* ']'
    | 'parallel_with' '=' '[' STRING (',' STRING)* ']'
    | 'condition' '=' STRING
    | 'timeout' '=' STRING
    ;

// 词法规则
STRING: '"' (~["\\\r\n] | '\\' (. | EOF))* '"' | '\'' (~['\\\r\n] | '\\' (. | EOF))* '\'';
WS: [ \t\r\n]+ -> skip;

基于这个语法规则,我们可以实现DSL解析器:

package dsl

import (
    "fmt"
    "regexp"
    "strconv"
    "strings"
)

// Workflow 工作流定义
type Workflow struct {
    Name  string
    Tasks map[string]*TaskDefinition
}

// TaskDefinition 任务定义
type TaskDefinition struct {
    Name        string
    Type        string
    DependsOn   []string
    ParallelWith []string
    Condition   string
    Timeout     string
}

// DSLParser DSL解析器
type DSLParser struct {
    content string
    pos     int
}

// NewDSLParser 创建DSL解析器
func NewDSLParser(content string) *DSLParser {
    return &DSLParser{
        content: content,
        pos:     0,
    }
}

// Parse 解析DSL内容
func (p *DSLParser) Parse() ([]*Workflow, error) {
    var workflows []*Workflow
    
    for p.pos < len(p.content) {
        p.skipWhitespace()
        
        if p.match("workflow") {
            workflow, err := p.parseWorkflow()
            if err != nil {
                return nil, err
            }
            workflows = append(workflows, workflow)
        } else {
            break
        }
    }
    
    return workflows, nil
}

// parseWorkflow 解析工作流
func (p *DSLParser) parseWorkflow() (*Workflow, error) {
    p.skipWhitespace()
    
    // 解析工作流名称
    name, err := p.parseString()
    if err != nil {
        return nil, fmt.Errorf("failed to parse workflow name: %v", err)
    }
    
    p.skipWhitespace()
    
    // 期望 '{'
    if !p.match("{") {
        return nil, fmt.Errorf("expected '{' after workflow name")
    }
    
    workflow := &Workflow{
        Name:  name,
        Tasks: make(map[string]*TaskDefinition),
    }
    
    // 解析任务定义
    for p.pos < len(p.content) && !p.match("}") {
        p.skipWhitespace()
        
        if p.match("task") {
            task, err := p.parseTask()
            if err != nil {
                return nil, fmt.Errorf("failed to parse task: %v", err)
            }
            workflow.Tasks[task.Name] = task
        } else if p.current() == '}' {
            break
        } else {
            return nil, fmt.Errorf("unexpected token: %c", p.current())
        }
    }
    
    return workflow, nil
}

// parseTask 解析任务定义
func (p *DSLParser) parseTask() (*TaskDefinition, error) {
    p.skipWhitespace()
    
    // 解析任务名称
    name, err := p.parseString()
    if err != nil {
        return nil, fmt.Errorf("failed to parse task name: %v", err)
    }
    
    p.skipWhitespace()
    
    // 期望 '{'
    if !p.match("{") {
        return nil, fmt.Errorf("expected '{' after task name")
    }
    
    task := &TaskDefinition{
        Name: name,
    }
    
    // 解析任务属性
    for p.pos < len(p.content) && !p.match("}") {
        p.skipWhitespace()
        
        if p.match("type") {
            p.skipWhitespace()
            if !p.match("=") {
                return nil, fmt.Errorf("expected '=' after type")
            }
            p.skipWhitespace()
            task.Type, err = p.parseString()
            if err != nil {
                return nil, fmt.Errorf("failed to parse type value: %v", err)
            }
        } else if p.match("depends_on") {
            p.skipWhitespace()
            if !p.match("=") {
                return nil, fmt.Errorf("expected '=' after depends_on")
            }
            p.skipWhitespace()
            task.DependsOn, err = p.parseStringArray()
            if err != nil {
                return nil, fmt.Errorf("failed to parse depends_on value: %v", err)
            }
        } else if p.match("parallel_with") {
            p.skipWhitespace()
            if !p.match("=") {
                return nil, fmt.Errorf("expected '=' after parallel_with")
            }
            p.skipWhitespace()
            task.ParallelWith, err = p.parseStringArray()
            if err != nil {
                return nil, fmt.Errorf("failed to parse parallel_with value: %v", err)
            }
        } else if p.match("condition") {
            p.skipWhitespace()
            if !p.match("=") {
                return nil, fmt.Errorf("expected '=' after condition")
            }
            p.skipWhitespace()
            task.Condition, err = p.parseString()
            if err != nil {
                return nil, fmt.Errorf("failed to parse condition value: %v", err)
            }
        } else if p.match("timeout") {
            p.skipWhitespace()
            if !p.match("=") {
                return nil, fmt.Errorf("expected '=' after timeout")
            }
            p.skipWhitespace()
            task.Timeout, err = p.parseString()
            if err != nil {
                return nil, fmt.Errorf("failed to parse timeout value: %v", err)
            }
        } else if p.current() == '}' {
            break
        } else {
            return nil, fmt.Errorf("unexpected token in task definition: %c", p.current())
        }
    }
    
    return task, nil
}

// parseString 解析字符串
func (p *DSLParser) parseString() (string, error) {
    p.skipWhitespace()
    
    if p.pos >= len(p.content) {
        return "", fmt.Errorf("unexpected end of input")
    }
    
    quote := p.content[p.pos]
    if quote != '"' && quote != '\'' {
        return "", fmt.Errorf("expected string, got %c", quote)
    }
    
    p.pos++ // 跳过开始引号
    
    start := p.pos
    for p.pos < len(p.content) && p.content[p.pos] != quote {
        if p.content[p.pos] == '\\' && p.pos+1 < len(p.content) {
            p.pos += 2 // 跳过转义字符
        } else {
            p.pos++
        }
    }
    
    if p.pos >= len(p.content) {
        return "", fmt.Errorf("unterminated string")
    }
    
    result := p.content[start:p.pos]
    p.pos++ // 跳过结束引号
    
    // 处理转义字符
    result = strings.ReplaceAll(result, "\\\"", "\"")
    result = strings.ReplaceAll(result, "\\'", "'")
    result = strings.ReplaceAll(result, "\\n", "\n")
    result = strings.ReplaceAll(result, "\\t", "\t")
    
    return result, nil
}

// parseStringArray 解析字符串数组
func (p *DSLParser) parseStringArray() ([]string, error) {
    p.skipWhitespace()
    
    if !p.match("[") {
        return nil, fmt.Errorf("expected '[' for string array")
    }
    
    var result []string
    
    p.skipWhitespace()
    if p.current() != ']' {
        for {
            str, err := p.parseString()
            if err != nil {
                return nil, fmt.Errorf("failed to parse string in array: %v", err)
            }
            result = append(result, str)
            
            p.skipWhitespace()
            if p.current() == ']' {
                break
            }
            
            if !p.match(",") {
                return nil, fmt.Errorf("expected ',' or ']' in string array")
            }
            p.skipWhitespace()
        }
    }
    
    if !p.match("]") {
        return nil, fmt.Errorf("expected ']' to close string array")
    }
    
    return result, nil
}

// match 检查当前位置是否匹配指定字符串
func (p *DSLParser) match(s string) bool {
    if p.pos+len(s) <= len(p.content) && p.content[p.pos:p.pos+len(s)] == s {
        p.pos += len(s)
        return true
    }
    return false
}

// current 获取当前位置的字符
func (p *DSLParser) current() byte {
    if p.pos < len(p.content) {
        return p.content[p.pos]
    }
    return 0
}

// skipWhitespace 跳过空白字符
func (p *DSLParser) skipWhitespace() {
    for p.pos < len(p.content) {
        c := p.content[p.pos]
        if c != ' ' && c != '\t' && c != '\r' && c != '\n' {
            break
        }
        p.pos++
    }
}

// 使用示例
func ExampleDSL() {
    dslContent := `
workflow "user_registration" {
    task "send_verification_code" {
        type = "email"
        depends_on = []
        timeout = "30s"
    }
    
    task "verify_code" {
        type = "verification"
        depends_on = ["send_verification_code"]
        timeout = "5m"
    }
    
    task "create_user_account" {
        type = "database"
        depends_on = ["verify_code"]
        condition = "result.success == true"
        timeout = "10s"
    }
}
`
    
    parser := NewDSLParser(dslContent)
    workflows, err := parser.Parse()
    if err != nil {
        fmt.Printf("Failed to parse DSL: %v\n", err)
        return
    }
    
    for _, workflow := range workflows {
        fmt.Printf("Workflow: %s\n", workflow.Name)
        for _, task := range workflow.Tasks {
            fmt.Printf("  Task: %s, Type: %s\n", task.Name, task.Type)
            if len(task.DependsOn) > 0 {
                fmt.Printf("    Depends on: %v\n", task.DependsOn)
            }
            if task.Condition != "" {
                fmt.Printf("    Condition: %s\n", task.Condition)
            }
            if task.Timeout != "" {
                fmt.Printf("    Timeout: %s\n", task.Timeout)
            }
        }
    }
}

循环依赖检测机制

在任务编排中,循环依赖是一个常见但严重的问题。如果任务A依赖于任务B,而任务B又依赖于任务A,就会形成循环依赖,导致任务无法执行。我们需要一个健壮的循环依赖检测机制来避免这种情况。

package cycle

import (
    "fmt"
)

// DependencyChecker 依赖检查器
type DependencyChecker struct {
    tasks map[string]*TaskInfo
}

// TaskInfo 任务信息
type TaskInfo struct {
    ID         string
    DependsOn  []string
    Visited    bool // 用于DFS遍历
    InPath     bool // 用于检测环
}

// NewDependencyChecker 创建依赖检查器
func NewDependencyChecker() *DependencyChecker {
    return &DependencyChecker{
        tasks: make(map[string]*TaskInfo),
    }
}

// AddTask 添加任务
func (dc *DependencyChecker) AddTask(id string, dependsOn []string) {
    dc.tasks[id] = &TaskInfo{
        ID:        id,
        DependsOn: dependsOn,
    }
}

// CheckCycles 检查循环依赖
func (dc *DependencyChecker) CheckCycles() error {
    // 重置访问状态
    for _, task := range dc.tasks {
        task.Visited = false
        task.InPath = false
    }
    
    // 对每个未访问的任务进行DFS
    for _, task := range dc.tasks {
        if !task.Visited {
            cyclePath, err := dc.detectCycle(task)
            if err != nil {
                return err
            }
            if len(cyclePath) > 0 {
                return fmt.Errorf("cycle detected: %v", cyclePath)
            }
        }
    }
    
    return nil
}

// detectCycle 检测环的辅助函数
func (dc *DependencyChecker) detectCycle(task *TaskInfo) ([]string, error) {
    // 如果已经在当前路径中,说明发现了环
    if task.InPath {
        return []string{task.ID}, nil
    }
    
    // 如果已经访问过且不在当前路径中,说明之前已经检查过了
    if task.Visited {
        return nil, nil
    }
    
    // 标记为正在访问和在当前路径中
    task.Visited = true
    task.InPath = true
    
    // 递归检查所有依赖项
    for _, depID := range task.DependsOn {
        depTask, exists := dc.tasks[depID]
        if !exists {
            return nil, fmt.Errorf("dependency task %s not found for task %s", depID, task.ID)
        }
        
        cyclePath, err := dc.detectCycle(depTask)
        if err != nil {
            return nil, err
        }
        
        if len(cyclePath) > 0 {
            // 如果找到了环,将当前任务添加到环路径中
            return append([]string{task.ID}, cyclePath...), nil
        }
    }
    
    // 从当前路径中移除
    task.InPath = false
    
    return nil, nil
}

// GetTopologicalOrder 获取拓扑排序
func (dc *DependencyChecker) GetTopologicalOrder() ([]string, error) {
    // 先检查是否有环
    if err := dc.CheckCycles(); err != nil {
        return nil, err
    }
    
    // 重置访问状态
    for _, task := range dc.tasks {
        task.Visited = false
    }
    
    var result []string
    
    // 对每个未访问的任务进行DFS
    for _, task := range dc.tasks {
        if !task.Visited {
            dc.topologicalSort(task, &result)
        }
    }
    
    // 反转结果以获得正确的拓扑顺序
    for i, j := 0, len(result)-1; i < j; i, j = i+1, j-1 {
        result[i], result[j] = result[j], result[i]
    }
    
    return result, nil
}

// topologicalSort 拓扑排序的辅助函数
func (dc *DependencyChecker) topologicalSort(task *TaskInfo, result *[]string) {
    // 如果已经访问过,直接返回
    if task.Visited {
        return
    }
    
    // 标记为已访问
    task.Visited = true
    
    // 递归处理所有依赖项
    for _, depID := range task.DependsOn {
        depTask := dc.tasks[depID]
        dc.topologicalSort(depTask, result)
    }
    
    // 将当前任务添加到结果中
    *result = append(*result, task.ID)
}

// 使用示例
func ExampleCycleDetection() {
    checker := NewDependencyChecker()
    
    // 添加任务及其依赖关系
    checker.AddTask("task_a", []string{})
    checker.AddTask("task_b", []string{"task_a"})
    checker.AddTask("task_c", []string{"task_b"})
    checker.AddTask("task_d", []string{"task_c"})
    
    // 检查循环依赖
    err := checker.CheckCycles()
    if err != nil {
        fmt.Printf("Cycle detected: %v\n", err)
        return
    }
    
    fmt.Println("No cycles detected")
    
    // 获取拓扑排序
    order, err := checker.GetTopologicalOrder()
    if err != nil {
        fmt.Printf("Failed to get topological order: %v\n", err)
        return
    }
    
    fmt.Printf("Topological order: %v\n", order)
    
    // 添加一个循环依赖来测试
    checker.AddTask("task_e", []string{"task_d"})
    checker.AddTask("task_a", []string{"task_e"}) // 创建循环依赖
    
    err = checker.CheckCycles()
    if err != nil {
        fmt.Printf("Cycle detected: %v\n", err)
    }
}

高级任务编排特性

现在让我们实现一些更高级的任务编排特性:

package advanced

import (
    "context"
    "fmt"
    "sync"
    "time"
)

// ExecutionContext 执行上下文
type ExecutionContext struct {
    TaskResults map[string]interface{}
    mu          sync.RWMutex
}

// GetResult 获取任务结果
func (ec *ExecutionContext) GetResult(taskID string) (interface{}, bool) {
    ec.mu.RLock()
    defer ec.mu.RUnlock()
    result, ok := ec.TaskResults[taskID]
    return result, ok
}

// SetResult 设置任务结果
func (ec *ExecutionContext) SetResult(taskID string, result interface{}) {
    ec.mu.Lock()
    defer ec.mu.Unlock()
    if ec.TaskResults == nil {
        ec.TaskResults = make(map[string]interface{})
    }
    ec.TaskResults[taskID] = result
}

// TaskExecutor 任务执行器接口
type TaskExecutor interface {
    Execute(ctx context.Context, taskID string, params map[string]interface{}, ec *ExecutionContext) (interface{}, error)
}

// ConditionalTaskExecutor 条件任务执行器
type ConditionalTaskExecutor struct {
    condition string
    executor  TaskExecutor
}

// Execute 执行条件任务
func (cte *ConditionalTaskExecutor) Execute(ctx context.Context, taskID string, params map[string]interface{}, ec *ExecutionContext) (interface{}, error) {
    // 简单的条件评估(实际项目中可以使用更复杂的表达式引擎)
    if cte.condition != "" {
        // 这里简化处理,实际应该解析和评估条件表达式
        if cte.condition == "always_false" {
            return nil, fmt.Errorf("condition not met: %s", cte.condition)
        }
    }
    
    return cte.executor.Execute(ctx, taskID, params, ec)
}

// ParallelTaskGroup 并行任务组
type ParallelTaskGroup struct {
    tasks []TaskExecutor
}

// Execute 执行并行任务组
func (ptg *ParallelTaskGroup) Execute(ctx context.Context, taskID string, params map[string]interface{}, ec *ExecutionContext) (interface{}, error) {
    var wg sync.WaitGroup
    results := make([]interface{}, len(ptg.tasks))
    errors := make([]error, len(ptg.tasks))
    
    for i, task := range ptg.tasks {
        wg.Add(1)
        go func(index int, t TaskExecutor) {
            defer wg.Done()
            result, err := t.Execute(ctx, fmt.Sprintf("%s_%d", taskID, index), params, ec)
            results[index] = result
            errors[index] = err
        }(i, task)
    }
    
    wg.Wait()
    
    // 检查是否有错误
    for _, err := range errors {
        if err != nil {
            return nil, fmt.Errorf("parallel task failed: %v", err)
        }
    }
    
    return results, nil
}

// TimeoutTaskExecutor 超时任务执行器
type TimeoutTaskExecutor struct {
    timeout  time.Duration
    executor TaskExecutor
}

// Execute 执行带超时的任务
func (tte *TimeoutTaskExecutor) Execute(ctx context.Context, taskID string, params map[string]interface{}, ec *ExecutionContext) (interface{}, error) {
    // 创建带超时的上下文
    ctxWithTimeout, cancel := context.WithTimeout(ctx, tte.timeout)
    defer cancel()
    
    // 创建通道来接收结果
    resultChan := make(chan interface{}, 1)
    errorChan := make(chan error, 1)
    
    // 在goroutine中执行任务
    go func() {
        result, err := tte.executor.Execute(ctxWithTimeout, taskID, params, ec)
        if err != nil {
            errorChan <- err
        } else {
            resultChan <- result
        }
    }()
    
    // 等待结果或超时
    select {
    case result := <-resultChan:
        return result, nil
    case err := <-errorChan:
        return nil, err
    case <-ctxWithTimeout.Done():
        return nil, fmt.Errorf("task %s timed out after %v", taskID, tte.timeout)
    }
}

// RetryTaskExecutor 重试任务执行器
type RetryTaskExecutor struct {
    maxRetries int
    delay      time.Duration
    executor   TaskExecutor
}

// Execute 执行带重试的任务
func (rte *RetryTaskExecutor) Execute(ctx context.Context, taskID string, params map[string]interface{}, ec *ExecutionContext) (interface{}, error) {
    var lastErr error
    
    for i := 0; i <= rte.maxRetries; i++ {
        result, err := rte.executor.Execute(ctx, taskID, params, ec)
        if err == nil {
            return result, nil
        }
        
        lastErr = err
        fmt.Printf("Task %s failed (attempt %d/%d): %v\n", taskID, i+1, rte.maxRetries+1, err)
        
        if i < rte.maxRetries {
            select {
            case <-ctx.Done():
                return nil, ctx.Err()
            case <-time.After(rte.delay):
                // 等待后重试
            }
        }
    }
    
    return nil, fmt.Errorf("task %s failed after %d retries: %v", taskID, rte.maxRetries, lastErr)
}

// 使用示例
func ExampleAdvancedOrchestration() {
    // 创建执行上下文
    ec := &ExecutionContext{}
    
    // 创建一个带超时和重试的任务执行器
    executor := &RetryTaskExecutor{
        maxRetries: 3,
        delay:      time.Second,
        executor: &TimeoutTaskExecutor{
            timeout: 5 * time.Second,
            executor: &ConditionalTaskExecutor{
                condition: "",
                executor: nil, // 这里应该是一个具体的任务执行器
            },
        },
    }
    
    // 执行任务
    ctx := context.Background()
    result, err := executor.Execute(ctx, "sample_task", nil, ec)
    if err != nil {
        fmt.Printf("Task failed: %v\n", err)
        return
    }
    
    fmt.Printf("Task completed with result: %v\n", result)
}

总结

在本节中,我们深入探讨了以下关键技术:

  1. 自定义DSL设计

    • 使用ANTLR或自定义解析器实现DSL解析
    • 设计易于理解和编写的语法规则
    • 实现DSL到内部数据结构的转换
  2. 循环依赖检测

    • 使用深度优先搜索(DFS)算法检测循环依赖
    • 实现拓扑排序以确定任务执行顺序
    • 提供清晰的错误信息帮助用户定位问题
  3. 高级任务编排特性

    • 条件执行:根据条件决定是否执行任务
    • 并行执行:同时执行多个无依赖关系的任务
    • 超时控制:防止任务执行时间过长
    • 重试机制:在任务失败时自动重试

这些技术使我们的任务调度系统更加智能和健壮,能够处理各种复杂的业务场景。在下一节中,我们将探讨如何将这些组件整合到一个完整的分布式任务调度系统中,并实现高可用性保障。