changeset 24:0a0ea2b34787

Propagate contexts everywhere
author Lewin Bormann <lbo@spheniscida.de>
date Sat, 10 Dec 2016 10:23:30 +0100
parents ee729eaf611c
children 74f0036928e2
files handler_todo.go handlers.go http.go pull.go webhook.go
diffstat 5 files changed, 42 insertions(+), 22 deletions(-) [+]
line wrap: on
line diff
--- a/handler_todo.go	Sat Dec 10 10:03:15 2016 +0100
+++ b/handler_todo.go	Sat Dec 10 10:23:30 2016 +0100
@@ -1,8 +1,13 @@
 package main
 
-import "log"
+import (
+	"context"
+	"log"
 
-func todoButtonTest(msg message) (replyContent, error) {
+	_ "bitbucket.org/dermesser/goe_bot/sql"
+)
+
+func todoButtonTest(ctx context.Context, msg message) (replyContent, error) {
 	buttonRows := [][]inlineKeyboardButton{}
 
 	for _, b := range []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15"} {
@@ -12,7 +17,7 @@
 	return replyContent{text: "Please select", buttons: inlineKeyboardMarkup{buttonRows}}, nil
 }
 
-func todoDoneHandler(token string, cbq callbackQuery) (replyContent, error) {
+func todoDoneHandler(ctx context.Context, token string, cbq callbackQuery) (replyContent, error) {
 	log.Println("Received callback for", token)
 
 	reply := replyContent{text: "Verarbeitet!"}
--- a/handlers.go	Sat Dec 10 10:03:15 2016 +0100
+++ b/handlers.go	Sat Dec 10 10:23:30 2016 +0100
@@ -2,6 +2,7 @@
 
 import (
 	"bytes"
+	"context"
 	"errors"
 	"fmt"
 	"log"
@@ -22,12 +23,12 @@
 )
 
 type handler struct {
-	fn   (func(message) (replyContent, error))
+	fn   (func(context.Context, message) (replyContent, error))
 	desc string
 }
 
 // A handler taking the callback data. `data` only contains the token, not the callback type.
-type callbackHandler (func(data string, cbq callbackQuery) (replyContent, error))
+type callbackHandler (func(ctx context.Context, data string, cbq callbackQuery) (replyContent, error))
 
 var (
 	handlers = map[string]handler{
@@ -50,19 +51,19 @@
 	buttons inlineKeyboardMarkup
 }
 
-func echoHandler(msg message) (replyContent, error) {
+func echoHandler(ctx context.Context, msg message) (replyContent, error) {
 	return replyContent{text: msg.Text}, nil
 }
 
-func missingHandler(msg message) (replyContent, error) {
+func missingHandler(ctx context.Context, msg message) (replyContent, error) {
 	return replyContent{text: "_tut mir leid, das kann ich noch nicht :(_"}, nil
 }
 
-func unknownCommandHandler(msg message) (replyContent, error) {
+func unknownCommandHandler(ctx context.Context, msg message) (replyContent, error) {
 	return replyContent{text: "_tut mir leid, das verstehe ich nicht (*/help* für Hilfe)_"}, nil
 }
 
-func badFormatHandler(msg message) (replyContent, error) {
+func badFormatHandler(ctx context.Context, msg message) (replyContent, error) {
 	return replyContent{text: "_bitte benutze einen Befehl, der mit '/' beginnt!_"}, nil
 }
 
@@ -76,13 +77,13 @@
 	return replyContent{text: b.String()}, nil
 }
 
-func statusHandler(message) (replyContent, error) {
+func statusHandler(ctx context.Context, msg message) (replyContent, error) {
 	return replyContent{text: srvStatus.String()}, nil
 }
 
 // Dispatch an incoming update to the right handler. dispatch() only selects between messages and callbacks; dispatchMessage
 // and dispatchCallback select the actual handlers.
-func dispatch(upd update) (sendMessage, error) {
+func dispatch(ctx context.Context, upd update) (sendMessage, error) {
 	var rp replyContent
 	var err error
 	var chatID uint64
@@ -90,11 +91,11 @@
 	if upd.Message.Message_ID > 0 {
 		srvStatus.commands++
 		chatID = upd.Message.Chat.ID
-		rp, err = dispatchMessage(upd.Message)
+		rp, err = dispatchMessage(ctx, upd.Message)
 	} else if upd.Callback_Query.ID != "" {
 		srvStatus.callbacks++
 		chatID = upd.Callback_Query.Message.Chat.ID
-		rp, err = dispatchCallback(upd.Callback_Query)
+		rp, err = dispatchCallback(ctx, upd.Callback_Query)
 	} else {
 		reply := sendMessage{
 			Chat_ID:    upd.Message.Chat.ID,
@@ -125,11 +126,11 @@
 }
 
 // Dispatches an incoming callbackQuery. The data field has the format callback_type:token.
-func dispatchCallback(cbq callbackQuery) (replyContent, error) {
+func dispatchCallback(ctx context.Context, cbq callbackQuery) (replyContent, error) {
 	parts := strings.Split(cbq.Data, ":") // callback:token
 
 	if handler, ok := callbackHandlers[parts[0]]; ok {
-		rp, err := handler(parts[1], cbq)
+		rp, err := handler(ctx, parts[1], cbq)
 
 		return rp, err
 	} else {
@@ -138,7 +139,7 @@
 	}
 }
 
-func dispatchMessage(msg message) (replyContent, error) {
+func dispatchMessage(ctx context.Context, msg message) (replyContent, error) {
 	var rp replyContent
 	var err error
 
@@ -156,14 +157,14 @@
 		if cmd == helpCmd { // special case
 			rp, err = helpHandler(msg)
 		} else if h, ok := handlers[cmd]; ok {
-			rp, err = h.fn(msg)
+			rp, err = h.fn(ctx, msg)
 		} else {
 			log.Println("Unknown command", cmd)
-			rp, err = unknownCommandHandler(msg)
+			rp, err = unknownCommandHandler(ctx, msg)
 		}
 	} else {
 		log.Println("Bad format:", msg.Text)
-		rp, err = badFormatHandler(msg)
+		rp, err = badFormatHandler(ctx, msg)
 	}
 
 	return rp, err
--- a/http.go	Sat Dec 10 10:03:15 2016 +0100
+++ b/http.go	Sat Dec 10 10:23:30 2016 +0100
@@ -1,6 +1,7 @@
 package main
 
 import (
+	"context"
 	"fmt"
 	"io/ioutil"
 	"log"
@@ -42,6 +43,9 @@
 
 // Handles simple text posted to a debug handler (/debug).
 func debugHandler(rp http.ResponseWriter, rq *http.Request) {
+	ctx, cancel := context.WithTimeout(rq.Context(), time.Duration(*flagDeadline)*time.Second)
+	defer cancel()
+
 	body, err := ioutil.ReadAll(rq.Body)
 
 	if err != nil {
@@ -56,7 +60,7 @@
 
 	srvStatus.commands++
 
-	reply, err := dispatch(update{Message: msg, Update_ID: 12345})
+	reply, err := dispatch(ctx, update{Message: msg, Update_ID: 12345})
 
 	if err != nil {
 		srvStatus.errors++
--- a/pull.go	Sat Dec 10 10:03:15 2016 +0100
+++ b/pull.go	Sat Dec 10 10:23:30 2016 +0100
@@ -2,15 +2,20 @@
 
 import (
 	"bytes"
+	"context"
 	"encoding/json"
 	"io/ioutil"
 	"log"
 	"net/http"
+	"time"
 )
 
 // After the webhook is called, it should process the request, respond with http.StatusOK (if appropriate),
 // and then call the sendMessage method.
 func pullUpdates(off uint64) (uint64, error) {
+	ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*flagDeadline)*time.Second)
+	defer cancel()
+
 	gU := getUpdates{Offset: off, Timeout: uint64(*flagDeadline)}
 	body, err := json.Marshal(gU)
 
@@ -45,7 +50,7 @@
 	}
 
 	for u := range upds.Result {
-		sM, err := dispatch(upds.Result[u])
+		sM, err := dispatch(ctx, upds.Result[u])
 
 		if err != nil {
 			log.Println(err)
--- a/webhook.go	Sat Dec 10 10:03:15 2016 +0100
+++ b/webhook.go	Sat Dec 10 10:23:30 2016 +0100
@@ -2,12 +2,14 @@
 
 import (
 	"bytes"
+	"context"
 	"encoding/json"
 	"errors"
 	"fmt"
 	"io/ioutil"
 	"log"
 	"net/http"
+	"time"
 )
 
 // Register webhook
@@ -88,6 +90,9 @@
 // Webhook handler (push method; /hook)
 // Receives `update`s as JSON
 func updatesHandler(rp http.ResponseWriter, rq *http.Request) {
+	ctx, cancel := context.WithTimeout(rq.Context(), time.Duration(*flagDeadline)*time.Second)
+	defer cancel()
+
 	body, err := ioutil.ReadAll(rq.Body)
 
 	if err != nil {
@@ -112,7 +117,7 @@
 		return
 	}
 
-	responseMsg, err := dispatch(upd)
+	responseMsg, err := dispatch(ctx, upd)
 
 	if err != nil {
 		rp.WriteHeader(http.StatusInternalServerError)