diff --git a/common/xudp/xudp.go b/common/xudp/xudp.go index 1dfff16af004..e5da49f459b9 100644 --- a/common/xudp/xudp.go +++ b/common/xudp/xudp.go @@ -52,7 +52,7 @@ func GetGlobalID(ctx context.Context) (globalID [8]byte) { return } if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.Network == net.Network_UDP && - (inbound.Name == "dokodemo-door" || inbound.Name == "socks" || inbound.Name == "shadowsocks") { + (inbound.Name == "dokodemo-door" || inbound.Name == "socks" || inbound.Name == "shadowsocks" || inbound.Name == "tun") { h := blake3.New(8, BaseKey) h.Write([]byte(inbound.Source.String())) copy(globalID[:], h.Sum(nil)) diff --git a/go.mod b/go.mod index f0797bc0bfbb..a39162cca6dd 100644 --- a/go.mod +++ b/go.mod @@ -25,6 +25,7 @@ require ( golang.org/x/net v0.48.0 golang.org/x/sync v0.19.0 golang.org/x/sys v0.39.0 + golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 google.golang.org/grpc v1.78.0 google.golang.org/protobuf v1.36.11 @@ -50,7 +51,6 @@ require ( golang.org/x/text v0.32.0 // indirect golang.org/x/time v0.12.0 // indirect golang.org/x/tools v0.39.0 // indirect - golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251029180050-ab9386a59fda // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/proxy/tun/stack_gvisor.go b/proxy/tun/stack_gvisor.go index 81f307841d19..d062c3d0dedd 100644 --- a/proxy/tun/stack_gvisor.go +++ b/proxy/tun/stack_gvisor.go @@ -6,6 +6,7 @@ import ( "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" + "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -100,32 +101,35 @@ func (t *stackGVisor) Start() error { }) ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket) - udpForwarder := udp.NewForwarder(ipStack, func(r *udp.ForwarderRequest) { - go func(r *udp.ForwarderRequest) { - var wq waiter.Queue - var id = r.ID() - - ep, err := r.CreateEndpoint(&wq) - if err != nil { - errors.LogError(t.ctx, err.String()) - return - } - - options := ep.SocketOptions() - options.SetReuseAddress(true) - options.SetReusePort(true) - - t.handler.HandleConnection( - gonet.NewUDPConn(&wq, ep), - // local address on the gVisor side is connection destination - net.UDPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort)), - ) - - // close the socket - ep.Close() - }(r) + // Use custom UDP packet handler, instead of strict gVisor forwarder, for FullCone NAT support + udpForwarder := newUdpConnectionHandler(t.ctx, t.handler, func(p []byte) { + // extract network protocol from the packet + var networkProtocol tcpip.NetworkProtocolNumber + switch header.IPVersion(p) { + case header.IPv4Version: + networkProtocol = header.IPv4ProtocolNumber + case header.IPv6Version: + networkProtocol = header.IPv6ProtocolNumber + default: + // discard packet with unknown network version + return + } + + ipStack.WriteRawPacket(defaultNIC, networkProtocol, buffer.MakeWithData(p)) + }) + ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, func(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { + data := pkt.Data().AsRange().ToSlice() + if len(data) == 0 { + return false + } + // source/destination of the packet we process as incoming, on gVisor side are Remote/Local + // in other terms, src is the side behind tun, dst is the side behind gVisor + // this function handle packets passing from the tun to the gVisor, therefore the src/dst assignement + src := net.UDPDestination(net.IPAddress(id.RemoteAddress.AsSlice()), net.Port(id.RemotePort)) + dst := net.UDPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort)) + + return udpForwarder.HandlePacket(src, dst, data) }) - ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) t.stack = ipStack t.endpoint = linkEndpoint diff --git a/proxy/tun/udp_fullcone.go b/proxy/tun/udp_fullcone.go new file mode 100644 index 000000000000..bc50a7f3a305 --- /dev/null +++ b/proxy/tun/udp_fullcone.go @@ -0,0 +1,216 @@ +package tun + +import ( + "context" + "sync" + "sync/atomic" + "time" + + "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/common/buf" + c "github.com/xtls/xray-core/common/ctx" + "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/common/protocol" + "github.com/xtls/xray-core/common/session" + "github.com/xtls/xray-core/common/signal/done" + "github.com/xtls/xray-core/common/task" + "github.com/xtls/xray-core/transport" + "github.com/xtls/xray-core/transport/pipe" + "gvisor.dev/gvisor/pkg/buffer" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/checksum" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// udp connection abstraction +type udpConn struct { + lastActive atomic.Int64 + reader buf.Reader + writer buf.Writer + done *done.Instance + cancel context.CancelFunc +} + +// sub-handler specifically for udp connections under main handler +type udpConnectionHandler struct { + sync.Mutex + ctx context.Context + handler *Handler + udpConns map[net.Destination]*udpConn + udpChecker *task.Periodic + writePacket func(p []byte) +} + +func newUdpConnectionHandler(ctx context.Context, h *Handler, writePacket func(p []byte)) *udpConnectionHandler { + handler := &udpConnectionHandler{ + ctx: ctx, + handler: h, + udpConns: make(map[net.Destination]*udpConn), + writePacket: writePacket, + } + + handler.udpChecker = &task.Periodic{Interval: time.Minute, Execute: handler.cleanupUDP} + handler.udpChecker.Start() + + return handler +} + +func (u *udpConnectionHandler) cleanupUDP() error { + u.Lock() + defer u.Unlock() + if len(u.udpConns) == 0 { + return errors.New("no connections") + } + now := time.Now().Unix() + for src, conn := range u.udpConns { + if now-conn.lastActive.Load() > 300 { + conn.cancel() + common.Must(conn.done.Close()) + common.Must(common.Close(conn.writer)) + delete(u.udpConns, src) + } + } + return nil +} + +// HandlePacket handles UDP packets coming from tun, to forward to the dispatcher +// this custom handler support FullCone NAT of returning packets, binding connection only by the source port +func (u *udpConnectionHandler) HandlePacket(src net.Destination, dst net.Destination, data []byte) bool { + u.Lock() + conn, found := u.udpConns[src] + if !found { + reader, writer := pipe.New(pipe.DiscardOverflow(), pipe.WithSizeLimit(16*1024)) + conn = &udpConn{reader: reader, writer: writer, done: done.New()} + u.udpConns[src] = conn + u.Unlock() + + go func() { + ctx, cancel := context.WithCancel(u.ctx) + conn.cancel = cancel + defer func() { + cancel() + u.Lock() + delete(u.udpConns, src) + u.Unlock() + common.Must(conn.done.Close()) + common.Must(common.Close(conn.writer)) + }() + + inbound := &session.Inbound{ + Name: "tun", + Source: src, + CanSpliceCopy: 1, + User: &protocol.MemoryUser{Level: u.handler.config.UserLevel}, + } + ctx = session.ContextWithInbound(c.ContextWithID(ctx, session.NewID()), inbound) + ctx = session.SubContextFromMuxInbound(ctx) + link := &transport.Link{ + Reader: &buf.TimeoutWrapperReader{Reader: conn.reader}, + // reverse source and destination, indicating the packets to write are going in the other + // direction (written back to tun) and should have reversed addressing + Writer: &udpWriter{handler: u, src: dst, dst: src}, + } + _ = u.handler.dispatcher.DispatchLink(ctx, dst, link) + }() + } else { + conn.lastActive.Store(time.Now().Unix()) + u.Unlock() + } + + b := buf.New() + b.Write(data) + b.UDP = &dst + conn.writer.WriteMultiBuffer(buf.MultiBuffer{b}) + + return true +} + +type udpWriter struct { + handler *udpConnectionHandler + // address in the side of stack, where packet will be coming from + src net.Destination + // address on the side of tun, where packet will be destined to + dst net.Destination +} + +func (w *udpWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { + for _, b := range mb { + // use captured in the dispatched packet source address b.UDP as source, if available, + // otherwise use captured in the writer source w.src + srcAddr := w.src + if b.UDP != nil { + srcAddr = *b.UDP + } + + // validate address family matches + if srcAddr.Address.Family() != w.src.Address.Family() { + errors.LogWarning(context.Background(), "UDP return packet address family mismatch: expected ", w.src.Address.Family(), ", got ", srcAddr.Address.Family()) + b.Release() + continue + } + + payload := b.Bytes() + udpLen := header.UDPMinimumSize + len(payload) + srcIP := tcpip.AddrFromSlice(srcAddr.Address.IP()) + dstIP := tcpip.AddrFromSlice(w.dst.Address.IP()) + + // build packet with appropriate IP header size + isIPv4 := srcAddr.Address.Family().IsIPv4() + ipHdrSize := header.IPv6MinimumSize + if isIPv4 { + ipHdrSize = header.IPv4MinimumSize + } + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: ipHdrSize + header.UDPMinimumSize, + Payload: buffer.MakeWithData(payload), + }) + + // Build UDP header + udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) + udpHdr.Encode(&header.UDPFields{ + SrcPort: uint16(srcAddr.Port), + DstPort: uint16(w.dst.Port), + Length: uint16(udpLen), + }) + + // Calculate and set UDP checksum + xsum := header.PseudoHeaderChecksum(header.UDPProtocolNumber, srcIP, dstIP, uint16(udpLen)) + udpHdr.SetChecksum(^udpHdr.CalculateChecksum(checksum.Checksum(payload, xsum))) + + // Build IP header + if isIPv4 { + ipHdr := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize)) + ipHdr.Encode(&header.IPv4Fields{ + TotalLength: uint16(header.IPv4MinimumSize + udpLen), + TTL: 64, + Protocol: uint8(header.UDPProtocolNumber), + SrcAddr: srcIP, + DstAddr: dstIP, + }) + ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) + } else { + ipHdr := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) + ipHdr.Encode(&header.IPv6Fields{ + PayloadLength: uint16(udpLen), + TransportProtocol: header.UDPProtocolNumber, + HopLimit: 64, + SrcAddr: srcIP, + DstAddr: dstIP, + }) + } + + // Write raw packet to network stack + views := pkt.AsSlices() + var data []byte + for _, view := range views { + data = append(data, view...) + } + w.handler.writePacket(data) + pkt.DecRef() + b.Release() + } + return nil +}