small pixel drawing of a pufferfish zoa

main.go

package main

import (
	"bufio"
	"bytes"
	"context"
	"crypto/sha256"
	"fmt"
	"io"
	"log"
	"os"
	"path/filepath"
	"strings"
	"syscall"

	"mvdan.cc/sh/v3/expand"
	"mvdan.cc/sh/v3/interp"
	"mvdan.cc/sh/v3/syntax"
)

var ctx = context.Background()
var rootDir = "test/"

func main() {
	// if you run "zoa", i assume your PWD has a zoa entrypoint in it
	// if you run "zoa <repo> <branch>", i assume

	if os.Geteuid() != 0 {
		fmt.Println("you are running zoa as a non-root user. this is not advised")
		fmt.Println("but you do you")
	}

	// TODO: this writer is responsible for the random stdout
	// maybe save the stdout for debug mode somehow
	r, err := interp.New(interp.StdIO(nil, os.Stdout, os.Stderr))
	if err != nil {
		log.Fatal(err)
	}

	// set standard env vars for runtime
	r.Env, err = generateEnv()
	if err != nil {
		log.Fatal(err)
	}
	// for debuggin' fmt.Printf("%+v", r.Env)

	entrypoint := filepath.Join(rootDir, "main")
	runCommands(entrypoint, r)
}

func generateEnv() (expand.Environ, error) {
	// syscall.Uname _should_ be supported on all *nix systems and is backed
	// by posix
	var env expand.Environ
	uname := syscall.Utsname{}
	err := syscall.Uname(&uname)
	if err != nil {
		return env, err
	}

	// shell := "/bin/sh"?

	// $PATH is annoyingly non-standard, so we hardcode the var
	// to the binary paths described in the fhs standard.
	// in most cases, this will "just work" for people, but special
	// cases should be evaluated - this may need some adjustment in the future.
	// i am resistant to making it over-rideable.
	//
	// users can always call their special binaries with their full paths
	// if they are resistant to moving them for some reason.
	path := envString("PATH", "/usr/local/sbin:/usr/local/bin:/sbin:/bin:/usr/sbin:/usr/bin")
	uname_os := envString("OS", charsToString(uname.Sysname[:]))
	uname_release := envString("RELEASE", charsToString(uname.Release[:]))
	uname_arch := envString("ARCH", charsToString(uname.Machine[:]))
	if err != nil {
		return env, err
	}

	// standards are extremely annoying about hostnames.
	//
	// "Note that there is no standard that
	// says that the hostname set by sethostname(2)
	// is the same string as the nodename field of
	// the struct returned by uname() (indeed, some
	// systems allow a 256-byte hostname and an 8-byte
	// nodename), but this is true on Linux. The same
	// holds for setdomainname(2) and the domainname field."
	//
	// in practice, there's usually not a difference between HOSTNAME
	// and NODENAME, so i've chosen to only expose HOSTNAME for the
	// sake of simplicity. i'm using the Golang implementation, which
	// does call out to uname, annoyingly.
	//
	// if this becomes an issue, i'll revisit it. i doubt it though.
	// tldr: i'm ignoring that golang's os.Hostname() implementation
	// isn't standards-compliant by the letter of the law.
	// in actual practice, the hostname and nodename
	// are always identical in every case i've observed.
	//
	// and i'm exposing only 1 because otherwise things get annoying.
	// shrug.
	h, err := os.Hostname()
	if err != nil {
		return env, err
	}
	uname_hostname := fmt.Sprintf("HOSTNAME=%s", h)

	// !OS_RELEASE_* VARS ARE NOT STANDARDS-BACKED!
	// OS_* vars may or may not exist depending on the distro in question, so
	// they're not reliable _at all_. they're scraped from /etc/os-release
	// and are useful for identifying specific Linux distros, or their versions.
	//
	// if you rely on these variables, I highly suggest checking for their
	// existence with test -z before utilizing them. there be no standards here.
	os_release, err := getOSRelease()
	if err != nil {
		return env, err
	}
	osReleaseID := envString("OS_RELEASE_ID", os_release.ID)
	osReleaseVersionID := envString("OS_RELEASE_VERSION_ID", os_release.VersionID)

	env = expand.ListEnviron(path, // normie shit
		uname_os, uname_hostname, uname_release, uname_arch, // uname-derivated env vars
		osReleaseID, osReleaseVersionID) // /etc/os-release
	return env, nil
}

func envString(key string, value string) string {
	return fmt.Sprintf("%s=%s", key, value)
}

// I want to keep this list as small as possible, since
// this struct is unreliable.
// design principle: only 1 way to do common things
type OSRelease struct {
	ID        string // distro name - "arch"
	VersionID string // for debian distros, this is set to "22.04"
}

// getOsRelease parses /etc/os-release data
// into a struct
func getOSRelease() (OSRelease, error) {
	var osr = OSRelease{}
	f, err := os.Open("/etc/os-release")
	if err != nil {
		return osr, err
	}
	defer f.Close()

	scanner := bufio.NewScanner(f)
	for scanner.Scan() {
		key, value, _ := strings.Cut(scanner.Text(), "=")
		value = strings.Trim(value, `"`)
		switch key {
		case "ID":
			osr.ID = value
		case "VERSION_ID":
			osr.VersionID = value
		}
	}

	return osr, nil
}

// this is used to detect when the script
// name changes from run to run, which allows
// us to prettily-print
var lastScriptPath string

func runCommands(scriptPath string, r *interp.Runner) {
	script, err := parseFile(scriptPath)
	if err != nil {
		fmt.Println("error in " + scriptPath)
		fmt.Println(err)
		os.Exit(1)
	}

	// execute every statement individually, decorating
	// each with ->, and doing some speshul logicks against
	// certain strings
	for _, stmt := range script.Stmts {
		cmdName := commandName(stmt)
		command, after, _ := strings.Cut(cmdName, " ")

		if command == "zoa-script" {
			// recursion detected!! :3
			subScriptPath := filepath.Join(rootDir + "scripts/" + after)
			runCommands(subScriptPath, r)
			continue
		}

		if command == "zoa-file" {
			// after = "nginx /etc/nginx/conf.d/jesse.conf systemctl nginx reload"
			zoaFileParts := strings.Split(after, " ")
			if len(zoaFileParts) < 2 {
				log.Fatal("zoa-file requires 2+ arguments")
			}
			src := zoaFileParts[0]
			dst := zoaFileParts[1]
			optionalCmd := ""
			if len(zoaFileParts) > 2 {
				optionalCmd = strings.Join(zoaFileParts[2:], " ")
			}
			fmt.Printf("$ zoa-file %s %s\n", src, dst)
			dstChanged, err := zoaCopy(src, dst)
			if err != nil {
				log.Fatal(err)
			}
			// if there's an optional argument
			if optionalCmd != "" && dstChanged {
				re := strings.NewReader(optionalCmd)
				f, err := syntax.NewParser().Parse(re, "")
				if err != nil {
					log.Fatal(err)
				}

				for _, stmt := range f.Stmts {
					runCommand(ctx, stmt, r)
				}
			}
			return
		}

		// if the script name changed between runs,
		// print it
		if scriptPath != lastScriptPath {
			bluePrintln("-> " + scriptPath)
			lastScriptPath = scriptPath
		}

		fmt.Printf("$ %s\n", cmdName)
		err = r.Run(ctx, stmt)
		if err != nil {
			// ignore err here bc it's just the status code
			os.Exit(1)
		}
	}
}

// zoaCopy copies a file from zoaRoot/scripts/src to dst.
// if the dst was changed, zoaCopy will return true,
// otherwise it will return false
// zoaCopy defaults to 0666 for permissions (before umask)
func zoaCopy(src string, dst string) (bool, error) {
	src = filepath.Join(rootDir + "scripts/" + src)
	srcChk, err := checksumFile(src)
	if err != nil {
		// source file should always exist, return error
		return false, err
	}
	dstChk, err := checksumFile(dst)
	if err != nil {
		// dstfile may not exist for a million
		// reasons, set checksum to blank
		// to force a mismatch
		dstChk = ""
	}

	if srcChk == dstChk {
		fmt.Println("file unchanged")
		return false, nil
	}
	// TODO: pass the file through a templating engine
	// - expand vars
	// - loop support?

	err = Copy(src, dst)
	if err != nil {
		return false, err
	}
	return true, nil
}

// Copy just copiez files, it's only used by zoaCopy rn
func Copy(src, dst string) error {
	in, err := os.Open(src)
	if err != nil {
		return err
	}
	defer in.Close()

	out, err := os.Create(dst)
	if err != nil {
		return err
	}
	defer out.Close()

	_, err = io.Copy(out, in)
	if err != nil {
		return err
	}
	return out.Close()
}

func checksumFile(file string) (string, error) {
	f, err := os.Open(file)
	if err != nil {
		return "", err
	}
	defer f.Close()

	h := sha256.New()
	if _, err := io.Copy(h, f); err != nil {
		return "", err
	}

	checksum := fmt.Sprintf("%x", h.Sum(nil))
	return checksum, nil
}

func runCommand(c context.Context, s *syntax.Stmt, r *interp.Runner) {
	name := commandName(s)
	fmt.Printf("$ %s\n", name)
	err := r.Run(c, s)
	if err != nil {
		os.Exit(1)
	}
}

func commandName(statement *syntax.Stmt) string {
	b := new(bytes.Buffer)
	syntax.NewPrinter().Print(b, statement)
	return b.String()
}

func parseFile(filename string) (*syntax.File, error) {
	var result = &syntax.File{}
	f, err := os.Open(filename)
	if err != nil {
		return result, err
	}
	defer f.Close()
	result, err = syntax.NewParser().Parse(f, "")
	return result, err
}

func bluePrintln(s string) {
	colored := fmt.Sprintf("\x1b[%dm%s\x1b[0m", 34, s)
	fmt.Println(colored)
}

// runStatements takes a file & runs individual
// commands from that file, prepending the decorator
// and returning the first error
// func runScript(file *syntax.File) error {
// fmt.Printf("%s%s\n", decorator, output)
// return nil
// }

func charsToString(arr []int8) string {
	b := make([]byte, 0, len(arr))
	for _, v := range arr {
		if v == 0x00 {
			break
		}
		b = append(b, byte(v))
	}
	return string(b)
}