Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func parseVersion1(reader *bufio.Reader) (*Header, error) {
for {
b, err := reader.ReadByte()
if err != nil {
return nil, fmt.Errorf(ErrCantReadVersion1Header.Error()+": %v", err)
return nil, fmt.Errorf("%w: %w", ErrCantReadVersion1Header, err)
}
buf = append(buf, b)
if b == '\n' {
Expand Down Expand Up @@ -216,7 +216,10 @@ func (header *Header) formatVersion1() ([]byte, error) {

func parseV1PortNumber(portStr string) (int, error) {
port, err := strconv.Atoi(portStr)
if err != nil || port < 0 || port > 65535 {
if err != nil {
return 0, fmt.Errorf("%w: %w", ErrInvalidPortNumber, err)
}
if port < 0 || port > 65535 {
return 0, ErrInvalidPortNumber
}
return port, nil
Expand All @@ -225,7 +228,7 @@ func parseV1PortNumber(portStr string) (int, error) {
func parseV1IPAddress(protocol AddressFamilyAndProtocol, addrStr string) (net.IP, error) {
addr, err := netip.ParseAddr(addrStr)
if err != nil {
return nil, ErrInvalidAddress
return nil, fmt.Errorf("%w: %w", ErrInvalidAddress, err)
}

switch protocol {
Expand Down
2 changes: 1 addition & 1 deletion v1_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ var invalidParseV1Tests = []struct {
func TestReadV1Invalid(t *testing.T) {
for _, tt := range invalidParseV1Tests {
t.Run(tt.desc, func(t *testing.T) {
if _, err := Read(tt.reader); err != tt.expectedError {
if _, err := Read(tt.reader); !errors.Is(err, tt.expectedError) {
t.Fatalf("expected %s, actual %v", tt.expectedError, err)
}
})
Expand Down
15 changes: 8 additions & 7 deletions v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"math"
"net"
Expand Down Expand Up @@ -73,7 +74,7 @@ func parseVersion2(reader *bufio.Reader) (header *Header, err error) {
// Skip first 12 bytes (signature)
for range 12 {
if _, err = reader.ReadByte(); err != nil {
return nil, ErrCantReadProtocolVersionAndCommand
return nil, fmt.Errorf("%w: %w", ErrCantReadProtocolVersionAndCommand, err)
}
}

Expand All @@ -83,7 +84,7 @@ func parseVersion2(reader *bufio.Reader) (header *Header, err error) {
// Read the 13th byte, protocol version and command
b13, err := reader.ReadByte()
if err != nil {
return nil, ErrCantReadProtocolVersionAndCommand
return nil, fmt.Errorf("%w: %w", ErrCantReadProtocolVersionAndCommand, err)
}
header.Command = ProtocolVersionAndCommand(b13)
if _, ok := supportedCommand[header.Command]; !ok {
Expand All @@ -93,7 +94,7 @@ func parseVersion2(reader *bufio.Reader) (header *Header, err error) {
// Read the 14th byte, address family and protocol
b14, err := reader.ReadByte()
if err != nil {
return nil, ErrCantReadAddressFamilyAndProtocol
return nil, fmt.Errorf("%w: %w", ErrCantReadAddressFamilyAndProtocol, err)
}
header.TransportProtocol = AddressFamilyAndProtocol(b14)
// UNSPEC is only supported when LOCAL is set.
Expand All @@ -104,7 +105,7 @@ func parseVersion2(reader *bufio.Reader) (header *Header, err error) {
// Make sure there are bytes available as specified in length
var length uint16
if err := binary.Read(reader, binary.BigEndian, &length); err != nil {
return nil, ErrCantReadLength
return nil, fmt.Errorf("%w: %w", ErrCantReadLength, err)
}
if !header.validateLength(length) {
return nil, ErrInvalidLength
Expand All @@ -130,21 +131,21 @@ func parseVersion2(reader *bufio.Reader) (header *Header, err error) {
if header.TransportProtocol.IsIPv4() {
var addr _addr4
if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil {
return nil, ErrInvalidAddress
return nil, fmt.Errorf("%w: %w", ErrInvalidAddress, err)
}
header.SourceAddr = newIPAddr(header.TransportProtocol, addr.Src[:], addr.SrcPort)
header.DestinationAddr = newIPAddr(header.TransportProtocol, addr.Dst[:], addr.DstPort)
} else if header.TransportProtocol.IsIPv6() {
var addr _addr6
if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil {
return nil, ErrInvalidAddress
return nil, fmt.Errorf("%w: %w", ErrInvalidAddress, err)
}
header.SourceAddr = newIPAddr(header.TransportProtocol, addr.Src[:], addr.SrcPort)
header.DestinationAddr = newIPAddr(header.TransportProtocol, addr.Dst[:], addr.DstPort)
} else if header.TransportProtocol.IsUnix() {
var addr _addrUnix
if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil {
return nil, ErrInvalidAddress
return nil, fmt.Errorf("%w: %w", ErrInvalidAddress, err)
}

network := "unix"
Expand Down
3 changes: 2 additions & 1 deletion v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"bytes"
iorand "crypto/rand"
"encoding/binary"
"errors"
"reflect"
"strings"
"testing"
Expand Down Expand Up @@ -164,7 +165,7 @@ var invalidParseV2Tests = []struct {
func TestParseV2Invalid(t *testing.T) {
for _, tt := range invalidParseV2Tests {
t.Run(tt.desc, func(t *testing.T) {
if _, err := Read(tt.reader); err != tt.expectedError {
if _, err := Read(tt.reader); !errors.Is(err, tt.expectedError) {
t.Fatalf("expected %v, actual %v", tt.expectedError, err)
}
})
Expand Down
Loading