small pixel drawing of a pufferfish cascade

command/agent/agent.go

package agent

import (
	"flag"
	"fmt"
	"log"
	"net"
	"os"
	"os/signal"
	"strconv"
	"strings"
	"syscall"
	"time"

	"git.j3s.sh/cascade/agent"
)

// gracefulTimeout controls how long we wait before forcefully terminating
// note that this value interacts with serf's LeavePropagateDelay config
const gracefulTimeout = 10 * time.Second

const usage = `cascade agent [options]

  this command starts the cascade agent, which is responsible
  for basically everything, including service registration,
  health checking, cluster membership, and hosting the API.

common options:

  -node=<name>
    name of this node, must be globally unique (default = hostname)

bind options:

  -api-bind=<addr>
    address the http api binds to (default = 127.0.0.1:8500)
    !!do not expose the api publicly!!

  -dns-bind=<addr>
    address the DNS resolver to (default = 127.0.0.1:8600)
    !!do not expose the api publicly!!

  -serf-bind=<addr>
    address the serf gossip cluster binds to (default = 0.0.0.0:8301)
`

var (
	apiBindAddr  string
	serfBindAddr string
)

func Run(args []string) {
	flags := flag.NewFlagSet("agent", flag.ContinueOnError)
	flags.Usage = func() { fmt.Printf(usage) }
	flags.StringVar(&serfBindAddr, "serf-bind", "", "")
	if err := flags.Parse(args); err != nil {
		fmt.Println(err)
		os.Exit(1)
	}

	config, err := configureAgent()
	if err != nil {
		fmt.Println(err)
		os.Exit(1)
	}
	agent := agent.New(config)
	if err := agent.Start(); err != nil {
		fmt.Println(err)
		os.Exit(1)
	}
	defer agent.Shutdown()
	// join any specified startup nodes
	if err := startupJoin(agent); err != nil {
		fmt.Println(err)
		os.Exit(1)
	}
	if err := handleSignals(agent); err != nil {
		fmt.Println(err)
		os.Exit(1)
	}
}

// handleSignals blocks until we get an exit-causing signal
func handleSignals(agent *agent.Agent) error {
	signalCh := make(chan os.Signal, 4)
	signal.Notify(signalCh, os.Interrupt, syscall.SIGTERM)

	// Wait for a signal
	var sig os.Signal
	select {
	case s := <-signalCh:
		sig = s
	case <-agent.ShutdownCh():
		// Agent is already shutdown!
		return nil
	}
	fmt.Fprintf(os.Stderr, "caught signal %s", sig)

	// Check if we should do a graceful leave
	graceful := false
	if sig == os.Interrupt || sig == syscall.SIGTERM {
		graceful = true
	}

	// Bail fast if not doing a graceful leave
	if !graceful {
		fmt.Fprintf(os.Stderr, "leave cluster with zero grace")
		return nil
	}

	// Attempt a graceful leave
	gracefulCh := make(chan struct{})
	fmt.Printf("shutting down gracefully")
	go func() {
		if err := agent.Leave(); err != nil {
			fmt.Println("Error: ", err)
			return
		}
		close(gracefulCh)
	}()

	// Wait for leave or another signal
	select {
	case <-signalCh:
		return fmt.Errorf("idfk")
	case <-time.After(gracefulTimeout):
		return fmt.Errorf("leave timed out")
	case <-gracefulCh:
		return nil
	}
}

func startupJoin(a *agent.Agent) error {
	if len(a.Config.StartJoin) == 0 {
		return nil
	}

	n, err := a.Join(a.Config.StartJoin)
	if err != nil {
		return err
	}
	if n > 0 {
		log.Printf("issue join request nodes=%d\n", n)
	}

	return nil
}

func configureAgent() (*agent.Config, error) {
	config := agent.DefaultConfig()
	// CASCADE_BIND=192.168.0.15:12345
	if os.Getenv("CASCADE_BIND") != "" {
		err := parseFlagAddress(os.Getenv("CASCADE_BIND"), config.SerfBindAddr)
		if err != nil {
			return nil, err
		}
	}
	// CASCADE_JOIN=127.0.0.1,127.0.0.5
	if os.Getenv("CASCADE_JOIN") != "" {
		config.StartJoin = strings.Split(os.Getenv("CASCADE_JOIN"), ",")
	}
	// CASCADE_NAME=nostromo.j3s.sh
	if os.Getenv("CASCADE_NAME") != "" {
		config.NodeName = os.Getenv("CASCADE_NAME")
	}
	return config, nil
}

// parseFlagAddress takes a colon-delimited host:port pair as a string, parses
// out the ip (and optionally, the port), and modifies the passed TCPAddr with
// the resulting values.
func parseFlagAddress(hostPort string, tcpAddr *net.TCPAddr) error {
	addr, portStr, err := net.SplitHostPort(hostPort)
	if err != nil {
		if !strings.Contains(err.Error(), "missing port in address") {
			return fmt.Errorf("Error parsing address: %v", err)
		}

		// If we get a missing port error, we try to coerce the whole hostPort
		// into an address. This allows the user to supply just a host address
		// instead of always requiring a host:ip pair.
		addr = hostPort
	}

	if addr == "" {
		return fmt.Errorf("Error parsing blank address")
	}
	ip := net.ParseIP(addr)
	if ip == nil {
		return fmt.Errorf("Error parsing address %q: not a valid IP address", ip)
	}

	if portStr != "" {
		port, err := strconv.Atoi(portStr)
		if err != nil {
			return fmt.Errorf("Error parsing port: %s", err)
		}
		tcpAddr.Port = port
	}
	tcpAddr.IP = ip

	return nil
}