Add address parsing logic to cmd/agent, test the shit out of it
Jes Olson j3s@c3f.net
Mon, 20 Feb 2023 15:18:53 -0800
6 files changed,
282 insertions(+),
6 deletions(-)
M
agent/agent.go
→
agent/agent.go
@@ -58,7 +58,6 @@ "os"
"sync" "time" - "github.com/hashicorp/memberlist" "github.com/hashicorp/serf/serf" "golang.org/x/exp/slog" )@@ -86,10 +85,12 @@ // New returns an agent lmao
func New(config *Config) *Agent { agent := Agent{} serfConfig := serf.DefaultConfig() - serfConfig.MemberlistConfig = memberlist.DefaultWANConfig() - // XXX: why do serf and cascade use the same event channel? eventCh := make(chan serf.Event, 1024) + agent.eventCh = eventCh + serfConfig.EventCh = eventCh + + // XXX: why do serf and cascade use the same event channel? agent.eventCh = eventCh serfConfig.EventCh = eventCh
M
agent/config.go
→
agent/config.go
@@ -1,11 +1,12 @@
package agent import ( + "fmt" "net" "os" ) -const DefaultBindPort int = 4443 +const DefaultSerfPort int = 4443 const DefaultClientPort int = 4444 func DefaultConfig() *Config {@@ -16,7 +17,7 @@ }
// TODO: figure out how to default the listeners cfg := Config{} - cfg.BindAddr = &net.TCPAddr{IP: []byte{0, 0, 0, 0}, Port: DefaultBindPort} + cfg.BindAddr = &net.TCPAddr{IP: []byte{0, 0, 0, 0}, Port: DefaultSerfPort} cfg.ClientAddr = &net.TCPAddr{IP: []byte{127, 0, 0, 1}, Port: DefaultClientPort} cfg.NodeName = hostname@@ -29,3 +30,27 @@ ClientAddr *net.TCPAddr
NodeName string StartJoin []string } + +// BindAddrParts returns the parts of the BindAddr that should be +// used to configure Serf. +func (c *Config) AddrParts(address string) (string, int, error) { + checkAddr := address + +START: + _, _, err := net.SplitHostPort(checkAddr) + if ae, ok := err.(*net.AddrError); ok && ae.Err == "missing port in address" { + checkAddr = fmt.Sprintf("%s:%d", checkAddr, DefaultSerfPort) + goto START + } + if err != nil { + return "", 0, err + } + + // Get the address + addr, err := net.ResolveTCPAddr("tcp", checkAddr) + if err != nil { + return "", 0, err + } + + return addr.IP.String(), addr.Port, nil +}
A
command/agent/agent.go
@@ -0,0 +1,162 @@
+package agent + +import ( + "fmt" + "log" + "net" + "os" + "os/signal" + "strconv" + "strings" + "syscall" + "time" + + "git.j3s.sh/cascade/agent" +) + +// gracefulTimeout controls how long we wait before forcefully terminating +// note that this value interacts with serf's LeavePropagateDelay config +const gracefulTimeout = 10 * time.Second + +func Run(args []string) { + // do flags + config := configureAgent() + agent := agent.New(config) + if err := agent.Start(); err != nil { + fmt.Println(err) + os.Exit(1) + } + defer agent.Shutdown() + // join any specified startup nodes + if err := startupJoin(agent); err != nil { + fmt.Println(err) + os.Exit(1) + } + if err := handleSignals(agent); err != nil { + fmt.Println(err) + os.Exit(1) + } +} + +// handleSignals blocks until we get an exit-causing signal +func handleSignals(agent *agent.Agent) error { + signalCh := make(chan os.Signal, 4) + signal.Notify(signalCh, os.Interrupt, syscall.SIGTERM) + + // Wait for a signal + var sig os.Signal + select { + case s := <-signalCh: + sig = s + case <-agent.ShutdownCh(): + // Agent is already shutdown! + return nil + } + fmt.Fprintf(os.Stderr, "caught signal %s", sig) + + // Check if we should do a graceful leave + graceful := false + if sig == os.Interrupt || sig == syscall.SIGTERM { + graceful = true + } + + // Bail fast if not doing a graceful leave + if !graceful { + fmt.Fprintf(os.Stderr, "leave cluster with zero grace") + return nil + } + + // Attempt a graceful leave + gracefulCh := make(chan struct{}) + fmt.Printf("shutting down gracefully") + go func() { + if err := agent.Leave(); err != nil { + fmt.Println("Error: ", err) + return + } + close(gracefulCh) + }() + + // Wait for leave or another signal + select { + case <-signalCh: + return fmt.Errorf("idfk") + case <-time.After(gracefulTimeout): + return fmt.Errorf("leave timed out") + case <-gracefulCh: + return nil + } +} + +func startupJoin(a *agent.Agent) error { + if len(a.Config.StartJoin) == 0 { + return nil + } + + n, err := a.Join(a.Config.StartJoin) + if err != nil { + return err + } + if n > 0 { + log.Printf("issue join request nodes=%d\n", n) + } + + return nil +} + +func configureAgent() *agent.Config { + config := agent.DefaultConfig() + // CASCADE_BIND=192.168.0.15:12345 + if os.Getenv("CASCADE_BIND") != "" { + err := parseFlagAddress(os.Getenv("CASCADE_BIND"), config.BindAddr) + if err != nil { + fmt.Printf("Error parsing CASCADE_BIND: %s\n", err) + os.Exit(1) + } + } + // CASCADE_JOIN=127.0.0.1,127.0.0.5 + if os.Getenv("CASCADE_JOIN") != "" { + config.StartJoin = strings.Split(os.Getenv("CASCADE_JOIN"), ",") + } + // CASCADE_NAME=nostromo.j3s.sh + if os.Getenv("CASCADE_NAME") != "" { + config.NodeName = os.Getenv("CASCADE_NAME") + } + return config +} + +// parseFlagAddress takes a colon-delimited host:port pair as a string, parses +// out the ip (and optionally, the port), and modifies the passed TCPAddr with +// the resulting values. +func parseFlagAddress(hostPort string, tcpAddr *net.TCPAddr) error { + addr, portStr, err := net.SplitHostPort(hostPort) + if err != nil { + if !strings.Contains(err.Error(), "missing port in address") { + return fmt.Errorf("Error parsing address: %v", err) + } + + // If we get a missing port error, we try to coerce the whole hostPort + // into an address. This allows the user to supply just a host address + // instead of always requiring a host:ip pair. + addr = hostPort + } + + if addr == "" { + return fmt.Errorf("Error parsing blank address") + } + ip := net.ParseIP(addr) + if ip == nil { + return fmt.Errorf("Error parsing address %q: not a valid IP address", ip) + } + + if portStr != "" { + port, err := strconv.Atoi(portStr) + if err != nil { + return fmt.Errorf("Error parsing port: %s", err) + } + tcpAddr.Port = port + } + tcpAddr.IP = ip + + return nil +}
A
command/agent/agent_test.go
@@ -0,0 +1,89 @@
+package agent + +import ( + "strings" + "testing" + + "git.j3s.sh/cascade/agent" +) + +func TestParseFlagAddress(t *testing.T) { + type c struct { + inputAddr string + expectIP string + expectPort int + expectErrStr string + } + cases := []c{ + { + inputAddr: "127.0.0.1", + expectIP: "127.0.0.1", + expectPort: 4443, + }, + { + inputAddr: "127.0.0.1:6969", + expectIP: "127.0.0.1", + expectPort: 6969, + }, + { + inputAddr: "192.168.0.4:420", + expectIP: "192.168.0.4", + expectPort: 420, + }, + { + inputAddr: "127.0.0.3:", + expectIP: "127.0.0.3", + expectPort: 4443, + }, + // error cases + { + inputAddr: "", + expectIP: "0.0.0.0", + expectPort: 4443, + expectErrStr: "Error parsing blank address", + }, + { + inputAddr: ":1234", + expectIP: "0.0.0.0", + expectPort: 4443, + expectErrStr: "Error parsing blank address", + }, + { + inputAddr: "127.0.0.1:abcd", + expectIP: "0.0.0.0", + expectPort: 4443, + expectErrStr: "Error parsing port", + }, + { + inputAddr: "127.0.0.256:6969", + expectIP: "0.0.0.0", + expectPort: 4443, + expectErrStr: "Error parsing address", + }, + } + + for _, tc := range cases { + addr := agent.DefaultConfig().BindAddr + expectErr := tc.expectErrStr != "" + + err := parseFlagAddress(tc.inputAddr, addr) + if expectErr && err == nil { + t.Errorf("Expected error '%s', but received none", tc.expectErrStr) + } + if !expectErr && err != nil { + t.Errorf("Unexpected error: '%s'", err) + } + // errors we expect are unwrapped here + if expectErr && err != nil { + if !strings.Contains(err.Error(), tc.expectErrStr) { + t.Errorf("Expected error '%s', got '%s'", tc.expectErrStr, err) + } + } + if tc.expectIP != addr.IP.String() { + t.Errorf("Expected IP '%s', got '%s'", tc.expectIP, addr.IP.String()) + } + if tc.expectPort != addr.Port { + t.Errorf("Expected port '%d', got '%d'", tc.expectPort, addr.Port) + } + } +}