用flag.Func自定义命令行标志的详细教程

204 阅读3分钟

在最近的Go 1.16版本中,我最喜欢的事情之一是在flag 包中增加了一个小的--但非常受欢迎的--东西:the flag.Func()函数。这使得在你的应用程序中定义和使用自定义命令行标志变得更加容易。

例如,如果你想把一个像--pause=10s 的标志直接解析成一个time.Duration 类型,或者把--urls="http://example.com http://example.org" 直接解析成一个[]string 片断,那么以前你有两个选择。你可以创建一个自定义类型来实现flag.Value 接口,或者使用第三方软件包,如 pflag.

但是现在,flag.Func() 函数为你提供了一个简单和轻量级的选择。在这篇短文中,我们将看一下几个例子,看看你如何在自己的代码中使用它。

解析自定义标志类型

为了演示它是如何工作的,让我们从我上面给出的两个例子开始,创建一个示例应用程序,它接受一个URL列表,然后将它们打印出来,在它们之间有一个停顿。与此类似:

$ go run . --pause=3s --urls="http://example.com http://example.org http://example.net"
2021/03/08 08:16:04 http://example.com
2021/03/08 08:16:07 http://example.org
2021/03/08 08:16:10 http://example.net

为了使其发挥作用,我们需要做两件事:

  • --pause 标志值从 "人类可读 "的字符串(如200ms,5s10m )转换成本地Gotime.Duration 类型。我们可以使用 time.ParseDuration()函数来完成。
  • --urls 标志中的值分割成一个片断,这样我们就可以循环浏览它们。该 strings.Fields函数很适合这项任务。

我们可以像这样将这些与flag.Func() 一起使用:

package main

import (
    "flag"
    "log"
    "strings"
    "time"
)

func main() {
    // First we need to declare variables to hold the values from the
    // command-line flags. Notice that we also need to set any defaults,
    // which will be used if the relevant flag is not provided at runtime.
    var (
        urls  []string                    // Default of the empty slice
        pause time.Duration = time.Second // Default of one second
    )

    // The flag.Func() function takes three parameters: the flag name, 
    // descriptive help text, and a function with the signature 
    // `func(string) error` which is called to process the string value 
    // from the command-line flag at runtime and assign it to the necessary 
    // variable. In this case, we use strings.Fields() to split the string 
    // based on whitespace and store the resulting slice in the urls  
    // variable that we declared above. We then return nil from the 
    // function to indicate that the flag was parsed without any errors.
    flag.Func("urls", "List of URLs to print", func(flagValue string) error {
        urls = strings.Fields(flagValue)
        return nil
    })

    // Likewise we can do the same thing to parse the pause duration. The 
    // time.ParseDuration() function may throw an error here, so we make 
    // sure to return that from our function.
    flag.Func("pause", "Duration to pause between printing URLs", func(flagValue string) error {
        var err error
        pause, err = time.ParseDuration(flagValue)
        return err
    })

    // Importantly, call flag.Parse() to trigger actual parsing of the 
    // flags.
    flag.Parse()

    // Print out the URLs, pausing between each iteration.
    for _, u := range urls {
        log.Print(u)
        time.Sleep(pause)
    }
}

如果你尝试运行这个应用程序,你应该发现这些标志被解析了,并且像你所期望的那样工作。比如说:

$ go run . --pause=500ms --urls="http://example.com http://example.org http://example.net"
2021/03/08 08:22:33 http://example.com
2021/03/08 08:22:34 http://example.org
2021/03/08 08:22:34 http://example.net

而如果你提供了一个无效的标志值,在flag.Func() 的某个函数中引发了错误,Go会自动显示相应的错误信息并退出。举例来说:

$ go run . --pause=500xx --urls="http://example.com http://example.org http://example.net"
invalid value "500xx" for flag -pause: time: unknown unit "xx" in duration "500xx"
Usage of /tmp/go-build3141872390/b001/exe/example.text:
  -pause value
        Duration to pause between printing URLs
  -urls value
        List of URLs to print
exit status 2

这里真的要指出的是,如果没有提供一个标志,相应的flag.Func() 函数根本就不会被调用。这意味着你不能flag.Func() 函数中设置一个默认值,所以试图做这样的事情是行不通的:

flag.Func("pause", "Duration to pause between printing URLs (default 1s)", func(flagValue string) error {
    // DON'T DO THIS! This function wont' be called if the flag value is "".
    if flagValue == "" {
        pause = time.Second
        return nil
    }

    var err error
    pause, err = time.ParseDuration(flagValue)
    return err
})

不过,从好的方面看,flag.Func() 函数中可以包含的代码没有任何限制,所以如果你愿意,你可以用它来做更多的事情,把URL解析成一个[]*url.URL slice,而不是一个[]string 。 像这样。

var (
    urls  []*url.URL                 
    pause time.Duration = time.Second
)

flag.Func("urls", "List of URLs to print", func(flagValue string) error {
    for _, u := range strings.Fields(flagValue) {
        parsedURL, err := url.Parse(u)
        if err != nil {
            return err
        }
        urls = append(urls, parsedURL)
    }
    return nil
})

验证标志值

flag.Func() 函数还为验证来自命令行标志的输入数据提供了一些新的机会。例如,假设你的应用程序有一个--environment 标志,你想把可能的值限制在development,stagingproduction

要做到这一点,你可以实现一个与此类似的flag.Func() 函数:

package main

import (
    "errors"
    "flag"
    "fmt"
)

func main() {
    var (
        environment string = "development"
    )

    flag.Func("environment", "Operating environment", func(flagValue string) error {
        for _, allowedValue := range []string{"development", "staging", "production"} {
            if flagValue == allowedValue {
                environment = flagValue
                return nil
            }
        }
        return errors.New(`must be one of "development", "staging" or "production"`)
    })

    flag.Parse()

    fmt.Printf("The operating environment is: %s\n", environment)
}

制作可重复使用的辅助工具

如果你发现自己在flag.Func() 函数中重复相同的代码,或者逻辑变得太复杂,可以把它分解成一个可重用的帮助器。例如,我们可以重写上面的例子,通过一个通用的enumFlag() 函数来处理我们的--environment 标志,像这样:

package main

import (
    "flag"
    "fmt"
)

func main() {
    var (
        environment string = "development"
    )

    enumFlag(&environment, "environment", []string{"development", "staging", "production"}, "Operating environment")

    flag.Parse()

    fmt.Printf("The operating environment is: %s\n", environment)
}

func enumFlag(target *string, name string, safelist []string, usage string) {
    flag.Func(name, usage, func(flagValue string) error {
        for _, allowedValue := range safelist {
            if flagValue == allowedValue {
                *target = flagValue
                return nil
            }
        }

        return fmt.Errorf("must be one of %v", safelist)
    })
}