diff options
| -rw-r--r-- | main.go | 47 | ||||
| -rw-r--r-- | main_aws.go | 37 | ||||
| -rw-r--r-- | resp.go | 50 | ||||
| -rw-r--r-- | resp_test.go | 81 |
4 files changed, 146 insertions, 69 deletions
@@ -7,13 +7,12 @@ package main import ( "context" - "fmt" + "flag" "log" "net" "net/http" "os" "os/signal" - "path" "time" "golang.org/x/sys/unix" @@ -21,41 +20,27 @@ import ( "go.awhk.org/gosdd" ) -func redirect(resp http.ResponseWriter, req *http.Request) { - if req.Method != http.MethodGet { - resp.Header().Set("Allow", http.MethodGet) - resp.WriteHeader(http.StatusMethodNotAllowed) - return - } - - pkg := path.Join(req.Host, req.URL.Path) - if req.URL.Query().Get("go-get") != "1" { - resp.Header().Set("Location", "https://pkg.go.dev/"+pkg) - resp.WriteHeader(http.StatusFound) - return - } - resp.Header().Set("Content-Type", "text/html; charset=utf-8") - resp.WriteHeader(http.StatusOK) - if _, err := fmt.Fprint(resp, GetBody(pkg)); err != nil { - log.Println("fmt.Fprint:", err) - } -} +var ( + addr = flag.String("addr", "localhost:8080", "address to listen on") + from = flag.String("from", "", "package prefix to remove") + to = flag.String("to", "", "repository prefix to add") + vcs = flag.String("vcs", "git", "version control system to signal") +) func main() { - mux := http.NewServeMux() - mux.HandleFunc("/", redirect) - srv := http.Server{Handler: mux} + flag.Parse() done := make(chan os.Signal, 1) signal.Notify(done, os.Interrupt, unix.SIGTERM) + srv := http.Server{Handler: &redirector{*from, *to, *vcs}} go func() { ln, err := listenSD() if err != nil { log.Fatalln("listenSD:", err) } if ln == nil { - ln = listenEnv() + ln = listenFlag() } if err := srv.Serve(ln); err != nil && err != http.ErrServerClosed { log.Fatalln("server.ListenAndServe:", err) @@ -70,19 +55,15 @@ func main() { } } -func listenEnv() net.Listener { - addr := os.Getenv("ADDR") - if addr == "" || addr[0] != '/' { - if addr == "" { - addr = ":8080" - } - ln, err := net.Listen("tcp", ":8080") +func listenFlag() net.Listener { + if (*addr)[0] != '/' { + ln, err := net.Listen("tcp", *addr) if err != nil { log.Fatalln("net.Listen:", err) } return ln } - ln, err := net.Listen("unix", addr) + ln, err := net.Listen("unix", *addr) if err != nil { log.Fatalln("net.Listen:", err) } diff --git a/main_aws.go b/main_aws.go index 120d6de..4f6f16e 100644 --- a/main_aws.go +++ b/main_aws.go @@ -8,27 +8,38 @@ package main import ( "context" "net/http" + "os" "path" + "strings" "github.com/aws/aws-lambda-go/events" "github.com/aws/aws-lambda-go/lambda" ) +var ( + from = os.Getenv("FROM") + to = os.Getenv("TO") + vcs = os.Getenv("VCS") + redir = &redirector{from, to, vcs} +) + func redirect(ctx context.Context, req events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error) { - var ( - pkg = path.Join(req.Headers["Host"], req.Path) - resp = events.APIGatewayProxyResponse{ - Body: GetBody(pkg), - Headers: map[string]string{"Content-Type": "text/html; charset=utf-8"}, - } - ) - if v, ok := req.QueryStringParameters["go-get"]; ok && v == "1" { - resp.StatusCode = http.StatusOK - } else { - resp.Headers["Location"] = "https://pkg.go.dev/" + pkg - resp.StatusCode = http.StatusFound + pkg := path.Join(req.Headers["Host"], req.Path) + if v, ok := req.QueryStringParameters["go-get"]; !ok || v != "1" { + return events.APIGatewayProxyResponse{ + Headers: map[string]string{"Location": "https://pkg.go.dev/" + pkg}, + StatusCode: http.StatusFound, + }, nil + } + var buf strings.Builder + if err := body.Execute(&buf, bodyData{pkg, redir.getRepo(pkg), vcs}); err != nil { + return events.APIGatewayProxyResponse{}, err } - return resp, nil + return events.APIGatewayProxyResponse{ + Body: buf.String(), + Headers: map[string]string{"Content-Type": "text/html; charset=utf-8"}, + StatusCode: http.StatusOK, + }, nil } func main() { @@ -4,22 +4,48 @@ package main import ( - "fmt" - "os" + "log" + "net/http" + "path" "strings" + "text/template" ) -func GetBody(pkg string) string { - dest := GetDest(os.Getenv("PREFIX"), os.Getenv("DEST"), pkg) - return fmt.Sprintf(`<!doctype html> -<meta name="go-import" content="%s %s %s"> +var body = template.Must(template.New("").Parse(`<!doctype html> +<meta name="go-import" content="{{.Package}} {{.VCS}} {{.Repository}}"> <title>go-import-redirect</title> -`, pkg, os.Getenv("VCS"), dest) +`)) + +type bodyData struct{ Package, Repository, VCS string } + +type redirector struct{ from, to, vcs string } + +var _ http.Handler = &redirector{} + +func (h *redirector) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodGet { + w.Header().Set("Allow", http.MethodGet) + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + pkg := path.Join(req.Host, req.URL.Path) + if req.URL.Query().Get("go-get") != "1" { + w.Header().Set("Location", "https://pkg.go.dev/"+pkg) + w.WriteHeader(http.StatusFound) + return + } + dest := h.getRepo(pkg) + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusOK) + if err := body.Execute(w, bodyData{pkg, dest, h.vcs}); err != nil { + log.Println(err) + } } -func GetDest(srcPrefix, destPrefix, pkg string) string { - srcPrefix = strings.TrimRight(srcPrefix, "/") - destPrefix = strings.TrimRight(destPrefix, "/") - path := strings.TrimLeft(strings.TrimPrefix(pkg, srcPrefix), "/") - return destPrefix + "/" + strings.Split(path, "/")[0] +func (h *redirector) getRepo(pkg string) string { + from := strings.TrimRight(h.from, "/") + to := strings.TrimRight(h.to, "/") + path := strings.TrimLeft(strings.TrimPrefix(pkg, from), "/") + return to + "/" + strings.Split(path, "/")[0] } diff --git a/resp_test.go b/resp_test.go index 4afbd92..434d985 100644 --- a/resp_test.go +++ b/resp_test.go @@ -3,19 +3,78 @@ package main -import "testing" +import ( + "io" + "net/http" + "net/http/httptest" + "testing" +) -func TestGetDest(t *testing.T) { - cs := []struct{ srcPrefix, destPrefix, pkg, expected string }{ - {"src.example.com/x/", "https://example.com/git/", "src.example.com/x/foo", "https://example.com/git/foo"}, - {"src.example.com/x/", "https://example.com/git/", "src.example.com/x/foo/bar", "https://example.com/git/foo"}, - {"src.example.com/x", "https://example.com/git", "src.example.com/x/foo", "https://example.com/git/foo"}, - {"src.example.com/x", "https://example.com/git", "src.example.com/x/foo/bar", "https://example.com/git/foo"}, +func TestRedirector_ServeHTTP(t *testing.T) { + r := &redirector{"src.example.com/x", "https://example.com/git", "git"} + + t.Run("GoVisit", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "https://src.example.com/foo?go-get=1", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + resp := w.Result() + if http.StatusOK != resp.StatusCode { + t.Errorf("expected %d, got %d", http.StatusFound, resp.StatusCode) + } + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Error(err) + t.FailNow() + } + expected := `<!doctype html> +<meta name="go-import" content="src.example.com/foo git https://example.com/git/src.example.com"> +<title>go-import-redirect</title> +` + if string(body) != expected { + t.Errorf("expected\n---\n%s\n---\ngot\n---\n%s\n---", expected, string(body)) + } + if hdr := resp.Header.Get("Location"); hdr != "" { + t.Error("expected empty Location header") + } + }) + + t.Run("UserVisit", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "https://src.example.com/foo", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + resp := w.Result() + if http.StatusFound != resp.StatusCode { + t.Errorf("expected %d, got %d", http.StatusFound, resp.StatusCode) + } + if resp.ContentLength > 0 { + t.Error("expected empty body") + } + if hdr := resp.Header.Get("Location"); hdr != "https://pkg.go.dev/src.example.com/foo" { + t.Errorf("expected %q, got %q", "https://pkg.go.dev/src.example.com/foo", hdr) + } + }) +} + +func TestRedirector_getRepo(t *testing.T) { + r := &redirector{"src.example.com/x/", "https://example.com/git/", "git"} + for _, tc := range []struct{ pkg, expected string }{ + {"src.example.com/x/foo", "https://example.com/git/foo"}, + {"src.example.com/x/foo/bar", "https://example.com/git/foo"}, + } { + if actual := r.getRepo(tc.pkg); actual != tc.expected { + t.Errorf("expected %q, got %q", tc.expected, actual) + } } - for _, c := range cs { - actual := GetDest(c.srcPrefix, c.destPrefix, c.pkg) - if actual != c.expected { - t.Errorf("expected %s, got %s", c.expected, actual) + + r = &redirector{"src.example.com/x", "https://example.com/git", "git"} + for _, tc := range []struct{ pkg, expected string }{ + {"src.example.com/x/foo", "https://example.com/git/foo"}, + {"src.example.com/x/foo/bar", "https://example.com/git/foo"}, + } { + if actual := r.getRepo(tc.pkg); actual != tc.expected { + t.Errorf("expected %q, got %q", tc.expected, actual) } } } |
