diff options
| author | Grégoire Duchêne <gduchene@awhk.org> | 2019-08-29 00:21:37 +0100 |
|---|---|---|
| committer | Grégoire Duchêne <gduchene@awhk.org> | 2019-08-29 00:21:37 +0100 |
| commit | 863abc0eda83ef08be8d8885e2875de36c4d57dd (patch) | |
| tree | 9bce1685fb308e0cb1016a067926b424953c47fc | |
| parent | a7f187677256c81c53a00e3909a347504340839a (diff) | |
Reorganize flags into flag sets
| -rw-r--r-- | README.md | 4 | ||||
| -rw-r--r-- | main.go | 98 |
2 files changed, 71 insertions, 31 deletions
@@ -9,7 +9,7 @@ be useful to do PKI on things that only live on your LAN. ```shell # Generate a self-signed certificate. # This generates ~/out/my-ca.crt and ~/out/my-ca.key. -$ gencert \ +$ gencert ca \ -c US \ -o example.com \ -cn 'My CA' \ @@ -19,7 +19,7 @@ $ gencert \ # Generate a normal certificate. # This reads ~/out/my-ca.crt and ~/out/my-ca.key, and generates # ~/out/my-site.crt and ~/my-site.key. -$ gencert \ +$ gencert cert \ -ca ~/out/my-ca \ -c US \ -o example.com \ @@ -65,26 +65,55 @@ func (f *TimeFlag) Set(s string) (err error) { } var ( - caName = flag.String("ca", "", "base name for the CA files") - commonName = flag.String("cn", "", "common name") - country = flag.String("c", "", "country code") + caFlags = flag.NewFlagSet(os.Args[0]+" ca", flag.ExitOnError) + certFlags = flag.NewFlagSet(os.Args[0]+" cert", flag.ExitOnError) + + caName string + commonName string + country string dnsNames StringListFlag - duration = flag.Duration("d", 0, "certificate duration") + duration time.Duration from TimeFlag ips IPListFlag org StringListFlag - out = flag.String("out", "", "base name for the output") + out string unit StringListFlag until TimeFlag ) func init() { - flag.Var(&dnsNames, "dns", "DNS name") - flag.Var(&from, "nb", "the earliest time on which the certificate is valid") - flag.Var(&ips, "ip", "IP address") - flag.Var(&org, "o", "organization") - flag.Var(&unit, "ou", "organizational unit") - flag.Var(&until, "na", "the time past which the certificate is no longer valid") + log.SetFlags(0) + + for _, f := range []*flag.FlagSet{caFlags, certFlags} { + f.StringVar(&commonName, "cn", "", "common name") + f.StringVar(&country, "c", "", "country code") + f.DurationVar(&duration, "d", 0, "certificate duration") + f.Var(&from, "nb", "the earliest time on which the certificate is valid") + f.Var(&org, "o", "organization") + f.StringVar(&out, "out", "", "base name for the output") + f.Var(&unit, "ou", "organizational unit") + f.Var(&until, "na", "the time past which the certificate is no longer valid") + } + certFlags.StringVar(&caName, "ca", "", "base name for the CA files") + certFlags.Var(&dnsNames, "dns", "DNS name") + certFlags.Var(&ips, "ip", "IP address") + + flag.Usage = func() { + fmt.Fprintf(flag.CommandLine.Output(), `%s is a tool for generating certificates. + +Usage: + + %[1]s <command> [arguments] + +The commands are: + + ca generate a CA certificate + cert generate a regular certificate + +Use %[1]s <command> -h for help about that command. + +`, os.Args[0]) + } } func newSerial() *big.Int { @@ -107,29 +136,40 @@ func newSerial() *big.Int { func main() { flag.Parse() - log.SetFlags(0) - - if *commonName == "" { + if len(os.Args) == 1 { + flag.Usage() + os.Exit(2) + } + switch os.Args[1] { + case "ca": + caFlags.Parse(os.Args[2:]) + case "cert": + certFlags.Parse(os.Args[2:]) + default: + flag.Usage() + os.Exit(2) + } + if commonName == "" { log.Fatalln("error: -cn is required") } - if *country == "" { + if country == "" { log.Fatalln("error: -c is required") } if from.t.IsZero() { from.t = time.Now() } - if *out == "" { + if out == "" { log.Fatalln("error: -out is required") } if len(org) == 0 { log.Fatalln("error: -o is required") } if until.t.IsZero() { - if *duration == 0 { + if duration == 0 { log.Fatalln("error: -end-date is required when no -d is passed") } - until.t = from.t.Add(*duration) - } else if *duration != 0 { + until.t = from.t.Add(duration) + } else if duration != 0 { log.Println("warning: ignored -d as -end-date was passed") } if until.t.Before(from.t) { @@ -144,33 +184,33 @@ func main() { keyUsage x509.KeyUsage extKeyUsage []x509.ExtKeyUsage ) - if *caName == "" { + if os.Args[1] == "ca" { keyUsage = x509.KeyUsageCertSign } else { keyUsage = x509.KeyUsageDigitalSignature extKeyUsage = append(extKeyUsage, x509.ExtKeyUsageServerAuth) } tmpl := &x509.Certificate{ - BasicConstraintsValid: *caName == "", + BasicConstraintsValid: os.Args[1] == "ca", DNSNames: dnsNames, ExtKeyUsage: extKeyUsage, IPAddresses: ips, - IsCA: *caName == "", + IsCA: os.Args[1] == "ca", KeyUsage: keyUsage, NotBefore: from.t, NotAfter: until.t, SerialNumber: newSerial(), Subject: pkix.Name{ - CommonName: *commonName, - Country: []string{*country}, + CommonName: commonName, + Country: []string{country}, Organization: org, OrganizationalUnit: unit, }, } parentKey := key parentCert := tmpl - if *caName != "" { - buf, err := ioutil.ReadFile(fmt.Sprintf("%s.key", *caName)) + if caName != "" { + buf, err := ioutil.ReadFile(caName + ".key") if err != nil { log.Fatalln("error: could not read the CA private key:", err) } @@ -179,7 +219,7 @@ func main() { if err != nil { log.Fatalln("error: could not parse the CA private key:", err) } - buf, err = ioutil.ReadFile(fmt.Sprintf("%s.crt", *caName)) + buf, err = ioutil.ReadFile(caName + ".crt") if err != nil { log.Fatalln("error: could not read the CA certificate:", err) } @@ -194,12 +234,12 @@ func main() { log.Fatalln("error: could not generate the certificate:", err) } - keyOut, err := os.OpenFile(fmt.Sprintf("%s.key", *out), os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0600) + keyOut, err := os.OpenFile(out+".key", os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0600) if err != nil { log.Fatalln("error: could not create the private key:", err) } defer keyOut.Close() - certOut, err := os.OpenFile(fmt.Sprintf("%s.crt", *out), os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0644) + certOut, err := os.OpenFile(out+".crt", os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0644) if err != nil { log.Fatalln("error: could not create the certificate:", err) } |
