aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--main.go47
-rw-r--r--main_aws.go37
-rw-r--r--resp.go50
-rw-r--r--resp_test.go81
4 files changed, 146 insertions, 69 deletions
diff --git a/main.go b/main.go
index 8c2bfb1..71d40de 100644
--- a/main.go
+++ b/main.go
@@ -7,13 +7,12 @@ package main
import (
"context"
- "fmt"
+ "flag"
"log"
"net"
"net/http"
"os"
"os/signal"
- "path"
"time"
"golang.org/x/sys/unix"
@@ -21,41 +20,27 @@ import (
"go.awhk.org/gosdd"
)
-func redirect(resp http.ResponseWriter, req *http.Request) {
- if req.Method != http.MethodGet {
- resp.Header().Set("Allow", http.MethodGet)
- resp.WriteHeader(http.StatusMethodNotAllowed)
- return
- }
-
- pkg := path.Join(req.Host, req.URL.Path)
- if req.URL.Query().Get("go-get") != "1" {
- resp.Header().Set("Location", "https://pkg.go.dev/"+pkg)
- resp.WriteHeader(http.StatusFound)
- return
- }
- resp.Header().Set("Content-Type", "text/html; charset=utf-8")
- resp.WriteHeader(http.StatusOK)
- if _, err := fmt.Fprint(resp, GetBody(pkg)); err != nil {
- log.Println("fmt.Fprint:", err)
- }
-}
+var (
+ addr = flag.String("addr", "localhost:8080", "address to listen on")
+ from = flag.String("from", "", "package prefix to remove")
+ to = flag.String("to", "", "repository prefix to add")
+ vcs = flag.String("vcs", "git", "version control system to signal")
+)
func main() {
- mux := http.NewServeMux()
- mux.HandleFunc("/", redirect)
- srv := http.Server{Handler: mux}
+ flag.Parse()
done := make(chan os.Signal, 1)
signal.Notify(done, os.Interrupt, unix.SIGTERM)
+ srv := http.Server{Handler: &redirector{*from, *to, *vcs}}
go func() {
ln, err := listenSD()
if err != nil {
log.Fatalln("listenSD:", err)
}
if ln == nil {
- ln = listenEnv()
+ ln = listenFlag()
}
if err := srv.Serve(ln); err != nil && err != http.ErrServerClosed {
log.Fatalln("server.ListenAndServe:", err)
@@ -70,19 +55,15 @@ func main() {
}
}
-func listenEnv() net.Listener {
- addr := os.Getenv("ADDR")
- if addr == "" || addr[0] != '/' {
- if addr == "" {
- addr = ":8080"
- }
- ln, err := net.Listen("tcp", ":8080")
+func listenFlag() net.Listener {
+ if (*addr)[0] != '/' {
+ ln, err := net.Listen("tcp", *addr)
if err != nil {
log.Fatalln("net.Listen:", err)
}
return ln
}
- ln, err := net.Listen("unix", addr)
+ ln, err := net.Listen("unix", *addr)
if err != nil {
log.Fatalln("net.Listen:", err)
}
diff --git a/main_aws.go b/main_aws.go
index 120d6de..4f6f16e 100644
--- a/main_aws.go
+++ b/main_aws.go
@@ -8,27 +8,38 @@ package main
import (
"context"
"net/http"
+ "os"
"path"
+ "strings"
"github.com/aws/aws-lambda-go/events"
"github.com/aws/aws-lambda-go/lambda"
)
+var (
+ from = os.Getenv("FROM")
+ to = os.Getenv("TO")
+ vcs = os.Getenv("VCS")
+ redir = &redirector{from, to, vcs}
+)
+
func redirect(ctx context.Context, req events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error) {
- var (
- pkg = path.Join(req.Headers["Host"], req.Path)
- resp = events.APIGatewayProxyResponse{
- Body: GetBody(pkg),
- Headers: map[string]string{"Content-Type": "text/html; charset=utf-8"},
- }
- )
- if v, ok := req.QueryStringParameters["go-get"]; ok && v == "1" {
- resp.StatusCode = http.StatusOK
- } else {
- resp.Headers["Location"] = "https://pkg.go.dev/" + pkg
- resp.StatusCode = http.StatusFound
+ pkg := path.Join(req.Headers["Host"], req.Path)
+ if v, ok := req.QueryStringParameters["go-get"]; !ok || v != "1" {
+ return events.APIGatewayProxyResponse{
+ Headers: map[string]string{"Location": "https://pkg.go.dev/" + pkg},
+ StatusCode: http.StatusFound,
+ }, nil
+ }
+ var buf strings.Builder
+ if err := body.Execute(&buf, bodyData{pkg, redir.getRepo(pkg), vcs}); err != nil {
+ return events.APIGatewayProxyResponse{}, err
}
- return resp, nil
+ return events.APIGatewayProxyResponse{
+ Body: buf.String(),
+ Headers: map[string]string{"Content-Type": "text/html; charset=utf-8"},
+ StatusCode: http.StatusOK,
+ }, nil
}
func main() {
diff --git a/resp.go b/resp.go
index f90c486..408dcd3 100644
--- a/resp.go
+++ b/resp.go
@@ -4,22 +4,48 @@
package main
import (
- "fmt"
- "os"
+ "log"
+ "net/http"
+ "path"
"strings"
+ "text/template"
)
-func GetBody(pkg string) string {
- dest := GetDest(os.Getenv("PREFIX"), os.Getenv("DEST"), pkg)
- return fmt.Sprintf(`<!doctype html>
-<meta name="go-import" content="%s %s %s">
+var body = template.Must(template.New("").Parse(`<!doctype html>
+<meta name="go-import" content="{{.Package}} {{.VCS}} {{.Repository}}">
<title>go-import-redirect</title>
-`, pkg, os.Getenv("VCS"), dest)
+`))
+
+type bodyData struct{ Package, Repository, VCS string }
+
+type redirector struct{ from, to, vcs string }
+
+var _ http.Handler = &redirector{}
+
+func (h *redirector) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+ if req.Method != http.MethodGet {
+ w.Header().Set("Allow", http.MethodGet)
+ w.WriteHeader(http.StatusMethodNotAllowed)
+ return
+ }
+
+ pkg := path.Join(req.Host, req.URL.Path)
+ if req.URL.Query().Get("go-get") != "1" {
+ w.Header().Set("Location", "https://pkg.go.dev/"+pkg)
+ w.WriteHeader(http.StatusFound)
+ return
+ }
+ dest := h.getRepo(pkg)
+ w.Header().Set("Content-Type", "text/html; charset=utf-8")
+ w.WriteHeader(http.StatusOK)
+ if err := body.Execute(w, bodyData{pkg, dest, h.vcs}); err != nil {
+ log.Println(err)
+ }
}
-func GetDest(srcPrefix, destPrefix, pkg string) string {
- srcPrefix = strings.TrimRight(srcPrefix, "/")
- destPrefix = strings.TrimRight(destPrefix, "/")
- path := strings.TrimLeft(strings.TrimPrefix(pkg, srcPrefix), "/")
- return destPrefix + "/" + strings.Split(path, "/")[0]
+func (h *redirector) getRepo(pkg string) string {
+ from := strings.TrimRight(h.from, "/")
+ to := strings.TrimRight(h.to, "/")
+ path := strings.TrimLeft(strings.TrimPrefix(pkg, from), "/")
+ return to + "/" + strings.Split(path, "/")[0]
}
diff --git a/resp_test.go b/resp_test.go
index 4afbd92..434d985 100644
--- a/resp_test.go
+++ b/resp_test.go
@@ -3,19 +3,78 @@
package main
-import "testing"
+import (
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+)
-func TestGetDest(t *testing.T) {
- cs := []struct{ srcPrefix, destPrefix, pkg, expected string }{
- {"src.example.com/x/", "https://example.com/git/", "src.example.com/x/foo", "https://example.com/git/foo"},
- {"src.example.com/x/", "https://example.com/git/", "src.example.com/x/foo/bar", "https://example.com/git/foo"},
- {"src.example.com/x", "https://example.com/git", "src.example.com/x/foo", "https://example.com/git/foo"},
- {"src.example.com/x", "https://example.com/git", "src.example.com/x/foo/bar", "https://example.com/git/foo"},
+func TestRedirector_ServeHTTP(t *testing.T) {
+ r := &redirector{"src.example.com/x", "https://example.com/git", "git"}
+
+ t.Run("GoVisit", func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodGet, "https://src.example.com/foo?go-get=1", nil)
+ w := httptest.NewRecorder()
+ r.ServeHTTP(w, req)
+
+ resp := w.Result()
+ if http.StatusOK != resp.StatusCode {
+ t.Errorf("expected %d, got %d", http.StatusFound, resp.StatusCode)
+ }
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ t.Error(err)
+ t.FailNow()
+ }
+ expected := `<!doctype html>
+<meta name="go-import" content="src.example.com/foo git https://example.com/git/src.example.com">
+<title>go-import-redirect</title>
+`
+ if string(body) != expected {
+ t.Errorf("expected\n---\n%s\n---\ngot\n---\n%s\n---", expected, string(body))
+ }
+ if hdr := resp.Header.Get("Location"); hdr != "" {
+ t.Error("expected empty Location header")
+ }
+ })
+
+ t.Run("UserVisit", func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodGet, "https://src.example.com/foo", nil)
+ w := httptest.NewRecorder()
+ r.ServeHTTP(w, req)
+
+ resp := w.Result()
+ if http.StatusFound != resp.StatusCode {
+ t.Errorf("expected %d, got %d", http.StatusFound, resp.StatusCode)
+ }
+ if resp.ContentLength > 0 {
+ t.Error("expected empty body")
+ }
+ if hdr := resp.Header.Get("Location"); hdr != "https://pkg.go.dev/src.example.com/foo" {
+ t.Errorf("expected %q, got %q", "https://pkg.go.dev/src.example.com/foo", hdr)
+ }
+ })
+}
+
+func TestRedirector_getRepo(t *testing.T) {
+ r := &redirector{"src.example.com/x/", "https://example.com/git/", "git"}
+ for _, tc := range []struct{ pkg, expected string }{
+ {"src.example.com/x/foo", "https://example.com/git/foo"},
+ {"src.example.com/x/foo/bar", "https://example.com/git/foo"},
+ } {
+ if actual := r.getRepo(tc.pkg); actual != tc.expected {
+ t.Errorf("expected %q, got %q", tc.expected, actual)
+ }
}
- for _, c := range cs {
- actual := GetDest(c.srcPrefix, c.destPrefix, c.pkg)
- if actual != c.expected {
- t.Errorf("expected %s, got %s", c.expected, actual)
+
+ r = &redirector{"src.example.com/x", "https://example.com/git", "git"}
+ for _, tc := range []struct{ pkg, expected string }{
+ {"src.example.com/x/foo", "https://example.com/git/foo"},
+ {"src.example.com/x/foo/bar", "https://example.com/git/foo"},
+ } {
+ if actual := r.getRepo(tc.pkg); actual != tc.expected {
+ t.Errorf("expected %q, got %q", tc.expected, actual)
}
}
}