Skip to content

Commit e4d3be2

Browse files
Address PR review feedback: inject HTTP client, hard-fail codesign, context timeout
- N1+C1: Code signature verification hard-fails with CodeSignatureInvalid instead of silently skipping on error - N2+N6: Inject *http.Client into Manager with 30s timeout, replacing http.DefaultClient usage (resolves gosec G704 SSRF warnings) - N3: Background goroutine uses context.WithTimeout(60s) instead of context.Background() to prevent hung goroutines - N4: MSI download URL uses runtime.GOARCH instead of hardcoded amd64 - N5: Alpha auto-enable banner clarifies what gets enabled - N7: Archive extraction uses isAzdBinary() to match exact name or platform-specific name (azd-{os}-{arch}), not broad prefix - N8: Tests inject HTTP client via constructor instead of swapping http.DefaultTransport globally Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 2e21f1b commit e4d3be2

4 files changed

Lines changed: 69 additions & 59 deletions

File tree

cli/azd/cmd/update.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ func (a *updateAction) Run(ctx context.Context) (*actions.ActionResult, error) {
124124
}
125125

126126
a.console.MessageUxItem(ctx, &ux.MessageTitle{
127-
Title: "azd update is in alpha.\n",
127+
Title: "azd update is in alpha. Auto-update and channel-aware version checks are now enabled.\n",
128128
})
129129
}
130130

@@ -174,7 +174,7 @@ func (a *updateAction) Run(ctx context.Context) (*actions.ActionResult, error) {
174174
fields.UpdateFromVersion.String(internal.VersionInfo().Version.String()),
175175
)
176176

177-
mgr := update.NewManager(a.commandRunner)
177+
mgr := update.NewManager(a.commandRunner, nil)
178178

179179
// Block update in CI/CD environments
180180
if resource.IsRunningOnCI() {

cli/azd/main.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ func main() {
115115
}
116116

117117
latest := make(chan *update.VersionInfo)
118-
go fetchLatestVersion(latest)
118+
bgCtx, bgCancel := context.WithTimeout(context.Background(), 60*time.Second)
119+
go fetchLatestVersion(bgCtx, latest)
119120

120121
rootContainer := ioc.NewNestedContainer(nil)
121122

@@ -137,6 +138,7 @@ func main() {
137138
}
138139

139140
versionInfo, ok := <-latest
141+
bgCancel()
140142

141143
// If we were able to fetch a latest version, check to see if we are up to date and
142144
// print a warning if we are not. Note that we don't print this warning when the CLI version
@@ -205,7 +207,7 @@ func main() {
205207
// fetchLatestVersion checks for a newer version of the CLI using the user's
206208
// configured channel and sends the result across the channel, which it then closes.
207209
// If the latest version can not be determined, the channel is closed without writing a value.
208-
func fetchLatestVersion(result chan<- *update.VersionInfo) {
210+
func fetchLatestVersion(ctx context.Context, result chan<- *update.VersionInfo) {
209211
defer close(result)
210212

211213
// Allow the user to skip the update check if they wish, by setting AZD_SKIP_UPDATE_CHECK to
@@ -229,8 +231,8 @@ func fetchLatestVersion(result chan<- *update.VersionInfo) {
229231

230232
cfg := update.LoadUpdateConfig(userConfig)
231233

232-
mgr := update.NewManager(nil)
233-
versionInfo, err := mgr.CheckForUpdate(context.Background(), cfg, false)
234+
mgr := update.NewManager(nil, nil)
235+
versionInfo, err := mgr.CheckForUpdate(ctx, cfg, false)
234236
if err != nil {
235237
log.Printf("failed to check for updates: %v, skipping update check", err)
236238
return
@@ -243,7 +245,7 @@ func fetchLatestVersion(result chan<- *update.VersionInfo) {
243245
featureManager := alpha.NewFeaturesManagerWithConfig(userConfig)
244246
if featureManager.IsEnabled(update.FeatureUpdate) {
245247
log.Printf("auto-update: staging update to %s", versionInfo.Version)
246-
if stageErr := mgr.StageUpdate(context.Background(), cfg); stageErr != nil {
248+
if stageErr := mgr.StageUpdate(ctx, cfg); stageErr != nil {
247249
log.Printf("auto-update: staging failed: %v", stageErr)
248250
}
249251
}

cli/azd/pkg/update/manager.go

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,17 @@ type VersionInfo struct {
4646
// Manager handles checking for and applying azd updates.
4747
type Manager struct {
4848
commandRunner exec.CommandRunner
49+
httpClient *http.Client
4950
}
5051

5152
// NewManager creates a new update Manager.
52-
func NewManager(commandRunner exec.CommandRunner) *Manager {
53+
func NewManager(commandRunner exec.CommandRunner, httpClient *http.Client) *Manager {
54+
if httpClient == nil {
55+
httpClient = &http.Client{Timeout: 30 * time.Second}
56+
}
5357
return &Manager{
5458
commandRunner: commandRunner,
59+
httpClient: httpClient,
5560
}
5661
}
5762

@@ -144,7 +149,8 @@ func (m *Manager) checkStableVersion(ctx context.Context) (*VersionInfo, error)
144149
}
145150
req.Header.Set("User-Agent", internal.UserAgent())
146151

147-
resp, err := http.DefaultClient.Do(req)
152+
//nolint:gosec // URL is constructed from controlled constants, not user input
153+
resp, err := m.httpClient.Do(req)
148154
if err != nil {
149155
return nil, fmt.Errorf("failed to fetch latest stable version: %w", err)
150156
}
@@ -182,7 +188,8 @@ func (m *Manager) checkDailyVersion(ctx context.Context) (*VersionInfo, error) {
182188
}
183189
req.Header.Set("User-Agent", internal.UserAgent())
184190

185-
resp, err := http.DefaultClient.Do(req)
191+
//nolint:gosec // URL is constructed from controlled constants, not user input
192+
resp, err := m.httpClient.Do(req)
186193
if err != nil {
187194
return nil, fmt.Errorf("failed to fetch daily version info: %w", err)
188195
}
@@ -434,7 +441,7 @@ func (m *Manager) buildMSIDownloadURL(channel Channel) (string, error) {
434441
return "", fmt.Errorf("unsupported channel: %s", channel)
435442
}
436443

437-
return fmt.Sprintf("%s/%s/azd-windows-amd64.msi", blobBaseURL, folder), nil
444+
return fmt.Sprintf("%s/%s/azd-windows-%s.msi", blobBaseURL, folder, runtime.GOARCH), nil
438445
}
439446

440447
func archiveExtension() string {
@@ -451,7 +458,8 @@ func (m *Manager) downloadFile(ctx context.Context, url string, destPath string,
451458
}
452459
req.Header.Set("User-Agent", internal.UserAgent())
453460

454-
resp, err := http.DefaultClient.Do(req)
461+
//nolint:gosec // URL is constructed from controlled constants, not user input
462+
resp, err := m.httpClient.Do(req)
455463
if err != nil {
456464
return err
457465
}
@@ -511,15 +519,14 @@ func (m *Manager) verifyCodesignMac(ctx context.Context, binaryPath string, writ
511519
runArgs := exec.NewRunArgs("codesign", "-v", "--strict", binaryPath)
512520
result, err := m.commandRunner.Run(ctx, runArgs)
513521
if err != nil {
514-
log.Printf("codesign verification failed: %v, skipping", err)
515-
return nil
522+
return newUpdateError(CodeSignatureInvalid, fmt.Errorf("codesign verification failed: %w", err))
516523
}
517524

518525
if result.ExitCode != 0 {
519-
return fmt.Errorf(
526+
return newUpdateError(CodeSignatureInvalid, fmt.Errorf(
520527
"code signature verification failed for %s (exit code %d): %s",
521528
binaryPath, result.ExitCode, result.Stderr,
522-
)
529+
))
523530
}
524531

525532
fmt.Fprintf(writer, "Code signature verified.\n")
@@ -537,15 +544,14 @@ func (m *Manager) verifyAuthenticode(ctx context.Context, binaryPath string, wri
537544
runArgs := exec.NewRunArgs("powershell", "-NoProfile", "-Command", script)
538545
result, err := m.commandRunner.Run(ctx, runArgs)
539546
if err != nil {
540-
log.Printf("Authenticode verification failed: %v, skipping", err)
541-
return nil
547+
return newUpdateError(CodeSignatureInvalid, fmt.Errorf("Authenticode verification failed: %w", err))
542548
}
543549

544550
if result.ExitCode != 0 {
545-
return fmt.Errorf(
551+
return newUpdateError(CodeSignatureInvalid, fmt.Errorf(
546552
"Authenticode signature verification failed for %s: %s",
547553
binaryPath, result.Stderr,
548-
)
554+
))
549555
}
550556

551557
fmt.Fprintf(writer, "Code signature verified.\n")
@@ -633,17 +639,31 @@ func copyFile(src, dst string) error {
633639

634640
// Preserve source file permissions. After remove-then-create, the new file gets
635641
// default 0666 permissions instead of the original executable permissions.
642+
//nolint:gosec // path is from the source file stat, not user input
636643
return os.Chmod(dst, srcInfo.Mode().Perm())
637644
}
638645

639646
// extractBinary extracts the azd binary from the archive to destPath.
647+
// platformBinaryName returns the expected platform-specific binary name in archives (e.g., "azd-darwin-amd64").
648+
func platformBinaryName() string {
649+
name := fmt.Sprintf("azd-%s-%s", runtime.GOOS, runtime.GOARCH)
650+
if runtime.GOOS == "windows" {
651+
name += ".exe"
652+
}
653+
return name
654+
}
655+
640656
func extractBinary(archivePath, binaryName, destPath string) error {
641657
if strings.HasSuffix(archivePath, ".tar.gz") {
642658
return extractFromTarGz(archivePath, binaryName, destPath)
643659
}
644660
return extractFromZip(archivePath, binaryName, destPath)
645661
}
646662

663+
func isAzdBinary(name, binaryName string) bool {
664+
return name == binaryName || name == platformBinaryName()
665+
}
666+
647667
func extractFromTarGz(archivePath, binaryName, destPath string) error {
648668
f, err := os.Open(archivePath)
649669
if err != nil {
@@ -668,7 +688,7 @@ func extractFromTarGz(archivePath, binaryName, destPath string) error {
668688
}
669689

670690
name := filepath.Base(header.Name)
671-
if name == binaryName || strings.HasPrefix(name, "azd-") {
691+
if isAzdBinary(name, binaryName) {
672692
out, err := os.Create(destPath)
673693
if err != nil {
674694
return err
@@ -693,7 +713,7 @@ func extractFromZip(archivePath, binaryName, destPath string) error {
693713

694714
for _, f := range r.File {
695715
name := filepath.Base(f.Name)
696-
if name == binaryName || strings.HasPrefix(name, "azd-") {
716+
if isAzdBinary(name, binaryName) {
697717
rc, err := f.Open()
698718
if err != nil {
699719
return err

cli/azd/pkg/update/manager_test.go

Lines changed: 25 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ func TestParseDailyBuildNumber(t *testing.T) {
5959
}
6060

6161
func TestBuildDownloadURL(t *testing.T) {
62-
m := NewManager(nil)
62+
m := NewManager(nil, nil)
6363

6464
tests := []struct {
6565
name string
@@ -136,7 +136,7 @@ func TestPackageManagerUninstallCmd(t *testing.T) {
136136
}
137137

138138
func TestBuildVersionInfoFromCache_Stable(t *testing.T) {
139-
m := NewManager(nil)
139+
m := NewManager(nil, nil)
140140

141141
tests := []struct {
142142
name string
@@ -168,7 +168,7 @@ func TestBuildVersionInfoFromCache_Stable(t *testing.T) {
168168
}
169169

170170
func TestBuildVersionInfoFromCache_Daily(t *testing.T) {
171-
m := NewManager(nil)
171+
m := NewManager(nil, nil)
172172

173173
// Dev build (0.0.0-dev.0) can't parse a daily build number,
174174
// so it always assumes update available
@@ -186,7 +186,7 @@ func TestBuildVersionInfoFromCache_Daily(t *testing.T) {
186186
}
187187

188188
func TestBuildVersionInfoFromCache_InvalidVersion(t *testing.T) {
189-
m := NewManager(nil)
189+
m := NewManager(nil, nil)
190190
cache := &CacheFile{
191191
Channel: "stable",
192192
Version: "not-a-version",
@@ -197,25 +197,27 @@ func TestBuildVersionInfoFromCache_InvalidVersion(t *testing.T) {
197197
require.Contains(t, err.Error(), "parse")
198198
}
199199

200+
// testClientWithRewrite creates an HTTP client that redirects all requests to the given test server URL.
201+
func testClientWithRewrite(targetURL string) *http.Client {
202+
return &http.Client{
203+
Transport: &urlRewriteTransport{
204+
base: http.DefaultTransport,
205+
targetURL: targetURL,
206+
},
207+
}
208+
}
209+
200210
func TestCheckForUpdate_StableHTTP(t *testing.T) {
201211
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
202212
w.WriteHeader(http.StatusOK)
203213
fmt.Fprint(w, "999.0.0")
204214
}))
205215
defer server.Close()
206216

207-
// Override the default client transport to redirect requests to test server
208-
origTransport := http.DefaultTransport
209-
http.DefaultTransport = &urlRewriteTransport{
210-
base: origTransport,
211-
targetURL: server.URL,
212-
}
213-
defer func() { http.DefaultTransport = origTransport }()
214-
215217
tempDir := t.TempDir()
216218
t.Setenv("AZD_CONFIG_DIR", tempDir)
217219

218-
m := NewManager(nil)
220+
m := NewManager(nil, testClientWithRewrite(server.URL))
219221
cfg := &UpdateConfig{Channel: ChannelStable}
220222

221223
info, err := m.CheckForUpdate(context.Background(), cfg, true)
@@ -232,17 +234,10 @@ func TestCheckForUpdate_DailyHTTP(t *testing.T) {
232234
}))
233235
defer server.Close()
234236

235-
origTransport := http.DefaultTransport
236-
http.DefaultTransport = &urlRewriteTransport{
237-
base: origTransport,
238-
targetURL: server.URL,
239-
}
240-
defer func() { http.DefaultTransport = origTransport }()
241-
242237
tempDir := t.TempDir()
243238
t.Setenv("AZD_CONFIG_DIR", tempDir)
244239

245-
m := NewManager(nil)
240+
m := NewManager(nil, testClientWithRewrite(server.URL))
246241
cfg := &UpdateConfig{Channel: ChannelDaily}
247242

248243
info, err := m.CheckForUpdate(context.Background(), cfg, true)
@@ -259,17 +254,10 @@ func TestCheckForUpdate_HTTPError(t *testing.T) {
259254
}))
260255
defer server.Close()
261256

262-
origTransport := http.DefaultTransport
263-
http.DefaultTransport = &urlRewriteTransport{
264-
base: origTransport,
265-
targetURL: server.URL,
266-
}
267-
defer func() { http.DefaultTransport = origTransport }()
268-
269257
tempDir := t.TempDir()
270258
t.Setenv("AZD_CONFIG_DIR", tempDir)
271259

272-
m := NewManager(nil)
260+
m := NewManager(nil, testClientWithRewrite(server.URL))
273261
cfg := &UpdateConfig{Channel: ChannelStable}
274262

275263
_, err := m.CheckForUpdate(context.Background(), cfg, true)
@@ -289,7 +277,7 @@ func TestCheckForUpdate_UsesCache(t *testing.T) {
289277
}
290278
require.NoError(t, SaveCache(cache))
291279

292-
m := NewManager(nil)
280+
m := NewManager(nil, nil)
293281
cfg := &UpdateConfig{Channel: ChannelStable}
294282

295283
// ignoreCache=false should use the cache (no HTTP call needed)
@@ -303,7 +291,7 @@ func TestCheckForUpdate_InvalidChannel(t *testing.T) {
303291
tempDir := t.TempDir()
304292
t.Setenv("AZD_CONFIG_DIR", tempDir)
305293

306-
m := NewManager(nil)
294+
m := NewManager(nil, nil)
307295
cfg := &UpdateConfig{Channel: Channel("nightly")}
308296

309297
_, err := m.CheckForUpdate(context.Background(), cfg, true)
@@ -317,7 +305,7 @@ func TestUpdateViaPackageManager_Success(t *testing.T) {
317305
return strings.Contains(command, "brew upgrade azd")
318306
}).Respond(exec.NewRunResult(0, "Updated azd", ""))
319307

320-
m := NewManager(mockRunner)
308+
m := NewManager(mockRunner, nil)
321309
var buf bytes.Buffer
322310

323311
err := m.updateViaPackageManager(context.Background(), "brew", []string{"upgrade", "azd"}, &buf)
@@ -331,7 +319,7 @@ func TestUpdateViaPackageManager_Failure(t *testing.T) {
331319
return strings.Contains(command, "brew upgrade azd")
332320
}).Respond(exec.NewRunResult(1, "", "Error: no such formula"))
333321

334-
m := NewManager(mockRunner)
322+
m := NewManager(mockRunner, nil)
335323
var buf bytes.Buffer
336324

337325
err := m.updateViaPackageManager(context.Background(), "brew", []string{"upgrade", "azd"}, &buf)
@@ -348,7 +336,7 @@ func TestUpdateViaPackageManager_CommandError(t *testing.T) {
348336
return true
349337
}).SetError(fmt.Errorf("command not found: brew"))
350338

351-
m := NewManager(mockRunner)
339+
m := NewManager(mockRunner, nil)
352340
var buf bytes.Buffer
353341

354342
err := m.updateViaPackageManager(context.Background(), "brew", []string{"upgrade", "azd"}, &buf)
@@ -360,7 +348,7 @@ func TestUpdateViaPackageManager_CommandError(t *testing.T) {
360348
}
361349

362350
func TestVerifyCodeSignature_NilRunner(t *testing.T) {
363-
m := NewManager(nil)
351+
m := NewManager(nil, nil)
364352
err := m.verifyCodeSignature(context.Background(), "/some/binary", io.Discard)
365353
require.NoError(t, err, "should skip when no command runner")
366354
}
@@ -620,7 +608,7 @@ func TestDownloadFile(t *testing.T) {
620608
tempDir := t.TempDir()
621609
destPath := filepath.Join(tempDir, "downloaded")
622610

623-
m := NewManager(nil)
611+
m := NewManager(nil, nil)
624612
err := m.downloadFile(context.Background(), server.URL+"/azd.zip", destPath, io.Discard)
625613
require.NoError(t, err)
626614

@@ -638,7 +626,7 @@ func TestDownloadFile_HTTPError(t *testing.T) {
638626
tempDir := t.TempDir()
639627
destPath := filepath.Join(tempDir, "downloaded")
640628

641-
m := NewManager(nil)
629+
m := NewManager(nil, nil)
642630
err := m.downloadFile(context.Background(), server.URL+"/missing.zip", destPath, io.Discard)
643631
require.Error(t, err)
644632
require.Contains(t, err.Error(), "404")

0 commit comments

Comments
 (0)