diff --git a/app/dns/nameserver_quic.go b/app/dns/nameserver_quic.go index 4c7ac0329cac..bd0da1700929 100644 --- a/app/dns/nameserver_quic.go +++ b/app/dns/nameserver_quic.go @@ -8,7 +8,7 @@ import ( "sync" "time" - "github.com/quic-go/quic-go" + "github.com/apernet/quic-go" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/log" diff --git a/common/net/port.go b/common/net/port.go index d4a6514c1d5c..26f5e3e25105 100644 --- a/common/net/port.go +++ b/common/net/port.go @@ -87,6 +87,16 @@ func PortListFromProto(l *PortList) MemoryPortList { return mpl } +func (l *PortList) Ports() []uint32 { + var ports []uint32 + for _, r := range l.Range { + for i := uint32(r.From); i <= uint32(r.To); i++ { + ports = append(ports, i) + } + } + return ports +} + func (mpl MemoryPortList) Contains(port Port) bool { for _, pr := range mpl { if pr.Contains(port) { diff --git a/common/protocol/quic/sniff.go b/common/protocol/quic/sniff.go index 0691bad62bc4..5b29d6ffada4 100644 --- a/common/protocol/quic/sniff.go +++ b/common/protocol/quic/sniff.go @@ -7,7 +7,7 @@ import ( "encoding/binary" "io" - "github.com/quic-go/quic-go/quicvarint" + "github.com/apernet/quic-go/quicvarint" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" diff --git a/go.mod b/go.mod index c52fc6d695b2..f6bfcb794244 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/xtls/xray-core go 1.25.5 require ( + github.com/apernet/quic-go v0.57.2-0.20260111184307-eec823306178 github.com/cloudflare/circl v1.6.2 github.com/ghodss/yaml v1.0.1-0.20220118164431-d8423dcdf344 github.com/golang/mock v1.7.0-rc.1 @@ -11,7 +12,6 @@ require ( github.com/miekg/dns v1.1.70 github.com/pelletier/go-toml v1.9.5 github.com/pires/go-proxyproto v0.8.1 - github.com/quic-go/quic-go v0.58.0 github.com/refraction-networking/utls v1.8.1 github.com/sagernet/sing v0.5.1 github.com/sagernet/sing-shadowsocks v0.2.7 @@ -22,6 +22,7 @@ require ( github.com/xtls/reality v0.0.0-20251014195629-e4eec4520535 go4.org/netipx v0.0.0-20231129151722-fdeea329fbba golang.org/x/crypto v0.47.0 + golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 golang.org/x/net v0.49.0 golang.org/x/sync v0.19.0 golang.org/x/sys v0.40.0 diff --git a/go.sum b/go.sum index 5f8367d2cb1b..0f8e1428df4c 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI= github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= +github.com/apernet/quic-go v0.57.2-0.20260111184307-eec823306178 h1:bSq8n+gX4oO/qnM3MKf4kroW75n+phO9Qp6nigJKZ1E= +github.com/apernet/quic-go v0.57.2-0.20260111184307-eec823306178/go.mod h1:N1WIjPphkqs4efXWuyDNQ6OjjIK04vM3h+bEgwV+eVU= github.com/cloudflare/circl v1.6.2 h1:hL7VBpHHKzrV5WTfHCaBsgx/HGbBYlgrwvNXEVDYYsQ= github.com/cloudflare/circl v1.6.2/go.mod h1:2eXP6Qfat4O/Yhh8BznvKnJ+uzEoTQ6jVKJRn81BiS4= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= @@ -50,8 +52,6 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8= github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII= -github.com/quic-go/quic-go v0.58.0 h1:ggY2pvZaVdB9EyojxL1p+5mptkuHyX5MOSv4dgWF4Ug= -github.com/quic-go/quic-go v0.58.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU= github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkrMJI0pRUOCAo= github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 h1:f/FNXud6gA3MNr8meMVVGxhp+QBTqY91tM8HjEuMjGg= @@ -97,6 +97,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= +golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= +golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= golang.org/x/mod v0.5.1/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro= golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI= golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg= diff --git a/infra/conf/hysteria.go b/infra/conf/hysteria.go new file mode 100644 index 000000000000..6512d1050d5f --- /dev/null +++ b/infra/conf/hysteria.go @@ -0,0 +1,23 @@ +package conf + +import ( + "github.com/xtls/xray-core/common/protocol" + "github.com/xtls/xray-core/proxy/hysteria" + "google.golang.org/protobuf/proto" +) + +type HysteriaClientConfig struct { + Address *Address `json:"address"` + Port uint16 `json:"port"` +} + +func (c *HysteriaClientConfig) Build() (proto.Message, error) { + config := new(hysteria.ClientConfig) + + config.Server = &protocol.ServerEndpoint{ + Address: c.Address.Build(), + Port: uint32(c.Port), + } + + return config, nil +} diff --git a/infra/conf/transport_internet.go b/infra/conf/transport_internet.go index 2544c016185e..742c7df5a940 100644 --- a/infra/conf/transport_internet.go +++ b/infra/conf/transport_internet.go @@ -16,7 +16,9 @@ import ( "github.com/xtls/xray-core/common/platform/filesystem" "github.com/xtls/xray-core/common/serial" "github.com/xtls/xray-core/transport/internet" + "github.com/xtls/xray-core/transport/internet/finalmask/salamander" "github.com/xtls/xray-core/transport/internet/httpupgrade" + "github.com/xtls/xray-core/transport/internet/hysteria" "github.com/xtls/xray-core/transport/internet/kcp" "github.com/xtls/xray-core/transport/internet/reality" "github.com/xtls/xray-core/transport/internet/splithttp" @@ -332,6 +334,161 @@ func (c *SplitHTTPConfig) Build() (proto.Message, error) { return config, nil } +const ( + Byte = 1 + Kilobyte = 1024 * Byte + Megabyte = 1024 * Kilobyte + Gigabyte = 1024 * Megabyte + Terabyte = 1024 * Gigabyte +) + +type Bandwidth string + +func (b Bandwidth) Bps() (uint64, error) { + s := strings.TrimSpace(strings.ToLower(string(b))) + if s == "" { + return 0, nil + } + + idx := len(s) + for i, c := range s { + if (c < '0' || c > '9') && c != '.' { + idx = i + break + } + } + + numStr := s[:idx] + unit := strings.TrimSpace(s[idx:]) + + val, err := strconv.ParseFloat(numStr, 64) + if err != nil { + return 0, err + } + + mul := uint64(1) + switch unit { + case "", "b", "bps": + mul = Byte + case "k", "kb", "kbps": + mul = Kilobyte + case "m", "mb", "mbps": + mul = Megabyte + case "g", "gb", "gbps": + mul = Gigabyte + case "t", "tb", "tbps": + mul = Terabyte + default: + return 0, errors.New("unsupported unit: " + unit) + } + + return uint64(val*float64(mul)) / 8, nil +} + +type UdpHop struct { + PortList json.RawMessage `json:"port"` + Interval int64 `json:"interval"` +} + +type HysteriaConfig struct { + Version int32 `json:"version"` + Auth string `json:"auth"` + Up Bandwidth `json:"up"` + Down Bandwidth `json:"down"` + UdpHop UdpHop `json:"udphop"` + + InitStreamReceiveWindow uint64 `json:"initStreamReceiveWindow"` + MaxStreamReceiveWindow uint64 `json:"maxStreamReceiveWindow"` + InitConnectionReceiveWindow uint64 `json:"initConnectionReceiveWindow"` + MaxConnectionReceiveWindow uint64 `json:"maxConnectionReceiveWindow"` + MaxIdleTimeout int64 `json:"maxIdleTimeout"` + KeepAlivePeriod int64 `json:"keepAlivePeriod"` + DisablePathMTUDiscovery bool `json:"disablePathMTUDiscovery"` +} + +func (c *HysteriaConfig) Build() (proto.Message, error) { + if c.Version != 2 { + return nil, errors.New("version != 2") + } + up, err := c.Up.Bps() + if err != nil { + return nil, err + } + down, err := c.Down.Bps() + if err != nil { + return nil, err + } + var hop *PortList + if err := json.Unmarshal(c.UdpHop.PortList, &hop); err != nil { + hop = &PortList{} + } + + if up > 0 && up < 65536 { + return nil, errors.New("Up must be at least 65536 Bps") + } + if down > 0 && down < 65536 { + return nil, errors.New("Down must be at least 65536 Bps") + } + if c.UdpHop.Interval != 0 && c.UdpHop.Interval < 5 { + return nil, errors.New("Interval must be at least 5") + } + + if c.InitStreamReceiveWindow > 0 && c.InitStreamReceiveWindow < 16384 { + return nil, errors.New("InitStreamReceiveWindow must be at least 16384") + } + if c.MaxStreamReceiveWindow > 0 && c.MaxStreamReceiveWindow < 16384 { + return nil, errors.New("MaxStreamReceiveWindow must be at least 16384") + } + if c.InitConnectionReceiveWindow > 0 && c.InitConnectionReceiveWindow < 16384 { + return nil, errors.New("InitConnectionReceiveWindow must be at least 16384") + } + if c.MaxConnectionReceiveWindow > 0 && c.MaxConnectionReceiveWindow < 16384 { + return nil, errors.New("MaxConnectionReceiveWindow must be at least 16384") + } + if c.MaxIdleTimeout != 0 && (c.MaxIdleTimeout < 4 || c.MaxIdleTimeout > 120) { + return nil, errors.New("MaxIdleTimeout must be between 4 and 120") + } + if c.KeepAlivePeriod != 0 && (c.KeepAlivePeriod < 2 || c.KeepAlivePeriod > 60) { + return nil, errors.New("KeepAlivePeriod must be between 2 and 60") + } + + config := &hysteria.Config{} + config.Version = int32(c.Version) + config.Auth = c.Auth + config.Up = up + config.Down = down + config.Ports = hop.Build().Ports() + config.Interval = c.UdpHop.Interval + config.InitStreamReceiveWindow = c.InitStreamReceiveWindow + config.MaxStreamReceiveWindow = c.MaxStreamReceiveWindow + config.InitConnReceiveWindow = c.InitConnectionReceiveWindow + config.MaxConnReceiveWindow = c.MaxConnectionReceiveWindow + config.MaxIdleTimeout = c.MaxIdleTimeout + config.KeepAlivePeriod = c.KeepAlivePeriod + config.DisablePathMtuDiscovery = c.DisablePathMTUDiscovery + + if config.InitStreamReceiveWindow == 0 { + config.InitStreamReceiveWindow = 8388608 + } + if config.MaxStreamReceiveWindow == 0 { + config.MaxStreamReceiveWindow = 8388608 + } + if config.InitConnReceiveWindow == 0 { + config.InitConnReceiveWindow = 8388608 * 5 / 2 + } + if config.MaxConnReceiveWindow == 0 { + config.MaxConnReceiveWindow = 8388608 * 5 / 2 + } + if config.MaxIdleTimeout == 0 { + config.MaxIdleTimeout = 30 + } + // if config.KeepAlivePeriod == 0 { + // config.KeepAlivePeriod = 10 + // } + + return config, nil +} + func readFileOrString(f string, s []string) ([]byte, error) { if len(f) > 0 { return filesystem.ReadCert(f) @@ -746,6 +903,8 @@ func (p TransportProtocol) Build() (string, error) { return "", errors.PrintRemovedFeatureError("HTTP transport (without header padding, etc.)", "XHTTP stream-one H2 & H3") case "quic": return "", errors.PrintRemovedFeatureError("QUIC transport (without web service, etc.)", "XHTTP stream-one H3") + case "hysteria": + return "hysteria", nil default: return "", errors.New("Config: unknown transport protocol: ", p) } @@ -928,11 +1087,54 @@ func (c *SocketConfig) Build() (*internet.SocketConfig, error) { }, nil } +var ( + udpmaskLoader = NewJSONConfigLoader(ConfigCreatorCache{ + "salamander": func() interface{} { return new(Salamander) }, + }, "type", "settings") +) + +type Salamander struct { + Password string `json:"password"` +} + +func (c *Salamander) Build() (proto.Message, error) { + config := &salamander.Config{} + config.Password = c.Password + return config, nil +} + +type FinalMask struct { + Type string `json:"type"` + Settings *json.RawMessage `json:"settings"` +} + +func (c *FinalMask) Build(tcpmaskLoader bool) (proto.Message, error) { + loader := udpmaskLoader + if tcpmaskLoader { + return nil, errors.New("") + } + + settings := []byte("{}") + if c.Settings != nil { + settings = ([]byte)(*c.Settings) + } + rawConfig, err := loader.LoadWithID(settings, c.Type) + if err != nil { + return nil, err + } + ts, err := rawConfig.(Buildable).Build() + if err != nil { + return nil, err + } + return ts, nil +} + type StreamConfig struct { Address *Address `json:"address"` Port uint16 `json:"port"` Network *TransportProtocol `json:"network"` Security string `json:"security"` + Udpmasks []*FinalMask `json:"udpmasks"` TLSSettings *TLSConfig `json:"tlsSettings"` REALITYSettings *REALITYConfig `json:"realitySettings"` RAWSettings *TCPConfig `json:"rawSettings"` @@ -943,6 +1145,7 @@ type StreamConfig struct { GRPCSettings *GRPCConfig `json:"grpcSettings"` WSSettings *WebSocketConfig `json:"wsSettings"` HTTPUPGRADESettings *HttpUpgradeConfig `json:"httpupgradeSettings"` + HysteriaSettings *HysteriaConfig `json:"hysteriaSettings"` SocketSettings *SocketConfig `json:"sockopt"` } @@ -962,6 +1165,7 @@ func (c *StreamConfig) Build() (*internet.StreamConfig, error) { } config.ProtocolName = protocol } + switch strings.ToLower(c.Security) { case "", "none": case "tls": @@ -995,6 +1199,7 @@ func (c *StreamConfig) Build() (*internet.StreamConfig, error) { default: return nil, errors.New(`Unknown security "` + c.Security + `".`) } + if c.RAWSettings != nil { c.TCPSettings = c.RAWSettings } @@ -1061,6 +1266,16 @@ func (c *StreamConfig) Build() (*internet.StreamConfig, error) { Settings: serial.ToTypedMessage(hs), }) } + if c.HysteriaSettings != nil { + hs, err := c.HysteriaSettings.Build() + if err != nil { + return nil, errors.New("Failed to build Hysteria config.").Base(err) + } + config.TransportSettings = append(config.TransportSettings, &internet.TransportConfig{ + ProtocolName: "hysteria", + Settings: serial.ToTypedMessage(hs), + }) + } if c.SocketSettings != nil { ss, err := c.SocketSettings.Build() if err != nil { @@ -1068,6 +1283,15 @@ func (c *StreamConfig) Build() (*internet.StreamConfig, error) { } config.SocketSettings = ss } + + for _, mask := range c.Udpmasks { + u, err := mask.Build(false) + if err != nil { + return nil, errors.New("failed to build mask with type ", mask.Type).Base(err) + } + config.Udpmasks = append(config.Udpmasks, serial.ToTypedMessage(u)) + } + return config, nil } diff --git a/infra/conf/xray.go b/infra/conf/xray.go index eff6b8a91c4f..9e5c1394ab0d 100644 --- a/infra/conf/xray.go +++ b/infra/conf/xray.go @@ -43,6 +43,7 @@ var ( "vless": func() interface{} { return new(VLessOutboundConfig) }, "vmess": func() interface{} { return new(VMessOutboundConfig) }, "trojan": func() interface{} { return new(TrojanClientConfig) }, + "hysteria": func() interface{} { return new(HysteriaClientConfig) }, "dns": func() interface{} { return new(DNSOutboundConfig) }, "wireguard": func() interface{} { return &WireGuardConfig{IsClient: true} }, }, "protocol", "settings") @@ -117,13 +118,13 @@ func (m *MuxConfig) Build() (*proxyman.MultiplexingConfig, error) { } type InboundDetourConfig struct { - Protocol string `json:"protocol"` - PortList *PortList `json:"port"` - ListenOn *Address `json:"listen"` - Settings *json.RawMessage `json:"settings"` - Tag string `json:"tag"` - StreamSetting *StreamConfig `json:"streamSettings"` - SniffingConfig *SniffingConfig `json:"sniffing"` + Protocol string `json:"protocol"` + PortList *PortList `json:"port"` + ListenOn *Address `json:"listen"` + Settings *json.RawMessage `json:"settings"` + Tag string `json:"tag"` + StreamSetting *StreamConfig `json:"streamSettings"` + SniffingConfig *SniffingConfig `json:"sniffing"` } // Build implements Buildable. diff --git a/proxy/hysteria/client.go b/proxy/hysteria/client.go new file mode 100644 index 000000000000..4e98544358d9 --- /dev/null +++ b/proxy/hysteria/client.go @@ -0,0 +1,263 @@ +package hysteria + +import ( + "context" + go_errors "errors" + "io" + "math/rand" + + "github.com/apernet/quic-go" + "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/common/buf" + "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/common/protocol" + "github.com/xtls/xray-core/common/session" + "github.com/xtls/xray-core/common/signal" + "github.com/xtls/xray-core/common/task" + "github.com/xtls/xray-core/core" + "github.com/xtls/xray-core/features/policy" + "github.com/xtls/xray-core/transport" + "github.com/xtls/xray-core/transport/internet" + "github.com/xtls/xray-core/transport/internet/hysteria" + "github.com/xtls/xray-core/transport/internet/stat" +) + +type Client struct { + server *protocol.ServerSpec + policyManager policy.Manager +} + +func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) { + if config.Server == nil { + return nil, errors.New(`no target server found`) + } + server, err := protocol.NewServerSpecFromPB(config.Server) + if err != nil { + return nil, errors.New("failed to get server spec").Base(err) + } + + v := core.MustFromContext(ctx) + client := &Client{ + server: server, + policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), + } + return client, nil +} + +func (c *Client) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { + outbounds := session.OutboundsFromContext(ctx) + ob := outbounds[len(outbounds)-1] + if !ob.Target.IsValid() { + return errors.New("target not specified") + } + ob.Name = "hysteria" + ob.CanSpliceCopy = 3 + target := ob.Target + + conn, err := dialer.Dial(ctx, c.server.Destination) + if err != nil { + return errors.New("failed to find an available destination").AtWarning().Base(err) + } + defer conn.Close() + errors.LogInfo(ctx, "tunneling request to ", target, " via ", target.Network, ":", c.server.Destination.NetAddr()) + + var newCtx context.Context + var newCancel context.CancelFunc + if session.TimeoutOnlyFromContext(ctx) { + newCtx, newCancel = context.WithCancel(context.Background()) + } + + sessionPolicy := c.policyManager.ForLevel(0) + ctx, cancel := context.WithCancel(ctx) + timer := signal.CancelAfterInactivity(ctx, func() { + cancel() + if newCancel != nil { + newCancel() + } + }, sessionPolicy.Timeouts.ConnectionIdle) + + if newCtx != nil { + ctx = newCtx + } + + if target.Network == net.Network_TCP { + requestDone := func() error { + defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) + bufferedWriter := buf.NewBufferedWriter(buf.NewWriter(conn)) + err := WriteTCPRequest(bufferedWriter, target.NetAddr()) + if err != nil { + return errors.New("failed to write request").Base(err) + } + if err := bufferedWriter.SetBuffered(false); err != nil { + return err + } + return buf.Copy(link.Reader, bufferedWriter, buf.UpdateActivity(timer)) + } + + responseDone := func() error { + defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) + ok, msg, err := ReadTCPResponse(conn) + if err != nil { + return err + } + if !ok { + return errors.New(msg) + } + return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)) + } + + responseDoneAndCloseWriter := task.OnSuccess(responseDone, task.Close(link.Writer)) + if err := task.Run(ctx, requestDone, responseDoneAndCloseWriter); err != nil { + return errors.New("connection ends").Base(err) + } + + return nil + } + + if target.Network == net.Network_UDP { + iConn := stat.TryUnwrapStatsConn(conn) + _, ok := iConn.(*hysteria.InterUdpConn) + if !ok { + return errors.New("udp requires hysteria udp transport") + } + + requestDone := func() error { + defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) + + writer := &UDPWriter{ + Writer: conn, + buf: make([]byte, MaxUDPSize), + addr: target.NetAddr(), + } + + if err := buf.Copy(link.Reader, writer, buf.UpdateActivity(timer)); err != nil { + return errors.New("failed to transport all UDP request").Base(err) + } + return nil + } + + responseDone := func() error { + defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) + + reader := &UDPReader{ + Reader: conn, + df: &Defragger{}, + } + + if err := buf.Copy(reader, link.Writer, buf.UpdateActivity(timer)); err != nil { + return errors.New("failed to transport all UDP response").Base(err) + } + return nil + } + + responseDoneAndCloseWriter := task.OnSuccess(responseDone, task.Close(link.Writer)) + if err := task.Run(ctx, requestDone, responseDoneAndCloseWriter); err != nil { + return errors.New("connection ends").Base(err) + } + + return nil + } + + return nil +} + +func init() { + common.Must(common.RegisterConfig((*ClientConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { + return NewClient(ctx, config.(*ClientConfig)) + })) +} + +type UDPWriter struct { + Writer io.Writer + buf []byte + addr string +} + +func (w *UDPWriter) sendMsg(msg *UDPMessage) error { + msgN := msg.Serialize(w.buf) + if msgN < 0 { + // Message larger than buffer, silent drop + return nil + } + _, err := w.Writer.Write(w.buf[:msgN]) + return err +} + +func (w *UDPWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { + for { + mb2, b := buf.SplitFirst(mb) + mb = mb2 + if b == nil { + break + } + addr := w.addr + if b.UDP != nil { + addr = b.UDP.NetAddr() + } + msg := &UDPMessage{ + SessionID: 0, + PacketID: 0, + FragID: 0, + FragCount: 1, + Addr: addr, + Data: b.Bytes(), + } + if err := w.sendMsg(msg); err != nil { + var errTooLarge *quic.DatagramTooLargeError + if go_errors.As(err, &errTooLarge) { + msg.PacketID = uint16(rand.Intn(0xFFFF)) + 1 + fMsgs := FragUDPMessage(msg, int(errTooLarge.MaxDatagramPayloadSize)) + for _, fMsg := range fMsgs { + err := w.sendMsg(&fMsg) + if err != nil { + b.Release() + buf.ReleaseMulti(mb) + return err + } + } + } else { + b.Release() + buf.ReleaseMulti(mb) + return err + } + } + b.Release() + } + return nil +} + +type UDPReader struct { + Reader io.Reader + df *Defragger +} + +func (r *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) { + for { + b := buf.New() + _, err := b.ReadFrom(r.Reader) + if err != nil { + b.Release() + return nil, err + } + + msg, err := ParseUDPMessage(b.Bytes()) + if err != nil { + b.Release() + continue + } + + dfMsg := r.df.Feed(msg) + if dfMsg == nil { + continue + } + + dest, _ := net.ParseDestination("udp:" + dfMsg.Addr) + + buffer := buf.New() + buffer.Write(dfMsg.Data) + buffer.UDP = &dest + + return buf.MultiBuffer{buffer}, nil + } +} diff --git a/proxy/hysteria/config.go b/proxy/hysteria/config.go new file mode 100644 index 000000000000..2650d856aec8 --- /dev/null +++ b/proxy/hysteria/config.go @@ -0,0 +1,10 @@ +package hysteria + +import ( + "github.com/xtls/xray-core/transport/internet/hysteria/padding" +) + +var ( + tcpRequestPadding = padding.Padding{Min: 64, Max: 512} + // tcpResponsePadding = padding.Padding{Min: 128, Max: 1024} +) diff --git a/proxy/hysteria/config.pb.go b/proxy/hysteria/config.pb.go new file mode 100644 index 000000000000..d58022cc2d63 --- /dev/null +++ b/proxy/hysteria/config.pb.go @@ -0,0 +1,126 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.10 +// protoc v6.33.1 +// source: proxy/hysteria/config.proto + +package hysteria + +import ( + protocol "github.com/xtls/xray-core/common/protocol" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type ClientConfig struct { + state protoimpl.MessageState `protogen:"open.v1"` + Server *protocol.ServerEndpoint `protobuf:"bytes,1,opt,name=server,proto3" json:"server,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ClientConfig) Reset() { + *x = ClientConfig{} + mi := &file_proxy_hysteria_config_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ClientConfig) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ClientConfig) ProtoMessage() {} + +func (x *ClientConfig) ProtoReflect() protoreflect.Message { + mi := &file_proxy_hysteria_config_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ClientConfig.ProtoReflect.Descriptor instead. +func (*ClientConfig) Descriptor() ([]byte, []int) { + return file_proxy_hysteria_config_proto_rawDescGZIP(), []int{0} +} + +func (x *ClientConfig) GetServer() *protocol.ServerEndpoint { + if x != nil { + return x.Server + } + return nil +} + +var File_proxy_hysteria_config_proto protoreflect.FileDescriptor + +const file_proxy_hysteria_config_proto_rawDesc = "" + + "\n" + + "\x1bproxy/hysteria/config.proto\x12\x13xray.proxy.hysteria\x1a!common/protocol/server_spec.proto\"L\n" + + "\fClientConfig\x12<\n" + + "\x06server\x18\x01 \x01(\v2$.xray.common.protocol.ServerEndpointR\x06serverB[\n" + + "\x17com.xray.proxy.hysteriaP\x01Z(github.com/xtls/xray-core/proxy/hysteria\xaa\x02\x13Xray.Proxy.Hysteriab\x06proto3" + +var ( + file_proxy_hysteria_config_proto_rawDescOnce sync.Once + file_proxy_hysteria_config_proto_rawDescData []byte +) + +func file_proxy_hysteria_config_proto_rawDescGZIP() []byte { + file_proxy_hysteria_config_proto_rawDescOnce.Do(func() { + file_proxy_hysteria_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_proxy_hysteria_config_proto_rawDesc), len(file_proxy_hysteria_config_proto_rawDesc))) + }) + return file_proxy_hysteria_config_proto_rawDescData +} + +var file_proxy_hysteria_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_proxy_hysteria_config_proto_goTypes = []any{ + (*ClientConfig)(nil), // 0: xray.proxy.hysteria.ClientConfig + (*protocol.ServerEndpoint)(nil), // 1: xray.common.protocol.ServerEndpoint +} +var file_proxy_hysteria_config_proto_depIdxs = []int32{ + 1, // 0: xray.proxy.hysteria.ClientConfig.server:type_name -> xray.common.protocol.ServerEndpoint + 1, // [1:1] is the sub-list for method output_type + 1, // [1:1] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 1, // [1:1] is the sub-list for extension extendee + 0, // [0:1] is the sub-list for field type_name +} + +func init() { file_proxy_hysteria_config_proto_init() } +func file_proxy_hysteria_config_proto_init() { + if File_proxy_hysteria_config_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_proxy_hysteria_config_proto_rawDesc), len(file_proxy_hysteria_config_proto_rawDesc)), + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_proxy_hysteria_config_proto_goTypes, + DependencyIndexes: file_proxy_hysteria_config_proto_depIdxs, + MessageInfos: file_proxy_hysteria_config_proto_msgTypes, + }.Build() + File_proxy_hysteria_config_proto = out.File + file_proxy_hysteria_config_proto_goTypes = nil + file_proxy_hysteria_config_proto_depIdxs = nil +} diff --git a/proxy/hysteria/config.proto b/proxy/hysteria/config.proto new file mode 100644 index 000000000000..c54c0ead0cd0 --- /dev/null +++ b/proxy/hysteria/config.proto @@ -0,0 +1,13 @@ +syntax = "proto3"; + +package xray.proxy.hysteria; +option csharp_namespace = "Xray.Proxy.Hysteria"; +option go_package = "github.com/xtls/xray-core/proxy/hysteria"; +option java_package = "com.xray.proxy.hysteria"; +option java_multiple_files = true; + +import "common/protocol/server_spec.proto"; + +message ClientConfig { + xray.common.protocol.ServerEndpoint server = 1; +} diff --git a/proxy/hysteria/frag.go b/proxy/hysteria/frag.go new file mode 100644 index 000000000000..64a6b0e1b0a4 --- /dev/null +++ b/proxy/hysteria/frag.go @@ -0,0 +1,73 @@ +package hysteria + +func FragUDPMessage(m *UDPMessage, maxSize int) []UDPMessage { + if m.Size() <= maxSize { + return []UDPMessage{*m} + } + fullPayload := m.Data + maxPayloadSize := maxSize - m.HeaderSize() + off := 0 + fragID := uint8(0) + fragCount := uint8((len(fullPayload) + maxPayloadSize - 1) / maxPayloadSize) // round up + frags := make([]UDPMessage, fragCount) + for off < len(fullPayload) { + payloadSize := len(fullPayload) - off + if payloadSize > maxPayloadSize { + payloadSize = maxPayloadSize + } + frag := *m + frag.FragID = fragID + frag.FragCount = fragCount + frag.Data = fullPayload[off : off+payloadSize] + frags[fragID] = frag + off += payloadSize + fragID++ + } + return frags +} + +// Defragger handles the defragmentation of UDP messages. +// The current implementation can only handle one packet ID at a time. +// If another packet arrives before a packet has received all fragments +// in their entirety, any previous state is discarded. +type Defragger struct { + pktID uint16 + frags []*UDPMessage + count uint8 + size int // data size +} + +func (d *Defragger) Feed(m *UDPMessage) *UDPMessage { + if m.FragCount <= 1 { + return m + } + if m.FragID >= m.FragCount { + // wtf is this? + return nil + } + if m.PacketID != d.pktID || m.FragCount != uint8(len(d.frags)) { + // new message, clear previous state + d.pktID = m.PacketID + d.frags = make([]*UDPMessage, m.FragCount) + d.frags[m.FragID] = m + d.count = 1 + d.size = len(m.Data) + } else if d.frags[m.FragID] == nil { + d.frags[m.FragID] = m + d.count++ + d.size += len(m.Data) + if int(d.count) == len(d.frags) { + // all fragments received, assemble + data := make([]byte, d.size) + off := 0 + for _, frag := range d.frags { + off += copy(data[off:], frag.Data) + } + m.Data = data + m.FragID = 0 + m.FragCount = 1 + return m + } + } + return nil +} diff --git a/proxy/hysteria/protocol.go b/proxy/hysteria/protocol.go new file mode 100644 index 000000000000..ee4834a01db8 --- /dev/null +++ b/proxy/hysteria/protocol.go @@ -0,0 +1,204 @@ +package hysteria + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + + "github.com/apernet/quic-go/quicvarint" + "github.com/xtls/xray-core/common/errors" +) + +const ( + FrameTypeTCPRequest = 0x401 + + // Max length values are for preventing DoS attacks + + MaxAddressLength = 2048 + MaxMessageLength = 2048 + MaxPaddingLength = 4096 + + MaxUDPSize = 4096 + + maxVarInt1 = 63 + maxVarInt2 = 16383 + maxVarInt4 = 1073741823 + maxVarInt8 = 4611686018427387903 +) + +// TCPRequest format: +// 0x401 (QUIC varint) +// Address length (QUIC varint) +// Address (bytes) +// Padding length (QUIC varint) +// Padding (bytes) + +func WriteTCPRequest(w io.Writer, addr string) error { + padding := tcpRequestPadding.String() + paddingLen := len(padding) + addrLen := len(addr) + sz := int(quicvarint.Len(FrameTypeTCPRequest)) + + int(quicvarint.Len(uint64(addrLen))) + addrLen + + int(quicvarint.Len(uint64(paddingLen))) + paddingLen + buf := make([]byte, sz) + i := varintPut(buf, FrameTypeTCPRequest) + i += varintPut(buf[i:], uint64(addrLen)) + i += copy(buf[i:], addr) + i += varintPut(buf[i:], uint64(paddingLen)) + copy(buf[i:], padding) + _, err := w.Write(buf) + return err +} + +// TCPResponse format: +// Status (byte, 0=ok, 1=error) +// Message length (QUIC varint) +// Message (bytes) +// Padding length (QUIC varint) +// Padding (bytes) + +func ReadTCPResponse(r io.Reader) (bool, string, error) { + var status [1]byte + if _, err := io.ReadFull(r, status[:]); err != nil { + return false, "", err + } + bReader := quicvarint.NewReader(r) + msgLen, err := quicvarint.Read(bReader) + if err != nil { + return false, "", err + } + if msgLen > MaxMessageLength { + return false, "", errors.New("invalid message length") + } + var msgBuf []byte + // No message is fine + if msgLen > 0 { + msgBuf = make([]byte, msgLen) + _, err = io.ReadFull(r, msgBuf) + if err != nil { + return false, "", err + } + } + paddingLen, err := quicvarint.Read(bReader) + if err != nil { + return false, "", err + } + if paddingLen > MaxPaddingLength { + return false, "", errors.New("invalid padding length") + } + if paddingLen > 0 { + _, err = io.CopyN(io.Discard, r, int64(paddingLen)) + if err != nil { + return false, "", err + } + } + return status[0] == 0, string(msgBuf), nil +} + +// UDPMessage format: +// Session ID (uint32 BE) +// Packet ID (uint16 BE) +// Fragment ID (uint8) +// Fragment count (uint8) +// Address length (QUIC varint) +// Address (bytes) +// Data... + +type UDPMessage struct { + SessionID uint32 // 4 + PacketID uint16 // 2 + FragID uint8 // 1 + FragCount uint8 // 1 + Addr string // varint + bytes + Data []byte +} + +func (m *UDPMessage) HeaderSize() int { + lAddr := len(m.Addr) + return 4 + 2 + 1 + 1 + int(quicvarint.Len(uint64(lAddr))) + lAddr +} + +func (m *UDPMessage) Size() int { + return m.HeaderSize() + len(m.Data) +} + +func (m *UDPMessage) Serialize(buf []byte) int { + // Make sure the buffer is big enough + if len(buf) < m.Size() { + return -1 + } + // binary.BigEndian.PutUint32(buf, m.SessionID) + binary.BigEndian.PutUint16(buf[4:], m.PacketID) + buf[6] = m.FragID + buf[7] = m.FragCount + i := varintPut(buf[8:], uint64(len(m.Addr))) + i += copy(buf[8+i:], m.Addr) + i += copy(buf[8+i:], m.Data) + return 8 + i +} + +func ParseUDPMessage(msg []byte) (*UDPMessage, error) { + m := &UDPMessage{} + buf := bytes.NewBuffer(msg) + if err := binary.Read(buf, binary.BigEndian, &m.SessionID); err != nil { + return nil, err + } + if err := binary.Read(buf, binary.BigEndian, &m.PacketID); err != nil { + return nil, err + } + if err := binary.Read(buf, binary.BigEndian, &m.FragID); err != nil { + return nil, err + } + if err := binary.Read(buf, binary.BigEndian, &m.FragCount); err != nil { + return nil, err + } + lAddr, err := quicvarint.Read(buf) + if err != nil { + return nil, err + } + if lAddr == 0 || lAddr > MaxMessageLength { + return nil, errors.New("invalid address length") + } + bs := buf.Bytes() + if len(bs) <= int(lAddr) { + // We use <= instead of < here as we expect at least one byte of data after the address + return nil, errors.New("invalid message length") + } + m.Addr = string(bs[:lAddr]) + m.Data = bs[lAddr:] + return m, nil +} + +// varintPut is like quicvarint.Append, but instead of appending to a slice, +// it writes to a fixed-size buffer. Returns the number of bytes written. +func varintPut(b []byte, i uint64) int { + if i <= maxVarInt1 { + b[0] = uint8(i) + return 1 + } + if i <= maxVarInt2 { + b[0] = uint8(i>>8) | 0x40 + b[1] = uint8(i) + return 2 + } + if i <= maxVarInt4 { + b[0] = uint8(i>>24) | 0x80 + b[1] = uint8(i >> 16) + b[2] = uint8(i >> 8) + b[3] = uint8(i) + return 4 + } + if i <= maxVarInt8 { + b[0] = uint8(i>>56) | 0xc0 + b[1] = uint8(i >> 48) + b[2] = uint8(i >> 40) + b[3] = uint8(i >> 32) + b[4] = uint8(i >> 24) + b[5] = uint8(i >> 16) + b[6] = uint8(i >> 8) + b[7] = uint8(i) + return 8 + } + panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i)) +} diff --git a/transport/internet/config.go b/transport/internet/config.go index bd6dfac4fe39..0b6f93245f66 100644 --- a/transport/internet/config.go +++ b/transport/internet/config.go @@ -90,7 +90,7 @@ func (c *StreamConfig) GetEffectiveSecuritySettings() (interface{}, error) { } func (c *StreamConfig) HasSecuritySettings() bool { - return len(c.SecurityType) > 0 + return len(c.SecuritySettings) > 0 } func (c *ProxyConfig) HasTag() bool { @@ -130,7 +130,7 @@ func (s DomainStrategy) FallbackIP6() bool { } func (s DomainStrategy) GetDynamicStrategy(addrFamily net.AddressFamily) DomainStrategy { - if addrFamily.IsDomain(){ + if addrFamily.IsDomain() { return s } switch s { diff --git a/transport/internet/config.pb.go b/transport/internet/config.pb.go index 41ce3294c21b..89f83a1272d8 100644 --- a/transport/internet/config.pb.go +++ b/transport/internet/config.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.35.1 -// protoc v5.28.2 +// protoc-gen-go v1.36.10 +// protoc v6.33.1 // source: transport/internet/config.proto package internet @@ -13,6 +13,7 @@ import ( protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" sync "sync" + unsafe "unsafe" ) const ( @@ -209,14 +210,13 @@ func (SocketConfig_TProxyMode) EnumDescriptor() ([]byte, []int) { } type TransportConfig struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - + state protoimpl.MessageState `protogen:"open.v1"` // Transport protocol name. ProtocolName string `protobuf:"bytes,3,opt,name=protocol_name,json=protocolName,proto3" json:"protocol_name,omitempty"` // Specific transport protocol settings. - Settings *serial.TypedMessage `protobuf:"bytes,2,opt,name=settings,proto3" json:"settings,omitempty"` + Settings *serial.TypedMessage `protobuf:"bytes,2,opt,name=settings,proto3" json:"settings,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *TransportConfig) Reset() { @@ -264,12 +264,9 @@ func (x *TransportConfig) GetSettings() *serial.TypedMessage { } type StreamConfig struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - Address *net.IPOrDomain `protobuf:"bytes,8,opt,name=address,proto3" json:"address,omitempty"` - Port uint32 `protobuf:"varint,9,opt,name=port,proto3" json:"port,omitempty"` + state protoimpl.MessageState `protogen:"open.v1"` + Address *net.IPOrDomain `protobuf:"bytes,8,opt,name=address,proto3" json:"address,omitempty"` + Port uint32 `protobuf:"varint,9,opt,name=port,proto3" json:"port,omitempty"` // Effective network. ProtocolName string `protobuf:"bytes,5,opt,name=protocol_name,json=protocolName,proto3" json:"protocol_name,omitempty"` TransportSettings []*TransportConfig `protobuf:"bytes,2,rep,name=transport_settings,json=transportSettings,proto3" json:"transport_settings,omitempty"` @@ -277,7 +274,11 @@ type StreamConfig struct { SecurityType string `protobuf:"bytes,3,opt,name=security_type,json=securityType,proto3" json:"security_type,omitempty"` // Transport security settings. They can be either TLS or REALITY. SecuritySettings []*serial.TypedMessage `protobuf:"bytes,4,rep,name=security_settings,json=securitySettings,proto3" json:"security_settings,omitempty"` + Udpmasks []*serial.TypedMessage `protobuf:"bytes,10,rep,name=udpmasks,proto3" json:"udpmasks,omitempty"` + Tcpmasks []*serial.TypedMessage `protobuf:"bytes,11,rep,name=tcpmasks,proto3" json:"tcpmasks,omitempty"` SocketSettings *SocketConfig `protobuf:"bytes,6,opt,name=socket_settings,json=socketSettings,proto3" json:"socket_settings,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *StreamConfig) Reset() { @@ -352,6 +353,20 @@ func (x *StreamConfig) GetSecuritySettings() []*serial.TypedMessage { return nil } +func (x *StreamConfig) GetUdpmasks() []*serial.TypedMessage { + if x != nil { + return x.Udpmasks + } + return nil +} + +func (x *StreamConfig) GetTcpmasks() []*serial.TypedMessage { + if x != nil { + return x.Tcpmasks + } + return nil +} + func (x *StreamConfig) GetSocketSettings() *SocketConfig { if x != nil { return x.SocketSettings @@ -360,12 +375,11 @@ func (x *StreamConfig) GetSocketSettings() *SocketConfig { } type ProxyConfig struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - Tag string `protobuf:"bytes,1,opt,name=tag,proto3" json:"tag,omitempty"` - TransportLayerProxy bool `protobuf:"varint,2,opt,name=transportLayerProxy,proto3" json:"transportLayerProxy,omitempty"` + state protoimpl.MessageState `protogen:"open.v1"` + Tag string `protobuf:"bytes,1,opt,name=tag,proto3" json:"tag,omitempty"` + TransportLayerProxy bool `protobuf:"varint,2,opt,name=transportLayerProxy,proto3" json:"transportLayerProxy,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *ProxyConfig) Reset() { @@ -413,16 +427,15 @@ func (x *ProxyConfig) GetTransportLayerProxy() bool { } type CustomSockopt struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + System string `protobuf:"bytes,1,opt,name=system,proto3" json:"system,omitempty"` + Network string `protobuf:"bytes,2,opt,name=network,proto3" json:"network,omitempty"` + Level string `protobuf:"bytes,3,opt,name=level,proto3" json:"level,omitempty"` + Opt string `protobuf:"bytes,4,opt,name=opt,proto3" json:"opt,omitempty"` + Value string `protobuf:"bytes,5,opt,name=value,proto3" json:"value,omitempty"` + Type string `protobuf:"bytes,6,opt,name=type,proto3" json:"type,omitempty"` unknownFields protoimpl.UnknownFields - - System string `protobuf:"bytes,1,opt,name=system,proto3" json:"system,omitempty"` - Network string `protobuf:"bytes,2,opt,name=network,proto3" json:"network,omitempty"` - Level string `protobuf:"bytes,3,opt,name=level,proto3" json:"level,omitempty"` - Opt string `protobuf:"bytes,4,opt,name=opt,proto3" json:"opt,omitempty"` - Value string `protobuf:"bytes,5,opt,name=value,proto3" json:"value,omitempty"` - Type string `protobuf:"bytes,6,opt,name=type,proto3" json:"type,omitempty"` + sizeCache protoimpl.SizeCache } func (x *CustomSockopt) Reset() { @@ -499,10 +512,7 @@ func (x *CustomSockopt) GetType() string { // SocketConfig is options to be applied on network sockets. type SocketConfig struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - + state protoimpl.MessageState `protogen:"open.v1"` // Mark of the connection. If non-zero, the value will be set to SO_MARK. Mark int32 `protobuf:"varint,1,opt,name=mark,proto3" json:"mark,omitempty"` // TFO is the state of TFO settings. @@ -531,6 +541,8 @@ type SocketConfig struct { AddressPortStrategy AddressPortStrategy `protobuf:"varint,21,opt,name=address_port_strategy,json=addressPortStrategy,proto3,enum=xray.transport.internet.AddressPortStrategy" json:"address_port_strategy,omitempty"` HappyEyeballs *HappyEyeballsConfig `protobuf:"bytes,22,opt,name=happy_eyeballs,json=happyEyeballs,proto3" json:"happy_eyeballs,omitempty"` TrustedXForwardedFor []string `protobuf:"bytes,23,rep,name=trusted_x_forwarded_for,json=trustedXForwardedFor,proto3" json:"trusted_x_forwarded_for,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *SocketConfig) Reset() { @@ -725,14 +737,13 @@ func (x *SocketConfig) GetTrustedXForwardedFor() []string { } type HappyEyeballsConfig struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - PrioritizeIpv6 bool `protobuf:"varint,1,opt,name=prioritize_ipv6,json=prioritizeIpv6,proto3" json:"prioritize_ipv6,omitempty"` - Interleave uint32 `protobuf:"varint,2,opt,name=interleave,proto3" json:"interleave,omitempty"` - TryDelayMs uint64 `protobuf:"varint,3,opt,name=try_delayMs,json=tryDelayMs,proto3" json:"try_delayMs,omitempty"` - MaxConcurrentTry uint32 `protobuf:"varint,4,opt,name=max_concurrent_try,json=maxConcurrentTry,proto3" json:"max_concurrent_try,omitempty"` + state protoimpl.MessageState `protogen:"open.v1"` + PrioritizeIpv6 bool `protobuf:"varint,1,opt,name=prioritize_ipv6,json=prioritizeIpv6,proto3" json:"prioritize_ipv6,omitempty"` + Interleave uint32 `protobuf:"varint,2,opt,name=interleave,proto3" json:"interleave,omitempty"` + TryDelayMs uint64 `protobuf:"varint,3,opt,name=try_delayMs,json=tryDelayMs,proto3" json:"try_delayMs,omitempty"` + MaxConcurrentTry uint32 `protobuf:"varint,4,opt,name=max_concurrent_try,json=maxConcurrentTry,proto3" json:"max_concurrent_try,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *HappyEyeballsConfig) Reset() { @@ -795,184 +806,106 @@ func (x *HappyEyeballsConfig) GetMaxConcurrentTry() uint32 { var File_transport_internet_config_proto protoreflect.FileDescriptor -var file_transport_internet_config_proto_rawDesc = []byte{ - 0x0a, 0x1f, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2f, 0x69, 0x6e, 0x74, 0x65, - 0x72, 0x6e, 0x65, 0x74, 0x2f, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x12, 0x17, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, - 0x74, 0x2e, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x1a, 0x21, 0x63, 0x6f, 0x6d, 0x6d, - 0x6f, 0x6e, 0x2f, 0x73, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x2f, 0x74, 0x79, 0x70, 0x65, 0x64, 0x5f, - 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x18, 0x63, - 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2f, 0x6e, 0x65, 0x74, 0x2f, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, - 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x74, 0x0a, 0x0f, 0x54, 0x72, 0x61, 0x6e, 0x73, - 0x70, 0x6f, 0x72, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x23, 0x0a, 0x0d, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x0c, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x4e, 0x61, 0x6d, 0x65, 0x12, - 0x3c, 0x0a, 0x08, 0x73, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, - 0x0b, 0x32, 0x20, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2e, - 0x73, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, - 0x61, 0x67, 0x65, 0x52, 0x08, 0x73, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73, 0x22, 0x9b, 0x03, - 0x0a, 0x0c, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x35, - 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0b, 0x32, - 0x1b, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2e, 0x6e, 0x65, - 0x74, 0x2e, 0x49, 0x50, 0x4f, 0x72, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x52, 0x07, 0x61, 0x64, - 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x09, 0x20, - 0x01, 0x28, 0x0d, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x23, 0x0a, 0x0d, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x0c, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x57, - 0x0a, 0x12, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x5f, 0x73, 0x65, 0x74, 0x74, - 0x69, 0x6e, 0x67, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x28, 0x2e, 0x78, 0x72, 0x61, - 0x79, 0x2e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x69, 0x6e, 0x74, 0x65, - 0x72, 0x6e, 0x65, 0x74, 0x2e, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x43, 0x6f, - 0x6e, 0x66, 0x69, 0x67, 0x52, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x53, - 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73, 0x12, 0x23, 0x0a, 0x0d, 0x73, 0x65, 0x63, 0x75, 0x72, - 0x69, 0x74, 0x79, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, - 0x73, 0x65, 0x63, 0x75, 0x72, 0x69, 0x74, 0x79, 0x54, 0x79, 0x70, 0x65, 0x12, 0x4d, 0x0a, 0x11, - 0x73, 0x65, 0x63, 0x75, 0x72, 0x69, 0x74, 0x79, 0x5f, 0x73, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, - 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x20, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x63, - 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2e, 0x73, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x2e, 0x54, 0x79, 0x70, - 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x52, 0x10, 0x73, 0x65, 0x63, 0x75, 0x72, - 0x69, 0x74, 0x79, 0x53, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73, 0x12, 0x4e, 0x0a, 0x0f, 0x73, - 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x5f, 0x73, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73, 0x18, 0x06, - 0x20, 0x01, 0x28, 0x0b, 0x32, 0x25, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x74, 0x72, 0x61, 0x6e, - 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x2e, 0x53, - 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x73, 0x6f, 0x63, - 0x6b, 0x65, 0x74, 0x53, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73, 0x22, 0x51, 0x0a, 0x0b, 0x50, - 0x72, 0x6f, 0x78, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x10, 0x0a, 0x03, 0x74, 0x61, - 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x74, 0x61, 0x67, 0x12, 0x30, 0x0a, 0x13, - 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x4c, 0x61, 0x79, 0x65, 0x72, 0x50, 0x72, - 0x6f, 0x78, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x74, 0x72, 0x61, 0x6e, 0x73, - 0x70, 0x6f, 0x72, 0x74, 0x4c, 0x61, 0x79, 0x65, 0x72, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x22, 0x93, - 0x01, 0x0a, 0x0d, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x53, 0x6f, 0x63, 0x6b, 0x6f, 0x70, 0x74, - 0x12, 0x16, 0x0a, 0x06, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x06, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x12, 0x18, 0x0a, 0x07, 0x6e, 0x65, 0x74, 0x77, - 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6e, 0x65, 0x74, 0x77, 0x6f, - 0x72, 0x6b, 0x12, 0x14, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x10, 0x0a, 0x03, 0x6f, 0x70, 0x74, 0x18, - 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6f, 0x70, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, - 0x6c, 0x75, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, - 0x12, 0x12, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, - 0x74, 0x79, 0x70, 0x65, 0x22, 0x89, 0x09, 0x0a, 0x0c, 0x53, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x6d, 0x61, 0x72, 0x6b, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x05, 0x52, 0x04, 0x6d, 0x61, 0x72, 0x6b, 0x12, 0x10, 0x0a, 0x03, 0x74, 0x66, 0x6f, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x03, 0x74, 0x66, 0x6f, 0x12, 0x48, 0x0a, 0x06, 0x74, - 0x70, 0x72, 0x6f, 0x78, 0x79, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x30, 0x2e, 0x78, 0x72, - 0x61, 0x79, 0x2e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x69, 0x6e, 0x74, - 0x65, 0x72, 0x6e, 0x65, 0x74, 0x2e, 0x53, 0x6f, 0x63, 0x6b, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x2e, 0x54, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x6f, 0x64, 0x65, 0x52, 0x06, 0x74, - 0x70, 0x72, 0x6f, 0x78, 0x79, 0x12, 0x41, 0x0a, 0x1d, 0x72, 0x65, 0x63, 0x65, 0x69, 0x76, 0x65, - 0x5f, 0x6f, 0x72, 0x69, 0x67, 0x69, 0x6e, 0x61, 0x6c, 0x5f, 0x64, 0x65, 0x73, 0x74, 0x5f, 0x61, - 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1a, 0x72, 0x65, - 0x63, 0x65, 0x69, 0x76, 0x65, 0x4f, 0x72, 0x69, 0x67, 0x69, 0x6e, 0x61, 0x6c, 0x44, 0x65, 0x73, - 0x74, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x21, 0x0a, 0x0c, 0x62, 0x69, 0x6e, 0x64, - 0x5f, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0b, - 0x62, 0x69, 0x6e, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x1b, 0x0a, 0x09, 0x62, - 0x69, 0x6e, 0x64, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, - 0x62, 0x69, 0x6e, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x32, 0x0a, 0x15, 0x61, 0x63, 0x63, 0x65, - 0x70, 0x74, 0x5f, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x5f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, - 0x6c, 0x18, 0x07, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x61, 0x63, 0x63, 0x65, 0x70, 0x74, 0x50, - 0x72, 0x6f, 0x78, 0x79, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x50, 0x0a, 0x0f, - 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x5f, 0x73, 0x74, 0x72, 0x61, 0x74, 0x65, 0x67, 0x79, 0x18, - 0x08, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x27, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x74, 0x72, 0x61, - 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x2e, - 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x53, 0x74, 0x72, 0x61, 0x74, 0x65, 0x67, 0x79, 0x52, 0x0e, - 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x53, 0x74, 0x72, 0x61, 0x74, 0x65, 0x67, 0x79, 0x12, 0x21, - 0x0a, 0x0c, 0x64, 0x69, 0x61, 0x6c, 0x65, 0x72, 0x5f, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x18, 0x09, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, 0x69, 0x61, 0x6c, 0x65, 0x72, 0x50, 0x72, 0x6f, 0x78, - 0x79, 0x12, 0x35, 0x0a, 0x17, 0x74, 0x63, 0x70, 0x5f, 0x6b, 0x65, 0x65, 0x70, 0x5f, 0x61, 0x6c, - 0x69, 0x76, 0x65, 0x5f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x18, 0x0a, 0x20, 0x01, - 0x28, 0x05, 0x52, 0x14, 0x74, 0x63, 0x70, 0x4b, 0x65, 0x65, 0x70, 0x41, 0x6c, 0x69, 0x76, 0x65, - 0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x12, 0x2d, 0x0a, 0x13, 0x74, 0x63, 0x70, 0x5f, - 0x6b, 0x65, 0x65, 0x70, 0x5f, 0x61, 0x6c, 0x69, 0x76, 0x65, 0x5f, 0x69, 0x64, 0x6c, 0x65, 0x18, - 0x0b, 0x20, 0x01, 0x28, 0x05, 0x52, 0x10, 0x74, 0x63, 0x70, 0x4b, 0x65, 0x65, 0x70, 0x41, 0x6c, - 0x69, 0x76, 0x65, 0x49, 0x64, 0x6c, 0x65, 0x12, 0x25, 0x0a, 0x0e, 0x74, 0x63, 0x70, 0x5f, 0x63, - 0x6f, 0x6e, 0x67, 0x65, 0x73, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x0d, 0x74, 0x63, 0x70, 0x43, 0x6f, 0x6e, 0x67, 0x65, 0x73, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x1c, - 0x0a, 0x09, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x18, 0x0d, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x09, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x12, 0x16, 0x0a, 0x06, - 0x76, 0x36, 0x6f, 0x6e, 0x6c, 0x79, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x76, 0x36, - 0x6f, 0x6e, 0x6c, 0x79, 0x12, 0x28, 0x0a, 0x10, 0x74, 0x63, 0x70, 0x5f, 0x77, 0x69, 0x6e, 0x64, - 0x6f, 0x77, 0x5f, 0x63, 0x6c, 0x61, 0x6d, 0x70, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0e, - 0x74, 0x63, 0x70, 0x57, 0x69, 0x6e, 0x64, 0x6f, 0x77, 0x43, 0x6c, 0x61, 0x6d, 0x70, 0x12, 0x28, - 0x0a, 0x10, 0x74, 0x63, 0x70, 0x5f, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x6f, - 0x75, 0x74, 0x18, 0x10, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0e, 0x74, 0x63, 0x70, 0x55, 0x73, 0x65, - 0x72, 0x54, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x12, 0x1e, 0x0a, 0x0b, 0x74, 0x63, 0x70, 0x5f, - 0x6d, 0x61, 0x78, 0x5f, 0x73, 0x65, 0x67, 0x18, 0x11, 0x20, 0x01, 0x28, 0x05, 0x52, 0x09, 0x74, - 0x63, 0x70, 0x4d, 0x61, 0x78, 0x53, 0x65, 0x67, 0x12, 0x1c, 0x0a, 0x09, 0x70, 0x65, 0x6e, 0x65, - 0x74, 0x72, 0x61, 0x74, 0x65, 0x18, 0x12, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x70, 0x65, 0x6e, - 0x65, 0x74, 0x72, 0x61, 0x74, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x74, 0x63, 0x70, 0x5f, 0x6d, 0x70, - 0x74, 0x63, 0x70, 0x18, 0x13, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x74, 0x63, 0x70, 0x4d, 0x70, - 0x74, 0x63, 0x70, 0x12, 0x4c, 0x0a, 0x0d, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x53, 0x6f, 0x63, - 0x6b, 0x6f, 0x70, 0x74, 0x18, 0x14, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x26, 0x2e, 0x78, 0x72, 0x61, - 0x79, 0x2e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x69, 0x6e, 0x74, 0x65, - 0x72, 0x6e, 0x65, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x53, 0x6f, 0x63, 0x6b, 0x6f, - 0x70, 0x74, 0x52, 0x0d, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x53, 0x6f, 0x63, 0x6b, 0x6f, 0x70, - 0x74, 0x12, 0x60, 0x0a, 0x15, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x5f, 0x70, 0x6f, 0x72, - 0x74, 0x5f, 0x73, 0x74, 0x72, 0x61, 0x74, 0x65, 0x67, 0x79, 0x18, 0x15, 0x20, 0x01, 0x28, 0x0e, - 0x32, 0x2c, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, - 0x74, 0x2e, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x2e, 0x41, 0x64, 0x64, 0x72, 0x65, - 0x73, 0x73, 0x50, 0x6f, 0x72, 0x74, 0x53, 0x74, 0x72, 0x61, 0x74, 0x65, 0x67, 0x79, 0x52, 0x13, - 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x50, 0x6f, 0x72, 0x74, 0x53, 0x74, 0x72, 0x61, 0x74, - 0x65, 0x67, 0x79, 0x12, 0x53, 0x0a, 0x0e, 0x68, 0x61, 0x70, 0x70, 0x79, 0x5f, 0x65, 0x79, 0x65, - 0x62, 0x61, 0x6c, 0x6c, 0x73, 0x18, 0x16, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x2c, 0x2e, 0x78, 0x72, - 0x61, 0x79, 0x2e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x69, 0x6e, 0x74, - 0x65, 0x72, 0x6e, 0x65, 0x74, 0x2e, 0x48, 0x61, 0x70, 0x70, 0x79, 0x45, 0x79, 0x65, 0x62, 0x61, - 0x6c, 0x6c, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0d, 0x68, 0x61, 0x70, 0x70, 0x79, - 0x45, 0x79, 0x65, 0x62, 0x61, 0x6c, 0x6c, 0x73, 0x12, 0x35, 0x0a, 0x17, 0x74, 0x72, 0x75, 0x73, - 0x74, 0x65, 0x64, 0x5f, 0x78, 0x5f, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x64, 0x5f, - 0x66, 0x6f, 0x72, 0x18, 0x17, 0x20, 0x03, 0x28, 0x09, 0x52, 0x14, 0x74, 0x72, 0x75, 0x73, 0x74, - 0x65, 0x64, 0x58, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x64, 0x46, 0x6f, 0x72, 0x22, - 0x2f, 0x0a, 0x0a, 0x54, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x6f, 0x64, 0x65, 0x12, 0x07, 0x0a, - 0x03, 0x4f, 0x66, 0x66, 0x10, 0x00, 0x12, 0x0a, 0x0a, 0x06, 0x54, 0x50, 0x72, 0x6f, 0x78, 0x79, - 0x10, 0x01, 0x12, 0x0c, 0x0a, 0x08, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x10, 0x02, - 0x22, 0xad, 0x01, 0x0a, 0x13, 0x48, 0x61, 0x70, 0x70, 0x79, 0x45, 0x79, 0x65, 0x62, 0x61, 0x6c, - 0x6c, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x27, 0x0a, 0x0f, 0x70, 0x72, 0x69, 0x6f, - 0x72, 0x69, 0x74, 0x69, 0x7a, 0x65, 0x5f, 0x69, 0x70, 0x76, 0x36, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x0e, 0x70, 0x72, 0x69, 0x6f, 0x72, 0x69, 0x74, 0x69, 0x7a, 0x65, 0x49, 0x70, 0x76, - 0x36, 0x12, 0x1e, 0x0a, 0x0a, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6c, 0x65, 0x61, 0x76, 0x65, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6c, 0x65, 0x61, 0x76, - 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x74, 0x72, 0x79, 0x5f, 0x64, 0x65, 0x6c, 0x61, 0x79, 0x4d, 0x73, - 0x18, 0x03, 0x20, 0x01, 0x28, 0x04, 0x52, 0x0a, 0x74, 0x72, 0x79, 0x44, 0x65, 0x6c, 0x61, 0x79, - 0x4d, 0x73, 0x12, 0x2c, 0x0a, 0x12, 0x6d, 0x61, 0x78, 0x5f, 0x63, 0x6f, 0x6e, 0x63, 0x75, 0x72, - 0x72, 0x65, 0x6e, 0x74, 0x5f, 0x74, 0x72, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x10, - 0x6d, 0x61, 0x78, 0x43, 0x6f, 0x6e, 0x63, 0x75, 0x72, 0x72, 0x65, 0x6e, 0x74, 0x54, 0x72, 0x79, - 0x2a, 0xa9, 0x01, 0x0a, 0x0e, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x53, 0x74, 0x72, 0x61, 0x74, - 0x65, 0x67, 0x79, 0x12, 0x09, 0x0a, 0x05, 0x41, 0x53, 0x5f, 0x49, 0x53, 0x10, 0x00, 0x12, 0x0a, - 0x0a, 0x06, 0x55, 0x53, 0x45, 0x5f, 0x49, 0x50, 0x10, 0x01, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x53, - 0x45, 0x5f, 0x49, 0x50, 0x34, 0x10, 0x02, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x53, 0x45, 0x5f, 0x49, - 0x50, 0x36, 0x10, 0x03, 0x12, 0x0c, 0x0a, 0x08, 0x55, 0x53, 0x45, 0x5f, 0x49, 0x50, 0x34, 0x36, - 0x10, 0x04, 0x12, 0x0c, 0x0a, 0x08, 0x55, 0x53, 0x45, 0x5f, 0x49, 0x50, 0x36, 0x34, 0x10, 0x05, - 0x12, 0x0c, 0x0a, 0x08, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x10, 0x06, 0x12, 0x0d, - 0x0a, 0x09, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x34, 0x10, 0x07, 0x12, 0x0d, 0x0a, - 0x09, 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x36, 0x10, 0x08, 0x12, 0x0e, 0x0a, 0x0a, - 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x34, 0x36, 0x10, 0x09, 0x12, 0x0e, 0x0a, 0x0a, - 0x46, 0x4f, 0x52, 0x43, 0x45, 0x5f, 0x49, 0x50, 0x36, 0x34, 0x10, 0x0a, 0x2a, 0x97, 0x01, 0x0a, - 0x13, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x50, 0x6f, 0x72, 0x74, 0x53, 0x74, 0x72, 0x61, - 0x74, 0x65, 0x67, 0x79, 0x12, 0x08, 0x0a, 0x04, 0x4e, 0x6f, 0x6e, 0x65, 0x10, 0x00, 0x12, 0x0f, - 0x0a, 0x0b, 0x53, 0x72, 0x76, 0x50, 0x6f, 0x72, 0x74, 0x4f, 0x6e, 0x6c, 0x79, 0x10, 0x01, 0x12, - 0x12, 0x0a, 0x0e, 0x53, 0x72, 0x76, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x4f, 0x6e, 0x6c, - 0x79, 0x10, 0x02, 0x12, 0x15, 0x0a, 0x11, 0x53, 0x72, 0x76, 0x50, 0x6f, 0x72, 0x74, 0x41, 0x6e, - 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x10, 0x03, 0x12, 0x0f, 0x0a, 0x0b, 0x54, 0x78, - 0x74, 0x50, 0x6f, 0x72, 0x74, 0x4f, 0x6e, 0x6c, 0x79, 0x10, 0x04, 0x12, 0x12, 0x0a, 0x0e, 0x54, - 0x78, 0x74, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x4f, 0x6e, 0x6c, 0x79, 0x10, 0x05, 0x12, - 0x15, 0x0a, 0x11, 0x54, 0x78, 0x74, 0x50, 0x6f, 0x72, 0x74, 0x41, 0x6e, 0x64, 0x41, 0x64, 0x64, - 0x72, 0x65, 0x73, 0x73, 0x10, 0x06, 0x42, 0x67, 0x0a, 0x1b, 0x63, 0x6f, 0x6d, 0x2e, 0x78, 0x72, - 0x61, 0x79, 0x2e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x69, 0x6e, 0x74, - 0x65, 0x72, 0x6e, 0x65, 0x74, 0x50, 0x01, 0x5a, 0x2c, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, - 0x63, 0x6f, 0x6d, 0x2f, 0x78, 0x74, 0x6c, 0x73, 0x2f, 0x78, 0x72, 0x61, 0x79, 0x2d, 0x63, 0x6f, - 0x72, 0x65, 0x2f, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2f, 0x69, 0x6e, 0x74, - 0x65, 0x72, 0x6e, 0x65, 0x74, 0xaa, 0x02, 0x17, 0x58, 0x72, 0x61, 0x79, 0x2e, 0x54, 0x72, 0x61, - 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x62, - 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, -} +const file_transport_internet_config_proto_rawDesc = "" + + "\n" + + "\x1ftransport/internet/config.proto\x12\x17xray.transport.internet\x1a!common/serial/typed_message.proto\x1a\x18common/net/address.proto\"t\n" + + "\x0fTransportConfig\x12#\n" + + "\rprotocol_name\x18\x03 \x01(\tR\fprotocolName\x12<\n" + + "\bsettings\x18\x02 \x01(\v2 .xray.common.serial.TypedMessageR\bsettings\"\x97\x04\n" + + "\fStreamConfig\x125\n" + + "\aaddress\x18\b \x01(\v2\x1b.xray.common.net.IPOrDomainR\aaddress\x12\x12\n" + + "\x04port\x18\t \x01(\rR\x04port\x12#\n" + + "\rprotocol_name\x18\x05 \x01(\tR\fprotocolName\x12W\n" + + "\x12transport_settings\x18\x02 \x03(\v2(.xray.transport.internet.TransportConfigR\x11transportSettings\x12#\n" + + "\rsecurity_type\x18\x03 \x01(\tR\fsecurityType\x12M\n" + + "\x11security_settings\x18\x04 \x03(\v2 .xray.common.serial.TypedMessageR\x10securitySettings\x12<\n" + + "\budpmasks\x18\n" + + " \x03(\v2 .xray.common.serial.TypedMessageR\budpmasks\x12<\n" + + "\btcpmasks\x18\v \x03(\v2 .xray.common.serial.TypedMessageR\btcpmasks\x12N\n" + + "\x0fsocket_settings\x18\x06 \x01(\v2%.xray.transport.internet.SocketConfigR\x0esocketSettings\"Q\n" + + "\vProxyConfig\x12\x10\n" + + "\x03tag\x18\x01 \x01(\tR\x03tag\x120\n" + + "\x13transportLayerProxy\x18\x02 \x01(\bR\x13transportLayerProxy\"\x93\x01\n" + + "\rCustomSockopt\x12\x16\n" + + "\x06system\x18\x01 \x01(\tR\x06system\x12\x18\n" + + "\anetwork\x18\x02 \x01(\tR\anetwork\x12\x14\n" + + "\x05level\x18\x03 \x01(\tR\x05level\x12\x10\n" + + "\x03opt\x18\x04 \x01(\tR\x03opt\x12\x14\n" + + "\x05value\x18\x05 \x01(\tR\x05value\x12\x12\n" + + "\x04type\x18\x06 \x01(\tR\x04type\"\x89\t\n" + + "\fSocketConfig\x12\x12\n" + + "\x04mark\x18\x01 \x01(\x05R\x04mark\x12\x10\n" + + "\x03tfo\x18\x02 \x01(\x05R\x03tfo\x12H\n" + + "\x06tproxy\x18\x03 \x01(\x0e20.xray.transport.internet.SocketConfig.TProxyModeR\x06tproxy\x12A\n" + + "\x1dreceive_original_dest_address\x18\x04 \x01(\bR\x1areceiveOriginalDestAddress\x12!\n" + + "\fbind_address\x18\x05 \x01(\fR\vbindAddress\x12\x1b\n" + + "\tbind_port\x18\x06 \x01(\rR\bbindPort\x122\n" + + "\x15accept_proxy_protocol\x18\a \x01(\bR\x13acceptProxyProtocol\x12P\n" + + "\x0fdomain_strategy\x18\b \x01(\x0e2'.xray.transport.internet.DomainStrategyR\x0edomainStrategy\x12!\n" + + "\fdialer_proxy\x18\t \x01(\tR\vdialerProxy\x125\n" + + "\x17tcp_keep_alive_interval\x18\n" + + " \x01(\x05R\x14tcpKeepAliveInterval\x12-\n" + + "\x13tcp_keep_alive_idle\x18\v \x01(\x05R\x10tcpKeepAliveIdle\x12%\n" + + "\x0etcp_congestion\x18\f \x01(\tR\rtcpCongestion\x12\x1c\n" + + "\tinterface\x18\r \x01(\tR\tinterface\x12\x16\n" + + "\x06v6only\x18\x0e \x01(\bR\x06v6only\x12(\n" + + "\x10tcp_window_clamp\x18\x0f \x01(\x05R\x0etcpWindowClamp\x12(\n" + + "\x10tcp_user_timeout\x18\x10 \x01(\x05R\x0etcpUserTimeout\x12\x1e\n" + + "\vtcp_max_seg\x18\x11 \x01(\x05R\ttcpMaxSeg\x12\x1c\n" + + "\tpenetrate\x18\x12 \x01(\bR\tpenetrate\x12\x1b\n" + + "\ttcp_mptcp\x18\x13 \x01(\bR\btcpMptcp\x12L\n" + + "\rcustomSockopt\x18\x14 \x03(\v2&.xray.transport.internet.CustomSockoptR\rcustomSockopt\x12`\n" + + "\x15address_port_strategy\x18\x15 \x01(\x0e2,.xray.transport.internet.AddressPortStrategyR\x13addressPortStrategy\x12S\n" + + "\x0ehappy_eyeballs\x18\x16 \x01(\v2,.xray.transport.internet.HappyEyeballsConfigR\rhappyEyeballs\x125\n" + + "\x17trusted_x_forwarded_for\x18\x17 \x03(\tR\x14trustedXForwardedFor\"/\n" + + "\n" + + "TProxyMode\x12\a\n" + + "\x03Off\x10\x00\x12\n" + + "\n" + + "\x06TProxy\x10\x01\x12\f\n" + + "\bRedirect\x10\x02\"\xad\x01\n" + + "\x13HappyEyeballsConfig\x12'\n" + + "\x0fprioritize_ipv6\x18\x01 \x01(\bR\x0eprioritizeIpv6\x12\x1e\n" + + "\n" + + "interleave\x18\x02 \x01(\rR\n" + + "interleave\x12\x1f\n" + + "\vtry_delayMs\x18\x03 \x01(\x04R\n" + + "tryDelayMs\x12,\n" + + "\x12max_concurrent_try\x18\x04 \x01(\rR\x10maxConcurrentTry*\xa9\x01\n" + + "\x0eDomainStrategy\x12\t\n" + + "\x05AS_IS\x10\x00\x12\n" + + "\n" + + "\x06USE_IP\x10\x01\x12\v\n" + + "\aUSE_IP4\x10\x02\x12\v\n" + + "\aUSE_IP6\x10\x03\x12\f\n" + + "\bUSE_IP46\x10\x04\x12\f\n" + + "\bUSE_IP64\x10\x05\x12\f\n" + + "\bFORCE_IP\x10\x06\x12\r\n" + + "\tFORCE_IP4\x10\a\x12\r\n" + + "\tFORCE_IP6\x10\b\x12\x0e\n" + + "\n" + + "FORCE_IP46\x10\t\x12\x0e\n" + + "\n" + + "FORCE_IP64\x10\n" + + "*\x97\x01\n" + + "\x13AddressPortStrategy\x12\b\n" + + "\x04None\x10\x00\x12\x0f\n" + + "\vSrvPortOnly\x10\x01\x12\x12\n" + + "\x0eSrvAddressOnly\x10\x02\x12\x15\n" + + "\x11SrvPortAndAddress\x10\x03\x12\x0f\n" + + "\vTxtPortOnly\x10\x04\x12\x12\n" + + "\x0eTxtAddressOnly\x10\x05\x12\x15\n" + + "\x11TxtPortAndAddress\x10\x06Bg\n" + + "\x1bcom.xray.transport.internetP\x01Z,github.com/xtls/xray-core/transport/internet\xaa\x02\x17Xray.Transport.Internetb\x06proto3" var ( file_transport_internet_config_proto_rawDescOnce sync.Once - file_transport_internet_config_proto_rawDescData = file_transport_internet_config_proto_rawDesc + file_transport_internet_config_proto_rawDescData []byte ) func file_transport_internet_config_proto_rawDescGZIP() []byte { file_transport_internet_config_proto_rawDescOnce.Do(func() { - file_transport_internet_config_proto_rawDescData = protoimpl.X.CompressGZIP(file_transport_internet_config_proto_rawDescData) + file_transport_internet_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_transport_internet_config_proto_rawDesc), len(file_transport_internet_config_proto_rawDesc))) }) return file_transport_internet_config_proto_rawDescData } @@ -997,17 +930,19 @@ var file_transport_internet_config_proto_depIdxs = []int32{ 10, // 1: xray.transport.internet.StreamConfig.address:type_name -> xray.common.net.IPOrDomain 3, // 2: xray.transport.internet.StreamConfig.transport_settings:type_name -> xray.transport.internet.TransportConfig 9, // 3: xray.transport.internet.StreamConfig.security_settings:type_name -> xray.common.serial.TypedMessage - 7, // 4: xray.transport.internet.StreamConfig.socket_settings:type_name -> xray.transport.internet.SocketConfig - 2, // 5: xray.transport.internet.SocketConfig.tproxy:type_name -> xray.transport.internet.SocketConfig.TProxyMode - 0, // 6: xray.transport.internet.SocketConfig.domain_strategy:type_name -> xray.transport.internet.DomainStrategy - 6, // 7: xray.transport.internet.SocketConfig.customSockopt:type_name -> xray.transport.internet.CustomSockopt - 1, // 8: xray.transport.internet.SocketConfig.address_port_strategy:type_name -> xray.transport.internet.AddressPortStrategy - 8, // 9: xray.transport.internet.SocketConfig.happy_eyeballs:type_name -> xray.transport.internet.HappyEyeballsConfig - 10, // [10:10] is the sub-list for method output_type - 10, // [10:10] is the sub-list for method input_type - 10, // [10:10] is the sub-list for extension type_name - 10, // [10:10] is the sub-list for extension extendee - 0, // [0:10] is the sub-list for field type_name + 9, // 4: xray.transport.internet.StreamConfig.udpmasks:type_name -> xray.common.serial.TypedMessage + 9, // 5: xray.transport.internet.StreamConfig.tcpmasks:type_name -> xray.common.serial.TypedMessage + 7, // 6: xray.transport.internet.StreamConfig.socket_settings:type_name -> xray.transport.internet.SocketConfig + 2, // 7: xray.transport.internet.SocketConfig.tproxy:type_name -> xray.transport.internet.SocketConfig.TProxyMode + 0, // 8: xray.transport.internet.SocketConfig.domain_strategy:type_name -> xray.transport.internet.DomainStrategy + 6, // 9: xray.transport.internet.SocketConfig.customSockopt:type_name -> xray.transport.internet.CustomSockopt + 1, // 10: xray.transport.internet.SocketConfig.address_port_strategy:type_name -> xray.transport.internet.AddressPortStrategy + 8, // 11: xray.transport.internet.SocketConfig.happy_eyeballs:type_name -> xray.transport.internet.HappyEyeballsConfig + 12, // [12:12] is the sub-list for method output_type + 12, // [12:12] is the sub-list for method input_type + 12, // [12:12] is the sub-list for extension type_name + 12, // [12:12] is the sub-list for extension extendee + 0, // [0:12] is the sub-list for field type_name } func init() { file_transport_internet_config_proto_init() } @@ -1019,7 +954,7 @@ func file_transport_internet_config_proto_init() { out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: file_transport_internet_config_proto_rawDesc, + RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_config_proto_rawDesc), len(file_transport_internet_config_proto_rawDesc)), NumEnums: 3, NumMessages: 6, NumExtensions: 0, @@ -1031,7 +966,6 @@ func file_transport_internet_config_proto_init() { MessageInfos: file_transport_internet_config_proto_msgTypes, }.Build() File_transport_internet_config_proto = out.File - file_transport_internet_config_proto_rawDesc = nil file_transport_internet_config_proto_goTypes = nil file_transport_internet_config_proto_depIdxs = nil } diff --git a/transport/internet/config.proto b/transport/internet/config.proto index e1655fdd6f5c..8b1bb23e9fec 100644 --- a/transport/internet/config.proto +++ b/transport/internet/config.proto @@ -56,6 +56,9 @@ message StreamConfig { // Transport security settings. They can be either TLS or REALITY. repeated xray.common.serial.TypedMessage security_settings = 4; + repeated xray.common.serial.TypedMessage udpmasks = 10; + repeated xray.common.serial.TypedMessage tcpmasks = 11; + SocketConfig socket_settings = 6; } diff --git a/transport/internet/finalmask/finalmask.go b/transport/internet/finalmask/finalmask.go new file mode 100644 index 000000000000..7ce4d4f329a3 --- /dev/null +++ b/transport/internet/finalmask/finalmask.go @@ -0,0 +1,127 @@ +package finalmask + +import ( + "net" +) + +type Udpmask interface { + UDP() + + WrapConnClient(net.Conn) (net.Conn, error) + WrapConnServer(net.Conn) (net.Conn, error) + + WrapPacketConnClient(net.PacketConn) (net.PacketConn, error) + WrapPacketConnServer(net.PacketConn) (net.PacketConn, error) + + Size() int + Serialize([]byte) +} + +type UdpmaskManager struct { + udpmasks []Udpmask +} + +func NewUdpmaskManager(udpmasks []Udpmask) *UdpmaskManager { + return &UdpmaskManager{ + udpmasks: udpmasks, + } +} + +func (m *UdpmaskManager) WrapConnClient(raw net.Conn) (net.Conn, error) { + var err error + for _, mask := range m.udpmasks { + raw, err = mask.WrapConnClient(raw) + if err != nil { + return nil, err + } + } + return raw, nil +} + +func (m *UdpmaskManager) WrapConnServer(raw net.Conn) (net.Conn, error) { + var err error + for _, mask := range m.udpmasks { + raw, err = mask.WrapConnServer(raw) + if err != nil { + return nil, err + } + } + return raw, nil +} + +func (m *UdpmaskManager) WrapPacketConnClient(raw net.PacketConn) (net.PacketConn, error) { + var err error + for _, mask := range m.udpmasks { + raw, err = mask.WrapPacketConnClient(raw) + if err != nil { + return nil, err + } + } + return raw, nil +} + +func (m *UdpmaskManager) WrapPacketConnServer(raw net.PacketConn) (net.PacketConn, error) { + var err error + for _, mask := range m.udpmasks { + raw, err = mask.WrapPacketConnServer(raw) + if err != nil { + return nil, err + } + } + return raw, nil +} + +func (m *UdpmaskManager) Size() int { + size := 0 + for _, mask := range m.udpmasks { + size += mask.Size() + } + return size +} + +func (m *UdpmaskManager) Serialize(b []byte) { + index := 0 + for _, mask := range m.udpmasks { + mask.Serialize(b[index:]) + index += mask.Size() + } +} + +type Tcpmask interface { + TCP() + + WrapConnClient(net.Conn) (net.Conn, error) + WrapConnServer(net.Conn) (net.Conn, error) +} + +type TcpmaskManager struct { + tcpmasks []Tcpmask +} + +func NewTcpmaskManager(tcpmasks []Tcpmask) *TcpmaskManager { + return &TcpmaskManager{ + tcpmasks: tcpmasks, + } +} + +func (m *TcpmaskManager) WrapConnClient(raw net.Conn) (net.Conn, error) { + var err error + for _, mask := range m.tcpmasks { + raw, err = mask.WrapConnClient(raw) + if err != nil { + return nil, err + } + } + return raw, nil +} + +func (m *TcpmaskManager) WrapConnServer(raw net.Conn) (net.Conn, error) { + var err error + for _, mask := range m.tcpmasks { + raw, err = mask.WrapConnServer(raw) + if err != nil { + return nil, err + } + } + return raw, nil +} diff --git a/transport/internet/finalmask/salamander/config.go b/transport/internet/finalmask/salamander/config.go new file mode 100644 index 000000000000..aca35b4dd6e2 --- /dev/null +++ b/transport/internet/finalmask/salamander/config.go @@ -0,0 +1,42 @@ +package salamander + +import ( + "net" + + "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/transport/internet/finalmask/salamander/obfs" +) + +func (c *Config) UDP() { +} + +func (c *Config) WrapConnClient(raw net.Conn) (net.Conn, error) { + return raw, nil +} + +func (c *Config) WrapConnServer(raw net.Conn) (net.Conn, error) { + return raw, nil +} + +func (c *Config) WrapPacketConnClient(raw net.PacketConn) (net.PacketConn, error) { + ob, err := obfs.NewSalamanderObfuscator([]byte(c.Password)) + if err != nil { + return nil, errors.New("salamander err").Base(err) + } + return obfs.WrapPacketConn(raw, ob), nil +} + +func (c *Config) WrapPacketConnServer(raw net.PacketConn) (net.PacketConn, error) { + ob, err := obfs.NewSalamanderObfuscator([]byte(c.Password)) + if err != nil { + return nil, errors.New("salamander err").Base(err) + } + return obfs.WrapPacketConn(raw, ob), nil +} + +func (c *Config) Size() int { + return 0 +} + +func (c *Config) Serialize([]byte) { +} diff --git a/transport/internet/finalmask/salamander/config.pb.go b/transport/internet/finalmask/salamander/config.pb.go new file mode 100644 index 000000000000..7d572b422aea --- /dev/null +++ b/transport/internet/finalmask/salamander/config.pb.go @@ -0,0 +1,123 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.10 +// protoc v6.33.1 +// source: transport/internet/udpmask/salamander/config.proto + +package salamander + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type Config struct { + state protoimpl.MessageState `protogen:"open.v1"` + Password string `protobuf:"bytes,1,opt,name=password,proto3" json:"password,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Config) Reset() { + *x = Config{} + mi := &file_transport_internet_udpmask_salamander_config_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Config) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Config) ProtoMessage() {} + +func (x *Config) ProtoReflect() protoreflect.Message { + mi := &file_transport_internet_udpmask_salamander_config_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Config.ProtoReflect.Descriptor instead. +func (*Config) Descriptor() ([]byte, []int) { + return file_transport_internet_udpmask_salamander_config_proto_rawDescGZIP(), []int{0} +} + +func (x *Config) GetPassword() string { + if x != nil { + return x.Password + } + return "" +} + +var File_transport_internet_udpmask_salamander_config_proto protoreflect.FileDescriptor + +const file_transport_internet_udpmask_salamander_config_proto_rawDesc = "" + + "\n" + + "2transport/internet/udpmask/salamander/config.proto\x12*xray.transport.internet.udpmask.salamander\"$\n" + + "\x06Config\x12\x1a\n" + + "\bpassword\x18\x01 \x01(\tR\bpasswordB\xa0\x01\n" + + ".com.xray.transport.internet.udpmask.salamanderP\x01Z?github.com/xtls/xray-core/transport/internet/udpmask/salamander\xaa\x02*Xray.Transport.Internet.Udpmask.Salamanderb\x06proto3" + +var ( + file_transport_internet_udpmask_salamander_config_proto_rawDescOnce sync.Once + file_transport_internet_udpmask_salamander_config_proto_rawDescData []byte +) + +func file_transport_internet_udpmask_salamander_config_proto_rawDescGZIP() []byte { + file_transport_internet_udpmask_salamander_config_proto_rawDescOnce.Do(func() { + file_transport_internet_udpmask_salamander_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_transport_internet_udpmask_salamander_config_proto_rawDesc), len(file_transport_internet_udpmask_salamander_config_proto_rawDesc))) + }) + return file_transport_internet_udpmask_salamander_config_proto_rawDescData +} + +var file_transport_internet_udpmask_salamander_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_transport_internet_udpmask_salamander_config_proto_goTypes = []any{ + (*Config)(nil), // 0: xray.transport.internet.udpmask.salamander.Config +} +var file_transport_internet_udpmask_salamander_config_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_transport_internet_udpmask_salamander_config_proto_init() } +func file_transport_internet_udpmask_salamander_config_proto_init() { + if File_transport_internet_udpmask_salamander_config_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_udpmask_salamander_config_proto_rawDesc), len(file_transport_internet_udpmask_salamander_config_proto_rawDesc)), + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_transport_internet_udpmask_salamander_config_proto_goTypes, + DependencyIndexes: file_transport_internet_udpmask_salamander_config_proto_depIdxs, + MessageInfos: file_transport_internet_udpmask_salamander_config_proto_msgTypes, + }.Build() + File_transport_internet_udpmask_salamander_config_proto = out.File + file_transport_internet_udpmask_salamander_config_proto_goTypes = nil + file_transport_internet_udpmask_salamander_config_proto_depIdxs = nil +} diff --git a/transport/internet/finalmask/salamander/config.proto b/transport/internet/finalmask/salamander/config.proto new file mode 100644 index 000000000000..bed943ab4428 --- /dev/null +++ b/transport/internet/finalmask/salamander/config.proto @@ -0,0 +1,12 @@ +syntax = "proto3"; + +package xray.transport.internet.udpmask.salamander; +option csharp_namespace = "Xray.Transport.Internet.Udpmask.Salamander"; +option go_package = "github.com/xtls/xray-core/transport/internet/udpmask/salamander"; +option java_package = "com.xray.transport.internet.udpmask.salamander"; +option java_multiple_files = true; + +message Config { + string password = 1; +} + diff --git a/transport/internet/finalmask/salamander/obfs/conn.go b/transport/internet/finalmask/salamander/obfs/conn.go new file mode 100644 index 000000000000..6b97592eba21 --- /dev/null +++ b/transport/internet/finalmask/salamander/obfs/conn.go @@ -0,0 +1,121 @@ +package obfs + +import ( + "net" + "sync" + "syscall" + "time" +) + +const udpBufferSize = 2048 // QUIC packets are at most 1500 bytes long, so 2k should be more than enough + +// Obfuscator is the interface that wraps the Obfuscate and Deobfuscate methods. +// Both methods return the number of bytes written to out. +// If a packet is not valid, the methods should return 0. +type Obfuscator interface { + Obfuscate(in, out []byte) int + Deobfuscate(in, out []byte) int +} + +var _ net.PacketConn = (*obfsPacketConn)(nil) + +type obfsPacketConn struct { + Conn net.PacketConn + Obfs Obfuscator + + readBuf []byte + readMutex sync.Mutex + writeBuf []byte + writeMutex sync.Mutex +} + +// obfsPacketConnUDP is a special case of obfsPacketConn that uses a UDPConn +// as the underlying connection. We pass additional methods to quic-go to +// enable UDP-specific optimizations. +type obfsPacketConnUDP struct { + *obfsPacketConn + UDPConn *net.UDPConn +} + +// WrapPacketConn enables obfuscation on a net.PacketConn. +// The obfuscation is transparent to the caller - the n bytes returned by +// ReadFrom and WriteTo are the number of original bytes, not after +// obfuscation/deobfuscation. +func WrapPacketConn(conn net.PacketConn, obfs Obfuscator) net.PacketConn { + opc := &obfsPacketConn{ + Conn: conn, + Obfs: obfs, + readBuf: make([]byte, udpBufferSize), + writeBuf: make([]byte, udpBufferSize), + } + if udpConn, ok := conn.(*net.UDPConn); ok { + return &obfsPacketConnUDP{ + obfsPacketConn: opc, + UDPConn: udpConn, + } + } else { + return opc + } +} + +func (c *obfsPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + for { + c.readMutex.Lock() + n, addr, err = c.Conn.ReadFrom(c.readBuf) + if n <= 0 { + c.readMutex.Unlock() + return n, addr, err + } + n = c.Obfs.Deobfuscate(c.readBuf[:n], p) + c.readMutex.Unlock() + if n > 0 || err != nil { + return n, addr, err + } + // Invalid packet, try again + } +} + +func (c *obfsPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + c.writeMutex.Lock() + nn := c.Obfs.Obfuscate(p, c.writeBuf) + _, err = c.Conn.WriteTo(c.writeBuf[:nn], addr) + c.writeMutex.Unlock() + if err == nil { + n = len(p) + } + return n, err +} + +func (c *obfsPacketConn) Close() error { + return c.Conn.Close() +} + +func (c *obfsPacketConn) LocalAddr() net.Addr { + return c.Conn.LocalAddr() +} + +func (c *obfsPacketConn) SetDeadline(t time.Time) error { + return c.Conn.SetDeadline(t) +} + +func (c *obfsPacketConn) SetReadDeadline(t time.Time) error { + return c.Conn.SetReadDeadline(t) +} + +func (c *obfsPacketConn) SetWriteDeadline(t time.Time) error { + return c.Conn.SetWriteDeadline(t) +} + +// UDP-specific methods below + +func (c *obfsPacketConnUDP) SetReadBuffer(bytes int) error { + return c.UDPConn.SetReadBuffer(bytes) +} + +func (c *obfsPacketConnUDP) SetWriteBuffer(bytes int) error { + return c.UDPConn.SetWriteBuffer(bytes) +} + +func (c *obfsPacketConnUDP) SyscallConn() (syscall.RawConn, error) { + return c.UDPConn.SyscallConn() +} diff --git a/transport/internet/finalmask/salamander/obfs/salamander.go b/transport/internet/finalmask/salamander/obfs/salamander.go new file mode 100644 index 000000000000..50a3ce26307d --- /dev/null +++ b/transport/internet/finalmask/salamander/obfs/salamander.go @@ -0,0 +1,71 @@ +package obfs + +import ( + "fmt" + "math/rand" + "sync" + "time" + + "golang.org/x/crypto/blake2b" +) + +const ( + smPSKMinLen = 4 + smSaltLen = 8 + smKeyLen = blake2b.Size256 +) + +var _ Obfuscator = (*SalamanderObfuscator)(nil) + +var ErrPSKTooShort = fmt.Errorf("PSK must be at least %d bytes", smPSKMinLen) + +// SalamanderObfuscator is an obfuscator that obfuscates each packet with +// the BLAKE2b-256 hash of a pre-shared key combined with a random salt. +// Packet format: [8-byte salt][payload] +type SalamanderObfuscator struct { + PSK []byte + RandSrc *rand.Rand + + lk sync.Mutex +} + +func NewSalamanderObfuscator(psk []byte) (*SalamanderObfuscator, error) { + if len(psk) < smPSKMinLen { + return nil, ErrPSKTooShort + } + return &SalamanderObfuscator{ + PSK: psk, + RandSrc: rand.New(rand.NewSource(time.Now().UnixNano())), + }, nil +} + +func (o *SalamanderObfuscator) Obfuscate(in, out []byte) int { + outLen := len(in) + smSaltLen + if len(out) < outLen { + return 0 + } + o.lk.Lock() + _, _ = o.RandSrc.Read(out[:smSaltLen]) + o.lk.Unlock() + key := o.key(out[:smSaltLen]) + for i, c := range in { + out[i+smSaltLen] = c ^ key[i%smKeyLen] + } + return outLen +} + +func (o *SalamanderObfuscator) Deobfuscate(in, out []byte) int { + outLen := len(in) - smSaltLen + if outLen <= 0 || len(out) < outLen { + return 0 + } + key := o.key(in[:smSaltLen]) + for i, c := range in[smSaltLen:] { + out[i] = c ^ key[i%smKeyLen] + } + return outLen +} + +func (o *SalamanderObfuscator) key(salt []byte) [smKeyLen]byte { + return blake2b.Sum256(append(o.PSK, salt...)) +} diff --git a/transport/internet/finalmask/salamander/obfs/salamander_test.go b/transport/internet/finalmask/salamander/obfs/salamander_test.go new file mode 100644 index 000000000000..85eafdcce6d4 --- /dev/null +++ b/transport/internet/finalmask/salamander/obfs/salamander_test.go @@ -0,0 +1,45 @@ +package obfs + +import ( + "crypto/rand" + "testing" + + "github.com/stretchr/testify/assert" +) + +func BenchmarkSalamanderObfuscator_Obfuscate(b *testing.B) { + o, _ := NewSalamanderObfuscator([]byte("average_password")) + in := make([]byte, 1200) + _, _ = rand.Read(in) + out := make([]byte, 2048) + b.ResetTimer() + for i := 0; i < b.N; i++ { + o.Obfuscate(in, out) + } +} + +func BenchmarkSalamanderObfuscator_Deobfuscate(b *testing.B) { + o, _ := NewSalamanderObfuscator([]byte("average_password")) + in := make([]byte, 1200) + _, _ = rand.Read(in) + out := make([]byte, 2048) + b.ResetTimer() + for i := 0; i < b.N; i++ { + o.Deobfuscate(in, out) + } +} + +func TestSalamanderObfuscator(t *testing.T) { + o, _ := NewSalamanderObfuscator([]byte("average_password")) + in := make([]byte, 1200) + oOut := make([]byte, 2048) + dOut := make([]byte, 2048) + for i := 0; i < 1000; i++ { + _, _ = rand.Read(in) + n := o.Obfuscate(in, oOut) + assert.Equal(t, len(in)+smSaltLen, n) + n = o.Deobfuscate(oOut[:n], dOut) + assert.Equal(t, len(in), n) + assert.Equal(t, in, dOut[:n]) + } +} diff --git a/transport/internet/hysteria/config.go b/transport/internet/hysteria/config.go new file mode 100644 index 000000000000..fd7a4bb4af08 --- /dev/null +++ b/transport/internet/hysteria/config.go @@ -0,0 +1,47 @@ +package hysteria + +import ( + "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/transport/internet" + "github.com/xtls/xray-core/transport/internet/hysteria/padding" +) + +const ( + closeErrCodeOK = 0x100 // HTTP3 ErrCodeNoError + closeErrCodeProtocolError = 0x101 // HTTP3 ErrCodeGeneralProtocolError + + MaxDatagramFrameSize = 1200 + + URLHost = "hysteria" + URLPath = "/auth" + + RequestHeaderAuth = "Hysteria-Auth" + ResponseHeaderUDPEnabled = "Hysteria-UDP" + CommonHeaderCCRX = "Hysteria-CC-RX" + CommonHeaderPadding = "Hysteria-Padding" + + StatusAuthOK = 233 + + udpMessageChanSize = 1024 +) + +var ( + authRequestPadding = padding.Padding{Min: 256, Max: 2048} + // authResponsePadding = padding.Padding{Min: 256, Max: 2048} +) + +type Status int + +const ( + StatusUnknown Status = iota + StatusActive + StatusInactive +) + +const protocolName = "hysteria" + +func init() { + common.Must(internet.RegisterProtocolConfigCreator(protocolName, func() interface{} { + return new(Config) + })) +} diff --git a/transport/internet/hysteria/config.pb.go b/transport/internet/hysteria/config.pb.go new file mode 100644 index 000000000000..5e453c3005bb --- /dev/null +++ b/transport/internet/hysteria/config.pb.go @@ -0,0 +1,232 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.10 +// protoc v6.33.1 +// source: transport/internet/hysteria/config.proto + +package hysteria + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type Config struct { + state protoimpl.MessageState `protogen:"open.v1"` + Version int32 `protobuf:"varint,1,opt,name=version,proto3" json:"version,omitempty"` + Auth string `protobuf:"bytes,2,opt,name=auth,proto3" json:"auth,omitempty"` + Up uint64 `protobuf:"varint,3,opt,name=up,proto3" json:"up,omitempty"` + Down uint64 `protobuf:"varint,4,opt,name=down,proto3" json:"down,omitempty"` + Ports []uint32 `protobuf:"varint,5,rep,packed,name=ports,proto3" json:"ports,omitempty"` + Interval int64 `protobuf:"varint,6,opt,name=interval,proto3" json:"interval,omitempty"` + InitStreamReceiveWindow uint64 `protobuf:"varint,7,opt,name=init_stream_receive_window,json=initStreamReceiveWindow,proto3" json:"init_stream_receive_window,omitempty"` + MaxStreamReceiveWindow uint64 `protobuf:"varint,8,opt,name=max_stream_receive_window,json=maxStreamReceiveWindow,proto3" json:"max_stream_receive_window,omitempty"` + InitConnReceiveWindow uint64 `protobuf:"varint,9,opt,name=init_conn_receive_window,json=initConnReceiveWindow,proto3" json:"init_conn_receive_window,omitempty"` + MaxConnReceiveWindow uint64 `protobuf:"varint,10,opt,name=max_conn_receive_window,json=maxConnReceiveWindow,proto3" json:"max_conn_receive_window,omitempty"` + MaxIdleTimeout int64 `protobuf:"varint,11,opt,name=max_idle_timeout,json=maxIdleTimeout,proto3" json:"max_idle_timeout,omitempty"` + KeepAlivePeriod int64 `protobuf:"varint,12,opt,name=keep_alive_period,json=keepAlivePeriod,proto3" json:"keep_alive_period,omitempty"` + DisablePathMtuDiscovery bool `protobuf:"varint,13,opt,name=disable_path_mtu_discovery,json=disablePathMtuDiscovery,proto3" json:"disable_path_mtu_discovery,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Config) Reset() { + *x = Config{} + mi := &file_transport_internet_hysteria_config_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Config) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Config) ProtoMessage() {} + +func (x *Config) ProtoReflect() protoreflect.Message { + mi := &file_transport_internet_hysteria_config_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Config.ProtoReflect.Descriptor instead. +func (*Config) Descriptor() ([]byte, []int) { + return file_transport_internet_hysteria_config_proto_rawDescGZIP(), []int{0} +} + +func (x *Config) GetVersion() int32 { + if x != nil { + return x.Version + } + return 0 +} + +func (x *Config) GetAuth() string { + if x != nil { + return x.Auth + } + return "" +} + +func (x *Config) GetUp() uint64 { + if x != nil { + return x.Up + } + return 0 +} + +func (x *Config) GetDown() uint64 { + if x != nil { + return x.Down + } + return 0 +} + +func (x *Config) GetPorts() []uint32 { + if x != nil { + return x.Ports + } + return nil +} + +func (x *Config) GetInterval() int64 { + if x != nil { + return x.Interval + } + return 0 +} + +func (x *Config) GetInitStreamReceiveWindow() uint64 { + if x != nil { + return x.InitStreamReceiveWindow + } + return 0 +} + +func (x *Config) GetMaxStreamReceiveWindow() uint64 { + if x != nil { + return x.MaxStreamReceiveWindow + } + return 0 +} + +func (x *Config) GetInitConnReceiveWindow() uint64 { + if x != nil { + return x.InitConnReceiveWindow + } + return 0 +} + +func (x *Config) GetMaxConnReceiveWindow() uint64 { + if x != nil { + return x.MaxConnReceiveWindow + } + return 0 +} + +func (x *Config) GetMaxIdleTimeout() int64 { + if x != nil { + return x.MaxIdleTimeout + } + return 0 +} + +func (x *Config) GetKeepAlivePeriod() int64 { + if x != nil { + return x.KeepAlivePeriod + } + return 0 +} + +func (x *Config) GetDisablePathMtuDiscovery() bool { + if x != nil { + return x.DisablePathMtuDiscovery + } + return false +} + +var File_transport_internet_hysteria_config_proto protoreflect.FileDescriptor + +const file_transport_internet_hysteria_config_proto_rawDesc = "" + + "\n" + + "(transport/internet/hysteria/config.proto\x12 xray.transport.internet.hysteria\"\x87\x04\n" + + "\x06Config\x12\x18\n" + + "\aversion\x18\x01 \x01(\x05R\aversion\x12\x12\n" + + "\x04auth\x18\x02 \x01(\tR\x04auth\x12\x0e\n" + + "\x02up\x18\x03 \x01(\x04R\x02up\x12\x12\n" + + "\x04down\x18\x04 \x01(\x04R\x04down\x12\x14\n" + + "\x05ports\x18\x05 \x03(\rR\x05ports\x12\x1a\n" + + "\binterval\x18\x06 \x01(\x03R\binterval\x12;\n" + + "\x1ainit_stream_receive_window\x18\a \x01(\x04R\x17initStreamReceiveWindow\x129\n" + + "\x19max_stream_receive_window\x18\b \x01(\x04R\x16maxStreamReceiveWindow\x127\n" + + "\x18init_conn_receive_window\x18\t \x01(\x04R\x15initConnReceiveWindow\x125\n" + + "\x17max_conn_receive_window\x18\n" + + " \x01(\x04R\x14maxConnReceiveWindow\x12(\n" + + "\x10max_idle_timeout\x18\v \x01(\x03R\x0emaxIdleTimeout\x12*\n" + + "\x11keep_alive_period\x18\f \x01(\x03R\x0fkeepAlivePeriod\x12;\n" + + "\x1adisable_path_mtu_discovery\x18\r \x01(\bR\x17disablePathMtuDiscoveryB\x82\x01\n" + + "$com.xray.transport.internet.hysteriaP\x01Z5github.com/xtls/xray-core/transport/internet/hysteria\xaa\x02 Xray.Transport.Internet.Hysteriab\x06proto3" + +var ( + file_transport_internet_hysteria_config_proto_rawDescOnce sync.Once + file_transport_internet_hysteria_config_proto_rawDescData []byte +) + +func file_transport_internet_hysteria_config_proto_rawDescGZIP() []byte { + file_transport_internet_hysteria_config_proto_rawDescOnce.Do(func() { + file_transport_internet_hysteria_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_transport_internet_hysteria_config_proto_rawDesc), len(file_transport_internet_hysteria_config_proto_rawDesc))) + }) + return file_transport_internet_hysteria_config_proto_rawDescData +} + +var file_transport_internet_hysteria_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_transport_internet_hysteria_config_proto_goTypes = []any{ + (*Config)(nil), // 0: xray.transport.internet.hysteria.Config +} +var file_transport_internet_hysteria_config_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_transport_internet_hysteria_config_proto_init() } +func file_transport_internet_hysteria_config_proto_init() { + if File_transport_internet_hysteria_config_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_hysteria_config_proto_rawDesc), len(file_transport_internet_hysteria_config_proto_rawDesc)), + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_transport_internet_hysteria_config_proto_goTypes, + DependencyIndexes: file_transport_internet_hysteria_config_proto_depIdxs, + MessageInfos: file_transport_internet_hysteria_config_proto_msgTypes, + }.Build() + File_transport_internet_hysteria_config_proto = out.File + file_transport_internet_hysteria_config_proto_goTypes = nil + file_transport_internet_hysteria_config_proto_depIdxs = nil +} diff --git a/transport/internet/hysteria/config.proto b/transport/internet/hysteria/config.proto new file mode 100644 index 000000000000..221133380d05 --- /dev/null +++ b/transport/internet/hysteria/config.proto @@ -0,0 +1,25 @@ +syntax = "proto3"; + +package xray.transport.internet.hysteria; +option csharp_namespace = "Xray.Transport.Internet.Hysteria"; +option go_package = "github.com/xtls/xray-core/transport/internet/hysteria"; +option java_package = "com.xray.transport.internet.hysteria"; +option java_multiple_files = true; + +message Config { + int32 version = 1; + string auth = 2; + uint64 up = 3; + uint64 down = 4; + repeated uint32 ports = 5; + int64 interval = 6; + + uint64 init_stream_receive_window = 7; + uint64 max_stream_receive_window = 8; + uint64 init_conn_receive_window = 9; + uint64 max_conn_receive_window = 10; + int64 max_idle_timeout = 11; + int64 keep_alive_period = 12; + bool disable_path_mtu_discovery = 13; +} + diff --git a/transport/internet/hysteria/congestion/bbr/bandwidth.go b/transport/internet/hysteria/congestion/bbr/bandwidth.go new file mode 100644 index 000000000000..52deb249654a --- /dev/null +++ b/transport/internet/hysteria/congestion/bbr/bandwidth.go @@ -0,0 +1,27 @@ +package bbr + +import ( + "math" + "time" + + "github.com/apernet/quic-go/congestion" +) + +const ( + infBandwidth = Bandwidth(math.MaxUint64) +) + +// Bandwidth of a connection +type Bandwidth uint64 + +const ( + // BitsPerSecond is 1 bit per second + BitsPerSecond Bandwidth = 1 + // BytesPerSecond is 1 byte per second + BytesPerSecond = 8 * BitsPerSecond +) + +// BandwidthFromDelta calculates the bandwidth from a number of bytes and a time delta +func BandwidthFromDelta(bytes congestion.ByteCount, delta time.Duration) Bandwidth { + return Bandwidth(bytes) * Bandwidth(time.Second) / Bandwidth(delta) * BytesPerSecond +} diff --git a/transport/internet/hysteria/congestion/bbr/bandwidth_sampler.go b/transport/internet/hysteria/congestion/bbr/bandwidth_sampler.go new file mode 100644 index 000000000000..a1e22d6fe00f --- /dev/null +++ b/transport/internet/hysteria/congestion/bbr/bandwidth_sampler.go @@ -0,0 +1,874 @@ +package bbr + +import ( + "math" + "time" + + "github.com/apernet/quic-go/congestion" +) + +const ( + infRTT = time.Duration(math.MaxInt64) + defaultConnectionStateMapQueueSize = 256 + defaultCandidatesBufferSize = 256 +) + +type roundTripCount uint64 + +// SendTimeState is a subset of ConnectionStateOnSentPacket which is returned +// to the caller when the packet is acked or lost. +type sendTimeState struct { + // Whether other states in this object is valid. + isValid bool + // Whether the sender is app limited at the time the packet was sent. + // App limited bandwidth sample might be artificially low because the sender + // did not have enough data to send in order to saturate the link. + isAppLimited bool + // Total number of sent bytes at the time the packet was sent. + // Includes the packet itself. + totalBytesSent congestion.ByteCount + // Total number of acked bytes at the time the packet was sent. + totalBytesAcked congestion.ByteCount + // Total number of lost bytes at the time the packet was sent. + totalBytesLost congestion.ByteCount + // Total number of inflight bytes at the time the packet was sent. + // Includes the packet itself. + // It should be equal to |total_bytes_sent| minus the sum of + // |total_bytes_acked|, |total_bytes_lost| and total neutered bytes. + bytesInFlight congestion.ByteCount +} + +func newSendTimeState( + isAppLimited bool, + totalBytesSent congestion.ByteCount, + totalBytesAcked congestion.ByteCount, + totalBytesLost congestion.ByteCount, + bytesInFlight congestion.ByteCount, +) *sendTimeState { + return &sendTimeState{ + isValid: true, + isAppLimited: isAppLimited, + totalBytesSent: totalBytesSent, + totalBytesAcked: totalBytesAcked, + totalBytesLost: totalBytesLost, + bytesInFlight: bytesInFlight, + } +} + +type extraAckedEvent struct { + // The excess bytes acknowlwedged in the time delta for this event. + extraAcked congestion.ByteCount + + // The bytes acknowledged and time delta from the event. + bytesAcked congestion.ByteCount + timeDelta time.Duration + // The round trip of the event. + round roundTripCount +} + +func maxExtraAckedEventFunc(a, b extraAckedEvent) int { + if a.extraAcked > b.extraAcked { + return 1 + } else if a.extraAcked < b.extraAcked { + return -1 + } + return 0 +} + +// BandwidthSample +type bandwidthSample struct { + // The bandwidth at that particular sample. Zero if no valid bandwidth sample + // is available. + bandwidth Bandwidth + // The RTT measurement at this particular sample. Zero if no RTT sample is + // available. Does not correct for delayed ack time. + rtt time.Duration + // |send_rate| is computed from the current packet being acked('P') and an + // earlier packet that is acked before P was sent. + sendRate Bandwidth + // States captured when the packet was sent. + stateAtSend sendTimeState +} + +func newBandwidthSample() *bandwidthSample { + return &bandwidthSample{ + sendRate: infBandwidth, + } +} + +// MaxAckHeightTracker is part of the BandwidthSampler. It is called after every +// ack event to keep track the degree of ack aggregation(a.k.a "ack height"). +type maxAckHeightTracker struct { + // Tracks the maximum number of bytes acked faster than the estimated + // bandwidth. + maxAckHeightFilter *WindowedFilter[extraAckedEvent, roundTripCount] + // The time this aggregation started and the number of bytes acked during it. + aggregationEpochStartTime congestion.Time + aggregationEpochBytes congestion.ByteCount + // The last sent packet number before the current aggregation epoch started. + lastSentPacketNumberBeforeEpoch congestion.PacketNumber + // The number of ack aggregation epochs ever started, including the ongoing + // one. Stats only. + numAckAggregationEpochs uint64 + ackAggregationBandwidthThreshold float64 + startNewAggregationEpochAfterFullRound bool + reduceExtraAckedOnBandwidthIncrease bool +} + +func newMaxAckHeightTracker(windowLength roundTripCount) *maxAckHeightTracker { + return &maxAckHeightTracker{ + maxAckHeightFilter: NewWindowedFilter(windowLength, maxExtraAckedEventFunc), + lastSentPacketNumberBeforeEpoch: invalidPacketNumber, + ackAggregationBandwidthThreshold: 1.0, + } +} + +func (m *maxAckHeightTracker) Get() congestion.ByteCount { + return m.maxAckHeightFilter.GetBest().extraAcked +} + +func (m *maxAckHeightTracker) Update( + bandwidthEstimate Bandwidth, + isNewMaxBandwidth bool, + roundTripCount roundTripCount, + lastSentPacketNumber congestion.PacketNumber, + lastAckedPacketNumber congestion.PacketNumber, + ackTime congestion.Time, + bytesAcked congestion.ByteCount, +) congestion.ByteCount { + forceNewEpoch := false + + if m.reduceExtraAckedOnBandwidthIncrease && isNewMaxBandwidth { + // Save and clear existing entries. + best := m.maxAckHeightFilter.GetBest() + secondBest := m.maxAckHeightFilter.GetSecondBest() + thirdBest := m.maxAckHeightFilter.GetThirdBest() + m.maxAckHeightFilter.Clear() + + // Reinsert the heights into the filter after recalculating. + expectedBytesAcked := bytesFromBandwidthAndTimeDelta(bandwidthEstimate, best.timeDelta) + if expectedBytesAcked < best.bytesAcked { + best.extraAcked = best.bytesAcked - expectedBytesAcked + m.maxAckHeightFilter.Update(best, best.round) + } + expectedBytesAcked = bytesFromBandwidthAndTimeDelta(bandwidthEstimate, secondBest.timeDelta) + if expectedBytesAcked < secondBest.bytesAcked { + secondBest.extraAcked = secondBest.bytesAcked - expectedBytesAcked + m.maxAckHeightFilter.Update(secondBest, secondBest.round) + } + expectedBytesAcked = bytesFromBandwidthAndTimeDelta(bandwidthEstimate, thirdBest.timeDelta) + if expectedBytesAcked < thirdBest.bytesAcked { + thirdBest.extraAcked = thirdBest.bytesAcked - expectedBytesAcked + m.maxAckHeightFilter.Update(thirdBest, thirdBest.round) + } + } + + // If any packet sent after the start of the epoch has been acked, start a new + // epoch. + if m.startNewAggregationEpochAfterFullRound && + m.lastSentPacketNumberBeforeEpoch != invalidPacketNumber && + lastAckedPacketNumber != invalidPacketNumber && + lastAckedPacketNumber > m.lastSentPacketNumberBeforeEpoch { + forceNewEpoch = true + } + if m.aggregationEpochStartTime.IsZero() || forceNewEpoch { + m.aggregationEpochBytes = bytesAcked + m.aggregationEpochStartTime = ackTime + m.lastSentPacketNumberBeforeEpoch = lastSentPacketNumber + m.numAckAggregationEpochs++ + return 0 + } + + // Compute how many bytes are expected to be delivered, assuming max bandwidth + // is correct. + aggregationDelta := ackTime.Sub(m.aggregationEpochStartTime) + expectedBytesAcked := bytesFromBandwidthAndTimeDelta(bandwidthEstimate, aggregationDelta) + // Reset the current aggregation epoch as soon as the ack arrival rate is less + // than or equal to the max bandwidth. + if m.aggregationEpochBytes <= congestion.ByteCount(m.ackAggregationBandwidthThreshold*float64(expectedBytesAcked)) { + // Reset to start measuring a new aggregation epoch. + m.aggregationEpochBytes = bytesAcked + m.aggregationEpochStartTime = ackTime + m.lastSentPacketNumberBeforeEpoch = lastSentPacketNumber + m.numAckAggregationEpochs++ + return 0 + } + + m.aggregationEpochBytes += bytesAcked + + // Compute how many extra bytes were delivered vs max bandwidth. + extraBytesAcked := m.aggregationEpochBytes - expectedBytesAcked + newEvent := extraAckedEvent{ + extraAcked: extraBytesAcked, + bytesAcked: m.aggregationEpochBytes, + timeDelta: aggregationDelta, + } + m.maxAckHeightFilter.Update(newEvent, roundTripCount) + return extraBytesAcked +} + +func (m *maxAckHeightTracker) SetFilterWindowLength(length roundTripCount) { + m.maxAckHeightFilter.SetWindowLength(length) +} + +func (m *maxAckHeightTracker) Reset(newHeight congestion.ByteCount, newTime roundTripCount) { + newEvent := extraAckedEvent{ + extraAcked: newHeight, + round: newTime, + } + m.maxAckHeightFilter.Reset(newEvent, newTime) +} + +func (m *maxAckHeightTracker) SetAckAggregationBandwidthThreshold(threshold float64) { + m.ackAggregationBandwidthThreshold = threshold +} + +func (m *maxAckHeightTracker) SetStartNewAggregationEpochAfterFullRound(value bool) { + m.startNewAggregationEpochAfterFullRound = value +} + +func (m *maxAckHeightTracker) SetReduceExtraAckedOnBandwidthIncrease(value bool) { + m.reduceExtraAckedOnBandwidthIncrease = value +} + +func (m *maxAckHeightTracker) AckAggregationBandwidthThreshold() float64 { + return m.ackAggregationBandwidthThreshold +} + +func (m *maxAckHeightTracker) NumAckAggregationEpochs() uint64 { + return m.numAckAggregationEpochs +} + +// AckPoint represents a point on the ack line. +type ackPoint struct { + ackTime congestion.Time + totalBytesAcked congestion.ByteCount +} + +// RecentAckPoints maintains the most recent 2 ack points at distinct times. +type recentAckPoints struct { + ackPoints [2]ackPoint +} + +func (r *recentAckPoints) Update(ackTime congestion.Time, totalBytesAcked congestion.ByteCount) { + if ackTime.Before(r.ackPoints[1].ackTime) { + r.ackPoints[1].ackTime = ackTime + } else if ackTime.After(r.ackPoints[1].ackTime) { + r.ackPoints[0] = r.ackPoints[1] + r.ackPoints[1].ackTime = ackTime + } + + r.ackPoints[1].totalBytesAcked = totalBytesAcked +} + +func (r *recentAckPoints) Clear() { + r.ackPoints[0] = ackPoint{} + r.ackPoints[1] = ackPoint{} +} + +func (r *recentAckPoints) MostRecentPoint() *ackPoint { + return &r.ackPoints[1] +} + +func (r *recentAckPoints) LessRecentPoint() *ackPoint { + if r.ackPoints[0].totalBytesAcked != 0 { + return &r.ackPoints[0] + } + + return &r.ackPoints[1] +} + +// ConnectionStateOnSentPacket represents the information about a sent packet +// and the state of the connection at the moment the packet was sent, +// specifically the information about the most recently acknowledged packet at +// that moment. +type connectionStateOnSentPacket struct { + // Time at which the packet is sent. + sentTime congestion.Time + // Size of the packet. + size congestion.ByteCount + // The value of |totalBytesSentAtLastAckedPacket| at the time the + // packet was sent. + totalBytesSentAtLastAckedPacket congestion.ByteCount + // The value of |lastAckedPacketSentTime| at the time the packet was + // sent. + lastAckedPacketSentTime congestion.Time + // The value of |lastAckedPacketAckTime| at the time the packet was + // sent. + lastAckedPacketAckTime congestion.Time + // Send time states that are returned to the congestion controller when the + // packet is acked or lost. + sendTimeState sendTimeState +} + +// Snapshot constructor. Records the current state of the bandwidth +// sampler. +// |bytes_in_flight| is the bytes in flight right after the packet is sent. +func newConnectionStateOnSentPacket( + sentTime congestion.Time, + size congestion.ByteCount, + bytesInFlight congestion.ByteCount, + sampler *bandwidthSampler, +) *connectionStateOnSentPacket { + return &connectionStateOnSentPacket{ + sentTime: sentTime, + size: size, + totalBytesSentAtLastAckedPacket: sampler.totalBytesSentAtLastAckedPacket, + lastAckedPacketSentTime: sampler.lastAckedPacketSentTime, + lastAckedPacketAckTime: sampler.lastAckedPacketAckTime, + sendTimeState: *newSendTimeState( + sampler.isAppLimited, + sampler.totalBytesSent, + sampler.totalBytesAcked, + sampler.totalBytesLost, + bytesInFlight, + ), + } +} + +// BandwidthSampler keeps track of sent and acknowledged packets and outputs a +// bandwidth sample for every packet acknowledged. The samples are taken for +// individual packets, and are not filtered; the consumer has to filter the +// bandwidth samples itself. In certain cases, the sampler will locally severely +// underestimate the bandwidth, hence a maximum filter with a size of at least +// one RTT is recommended. +// +// This class bases its samples on the slope of two curves: the number of bytes +// sent over time, and the number of bytes acknowledged as received over time. +// It produces a sample of both slopes for every packet that gets acknowledged, +// based on a slope between two points on each of the corresponding curves. Note +// that due to the packet loss, the number of bytes on each curve might get +// further and further away from each other, meaning that it is not feasible to +// compare byte values coming from different curves with each other. +// +// The obvious points for measuring slope sample are the ones corresponding to +// the packet that was just acknowledged. Let us denote them as S_1 (point at +// which the current packet was sent) and A_1 (point at which the current packet +// was acknowledged). However, taking a slope requires two points on each line, +// so estimating bandwidth requires picking a packet in the past with respect to +// which the slope is measured. +// +// For that purpose, BandwidthSampler always keeps track of the most recently +// acknowledged packet, and records it together with every outgoing packet. +// When a packet gets acknowledged (A_1), it has not only information about when +// it itself was sent (S_1), but also the information about the latest +// acknowledged packet right before it was sent (S_0 and A_0). +// +// Based on that data, send and ack rate are estimated as: +// +// send_rate = (bytes(S_1) - bytes(S_0)) / (time(S_1) - time(S_0)) +// ack_rate = (bytes(A_1) - bytes(A_0)) / (time(A_1) - time(A_0)) +// +// Here, the ack rate is intuitively the rate we want to treat as bandwidth. +// However, in certain cases (e.g. ack compression) the ack rate at a point may +// end up higher than the rate at which the data was originally sent, which is +// not indicative of the real bandwidth. Hence, we use the send rate as an upper +// bound, and the sample value is +// +// rate_sample = min(send_rate, ack_rate) +// +// An important edge case handled by the sampler is tracking the app-limited +// samples. There are multiple meaning of "app-limited" used interchangeably, +// hence it is important to understand and to be able to distinguish between +// them. +// +// Meaning 1: connection state. The connection is said to be app-limited when +// there is no outstanding data to send. This means that certain bandwidth +// samples in the future would not be an accurate indication of the link +// capacity, and it is important to inform consumer about that. Whenever +// connection becomes app-limited, the sampler is notified via OnAppLimited() +// method. +// +// Meaning 2: a phase in the bandwidth sampler. As soon as the bandwidth +// sampler becomes notified about the connection being app-limited, it enters +// app-limited phase. In that phase, all *sent* packets are marked as +// app-limited. Note that the connection itself does not have to be +// app-limited during the app-limited phase, and in fact it will not be +// (otherwise how would it send packets?). The boolean flag below indicates +// whether the sampler is in that phase. +// +// Meaning 3: a flag on the sent packet and on the sample. If a sent packet is +// sent during the app-limited phase, the resulting sample related to the +// packet will be marked as app-limited. +// +// With the terminology issue out of the way, let us consider the question of +// what kind of situation it addresses. +// +// Consider a scenario where we first send packets 1 to 20 at a regular +// bandwidth, and then immediately run out of data. After a few seconds, we send +// packets 21 to 60, and only receive ack for 21 between sending packets 40 and +// 41. In this case, when we sample bandwidth for packets 21 to 40, the S_0/A_0 +// we use to compute the slope is going to be packet 20, a few seconds apart +// from the current packet, hence the resulting estimate would be extremely low +// and not indicative of anything. Only at packet 41 the S_0/A_0 will become 21, +// meaning that the bandwidth sample would exclude the quiescence. +// +// Based on the analysis of that scenario, we implement the following rule: once +// OnAppLimited() is called, all sent packets will produce app-limited samples +// up until an ack for a packet that was sent after OnAppLimited() was called. +// Note that while the scenario above is not the only scenario when the +// connection is app-limited, the approach works in other cases too. + +type congestionEventSample struct { + // The maximum bandwidth sample from all acked packets. + // QuicBandwidth::Zero() if no samples are available. + sampleMaxBandwidth Bandwidth + // Whether |sample_max_bandwidth| is from a app-limited sample. + sampleIsAppLimited bool + // The minimum rtt sample from all acked packets. + // QuicTime::Delta::Infinite() if no samples are available. + sampleRtt time.Duration + // For each packet p in acked packets, this is the max value of INFLIGHT(p), + // where INFLIGHT(p) is the number of bytes acked while p is inflight. + sampleMaxInflight congestion.ByteCount + // The send state of the largest packet in acked_packets, unless it is + // empty. If acked_packets is empty, it's the send state of the largest + // packet in lost_packets. + lastPacketSendState sendTimeState + // The number of extra bytes acked from this ack event, compared to what is + // expected from the flow's bandwidth. Larger value means more ack + // aggregation. + extraAcked congestion.ByteCount +} + +func newCongestionEventSample() *congestionEventSample { + return &congestionEventSample{ + sampleRtt: infRTT, + } +} + +type bandwidthSampler struct { + // The total number of congestion controlled bytes sent during the connection. + totalBytesSent congestion.ByteCount + + // The total number of congestion controlled bytes which were acknowledged. + totalBytesAcked congestion.ByteCount + + // The total number of congestion controlled bytes which were lost. + totalBytesLost congestion.ByteCount + + // The total number of congestion controlled bytes which have been neutered. + totalBytesNeutered congestion.ByteCount + + // The value of |total_bytes_sent_| at the time the last acknowledged packet + // was sent. Valid only when |last_acked_packet_sent_time_| is valid. + totalBytesSentAtLastAckedPacket congestion.ByteCount + + // The time at which the last acknowledged packet was sent. Set to + // QuicTime::Zero() if no valid timestamp is available. + lastAckedPacketSentTime congestion.Time + + // The time at which the most recent packet was acknowledged. + lastAckedPacketAckTime congestion.Time + + // The most recently sent packet. + lastSentPacket congestion.PacketNumber + + // The most recently acked packet. + lastAckedPacket congestion.PacketNumber + + // Indicates whether the bandwidth sampler is currently in an app-limited + // phase. + isAppLimited bool + + // The packet that will be acknowledged after this one will cause the sampler + // to exit the app-limited phase. + endOfAppLimitedPhase congestion.PacketNumber + + // Record of the connection state at the point where each packet in flight was + // sent, indexed by the packet number. + connectionStateMap *packetNumberIndexedQueue[connectionStateOnSentPacket] + + recentAckPoints recentAckPoints + a0Candidates RingBuffer[ackPoint] + + // Maximum number of tracked packets. + maxTrackedPackets congestion.ByteCount + + maxAckHeightTracker *maxAckHeightTracker + totalBytesAckedAfterLastAckEvent congestion.ByteCount + + // True if connection option 'BSAO' is set. + overestimateAvoidance bool + + // True if connection option 'BBRB' is set. + limitMaxAckHeightTrackerBySendRate bool +} + +func newBandwidthSampler(maxAckHeightTrackerWindowLength roundTripCount) *bandwidthSampler { + b := &bandwidthSampler{ + maxAckHeightTracker: newMaxAckHeightTracker(maxAckHeightTrackerWindowLength), + connectionStateMap: newPacketNumberIndexedQueue[connectionStateOnSentPacket](defaultConnectionStateMapQueueSize), + lastSentPacket: invalidPacketNumber, + lastAckedPacket: invalidPacketNumber, + endOfAppLimitedPhase: invalidPacketNumber, + } + + b.a0Candidates.Init(defaultCandidatesBufferSize) + + return b +} + +func (b *bandwidthSampler) MaxAckHeight() congestion.ByteCount { + return b.maxAckHeightTracker.Get() +} + +func (b *bandwidthSampler) NumAckAggregationEpochs() uint64 { + return b.maxAckHeightTracker.NumAckAggregationEpochs() +} + +func (b *bandwidthSampler) SetMaxAckHeightTrackerWindowLength(length roundTripCount) { + b.maxAckHeightTracker.SetFilterWindowLength(length) +} + +func (b *bandwidthSampler) ResetMaxAckHeightTracker(newHeight congestion.ByteCount, newTime roundTripCount) { + b.maxAckHeightTracker.Reset(newHeight, newTime) +} + +func (b *bandwidthSampler) SetStartNewAggregationEpochAfterFullRound(value bool) { + b.maxAckHeightTracker.SetStartNewAggregationEpochAfterFullRound(value) +} + +func (b *bandwidthSampler) SetLimitMaxAckHeightTrackerBySendRate(value bool) { + b.limitMaxAckHeightTrackerBySendRate = value +} + +func (b *bandwidthSampler) SetReduceExtraAckedOnBandwidthIncrease(value bool) { + b.maxAckHeightTracker.SetReduceExtraAckedOnBandwidthIncrease(value) +} + +func (b *bandwidthSampler) EnableOverestimateAvoidance() { + if b.overestimateAvoidance { + return + } + + b.overestimateAvoidance = true + b.maxAckHeightTracker.SetAckAggregationBandwidthThreshold(2.0) +} + +func (b *bandwidthSampler) IsOverestimateAvoidanceEnabled() bool { + return b.overestimateAvoidance +} + +func (b *bandwidthSampler) OnPacketSent( + sentTime congestion.Time, + packetNumber congestion.PacketNumber, + bytes congestion.ByteCount, + bytesInFlight congestion.ByteCount, + isRetransmittable bool, +) { + b.lastSentPacket = packetNumber + + if !isRetransmittable { + return + } + + b.totalBytesSent += bytes + + // If there are no packets in flight, the time at which the new transmission + // opens can be treated as the A_0 point for the purpose of bandwidth + // sampling. This underestimates bandwidth to some extent, and produces some + // artificially low samples for most packets in flight, but it provides with + // samples at important points where we would not have them otherwise, most + // importantly at the beginning of the connection. + if bytesInFlight == 0 { + b.lastAckedPacketAckTime = sentTime + if b.overestimateAvoidance { + b.recentAckPoints.Clear() + b.recentAckPoints.Update(sentTime, b.totalBytesAcked) + b.a0Candidates.Clear() + b.a0Candidates.PushBack(*b.recentAckPoints.MostRecentPoint()) + } + b.totalBytesSentAtLastAckedPacket = b.totalBytesSent + + // In this situation ack compression is not a concern, set send rate to + // effectively infinite. + b.lastAckedPacketSentTime = sentTime + } + + b.connectionStateMap.Emplace(packetNumber, newConnectionStateOnSentPacket( + sentTime, + bytes, + bytesInFlight+bytes, + b, + )) +} + +func (b *bandwidthSampler) OnCongestionEvent( + ackTime congestion.Time, + ackedPackets []congestion.AckedPacketInfo, + lostPackets []congestion.LostPacketInfo, + maxBandwidth Bandwidth, + estBandwidthUpperBound Bandwidth, + roundTripCount roundTripCount, +) congestionEventSample { + eventSample := newCongestionEventSample() + + var lastLostPacketSendState sendTimeState + + for _, p := range lostPackets { + sendState := b.OnPacketLost(p.PacketNumber, p.BytesLost) + if sendState.isValid { + lastLostPacketSendState = sendState + } + } + + if len(ackedPackets) == 0 { + // Only populate send state for a loss-only event. + eventSample.lastPacketSendState = lastLostPacketSendState + return *eventSample + } + + var lastAckedPacketSendState sendTimeState + var maxSendRate Bandwidth + + for _, p := range ackedPackets { + sample := b.onPacketAcknowledged(ackTime, p.PacketNumber) + if !sample.stateAtSend.isValid { + continue + } + + lastAckedPacketSendState = sample.stateAtSend + + if sample.rtt != 0 { + eventSample.sampleRtt = min(eventSample.sampleRtt, sample.rtt) + } + if sample.bandwidth > eventSample.sampleMaxBandwidth { + eventSample.sampleMaxBandwidth = sample.bandwidth + eventSample.sampleIsAppLimited = sample.stateAtSend.isAppLimited + } + if sample.sendRate != infBandwidth { + maxSendRate = max(maxSendRate, sample.sendRate) + } + inflightSample := b.totalBytesAcked - lastAckedPacketSendState.totalBytesAcked + if inflightSample > eventSample.sampleMaxInflight { + eventSample.sampleMaxInflight = inflightSample + } + } + + if !lastLostPacketSendState.isValid { + eventSample.lastPacketSendState = lastAckedPacketSendState + } else if !lastAckedPacketSendState.isValid { + eventSample.lastPacketSendState = lastLostPacketSendState + } else { + // If two packets are inflight and an alarm is armed to lose a packet and it + // wakes up late, then the first of two in flight packets could have been + // acknowledged before the wakeup, which re-evaluates loss detection, and + // could declare the later of the two lost. + if lostPackets[len(lostPackets)-1].PacketNumber > ackedPackets[len(ackedPackets)-1].PacketNumber { + eventSample.lastPacketSendState = lastLostPacketSendState + } else { + eventSample.lastPacketSendState = lastAckedPacketSendState + } + } + + isNewMaxBandwidth := eventSample.sampleMaxBandwidth > maxBandwidth + maxBandwidth = max(maxBandwidth, eventSample.sampleMaxBandwidth) + if b.limitMaxAckHeightTrackerBySendRate { + maxBandwidth = max(maxBandwidth, maxSendRate) + } + + eventSample.extraAcked = b.onAckEventEnd(min(estBandwidthUpperBound, maxBandwidth), isNewMaxBandwidth, roundTripCount) + + return *eventSample +} + +func (b *bandwidthSampler) OnPacketLost(packetNumber congestion.PacketNumber, bytesLost congestion.ByteCount) (s sendTimeState) { + b.totalBytesLost += bytesLost + if sentPacketPointer := b.connectionStateMap.GetEntry(packetNumber); sentPacketPointer != nil { + sentPacketToSendTimeState(sentPacketPointer, &s) + } + return s +} + +func (b *bandwidthSampler) OnPacketNeutered(packetNumber congestion.PacketNumber) { + b.connectionStateMap.Remove(packetNumber, func(sentPacket connectionStateOnSentPacket) { + b.totalBytesNeutered += sentPacket.size + }) +} + +func (b *bandwidthSampler) OnAppLimited() { + b.isAppLimited = true + b.endOfAppLimitedPhase = b.lastSentPacket +} + +func (b *bandwidthSampler) RemoveObsoletePackets(leastUnacked congestion.PacketNumber) { + // A packet can become obsolete when it is removed from QuicUnackedPacketMap's + // view of inflight before it is acked or marked as lost. For example, when + // QuicSentPacketManager::RetransmitCryptoPackets retransmits a crypto packet, + // the packet is removed from QuicUnackedPacketMap's inflight, but is not + // marked as acked or lost in the BandwidthSampler. + b.connectionStateMap.RemoveUpTo(leastUnacked) +} + +func (b *bandwidthSampler) TotalBytesSent() congestion.ByteCount { + return b.totalBytesSent +} + +func (b *bandwidthSampler) TotalBytesLost() congestion.ByteCount { + return b.totalBytesLost +} + +func (b *bandwidthSampler) TotalBytesAcked() congestion.ByteCount { + return b.totalBytesAcked +} + +func (b *bandwidthSampler) TotalBytesNeutered() congestion.ByteCount { + return b.totalBytesNeutered +} + +func (b *bandwidthSampler) IsAppLimited() bool { + return b.isAppLimited +} + +func (b *bandwidthSampler) EndOfAppLimitedPhase() congestion.PacketNumber { + return b.endOfAppLimitedPhase +} + +func (b *bandwidthSampler) max_ack_height() congestion.ByteCount { + return b.maxAckHeightTracker.Get() +} + +func (b *bandwidthSampler) chooseA0Point(totalBytesAcked congestion.ByteCount, a0 *ackPoint) bool { + if b.a0Candidates.Empty() { + return false + } + + if b.a0Candidates.Len() == 1 { + *a0 = *b.a0Candidates.Front() + return true + } + + for i := 1; i < b.a0Candidates.Len(); i++ { + if b.a0Candidates.Offset(i).totalBytesAcked > totalBytesAcked { + *a0 = *b.a0Candidates.Offset(i - 1) + if i > 1 { + for j := 0; j < i-1; j++ { + b.a0Candidates.PopFront() + } + } + return true + } + } + + *a0 = *b.a0Candidates.Back() + for k := 0; k < b.a0Candidates.Len()-1; k++ { + b.a0Candidates.PopFront() + } + return true +} + +func (b *bandwidthSampler) onPacketAcknowledged(ackTime congestion.Time, packetNumber congestion.PacketNumber) bandwidthSample { + sample := newBandwidthSample() + b.lastAckedPacket = packetNumber + sentPacketPointer := b.connectionStateMap.GetEntry(packetNumber) + if sentPacketPointer == nil { + return *sample + } + + // OnPacketAcknowledgedInner + b.totalBytesAcked += sentPacketPointer.size + b.totalBytesSentAtLastAckedPacket = sentPacketPointer.sendTimeState.totalBytesSent + b.lastAckedPacketSentTime = sentPacketPointer.sentTime + b.lastAckedPacketAckTime = ackTime + if b.overestimateAvoidance { + b.recentAckPoints.Update(ackTime, b.totalBytesAcked) + } + + if b.isAppLimited { + // Exit app-limited phase in two cases: + // (1) end_of_app_limited_phase_ is not initialized, i.e., so far all + // packets are sent while there are buffered packets or pending data. + // (2) The current acked packet is after the sent packet marked as the end + // of the app limit phase. + if b.endOfAppLimitedPhase == invalidPacketNumber || + packetNumber > b.endOfAppLimitedPhase { + b.isAppLimited = false + } + } + + // There might have been no packets acknowledged at the moment when the + // current packet was sent. In that case, there is no bandwidth sample to + // make. + if sentPacketPointer.lastAckedPacketSentTime.IsZero() { + return *sample + } + + // Infinite rate indicates that the sampler is supposed to discard the + // current send rate sample and use only the ack rate. + sendRate := infBandwidth + if sentPacketPointer.sentTime.After(sentPacketPointer.lastAckedPacketSentTime) { + sendRate = BandwidthFromDelta( + sentPacketPointer.sendTimeState.totalBytesSent-sentPacketPointer.totalBytesSentAtLastAckedPacket, + sentPacketPointer.sentTime.Sub(sentPacketPointer.lastAckedPacketSentTime)) + } + + var a0 ackPoint + if b.overestimateAvoidance && b.chooseA0Point(sentPacketPointer.sendTimeState.totalBytesAcked, &a0) { + } else { + a0.ackTime = sentPacketPointer.lastAckedPacketAckTime + a0.totalBytesAcked = sentPacketPointer.sendTimeState.totalBytesAcked + } + + // During the slope calculation, ensure that ack time of the current packet is + // always larger than the time of the previous packet, otherwise division by + // zero or integer underflow can occur. + if ackTime.Sub(a0.ackTime) <= 0 { + return *sample + } + + ackRate := BandwidthFromDelta(b.totalBytesAcked-a0.totalBytesAcked, ackTime.Sub(a0.ackTime)) + + sample.bandwidth = min(sendRate, ackRate) + // Note: this sample does not account for delayed acknowledgement time. This + // means that the RTT measurements here can be artificially high, especially + // on low bandwidth connections. + sample.rtt = ackTime.Sub(sentPacketPointer.sentTime) + sample.sendRate = sendRate + sentPacketToSendTimeState(sentPacketPointer, &sample.stateAtSend) + + return *sample +} + +func (b *bandwidthSampler) onAckEventEnd( + bandwidthEstimate Bandwidth, + isNewMaxBandwidth bool, + roundTripCount roundTripCount, +) congestion.ByteCount { + newlyAckedBytes := b.totalBytesAcked - b.totalBytesAckedAfterLastAckEvent + if newlyAckedBytes == 0 { + return 0 + } + b.totalBytesAckedAfterLastAckEvent = b.totalBytesAcked + extraAcked := b.maxAckHeightTracker.Update( + bandwidthEstimate, + isNewMaxBandwidth, + roundTripCount, + b.lastSentPacket, + b.lastAckedPacket, + b.lastAckedPacketAckTime, + newlyAckedBytes) + // If |extra_acked| is zero, i.e. this ack event marks the start of a new ack + // aggregation epoch, save LessRecentPoint, which is the last ack point of the + // previous epoch, as a A0 candidate. + if b.overestimateAvoidance && extraAcked == 0 { + b.a0Candidates.PushBack(*b.recentAckPoints.LessRecentPoint()) + } + return extraAcked +} + +func sentPacketToSendTimeState(sentPacket *connectionStateOnSentPacket, sendTimeState *sendTimeState) { + *sendTimeState = sentPacket.sendTimeState + sendTimeState.isValid = true +} + +// BytesFromBandwidthAndTimeDelta calculates the bytes +// from a bandwidth(bits per second) and a time delta +func bytesFromBandwidthAndTimeDelta(bandwidth Bandwidth, delta time.Duration) congestion.ByteCount { + return (congestion.ByteCount(bandwidth) * congestion.ByteCount(delta)) / + (congestion.ByteCount(time.Second) * 8) +} + +func timeDeltaFromBytesAndBandwidth(bytes congestion.ByteCount, bandwidth Bandwidth) time.Duration { + return time.Duration(bytes*8) * time.Second / time.Duration(bandwidth) +} diff --git a/transport/internet/hysteria/congestion/bbr/bbr_sender.go b/transport/internet/hysteria/congestion/bbr/bbr_sender.go new file mode 100644 index 000000000000..93c5bf84db95 --- /dev/null +++ b/transport/internet/hysteria/congestion/bbr/bbr_sender.go @@ -0,0 +1,980 @@ +package bbr + +import ( + "fmt" + "math/rand" + "net" + "os" + "strconv" + "time" + + "github.com/apernet/quic-go/congestion" + + "github.com/xtls/xray-core/transport/internet/hysteria/congestion/common" +) + +// BbrSender implements BBR congestion control algorithm. BBR aims to estimate +// the current available Bottleneck Bandwidth and RTT (hence the name), and +// regulates the pacing rate and the size of the congestion window based on +// those signals. +// +// BBR relies on pacing in order to function properly. Do not use BBR when +// pacing is disabled. +// + +const ( + minBps = 65536 // 64 KB/s + + invalidPacketNumber = -1 + initialCongestionWindowPackets = 32 + + // Constants based on TCP defaults. + // The minimum CWND to ensure delayed acks don't reduce bandwidth measurements. + // Does not inflate the pacing rate. + defaultMinimumCongestionWindow = 4 * congestion.ByteCount(congestion.InitialPacketSize) + + // The gain used for the STARTUP, equal to 2/ln(2). + defaultHighGain = 2.885 + // The newly derived gain for STARTUP, equal to 4 * ln(2) + derivedHighGain = 2.773 + // The newly derived CWND gain for STARTUP, 2. + derivedHighCWNDGain = 2.0 + + debugEnv = "HYSTERIA_BBR_DEBUG" +) + +// The cycle of gains used during the PROBE_BW stage. +var pacingGain = [...]float64{1.25, 0.75, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0} + +const ( + // The length of the gain cycle. + gainCycleLength = len(pacingGain) + // The size of the bandwidth filter window, in round-trips. + bandwidthWindowSize = gainCycleLength + 2 + + // The time after which the current min_rtt value expires. + minRttExpiry = 10 * time.Second + // The minimum time the connection can spend in PROBE_RTT mode. + probeRttTime = 200 * time.Millisecond + // If the bandwidth does not increase by the factor of |kStartupGrowthTarget| + // within |kRoundTripsWithoutGrowthBeforeExitingStartup| rounds, the connection + // will exit the STARTUP mode. + startupGrowthTarget = 1.25 + roundTripsWithoutGrowthBeforeExitingStartup = int64(3) + + // Flag. + defaultStartupFullLossCount = 8 + quicBbr2DefaultLossThreshold = 0.02 + maxBbrBurstPackets = 10 +) + +type bbrMode int + +const ( + // Startup phase of the connection. + bbrModeStartup = iota + // After achieving the highest possible bandwidth during the startup, lower + // the pacing rate in order to drain the queue. + bbrModeDrain + // Cruising mode. + bbrModeProbeBw + // Temporarily slow down sending in order to empty the buffer and measure + // the real minimum RTT. + bbrModeProbeRtt +) + +// Indicates how the congestion control limits the amount of bytes in flight. +type bbrRecoveryState int + +const ( + // Do not limit. + bbrRecoveryStateNotInRecovery = iota + // Allow an extra outstanding byte for each byte acknowledged. + bbrRecoveryStateConservation + // Allow two extra outstanding bytes for each byte acknowledged (slow + // start). + bbrRecoveryStateGrowth +) + +type bbrSender struct { + rttStats congestion.RTTStatsProvider + clock Clock + pacer *common.Pacer + + mode bbrMode + + // Bandwidth sampler provides BBR with the bandwidth measurements at + // individual points. + sampler *bandwidthSampler + + // The number of the round trips that have occurred during the connection. + roundTripCount roundTripCount + + // The packet number of the most recently sent packet. + lastSentPacket congestion.PacketNumber + // Acknowledgement of any packet after |current_round_trip_end_| will cause + // the round trip counter to advance. + currentRoundTripEnd congestion.PacketNumber + + // Number of congestion events with some losses, in the current round. + numLossEventsInRound uint64 + + // Number of total bytes lost in the current round. + bytesLostInRound congestion.ByteCount + + // The filter that tracks the maximum bandwidth over the multiple recent + // round-trips. + maxBandwidth *WindowedFilter[Bandwidth, roundTripCount] + + // Minimum RTT estimate. Automatically expires within 10 seconds (and + // triggers PROBE_RTT mode) if no new value is sampled during that period. + minRtt time.Duration + // The time at which the current value of |min_rtt_| was assigned. + minRttTimestamp congestion.Time + + // The maximum allowed number of bytes in flight. + congestionWindow congestion.ByteCount + + // The initial value of the |congestion_window_|. + initialCongestionWindow congestion.ByteCount + + // The largest value the |congestion_window_| can achieve. + maxCongestionWindow congestion.ByteCount + + // The smallest value the |congestion_window_| can achieve. + minCongestionWindow congestion.ByteCount + + // The pacing gain applied during the STARTUP phase. + highGain float64 + + // The CWND gain applied during the STARTUP phase. + highCwndGain float64 + + // The pacing gain applied during the DRAIN phase. + drainGain float64 + + // The current pacing rate of the connection. + pacingRate Bandwidth + + // The gain currently applied to the pacing rate. + pacingGain float64 + // The gain currently applied to the congestion window. + congestionWindowGain float64 + + // The gain used for the congestion window during PROBE_BW. Latched from + // quic_bbr_cwnd_gain flag. + congestionWindowGainConstant float64 + // The number of RTTs to stay in STARTUP mode. Defaults to 3. + numStartupRtts int64 + + // Number of round-trips in PROBE_BW mode, used for determining the current + // pacing gain cycle. + cycleCurrentOffset int + // The time at which the last pacing gain cycle was started. + lastCycleStart congestion.Time + + // Indicates whether the connection has reached the full bandwidth mode. + isAtFullBandwidth bool + // Number of rounds during which there was no significant bandwidth increase. + roundsWithoutBandwidthGain int64 + // The bandwidth compared to which the increase is measured. + bandwidthAtLastRound Bandwidth + + // Set to true upon exiting quiescence. + exitingQuiescence bool + + // Time at which PROBE_RTT has to be exited. Setting it to zero indicates + // that the time is yet unknown as the number of packets in flight has not + // reached the required value. + exitProbeRttAt congestion.Time + // Indicates whether a round-trip has passed since PROBE_RTT became active. + probeRttRoundPassed bool + + // Indicates whether the most recent bandwidth sample was marked as + // app-limited. + lastSampleIsAppLimited bool + // Indicates whether any non app-limited samples have been recorded. + hasNoAppLimitedSample bool + + // Current state of recovery. + recoveryState bbrRecoveryState + // Receiving acknowledgement of a packet after |end_recovery_at_| will cause + // BBR to exit the recovery mode. A value above zero indicates at least one + // loss has been detected, so it must not be set back to zero. + endRecoveryAt congestion.PacketNumber + // A window used to limit the number of bytes in flight during loss recovery. + recoveryWindow congestion.ByteCount + // If true, consider all samples in recovery app-limited. + isAppLimitedRecovery bool // not used + + // When true, pace at 1.5x and disable packet conservation in STARTUP. + slowerStartup bool // not used + // When true, disables packet conservation in STARTUP. + rateBasedStartup bool // not used + + // When true, add the most recent ack aggregation measurement during STARTUP. + enableAckAggregationDuringStartup bool + // When true, expire the windowed ack aggregation values in STARTUP when + // bandwidth increases more than 25%. + expireAckAggregationInStartup bool + + // If true, will not exit low gain mode until bytes_in_flight drops below BDP + // or it's time for high gain mode. + drainToTarget bool + + // If true, slow down pacing rate in STARTUP when overshooting is detected. + detectOvershooting bool + // Bytes lost while detect_overshooting_ is true. + bytesLostWhileDetectingOvershooting congestion.ByteCount + // Slow down pacing rate if + // bytes_lost_while_detecting_overshooting_ * + // bytes_lost_multiplier_while_detecting_overshooting_ > IW. + bytesLostMultiplierWhileDetectingOvershooting uint8 + // When overshooting is detected, do not drop pacing_rate_ below this value / + // min_rtt. + cwndToCalculateMinPacingRate congestion.ByteCount + + // Max congestion window when adjusting network parameters. + maxCongestionWindowWithNetworkParametersAdjusted congestion.ByteCount // not used + + // Params. + maxDatagramSize congestion.ByteCount + // Recorded on packet sent. equivalent |unacked_packets_->bytes_in_flight()| + bytesInFlight congestion.ByteCount + + debug bool +} + +var _ congestion.CongestionControl = &bbrSender{} + +func NewBbrSender( + clock Clock, + initialMaxDatagramSize congestion.ByteCount, +) *bbrSender { + return newBbrSender( + clock, + initialMaxDatagramSize, + initialCongestionWindowPackets*initialMaxDatagramSize, + congestion.MaxCongestionWindowPackets*initialMaxDatagramSize, + ) +} + +func newBbrSender( + clock Clock, + initialMaxDatagramSize, + initialCongestionWindow, + initialMaxCongestionWindow congestion.ByteCount, +) *bbrSender { + debug, _ := strconv.ParseBool(os.Getenv(debugEnv)) + b := &bbrSender{ + clock: clock, + mode: bbrModeStartup, + sampler: newBandwidthSampler(roundTripCount(bandwidthWindowSize)), + lastSentPacket: invalidPacketNumber, + currentRoundTripEnd: invalidPacketNumber, + maxBandwidth: NewWindowedFilter(roundTripCount(bandwidthWindowSize), MaxFilter[Bandwidth]), + congestionWindow: initialCongestionWindow, + initialCongestionWindow: initialCongestionWindow, + maxCongestionWindow: initialMaxCongestionWindow, + minCongestionWindow: defaultMinimumCongestionWindow, + highGain: defaultHighGain, + highCwndGain: defaultHighGain, + drainGain: 1.0 / defaultHighGain, + pacingGain: 1.0, + congestionWindowGain: 1.0, + congestionWindowGainConstant: 2.0, + numStartupRtts: roundTripsWithoutGrowthBeforeExitingStartup, + recoveryState: bbrRecoveryStateNotInRecovery, + endRecoveryAt: invalidPacketNumber, + recoveryWindow: initialMaxCongestionWindow, + bytesLostMultiplierWhileDetectingOvershooting: 2, + cwndToCalculateMinPacingRate: initialCongestionWindow, + maxCongestionWindowWithNetworkParametersAdjusted: initialMaxCongestionWindow, + maxDatagramSize: initialMaxDatagramSize, + debug: debug, + } + b.pacer = common.NewPacer(b.bandwidthForPacer) + + /* + if b.tracer != nil { + b.lastState = logging.CongestionStateStartup + b.tracer.UpdatedCongestionState(logging.CongestionStateStartup) + } + */ + + b.enterStartupMode(b.clock.Now()) + b.setHighCwndGain(derivedHighCWNDGain) + + return b +} + +func (b *bbrSender) SetRTTStatsProvider(provider congestion.RTTStatsProvider) { + b.rttStats = provider +} + +// TimeUntilSend implements the SendAlgorithm interface. +func (b *bbrSender) TimeUntilSend(bytesInFlight congestion.ByteCount) congestion.Time { + return b.pacer.TimeUntilSend() +} + +// HasPacingBudget implements the SendAlgorithm interface. +func (b *bbrSender) HasPacingBudget(now congestion.Time) bool { + return b.pacer.Budget(now) >= b.maxDatagramSize +} + +// OnPacketSent implements the SendAlgorithm interface. +func (b *bbrSender) OnPacketSent( + sentTime congestion.Time, + bytesInFlight congestion.ByteCount, + packetNumber congestion.PacketNumber, + bytes congestion.ByteCount, + isRetransmittable bool, +) { + b.pacer.SentPacket(sentTime, bytes) + + b.lastSentPacket = packetNumber + b.bytesInFlight = bytesInFlight + + if bytesInFlight == 0 { + b.exitingQuiescence = true + } + + b.sampler.OnPacketSent(sentTime, packetNumber, bytes, bytesInFlight, isRetransmittable) +} + +// CanSend implements the SendAlgorithm interface. +func (b *bbrSender) CanSend(bytesInFlight congestion.ByteCount) bool { + return bytesInFlight < b.GetCongestionWindow() +} + +// MaybeExitSlowStart implements the SendAlgorithm interface. +func (b *bbrSender) MaybeExitSlowStart() { + // Do nothing +} + +// OnPacketAcked implements the SendAlgorithm interface. +func (b *bbrSender) OnPacketAcked(number congestion.PacketNumber, ackedBytes, priorInFlight congestion.ByteCount, eventTime congestion.Time) { + // Do nothing. +} + +// OnPacketLost implements the SendAlgorithm interface. +func (b *bbrSender) OnPacketLost(number congestion.PacketNumber, lostBytes, priorInFlight congestion.ByteCount) { + // Do nothing. +} + +// OnRetransmissionTimeout implements the SendAlgorithm interface. +func (b *bbrSender) OnRetransmissionTimeout(packetsRetransmitted bool) { + // Do nothing. +} + +// SetMaxDatagramSize implements the SendAlgorithm interface. +func (b *bbrSender) SetMaxDatagramSize(s congestion.ByteCount) { + if s < b.maxDatagramSize { + panic(fmt.Sprintf("congestion BUG: decreased max datagram size from %d to %d", b.maxDatagramSize, s)) + } + cwndIsMinCwnd := b.congestionWindow == b.minCongestionWindow + b.maxDatagramSize = s + if cwndIsMinCwnd { + b.congestionWindow = b.minCongestionWindow + } + b.pacer.SetMaxDatagramSize(s) +} + +// InSlowStart implements the SendAlgorithmWithDebugInfos interface. +func (b *bbrSender) InSlowStart() bool { + return b.mode == bbrModeStartup +} + +// InRecovery implements the SendAlgorithmWithDebugInfos interface. +func (b *bbrSender) InRecovery() bool { + return b.recoveryState != bbrRecoveryStateNotInRecovery +} + +// GetCongestionWindow implements the SendAlgorithmWithDebugInfos interface. +func (b *bbrSender) GetCongestionWindow() congestion.ByteCount { + if b.mode == bbrModeProbeRtt { + return b.probeRttCongestionWindow() + } + + if b.InRecovery() { + return min(b.congestionWindow, b.recoveryWindow) + } + + return b.congestionWindow +} + +func (b *bbrSender) OnCongestionEvent(number congestion.PacketNumber, lostBytes, priorInFlight congestion.ByteCount) { + // Do nothing. +} + +func (b *bbrSender) OnCongestionEventEx(priorInFlight congestion.ByteCount, eventTime congestion.Time, ackedPackets []congestion.AckedPacketInfo, lostPackets []congestion.LostPacketInfo) { + totalBytesAckedBefore := b.sampler.TotalBytesAcked() + totalBytesLostBefore := b.sampler.TotalBytesLost() + + var isRoundStart, minRttExpired bool + var excessAcked, bytesLost congestion.ByteCount + + // The send state of the largest packet in acked_packets, unless it is + // empty. If acked_packets is empty, it's the send state of the largest + // packet in lost_packets. + var lastPacketSendState sendTimeState + + b.maybeAppLimited(priorInFlight) + + // Update bytesInFlight + b.bytesInFlight = priorInFlight + for _, p := range ackedPackets { + b.bytesInFlight -= p.BytesAcked + } + for _, p := range lostPackets { + b.bytesInFlight -= p.BytesLost + } + + if len(ackedPackets) != 0 { + lastAckedPacket := ackedPackets[len(ackedPackets)-1].PacketNumber + isRoundStart = b.updateRoundTripCounter(lastAckedPacket) + b.updateRecoveryState(lastAckedPacket, len(lostPackets) != 0, isRoundStart) + } + + sample := b.sampler.OnCongestionEvent(eventTime, + ackedPackets, lostPackets, b.maxBandwidth.GetBest(), infBandwidth, b.roundTripCount) + if sample.lastPacketSendState.isValid { + b.lastSampleIsAppLimited = sample.lastPacketSendState.isAppLimited + b.hasNoAppLimitedSample = b.hasNoAppLimitedSample || !b.lastSampleIsAppLimited + } + // Avoid updating |max_bandwidth_| if a) this is a loss-only event, or b) all + // packets in |acked_packets| did not generate valid samples. (e.g. ack of + // ack-only packets). In both cases, sampler_.total_bytes_acked() will not + // change. + if totalBytesAckedBefore != b.sampler.TotalBytesAcked() { + if !sample.sampleIsAppLimited || sample.sampleMaxBandwidth > b.maxBandwidth.GetBest() { + b.maxBandwidth.Update(sample.sampleMaxBandwidth, b.roundTripCount) + } + } + + if sample.sampleRtt != infRTT { + minRttExpired = b.maybeUpdateMinRtt(eventTime, sample.sampleRtt) + } + bytesLost = b.sampler.TotalBytesLost() - totalBytesLostBefore + + excessAcked = sample.extraAcked + lastPacketSendState = sample.lastPacketSendState + + if len(lostPackets) != 0 { + b.numLossEventsInRound++ + b.bytesLostInRound += bytesLost + } + + // Handle logic specific to PROBE_BW mode. + if b.mode == bbrModeProbeBw { + b.updateGainCyclePhase(eventTime, priorInFlight, len(lostPackets) != 0) + } + + // Handle logic specific to STARTUP and DRAIN modes. + if isRoundStart && !b.isAtFullBandwidth { + b.checkIfFullBandwidthReached(&lastPacketSendState) + } + + b.maybeExitStartupOrDrain(eventTime) + + // Handle logic specific to PROBE_RTT. + b.maybeEnterOrExitProbeRtt(eventTime, isRoundStart, minRttExpired) + + // Calculate number of packets acked and lost. + bytesAcked := b.sampler.TotalBytesAcked() - totalBytesAckedBefore + + // After the model is updated, recalculate the pacing rate and congestion + // window. + b.calculatePacingRate(bytesLost) + b.calculateCongestionWindow(bytesAcked, excessAcked) + b.calculateRecoveryWindow(bytesAcked, bytesLost) + + // Cleanup internal state. + // This is where we clean up obsolete (acked or lost) packets from the bandwidth sampler. + // The "least unacked" should actually be FirstOutstanding, but since we are not passing + // that through OnCongestionEventEx, we will only do an estimate using acked/lost packets + // for now. Because of fast retransmission, they should differ by no more than 2 packets. + // (this is controlled by packetThreshold in quic-go's sentPacketHandler) + var leastUnacked congestion.PacketNumber + if len(ackedPackets) != 0 { + leastUnacked = ackedPackets[len(ackedPackets)-1].PacketNumber - 2 + } else { + leastUnacked = lostPackets[len(lostPackets)-1].PacketNumber + 1 + } + b.sampler.RemoveObsoletePackets(leastUnacked) + + if isRoundStart { + b.numLossEventsInRound = 0 + b.bytesLostInRound = 0 + } +} + +func (b *bbrSender) PacingRate() Bandwidth { + if b.pacingRate == 0 { + return Bandwidth(b.highGain * float64( + BandwidthFromDelta(b.initialCongestionWindow, b.getMinRtt()))) + } + + return b.pacingRate +} + +func (b *bbrSender) hasGoodBandwidthEstimateForResumption() bool { + return b.hasNonAppLimitedSample() +} + +func (b *bbrSender) hasNonAppLimitedSample() bool { + return b.hasNoAppLimitedSample +} + +// Sets the pacing gain used in STARTUP. Must be greater than 1. +func (b *bbrSender) setHighGain(highGain float64) { + b.highGain = highGain + if b.mode == bbrModeStartup { + b.pacingGain = highGain + } +} + +// Sets the CWND gain used in STARTUP. Must be greater than 1. +func (b *bbrSender) setHighCwndGain(highCwndGain float64) { + b.highCwndGain = highCwndGain + if b.mode == bbrModeStartup { + b.congestionWindowGain = highCwndGain + } +} + +// Sets the gain used in DRAIN. Must be less than 1. +func (b *bbrSender) setDrainGain(drainGain float64) { + b.drainGain = drainGain +} + +// Get the current bandwidth estimate. Note that Bandwidth is in bits per second. +func (b *bbrSender) bandwidthEstimate() Bandwidth { + return b.maxBandwidth.GetBest() +} + +func (b *bbrSender) bandwidthForPacer() congestion.ByteCount { + bps := congestion.ByteCount(float64(b.PacingRate()) / float64(BytesPerSecond)) + if bps < minBps { + // We need to make sure that the bandwidth value for pacer is never zero, + // otherwise it will go into an edge case where HasPacingBudget = false + // but TimeUntilSend is before, causing the quic-go send loop to go crazy and get stuck. + return minBps + } + return bps +} + +// Returns the current estimate of the RTT of the connection. Outside of the +// edge cases, this is minimum RTT. +func (b *bbrSender) getMinRtt() time.Duration { + if b.minRtt != 0 { + return b.minRtt + } + // min_rtt could be available if the handshake packet gets neutered then + // gets acknowledged. This could only happen for QUIC crypto where we do not + // drop keys. + minRtt := b.rttStats.MinRTT() + if minRtt == 0 { + return 100 * time.Millisecond + } else { + return minRtt + } +} + +// Computes the target congestion window using the specified gain. +func (b *bbrSender) getTargetCongestionWindow(gain float64) congestion.ByteCount { + bdp := bdpFromRttAndBandwidth(b.getMinRtt(), b.bandwidthEstimate()) + congestionWindow := congestion.ByteCount(gain * float64(bdp)) + + // BDP estimate will be zero if no bandwidth samples are available yet. + if congestionWindow == 0 { + congestionWindow = congestion.ByteCount(gain * float64(b.initialCongestionWindow)) + } + + return max(congestionWindow, b.minCongestionWindow) +} + +// The target congestion window during PROBE_RTT. +func (b *bbrSender) probeRttCongestionWindow() congestion.ByteCount { + return b.minCongestionWindow +} + +func (b *bbrSender) maybeUpdateMinRtt(now congestion.Time, sampleMinRtt time.Duration) bool { + // Do not expire min_rtt if none was ever available. + minRttExpired := b.minRtt != 0 && now.After(b.minRttTimestamp.Add(minRttExpiry)) + if minRttExpired || sampleMinRtt < b.minRtt || b.minRtt == 0 { + b.minRtt = sampleMinRtt + b.minRttTimestamp = now + } + + return minRttExpired +} + +// Enters the STARTUP mode. +func (b *bbrSender) enterStartupMode(now congestion.Time) { + b.mode = bbrModeStartup + // b.maybeTraceStateChange(logging.CongestionStateStartup) + b.pacingGain = b.highGain + b.congestionWindowGain = b.highCwndGain + + if b.debug { + b.debugPrint("Phase: STARTUP") + } +} + +// Enters the PROBE_BW mode. +func (b *bbrSender) enterProbeBandwidthMode(now congestion.Time) { + b.mode = bbrModeProbeBw + // b.maybeTraceStateChange(logging.CongestionStateProbeBw) + b.congestionWindowGain = b.congestionWindowGainConstant + + // Pick a random offset for the gain cycle out of {0, 2..7} range. 1 is + // excluded because in that case increased gain and decreased gain would not + // follow each other. + b.cycleCurrentOffset = int(rand.Int31n(congestion.PacketsPerConnectionID)) % (gainCycleLength - 1) + if b.cycleCurrentOffset >= 1 { + b.cycleCurrentOffset += 1 + } + + b.lastCycleStart = now + b.pacingGain = pacingGain[b.cycleCurrentOffset] + + if b.debug { + b.debugPrint("Phase: PROBE_BW") + } +} + +// Updates the round-trip counter if a round-trip has passed. Returns true if +// the counter has been advanced. +func (b *bbrSender) updateRoundTripCounter(lastAckedPacket congestion.PacketNumber) bool { + if b.currentRoundTripEnd == invalidPacketNumber || lastAckedPacket > b.currentRoundTripEnd { + b.roundTripCount++ + b.currentRoundTripEnd = b.lastSentPacket + return true + } + return false +} + +// Updates the current gain used in PROBE_BW mode. +func (b *bbrSender) updateGainCyclePhase(now congestion.Time, priorInFlight congestion.ByteCount, hasLosses bool) { + // In most cases, the cycle is advanced after an RTT passes. + shouldAdvanceGainCycling := now.After(b.lastCycleStart.Add(b.getMinRtt())) + // If the pacing gain is above 1.0, the connection is trying to probe the + // bandwidth by increasing the number of bytes in flight to at least + // pacing_gain * BDP. Make sure that it actually reaches the target, as long + // as there are no losses suggesting that the buffers are not able to hold + // that much. + if b.pacingGain > 1.0 && !hasLosses && priorInFlight < b.getTargetCongestionWindow(b.pacingGain) { + shouldAdvanceGainCycling = false + } + + // If pacing gain is below 1.0, the connection is trying to drain the extra + // queue which could have been incurred by probing prior to it. If the number + // of bytes in flight falls down to the estimated BDP value earlier, conclude + // that the queue has been successfully drained and exit this cycle early. + if b.pacingGain < 1.0 && b.bytesInFlight <= b.getTargetCongestionWindow(1) { + shouldAdvanceGainCycling = true + } + + if shouldAdvanceGainCycling { + b.cycleCurrentOffset = (b.cycleCurrentOffset + 1) % gainCycleLength + b.lastCycleStart = now + // Stay in low gain mode until the target BDP is hit. + // Low gain mode will be exited immediately when the target BDP is achieved. + if b.drainToTarget && b.pacingGain < 1 && + pacingGain[b.cycleCurrentOffset] == 1 && + b.bytesInFlight > b.getTargetCongestionWindow(1) { + return + } + b.pacingGain = pacingGain[b.cycleCurrentOffset] + } +} + +// Tracks for how many round-trips the bandwidth has not increased +// significantly. +func (b *bbrSender) checkIfFullBandwidthReached(lastPacketSendState *sendTimeState) { + if b.lastSampleIsAppLimited { + return + } + + target := Bandwidth(float64(b.bandwidthAtLastRound) * startupGrowthTarget) + if b.bandwidthEstimate() >= target { + b.bandwidthAtLastRound = b.bandwidthEstimate() + b.roundsWithoutBandwidthGain = 0 + if b.expireAckAggregationInStartup { + // Expire old excess delivery measurements now that bandwidth increased. + b.sampler.ResetMaxAckHeightTracker(0, b.roundTripCount) + } + return + } + + b.roundsWithoutBandwidthGain++ + if b.roundsWithoutBandwidthGain >= b.numStartupRtts || + b.shouldExitStartupDueToLoss(lastPacketSendState) { + b.isAtFullBandwidth = true + } +} + +func (b *bbrSender) maybeAppLimited(bytesInFlight congestion.ByteCount) { + if bytesInFlight < b.getTargetCongestionWindow(1) { + b.sampler.OnAppLimited() + } +} + +// Transitions from STARTUP to DRAIN and from DRAIN to PROBE_BW if +// appropriate. +func (b *bbrSender) maybeExitStartupOrDrain(now congestion.Time) { + if b.mode == bbrModeStartup && b.isAtFullBandwidth { + b.mode = bbrModeDrain + // b.maybeTraceStateChange(logging.CongestionStateDrain) + b.pacingGain = b.drainGain + b.congestionWindowGain = b.highCwndGain + + if b.debug { + b.debugPrint("Phase: DRAIN") + } + } + if b.mode == bbrModeDrain && b.bytesInFlight <= b.getTargetCongestionWindow(1) { + b.enterProbeBandwidthMode(now) + } +} + +// Decides whether to enter or exit PROBE_RTT. +func (b *bbrSender) maybeEnterOrExitProbeRtt(now congestion.Time, isRoundStart, minRttExpired bool) { + if minRttExpired && !b.exitingQuiescence && b.mode != bbrModeProbeRtt { + b.mode = bbrModeProbeRtt + // b.maybeTraceStateChange(logging.CongestionStateProbRtt) + b.pacingGain = 1.0 + // Do not decide on the time to exit PROBE_RTT until the |bytes_in_flight| + // is at the target small value. + b.exitProbeRttAt = 0 + + if b.debug { + b.debugPrint("BandwidthEstimate: %s, CongestionWindowGain: %.2f, PacingGain: %.2f, PacingRate: %s", + formatSpeed(b.bandwidthEstimate()), b.congestionWindowGain, b.pacingGain, formatSpeed(b.PacingRate())) + b.debugPrint("Phase: PROBE_RTT") + } + } + + if b.mode == bbrModeProbeRtt { + b.sampler.OnAppLimited() + // b.maybeTraceStateChange(logging.CongestionStateApplicationLimited) + + if b.exitProbeRttAt.IsZero() { + // If the window has reached the appropriate size, schedule exiting + // PROBE_RTT. The CWND during PROBE_RTT is kMinimumCongestionWindow, but + // we allow an extra packet since QUIC checks CWND before sending a + // packet. + if b.bytesInFlight < b.probeRttCongestionWindow()+congestion.MaxPacketBufferSize { + b.exitProbeRttAt = now.Add(probeRttTime) + b.probeRttRoundPassed = false + } + } else { + if isRoundStart { + b.probeRttRoundPassed = true + } + if now.Sub(b.exitProbeRttAt) >= 0 && b.probeRttRoundPassed { + b.minRttTimestamp = now + if b.debug { + b.debugPrint("MinRTT: %s", b.getMinRtt()) + } + if !b.isAtFullBandwidth { + b.enterStartupMode(now) + } else { + b.enterProbeBandwidthMode(now) + } + } + } + } + + b.exitingQuiescence = false +} + +// Determines whether BBR needs to enter, exit or advance state of the +// recovery. +func (b *bbrSender) updateRecoveryState(lastAckedPacket congestion.PacketNumber, hasLosses, isRoundStart bool) { + // Disable recovery in startup, if loss-based exit is enabled. + if !b.isAtFullBandwidth { + return + } + + // Exit recovery when there are no losses for a round. + if hasLosses { + b.endRecoveryAt = b.lastSentPacket + } + + switch b.recoveryState { + case bbrRecoveryStateNotInRecovery: + if hasLosses { + b.recoveryState = bbrRecoveryStateConservation + // This will cause the |recovery_window_| to be set to the correct + // value in CalculateRecoveryWindow(). + b.recoveryWindow = 0 + // Since the conservation phase is meant to be lasting for a whole + // round, extend the current round as if it were started right now. + b.currentRoundTripEnd = b.lastSentPacket + } + case bbrRecoveryStateConservation: + if isRoundStart { + b.recoveryState = bbrRecoveryStateGrowth + } + fallthrough + case bbrRecoveryStateGrowth: + // Exit recovery if appropriate. + if !hasLosses && lastAckedPacket > b.endRecoveryAt { + b.recoveryState = bbrRecoveryStateNotInRecovery + } + } +} + +// Determines the appropriate pacing rate for the connection. +func (b *bbrSender) calculatePacingRate(bytesLost congestion.ByteCount) { + if b.bandwidthEstimate() == 0 { + return + } + + targetRate := Bandwidth(b.pacingGain * float64(b.bandwidthEstimate())) + if b.isAtFullBandwidth { + b.pacingRate = targetRate + return + } + + // Pace at the rate of initial_window / RTT as soon as RTT measurements are + // available. + if b.pacingRate == 0 && b.rttStats.MinRTT() != 0 { + b.pacingRate = BandwidthFromDelta(b.initialCongestionWindow, b.rttStats.MinRTT()) + return + } + + if b.detectOvershooting { + b.bytesLostWhileDetectingOvershooting += bytesLost + // Check for overshooting with network parameters adjusted when pacing rate + // > target_rate and loss has been detected. + if b.pacingRate > targetRate && b.bytesLostWhileDetectingOvershooting > 0 { + if b.hasNoAppLimitedSample || + b.bytesLostWhileDetectingOvershooting*congestion.ByteCount(b.bytesLostMultiplierWhileDetectingOvershooting) > b.initialCongestionWindow { + // We are fairly sure overshoot happens if 1) there is at least one + // non app-limited bw sample or 2) half of IW gets lost. Slow pacing + // rate. + b.pacingRate = max(targetRate, BandwidthFromDelta(b.cwndToCalculateMinPacingRate, b.rttStats.MinRTT())) + b.bytesLostWhileDetectingOvershooting = 0 + b.detectOvershooting = false + } + } + } + + // Do not decrease the pacing rate during startup. + b.pacingRate = max(b.pacingRate, targetRate) +} + +// Determines the appropriate congestion window for the connection. +func (b *bbrSender) calculateCongestionWindow(bytesAcked, excessAcked congestion.ByteCount) { + if b.mode == bbrModeProbeRtt { + return + } + + targetWindow := b.getTargetCongestionWindow(b.congestionWindowGain) + if b.isAtFullBandwidth { + // Add the max recently measured ack aggregation to CWND. + targetWindow += b.sampler.MaxAckHeight() + } else if b.enableAckAggregationDuringStartup { + // Add the most recent excess acked. Because CWND never decreases in + // STARTUP, this will automatically create a very localized max filter. + targetWindow += excessAcked + } + + // Instead of immediately setting the target CWND as the new one, BBR grows + // the CWND towards |target_window| by only increasing it |bytes_acked| at a + // time. + if b.isAtFullBandwidth { + b.congestionWindow = min(targetWindow, b.congestionWindow+bytesAcked) + } else if b.congestionWindow < targetWindow || + b.sampler.TotalBytesAcked() < b.initialCongestionWindow { + // If the connection is not yet out of startup phase, do not decrease the + // window. + b.congestionWindow += bytesAcked + } + + // Enforce the limits on the congestion window. + b.congestionWindow = max(b.congestionWindow, b.minCongestionWindow) + b.congestionWindow = min(b.congestionWindow, b.maxCongestionWindow) +} + +// Determines the appropriate window that constrains the in-flight during recovery. +func (b *bbrSender) calculateRecoveryWindow(bytesAcked, bytesLost congestion.ByteCount) { + if b.recoveryState == bbrRecoveryStateNotInRecovery { + return + } + + // Set up the initial recovery window. + if b.recoveryWindow == 0 { + b.recoveryWindow = b.bytesInFlight + bytesAcked + b.recoveryWindow = max(b.minCongestionWindow, b.recoveryWindow) + return + } + + // Remove losses from the recovery window, while accounting for a potential + // integer underflow. + if b.recoveryWindow >= bytesLost { + b.recoveryWindow = b.recoveryWindow - bytesLost + } else { + b.recoveryWindow = b.maxDatagramSize + } + + // In CONSERVATION mode, just subtracting losses is sufficient. In GROWTH, + // release additional |bytes_acked| to achieve a slow-start-like behavior. + if b.recoveryState == bbrRecoveryStateGrowth { + b.recoveryWindow += bytesAcked + } + + // Always allow sending at least |bytes_acked| in response. + b.recoveryWindow = max(b.recoveryWindow, b.bytesInFlight+bytesAcked) + b.recoveryWindow = max(b.minCongestionWindow, b.recoveryWindow) +} + +// Return whether we should exit STARTUP due to excessive loss. +func (b *bbrSender) shouldExitStartupDueToLoss(lastPacketSendState *sendTimeState) bool { + if b.numLossEventsInRound < defaultStartupFullLossCount || !lastPacketSendState.isValid { + return false + } + + inflightAtSend := lastPacketSendState.bytesInFlight + + if inflightAtSend > 0 && b.bytesLostInRound > 0 { + if b.bytesLostInRound > congestion.ByteCount(float64(inflightAtSend)*quicBbr2DefaultLossThreshold) { + return true + } + return false + } + return false +} + +func (b *bbrSender) debugPrint(format string, a ...any) { + fmt.Printf("[BBRSender] [%s] %s\n", + time.Now().Format("15:04:05"), + fmt.Sprintf(format, a...)) +} + +func bdpFromRttAndBandwidth(rtt time.Duration, bandwidth Bandwidth) congestion.ByteCount { + return congestion.ByteCount(rtt) * congestion.ByteCount(bandwidth) / congestion.ByteCount(BytesPerSecond) / congestion.ByteCount(time.Second) +} + +func GetInitialPacketSize(addr net.Addr) congestion.ByteCount { + // If this is not a UDP address, we don't know anything about the MTU. + // Use the minimum size of an Initial packet as the max packet size. + if _, ok := addr.(*net.UDPAddr); ok { + return congestion.InitialPacketSize + } else { + return congestion.MinInitialPacketSize + } +} + +func formatSpeed(bw Bandwidth) string { + bwf := float64(bw) + units := []string{"bps", "Kbps", "Mbps", "Gbps"} + unitIndex := 0 + for bwf > 1000 && unitIndex < len(units)-1 { + bwf /= 1000 + unitIndex++ + } + return fmt.Sprintf("%.2f %s", bwf, units[unitIndex]) +} diff --git a/transport/internet/hysteria/congestion/bbr/clock.go b/transport/internet/hysteria/congestion/bbr/clock.go new file mode 100644 index 000000000000..fee7df6ab74f --- /dev/null +++ b/transport/internet/hysteria/congestion/bbr/clock.go @@ -0,0 +1,18 @@ +package bbr + +import "github.com/apernet/quic-go/congestion" + +// A Clock returns the current time +type Clock interface { + Now() congestion.Time +} + +// DefaultClock implements the Clock interface using the Go stdlib clock. +type DefaultClock struct{} + +var _ Clock = DefaultClock{} + +// Now gets the current time +func (DefaultClock) Now() congestion.Time { + return congestion.Now() +} diff --git a/transport/internet/hysteria/congestion/bbr/packet_number_indexed_queue.go b/transport/internet/hysteria/congestion/bbr/packet_number_indexed_queue.go new file mode 100644 index 000000000000..08b99deadf07 --- /dev/null +++ b/transport/internet/hysteria/congestion/bbr/packet_number_indexed_queue.go @@ -0,0 +1,199 @@ +package bbr + +import ( + "github.com/apernet/quic-go/congestion" +) + +// packetNumberIndexedQueue is a queue of mostly continuous numbered entries +// which supports the following operations: +// - adding elements to the end of the queue, or at some point past the end +// - removing elements in any order +// - retrieving elements +// If all elements are inserted in order, all of the operations above are +// amortized O(1) time. +// +// Internally, the data structure is a deque where each element is marked as +// present or not. The deque starts at the lowest present index. Whenever an +// element is removed, it's marked as not present, and the front of the deque is +// cleared of elements that are not present. +// +// The tail of the queue is not cleared due to the assumption of entries being +// inserted in order, though removing all elements of the queue will return it +// to its initial state. +// +// Note that this data structure is inherently hazardous, since an addition of +// just two entries will cause it to consume all of the memory available. +// Because of that, it is not a general-purpose container and should not be used +// as one. + +type entryWrapper[T any] struct { + present bool + entry T +} + +type packetNumberIndexedQueue[T any] struct { + entries RingBuffer[entryWrapper[T]] + numberOfPresentEntries int + firstPacket congestion.PacketNumber +} + +func newPacketNumberIndexedQueue[T any](size int) *packetNumberIndexedQueue[T] { + q := &packetNumberIndexedQueue[T]{ + firstPacket: invalidPacketNumber, + } + + q.entries.Init(size) + + return q +} + +// Emplace inserts data associated |packet_number| into (or past) the end of the +// queue, filling up the missing intermediate entries as necessary. Returns +// true if the element has been inserted successfully, false if it was already +// in the queue or inserted out of order. +func (p *packetNumberIndexedQueue[T]) Emplace(packetNumber congestion.PacketNumber, entry *T) bool { + if packetNumber == invalidPacketNumber || entry == nil { + return false + } + + if p.IsEmpty() { + p.entries.PushBack(entryWrapper[T]{ + present: true, + entry: *entry, + }) + p.numberOfPresentEntries = 1 + p.firstPacket = packetNumber + return true + } + + // Do not allow insertion out-of-order. + if packetNumber <= p.LastPacket() { + return false + } + + // Handle potentially missing elements. + offset := int(packetNumber - p.FirstPacket()) + if gap := offset - p.entries.Len(); gap > 0 { + for i := 0; i < gap; i++ { + p.entries.PushBack(entryWrapper[T]{}) + } + } + + p.entries.PushBack(entryWrapper[T]{ + present: true, + entry: *entry, + }) + p.numberOfPresentEntries++ + return true +} + +// GetEntry Retrieve the entry associated with the packet number. Returns the pointer +// to the entry in case of success, or nullptr if the entry does not exist. +func (p *packetNumberIndexedQueue[T]) GetEntry(packetNumber congestion.PacketNumber) *T { + ew := p.getEntryWraper(packetNumber) + if ew == nil { + return nil + } + + return &ew.entry +} + +// Remove, Same as above, but if an entry is present in the queue, also call f(entry) +// before removing it. +func (p *packetNumberIndexedQueue[T]) Remove(packetNumber congestion.PacketNumber, f func(T)) bool { + ew := p.getEntryWraper(packetNumber) + if ew == nil { + return false + } + if f != nil { + f(ew.entry) + } + ew.present = false + p.numberOfPresentEntries-- + + if packetNumber == p.FirstPacket() { + p.clearup() + } + + return true +} + +// RemoveUpTo, but not including |packet_number|. +// Unused slots in the front are also removed, which means when the function +// returns, |first_packet()| can be larger than |packet_number|. +func (p *packetNumberIndexedQueue[T]) RemoveUpTo(packetNumber congestion.PacketNumber) { + for !p.entries.Empty() && + p.firstPacket != invalidPacketNumber && + p.firstPacket < packetNumber { + if p.entries.Front().present { + p.numberOfPresentEntries-- + } + p.entries.PopFront() + p.firstPacket++ + } + p.clearup() + + return +} + +// IsEmpty return if queue is empty. +func (p *packetNumberIndexedQueue[T]) IsEmpty() bool { + return p.numberOfPresentEntries == 0 +} + +// NumberOfPresentEntries returns the number of entries in the queue. +func (p *packetNumberIndexedQueue[T]) NumberOfPresentEntries() int { + return p.numberOfPresentEntries +} + +// EntrySlotsUsed returns the number of entries allocated in the underlying deque. This is +// proportional to the memory usage of the queue. +func (p *packetNumberIndexedQueue[T]) EntrySlotsUsed() int { + return p.entries.Len() +} + +// FirstPacket returns packet number of the first entry in the queue. +func (p *packetNumberIndexedQueue[T]) FirstPacket() (packetNumber congestion.PacketNumber) { + return p.firstPacket +} + +// LastPacket returns packet number of the last entry ever inserted in the queue. Note that the +// entry in question may have already been removed. Zero if the queue is +// empty. +func (p *packetNumberIndexedQueue[T]) LastPacket() (packetNumber congestion.PacketNumber) { + if p.IsEmpty() { + return invalidPacketNumber + } + + return p.firstPacket + congestion.PacketNumber(p.entries.Len()-1) +} + +func (p *packetNumberIndexedQueue[T]) clearup() { + for !p.entries.Empty() && !p.entries.Front().present { + p.entries.PopFront() + p.firstPacket++ + } + if p.entries.Empty() { + p.firstPacket = invalidPacketNumber + } +} + +func (p *packetNumberIndexedQueue[T]) getEntryWraper(packetNumber congestion.PacketNumber) *entryWrapper[T] { + if packetNumber == invalidPacketNumber || + p.IsEmpty() || + packetNumber < p.firstPacket { + return nil + } + + offset := int(packetNumber - p.firstPacket) + if offset >= p.entries.Len() { + return nil + } + + ew := p.entries.Offset(offset) + if ew == nil || !ew.present { + return nil + } + + return ew +} diff --git a/transport/internet/hysteria/congestion/bbr/ringbuffer.go b/transport/internet/hysteria/congestion/bbr/ringbuffer.go new file mode 100644 index 000000000000..ed92d4ce0124 --- /dev/null +++ b/transport/internet/hysteria/congestion/bbr/ringbuffer.go @@ -0,0 +1,118 @@ +package bbr + +// A RingBuffer is a ring buffer. +// It acts as a heap that doesn't cause any allocations. +type RingBuffer[T any] struct { + ring []T + headPos, tailPos int + full bool +} + +// Init preallocs a buffer with a certain size. +func (r *RingBuffer[T]) Init(size int) { + r.ring = make([]T, size) +} + +// Len returns the number of elements in the ring buffer. +func (r *RingBuffer[T]) Len() int { + if r.full { + return len(r.ring) + } + if r.tailPos >= r.headPos { + return r.tailPos - r.headPos + } + return r.tailPos - r.headPos + len(r.ring) +} + +// Empty says if the ring buffer is empty. +func (r *RingBuffer[T]) Empty() bool { + return !r.full && r.headPos == r.tailPos +} + +// PushBack adds a new element. +// If the ring buffer is full, its capacity is increased first. +func (r *RingBuffer[T]) PushBack(t T) { + if r.full || len(r.ring) == 0 { + r.grow() + } + r.ring[r.tailPos] = t + r.tailPos++ + if r.tailPos == len(r.ring) { + r.tailPos = 0 + } + if r.tailPos == r.headPos { + r.full = true + } +} + +// PopFront returns the next element. +// It must not be called when the buffer is empty, that means that +// callers might need to check if there are elements in the buffer first. +func (r *RingBuffer[T]) PopFront() T { + if r.Empty() { + panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: pop from an empty queue") + } + r.full = false + t := r.ring[r.headPos] + r.ring[r.headPos] = *new(T) + r.headPos++ + if r.headPos == len(r.ring) { + r.headPos = 0 + } + return t +} + +// Offset returns the offset element. +// It must not be called when the buffer is empty, that means that +// callers might need to check if there are elements in the buffer first +// and check if the index larger than buffer length. +func (r *RingBuffer[T]) Offset(index int) *T { + if r.Empty() || index >= r.Len() { + panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: offset from invalid index") + } + offset := (r.headPos + index) % len(r.ring) + return &r.ring[offset] +} + +// Front returns the front element. +// It must not be called when the buffer is empty, that means that +// callers might need to check if there are elements in the buffer first. +func (r *RingBuffer[T]) Front() *T { + if r.Empty() { + panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: front from an empty queue") + } + return &r.ring[r.headPos] +} + +// Back returns the back element. +// It must not be called when the buffer is empty, that means that +// callers might need to check if there are elements in the buffer first. +func (r *RingBuffer[T]) Back() *T { + if r.Empty() { + panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: back from an empty queue") + } + return r.Offset(r.Len() - 1) +} + +// Grow the maximum size of the queue. +// This method assume the queue is full. +func (r *RingBuffer[T]) grow() { + oldRing := r.ring + newSize := len(oldRing) * 2 + if newSize == 0 { + newSize = 1 + } + r.ring = make([]T, newSize) + headLen := copy(r.ring, oldRing[r.headPos:]) + copy(r.ring[headLen:], oldRing[:r.headPos]) + r.headPos, r.tailPos, r.full = 0, len(oldRing), false +} + +// Clear removes all elements. +func (r *RingBuffer[T]) Clear() { + var zeroValue T + for i := range r.ring { + r.ring[i] = zeroValue + } + r.headPos, r.tailPos, r.full = 0, 0, false +} diff --git a/transport/internet/hysteria/congestion/bbr/windowed_filter.go b/transport/internet/hysteria/congestion/bbr/windowed_filter.go new file mode 100644 index 000000000000..4773bce597dc --- /dev/null +++ b/transport/internet/hysteria/congestion/bbr/windowed_filter.go @@ -0,0 +1,162 @@ +package bbr + +import ( + "golang.org/x/exp/constraints" +) + +// Implements Kathleen Nichols' algorithm for tracking the minimum (or maximum) +// estimate of a stream of samples over some fixed time interval. (E.g., +// the minimum RTT over the past five minutes.) The algorithm keeps track of +// the best, second best, and third best min (or max) estimates, maintaining an +// invariant that the measurement time of the n'th best >= n-1'th best. + +// The algorithm works as follows. On a reset, all three estimates are set to +// the same sample. The second best estimate is then recorded in the second +// quarter of the window, and a third best estimate is recorded in the second +// half of the window, bounding the worst case error when the true min is +// monotonically increasing (or true max is monotonically decreasing) over the +// window. +// +// A new best sample replaces all three estimates, since the new best is lower +// (or higher) than everything else in the window and it is the most recent. +// The window thus effectively gets reset on every new min. The same property +// holds true for second best and third best estimates. Specifically, when a +// sample arrives that is better than the second best but not better than the +// best, it replaces the second and third best estimates but not the best +// estimate. Similarly, a sample that is better than the third best estimate +// but not the other estimates replaces only the third best estimate. +// +// Finally, when the best expires, it is replaced by the second best, which in +// turn is replaced by the third best. The newest sample replaces the third +// best. + +type WindowedFilterValue interface { + any +} + +type WindowedFilterTime interface { + constraints.Integer | constraints.Float +} + +type WindowedFilter[V WindowedFilterValue, T WindowedFilterTime] struct { + // Time length of window. + windowLength T + estimates []entry[V, T] + comparator func(V, V) int +} + +type entry[V WindowedFilterValue, T WindowedFilterTime] struct { + sample V + time T +} + +// Compares two values and returns true if the first is greater than or equal +// to the second. +func MaxFilter[O constraints.Ordered](a, b O) int { + if a > b { + return 1 + } else if a < b { + return -1 + } + return 0 +} + +// Compares two values and returns true if the first is less than or equal +// to the second. +func MinFilter[O constraints.Ordered](a, b O) int { + if a < b { + return 1 + } else if a > b { + return -1 + } + return 0 +} + +func NewWindowedFilter[V WindowedFilterValue, T WindowedFilterTime](windowLength T, comparator func(V, V) int) *WindowedFilter[V, T] { + return &WindowedFilter[V, T]{ + windowLength: windowLength, + estimates: make([]entry[V, T], 3, 3), + comparator: comparator, + } +} + +// Changes the window length. Does not update any current samples. +func (f *WindowedFilter[V, T]) SetWindowLength(windowLength T) { + f.windowLength = windowLength +} + +func (f *WindowedFilter[V, T]) GetBest() V { + return f.estimates[0].sample +} + +func (f *WindowedFilter[V, T]) GetSecondBest() V { + return f.estimates[1].sample +} + +func (f *WindowedFilter[V, T]) GetThirdBest() V { + return f.estimates[2].sample +} + +// Updates best estimates with |sample|, and expires and updates best +// estimates as necessary. +func (f *WindowedFilter[V, T]) Update(newSample V, newTime T) { + // Reset all estimates if they have not yet been initialized, if new sample + // is a new best, or if the newest recorded estimate is too old. + if f.comparator(f.estimates[0].sample, *new(V)) == 0 || + f.comparator(newSample, f.estimates[0].sample) >= 0 || + newTime-f.estimates[2].time > f.windowLength { + f.Reset(newSample, newTime) + return + } + + if f.comparator(newSample, f.estimates[1].sample) >= 0 { + f.estimates[1] = entry[V, T]{newSample, newTime} + f.estimates[2] = f.estimates[1] + } else if f.comparator(newSample, f.estimates[2].sample) >= 0 { + f.estimates[2] = entry[V, T]{newSample, newTime} + } + + // Expire and update estimates as necessary. + if newTime-f.estimates[0].time > f.windowLength { + // The best estimate hasn't been updated for an entire window, so promote + // second and third best estimates. + f.estimates[0] = f.estimates[1] + f.estimates[1] = f.estimates[2] + f.estimates[2] = entry[V, T]{newSample, newTime} + // Need to iterate one more time. Check if the new best estimate is + // outside the window as well, since it may also have been recorded a + // long time ago. Don't need to iterate once more since we cover that + // case at the beginning of the method. + if newTime-f.estimates[0].time > f.windowLength { + f.estimates[0] = f.estimates[1] + f.estimates[1] = f.estimates[2] + } + return + } + if f.comparator(f.estimates[1].sample, f.estimates[0].sample) == 0 && + newTime-f.estimates[1].time > f.windowLength/4 { + // A quarter of the window has passed without a better sample, so the + // second-best estimate is taken from the second quarter of the window. + f.estimates[1] = entry[V, T]{newSample, newTime} + f.estimates[2] = f.estimates[1] + return + } + + if f.comparator(f.estimates[2].sample, f.estimates[1].sample) == 0 && + newTime-f.estimates[2].time > f.windowLength/2 { + // We've passed a half of the window without a better estimate, so take + // a third-best estimate from the second half of the window. + f.estimates[2] = entry[V, T]{newSample, newTime} + } +} + +// Resets all estimates to new sample. +func (f *WindowedFilter[V, T]) Reset(newSample V, newTime T) { + f.estimates[2] = entry[V, T]{newSample, newTime} + f.estimates[1] = f.estimates[2] + f.estimates[0] = f.estimates[1] +} + +func (f *WindowedFilter[V, T]) Clear() { + f.estimates = make([]entry[V, T], 3, 3) +} diff --git a/transport/internet/hysteria/congestion/brutal/brutal.go b/transport/internet/hysteria/congestion/brutal/brutal.go new file mode 100644 index 000000000000..ba0bfda81104 --- /dev/null +++ b/transport/internet/hysteria/congestion/brutal/brutal.go @@ -0,0 +1,185 @@ +package brutal + +import ( + "fmt" + "os" + "strconv" + "time" + + "github.com/xtls/xray-core/transport/internet/hysteria/congestion/common" + + "github.com/apernet/quic-go/congestion" +) + +const ( + pktInfoSlotCount = 5 // slot index is based on seconds, so this is basically how many seconds we sample + minSampleCount = 50 + minAckRate = 0.8 + congestionWindowMultiplier = 2 + + debugEnv = "HYSTERIA_BRUTAL_DEBUG" + debugPrintInterval = 2 +) + +var _ congestion.CongestionControl = &BrutalSender{} + +type BrutalSender struct { + rttStats congestion.RTTStatsProvider + bps congestion.ByteCount + maxDatagramSize congestion.ByteCount + pacer *common.Pacer + + pktInfoSlots [pktInfoSlotCount]pktInfo + ackRate float64 + + debug bool + lastAckPrintTimestamp int64 +} + +type pktInfo struct { + Timestamp int64 + AckCount uint64 + LossCount uint64 +} + +func NewBrutalSender(bps uint64) *BrutalSender { + debug, _ := strconv.ParseBool(os.Getenv(debugEnv)) + bs := &BrutalSender{ + bps: congestion.ByteCount(bps), + maxDatagramSize: congestion.InitialPacketSize, + ackRate: 1, + debug: debug, + } + bs.pacer = common.NewPacer(func() congestion.ByteCount { + return congestion.ByteCount(float64(bs.bps) / bs.ackRate) + }) + return bs +} + +func (b *BrutalSender) SetRTTStatsProvider(rttStats congestion.RTTStatsProvider) { + b.rttStats = rttStats +} + +func (b *BrutalSender) TimeUntilSend(bytesInFlight congestion.ByteCount) congestion.Time { + return b.pacer.TimeUntilSend() +} + +func (b *BrutalSender) HasPacingBudget(now congestion.Time) bool { + return b.pacer.Budget(now) >= b.maxDatagramSize +} + +func (b *BrutalSender) CanSend(bytesInFlight congestion.ByteCount) bool { + return bytesInFlight <= b.GetCongestionWindow() +} + +func (b *BrutalSender) GetCongestionWindow() congestion.ByteCount { + rtt := b.rttStats.SmoothedRTT() + if rtt <= 0 { + return 10240 + } + cwnd := congestion.ByteCount(float64(b.bps) * rtt.Seconds() * congestionWindowMultiplier / b.ackRate) + if cwnd < b.maxDatagramSize { + cwnd = b.maxDatagramSize + } + return cwnd +} + +func (b *BrutalSender) OnPacketSent(sentTime congestion.Time, bytesInFlight congestion.ByteCount, + packetNumber congestion.PacketNumber, bytes congestion.ByteCount, isRetransmittable bool, +) { + b.pacer.SentPacket(sentTime, bytes) +} + +func (b *BrutalSender) OnPacketAcked(number congestion.PacketNumber, ackedBytes congestion.ByteCount, + priorInFlight congestion.ByteCount, eventTime congestion.Time, +) { + // Stub +} + +func (b *BrutalSender) OnCongestionEvent(number congestion.PacketNumber, lostBytes congestion.ByteCount, + priorInFlight congestion.ByteCount, +) { + // Stub +} + +func (b *BrutalSender) OnCongestionEventEx(priorInFlight congestion.ByteCount, eventTime congestion.Time, ackedPackets []congestion.AckedPacketInfo, lostPackets []congestion.LostPacketInfo) { + currentTimestamp := int64(time.Duration(eventTime) / time.Second) + slot := currentTimestamp % pktInfoSlotCount + if b.pktInfoSlots[slot].Timestamp == currentTimestamp { + b.pktInfoSlots[slot].LossCount += uint64(len(lostPackets)) + b.pktInfoSlots[slot].AckCount += uint64(len(ackedPackets)) + } else { + // uninitialized slot or too old, reset + b.pktInfoSlots[slot].Timestamp = currentTimestamp + b.pktInfoSlots[slot].AckCount = uint64(len(ackedPackets)) + b.pktInfoSlots[slot].LossCount = uint64(len(lostPackets)) + } + b.updateAckRate(currentTimestamp) +} + +func (b *BrutalSender) SetMaxDatagramSize(size congestion.ByteCount) { + b.maxDatagramSize = size + b.pacer.SetMaxDatagramSize(size) + if b.debug { + b.debugPrint("SetMaxDatagramSize: %d", size) + } +} + +func (b *BrutalSender) updateAckRate(currentTimestamp int64) { + minTimestamp := currentTimestamp - pktInfoSlotCount + var ackCount, lossCount uint64 + for _, info := range b.pktInfoSlots { + if info.Timestamp < minTimestamp { + continue + } + ackCount += info.AckCount + lossCount += info.LossCount + } + if ackCount+lossCount < minSampleCount { + b.ackRate = 1 + if b.canPrintAckRate(currentTimestamp) { + b.lastAckPrintTimestamp = currentTimestamp + b.debugPrint("Not enough samples (total=%d, ack=%d, loss=%d, rtt=%d)", + ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds()) + } + return + } + rate := float64(ackCount) / float64(ackCount+lossCount) + if rate < minAckRate { + b.ackRate = minAckRate + if b.canPrintAckRate(currentTimestamp) { + b.lastAckPrintTimestamp = currentTimestamp + b.debugPrint("ACK rate too low: %.2f, clamped to %.2f (total=%d, ack=%d, loss=%d, rtt=%d)", + rate, minAckRate, ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds()) + } + return + } + b.ackRate = rate + if b.canPrintAckRate(currentTimestamp) { + b.lastAckPrintTimestamp = currentTimestamp + b.debugPrint("ACK rate: %.2f (total=%d, ack=%d, loss=%d, rtt=%d)", + rate, ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds()) + } +} + +func (b *BrutalSender) InSlowStart() bool { + return false +} + +func (b *BrutalSender) InRecovery() bool { + return false +} + +func (b *BrutalSender) MaybeExitSlowStart() {} + +func (b *BrutalSender) OnRetransmissionTimeout(packetsRetransmitted bool) {} + +func (b *BrutalSender) canPrintAckRate(currentTimestamp int64) bool { + return b.debug && currentTimestamp-b.lastAckPrintTimestamp >= debugPrintInterval +} + +func (b *BrutalSender) debugPrint(format string, a ...any) { + fmt.Printf("[BrutalSender] [%s] %s\n", + time.Now().Format("15:04:05"), + fmt.Sprintf(format, a...)) +} diff --git a/transport/internet/hysteria/congestion/common/pacer.go b/transport/internet/hysteria/congestion/common/pacer.go new file mode 100644 index 000000000000..779c2f1d1c45 --- /dev/null +++ b/transport/internet/hysteria/congestion/common/pacer.go @@ -0,0 +1,79 @@ +package common + +import ( + "time" + + "github.com/apernet/quic-go/congestion" +) + +const ( + maxBurstPackets = 10 + maxBurstPacingDelayMultiplier = 4 +) + +// Pacer implements a token bucket pacing algorithm. +type Pacer struct { + budgetAtLastSent congestion.ByteCount + maxDatagramSize congestion.ByteCount + lastSentTime congestion.Time + getBandwidth func() congestion.ByteCount // in bytes/s +} + +func NewPacer(getBandwidth func() congestion.ByteCount) *Pacer { + p := &Pacer{ + budgetAtLastSent: maxBurstPackets * congestion.InitialPacketSize, + maxDatagramSize: congestion.InitialPacketSize, + getBandwidth: getBandwidth, + } + return p +} + +func (p *Pacer) SentPacket(sendTime congestion.Time, size congestion.ByteCount) { + budget := p.Budget(sendTime) + if size > budget { + p.budgetAtLastSent = 0 + } else { + p.budgetAtLastSent = budget - size + } + p.lastSentTime = sendTime +} + +func (p *Pacer) Budget(now congestion.Time) congestion.ByteCount { + if p.lastSentTime.IsZero() { + return p.maxBurstSize() + } + budget := p.budgetAtLastSent + (p.getBandwidth()*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9 + if budget < 0 { // protect against overflows + budget = congestion.ByteCount(1<<62 - 1) + } + return min(p.maxBurstSize(), budget) +} + +func (p *Pacer) maxBurstSize() congestion.ByteCount { + return max( + congestion.ByteCount((maxBurstPacingDelayMultiplier*congestion.MinPacingDelay).Nanoseconds())*p.getBandwidth()/1e9, + maxBurstPackets*p.maxDatagramSize, + ) +} + +// TimeUntilSend returns when the next packet should be sent. +// It returns the zero value if a packet can be sent immediately. +func (p *Pacer) TimeUntilSend() congestion.Time { + if p.budgetAtLastSent >= p.maxDatagramSize { + return 0 + } + diff := 1e9 * uint64(p.maxDatagramSize-p.budgetAtLastSent) + bw := uint64(p.getBandwidth()) + // We might need to round up this value. + // Otherwise, we might have a budget (slightly) smaller than the datagram size when the timer expires. + d := diff / bw + // this is effectively a math.Ceil, but using only integer math + if diff%bw > 0 { + d++ + } + return p.lastSentTime.Add(max(congestion.MinPacingDelay, time.Duration(d)*time.Nanosecond)) +} + +func (p *Pacer) SetMaxDatagramSize(s congestion.ByteCount) { + p.maxDatagramSize = s +} diff --git a/transport/internet/hysteria/congestion/utils.go b/transport/internet/hysteria/congestion/utils.go new file mode 100644 index 000000000000..1036760eef5e --- /dev/null +++ b/transport/internet/hysteria/congestion/utils.go @@ -0,0 +1,18 @@ +package congestion + +import ( + "github.com/apernet/quic-go" + "github.com/xtls/xray-core/transport/internet/hysteria/congestion/bbr" + "github.com/xtls/xray-core/transport/internet/hysteria/congestion/brutal" +) + +func UseBBR(conn *quic.Conn) { + conn.SetCongestionControl(bbr.NewBbrSender( + bbr.DefaultClock{}, + bbr.GetInitialPacketSize(conn.RemoteAddr()), + )) +} + +func UseBrutal(conn *quic.Conn, tx uint64) { + conn.SetCongestionControl(brutal.NewBrutalSender(tx)) +} diff --git a/transport/internet/hysteria/conn.go b/transport/internet/hysteria/conn.go new file mode 100644 index 000000000000..fe1259bb8f88 --- /dev/null +++ b/transport/internet/hysteria/conn.go @@ -0,0 +1,101 @@ +package hysteria + +import ( + "encoding/binary" + "io" + "time" + + "github.com/apernet/quic-go" + "github.com/xtls/xray-core/common/net" +) + +type interConn struct { + stream *quic.Stream + local net.Addr + remote net.Addr +} + +func (i *interConn) Read(b []byte) (int, error) { + return i.stream.Read(b) +} + +func (i *interConn) Write(b []byte) (int, error) { + return i.stream.Write(b) +} + +func (i *interConn) Close() error { + return i.stream.Close() +} + +func (i *interConn) LocalAddr() net.Addr { + return i.local +} + +func (i *interConn) RemoteAddr() net.Addr { + return i.remote +} + +func (i *interConn) SetDeadline(t time.Time) error { + return i.stream.SetDeadline(t) +} + +func (i *interConn) SetReadDeadline(t time.Time) error { + return i.stream.SetReadDeadline(t) +} + +func (i *interConn) SetWriteDeadline(t time.Time) error { + return i.stream.SetWriteDeadline(t) +} + +type InterUdpConn struct { + conn *quic.Conn + local net.Addr + remote net.Addr + + id uint32 + ch chan []byte + closed bool + closeFunc func() +} + +func (i *InterUdpConn) Read(p []byte) (int, error) { + b, ok := <-i.ch + if !ok { + return 0, io.EOF + } + n := copy(p, b) + return n, nil +} + +func (i *InterUdpConn) Write(p []byte) (int, error) { + binary.BigEndian.PutUint32(p, i.id) + if err := i.conn.SendDatagram(p); err != nil { + return 0, err + } + return len(p), nil +} + +func (i *InterUdpConn) Close() error { + i.closeFunc() + return nil +} + +func (i *InterUdpConn) LocalAddr() net.Addr { + return i.local +} + +func (i *InterUdpConn) RemoteAddr() net.Addr { + return i.remote +} + +func (i *InterUdpConn) SetDeadline(t time.Time) error { + return nil +} + +func (i *InterUdpConn) SetReadDeadline(t time.Time) error { + return nil +} + +func (i *InterUdpConn) SetWriteDeadline(t time.Time) error { + return nil +} diff --git a/transport/internet/hysteria/dialer.go b/transport/internet/hysteria/dialer.go new file mode 100644 index 000000000000..5a694aa4236b --- /dev/null +++ b/transport/internet/hysteria/dialer.go @@ -0,0 +1,410 @@ +package hysteria + +import ( + "context" + go_tls "crypto/tls" + "encoding/binary" + "math/rand" + "net/http" + "net/url" + "strconv" + "sync" + "time" + + "github.com/apernet/quic-go" + "github.com/apernet/quic-go/http3" + "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/common/session" + "github.com/xtls/xray-core/common/task" + "github.com/xtls/xray-core/transport/internet" + "github.com/xtls/xray-core/transport/internet/finalmask" + "github.com/xtls/xray-core/transport/internet/hysteria/congestion" + "github.com/xtls/xray-core/transport/internet/hysteria/udphop" + "github.com/xtls/xray-core/transport/internet/stat" + "github.com/xtls/xray-core/transport/internet/tls" +) + +type udpSessionManager struct { + conn *quic.Conn + m map[uint32]*InterUdpConn + nextId uint32 + closed bool + mutex sync.RWMutex +} + +func (m *udpSessionManager) run() { + for { + d, err := m.conn.ReceiveDatagram(context.Background()) + if err != nil { + break + } + + if len(d) < 4 { + continue + } + sessionId := binary.BigEndian.Uint32(d[:4]) + + m.feed(sessionId, d) + } + + m.mutex.Lock() + defer m.mutex.Unlock() + + m.closed = true + for _, udpConn := range m.m { + m.close(udpConn) + } +} + +func (m *udpSessionManager) close(udpConn *InterUdpConn) { + if !udpConn.closed { + udpConn.closed = true + close(udpConn.ch) + delete(m.m, udpConn.id) + } +} + +func (m *udpSessionManager) udp() (*InterUdpConn, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + if m.closed { + return nil, errors.New("closed") + } + + udpConn := &InterUdpConn{ + conn: m.conn, + local: m.conn.LocalAddr(), + remote: m.conn.RemoteAddr(), + + id: m.nextId, + ch: make(chan []byte, udpMessageChanSize), + } + udpConn.closeFunc = func() { + m.mutex.Lock() + defer m.mutex.Unlock() + m.close(udpConn) + } + m.m[m.nextId] = udpConn + m.nextId++ + + return udpConn, nil +} + +func (m *udpSessionManager) feed(sessionId uint32, d []byte) { + m.mutex.RLock() + defer m.mutex.RUnlock() + + udpConn, ok := m.m[sessionId] + if !ok { + return + } + + select { + case udpConn.ch <- d: + default: + } +} + +type client struct { + ctx context.Context + dest net.Destination + pktConn net.PacketConn + conn *quic.Conn + config *Config + tlsConfig *go_tls.Config + udpmaskManager *finalmask.UdpmaskManager + socketConfig *internet.SocketConfig + udpSM *udpSessionManager + mutex sync.Mutex +} + +func (c *client) status() Status { + if c.conn == nil { + return StatusUnknown + } + select { + case <-c.conn.Context().Done(): + return StatusInactive + default: + return StatusActive + } +} + +func (c *client) close() { + _ = c.conn.CloseWithError(closeErrCodeOK, "") + _ = c.pktConn.Close() + c.pktConn = nil + c.conn = nil + c.udpSM = nil +} + +func (c *client) dial() error { + status := c.status() + if status == StatusActive { + return nil + } + if status == StatusInactive { + c.close() + } + + var index int + if len(c.config.Ports) > 0 { + index = rand.Intn(len(c.config.Ports)) + c.dest.Port = net.Port(c.config.Ports[index]) + } + + raw, err := internet.DialSystem(c.ctx, c.dest, c.socketConfig) + if err != nil { + return errors.New("failed to dial to dest").Base(err) + } + + remote := raw.RemoteAddr() + + pktConn, ok := raw.(net.PacketConn) + if !ok { + raw.Close() + return errors.New("raw is not PacketConn") + } + + if len(c.config.Ports) > 0 { + addr := &udphop.UDPHopAddr{ + IP: remote.(*net.UDPAddr).IP, + Ports: c.config.Ports, + } + pktConn, err = udphop.NewUDPHopPacketConn(addr, time.Duration(c.config.Interval)*time.Second, c.udphopDialer, pktConn, index) + if err != nil { + return errors.New("udphop err").Base(err) + } + } + + if c.udpmaskManager != nil { + pktConn, err = c.udpmaskManager.WrapPacketConnClient(pktConn) + if err != nil { + return errors.New("mask err").Base(err) + } + } + + var quicConn *quic.Conn + rt := &http3.Transport{ + TLSClientConfig: c.tlsConfig, + QUICConfig: &quic.Config{ + InitialStreamReceiveWindow: c.config.InitStreamReceiveWindow, + MaxStreamReceiveWindow: c.config.MaxStreamReceiveWindow, + InitialConnectionReceiveWindow: c.config.InitConnReceiveWindow, + MaxConnectionReceiveWindow: c.config.MaxConnReceiveWindow, + MaxIdleTimeout: time.Duration(c.config.MaxIdleTimeout) * time.Second, + KeepAlivePeriod: time.Duration(c.config.KeepAlivePeriod) * time.Second, + DisablePathMTUDiscovery: c.config.DisablePathMtuDiscovery, + EnableDatagrams: true, + MaxDatagramFrameSize: MaxDatagramFrameSize, + DisablePathManager: true, + }, + Dial: func(ctx context.Context, _ string, tlsCfg *go_tls.Config, cfg *quic.Config) (*quic.Conn, error) { + qc, err := quic.DialEarly(ctx, pktConn, remote, tlsCfg, cfg) + if err != nil { + return nil, err + } + quicConn = qc + return qc, nil + }, + } + req := &http.Request{ + Method: http.MethodPost, + URL: &url.URL{ + Scheme: "https", + Host: URLHost, + Path: URLPath, + }, + Header: http.Header{ + RequestHeaderAuth: []string{c.config.Auth}, + CommonHeaderCCRX: []string{strconv.FormatUint(c.config.Down, 10)}, + CommonHeaderPadding: []string{authRequestPadding.String()}, + }, + } + resp, err := rt.RoundTrip(req) + if err != nil { + if quicConn != nil { + _ = quicConn.CloseWithError(closeErrCodeProtocolError, "") + } + _ = pktConn.Close() + return errors.New("RoundTrip err").Base(err) + } + if resp.StatusCode != StatusAuthOK { + _ = quicConn.CloseWithError(closeErrCodeProtocolError, "") + _ = pktConn.Close() + return errors.New("auth failed") + } + _ = resp.Body.Close() + + serverUdp, _ := strconv.ParseBool(resp.Header.Get(ResponseHeaderUDPEnabled)) + serverAuto := resp.Header.Get(CommonHeaderCCRX) + serverDown, _ := strconv.ParseUint(serverAuto, 10, 64) + + if serverAuto == "auto" || c.config.Up == 0 || serverDown == 0 { + congestion.UseBBR(quicConn) + } else { + congestion.UseBrutal(quicConn, min(c.config.Up, serverDown)) + } + + c.pktConn = pktConn + c.conn = quicConn + if serverUdp { + c.udpSM = &udpSessionManager{ + conn: quicConn, + m: make(map[uint32]*InterUdpConn), + nextId: 1, + } + go c.udpSM.run() + } + + return nil +} + +func (c *client) clean() { + c.mutex.Lock() + defer c.mutex.Unlock() + + if c.status() == StatusInactive { + c.close() + } +} + +func (c *client) tcp() (stat.Connection, error) { + c.mutex.Lock() + defer c.mutex.Unlock() + + err := c.dial() + if err != nil { + return nil, err + } + + stream, err := c.conn.OpenStream() + if err != nil { + return nil, err + } + + return &interConn{ + stream: stream, + local: c.conn.LocalAddr(), + remote: c.conn.RemoteAddr(), + }, nil +} + +func (c *client) udp() (stat.Connection, error) { + c.mutex.Lock() + defer c.mutex.Unlock() + + err := c.dial() + if err != nil { + return nil, err + } + + if c.udpSM == nil { + return nil, errors.New("server does not support udp") + } + + return c.udpSM.udp() +} + +func (c *client) setCtx(ctx context.Context) { + c.mutex.Lock() + defer c.mutex.Unlock() + + c.ctx = ctx +} + +func (c *client) udphopDialer(addr *net.UDPAddr) (net.PacketConn, error) { + c.mutex.Lock() + defer c.mutex.Unlock() + + if c.status() != StatusActive { + errors.LogDebug(c.ctx, "stop hop on disconnected QUIC waiting to be closed") + return nil, errors.New() + } + + raw, err := internet.DialSystem(c.ctx, net.DestinationFromAddr(addr), c.socketConfig) + if err != nil { + errors.LogDebug(c.ctx, "failed to dial to dest skip hop") + return nil, errors.New() + } + + pktConn, ok := raw.(net.PacketConn) + if !ok { + errors.LogDebug(c.ctx, "raw is not PacketConn skip hop") + raw.Close() + return nil, errors.New() + } + + return pktConn, nil +} + +type clientManager struct { + m map[string]*client + mutex sync.Mutex +} + +func (m *clientManager) clean() { + m.mutex.Lock() + defer m.mutex.Unlock() + + for _, c := range m.m { + c.clean() + } +} + +var manger *clientManager + +func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) { + tlsConfig := tls.ConfigFromStreamSettings(streamSettings) + if tlsConfig == nil { + return nil, errors.New("tls config is nil") + } + + addr := dest.NetAddr() + config := streamSettings.ProtocolSettings.(*Config) + + manger.mutex.Lock() + c, ok := manger.m[addr] + if !ok { + dest.Network = net.Network_UDP + c = &client{ + ctx: ctx, + dest: dest, + config: config, + tlsConfig: tlsConfig.GetTLSConfig(), + udpmaskManager: streamSettings.UdpmaskManager, + socketConfig: streamSettings.SocketSettings, + } + manger.m[addr] = c + } + c.setCtx(ctx) + manger.mutex.Unlock() + + outbounds := session.OutboundsFromContext(ctx) + targetUdp := len(outbounds) > 0 && outbounds[len(outbounds)-1].Target.Network == net.Network_UDP + + if targetUdp { + return c.udp() + } + return c.tcp() +} + +func init() { + manger = &clientManager{ + m: make(map[string]*client), + } + (&task.Periodic{ + Interval: 30 * time.Second, + Execute: func() error { + manger.clean() + return nil + }, + }).Start() +} + +func init() { + common.Must(internet.RegisterTransportDialer(protocolName, Dial)) +} diff --git a/transport/internet/hysteria/padding/padding.go b/transport/internet/hysteria/padding/padding.go new file mode 100644 index 000000000000..b134601ed628 --- /dev/null +++ b/transport/internet/hysteria/padding/padding.go @@ -0,0 +1,24 @@ +package padding + +import ( + "math/rand" +) + +const ( + paddingChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" +) + +// padding specifies a half-open range [Min, Max). +type Padding struct { + Min int + Max int +} + +func (p Padding) String() string { + n := p.Min + rand.Intn(p.Max-p.Min) + bs := make([]byte, n) + for i := range bs { + bs[i] = paddingChars[rand.Intn(len(paddingChars))] + } + return string(bs) +} diff --git a/transport/internet/hysteria/udphop/addr.go b/transport/internet/hysteria/udphop/addr.go new file mode 100644 index 000000000000..70dae2a23f35 --- /dev/null +++ b/transport/internet/hysteria/udphop/addr.go @@ -0,0 +1,65 @@ +package udphop + +import ( + "fmt" + "net" +) + +type InvalidPortError struct { + PortStr string +} + +func (e InvalidPortError) Error() string { + return fmt.Sprintf("%s is not a valid port number or range", e.PortStr) +} + +// UDPHopAddr contains an IP address and a list of ports. +type UDPHopAddr struct { + IP net.IP + Ports []uint32 + PortStr string +} + +func (a *UDPHopAddr) Network() string { + return "udphop" +} + +func (a *UDPHopAddr) String() string { + return net.JoinHostPort(a.IP.String(), a.PortStr) +} + +// addrs returns a list of net.Addr's, one for each port. +func (a *UDPHopAddr) addrs() ([]net.Addr, error) { + var addrs []net.Addr + for _, port := range a.Ports { + addr := &net.UDPAddr{ + IP: a.IP, + Port: int(port), + } + addrs = append(addrs, addr) + } + return addrs, nil +} + +// func ResolveUDPHopAddr(addr string) (*UDPHopAddr, error) { +// host, portStr, err := net.SplitHostPort(addr) +// if err != nil { +// return nil, err +// } +// ip, err := net.ResolveIPAddr("ip", host) +// if err != nil { +// return nil, err +// } +// result := &UDPHopAddr{ +// IP: ip.IP, +// PortStr: portStr, +// } + +// pu := utils.ParsePortUnion(portStr) +// if pu == nil { +// return nil, InvalidPortError{portStr} +// } +// result.Ports = pu.Ports() + +// return result, nil +// } diff --git a/transport/internet/hysteria/udphop/conn.go b/transport/internet/hysteria/udphop/conn.go new file mode 100644 index 000000000000..5615ec5b8942 --- /dev/null +++ b/transport/internet/hysteria/udphop/conn.go @@ -0,0 +1,297 @@ +package udphop + +import ( + "errors" + "math/rand" + "net" + "sync" + "syscall" + "time" +) + +const ( + packetQueueSize = 1024 + udpBufferSize = 2048 // QUIC packets are at most 1500 bytes long, so 2k should be more than enough + + defaultHopInterval = 30 * time.Second +) + +type udpHopPacketConn struct { + Addr net.Addr + Addrs []net.Addr + HopInterval time.Duration + ListenUDPFunc ListenUDPFunc + + connMutex sync.RWMutex + prevConn net.PacketConn + currentConn net.PacketConn + addrIndex int + + readBufferSize int + writeBufferSize int + + recvQueue chan *udpPacket + closeChan chan struct{} + closed bool + + bufPool sync.Pool +} + +type udpPacket struct { + Buf []byte + N int + Addr net.Addr + Err error +} + +type ListenUDPFunc = func(*net.UDPAddr) (net.PacketConn, error) + +func NewUDPHopPacketConn(addr *UDPHopAddr, hopInterval time.Duration, listenUDPFunc ListenUDPFunc, pktConn net.PacketConn, index int) (net.PacketConn, error) { + if hopInterval == 0 { + hopInterval = defaultHopInterval + } else if hopInterval < 5*time.Second { + return nil, errors.New("hop interval must be at least 5 seconds") + } + // if listenUDPFunc == nil { + // listenUDPFunc = func() (net.PacketConn, error) { + // return net.ListenUDP("udp", nil) + // } + // } + if listenUDPFunc == nil { + return nil, errors.New("nil listenUDPFunc") + } + addrs, err := addr.addrs() + if err != nil { + return nil, err + } + // curConn, err := listenUDPFunc() + // if err != nil { + // return nil, err + // } + hConn := &udpHopPacketConn{ + Addr: addr, + Addrs: addrs, + HopInterval: hopInterval, + ListenUDPFunc: listenUDPFunc, + prevConn: nil, + currentConn: pktConn, + addrIndex: index, + recvQueue: make(chan *udpPacket, packetQueueSize), + closeChan: make(chan struct{}), + bufPool: sync.Pool{ + New: func() interface{} { + return make([]byte, udpBufferSize) + }, + }, + } + go hConn.recvLoop(pktConn) + go hConn.hopLoop() + return hConn, nil +} + +func (u *udpHopPacketConn) recvLoop(conn net.PacketConn) { + for { + buf := u.bufPool.Get().([]byte) + n, addr, err := conn.ReadFrom(buf) + if err != nil { + u.bufPool.Put(buf) + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + // Only pass through timeout errors here, not permanent errors + // like connection closed. Connection close is normal as we close + // the old connection to exit this loop every time we hop. + u.recvQueue <- &udpPacket{nil, 0, nil, netErr} + } + return + } + select { + case u.recvQueue <- &udpPacket{buf, n, addr, nil}: + // Packet successfully queued + default: + // Queue is full, drop the packet + u.bufPool.Put(buf) + } + } +} + +func (u *udpHopPacketConn) hopLoop() { + ticker := time.NewTicker(u.HopInterval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + u.hop() + case <-u.closeChan: + return + } + } +} + +func (u *udpHopPacketConn) hop() { + u.connMutex.Lock() + defer u.connMutex.Unlock() + if u.closed { + return + } + // Update addrIndex to a new random value + u.addrIndex = rand.Intn(len(u.Addrs)) + newConn, err := u.ListenUDPFunc(u.Addrs[u.addrIndex].(*net.UDPAddr)) + if err != nil { + // Could be temporary, just skip this hop + return + } + // We need to keep receiving packets from the previous connection, + // because otherwise there will be packet loss due to the time gap + // between we hop to a new port and the server acknowledges this change. + // So we do the following: + // Close prevConn, + // move currentConn to prevConn, + // set newConn as currentConn, + // start recvLoop on newConn. + if u.prevConn != nil { + _ = u.prevConn.Close() // recvLoop for this conn will exit + } + u.prevConn = u.currentConn + u.currentConn = newConn + // Set buffer sizes if previously set + if u.readBufferSize > 0 { + _ = trySetReadBuffer(u.currentConn, u.readBufferSize) + } + if u.writeBufferSize > 0 { + _ = trySetWriteBuffer(u.currentConn, u.writeBufferSize) + } + go u.recvLoop(newConn) +} + +func (u *udpHopPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + for { + select { + case p := <-u.recvQueue: + if p.Err != nil { + return 0, nil, p.Err + } + // Currently we do not check whether the packet is from + // the server or not due to performance reasons. + n := copy(b, p.Buf[:p.N]) + u.bufPool.Put(p.Buf) + return n, u.Addr, nil + case <-u.closeChan: + return 0, nil, net.ErrClosed + } + } +} + +func (u *udpHopPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + u.connMutex.RLock() + defer u.connMutex.RUnlock() + if u.closed { + return 0, net.ErrClosed + } + // Skip the check for now, always write to the server, + // for the same reason as in ReadFrom. + return u.currentConn.WriteTo(b, u.Addrs[u.addrIndex]) +} + +func (u *udpHopPacketConn) Close() error { + u.connMutex.Lock() + defer u.connMutex.Unlock() + if u.closed { + return nil + } + // Close prevConn and currentConn + // Close closeChan to unblock ReadFrom & hopLoop + // Set closed flag to true to prevent double close + if u.prevConn != nil { + _ = u.prevConn.Close() + } + err := u.currentConn.Close() + close(u.closeChan) + u.closed = true + u.Addrs = nil // For GC + return err +} + +func (u *udpHopPacketConn) LocalAddr() net.Addr { + u.connMutex.RLock() + defer u.connMutex.RUnlock() + return u.currentConn.LocalAddr() +} + +func (u *udpHopPacketConn) SetDeadline(t time.Time) error { + u.connMutex.RLock() + defer u.connMutex.RUnlock() + if u.prevConn != nil { + _ = u.prevConn.SetDeadline(t) + } + return u.currentConn.SetDeadline(t) +} + +func (u *udpHopPacketConn) SetReadDeadline(t time.Time) error { + u.connMutex.RLock() + defer u.connMutex.RUnlock() + if u.prevConn != nil { + _ = u.prevConn.SetReadDeadline(t) + } + return u.currentConn.SetReadDeadline(t) +} + +func (u *udpHopPacketConn) SetWriteDeadline(t time.Time) error { + u.connMutex.RLock() + defer u.connMutex.RUnlock() + if u.prevConn != nil { + _ = u.prevConn.SetWriteDeadline(t) + } + return u.currentConn.SetWriteDeadline(t) +} + +// UDP-specific methods below + +func (u *udpHopPacketConn) SetReadBuffer(bytes int) error { + u.connMutex.Lock() + defer u.connMutex.Unlock() + u.readBufferSize = bytes + if u.prevConn != nil { + _ = trySetReadBuffer(u.prevConn, bytes) + } + return trySetReadBuffer(u.currentConn, bytes) +} + +func (u *udpHopPacketConn) SetWriteBuffer(bytes int) error { + u.connMutex.Lock() + defer u.connMutex.Unlock() + u.writeBufferSize = bytes + if u.prevConn != nil { + _ = trySetWriteBuffer(u.prevConn, bytes) + } + return trySetWriteBuffer(u.currentConn, bytes) +} + +func (u *udpHopPacketConn) SyscallConn() (syscall.RawConn, error) { + u.connMutex.RLock() + defer u.connMutex.RUnlock() + sc, ok := u.currentConn.(syscall.Conn) + if !ok { + return nil, errors.New("not supported") + } + return sc.SyscallConn() +} + +func trySetReadBuffer(pc net.PacketConn, bytes int) error { + sc, ok := pc.(interface { + SetReadBuffer(bytes int) error + }) + if ok { + return sc.SetReadBuffer(bytes) + } + return nil +} + +func trySetWriteBuffer(pc net.PacketConn, bytes int) error { + sc, ok := pc.(interface { + SetWriteBuffer(bytes int) error + }) + if ok { + return sc.SetWriteBuffer(bytes) + } + return nil +} diff --git a/transport/internet/hysteria/utils/portunion.go b/transport/internet/hysteria/utils/portunion.go new file mode 100644 index 000000000000..f76a6fd0a69d --- /dev/null +++ b/transport/internet/hysteria/utils/portunion.go @@ -0,0 +1,107 @@ +package utils + +import ( + "sort" + "strconv" + "strings" +) + +// PortUnion is a collection of multiple port ranges. +type PortUnion []PortRange + +// PortRange represents a range of ports. +// Start and End are inclusive. [Start, End] +type PortRange struct { + Start, End uint16 +} + +// ParsePortUnion parses a string of comma-separated port ranges (or single ports) into a PortUnion. +// Returns nil if the input is invalid. +// The returned PortUnion is guaranteed to be normalized. +func ParsePortUnion(s string) PortUnion { + if s == "all" || s == "*" { + // Wildcard special case + return PortUnion{PortRange{0, 65535}} + } + var result PortUnion + portStrs := strings.Split(s, ",") + for _, portStr := range portStrs { + if strings.Contains(portStr, "-") { + // Port range + portRange := strings.Split(portStr, "-") + if len(portRange) != 2 { + return nil + } + start, err := strconv.ParseUint(portRange[0], 10, 16) + if err != nil { + return nil + } + end, err := strconv.ParseUint(portRange[1], 10, 16) + if err != nil { + return nil + } + if start > end { + start, end = end, start + } + result = append(result, PortRange{uint16(start), uint16(end)}) + } else { + // Single port + port, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + return nil + } + result = append(result, PortRange{uint16(port), uint16(port)}) + } + } + if result == nil { + return nil + } + return result.Normalize() +} + +// Normalize normalizes a PortUnion. +// No overlapping ranges, ranges are sorted from low to high. +func (u PortUnion) Normalize() PortUnion { + if len(u) == 0 { + return u + } + sort.Slice(u, func(i, j int) bool { + if u[i].Start == u[j].Start { + return u[i].End < u[j].End + } + return u[i].Start < u[j].Start + }) + normalized := PortUnion{u[0]} + for _, current := range u[1:] { + last := &normalized[len(normalized)-1] + if uint32(current.Start) <= uint32(last.End)+1 { + if current.End > last.End { + last.End = current.End + } + } else { + normalized = append(normalized, current) + } + } + return normalized +} + +// Ports returns all ports in the PortUnion as a slice. +func (u PortUnion) Ports() []uint16 { + var ports []uint16 + for _, r := range u { + for i := uint32(r.Start); i <= uint32(r.End); i++ { + ports = append(ports, uint16(i)) + } + } + return ports +} + +// Contains returns true if the PortUnion contains the given port. +func (u PortUnion) Contains(port uint16) bool { + for _, r := range u { + if port >= r.Start && port <= r.End { + return true + } + } + return false +} diff --git a/transport/internet/hysteria/utils/portunion_test.go b/transport/internet/hysteria/utils/portunion_test.go new file mode 100644 index 000000000000..ba056a374166 --- /dev/null +++ b/transport/internet/hysteria/utils/portunion_test.go @@ -0,0 +1,150 @@ +package utils + +import ( + "reflect" + "slices" + "testing" +) + +func TestParsePortUnion(t *testing.T) { + tests := []struct { + name string + s string + want PortUnion + }{ + { + name: "empty", + s: "", + want: nil, + }, + { + name: "all 1", + s: "all", + want: PortUnion{{0, 65535}}, + }, + { + name: "all 2", + s: "*", + want: PortUnion{{0, 65535}}, + }, + { + name: "single port", + s: "1234", + want: PortUnion{{1234, 1234}}, + }, + { + name: "multiple ports (unsorted)", + s: "5678,1234,9012", + want: PortUnion{{1234, 1234}, {5678, 5678}, {9012, 9012}}, + }, + { + name: "one range", + s: "1234-1240", + want: PortUnion{{1234, 1240}}, + }, + { + name: "one range (reversed)", + s: "1240-1234", + want: PortUnion{{1234, 1240}}, + }, + { + name: "multiple ports and ranges (reversed, unsorted, overlapping)", + s: "5678,1200-1236,9100-9012,1234-1240", + want: PortUnion{{1200, 1240}, {5678, 5678}, {9012, 9100}}, + }, + { + name: "multiple ports and ranges with 65535 (reversed, unsorted, overlapping)", + s: "5678,1200-1236,65531-65535,65532-65534,9100-9012,1234-1240", + want: PortUnion{{1200, 1240}, {5678, 5678}, {9012, 9100}, {65531, 65535}}, + }, + { + name: "multiple ports and ranges with 65535 (reversed, unsorted, overlapping) 2", + s: "5678,1200-1236,65532-65535,65531-65534,9100-9012,1234-1240", + want: PortUnion{{1200, 1240}, {5678, 5678}, {9012, 9100}, {65531, 65535}}, + }, + { + name: "invalid 1", + s: "1234-", + want: nil, + }, + { + name: "invalid 2", + s: "1234-ggez", + want: nil, + }, + { + name: "invalid 3", + s: "233,", + want: nil, + }, + { + name: "invalid 4", + s: "1234-1240-1250", + want: nil, + }, + { + name: "invalid 5", + s: "-,,", + want: nil, + }, + { + name: "invalid 6", + s: "http", + want: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := ParsePortUnion(tt.s); !reflect.DeepEqual(got, tt.want) { + t.Errorf("ParsePortUnion() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestPortUnion_Ports(t *testing.T) { + tests := []struct { + name string + pu PortUnion + want []uint16 + }{ + { + name: "single port", + pu: PortUnion{{1234, 1234}}, + want: []uint16{1234}, + }, + { + name: "multiple ports", + pu: PortUnion{{1234, 1236}}, + want: []uint16{1234, 1235, 1236}, + }, + { + name: "multiple ports and ranges", + pu: PortUnion{{1234, 1236}, {5678, 5680}, {9000, 9002}}, + want: []uint16{1234, 1235, 1236, 5678, 5679, 5680, 9000, 9001, 9002}, + }, + { + name: "single port 65535", + pu: PortUnion{{65535, 65535}}, + want: []uint16{65535}, + }, + { + name: "port range with 65535", + pu: PortUnion{{65530, 65535}}, + want: []uint16{65530, 65531, 65532, 65533, 65534, 65535}, + }, + { + name: "multiple ports and ranges with 65535", + pu: PortUnion{{65530, 65535}, {1234, 1236}}, + want: []uint16{65530, 65531, 65532, 65533, 65534, 65535, 1234, 1235, 1236}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.pu.Ports(); !slices.Equal(got, tt.want) { + t.Errorf("PortUnion.Ports() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/transport/internet/memory_settings.go b/transport/internet/memory_settings.go index f133135376a8..db2b0d1f05de 100644 --- a/transport/internet/memory_settings.go +++ b/transport/internet/memory_settings.go @@ -1,6 +1,9 @@ package internet -import "github.com/xtls/xray-core/common/net" +import ( + "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/transport/internet/finalmask" +) // MemoryStreamConfig is a parsed form of StreamConfig. It is used to reduce the number of Protobuf parses. type MemoryStreamConfig struct { @@ -9,6 +12,8 @@ type MemoryStreamConfig struct { ProtocolSettings interface{} SecurityType string SecuritySettings interface{} + TcpmaskManager *finalmask.TcpmaskManager + UdpmaskManager *finalmask.UdpmaskManager SocketSettings *SocketConfig DownloadSettings *MemoryStreamConfig } @@ -45,5 +50,29 @@ func ToMemoryStreamConfig(s *StreamConfig) (*MemoryStreamConfig, error) { mss.SecuritySettings = ess } + if s != nil && len(s.Tcpmasks) > 0 { + var masks []finalmask.Tcpmask + for _, msg := range s.Tcpmasks { + instance, err := msg.GetInstance() + if err != nil { + return nil, err + } + masks = append(masks, instance.(finalmask.Tcpmask)) + } + mss.TcpmaskManager = finalmask.NewTcpmaskManager(masks) + } + + if s != nil && len(s.Udpmasks) > 0 { + var masks []finalmask.Udpmask + for _, msg := range s.Udpmasks { + instance, err := msg.GetInstance() + if err != nil { + return nil, err + } + masks = append(masks, instance.(finalmask.Udpmask)) + } + mss.UdpmaskManager = finalmask.NewUdpmaskManager(masks) + } + return mss, nil } diff --git a/transport/internet/splithttp/dialer.go b/transport/internet/splithttp/dialer.go index 3f6e9ea35a8f..45b770e7c6bd 100644 --- a/transport/internet/splithttp/dialer.go +++ b/transport/internet/splithttp/dialer.go @@ -13,8 +13,8 @@ import ( "sync/atomic" "time" - "github.com/quic-go/quic-go" - "github.com/quic-go/quic-go/http3" + "github.com/apernet/quic-go" + "github.com/apernet/quic-go/http3" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" diff --git a/transport/internet/splithttp/hub.go b/transport/internet/splithttp/hub.go index 5e9b9408f20b..442f196f7c4e 100644 --- a/transport/internet/splithttp/hub.go +++ b/transport/internet/splithttp/hub.go @@ -12,8 +12,8 @@ import ( "sync" "time" - "github.com/quic-go/quic-go" - "github.com/quic-go/quic-go/http3" + "github.com/apernet/quic-go" + "github.com/apernet/quic-go/http3" goreality "github.com/xtls/reality" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/errors"