small pixel drawing of a pufferfish cascade

internal/agent/http.go

package agent

import (
	"encoding/json"
	"fmt"
	"io"
	"net"
	"net/http"
	"strings"
	"time"
	"unicode"

	"git.j3s.sh/cascade/internal/lib"
)

// MethodNotAllowedError should be returned by a handler when the HTTP method is not allowed.
type MethodNotAllowedError struct {
	Method string
	Allow  []string
}

func (e MethodNotAllowedError) Error() string {
	return fmt.Sprintf("method %s not allowed", e.Method)
}

// BadRequestError should be returned by a handler when parameters or the payload are not valid
type BadRequestError struct {
	Reason string
}

func (e BadRequestError) Error() string {
	return fmt.Sprintf("Bad request: %s", e.Reason)
}

// NotFoundError should be returned by a handler when a resource specified does not exist
type NotFoundError struct {
	Reason string
}

func (e NotFoundError) Error() string {
	return e.Reason
}

// CodeWithPayloadError allow returning non HTTP 200
// Error codes while not returning PlainText payload
type CodeWithPayloadError struct {
	Reason      string
	StatusCode  int
	ContentType string
}

func (e CodeWithPayloadError) Error() string {
	return e.Reason
}

type ForbiddenError struct {
}

func (e ForbiddenError) Error() string {
	return "Access is restricted"
}

// HTTPHandlers provides an HTTP api for an agent.
// agent is copied into this struct because we need
// to call some of its functions & access some of its data
type HTTPHandlers struct {
	h     http.Handler
	agent *Agent
}

// handler is used to initialize the Handler.
// In agent code we only ever call this once.
func (s *HTTPHandlers) handler() http.Handler {
	mux := http.NewServeMux()

	// handleFuncMetrics takes the given pattern and handler and wraps to produce
	// metrics based on the pattern and request.
	handleFuncMetrics := func(pattern string, handler http.HandlerFunc) {
		// Get the parts of the pattern. We omit any initial empty for the
		// leading slash, and put an underscore as a "thing" placeholder if we
		// see a trailing slash, which means the part after is parsed. This lets
		// us distinguish from things like /v1/query and /v1/query/<query id>.
		var parts []string
		for i, part := range strings.Split(pattern, "/") {
			if part == "" {
				if i == 0 {
					continue
				}
				part = "_"
			}
			parts = append(parts, part)
		}

		// Tranform the pattern to a valid label by replacing the '/' by '_'.
		// Omit the leading slash.
		// Distinguish thing like /v1/query from /v1/query/<query_id> by having
		// an extra underscore.
		path_label := strings.Replace(pattern[1:], "/", "_", -1)

		// Register the wrapper.
		wrapper := func(resp http.ResponseWriter, req *http.Request) {
			start := time.Now()
			handler(resp, req)

			s.agent.logger.Warn("request metrics", "method", req.Method, "path", path_label, "latency", time.Since(start))
		}

		mux.Handle(pattern, http.HandlerFunc(wrapper))
	}

	mux.HandleFunc("/", s.Index)
	endpoints := map[string]func(resp http.ResponseWriter, req *http.Request){
		"/v1/agent/members": s.agentMembers,
	}

	for pattern, fn := range endpoints {
		handleFuncMetrics(pattern, fn)
	}

	// This handler bans URLs with non-printable characters
	h := printablePathCheckHandler(mux, nil)
	return h
}

type HandlerInput struct {
	ErrStatus int
}

func printablePathCheckHandler(next http.Handler, input *HandlerInput) http.Handler {
	// Nil-check on input to make it optional
	if input == nil {
		input = &HandlerInput{
			ErrStatus: http.StatusBadRequest,
		}
	}

	// Default to http.StatusBadRequest on error
	if input.ErrStatus == 0 {
		input.ErrStatus = http.StatusBadRequest
	}

	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		if r != nil {
			// Check URL path for non-printable characters
			idx := strings.IndexFunc(r.URL.Path, func(c rune) bool {
				return !unicode.IsPrint(c)
			})

			if idx != -1 {
				w.WriteHeader(input.ErrStatus)
				return
			}

			if next != nil {
				next.ServeHTTP(w, r)
			}
		}

		return
	})
}

// wrap is used to wrap functions to make them more convenient
// func (s *HTTPHandlers) wrap(handler http.HandlerFunc) http.HandlerFunc {
// 	httpLogger := s.agent.logger
// 	return func(resp http.ResponseWriter, req *http.Request) {
// 		setHeaders(resp, s.agent.Config.HTTPResponseHeaders)
//
// 		addAllowHeader := func(methods []string) {
// 			resp.Header().Add("Allow", strings.Join(methods, ","))
// 		}
//
// 		logURL := req.URL.String()
//
// 		handleErr := func(err error) {
// 			httpLogger.Error("Request error", err,
// 				"method", req.Method,
// 				"url", logURL,
// 				"from", req.RemoteAddr,
// 				"error", err,
// 			)
// 			switch {
// 			default:
// 				resp.WriteHeader(http.StatusInternalServerError)
// 				fmt.Fprint(resp, err.Error())
// 			}
// 		}
//
// 		start := time.Now()
// 		defer func() {
// 			httpLogger.Debug("Request finished",
// 				"method", req.Method,
// 				"url", logURL,
// 				"from", req.RemoteAddr,
// 				"latency", time.Since(start).String(),
// 			)
// 		}()
//
// 		contentType := "application/json"
// 		httpCode := http.StatusOK
// 		var buf []byte
// 		if contentType == "application/json" {
// 			buf, err = s.marshalJSON(req, handler(resp, req))
// 			if err != nil {
// 				handleErr(err)
// 				return
// 			}
// 		} else {
// 			if strings.HasPrefix(contentType, "text/") {
// 				if val, ok := obj.(string); ok {
// 					buf = []byte(val)
// 				}
// 			}
// 		}
// 		resp.Header().Set("Content-Type", contentType)
// 		resp.WriteHeader(httpCode)
// 		resp.Write(buf)
// 	}
// }

// marshalJSON marshals the object into JSON
func (s *HTTPHandlers) marshalJSON(req *http.Request, obj interface{}) ([]byte, error) {
	buf, err := json.Marshal(obj)
	if err != nil {
		return nil, err
	}
	return buf, nil
}

// Renders a simple index page
func (s *HTTPHandlers) Index(resp http.ResponseWriter, req *http.Request) {
	// Send special headers too since this endpoint isn't wrapped with something
	// that sends them.
	// setHeaders(resp, s.agent.Config.HTTPResponseHeaders)

	// Check if this is a non-index path
	if req.URL.Path != "/" {
		resp.WriteHeader(http.StatusNotFound)
		return
	}

	// Give them something helpful if there's no UI so they at least know
	// what this server is.
	fmt.Fprint(resp, "cascade agent\n")
	return
}

func decodeBody(body io.Reader, out interface{}) error {
	return lib.DecodeJSON(body, out)
}

// setHeaders is used to set canonical response header fields
func setHeaders(resp http.ResponseWriter, headers map[string]string) {
	for field, value := range headers {
		resp.Header().Set(http.CanonicalHeaderKey(field), value)
	}
}

// serveHandlerWithHeaders is used to serve a http.Handler with the specified headers
func serveHandlerWithHeaders(h http.Handler, headers map[string]string) http.HandlerFunc {
	return func(resp http.ResponseWriter, req *http.Request) {
		setHeaders(resp, headers)
		h.ServeHTTP(resp, req)
	}
}

func sourceAddrFromRequest(req *http.Request) string {
	xff := req.Header.Get("X-Forwarded-For")
	forwardHosts := strings.Split(xff, ",")
	if len(forwardHosts) > 0 {
		forwardIp := net.ParseIP(strings.TrimSpace(forwardHosts[0]))
		if forwardIp != nil {
			return forwardIp.String()
		}
	}

	host, _, err := net.SplitHostPort(req.RemoteAddr)
	if err != nil {
		return ""
	}

	ip := net.ParseIP(host)
	if ip != nil {
		return ip.String()
	} else {
		return ""
	}
}

func (s *HTTPHandlers) parseFilter(req *http.Request, filter *string) {
	if other := req.URL.Query().Get("filter"); other != "" {
		*filter = other
	}
}