@@ -814,6 +814,71 @@ func TestReadingIsRefusedOnErrorWhenLocalAddrRequestedFirst(t *testing.T) {
814814 }
815815}
816816
817+ func TestSkipProxyProtocolPolicy (t * testing.T ) {
818+ l , err := net .Listen ("tcp" , "127.0.0.1:0" )
819+ if err != nil {
820+ t .Fatalf ("err: %v" , err )
821+ }
822+
823+ policyFunc := func (upstream net.Addr ) (Policy , error ) { return SKIP , nil }
824+
825+ timeout := time .Minute
826+ pl := & Listener {
827+ Listener : l ,
828+ Policy : policyFunc ,
829+ ReadHeaderTimeout : timeout ,
830+ }
831+
832+ ticker := time .NewTicker (timeout )
833+ done := make (chan bool )
834+ defer func () {
835+ close (done )
836+ ticker .Stop ()
837+ }()
838+
839+ go func () {
840+ for {
841+ select {
842+ case <- done :
843+ return
844+ case <- ticker .C :
845+ t .Fatalf ("Timeout waiting for traffic" )
846+ }
847+ }
848+ }()
849+
850+ ping := []byte ("ping" )
851+ go func () {
852+ conn , err := net .Dial ("tcp" , pl .Addr ().String ())
853+ if err != nil {
854+ t .Fatalf ("err: %v" , err )
855+ }
856+ defer conn .Close ()
857+ conn .Write (ping )
858+ }()
859+
860+ conn , err := pl .Accept ()
861+ if err != nil {
862+ t .Fatalf ("err: %v" , err )
863+ }
864+ defer conn .Close ()
865+
866+ _ , ok := conn .(* net.TCPConn )
867+ if ! ok {
868+ t .Fatal ("err: should be a tcp connection" )
869+ }
870+ _ = conn .LocalAddr ()
871+ recv := make ([]byte , 4 )
872+ _ , err = conn .Read (recv )
873+ if err != nil {
874+ t .Fatalf ("Unexpected read error: %v" , err )
875+ }
876+
877+ if ! bytes .Equal (ping , recv ) {
878+ t .Fatalf ("Unexpected %s data while expected %s" , recv , ping )
879+ }
880+ }
881+
817882func Test_ConnectionCasts (t * testing.T ) {
818883 l , err := net .Listen ("tcp" , "127.0.0.1:0" )
819884 if err != nil {
0 commit comments