small pixel drawing of a pufferfish cascade

internal/cli/agent.go

package cli

import (
	"flag"
	"fmt"
	"log"
	"os"
	"os/signal"
	"strings"
	"syscall"
	"time"

	"git.j3s.sh/cascade/internal/agent"
)

type agentCommand struct {
	// gracefulTimeout controls how long we wait before forcefully terminating
	// note that this value interacts with serf's LeavePropagateDelay config
	gracefulTimeout time.Duration

	flagBindDNS  string
	flagBindHTTP string
	flagBindSerf string
	flagJoin     string
	flagNode     string
}

func (c *agentCommand) Usage() {
	fmt.Print(
		`usage: cascade agent [flags]

  this command starts the cascade agent, which is responsible
  for basically everything, including service registration,
  health checking, cluster membership, and hosting the API.

flags:
  -bind-dns=<addr>
    address the DNS server binds to (default = 127.0.0.1:8600)

  -bind-http=<addr>
    address the http api/interace binds to (default = 127.0.0.1:8500)

  -bind-serf=<addr>
    address the serf agent binds to (default = 0.0.0.0:8301)

  -join=<addrs>
    comma-separated address of agents to join at start time (default = nil)

  -node=<name>
    name of this node, must be globally unique (default = hostname)
`)
}

func (c *agentCommand) Init(args []string) agentCommand {
	c.gracefulTimeout = 10 * time.Second

	flags := flag.NewFlagSet("", flag.ContinueOnError)
	flags.Usage = c.Usage
	flags.StringVar(&c.flagBindDNS, "bind-dns", "", "")
	flags.StringVar(&c.flagBindHTTP, "bind-http", "", "")
	flags.StringVar(&c.flagBindSerf, "bind-serf", "", "")
	flags.StringVar(&c.flagJoin, "join", "", "")
	flags.StringVar(&c.flagNode, "node", "", "")

	if err := flags.Parse(args); err != nil {
		fmt.Println(err)
		os.Exit(1)
	}

	return *c
}

func RunAgent(args []string) {
	c := agentCommand{}
	c.Init(args)

	config, err := c.getAgentConfig()
	if err != nil {
		fmt.Println(err)
		os.Exit(1)
	}
	agent := agent.New(config)

	fmt.Printf("-> starting cascade agent\n")
	fmt.Printf("    node '%s'\n", config.NodeName)
	if len(config.StartJoin) > 0 {
		fmt.Printf("    join '%s'\n", config.StartJoin)
	}
	fmt.Printf("    bind addrs:\n")
	fmt.Printf("        dns  '%s'\n", config.DNSBindAddr)
	fmt.Printf("        http '%s'\n", config.HTTPBindAddr)
	fmt.Printf("        serf '%s'\n", config.SerfBindAddr)
	fmt.Printf("\n-> logs\n")

	if err := agent.Start(); err != nil {
		fmt.Println(err)
		os.Exit(1)
	}
	defer agent.Shutdown()

	if err := c.startupJoin(agent); err != nil {
		fmt.Println(err)
		os.Exit(1)
	}
	if err := c.handleSignals(agent); err != nil {
		fmt.Println(err)
		os.Exit(1)
	}
}

// handleSignals blocks until we get an exit-causing signal
func (c agentCommand) handleSignals(agent *agent.Agent) error {
	signalCh := make(chan os.Signal, 4)
	signal.Notify(signalCh, os.Interrupt, syscall.SIGTERM)

	// Wait for a signal
	var sig os.Signal
	select {
	case s := <-signalCh:
		sig = s
	case <-agent.ShutdownCh():
		// Agent is already shutdown!
		return nil
	}
	fmt.Fprintf(os.Stderr, "caught signal %s", sig)

	// Check if we should do a graceful leave
	graceful := false
	if sig == os.Interrupt || sig == syscall.SIGTERM {
		graceful = true
	}

	// Bail fast if not doing a graceful leave
	if !graceful {
		fmt.Fprintf(os.Stderr, "leave cluster with zero grace")
		return nil
	}

	// Attempt a graceful leave
	gracefulCh := make(chan struct{})
	fmt.Printf("shutting down gracefully")
	go func() {
		if err := agent.Leave(); err != nil {
			fmt.Println("Error: ", err)
			return
		}
		close(gracefulCh)
	}()

	// Wait for leave or another signal
	select {
	case <-signalCh:
		return fmt.Errorf("idfk")
	case <-time.After(c.gracefulTimeout):
		return fmt.Errorf("leave timed out")
	case <-gracefulCh:
		return nil
	}
}

func (c agentCommand) startupJoin(a *agent.Agent) error {
	if len(a.Config.StartJoin) == 0 {
		return nil
	}

	n, err := a.Join(a.Config.StartJoin)
	if err != nil {
		return err
	}
	if n > 0 {
		log.Printf("issue join request nodes=%d\n", n)
	}

	return nil
}

// getAgentConfig takes a default agent config and modifies it based
// on user specified flags. It also does some a little input validation.
func (c agentCommand) getAgentConfig() (*agent.Config, error) {
	config := agent.DefaultConfig()

	if c.flagBindDNS != "" {
		if err := parseFlagAddress(c.flagBindDNS, config.DNSBindAddr); err != nil {
			return nil, err
		}
	}
	if c.flagBindHTTP != "" {
		if err := parseFlagAddress(c.flagBindHTTP, config.HTTPBindAddr); err != nil {
			return nil, err
		}
	}
	if c.flagBindSerf != "" {
		if err := parseFlagAddress(c.flagBindSerf, config.SerfBindAddr); err != nil {
			return nil, err
		}
	}
	if c.flagJoin != "" {
		// TODO: moar validation
		config.StartJoin = strings.Split(c.flagJoin, ",")
	}

	if c.flagNode != "" {
		config.NodeName = c.flagNode
	}
	return config, nil
}