diff options
| -rw-r--r-- | go.mod | 6 | ||||
| -rw-r--r-- | go.sum | 4 | ||||
| -rw-r--r-- | main.go | 110 |
3 files changed, 44 insertions, 76 deletions
@@ -1,3 +1,7 @@ module go.awhk.org/gencert -go 1.16 +go 1.21 + +require go.awhk.org/core v0.6.1 + +require github.com/google/go-cmp v0.5.9 // indirect @@ -0,0 +1,4 @@ +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +go.awhk.org/core v0.6.1 h1:lKkmAH/p/kSR7WAUbLzWIWU11wZNmFsrXDJWe8bfsH0= +go.awhk.org/core v0.6.1/go.mod h1:lOs71woKF5QCNNEFjaACmhEj7U6IEGAFHw0Zo1Fyh50= @@ -14,73 +14,39 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/pem" - "errors" "flag" "fmt" "log" "math/big" "net" + "net/netip" "os" "sort" "time" -) - -type IPListFlag []net.IP - -func (f *IPListFlag) String() string { - return fmt.Sprintf("%s", []net.IP(*f)) -} - -func (f *IPListFlag) Set(s string) error { - ip := net.ParseIP(s) - if ip == nil { - return errors.New("could not parse IP") - } - *f = append(*f, ip) - return nil -} - -type StringListFlag []string - -func (f *StringListFlag) String() string { - return fmt.Sprintf("%s", []string(*f)) -} - -func (f *StringListFlag) Set(s string) error { - *f = append(*f, s) - return nil -} -type TimeFlag struct { - t time.Time -} - -func (f *TimeFlag) String() string { - return f.t.String() -} - -func (f *TimeFlag) Set(s string) (err error) { - f.t, err = time.Parse(time.RFC3339, s) - return -} + "go.awhk.org/core" +) var ( caFlags = flag.NewFlagSet(os.Args[0]+" ca", flag.ExitOnError) certFlags = flag.NewFlagSet(os.Args[0]+" cert", flag.ExitOnError) - caName string + caName = certFlags.String("ca", "", "base name for the CA files") commonName string country string - dnsNames StringListFlag + dnsNames = core.FlagSlice(certFlags, "dns", nil, "DNS name", core.ParseString, ",") duration time.Duration - from TimeFlag - ips IPListFlag + from time.Time + ips = core.FlagSlice(certFlags, "ip", nil, "IP addresses", netip.ParseAddr, ",") keyAlgo string - org StringListFlag + org []string out string - unit StringListFlag - until TimeFlag - usages = StringListFlag{"server-auth"} + unit []string + until time.Time + usages = core.FlagSlice(certFlags, "usage", []string{"server-auth"}, `how the certificate will be used: + - code-signing + - server-auth +`, core.ParseStringEnum("code-signing", "server-auth"), ",") ) func init() { @@ -94,19 +60,13 @@ func init() { - ecdsa - rsa `) - f.Var(&from, "nb", "the earliest time on which the certificate is valid") - f.Var(&org, "o", "organization") + core.FlagSliceVar(f, &org, "o", "organization", core.ParseString, ",") + core.FlagVar(f, &from, "nb", "the earliest time on which the certificate is valid", core.ParseTime) + core.FlagSliceVar(f, &org, "o", "organization", core.ParseString, ",") f.StringVar(&out, "out", "", "base name for the output") - f.Var(&unit, "ou", "organizational unit") - f.Var(&until, "na", "the time past which the certificate is no longer valid") + core.FlagSliceVar(f, &unit, "ou", "organizational unit", core.ParseString, ",") + core.FlagVar(f, &until, "na", "the time past which the certificate is no longer valid", core.ParseTime) } - certFlags.StringVar(&caName, "ca", "", "base name for the CA files") - certFlags.Var(&dnsNames, "dns", "DNS name") - certFlags.Var(&ips, "ip", "IP address") - certFlags.Var(&usages, "usage", `how the certificate will be used: - - code-signing - - server-auth -`) flag.Usage = func() { fmt.Fprintf(flag.CommandLine.Output(), `%s is a tool for generating certificates. @@ -147,12 +107,12 @@ func main() { log.Fatalln("error: -c is required") } // See RFC 6125ยง6.4.4. - if len(dnsNames) > 0 || len(ips) > 0 { - dnsNames = append(dnsNames, commonName) + if len(*dnsNames) > 0 || len(*ips) > 0 { + *dnsNames = append(*dnsNames, commonName) } - sort.Strings(dnsNames) - if from.t.IsZero() { - from.t = time.Now() + sort.Strings(*dnsNames) + if from.IsZero() { + from = time.Now() } if out == "" { log.Fatalln("error: -out is required") @@ -160,15 +120,15 @@ func main() { if len(org) == 0 { log.Fatalln("error: -o is required") } - if until.t.IsZero() { + if until.IsZero() { if duration == 0 { log.Fatalln("error: -na is required when no -d is passed") } - until.t = from.t.Add(duration) + until = from.Add(duration) } else if duration != 0 { log.Println("warning: ignored -d as -na was passed") } - if until.t.Before(from.t) { + if until.Before(from) { log.Fatalln("error: end date is before the start date") } @@ -178,13 +138,13 @@ func main() { } tmpl := &x509.Certificate{ BasicConstraintsValid: os.Args[1] == "ca", - DNSNames: dnsNames, + DNSNames: *dnsNames, ExtKeyUsage: extKeyUsage(), - IPAddresses: ips, + IPAddresses: core.SliceMap(func(ip netip.Addr) net.IP { return ip.AsSlice() }, *ips), IsCA: os.Args[1] == "ca", KeyUsage: keyUsage(), - NotBefore: from.t, - NotAfter: until.t, + NotBefore: from, + NotAfter: until, SerialNumber: newSerial(), Subject: pkix.Name{ CommonName: commonName, @@ -197,8 +157,8 @@ func main() { } parentKey := key parentCert := tmpl - if caName != "" { - buf, err := os.ReadFile(caName + ".key") + if *caName != "" { + buf, err := os.ReadFile(*caName + ".key") if err != nil { log.Fatalln("error: could not read the CA private key:", err) } @@ -210,7 +170,7 @@ func main() { if err != nil { log.Fatalln("error: could not parse the CA private key:", err) } - buf, err = os.ReadFile(caName + ".crt") + buf, err = os.ReadFile(*caName + ".crt") if err != nil { log.Fatalln("error: could not read the CA certificate:", err) } @@ -269,7 +229,7 @@ func extKeyUsage() []x509.ExtKeyUsage { return nil } s := map[string]x509.ExtKeyUsage{} - for _, e := range usages { + for _, e := range *usages { switch e { case "code-signing": s[e] = x509.ExtKeyUsageCodeSigning |
