diff --git a/cmd/src/batch_common.go b/cmd/src/batch_common.go index 59b462ba96..8a53fb1852 100644 --- a/cmd/src/batch_common.go +++ b/cmd/src/batch_common.go @@ -416,10 +416,14 @@ func executeBatchSpec(ctx context.Context, opts executeBatchSpecOpts) (err error execUI.DeterminingWorkspaces() workspaces, repos, err := svc.ResolveWorkspacesForBatchSpec(ctx, batchSpec, opts.flags.allowUnsupported, opts.flags.allowIgnored) if err != nil { - if repoSet, ok := err.(batches.UnsupportedRepoSet); ok { - execUI.DeterminingWorkspacesSuccess(len(workspaces), len(repos), repoSet, nil) - } else if repoSet, ok := err.(batches.IgnoredRepoSet); ok { - execUI.DeterminingWorkspacesSuccess(len(workspaces), len(repos), nil, repoSet) + var ( + unsupportedRepoSet batches.UnsupportedRepoSet + ignoredRepoSet batches.IgnoredRepoSet + ) + if errors.As(err, &unsupportedRepoSet) { + execUI.DeterminingWorkspacesSuccess(len(workspaces), len(repos), unsupportedRepoSet, nil) + } else if errors.As(err, &ignoredRepoSet) { + execUI.DeterminingWorkspacesSuccess(len(workspaces), len(repos), nil, ignoredRepoSet) } else { return errors.Wrap(err, "resolving repositories") } diff --git a/cmd/src/batch_repositories.go b/cmd/src/batch_repositories.go index 97a7ea6cfe..340b04dd22 100644 --- a/cmd/src/batch_repositories.go +++ b/cmd/src/batch_repositories.go @@ -108,9 +108,13 @@ Examples: _, repos, err := svc.ResolveWorkspacesForBatchSpec(ctx, spec, allowUnsupported, allowIgnored) if err != nil { - if _, ok := err.(batches.UnsupportedRepoSet); ok { + var ( + unsupportedRepoSet batches.UnsupportedRepoSet + ignoredRepoSet batches.IgnoredRepoSet + ) + if errors.As(err, &unsupportedRepoSet) { // This is fine, we just ignore those in the output. - } else if _, ok := err.(batches.IgnoredRepoSet); ok { + } else if errors.As(err, &ignoredRepoSet) { // This is fine, we just ignore those in the output. } else { return errors.Wrap(err, "resolving repositories") diff --git a/cmd/src/code_intel_upload.go b/cmd/src/code_intel_upload.go index d5e686babc..3e6d5d4b2e 100644 --- a/cmd/src/code_intel_upload.go +++ b/cmd/src/code_intel_upload.go @@ -231,7 +231,8 @@ func handleUploadError(accessToken string, err error) error { // for a 401 or 403 ErrUnexpectedStatusCode. Returns nil if none is found. func findAuthError(err error) *ErrUnexpectedStatusCode { // Check if it's a multi-error and scan all children. - if multi, ok := err.(errors.MultiError); ok { + var multi errors.MultiError + if errors.As(err, &multi) { for _, e := range multi.Errors() { if found := findAuthError(e); found != nil { return found diff --git a/cmd/src/run_migration_compat.go b/cmd/src/run_migration_compat.go index 6671b07da3..c4ebbf30eb 100644 --- a/cmd/src/run_migration_compat.go +++ b/cmd/src/run_migration_compat.go @@ -76,7 +76,8 @@ func runMigrated() (int, error) { if errors.HasType[*cmderrors.UsageError](err) { return 2, nil } - if e, ok := err.(*cmderrors.ExitCodeError); ok { + var e *cmderrors.ExitCodeError + if errors.As(err, &e) { if e.HasError() { return e.Code(), e } @@ -104,14 +105,16 @@ func runLegacy(cmd *command, flagSet *flag.FlagSet) (int, error) { // Execute the subcommand. if err := cmd.handler(flagSet.Args()[1:]); err != nil { - if _, ok := err.(*cmderrors.UsageError); ok { + var usageErr *cmderrors.UsageError + if errors.As(err, &usageErr) { log.Printf("error: %s\n\n", err) cmd.flagSet.SetOutput(os.Stderr) flag.CommandLine.SetOutput(os.Stderr) cmd.flagSet.Usage() return 2, nil } - if e, ok := err.(*cmderrors.ExitCodeError); ok { + var e *cmderrors.ExitCodeError + if errors.As(err, &e) { if e.HasError() { log.Println(e) } diff --git a/internal/batches/executor/executor.go b/internal/batches/executor/executor.go index 45fb7307c2..61215d9cd1 100644 --- a/internal/batches/executor/executor.go +++ b/internal/batches/executor/executor.go @@ -38,7 +38,8 @@ func (e TaskExecutionErr) Error() string { } func (e TaskExecutionErr) StatusText() string { - if stepErr, ok := e.Err.(stepFailedErr); ok { + var stepErr stepFailedErr + if errors.As(e.Err, &stepErr) { return stepErr.SingleLineError() } return e.Err.Error() diff --git a/internal/batches/executor/run_steps.go b/internal/batches/executor/run_steps.go index 67ccbfcb0b..7e93d74f87 100644 --- a/internal/batches/executor/run_steps.go +++ b/internal/batches/executor/run_steps.go @@ -745,8 +745,9 @@ func (e *errTimeoutReached) Error() string { } func reachedTimeout(cmdCtx context.Context, err error) bool { - if ee, ok := errors.Cause(err).(*exec.ExitError); ok { - if ee.String() == "signal: killed" && cmdCtx.Err() == context.DeadlineExceeded { + var ee *exec.ExitError + if errors.As(err, &ee) { + if ee.String() == "signal: killed" && errors.Is(cmdCtx.Err(), context.DeadlineExceeded) { return true } } diff --git a/internal/batches/ui/tui.go b/internal/batches/ui/tui.go index 405918909e..9375d361b8 100644 --- a/internal/batches/ui/tui.go +++ b/internal/batches/ui/tui.go @@ -317,7 +317,8 @@ func prettyPrintBatchUnlicensedError(out *output.Output, maxUnlicensedCS int, er // Pull apart the error to see if it's a licensing error: if so, we should // display a friendlier and more actionable message than the usual GraphQL // error output. - if gerrs, ok := err.(api.GraphQlErrors); ok { + var gerrs api.GraphQlErrors + if errors.As(err, &gerrs) { // A licensing error should be the sole error returned, so we'll only // pretty print if there's one error. if len(gerrs) == 1 { @@ -357,7 +358,8 @@ func prettyPrintBatchUnlicensedError(out *output.Output, maxUnlicensedCS int, er func printExecutionError(out *output.Output, err error) { // exitCodeError shouldn't generate any specific output, since it indicates // that this was done deeper in the call stack. - if _, ok := err.(*cmderrors.ExitCodeError); ok { + var exitCodeErr *cmderrors.ExitCodeError + if errors.As(err, &exitCodeErr) { return } @@ -373,10 +375,11 @@ func printExecutionError(out *output.Output, err error) { } for _, e := range errs { - if taskErr, ok := e.(executor.TaskExecutionErr); ok { + var taskErr executor.TaskExecutionErr + if errors.As(e, &taskErr) { block.Write(formatTaskExecutionErr(taskErr)) } else { - if err == context.Canceled { + if errors.Is(err, context.Canceled) { block.Writef("%sAborting", output.StyleBold) } else { block.Writef("%s%s", output.StyleBold, e.Error()) @@ -389,8 +392,13 @@ func printExecutionError(out *output.Output, err error) { } } - switch err := err.(type) { - case parallel.Errors, errors.MultiError, api.GraphQlErrors: + var ( + parErrs parallel.Errors + multiErr errors.MultiError + gqlErrs api.GraphQlErrors + ) + switch { + case errors.As(err, &parErrs), errors.As(err, &multiErr), errors.As(err, &gqlErrs): writeErrs(flattenErrs(err)) default: @@ -405,31 +413,37 @@ func printExecutionError(out *output.Output, err error) { } func flattenErrs(err error) (result []error) { - switch errs := err.(type) { - case parallel.Errors: - for _, e := range errs { + var ( + parErrs parallel.Errors + multiErr errors.MultiError + gqlErrs api.GraphQlErrors + ) + switch { + case errors.As(err, &parErrs): + for _, e := range parErrs { result = append(result, flattenErrs(e)...) } - case errors.MultiError: - for _, e := range errs.Errors() { + case errors.As(err, &multiErr): + for _, e := range multiErr.Errors() { result = append(result, flattenErrs(e)...) } - case api.GraphQlErrors: - for _, e := range errs { + case errors.As(err, &gqlErrs): + for _, e := range gqlErrs { result = append(result, flattenErrs(e)...) } default: - result = append(result, errs) + result = append(result, err) } return result } func formatTaskExecutionErr(err executor.TaskExecutionErr) string { - if ee, ok := errors.Cause(err).(*exec.ExitError); ok && ee.String() == "signal: killed" { + var ee *exec.ExitError + if errors.As(err, &ee) && ee.String() == "signal: killed" { return fmt.Sprintf( "%s%s%s: killed by interrupt signal", output.StyleBold, diff --git a/internal/batches/workspace/git.go b/internal/batches/workspace/git.go index 20c20d1b07..b4262093c5 100644 --- a/internal/batches/workspace/git.go +++ b/internal/batches/workspace/git.go @@ -29,7 +29,8 @@ func runGitCmd(ctx context.Context, dir string, args ...string) ([]byte, error) cmd.Dir = dir out, err := cmd.Output() if err != nil { - if exitErr, ok := err.(*exec.ExitError); ok { + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { return out, errors.Wrapf(err, "'git %s' failed: %s", strings.Join(args, " "), string(exitErr.Stderr)) } return out, errors.Wrapf(err, "'git %s' failed: %s", strings.Join(args, " "), string(out)) diff --git a/internal/secrets/keyring.go b/internal/secrets/keyring.go index e9bd8fbeb1..f44b9aa0d0 100644 --- a/internal/secrets/keyring.go +++ b/internal/secrets/keyring.go @@ -67,7 +67,7 @@ func (k *keyringStore) Get(key string) ([]byte, error) { return withContext(k.ctx, func() ([]byte, error) { secret, err := keyring.Get(k.serviceName, key) if err != nil { - if err == keyring.ErrNotFound { + if errors.Is(err, keyring.ErrNotFound) { return nil, ErrSecretNotFound } return nil, errors.Wrap(err, "getting item from keyring")