diff --git a/server/controller.go b/server/controller.go index b4b0b88..150989f 100644 --- a/server/controller.go +++ b/server/controller.go @@ -2,13 +2,18 @@ package main import ( "database/sql" + "encoding/base64" "fmt" + "log" "net/http" "os" "strconv" + "strings" "time" ) +var token string + type Handler struct { db *sql.DB ntfy *Ntfy @@ -27,6 +32,45 @@ func getStaticFile(relPath string, contentType string, w http.ResponseWriter) { w.Write(file) } +func SetupAuth() { + var username string + var password string + for _, entry := range os.Environ() { + split := strings.Split(entry, "=") + switch split[0] { + case "USERNAME": + username = split[1] + case "PASSWORD": + password = split[1] + } + + if username != "" && password != "" { + break + } + } + + if username == "" || password == "" { + log.Fatal("no authorization details") + } + + sb := new(strings.Builder) + encoder := base64.NewEncoder(base64.StdEncoding, sb) + encoder.Write(fmt.Appendf(nil, "%s:%s", username, password)) + encoder.Close() + + token = sb.String() +} + +func isAuthorized(r *http.Request) bool { + auth := r.Header.Get("Authorization") + + if auth == "" { + return false + } + + return auth == token +} + func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { fmt.Printf("%s - [%s] (%s) %s\n", time.Now().Format(time.RFC3339), r.RemoteAddr, r.Method, r.URL) @@ -50,6 +94,11 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { getStaticFile("../client/style.css", "text/css", w) case r.Method == "GET" && r.URL.Path == "/api/rsvps": + if !isAuthorized(r) { + w.WriteHeader(http.StatusUnauthorized) + return + } + rsvps, err := GetRsvps(h.db) if err != nil { w.WriteHeader(http.StatusInternalServerError) diff --git a/server/main.go b/server/main.go index 3c2be00..cea68ce 100644 --- a/server/main.go +++ b/server/main.go @@ -12,6 +12,7 @@ func main() { db := SetupDatabase() SetupRsvpsTable(db) + SetupAuth() hnd := &Handler{db, ntfy} http.ListenAndServe(":8000", hnd) }