Skip to content
288 changes: 288 additions & 0 deletions pkg/agent/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,15 @@
package agent

import (
"encoding/base64"
"encoding/json"
"os"
"os/exec"
"path/filepath"
"regexp"
"runtime"
"sort"
"strings"
"testing"

"github.com/Azure/agentbaker/pkg/agent/datamodel"
Expand Down Expand Up @@ -837,3 +845,283 @@ var _ = Describe("Test removeComments", func() {
})

})

// repoRoot returns the path to the AgentBaker repository root by walking up from the
// current test file until we find go.mod. This avoids hard-coding absolute paths.
func repoRoot() string {
_, filename, _, ok := runtime.Caller(0)
if !ok {
panic("unable to determine test file path")
}
dir := filepath.Dir(filename)
for {
if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil {
return dir
}
parent := filepath.Dir(dir)
if parent == dir {
panic("could not find repo root (go.mod)")
}
dir = parent
}
Comment thread
djsly marked this conversation as resolved.
}

// TestRemoveComments_ShellPatterns tests removeComments against realistic shell script
// patterns that have historically caused issues, particularly patterns where '#' appears
// inside string literals or in non-comment contexts.
//
// TestRemoveComments_ShellPatterns validates that removeComments correctly handles
// various shell script patterns without breaking functional code.
//
// Background: removeComments is a "best-effort" comment stripper (utils.go:202) that runs on
// all CSE shell scripts before template execution. It must not mangle code that contains
// '#' characters in non-comment contexts (string literals, variable expansions, grep patterns).
func TestRemoveComments_ShellPatterns(t *testing.T) {
Comment thread
djsly marked this conversation as resolved.
Comment thread
djsly marked this conversation as resolved.
tests := []struct {
name string
input string
expected string
}{
{
name: "pure comment lines are removed",
input: strings.Join([]string{
"#!/bin/bash",
"# This is a comment",
"echo hello",
"## Another comment",
"echo world",
}, "\n"),
expected: strings.Join([]string{
"#!/bin/bash",
"echo hello",
"echo world",
}, "\n"),
},
{
name: "hash inside quoted grep pattern is preserved",
input: strings.Join([]string{
` if grep -q "^#${mod} " /proc/modules 2>/dev/null; then`,
` modprobe -r "$mod"`,
` fi`,
}, "\n"),
expected: strings.Join([]string{
` if grep -q "^#${mod} " /proc/modules 2>/dev/null; then`,
` modprobe -r "$mod"`,
` fi`,
}, "\n"),
},
Comment thread
djsly marked this conversation as resolved.
{
name: "trailing comments are trimmed but code is preserved",
input: strings.Join([]string{
` local mod="$1" # module name`,
` modprobe -r "$mod" # try to unload`,
}, "\n"),
expected: strings.Join([]string{
` local mod="$1" `,
` modprobe -r "$mod" `,
}, "\n"),
},
{
name: "shebang line is preserved",
input: "#!/bin/bash\nset -euo pipefail",
expected: "#!/bin/bash\nset -euo pipefail",
},
{
name: "hash in variable expansion is not a comment",
input: strings.Join([]string{
` local count=${#array[@]}`,
` echo "${str#prefix}"`,
` echo "${str##*/}"`,
}, "\n"),
expected: strings.Join([]string{
` local count=${#array[@]}`,
` echo "${str#prefix}"`,
` echo "${str##*/}"`,
}, "\n"),
},
{
// Documents the DOA regression from PR #8475: a line starting with "# "
// inside a multi-line printf format string gets stripped by removeComments,
// breaking the script. The fix (PR #8486) was to not emit "# " lines from
// code. This test asserts the current (known-limitation) behavior.
name: "line starting with hash-space is stripped even inside string context",
input: strings.Join([]string{
`myFunc() {`,
` local desc="$1"`,
` printf '# %s\ninstall %s /bin/false\n' "$desc" "$mod"`,
`}`,
}, "\n"),
expected: strings.Join([]string{
`myFunc() {`,
` local desc="$1"`,
` printf '`,
`}`,
}, "\n"),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := removeComments([]byte(tt.input))
if diff := cmp.Diff(tt.expected, string(result)); diff != "" {
t.Errorf("removeComments() mismatch (-want +got):\n%s", diff)
}
Comment thread
djsly marked this conversation as resolved.
})
}
}
Comment thread
djsly marked this conversation as resolved.

// TestCSEScriptRoundTrip exercises the full CSE assembly pipeline for each embedded shell
// script: removeComments → gzip → base64 → base64-decode → gunzip, then validates:
// - byte-for-byte round-trip integrity (decoded output == stripped input)
// - bash -n syntax check on the decoded output (catches broken scripts)
//
// This exercises the comment-stripping and encoding stages of the production pipeline
// in getBase64EncodedGzippedCustomScript() (pkg/agent/utils.go). The Go template
// execution step is not included here since it requires a full NodeBootstrappingConfiguration.
// The comment stripping happens BEFORE template execution, so the stripped output must
// still be syntactically valid bash — a node cannot provision if any CSE script has a
// syntax error after stripping.
//
// The script list is dynamically derived by parsing variables.go and const.go source
// to find all .sh files passed to getBase64EncodedGzippedCustomScript(). If a new script
// is added to the CSE pipeline, it is automatically covered by this test.
func TestCSEScriptRoundTrip(t *testing.T) {
cseScripts := discoverCSEScripts(t)
if len(cseScripts) == 0 {
t.Fatal("no CSE scripts discovered — check variables.go and const.go parsing")
}
t.Logf("discovered %d CSE shell scripts", len(cseScripts))

artifactsDir := filepath.Join(repoRoot(), "parts")

for _, script := range cseScripts {
t.Run(filepath.Base(script), func(t *testing.T) {
decoded := cseRoundTrip(t, filepath.Join(artifactsDir, script))
cseValidateBashSyntax(t, script, decoded)
})
}
}

// discoverCSEScripts parses variables.go to find all constant names passed to
// getBase64EncodedGzippedCustomScript(), then resolves those constants to file
// paths from const.go, filtering to .sh files only.
func discoverCSEScripts(t *testing.T) []string {
t.Helper()
root := repoRoot()

// Step 1: Read variables.go and extract constant names from getBase64EncodedGzippedCustomScript() calls
variablesPath := filepath.Join(root, "pkg", "agent", "variables.go")
variablesBytes, err := os.ReadFile(variablesPath)
if err != nil {
t.Fatalf("failed to read variables.go: %v", err)
}

// Match: getBase64EncodedGzippedCustomScript(constantName, config)
callRe := regexp.MustCompile(`getBase64EncodedGzippedCustomScript\((\w+),`)
matches := callRe.FindAllStringSubmatch(string(variablesBytes), -1)
constNames := make(map[string]bool)
for _, m := range matches {
constNames[m[1]] = true
}

// Step 2: Read const.go and resolve constant names to file paths
constPath := filepath.Join(root, "pkg", "agent", "const.go")
constBytes, err := os.ReadFile(constPath)
if err != nil {
t.Fatalf("failed to read const.go: %v", err)
}

// Match: constantName = "linux/cloud-init/artifacts/..."
constRe := regexp.MustCompile(`(\w+)\s*=\s*"([^"]+)"`)
constMatches := constRe.FindAllStringSubmatch(string(constBytes), -1)
constMap := make(map[string]string)
for _, m := range constMatches {
constMap[m[1]] = m[2]
}

// Step 3: Resolve and filter to .sh files
var scripts []string
seen := make(map[string]bool)
for name := range constNames {
path, ok := constMap[name]
if !ok {
continue
}
if !strings.HasSuffix(path, ".sh") {
continue
}
if seen[path] {
continue
}
seen[path] = true
scripts = append(scripts, path)
}
sort.Strings(scripts)
return scripts
}

// cseRoundTrip reads a shell script, runs it through the production CSE pipeline
// (removeComments → gzip → base64 → decode → gunzip), validates byte-for-byte
// round-trip integrity, and returns the decoded output.
func cseRoundTrip(t *testing.T, path string) []byte {
t.Helper()

raw, err := os.ReadFile(path)
if err != nil {
t.Fatalf("failed to read %s: %v", path, err)
}

stripped := removeComments(raw)
encoded := getBase64EncodedGzippedCustomScriptFromStr(string(stripped))

gzipped, err := base64.StdEncoding.DecodeString(encoded)
if err != nil {
t.Fatalf("base64 decode failed: %v", err)
}

decoded, err := getGzipDecodedValue(gzipped)
if err != nil {
t.Fatalf("gzip decode failed: %v", err)
}

if diff := cmp.Diff(string(stripped), string(decoded)); diff != "" {
t.Errorf("round-trip mismatch (-stripped +decoded):\n%s", diff)
}

return decoded
}

// cseValidateBashSyntax runs bash -n on the decoded script to catch syntax errors
// introduced by comment stripping. Skips scripts with Go template directives.
func cseValidateBashSyntax(t *testing.T, script string, decoded []byte) {
t.Helper()

if strings.Contains(string(decoded), "{{") {
t.Logf("skipping bash -n for %s (contains Go template directives)", script)
return
}

bashPath, err := exec.LookPath("bash")
if err != nil {
t.Skip("bash not available, skipping syntax check")
}

tmpFile, err := os.CreateTemp("", "cse-roundtrip-*.sh")
if err != nil {
t.Fatalf("failed to create temp file: %v", err)
}
defer os.Remove(tmpFile.Name())

_, writeErr := tmpFile.Write(decoded)
tmpFile.Close()
if writeErr != nil {
t.Fatalf("failed to write temp file: %v", writeErr)
}

cmd := exec.Command(bashPath, "-O", "extglob", "-n", tmpFile.Name())
output, err := cmd.CombinedOutput()
if err != nil {
t.Errorf("bash -n syntax check FAILED for %s after removeComments + round-trip:\n%s\n%s",
script, string(output), err)
}
}
Loading