Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 79 additions & 3 deletions internal/shellhook/shellhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,81 @@ import (
"strings"
)

// Shell identifies a shell type.
type Shell int

const (
ShellUnknown Shell = iota
ShellPowerShell
ShellBash
ShellZsh
)

// detectShellEnv is the detection function variable, replaceable in tests.
var detectShellEnv = detectShell

// detectShell returns the active shell based on environment heuristics.
func detectShell() Shell {
if os.Getenv("PSModulePath") != "" {
return ShellPowerShell
}
sh := os.Getenv("SHELL")
if strings.Contains(sh, "zsh") {
return ShellZsh
}
if strings.Contains(sh, "bash") {
return ShellBash
}
return ShellUnknown
}

// primaryProfileIndex returns the index into candidates that corresponds to the
// detected shell's profile, or -1 if none matches.
func primaryProfileIndex(shell Shell, candidates []string) int {
for i, c := range candidates {
switch shell {
case ShellPowerShell:
if strings.HasSuffix(c, ".ps1") {
return i
}
case ShellZsh:
if strings.HasSuffix(c, ".zshrc") {
return i
}
case ShellBash:
if strings.HasSuffix(c, ".bashrc") {
return i
}
}
}
return -1
}

// ensureFileExists creates the file (and parent directories) if it does not
// already exist. Returns nil if the file already exists.
func ensureFileExists(path string) error {
if _, err := os.Stat(path); err == nil {
return nil
}
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
return err
}
f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return err
}
return f.Close()
}

// Result describes what happened to a single profile file.
type Result struct {
Path string
Updated bool // false = already present (Add) or line not found (Remove)
}

// currentOS is the GOOS value used by candidateProfiles, replaceable in tests.
var currentOS = runtime.GOOS

// homeDir is a variable so tests can substitute a temp directory.
var homeDir = os.UserHomeDir

Expand All @@ -26,7 +95,7 @@ func candidateProfiles() ([]string, error) {
if err != nil {
return nil, err
}
return profilesForOS(runtime.GOOS, home), nil
return profilesForOS(currentOS, home), nil
}

// profilesForOS returns candidate profile paths for a given GOOS and home
Expand All @@ -49,13 +118,20 @@ func profilesForOS(goos, home string) []string {
}

// Add appends comment and line to every existing profile that does not
// already contain line as a complete line. Returns one Result per
// candidate profile found on disk.
// already contain line as a complete line. The detected shell's profile
// is created if it does not exist. Returns one Result per candidate
// profile found on disk.
func Add(line, comment string) ([]Result, error) {
candidates, err := candidateProfiles()
if err != nil {
return nil, err
}
shell := detectShellEnv()
if idx := primaryProfileIndex(shell, candidates); idx >= 0 {
if err := ensureFileExists(candidates[idx]); err != nil {
return nil, fmt.Errorf("creating profile %s: %w", candidates[idx], err)
}
}
return addToProfiles(line, comment, candidates)
}

Expand Down
153 changes: 153 additions & 0 deletions internal/shellhook/shellhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,3 +336,156 @@ func TestRemove_WindowsPowerShellPath(t *testing.T) {
t.Error("hook should be removed from PowerShell profile")
}
}

// --- Shell detection tests ---

func TestDetectShell_PowerShell(t *testing.T) {
orig := detectShellEnv
defer func() { detectShellEnv = orig }()

detectShellEnv = func() Shell { return ShellPowerShell }
if detectShellEnv() != ShellPowerShell {
t.Error("expected PowerShell")
}
}

func TestDetectShell_Zsh(t *testing.T) {
orig := detectShellEnv
defer func() { detectShellEnv = orig }()

detectShellEnv = func() Shell { return ShellZsh }
if detectShellEnv() != ShellZsh {
t.Error("expected Zsh")
}
}

func TestDetectShell_Bash(t *testing.T) {
orig := detectShellEnv
defer func() { detectShellEnv = orig }()

detectShellEnv = func() Shell { return ShellBash }
if detectShellEnv() != ShellBash {
t.Error("expected Bash")
}
}

func TestPrimaryProfileIndex(t *testing.T) {
candidates := []string{
"/home/user/Documents/PowerShell/Microsoft.PowerShell_profile.ps1",
"/home/user/.bashrc",
"/home/user/.zshrc",
}
if idx := primaryProfileIndex(ShellPowerShell, candidates); idx != 0 {
t.Errorf("PowerShell: want 0, got %d", idx)
}
if idx := primaryProfileIndex(ShellBash, candidates); idx != 1 {
t.Errorf("Bash: want 1, got %d", idx)
}
if idx := primaryProfileIndex(ShellZsh, candidates); idx != 2 {
t.Errorf("Zsh: want 2, got %d", idx)
}
if idx := primaryProfileIndex(ShellUnknown, candidates); idx != -1 {
t.Errorf("Unknown: want -1, got %d", idx)
}
}

func TestEnsureFileExists_CreatesFileAndDirs(t *testing.T) {
dir := t.TempDir()
target := filepath.Join(dir, "sub", "dir", "profile.ps1")

if err := ensureFileExists(target); err != nil {
t.Fatalf("unexpected error: %v", err)
}
info, err := os.Stat(target)
if err != nil {
t.Fatalf("file should exist: %v", err)
}
if info.Size() != 0 {
t.Errorf("newly created file should be empty, got %d bytes", info.Size())
}
}

func TestEnsureFileExists_ExistingFileUnchanged(t *testing.T) {
dir := t.TempDir()
target := filepath.Join(dir, ".zshrc")
if err := os.WriteFile(target, []byte("existing content\n"), 0644); err != nil {
t.Fatal(err)
}

if err := ensureFileExists(target); err != nil {
t.Fatalf("unexpected error: %v", err)
}
data, _ := os.ReadFile(target)
if string(data) != "existing content\n" {
t.Error("existing file should not be modified")
}
}

func TestAdd_CreatesProfileForDetectedShell(t *testing.T) {
dir := t.TempDir()
psPath := filepath.Join(dir, "Documents", "PowerShell", "Microsoft.PowerShell_profile.ps1")

orig := detectShellEnv
defer func() { detectShellEnv = orig }()
detectShellEnv = func() Shell { return ShellPowerShell }

origOS := currentOS
defer func() { currentOS = origOS }()
currentOS = "windows"

// Profile doesn't exist yet — Add should create it
origHome := homeDir
defer func() { homeDir = origHome }()
homeDir = func() (string, error) { return dir, nil }

results, err := Add("sap-devs tip", "# SAP developer tips")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

found := false
for _, r := range results {
if r.Path == psPath && r.Updated {
found = true
}
}
if !found {
t.Errorf("expected PowerShell profile to be created and updated, got %+v", results)
}
data, _ := os.ReadFile(psPath)
if !strings.Contains(string(data), "sap-devs tip") {
t.Error("hook should be present in newly created PowerShell profile")
}
}

func TestAdd_DetectedShellZsh_CreatesZshrc(t *testing.T) {
dir := t.TempDir()
zshrc := filepath.Join(dir, ".zshrc")

orig := detectShellEnv
defer func() { detectShellEnv = orig }()
detectShellEnv = func() Shell { return ShellZsh }

origOS := currentOS
defer func() { currentOS = origOS }()
currentOS = "linux"

origHome := homeDir
defer func() { homeDir = origHome }()
homeDir = func() (string, error) { return dir, nil }

results, err := Add("sap-devs tip", "# SAP developer tips")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

found := false
for _, r := range results {
if r.Path == zshrc && r.Updated {
found = true
}
}
if !found {
t.Errorf("expected .zshrc to be created and updated, got %+v", results)
}
}
Loading