Adding auth function to controller
This commit is contained in:
parent
b009a89de9
commit
7b70cb5f2c
2 changed files with 50 additions and 0 deletions
|
|
@ -2,13 +2,18 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var token string
|
||||||
|
|
||||||
type Handler struct {
|
type Handler struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
ntfy *Ntfy
|
ntfy *Ntfy
|
||||||
|
|
@ -27,6 +32,45 @@ func getStaticFile(relPath string, contentType string, w http.ResponseWriter) {
|
||||||
w.Write(file)
|
w.Write(file)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SetupAuth() {
|
||||||
|
var username string
|
||||||
|
var password string
|
||||||
|
for _, entry := range os.Environ() {
|
||||||
|
split := strings.Split(entry, "=")
|
||||||
|
switch split[0] {
|
||||||
|
case "USERNAME":
|
||||||
|
username = split[1]
|
||||||
|
case "PASSWORD":
|
||||||
|
password = split[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
if username != "" && password != "" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if username == "" || password == "" {
|
||||||
|
log.Fatal("no authorization details")
|
||||||
|
}
|
||||||
|
|
||||||
|
sb := new(strings.Builder)
|
||||||
|
encoder := base64.NewEncoder(base64.StdEncoding, sb)
|
||||||
|
encoder.Write(fmt.Appendf(nil, "%s:%s", username, password))
|
||||||
|
encoder.Close()
|
||||||
|
|
||||||
|
token = sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func isAuthorized(r *http.Request) bool {
|
||||||
|
auth := r.Header.Get("Authorization")
|
||||||
|
|
||||||
|
if auth == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return auth == token
|
||||||
|
}
|
||||||
|
|
||||||
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
fmt.Printf("%s - [%s] (%s) %s\n", time.Now().Format(time.RFC3339), r.RemoteAddr, r.Method, r.URL)
|
fmt.Printf("%s - [%s] (%s) %s\n", time.Now().Format(time.RFC3339), r.RemoteAddr, r.Method, r.URL)
|
||||||
|
|
||||||
|
|
@ -50,6 +94,11 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
getStaticFile("../client/style.css", "text/css", w)
|
getStaticFile("../client/style.css", "text/css", w)
|
||||||
|
|
||||||
case r.Method == "GET" && r.URL.Path == "/api/rsvps":
|
case r.Method == "GET" && r.URL.Path == "/api/rsvps":
|
||||||
|
if !isAuthorized(r) {
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
rsvps, err := GetRsvps(h.db)
|
rsvps, err := GetRsvps(h.db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ func main() {
|
||||||
db := SetupDatabase()
|
db := SetupDatabase()
|
||||||
SetupRsvpsTable(db)
|
SetupRsvpsTable(db)
|
||||||
|
|
||||||
|
SetupAuth()
|
||||||
hnd := &Handler{db, ntfy}
|
hnd := &Handler{db, ntfy}
|
||||||
http.ListenAndServe(":8000", hnd)
|
http.ListenAndServe(":8000", hnd)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue