diff --git a/header_test.go b/header_test.go index 0312a99..424f1ed 100644 --- a/header_test.go +++ b/header_test.go @@ -13,19 +13,22 @@ import ( // Stuff to be used in both versions tests. const ( - NO_PROTOCOL = "There is no spoon" - IP4_ADDR = "127.0.0.1" - IP6_ADDR = "::1" - PORT = 65533 - INVALID_PORT = 99999 + NO_PROTOCOL = "There is no spoon" + IP4_ADDR = "127.0.0.1" + IP6_ADDR = "::1" + IP6_COMPAT_ADDR = "0:0:0:0:0:ffff:7f00:1" + PORT = 65533 + INVALID_PORT = 99999 ) var ( - v4ip = net.ParseIP(IP4_ADDR).To4() - v6ip = net.ParseIP(IP6_ADDR).To16() + v4ip = net.ParseIP(IP4_ADDR).To4() + v6ip = net.ParseIP(IP6_ADDR).To16() + v6CompatIP = net.ParseIP(IP4_ADDR).To16() - v4addr net.Addr = &net.TCPAddr{IP: v4ip, Port: PORT} - v6addr net.Addr = &net.TCPAddr{IP: v6ip, Port: PORT} + v4addr net.Addr = &net.TCPAddr{IP: v4ip, Port: PORT} + v6addr net.Addr = &net.TCPAddr{IP: v6ip, Port: PORT} + v6CompatAddr net.Addr = &net.TCPAddr{IP: v6CompatIP, Port: PORT} v4UDPAddr net.Addr = &net.UDPAddr{IP: v4ip, Port: PORT} v6UDPAddr net.Addr = &net.UDPAddr{IP: v6ip, Port: PORT} diff --git a/v1.go b/v1.go index 49d23c6..08efe9a 100644 --- a/v1.go +++ b/v1.go @@ -114,11 +114,16 @@ func parseV1PortNumber(portStr string) (int, error) { return port, nil } -func parseV1IPAddress(protocol AddressFamilyAndProtocol, addrStr string) (addr net.IP, err error) { - addr = net.ParseIP(addrStr) - tryV4 := addr.To4() - if (protocol == TCPv4 && tryV4 == nil) || (protocol == TCPv6 && tryV4 != nil) { - err = ErrInvalidAddress +func parseV1IPAddress(protocol AddressFamilyAndProtocol, addrStr string) (net.IP, error) { + ip := net.ParseIP(addrStr) + switch protocol { + case TCPv4: + ip = ip.To4() + case TCPv6: + ip = ip.To16() } - return + if ip == nil { + return nil, ErrInvalidAddress + } + return ip, nil } diff --git a/v1_test.go b/v1_test.go index 9b72a79..9b8083b 100644 --- a/v1_test.go +++ b/v1_test.go @@ -12,9 +12,11 @@ var ( TCP4AddressesAndPorts = strings.Join([]string{IP4_ADDR, IP4_ADDR, strconv.Itoa(PORT), strconv.Itoa(PORT)}, separator) TCP4AddressesAndInvalidPorts = strings.Join([]string{IP4_ADDR, IP4_ADDR, strconv.Itoa(INVALID_PORT), strconv.Itoa(INVALID_PORT)}, separator) TCP6AddressesAndPorts = strings.Join([]string{IP6_ADDR, IP6_ADDR, strconv.Itoa(PORT), strconv.Itoa(PORT)}, separator) + TCP6CompatAddressesAndPorts = strings.Join([]string{IP6_COMPAT_ADDR, IP6_ADDR, strconv.Itoa(PORT), strconv.Itoa(PORT)}, separator) - fixtureTCP4V1 = "PROXY TCP4 " + TCP4AddressesAndPorts + crlf + "GET /" - fixtureTCP6V1 = "PROXY TCP6 " + TCP6AddressesAndPorts + crlf + "GET /" + fixtureTCP4V1 = "PROXY TCP4 " + TCP4AddressesAndPorts + crlf + "GET /" + fixtureTCP6V1 = "PROXY TCP6 " + TCP6AddressesAndPorts + crlf + "GET /" + fixtureTCP6CompatV1 = "PROXY TCP6 " + TCP6CompatAddressesAndPorts + crlf + "GET /" ) var invalidParseV1Tests = []struct { @@ -54,7 +56,7 @@ var invalidParseV1Tests = []struct { func TestReadV1Invalid(t *testing.T) { for _, tt := range invalidParseV1Tests { if _, err := Read(tt.reader); err != tt.expectedError { - t.Fatalf("TestReadV1Invalid: expected %s, actual %s", tt.expectedError, err.Error()) + t.Fatalf("TestReadV1Invalid: expected %s, actual %s", tt.expectedError, err) } } } @@ -83,6 +85,16 @@ var validParseAndWriteV1Tests = []struct { DestinationAddr: v6addr, }, }, + { + bufio.NewReader(strings.NewReader(fixtureTCP6CompatV1)), + &Header{ + Version: 1, + Command: PROXY, + TransportProtocol: TCPv6, + SourceAddr: v6CompatAddr, + DestinationAddr: v6addr, + }, + }, } func TestParseV1Valid(t *testing.T) {