diff --git a/internal/shellhook/shellhook.go b/internal/shellhook/shellhook.go index b703553..13c9714 100644 --- a/internal/shellhook/shellhook.go +++ b/internal/shellhook/shellhook.go @@ -10,12 +10,81 @@ import ( "strings" ) +// Shell identifies a shell type. +type Shell int + +const ( + ShellUnknown Shell = iota + ShellPowerShell + ShellBash + ShellZsh +) + +// detectShellEnv is the detection function variable, replaceable in tests. +var detectShellEnv = detectShell + +// detectShell returns the active shell based on environment heuristics. +func detectShell() Shell { + if os.Getenv("PSModulePath") != "" { + return ShellPowerShell + } + sh := os.Getenv("SHELL") + if strings.Contains(sh, "zsh") { + return ShellZsh + } + if strings.Contains(sh, "bash") { + return ShellBash + } + return ShellUnknown +} + +// primaryProfileIndex returns the index into candidates that corresponds to the +// detected shell's profile, or -1 if none matches. +func primaryProfileIndex(shell Shell, candidates []string) int { + for i, c := range candidates { + switch shell { + case ShellPowerShell: + if strings.HasSuffix(c, ".ps1") { + return i + } + case ShellZsh: + if strings.HasSuffix(c, ".zshrc") { + return i + } + case ShellBash: + if strings.HasSuffix(c, ".bashrc") { + return i + } + } + } + return -1 +} + +// ensureFileExists creates the file (and parent directories) if it does not +// already exist. Returns nil if the file already exists. +func ensureFileExists(path string) error { + if _, err := os.Stat(path); err == nil { + return nil + } + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + return err + } + f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return err + } + return f.Close() +} + // Result describes what happened to a single profile file. type Result struct { Path string Updated bool // false = already present (Add) or line not found (Remove) } +// currentOS is the GOOS value used by candidateProfiles, replaceable in tests. +var currentOS = runtime.GOOS + // homeDir is a variable so tests can substitute a temp directory. var homeDir = os.UserHomeDir @@ -26,7 +95,7 @@ func candidateProfiles() ([]string, error) { if err != nil { return nil, err } - return profilesForOS(runtime.GOOS, home), nil + return profilesForOS(currentOS, home), nil } // profilesForOS returns candidate profile paths for a given GOOS and home @@ -49,13 +118,20 @@ func profilesForOS(goos, home string) []string { } // Add appends comment and line to every existing profile that does not -// already contain line as a complete line. Returns one Result per -// candidate profile found on disk. +// already contain line as a complete line. The detected shell's profile +// is created if it does not exist. Returns one Result per candidate +// profile found on disk. func Add(line, comment string) ([]Result, error) { candidates, err := candidateProfiles() if err != nil { return nil, err } + shell := detectShellEnv() + if idx := primaryProfileIndex(shell, candidates); idx >= 0 { + if err := ensureFileExists(candidates[idx]); err != nil { + return nil, fmt.Errorf("creating profile %s: %w", candidates[idx], err) + } + } return addToProfiles(line, comment, candidates) } diff --git a/internal/shellhook/shellhook_test.go b/internal/shellhook/shellhook_test.go index 20c3f62..8581fae 100644 --- a/internal/shellhook/shellhook_test.go +++ b/internal/shellhook/shellhook_test.go @@ -336,3 +336,156 @@ func TestRemove_WindowsPowerShellPath(t *testing.T) { t.Error("hook should be removed from PowerShell profile") } } + +// --- Shell detection tests --- + +func TestDetectShell_PowerShell(t *testing.T) { + orig := detectShellEnv + defer func() { detectShellEnv = orig }() + + detectShellEnv = func() Shell { return ShellPowerShell } + if detectShellEnv() != ShellPowerShell { + t.Error("expected PowerShell") + } +} + +func TestDetectShell_Zsh(t *testing.T) { + orig := detectShellEnv + defer func() { detectShellEnv = orig }() + + detectShellEnv = func() Shell { return ShellZsh } + if detectShellEnv() != ShellZsh { + t.Error("expected Zsh") + } +} + +func TestDetectShell_Bash(t *testing.T) { + orig := detectShellEnv + defer func() { detectShellEnv = orig }() + + detectShellEnv = func() Shell { return ShellBash } + if detectShellEnv() != ShellBash { + t.Error("expected Bash") + } +} + +func TestPrimaryProfileIndex(t *testing.T) { + candidates := []string{ + "/home/user/Documents/PowerShell/Microsoft.PowerShell_profile.ps1", + "/home/user/.bashrc", + "/home/user/.zshrc", + } + if idx := primaryProfileIndex(ShellPowerShell, candidates); idx != 0 { + t.Errorf("PowerShell: want 0, got %d", idx) + } + if idx := primaryProfileIndex(ShellBash, candidates); idx != 1 { + t.Errorf("Bash: want 1, got %d", idx) + } + if idx := primaryProfileIndex(ShellZsh, candidates); idx != 2 { + t.Errorf("Zsh: want 2, got %d", idx) + } + if idx := primaryProfileIndex(ShellUnknown, candidates); idx != -1 { + t.Errorf("Unknown: want -1, got %d", idx) + } +} + +func TestEnsureFileExists_CreatesFileAndDirs(t *testing.T) { + dir := t.TempDir() + target := filepath.Join(dir, "sub", "dir", "profile.ps1") + + if err := ensureFileExists(target); err != nil { + t.Fatalf("unexpected error: %v", err) + } + info, err := os.Stat(target) + if err != nil { + t.Fatalf("file should exist: %v", err) + } + if info.Size() != 0 { + t.Errorf("newly created file should be empty, got %d bytes", info.Size()) + } +} + +func TestEnsureFileExists_ExistingFileUnchanged(t *testing.T) { + dir := t.TempDir() + target := filepath.Join(dir, ".zshrc") + if err := os.WriteFile(target, []byte("existing content\n"), 0644); err != nil { + t.Fatal(err) + } + + if err := ensureFileExists(target); err != nil { + t.Fatalf("unexpected error: %v", err) + } + data, _ := os.ReadFile(target) + if string(data) != "existing content\n" { + t.Error("existing file should not be modified") + } +} + +func TestAdd_CreatesProfileForDetectedShell(t *testing.T) { + dir := t.TempDir() + psPath := filepath.Join(dir, "Documents", "PowerShell", "Microsoft.PowerShell_profile.ps1") + + orig := detectShellEnv + defer func() { detectShellEnv = orig }() + detectShellEnv = func() Shell { return ShellPowerShell } + + origOS := currentOS + defer func() { currentOS = origOS }() + currentOS = "windows" + + // Profile doesn't exist yet — Add should create it + origHome := homeDir + defer func() { homeDir = origHome }() + homeDir = func() (string, error) { return dir, nil } + + results, err := Add("sap-devs tip", "# SAP developer tips") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + found := false + for _, r := range results { + if r.Path == psPath && r.Updated { + found = true + } + } + if !found { + t.Errorf("expected PowerShell profile to be created and updated, got %+v", results) + } + data, _ := os.ReadFile(psPath) + if !strings.Contains(string(data), "sap-devs tip") { + t.Error("hook should be present in newly created PowerShell profile") + } +} + +func TestAdd_DetectedShellZsh_CreatesZshrc(t *testing.T) { + dir := t.TempDir() + zshrc := filepath.Join(dir, ".zshrc") + + orig := detectShellEnv + defer func() { detectShellEnv = orig }() + detectShellEnv = func() Shell { return ShellZsh } + + origOS := currentOS + defer func() { currentOS = origOS }() + currentOS = "linux" + + origHome := homeDir + defer func() { homeDir = origHome }() + homeDir = func() (string, error) { return dir, nil } + + results, err := Add("sap-devs tip", "# SAP developer tips") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + found := false + for _, r := range results { + if r.Path == zshrc && r.Updated { + found = true + } + } + if !found { + t.Errorf("expected .zshrc to be created and updated, got %+v", results) + } +}