Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion infra/conf/transport_internet.go
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@ func (p TransportProtocol) Build() (string, error) {
return "mkcp", nil
case "ws", "websocket":
return "websocket", nil
case "h2", "http":
case "h2", "h3", "http":
return "http", nil
case "grpc", "gun":
return "grpc", nil
Expand Down
175 changes: 119 additions & 56 deletions transport/internet/http/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"sync"
"time"

"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf"
c "github.com/xtls/xray-core/common/ctx"
Expand All @@ -24,6 +26,13 @@ import (
"golang.org/x/net/http2"
)

// defines the maximum time an idle TCP session can survive in the tunnel, so
// it should be consistent across HTTP versions and with other transports.
const connIdleTimeout = 300 * time.Second

// consistent with quic-go
const h3KeepalivePeriod = 10 * time.Second

type dialerConf struct {
net.Destination
*internet.MemoryStreamConfig
Expand All @@ -48,72 +57,129 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in
if tlsConfigs == nil && realityConfigs == nil {
return nil, errors.New("TLS or REALITY must be enabled for http transport.").AtWarning()
}
isH3 := tlsConfigs != nil && (len(tlsConfigs.NextProtocol) == 1 && tlsConfigs.NextProtocol[0] == "h3")
if isH3 {
dest.Network = net.Network_UDP
}
sockopt := streamSettings.SocketSettings

if client, found := globalDialerMap[dialerConf{dest, streamSettings}]; found {
return client, nil
}

transport := &http2.Transport{
DialTLSContext: func(hctx context.Context, string, addr string, tlsConfig *gotls.Config) (net.Conn, error) {
rawHost, rawPort, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
if len(rawPort) == 0 {
rawPort = "443"
}
port, err := net.PortFromString(rawPort)
if err != nil {
return nil, err
}
address := net.ParseAddress(rawHost)
var transport http.RoundTripper
if isH3 {
quicConfig := &quic.Config{
MaxIdleTimeout: connIdleTimeout,

hctx = c.ContextWithID(hctx, c.IDFromContext(ctx))
hctx = session.ContextWithOutbounds(hctx, session.OutboundsFromContext(ctx))
hctx = session.ContextWithTimeoutOnly(hctx, true)
// these two are defaults of quic-go/http3. the default of quic-go (no
// http3) is different, so it is hardcoded here for clarity.
// https://github.com/quic-go/quic-go/blob/b8ea5c798155950fb5bbfdd06cad1939c9355878/http3/client.go#L36-L39
MaxIncomingStreams: -1,
KeepAlivePeriod: h3KeepalivePeriod,
}
roundTripper := &http3.RoundTripper{
QUICConfig: quicConfig,
TLSClientConfig: tlsConfigs.GetTLSConfig(tls.WithDestination(dest)),
Dial: func(ctx context.Context, addr string, tlsCfg *gotls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
conn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings)
if err != nil {
return nil, err
}

pconn, err := internet.DialSystem(hctx, net.TCPDestination(address, port), sockopt)
if err != nil {
errors.LogErrorInner(ctx, err, "failed to dial to " + addr)
return nil, err
}
var udpConn net.PacketConn
var udpAddr *net.UDPAddr

if realityConfigs != nil {
return reality.UClient(pconn, realityConfigs, hctx, dest)
}
switch c := conn.(type) {
case *internet.PacketConnWrapper:
var ok bool
udpConn, ok = c.Conn.(*net.UDPConn)
if !ok {
return nil, errors.New("PacketConnWrapper does not contain a UDP connection")
}
udpAddr, err = net.ResolveUDPAddr("udp", c.Dest.String())
if err != nil {
return nil, err
}
case *net.UDPConn:
udpConn = c
udpAddr, err = net.ResolveUDPAddr("udp", c.RemoteAddr().String())
if err != nil {
return nil, err
}
default:
udpConn = &internet.FakePacketConn{c}
udpAddr, err = net.ResolveUDPAddr("udp", c.RemoteAddr().String())
if err != nil {
return nil, err
}
}

var cn tls.Interface
if fingerprint := tls.GetFingerprint(tlsConfigs.Fingerprint); fingerprint != nil {
cn = tls.UClient(pconn, tlsConfig, fingerprint).(*tls.UConn)
} else {
cn = tls.Client(pconn, tlsConfig).(*tls.Conn)
}
if err := cn.HandshakeContext(ctx); err != nil {
errors.LogErrorInner(ctx, err, "failed to dial to " + addr)
return nil, err
}
if !tlsConfig.InsecureSkipVerify {
if err := cn.VerifyHostname(tlsConfig.ServerName); err != nil {
return quic.DialEarly(ctx, udpConn, udpAddr, tlsCfg, cfg)
},
}
transport = roundTripper
} else {
transportH2 := &http2.Transport{
DialTLSContext: func(hctx context.Context, string, addr string, tlsConfig *gotls.Config) (net.Conn, error) {
rawHost, rawPort, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
if len(rawPort) == 0 {
rawPort = "443"
}
port, err := net.PortFromString(rawPort)
if err != nil {
return nil, err
}
address := net.ParseAddress(rawHost)

hctx = c.ContextWithID(hctx, c.IDFromContext(ctx))
hctx = session.ContextWithOutbounds(hctx, session.OutboundsFromContext(ctx))
hctx = session.ContextWithTimeoutOnly(hctx, true)

pconn, err := internet.DialSystem(hctx, net.TCPDestination(address, port), sockopt)
if err != nil {
errors.LogErrorInner(ctx, err, "failed to dial to " + addr)
return nil, err
}
}
negotiatedProtocol := cn.NegotiatedProtocol()
if negotiatedProtocol != http2.NextProtoTLS {
return nil, errors.New("http2: unexpected ALPN protocol " + negotiatedProtocol + "; want q" + http2.NextProtoTLS).AtError()
}
return cn, nil
},
}

if tlsConfigs != nil {
transport.TLSClientConfig = tlsConfigs.GetTLSConfig(tls.WithDestination(dest))
}

if httpSettings.IdleTimeout > 0 || httpSettings.HealthCheckTimeout > 0 {
transport.ReadIdleTimeout = time.Second * time.Duration(httpSettings.IdleTimeout)
transport.PingTimeout = time.Second * time.Duration(httpSettings.HealthCheckTimeout)

if realityConfigs != nil {
return reality.UClient(pconn, realityConfigs, hctx, dest)
}

var cn tls.Interface
if fingerprint := tls.GetFingerprint(tlsConfigs.Fingerprint); fingerprint != nil {
cn = tls.UClient(pconn, tlsConfig, fingerprint).(*tls.UConn)
} else {
cn = tls.Client(pconn, tlsConfig).(*tls.Conn)
}
if err := cn.HandshakeContext(ctx); err != nil {
errors.LogErrorInner(ctx, err, "failed to dial to " + addr)
return nil, err
}
if !tlsConfig.InsecureSkipVerify {
if err := cn.VerifyHostname(tlsConfig.ServerName); err != nil {
errors.LogErrorInner(ctx, err, "failed to dial to " + addr)
return nil, err
}
}
negotiatedProtocol := cn.NegotiatedProtocol()
if negotiatedProtocol != http2.NextProtoTLS {
return nil, errors.New("http2: unexpected ALPN protocol " + negotiatedProtocol + "; want q" + http2.NextProtoTLS).AtError()
}
return cn, nil
},
}
if tlsConfigs != nil {
transportH2.TLSClientConfig = tlsConfigs.GetTLSConfig(tls.WithDestination(dest))
}
if httpSettings.IdleTimeout > 0 || httpSettings.HealthCheckTimeout > 0 {
transportH2.ReadIdleTimeout = time.Second * time.Duration(httpSettings.IdleTimeout)
transportH2.PingTimeout = time.Second * time.Duration(httpSettings.HealthCheckTimeout)
}
transport = transportH2
}

client := &http.Client{
Expand Down Expand Up @@ -158,9 +224,6 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
Host: dest.NetAddr(),
Path: httpSettings.getNormalizedPath(),
},
Proto: "HTTP/2",
ProtoMajor: 2,
ProtoMinor: 0,
Header: httpHeaders,
}
// Disable any compression method from server.
Expand Down
78 changes: 78 additions & 0 deletions transport/internet/http/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/protocol/tls/cert"
"github.com/xtls/xray-core/testing/servers/tcp"
"github.com/xtls/xray-core/testing/servers/udp"
"github.com/xtls/xray-core/transport/internet"
. "github.com/xtls/xray-core/transport/internet/http"
"github.com/xtls/xray-core/transport/internet/stat"
Expand Down Expand Up @@ -92,3 +93,80 @@ func TestHTTPConnection(t *testing.T) {
t.Error(r)
}
}

func TestH3Connection(t *testing.T) {
port := udp.PickPort()

listener, err := Listen(context.Background(), net.LocalHostIP, port, &internet.MemoryStreamConfig{
ProtocolName: "http",
ProtocolSettings: &Config{},
SecurityType: "tls",
SecuritySettings: &tls.Config{
NextProtocol: []string{"h3"},
Certificate: []*tls.Certificate{tls.ParseCertificate(cert.MustGenerate(nil, cert.CommonName("www.example.com")))},
},
}, func(conn stat.Connection) {
go func() {
defer conn.Close()

b := buf.New()
defer b.Release()

for {
if _, err := b.ReadFrom(conn); err != nil {
return
}
_, err := conn.Write(b.Bytes())
common.Must(err)
}
}()
})
common.Must(err)

defer listener.Close()

time.Sleep(time.Second)

dctx := context.Background()
conn, err := Dial(dctx, net.TCPDestination(net.LocalHostIP, port), &internet.MemoryStreamConfig{
ProtocolName: "http",
ProtocolSettings: &Config{},
SecurityType: "tls",
SecuritySettings: &tls.Config{
NextProtocol: []string{"h3"},
ServerName: "www.example.com",
AllowInsecure: true,
},
})
common.Must(err)
defer conn.Close()

const N = 1024
b1 := make([]byte, N)
common.Must2(rand.Read(b1))
b2 := buf.New()

nBytes, err := conn.Write(b1)
common.Must(err)
if nBytes != N {
t.Error("write: ", nBytes)
}

b2.Clear()
common.Must2(b2.ReadFullFrom(conn, N))
if r := cmp.Diff(b2.Bytes(), b1); r != "" {
t.Error(r)
}

nBytes, err = conn.Write(b1)
common.Must(err)
if nBytes != N {
t.Error("write: ", nBytes)
}

b2.Clear()
common.Must2(b2.ReadFullFrom(conn, N))
if r := cmp.Diff(b2.Bytes(), b1); r != "" {
t.Error(r)
}
}
Loading