small pixel drawing of a pufferfish zoa

shell/shell.go

package shell

import (
	"context"
	"fmt"
	"log"
	"os"
	"os/exec"
	"path/filepath"
	"strings"

	"j3s.sh/zoa/color"
	"j3s.sh/zoa/utils"

	"mvdan.cc/sh/v3/interp"
	"mvdan.cc/sh/v3/syntax"
)

// this is used to detect when the script
// name changes from run to run, which allows
// us to prettily-print
// var lastScriptPath string

// r represents our global runner - we use a global here
// because we're very bad at coding, and globals are ez
// 😎
// var r *interp.Runner

// we use this to detect the working dir of
// the primary script the user called
// var zoaRoot string

type Session struct {
	Runner     *interp.Runner
	ScriptPath string
	SubShell   bool
}

func (s *Session) Run() {
	color.BluePrintln("./" + filepath.Base(s.ScriptPath))
	ctx := context.TODO()
	ctx = context.WithValue(ctx, "script", s.ScriptPath)
	f, err := parseFile(s.ScriptPath)
	if err != nil {
		log.Fatal(err)
	}
	r := s.Runner
	if s.SubShell {
		r = s.Runner.Subshell()
	}
	for _, stmt := range f.Stmts {
		// 14:8
		pos := stmt.Pos()
		ctx = context.WithValue(ctx, "position", pos)
		err = r.Run(ctx, stmt)
		if err != nil {
			// we yell but keep plowing forward
			// twirling, twirling
			color.ZoaYell(fmt.Sprintf("%s:%s: %s", s.ScriptPath, pos, err.Error()))
		}
	}
}

func CallHandler(ctx context.Context, args []string) ([]string, error) {
	var script string
	script = ctx.Value("script").(string)
	command := args[0]

	fmt.Println()
	// "shell tip" commands should go here
	if command == "echo" {
		color.ZoaSay(`hewo! please prefer printf over echo where possible :3
example: printf '%s\n' "hello world!"`)
	}
	fmt.Printf("%s $ %s\n", script, strings.Join(args, " "))

	switch command {
	case "zoa-script":
		if len(args) <= 1 {
			// log.Fatal here?
			return args, fmt.Errorf("zoa-script requires 1 argument: the script to run")
		}
		// session.Run()
		// subScript := filepath.Join("scripts", args[1])
		// err := RunScript(subScript)
		// if err != nil {
		// return args, err
		// }
		// args has to have _something_ in it or we panic at calltime
		return []string{"true"}, nil
	case "zoa-file":
		if len(args) <= 2 {
			// log.Fatal here?
			return args, fmt.Errorf("zoa-file requires 2+ arguments: zoa-file <source> <destination> <optional command>")
		}
		src := filepath.Join(filepath.Dir(string(script)), "files", args[1])
		dst := args[2]
		// todo: remove? orrrr
		// fmt.Printf("$ zoa-file %s %s\n", src, dst)
		var cmd []string
		if len(args) >= 3 {
			cmd = args[3:]
		}
		dstChanged, err := zoaCopy(src, dst)
		if err != nil {
			log.Fatal(err)
		}
		// if the destination changes and we have
		// a command specified, we def want to run it
		if len(cmd) > 0 && dstChanged {
			args = cmd
		} else {
			return []string{"true"}, nil
		}
	}

	return args, nil
}

func parseFile(filename string) (*syntax.File, error) {
	var result = &syntax.File{}
	f, err := os.Open(filename)
	if err != nil {
		return result, err
	}
	defer f.Close()
	result, err = syntax.NewParser().Parse(f, "")
	return result, err
}

// zoaCopy copies a file from zoaRoot/scripts/src to dst.
// if the dst was changed, zoaCopy will return true,
// otherwise it will return false
// zoaCopy defaults to 0666 for permissions (before umask)
func zoaCopy(src string, dst string) (bool, error) {
	srcChk, err := utils.ChecksumFile(src)
	if err != nil {
		// source file should always exist, return error
		return false, err
	}
	dstChk, err := utils.ChecksumFile(dst)
	if err != nil {
		// dstfile may not exist for a million
		// reasons, set checksum to blank
		// to force a mismatch
		dstChk = ""
	}

	if srcChk == dstChk {
		fmt.Println(dst + " was unchanged")
		return false, nil
	}
	// TODO: pass the file through a templating engine
	// - expand vars
	// - loop support?

	err = utils.Copy(src, dst)
	if err != nil {
		return false, err
	}
	return true, nil
}

// the code down under is dragons and shit, it's awful, ignore pls

func SSH(command string, hostlist []string) error {
	_, err := exec.LookPath("ssh")
	if err != nil {
		log.Fatal(err)
	}
	for _, server := range hostlist {
		cmd := exec.Command("ssh", server, command)
		cmd.Stdout = os.Stdout
		err := cmd.Run()
		return err
	}
	return nil
}

func SCP(file string, hostlist []string) error {
	_, err := exec.LookPath("scp")
	if err != nil {
		log.Fatal(err)
	}
	for _, server := range hostlist {
		cmd := exec.Command("scp", file, server+":")
		cmd.Stdout = os.Stdout
		err := cmd.Run()
		return err
	}
	return nil
}

func SCPDir(dir string, hostlist []string) error {
	_, err := exec.LookPath("scp")
	if err != nil {
		log.Fatal(err)
	}
	for _, server := range hostlist {
		cmd := exec.Command("scp", "-r", dir, server+":"+dir)
		cmd.Stdout = os.Stdout
		err := cmd.Run()
		return err
	}
	return nil
}