small pixel drawing of a pufferfish cascade

command/agent/agent.go

package agent

import (
	"flag"
	"fmt"
	"log"
	"net"
	"os"
	"os/signal"
	"strconv"
	"strings"
	"syscall"
	"time"

	"git.j3s.sh/cascade/agent"
)

// gracefulTimeout controls how long we wait before forcefully terminating
// note that this value interacts with serf's LeavePropagateDelay config
const gracefulTimeout = 10 * time.Second

const usage = `cascade agent [options]

  this command starts the cascade agent, which is responsible
  for basically everything, including service registration,
  health checking, cluster membership, and hosting the API.

options:
  -bind-dns=<addr>
    address the DNS server binds to (default = 127.0.0.1:8600)

  -bind-http=<addr>
    address the http api/interace binds to (default = 127.0.0.1:8500)

  -bind-serf=<addr>
    address the serf agent binds to (default = 0.0.0.0:8301)

  -join=<addrs>
    comma-separated address of agents to join at start time (default = nil)

  -node=<name>
    name of this node, must be globally unique (default = hostname)
`

type Flags struct {
	bindDNS  string
	bindHTTP string
	bindSerf string
	join     string
	node     string
}

var agentFlags Flags

func Run(args []string) {
	flags := flag.NewFlagSet("agent", flag.ContinueOnError)
	flags.Usage = func() { fmt.Printf(usage) }
	flags.StringVar(&agentFlags.bindDNS, "bind-dns", "", "")
	flags.StringVar(&agentFlags.bindHTTP, "bind-http", "", "")
	flags.StringVar(&agentFlags.bindSerf, "bind-serf", "", "")
	flags.StringVar(&agentFlags.join, "join", "", "")
	flags.StringVar(&agentFlags.node, "node", "", "")
	if err := flags.Parse(args); err != nil {
		os.Exit(1)
	}

	config, err := getAgentConfig()
	if err != nil {
		fmt.Println(err)
		os.Exit(1)
	}
	agent := agent.New(config)

	fmt.Printf("-> starting cascade agent\n")
	fmt.Printf("    node '%s'\n", config.NodeName)
	if len(config.StartJoin) > 0 {
		fmt.Printf("    join '%s'\n", config.StartJoin)
	}
	fmt.Printf("    bind addrs:\n")
	fmt.Printf("        dns  '%s'\n", config.DNSBindAddr)
	fmt.Printf("        http '%s'\n", config.HTTPBindAddr)
	fmt.Printf("        serf '%s'\n", config.SerfBindAddr)
	fmt.Printf("\n-> logs\n")

	if err := agent.Start(); err != nil {
		fmt.Println(err)
		os.Exit(1)
	}
	defer agent.Shutdown()

	if err := startupJoin(agent); err != nil {
		fmt.Println(err)
		os.Exit(1)
	}
	if err := handleSignals(agent); err != nil {
		fmt.Println(err)
		os.Exit(1)
	}
}

// handleSignals blocks until we get an exit-causing signal
func handleSignals(agent *agent.Agent) error {
	signalCh := make(chan os.Signal, 4)
	signal.Notify(signalCh, os.Interrupt, syscall.SIGTERM)

	// Wait for a signal
	var sig os.Signal
	select {
	case s := <-signalCh:
		sig = s
	case <-agent.ShutdownCh():
		// Agent is already shutdown!
		return nil
	}
	fmt.Fprintf(os.Stderr, "caught signal %s", sig)

	// Check if we should do a graceful leave
	graceful := false
	if sig == os.Interrupt || sig == syscall.SIGTERM {
		graceful = true
	}

	// Bail fast if not doing a graceful leave
	if !graceful {
		fmt.Fprintf(os.Stderr, "leave cluster with zero grace")
		return nil
	}

	// Attempt a graceful leave
	gracefulCh := make(chan struct{})
	fmt.Printf("shutting down gracefully")
	go func() {
		if err := agent.Leave(); err != nil {
			fmt.Println("Error: ", err)
			return
		}
		close(gracefulCh)
	}()

	// Wait for leave or another signal
	select {
	case <-signalCh:
		return fmt.Errorf("idfk")
	case <-time.After(gracefulTimeout):
		return fmt.Errorf("leave timed out")
	case <-gracefulCh:
		return nil
	}
}

func startupJoin(a *agent.Agent) error {
	if len(a.Config.StartJoin) == 0 {
		return nil
	}

	n, err := a.Join(a.Config.StartJoin)
	if err != nil {
		return err
	}
	if n > 0 {
		log.Printf("issue join request nodes=%d\n", n)
	}

	return nil
}

// getAgentConfig takes a default agent config and modifies it based
// on user specified flags. It also does some a little input validation.
func getAgentConfig() (*agent.Config, error) {
	config := agent.DefaultConfig()

	if agentFlags.bindDNS != "" {
		if err := parseFlagAddress(agentFlags.bindDNS, config.DNSBindAddr); err != nil {
			return nil, err
		}
	}
	if agentFlags.bindHTTP != "" {
		if err := parseFlagAddress(agentFlags.bindHTTP, config.HTTPBindAddr); err != nil {
			return nil, err
		}
	}
	if agentFlags.bindSerf != "" {
		if err := parseFlagAddress(agentFlags.bindSerf, config.SerfBindAddr); err != nil {
			return nil, err
		}
	}
	if agentFlags.join != "" {
		// TODO: moar validation
		config.StartJoin = strings.Split(agentFlags.join, ",")
	}

	if agentFlags.node != "" {
		config.NodeName = agentFlags.node
	}
	return config, nil
}

// parseFlagAddress takes a colon-delimited host:port pair as a string, parses
// out the ip (and optionally, the port), and modifies the passed TCPAddr with
// the resulting values.
func parseFlagAddress(hostPort string, tcpAddr *net.TCPAddr) error {
	addr, portStr, err := net.SplitHostPort(hostPort)
	if err != nil {
		if !strings.Contains(err.Error(), "missing port in address") {
			return fmt.Errorf("Error parsing address: %v", err)
		}

		// If we get a missing port error, we try to coerce the whole hostPort
		// into an address. This allows the user to supply just a host address
		// instead of always requiring a host:ip pair.
		addr = hostPort
	}

	if addr == "" {
		return fmt.Errorf("Error parsing blank address")
	}
	ip := net.ParseIP(addr)
	if ip == nil {
		return fmt.Errorf("Error parsing address %q: not a valid IP address", ip)
	}

	if portStr != "" {
		port, err := strconv.Atoi(portStr)
		if err != nil {
			return fmt.Errorf("Error parsing port: %s", err)
		}
		tcpAddr.Port = port
	}
	tcpAddr.IP = ip

	return nil
}