Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion proxy/vless/encoding/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func EncodeRequestHeader(writer io.Writer, request *protocol.RequestHeader, requ
}

// DecodeRequestHeader decodes and returns (if successful) a RequestHeader from an input stream.
func DecodeRequestHeader(isfb bool, first *buf.Buffer, reader io.Reader, validator *vless.Validator) (*protocol.RequestHeader, *Addons, bool, error) {
func DecodeRequestHeader(isfb bool, first *buf.Buffer, reader io.Reader, validator vless.Validator) (*protocol.RequestHeader, *Addons, bool, error) {
buffer := buf.StackNew()
defer buffer.Release()

Expand Down
6 changes: 3 additions & 3 deletions proxy/vless/encoding/encoding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func TestRequestSerialization(t *testing.T) {
buffer := buf.StackNew()
common.Must(EncodeRequestHeader(&buffer, expectedRequest, expectedAddons))

Validator := new(vless.Validator)
Validator := new(vless.MemoryValidator)
Validator.Add(user)

actualRequest, actualAddons, _, err := DecodeRequestHeader(false, nil, &buffer, Validator)
Expand Down Expand Up @@ -83,7 +83,7 @@ func TestInvalidRequest(t *testing.T) {
buffer := buf.StackNew()
common.Must(EncodeRequestHeader(&buffer, expectedRequest, expectedAddons))

Validator := new(vless.Validator)
Validator := new(vless.MemoryValidator)
Validator.Add(user)

_, _, _, err := DecodeRequestHeader(false, nil, &buffer, Validator)
Expand Down Expand Up @@ -114,7 +114,7 @@ func TestMuxRequest(t *testing.T) {
buffer := buf.StackNew()
common.Must(EncodeRequestHeader(&buffer, expectedRequest, expectedAddons))

Validator := new(vless.Validator)
Validator := new(vless.MemoryValidator)
Validator.Add(user)

actualRequest, actualAddons, _, err := DecodeRequestHeader(false, nil, &buffer, Validator)
Expand Down
32 changes: 18 additions & 14 deletions proxy/vless/inbound/inbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,38 +45,42 @@ func init() {
}); err != nil {
return nil, err
}
return New(ctx, config.(*Config), dc)

c := config.(*Config)

validator := new(vless.MemoryValidator)
for _, user := range c.Clients {
u, err := user.ToMemoryUser()
if err != nil {
return nil, errors.New("failed to get VLESS user").Base(err).AtError()
}
if err := validator.Add(u); err != nil {
return nil, errors.New("failed to initiate user").Base(err).AtError()
}
}

return New(ctx, c, dc, validator)
}))
}

// Handler is an inbound connection handler that handles messages in VLess protocol.
type Handler struct {
inboundHandlerManager feature_inbound.Manager
policyManager policy.Manager
validator *vless.Validator
validator vless.Validator
dns dns.Client
fallbacks map[string]map[string]map[string]*Fallback // or nil
// regexps map[string]*regexp.Regexp // or nil
}

// New creates a new VLess inbound handler.
func New(ctx context.Context, config *Config, dc dns.Client) (*Handler, error) {
func New(ctx context.Context, config *Config, dc dns.Client, validator vless.Validator) (*Handler, error) {
v := core.MustFromContext(ctx)
handler := &Handler{
inboundHandlerManager: v.GetFeature(feature_inbound.ManagerType()).(feature_inbound.Manager),
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
validator: new(vless.Validator),
dns: dc,
}

for _, user := range config.Clients {
u, err := user.ToMemoryUser()
if err != nil {
return nil, errors.New("failed to get VLESS user").Base(err).AtError()
}
if err := handler.AddUser(ctx, u); err != nil {
return nil, errors.New("failed to initiate user").Base(err).AtError()
}
validator: validator,
}

if config.Fallbacks != nil {
Expand Down
16 changes: 11 additions & 5 deletions proxy/vless/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,21 @@ import (
"github.com/xtls/xray-core/common/uuid"
)

// Validator stores valid VLESS users.
type Validator struct {
type Validator interface {
Get(id uuid.UUID) *protocol.MemoryUser
Add(u *protocol.MemoryUser) error
Del(email string) error
}

// MemoryValidator stores valid VLESS users.
type MemoryValidator struct {
// Considering email's usage here, map + sync.Mutex/RWMutex may have better performance.
email sync.Map
users sync.Map
}

// Add a VLESS user, Email must be empty or unique.
func (v *Validator) Add(u *protocol.MemoryUser) error {
func (v *MemoryValidator) Add(u *protocol.MemoryUser) error {
if u.Email != "" {
_, loaded := v.email.LoadOrStore(strings.ToLower(u.Email), u)
if loaded {
Expand All @@ -29,7 +35,7 @@ func (v *Validator) Add(u *protocol.MemoryUser) error {
}

// Del a VLESS user with a non-empty Email.
func (v *Validator) Del(e string) error {
func (v *MemoryValidator) Del(e string) error {
if e == "" {
return errors.New("Email must not be empty.")
}
Expand All @@ -44,7 +50,7 @@ func (v *Validator) Del(e string) error {
}

// Get a VLESS user with UUID, nil if user doesn't exist.
func (v *Validator) Get(id uuid.UUID) *protocol.MemoryUser {
func (v *MemoryValidator) Get(id uuid.UUID) *protocol.MemoryUser {
u, _ := v.users.Load(id)
if u != nil {
return u.(*protocol.MemoryUser)
Expand Down