274 lines
6 KiB
Go
274 lines
6 KiB
Go
// Yet another command line parser.
|
|
package config
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"os"
|
|
"reflect"
|
|
"strconv"
|
|
"strings"
|
|
"unicode"
|
|
"unicode/utf8"
|
|
)
|
|
|
|
// Errors accumulates errors during parsing.
|
|
type Errors struct {
|
|
errors []string
|
|
}
|
|
|
|
// Check logs an error if err is not nil.
|
|
func (e *Errors) Check(err error, format, arg1, arg2 string) {
|
|
if err != nil {
|
|
e.errors = append(e.errors, fmt.Sprintf(format, arg1, arg2))
|
|
}
|
|
}
|
|
|
|
// Add adds an error to the list.
|
|
func (e *Errors) Add(format, arg string) {
|
|
e.errors = append(e.errors, fmt.Sprintf(format, arg))
|
|
}
|
|
|
|
// Group holds a group of parameters.
|
|
type Group struct {
|
|
Name string
|
|
Prefix string
|
|
Values interface{}
|
|
}
|
|
|
|
// Config handles parsing command line arguments into structs.
|
|
//
|
|
// Groups of parameters can be stored in a struct and added as a group
|
|
// using Add(). All public, primitive fields are automatically
|
|
// exposed. The argument names are set to the snake case version of
|
|
// the prefix and variable name.
|
|
type Config struct {
|
|
groups []*Group
|
|
}
|
|
|
|
// Add a new group of parameters.
|
|
func (c *Config) Add(group, prefix string, values interface{}) {
|
|
next := &Group{Name: group, Prefix: prefix, Values: values}
|
|
c.groups = append(c.groups, next)
|
|
}
|
|
|
|
type parser func(field reflect.Value, name, value string, errors *Errors)
|
|
|
|
func parseString(field reflect.Value, name, value string, errors *Errors) {
|
|
field.SetString(value)
|
|
}
|
|
|
|
func parseBool(field reflect.Value, name, value string, errors *Errors) {
|
|
if value == "" {
|
|
field.SetBool(true)
|
|
} else {
|
|
parsed, err := strconv.ParseBool(value)
|
|
errors.Check(err, "Invalid boolean value %s for %s.", value, name)
|
|
field.SetBool(parsed)
|
|
}
|
|
}
|
|
|
|
func parseInt(field reflect.Value, name, value string, errors *Errors) {
|
|
parsed, err := strconv.ParseInt(value, 10, 32)
|
|
errors.Check(err, "Invalid integer value %s for %s.", value, name)
|
|
field.SetInt(parsed)
|
|
}
|
|
|
|
func parseFloat(field reflect.Value, name, value string, errors *Errors) {
|
|
parsed, err := strconv.ParseFloat(value, 32)
|
|
errors.Check(err, "Invalid floating point value %s for %s.", value, name)
|
|
field.SetFloat(parsed)
|
|
}
|
|
|
|
var parsers map[reflect.Kind]parser
|
|
|
|
func init() {
|
|
parsers = map[reflect.Kind]parser{
|
|
reflect.Bool: parseBool,
|
|
reflect.Int: parseInt,
|
|
reflect.Int32: parseInt,
|
|
reflect.Float32: parseFloat,
|
|
reflect.Float64: parseFloat,
|
|
reflect.String: parseString,
|
|
}
|
|
}
|
|
|
|
func toLower(prefix, name string) string {
|
|
into := bytes.NewBufferString(prefix)
|
|
chain := false
|
|
|
|
if prefix == "" {
|
|
chain = true
|
|
}
|
|
|
|
for _, ch := range name {
|
|
if unicode.IsUpper(ch) {
|
|
if !chain {
|
|
into.WriteRune('_')
|
|
}
|
|
into.WriteRune(unicode.ToLower(ch))
|
|
chain = true
|
|
} else {
|
|
into.WriteRune(ch)
|
|
chain = false
|
|
}
|
|
}
|
|
return into.String()
|
|
}
|
|
|
|
type visitor func(name string, field reflect.Value)
|
|
|
|
func visit(val interface{}, visitor visitor) {
|
|
s := reflect.ValueOf(val).Elem()
|
|
tp := s.Type()
|
|
|
|
for i := 0; i < s.NumField(); i++ {
|
|
f := s.Field(i)
|
|
name := tp.Field(i).Name
|
|
first, _ := utf8.DecodeRuneInString(name)
|
|
_, hasParser := parsers[f.Kind()]
|
|
|
|
switch {
|
|
case !hasParser:
|
|
break
|
|
case !unicode.IsUpper(first):
|
|
break
|
|
default:
|
|
visitor(name, f)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Help prints human-readable help to the writer. Use nil for
|
|
// stdout.
|
|
func (c *Config) Help(w io.Writer) {
|
|
if w == nil {
|
|
w = os.Stdout
|
|
}
|
|
for i, group := range c.groups {
|
|
if i != 0 {
|
|
fmt.Fprintln(w)
|
|
}
|
|
fmt.Fprintf(w, "%s:\n", group.Name)
|
|
visit(group.Values, func(name string, field reflect.Value) {
|
|
low := toLower(group.Prefix, name)
|
|
fmt.Fprintf(w, " --%s=%v\n", low, field.Interface())
|
|
})
|
|
}
|
|
}
|
|
|
|
// Print prints the parameters as a JSON dict. Use nil for stdout.
|
|
func (c *Config) Print(w io.Writer, compact bool) {
|
|
if w == nil {
|
|
w = os.Stdout
|
|
}
|
|
eol := "\n"
|
|
indent := " "
|
|
if compact {
|
|
eol = ""
|
|
indent = ""
|
|
}
|
|
|
|
// Yay, another JSON encoder.
|
|
fmt.Fprintf(w, "{%s", eol)
|
|
for i, group := range c.groups {
|
|
fmt.Fprintf(w, "%s%q: {%s", indent, group.Name, eol)
|
|
j := 0
|
|
visit(group.Values, func(name string, field reflect.Value) {
|
|
if j != 0 {
|
|
fmt.Fprintf(w, ",%s", eol)
|
|
}
|
|
j += 1
|
|
fmt.Fprintf(w, "%s%s%q: ", indent, indent, name)
|
|
switch field.Kind() {
|
|
case reflect.String:
|
|
fmt.Fprintf(w, "%q", field.Interface())
|
|
default:
|
|
fmt.Fprintf(w, "%v", field.Interface())
|
|
}
|
|
})
|
|
if i != len(c.groups)-1 {
|
|
fmt.Fprintf(w, "%s%s},%s", eol, indent, eol)
|
|
} else {
|
|
fmt.Fprintf(w, "%s%s}%s", eol, indent, eol)
|
|
}
|
|
}
|
|
fmt.Fprintf(w, "}%s", eol)
|
|
}
|
|
|
|
// Print prints the parameters as a JSON dict. Use nil for stdout.
|
|
func (c *Config) String() string {
|
|
var b bytes.Buffer
|
|
c.Print(&b, true)
|
|
return b.String()
|
|
}
|
|
|
|
// Parse parses the string arguments and returns all errors seen.
|
|
func (c *Config) Parse(args []string) []string {
|
|
errors := &Errors{}
|
|
all := make(map[string]reflect.Value)
|
|
|
|
for _, group := range c.groups {
|
|
visit(group.Values, func(name string, field reflect.Value) {
|
|
low := toLower(group.Prefix, name)
|
|
|
|
if _, ok := all[low]; ok {
|
|
log.Fatalf("Duplicate argument %s.", low)
|
|
}
|
|
all[low] = field
|
|
})
|
|
}
|
|
|
|
for _, arg := range args {
|
|
if !strings.HasPrefix(arg, "--") {
|
|
errors.Add("Unrecognised argument %s.", arg)
|
|
continue
|
|
}
|
|
name := strings.TrimPrefix(arg, "--")
|
|
var value string
|
|
separator := strings.Index(name, "=")
|
|
|
|
if separator >= 0 {
|
|
value = name[separator+1:]
|
|
name = name[:separator]
|
|
}
|
|
|
|
field, ok := all[name]
|
|
if !ok {
|
|
errors.Add("Unrecognised argument %s.", name)
|
|
continue
|
|
}
|
|
if separator < 0 && field.Kind() != reflect.Bool {
|
|
errors.Add("%s needs an argument.", name)
|
|
}
|
|
|
|
parser := parsers[field.Kind()]
|
|
|
|
if parser == nil {
|
|
log.Fatalf("Unhandled type %s while parsing %s.", field.Kind(), name)
|
|
}
|
|
parser(field, name, value, errors)
|
|
}
|
|
|
|
return errors.errors
|
|
}
|
|
|
|
// ParseArgv parses the programs command line arguments.
|
|
func (c *Config) ParseArgv() []string {
|
|
return c.Parse(os.Args[1:])
|
|
}
|
|
|
|
// ShowErrors prints the errors in a similar format to Write. Use nil
|
|
// for stdout.
|
|
func (c *Config) ShowErrors(w io.Writer, errors []string) {
|
|
if w == nil {
|
|
w = os.Stdout
|
|
}
|
|
|
|
for _, error := range errors {
|
|
fmt.Fprintf(w, "Error: %v\n", error)
|
|
}
|
|
}
|