diff --git a/controller.go b/controller.go index 3a64d02..da17bf1 100644 --- a/controller.go +++ b/controller.go @@ -16,8 +16,9 @@ import ( var token string type Handler struct { - db *sql.DB - ntfy *Ntfy + db *sql.DB + ntfy *Ntfy + visits *VisitHandler } func getStaticFile(relPath string, contentType string, w http.ResponseWriter) { @@ -153,6 +154,7 @@ func SetupAuth() { func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { fmt.Printf("%s - [%s] (%s) %s\n", time.Now().Format(time.RFC3339), r.RemoteAddr, r.Method, r.URL) + h.visits.HandleVisit(r.RemoteAddr) switch true { case r.Method == "GET" && r.URL.Path == "/favicon.ico": diff --git a/main.go b/main.go index 2b7953b..eaf6be3 100644 --- a/main.go +++ b/main.go @@ -16,8 +16,10 @@ func main() { db := SetupDatabase() SetupRsvpsTable(db) + visits := SetupVisitHandler(&VisitHandlerArgs{baseUrl: "http://ip-api.com", rateLimitAmount: 45, rateLimitDurationSeconds: 60}) + SetupAuth() - hnd := &Handler{db, ntfy} + hnd := &Handler{db, ntfy, visits} fmt.Println("handler ready") http.ListenAndServe(":8000", hnd) diff --git a/ntfy.go b/ntfy.go index 9511b5b..ccb10b5 100644 --- a/ntfy.go +++ b/ntfy.go @@ -71,46 +71,6 @@ func SetupNtfyClient() *Ntfy { return &Ntfy{client: tp} } -func buildTitle(rsvp *Rsvp) string { - builder := new(strings.Builder) - if rsvp.Attending { - peoplePerson := "people" - if rsvp.PartySize == 1 { - peoplePerson = "person" - } - fmt.Fprintf(builder, "%d %s confirmed!", rsvp.PartySize, peoplePerson) - } else { - fmt.Fprintf(builder, "Someone can't make it") - } - return builder.String() -} - -func buildMessage(rsvp *Rsvp) string { - builder := new(strings.Builder) - if rsvp.Attending { - builder.WriteString("Here's who is coming 👇") - for _, mem := range rsvp.PartyMembers { - age := "adult" - if mem.Child { - age = "👶" - } else { - age = "🧓" - } - dp := "n/a" - if len(mem.DietaryPreferences) > 0 { - dp = mem.DietaryPreferences - } - fmt.Fprintf(builder, "\n- %s %s, %s", age, mem.Name, dp) - } - } else { - for _, mem := range rsvp.PartyMembers { - fmt.Fprintf(builder, "%s\n", mem.Name) - } - builder.WriteString("can't make it") - } - return builder.String() -} - func numberToEmoji(num int64) string { switch num { case 0: @@ -138,7 +98,47 @@ func numberToEmoji(num int64) string { } } -func buildTags(rsvp *Rsvp) []string { +func buildRsvpTitle(rsvp *Rsvp) string { + builder := new(strings.Builder) + if rsvp.Attending { + peoplePerson := "people" + if rsvp.PartySize == 1 { + peoplePerson = "person" + } + fmt.Fprintf(builder, "%d %s confirmed!", rsvp.PartySize, peoplePerson) + } else { + fmt.Fprintf(builder, "Someone can't make it") + } + return builder.String() +} + +func buildRsvpMessage(rsvp *Rsvp) string { + builder := new(strings.Builder) + if rsvp.Attending { + builder.WriteString("Here's who is coming 👇") + for _, mem := range rsvp.PartyMembers { + age := "adult" + if mem.Child { + age = "👶" + } else { + age = "🧓" + } + dp := "n/a" + if len(mem.DietaryPreferences) > 0 { + dp = mem.DietaryPreferences + } + fmt.Fprintf(builder, "\n- %s %s, %s", age, mem.Name, dp) + } + } else { + for _, mem := range rsvp.PartyMembers { + fmt.Fprintf(builder, "%s\n", mem.Name) + } + builder.WriteString("can't make it") + } + return builder.String() +} + +func buildRsvpTags(rsvp *Rsvp) []string { if rsvp.Attending { return []string{"white_check_mark", numberToEmoji(rsvp.PartySize)} } else { @@ -146,18 +146,46 @@ func buildTags(rsvp *Rsvp) []string { } } -func BuildNtfyMessage(topic string, rsvp *Rsvp) *gotfy.Message { +func BuildRsvpMessage(topic string, rsvp *Rsvp) *gotfy.Message { return &gotfy.Message{ Topic: topic, - Message: buildMessage(rsvp), - Title: buildTitle(rsvp), - Tags: buildTags(rsvp), + Message: buildRsvpMessage(rsvp), + Title: buildRsvpTitle(rsvp), + Tags: buildRsvpTags(rsvp), Priority: gotfy.Default, } } func (n *Ntfy) PublishNewRsvpNotification(rsvp *Rsvp) (string, error) { - resp, err := n.client.SendMessage(context.Background(), BuildNtfyMessage("collinenlucy_nl", rsvp)) + resp, err := n.client.SendMessage(context.Background(), BuildRsvpMessage("collinenlucy_nl", rsvp)) + + if err != nil { + return "", err + } + + return resp.ID, nil +} + +func buildVisitTitle(visit *Visit) string { + return "Someone visited!" +} + +func buildVisitMessage(visit *Visit) string { + return fmt.Sprintf("Got a visit from %s (%s)", visit.Location, visit.RemoteAddr) +} + +func BuildVisitMessage(topic string, visit *Visit) *gotfy.Message { + return &gotfy.Message{ + Topic: topic, + Message: buildVisitMessage(visit), + Title: buildVisitTitle(visit), + Tags: []string{"eyes"}, + Priority: gotfy.Default, + } +} + +func (n *Ntfy) PublishNewVisitNotification(visit *Visit) (string, error) { + resp, err := n.client.SendMessage(context.Background(), BuildVisitMessage("collinenlucy_nl", visit)) if err != nil { return "", err diff --git a/ntfy_test.go b/ntfy_test.go index e4a7798..a25e7e2 100644 --- a/ntfy_test.go +++ b/ntfy_test.go @@ -11,7 +11,7 @@ func TestBuildNtfyMessage(t *testing.T) { {Name: "test", Child: false, DietaryPreferences: ""}, }, } - msg := BuildNtfyMessage("test", rsvp) + msg := BuildRsvpMessage("test", rsvp) if msg.Topic != "test" { t.Fatal("message topic is incorrect") @@ -34,7 +34,7 @@ func TestBuildNtfyMessage(t *testing.T) { {Name: "test", Child: false, DietaryPreferences: ""}, }, } - msg = BuildNtfyMessage("test", rsvp) + msg = BuildRsvpMessage("test", rsvp) if msg.Topic != "test" { t.Fatal("message topic is incorrect") @@ -58,7 +58,7 @@ func TestBuildNtfyMessage(t *testing.T) { {Name: "test2", Child: true, DietaryPreferences: "no tobacco"}, }, } - msg = BuildNtfyMessage("test", rsvp) + msg = BuildRsvpMessage("test", rsvp) if msg.Topic != "test" { t.Fatal("message topic is incorrect") diff --git a/visits.go b/visits.go new file mode 100644 index 0000000..ca29734 --- /dev/null +++ b/visits.go @@ -0,0 +1,139 @@ +package main + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +type Visit struct { + RemoteAddr string + Location string +} + +type GeolocationResponse struct { + Country string `json:"country"` + City string `json:"city"` + RegionName string `json:"regionName"` +} + +type VisitHandler struct { + geolocationApiBaseUrl string + visitsCache map[string]int64 + pastRequests []int64 + rateLimitAmount int + rateLimitDurationSeconds int +} + +func (v *VisitHandler) buildUrl(ipAddress string) string { + return fmt.Sprintf("%s/json/%s?fields=country,regionName,city", v.geolocationApiBaseUrl, ipAddress) +} + +func (v *VisitHandler) limitRate() bool { + now := time.Now() + + idxToRemove := make([]int, 0) + for idx, ts := range v.pastRequests { + if now.Sub(time.Unix(ts, 0)).Seconds() >= float64(v.rateLimitDurationSeconds) { + idxToRemove = append(idxToRemove, idx) + } + } + + for i := len(idxToRemove) - 1; i >= 0; i-- { + idx := idxToRemove[i] + v.pastRequests = append(v.pastRequests[:idx], v.pastRequests[idx+1:]...) + } + + if len(v.pastRequests) == 0 { + v.pastRequests = append(v.pastRequests, now.Unix()) + return true + } + + retVal := false + + if len(v.pastRequests) < v.rateLimitAmount { + v.pastRequests = append(v.pastRequests, now.Unix()) + retVal = true + } + + return retVal +} + +func (v *VisitHandler) getLocation(ipAddress string) *GeolocationResponse { + if !v.limitRate() { + return nil + } + + resp, err := http.Get(v.buildUrl(ipAddress)) + if err != nil { + return nil + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil + } + + var geoResponse GeolocationResponse + err = json.Unmarshal(body, &geoResponse) + return &geoResponse +} + +func (v *VisitHandler) createVisit(ipAddress string) *Visit { + loc := v.getLocation(ipAddress) + if loc == nil { + return nil + } + + locationParts := []string{} + if loc.City != "" { + locationParts = append(locationParts, loc.City) + } + if loc.RegionName != "" { + locationParts = append(locationParts, loc.RegionName) + } + if loc.Country != "" { + locationParts = append(locationParts, loc.Country) + } + + return &Visit{ + RemoteAddr: ipAddress, + Location: strings.Join(locationParts, ", "), + } +} + +func (v *VisitHandler) HandleVisit(remoteAddr string) *Visit { + if v.visitsCache == nil { + v.visitsCache = make(map[string]int64) + } + + if v.visitsCache[remoteAddr] == 0 || time.Since(time.Unix(v.visitsCache[remoteAddr], 0)) > time.Duration(v.rateLimitDurationSeconds) { + visit := v.createVisit(remoteAddr) + if visit == nil { + return nil + } + v.visitsCache[visit.RemoteAddr] = time.Now().Unix() + return visit + } + + return nil +} + +type VisitHandlerArgs struct { + baseUrl string + rateLimitAmount int + rateLimitDurationSeconds int +} + +func SetupVisitHandler(args *VisitHandlerArgs) *VisitHandler { + return &VisitHandler{ + geolocationApiBaseUrl: args.baseUrl, + visitsCache: make(map[string]int64), + pastRequests: make([]int64, 0), + rateLimitAmount: args.rateLimitAmount, + rateLimitDurationSeconds: args.rateLimitDurationSeconds, + } +} diff --git a/visits_test.go b/visits_test.go new file mode 100644 index 0000000..0d7fbbb --- /dev/null +++ b/visits_test.go @@ -0,0 +1,177 @@ +package main + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestHandleVisit_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"country":"Netherlands","regionName":"Noord-Holland","city":"Amsterdam"}`)) + })) + defer server.Close() + + v := SetupVisitHandler(&VisitHandlerArgs{baseUrl: server.URL}) + + visit := v.HandleVisit("home.collinduncan.com") + if visit == nil { + t.FailNow() + } + + if visit.Location != "Amsterdam, Noord-Holland, Netherlands" { + t.FailNow() + } +} + +func TestHandleVisit_NoRegion(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"country":"Netherlands","city":"Amsterdam"}`)) + })) + defer server.Close() + + v := SetupVisitHandler(&VisitHandlerArgs{baseUrl: server.URL}) + + visit := v.HandleVisit("home.collinduncan.com") + if visit == nil { + t.Fatalf("Location is nil") + } + + if visit.Location != "Amsterdam, Netherlands" { + t.Fatalf("Location is %s", visit.Location) + } +} + +func TestHandleVisit_NoCountry(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"regionName":"Noord-Holland","city":"Amsterdam"}`)) + })) + defer server.Close() + + v := SetupVisitHandler(&VisitHandlerArgs{baseUrl: server.URL}) + + visit := v.HandleVisit("home.collinduncan.com") + if visit == nil { + t.FailNow() + } + + if visit.Location != "Amsterdam, Noord-Holland" { + t.FailNow() + } +} + +func TestHandleVisit_NoCity(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"country":"Netherlands","regionName":"Noord-Holland"}`)) + })) + defer server.Close() + + v := SetupVisitHandler(&VisitHandlerArgs{baseUrl: server.URL}) + + visit := v.HandleVisit("home.collinduncan.com") + if visit == nil { + t.FailNow() + } + + if visit.Location != "Noord-Holland, Netherlands" { + t.FailNow() + } +} + +func TestHandleVisit_Nothing(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer server.Close() + + v := SetupVisitHandler(&VisitHandlerArgs{baseUrl: server.URL}) + + visit := v.HandleVisit("home.collinduncan.com") + if visit == nil { + t.FailNow() + } + + if visit.Location != "" { + t.FailNow() + } +} + +func TestHandleVisit_RateLimit(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"country":"Netherlands","regionName":"Noord-Holland","city":"Amsterdam"}`)) + })) + defer server.Close() + + v := SetupVisitHandler(&VisitHandlerArgs{ + baseUrl: server.URL, + rateLimitAmount: 30, + rateLimitDurationSeconds: 1, + }) + + // pre-limit + for i := 0; i < 30; i++ { + visit := v.HandleVisit(fmt.Sprintf("home.collinduncan%d.com", i)) + if visit == nil { + t.Fatalf("Failed before rate-limiter %d", i) + } + + if visit.Location != "Amsterdam, Noord-Holland, Netherlands" { + t.Fatal("Failed before rate-limiter, location") + } + } + + visit := v.HandleVisit("home.collinduncan90.com") + if visit != nil { + t.Fatal("Failed to hit rate limiter") + } + + time.Sleep(time.Second * 2) + + visit = v.HandleVisit("home.collinduncan90.com") + if visit == nil { + t.Fatal("Failed after rate-limiter") + } + + if visit.Location != "Amsterdam, Noord-Holland, Netherlands" { + t.Fatal("Failed after rate-limiter, location") + } +} + +func TestHandleVisit_Cache(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"country":"Netherlands","regionName":"Noord-Holland","city":"Amsterdam"}`)) + })) + defer server.Close() + + v := SetupVisitHandler(&VisitHandlerArgs{baseUrl: server.URL, rateLimitDurationSeconds: 1}) + + visit := v.HandleVisit("home.collinduncan.com") + if visit == nil { + t.FailNow() + } + + if visit.Location != "Amsterdam, Noord-Holland, Netherlands" { + t.FailNow() + } + + visit = v.HandleVisit("home.collinduncan.com") + if visit != nil { + t.FailNow() + } + + time.Sleep(time.Second * 2) + + visit = v.HandleVisit("home.collinduncan.com") + if visit == nil { + t.FailNow() + } +}