package archive import ( "archive/tar" "bytes" "io" "os" "os/exec" "path/filepath" "strings" "testing" "github.com/klauspost/compress/zstd" ) func TestBuildSourceArchive_RoundTrip(t *testing.T) { dir := t.TempDir() run := func(name string, args ...string) { t.Helper() cmd := exec.Command(name, args...) cmd.Dir = dir cmd.Env = append(os.Environ(), "GIT_AUTHOR_NAME=t", "GIT_AUTHOR_EMAIL=t@t", "GIT_COMMITTER_NAME=t", "GIT_COMMITTER_EMAIL=t@t", ) out, err := cmd.CombinedOutput() if err != nil { t.Fatalf("%s %v: %v\n%s", name, args, err, out) } } run("git", "init", "-q") if err := os.WriteFile(filepath.Join(dir, "plugin.mod"), []byte("[plugin]\nname=\"x\"\nscope=\"@s\"\nversion=\"0.1.0\"\n"), 0o644); err != nil { t.Fatal(err) } if err := os.WriteFile(filepath.Join(dir, "ignored.log"), []byte("nope"), 0o644); err != nil { t.Fatal(err) } if err := os.WriteFile(filepath.Join(dir, ".gitignore"), []byte("ignored.log\n"), 0o644); err != nil { t.Fatal(err) } run("git", "add", "plugin.mod", ".gitignore") run("git", "commit", "-qm", "init") zstdBytes, err := BuildSourceArchive(dir) if err != nil { t.Fatalf("BuildSourceArchive: %v", err) } if len(zstdBytes) == 0 { t.Fatal("empty archive") } dec, err := zstd.NewReader(bytes.NewReader(zstdBytes)) if err != nil { t.Fatal(err) } defer dec.Close() tr := tar.NewReader(dec) got := map[string]string{} for { hdr, err := tr.Next() if err == io.EOF { break } if err != nil { t.Fatal(err) } buf, err := io.ReadAll(tr) if err != nil { t.Fatal(err) } got[hdr.Name] = string(buf) } if _, ok := got["plugin.mod"]; !ok { t.Errorf("expected plugin.mod in archive, got %v", keys(got)) } if _, ok := got["ignored.log"]; ok { t.Errorf("ignored.log should not be in archive (gitignored + untracked)") } } func keys(m map[string]string) []string { out := make([]string, 0, len(m)) for k := range m { out = append(out, k) } return out } func TestBuildSourceArchive_DirtyTreeShipsWorkingCopy(t *testing.T) { dir := t.TempDir() runGitArchive(t, dir, "init", "-q") modPath := filepath.Join(dir, "plugin.mod") if err := os.WriteFile(modPath, []byte("[plugin]\nname=\"x\"\nscope=\"@s\"\nversion=\"0.1.0\"\n"), 0o644); err != nil { t.Fatal(err) } runGitArchive(t, dir, "add", "plugin.mod") runGitArchive(t, dir, "commit", "-qm", "init") dirtyContents := []byte("[plugin]\nname=\"x\"\nscope=\"@s\"\nversion=\"0.1.1\"\n") if err := os.WriteFile(modPath, dirtyContents, 0o644); err != nil { t.Fatal(err) } zstdBytes, err := BuildSourceArchive(dir) if err != nil { t.Fatalf("BuildSourceArchive: %v", err) } got := readArchive(t, zstdBytes) contents, ok := got["plugin.mod"] if !ok { t.Fatalf("expected plugin.mod in archive, got %v", keys(got)) } if !strings.Contains(contents, `version="0.1.1"`) { t.Errorf("archived plugin.mod should have dirty version 0.1.1, got: %q", contents) } if strings.Contains(contents, `version="0.1.0"`) { t.Errorf("archived plugin.mod should NOT have HEAD version 0.1.0, got: %q", contents) } // Working tree should be unchanged after stash-create. postContents, err := os.ReadFile(modPath) if err != nil { t.Fatal(err) } if string(postContents) != string(dirtyContents) { t.Errorf("working tree mutated after BuildSourceArchive\nwant: %q\ngot: %q", string(dirtyContents), string(postContents)) } } func TestBuildSourceArchive_DirtyTreeOmitsUntracked(t *testing.T) { dir := t.TempDir() runGitArchive(t, dir, "init", "-q") modPath := filepath.Join(dir, "plugin.mod") if err := os.WriteFile(modPath, []byte("[plugin]\nname=\"x\"\nscope=\"@s\"\nversion=\"0.1.0\"\n"), 0o644); err != nil { t.Fatal(err) } runGitArchive(t, dir, "add", "plugin.mod") runGitArchive(t, dir, "commit", "-qm", "init") // Dirty the tracked file. if err := os.WriteFile(modPath, []byte("[plugin]\nname=\"x\"\nscope=\"@s\"\nversion=\"0.1.1\"\n"), 0o644); err != nil { t.Fatal(err) } // Add an untracked file (no git add). if err := os.WriteFile(filepath.Join(dir, "extra.txt"), []byte("not tracked"), 0o644); err != nil { t.Fatal(err) } zstdBytes, err := BuildSourceArchive(dir) if err != nil { t.Fatalf("BuildSourceArchive: %v", err) } got := readArchive(t, zstdBytes) if _, ok := got["extra.txt"]; ok { t.Errorf("untracked extra.txt should not be in archive, got %v", keys(got)) } } func runGitArchive(t *testing.T, dir string, args ...string) { t.Helper() cmd := exec.Command("git", args...) cmd.Dir = dir cmd.Env = append(os.Environ(), "GIT_AUTHOR_NAME=t", "GIT_AUTHOR_EMAIL=t@t", "GIT_COMMITTER_NAME=t", "GIT_COMMITTER_EMAIL=t@t", ) out, err := cmd.CombinedOutput() if err != nil { t.Fatalf("git %v: %v\n%s", args, err, out) } } func readArchive(t *testing.T, zstdBytes []byte) map[string]string { t.Helper() dec, err := zstd.NewReader(bytes.NewReader(zstdBytes)) if err != nil { t.Fatal(err) } defer dec.Close() tr := tar.NewReader(dec) got := map[string]string{} for { hdr, err := tr.Next() if err == io.EOF { break } if err != nil { t.Fatal(err) } buf, err := io.ReadAll(tr) if err != nil { t.Fatal(err) } got[hdr.Name] = string(buf) } return got }