small pixel drawing of a pufferfish zoa

shell/shell.go

package shell

import (
	"context"
	"fmt"
	"log"
	"os"
	"os/exec"
	"path/filepath"
	"strings"

	"j3s.sh/zoa/env"
	"j3s.sh/zoa/utils"

	"mvdan.cc/sh/v3/interp"
	"mvdan.cc/sh/v3/syntax"
)

// this is used to detect when the script
// name changes from run to run, which allows
// us to prettily-print
// var lastScriptPath string

var r *interp.Runner

func init() {
	var err error
	envs, err := env.GenerateEnv()
	if err != nil {
		log.Fatal(err)
	}

	r, err = interp.New(
		interp.StdIO(nil, os.Stdout, os.Stdout),
		interp.CallHandler(CallHandler),
		interp.Env(envs),
	)
	if err != nil {
		log.Fatal(err)
	}
}

func CallHandler(ctx context.Context, args []string) ([]string, error) {
	// hc := interp.HandlerCtx(ctx)

	if args[0] == "zoa-script" {
		subScript := filepath.Join(utils.ZoaRoot, "scripts", args[1])
		// TODO: figure out how to get scripts to echo their names
		// when we resume their execution
		utils.BluePrintln("-> " + subScript)
		runScriptInSubshell(ctx, subScript)

		// args has to have something in it - true is a pretty safe bet
		args = []string{"true", "true"}
	}

	if args[0] == "zoa-file" {
		// fmt.Fprintln(hc.Stdout, "")
		// return nil
		// args = "zoa-file nginx /etc/nginx/conf.d/jesse.conf systemctl nginx reload"
		if len(args) <= 2 {
			log.Fatal("zoa-file requires 2+ arguments")
		}
		src := args[1]
		dst := args[2]
		optionalCmd := []string{}
		if len(args) >= 4 {
			optionalCmd = args[3:]
		}
		fmt.Printf("$ zoa-file %s %s\n", src, dst)
		filePath := filepath.Join(utils.ZoaRoot, "files", src)
		dstChanged, err := zoaCopy(filePath, dst)
		if err != nil {
			log.Fatal(err)
		}
		if len(optionalCmd) >= 1 && dstChanged {
			args = optionalCmd
		} else {
			args = []string{"true", "true"}
		}
	}

	// this should really just say "if zoa-file or zoa-script ran"
	if len(args) == 2 && args[0] == "true" {
		// no printy
	} else {
		fmt.Printf("$ %s\n", strings.Join(args, " "))
	}

	return args, nil

	// return interp.DefaultCallHandler(2*time.Second)(ctx, args)
}

func runScriptInSubshell(ctx context.Context, script string) error {
	r := r.Subshell()
	f, err := parseFile(script)
	if err != nil {
		return err
	}
	r.Run(ctx, f)
	return nil
}

func RunScript(s string) {
	utils.BluePrintln("(✿◠‿◠) zoa")
	script, err := parseFile(s)
	if err != nil {
		fmt.Printf("error in %s: %s\n", s, err)
		os.Exit(1)
	}
	err = r.Run(context.TODO(), script)
	if err != nil {
		fmt.Println(err)
		os.Exit(1)
	}
}

func parseFile(filename string) (*syntax.File, error) {
	var result = &syntax.File{}
	f, err := os.Open(filename)
	if err != nil {
		return result, err
	}
	defer f.Close()
	result, err = syntax.NewParser().Parse(f, "")
	return result, err
}

// zoaCopy copies a file from zoaRoot/scripts/src to dst.
// if the dst was changed, zoaCopy will return true,
// otherwise it will return false
// zoaCopy defaults to 0666 for permissions (before umask)
func zoaCopy(src string, dst string) (bool, error) {
	srcChk, err := utils.ChecksumFile(src)
	if err != nil {
		// source file should always exist, return error
		return false, err
	}
	dstChk, err := utils.ChecksumFile(dst)
	if err != nil {
		// dstfile may not exist for a million
		// reasons, set checksum to blank
		// to force a mismatch
		dstChk = ""
	}

	if srcChk == dstChk {
		fmt.Println("file unchanged")
		return false, nil
	}
	// TODO: pass the file through a templating engine
	// - expand vars
	// - loop support?

	err = utils.Copy(src, dst)
	if err != nil {
		return false, err
	}
	return true, nil
}

func SSH(command string, hostlist []string) error {
	_, err := exec.LookPath("ssh")
	if err != nil {
		log.Fatal(err)
	}
	for _, server := range hostlist {
		cmd := exec.Command("ssh", server, command)
		cmd.Stdout = os.Stdout
		err := cmd.Run()
		return err
	}
	return nil
}

func SCP(file string, hostlist []string) error {
	_, err := exec.LookPath("scp")
	if err != nil {
		log.Fatal(err)
	}
	for _, server := range hostlist {
		cmd := exec.Command("scp", file, server+":")
		cmd.Stdout = os.Stdout
		err := cmd.Run()
		return err
	}
	return nil
}

func SCPDir(dir string, hostlist []string) error {
	_, err := exec.LookPath("scp")
	if err != nil {
		log.Fatal(err)
	}
	for _, server := range hostlist {
		cmd := exec.Command("scp", "-r", dir, server+":"+dir)
		cmd.Stdout = os.Stdout
		err := cmd.Run()
		return err
	}
	return nil
}