diff --git a/cmd/go-diff/go-diff.go b/cmd/go-diff/go-diff.go index 5807060..0a04beb 100644 --- a/cmd/go-diff/go-diff.go +++ b/cmd/go-diff/go-diff.go @@ -1,6 +1,7 @@ package main import ( + "errors" "flag" "fmt" "io" @@ -45,7 +46,7 @@ func main() { if fdiff != nil { label = fmt.Sprintf("orig(%s) new(%s)", fdiff.OrigName, fdiff.NewName) } - if err == io.EOF { + if errors.Is(err, io.EOF) { break } if err != nil { diff --git a/diff/diff_test.go b/diff/diff_test.go index fc42e0b..2159d9a 100644 --- a/diff/diff_test.go +++ b/diff/diff_test.go @@ -2,7 +2,7 @@ package diff import ( "bytes" - "github.com/google/go-cmp/cmp" + "errors" "io" "io/ioutil" "path/filepath" @@ -10,6 +10,8 @@ import ( "strings" "testing" "time" + + "github.com/google/go-cmp/cmp" ) func unix(sec int64) *time.Time { @@ -975,7 +977,7 @@ func TestParseMultiFileDiffAndPrintMultiFileDiffIncludingTrailingContent(t *test if fd != nil { diffs = append(diffs, fd) } - if err == io.EOF { + if errors.Is(err, io.EOF) { break } if err != nil { diff --git a/diff/parse.go b/diff/parse.go index b73e230..229cee7 100644 --- a/diff/parse.go +++ b/diff/parse.go @@ -74,9 +74,11 @@ func (r *MultiFileDiffReader) ReadFileWithTrailingContent() (*FileDiff, string, fd, err := fr.ReadAllHeaders() if err != nil { - switch e := err.(type) { - case *ParseError: - if e.Err == ErrNoFileHeader || e.Err == ErrExtendedHeadersEOF { + var e *ParseError + var oe OverflowError + switch { + case errors.As(err, &e): + if errors.Is(e.Err, ErrNoFileHeader) || errors.Is(e.Err, ErrExtendedHeadersEOF) { // Any non-diff content preceding a valid diff is included in the // extended headers of the following diff. In this way, mixed diff / // non-diff content can be parsed. Trailing non-diff content is @@ -91,8 +93,8 @@ func (r *MultiFileDiffReader) ReadFileWithTrailingContent() (*FileDiff, string, } return nil, "", err - case OverflowError: - r.nextFileFirstLine = []byte(e) + case errors.As(err, &oe): + r.nextFileFirstLine = []byte(oe) return fd, "", nil default: @@ -114,7 +116,7 @@ func (r *MultiFileDiffReader) ReadFileWithTrailingContent() (*FileDiff, string, // need to perform the check here. hr := fr.HunksReader() line, err := r.reader.readLine() - if err != nil && err != io.EOF { + if err != nil && !errors.Is(err, io.EOF) { return fd, "", err } line = bytes.TrimSuffix(line, []byte{'\n'}) @@ -124,13 +126,12 @@ func (r *MultiFileDiffReader) ReadFileWithTrailingContent() (*FileDiff, string, r.line = fr.line r.offset = fr.offset if err != nil { - if e0, ok := err.(*ParseError); ok { - if e, ok := e0.Err.(*ErrBadHunkLine); ok { - // This just means we finished reading the hunks for the - // current file. See the ErrBadHunkLine doc for more info. - r.nextFileFirstLine = e.Line - return fd, "", nil - } + var e *ErrBadHunkLine + if errors.As(err, &e) { + // This just means we finished reading the hunks for the + // current file. See the ErrBadHunkLine doc for more info. + r.nextFileFirstLine = e.Line + return fd, "", nil } return nil, "", err } @@ -152,7 +153,7 @@ func (r *MultiFileDiffReader) ReadAllFiles() ([]*FileDiff, error) { if d != nil { ds = append(ds, d) } - if err == io.EOF { + if errors.Is(err, io.EOF) { return ds, nil } if err != nil { @@ -224,13 +225,15 @@ func (r *FileDiffReader) ReadAllHeaders() (*FileDiff, error) { fd := &FileDiff{} fd.Extended, err = r.ReadExtendedHeaders() - if pe, ok := err.(*ParseError); ok && pe.Err == ErrExtendedHeadersEOF { + var pe *ParseError + var oe OverflowError + if errors.As(err, &pe) && errors.Is(pe.Err, ErrExtendedHeadersEOF) { wasEmpty := handleEmpty(fd) if wasEmpty { return fd, nil } return fd, err - } else if _, ok := err.(OverflowError); ok { + } else if errors.As(err, &oe) { handleEmpty(fd) return fd, err } else if err != nil { @@ -305,7 +308,7 @@ func (r *FileDiffReader) readOneFileHeader(prefix []byte) (filename string, time if r.fileHeaderLine == nil { var err error line, err = r.reader.readLine() - if err == io.EOF { + if errors.Is(err, io.EOF) { return "", nil, &ParseError{r.line, r.offset, ErrNoFileHeader} } else if err != nil { return "", nil, err @@ -363,7 +366,7 @@ func (r *FileDiffReader) ReadExtendedHeaders() ([]string, error) { if r.fileHeaderLine == nil { var err error line, err = r.reader.readLine() - if err == io.EOF { + if errors.Is(err, io.EOF) { return xheaders, &ParseError{r.line, r.offset, ErrExtendedHeadersEOF} } else if err != nil { return xheaders, err @@ -660,7 +663,7 @@ func (r *HunksReader) ReadHunk() (*Hunk, error) { } else { line, err = r.reader.readLine() if err != nil { - if err == io.EOF && r.hunk != nil { + if errors.Is(err, io.EOF) && r.hunk != nil { return r.hunk, nil } return nil, err @@ -825,7 +828,7 @@ func (r *HunksReader) ReadAllHunks() ([]*Hunk, error) { linesRead := int32(0) for { hunk, err := r.ReadHunk() - if err == io.EOF { + if errors.Is(err, io.EOF) { return hunks, nil } if hunk != nil { @@ -865,6 +868,10 @@ func (e *ParseError) Error() string { return fmt.Sprintf("line %d, char %d: %s", e.Line, e.Offset, e.Err) } +// Unwrap returns the underlying error so it can be inspected with +// errors.Is and errors.As. +func (e *ParseError) Unwrap() error { return e.Err } + // ErrNoHunkHeader indicates that a unified diff hunk header was // expected but not found during parsing. var ErrNoHunkHeader = errors.New("no hunk header") diff --git a/diff/reader_util.go b/diff/reader_util.go index 3356283..c2059d4 100644 --- a/diff/reader_util.go +++ b/diff/reader_util.go @@ -84,7 +84,7 @@ func (l *lineReader) nextNextLineStartsWith(prefix string) (bool, error) { // false and ignore the error when readErr is io.EOF. func (l *lineReader) lineHasPrefix(line []byte, prefix string, readErr error) (bool, error) { if readErr != nil { - if readErr == io.EOF || readErr == bufio.ErrBufferFull { + if errors.Is(readErr, io.EOF) || errors.Is(readErr, bufio.ErrBufferFull) { return false, nil } return false, readErr @@ -100,10 +100,10 @@ func (l *lineReader) lineHasPrefix(line []byte, prefix string, readErr error) (b // will return any other errors it receives from the underlying call to ReadBytes. func readLine(r *bufio.Reader, keepCR bool) ([]byte, error) { line, err := r.ReadBytes('\n') - if err == io.EOF && len(line) == 0 { + if errors.Is(err, io.EOF) && len(line) == 0 { return nil, io.EOF } - if err != nil && err != io.EOF { + if err != nil && !errors.Is(err, io.EOF) { return nil, err } if line[len(line)-1] == '\n' { diff --git a/diff/reader_util_test.go b/diff/reader_util_test.go index 760fcb7..ea15d58 100644 --- a/diff/reader_util_test.go +++ b/diff/reader_util_test.go @@ -2,6 +2,7 @@ package diff import ( "bufio" + "errors" "io" "reflect" "strings" @@ -52,7 +53,7 @@ index 0000000..3be2928`, out := []string{} for { l, err := readLine(in, false) - if err == io.EOF { + if errors.Is(err, io.EOF) { break } if err != nil { @@ -94,11 +95,11 @@ index 0000000..3be2928 if err != nil { t.Fatal(err) } - if in.cachedNextLineErr != io.EOF { + if !errors.Is(in.cachedNextLineErr, io.EOF) { t.Fatalf("lineReader has wrong cachedNextLineErr: %s", in.cachedNextLineErr) } _, err = in.readLine() - if err != io.EOF { + if !errors.Is(err, io.EOF) { t.Fatalf("readLine did not return io.EOF: %s", err) } }