Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions router_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,20 @@ import (

// NewContext creates a routeContext pointer.
func NewContext() *routeContext {
return &routeContext{}
return &routeContext{
rawCtx: context.Background(),
}
}

// Context is a generic context in a message routing.
// It allows us to pass variables between handler and middlewares.
type Context interface {
context.Context

// WithContext sets the underline context.
// It's very useful to control the workflow when send to response channel.
WithContext(ctx context.Context) Context

// Session returns the current session.
Session() Session

Expand Down Expand Up @@ -74,6 +80,7 @@ type Context interface {

// routeContext implements the Context interface.
type routeContext struct {
rawCtx context.Context
mu sync.RWMutex
storage map[string]interface{}
session Session
Expand All @@ -82,18 +89,18 @@ type routeContext struct {
}

// Deadline implements the context.Context Deadline method.
func (c *routeContext) Deadline() (deadline time.Time, ok bool) {
return
func (c *routeContext) Deadline() (time.Time, bool) {
return c.rawCtx.Deadline()
}

// Done implements the context.Context Done method.
func (c *routeContext) Done() <-chan struct{} {
return nil
return c.rawCtx.Done()
}

// Err implements the context.Context Err method.
func (c *routeContext) Err() error {
return nil
return c.rawCtx.Err()
}

// Value implements the context.Context Value method.
Expand All @@ -105,6 +112,12 @@ func (c *routeContext) Value(key interface{}) interface{} {
return nil
}

// WithContext sets the underline context.
func (c *routeContext) WithContext(ctx context.Context) Context {
c.rawCtx = ctx
return c
}

// Session implements Context.Session method.
func (c *routeContext) Session() Session {
return c.session
Expand Down Expand Up @@ -234,6 +247,7 @@ func (c *routeContext) Remove(key string) {
// Copy implements Context.Copy method.
func (c *routeContext) Copy() Context {
return &routeContext{
rawCtx: c.rawCtx,
storage: c.storage,
session: c.session,
reqEntry: c.reqEntry,
Expand All @@ -242,6 +256,7 @@ func (c *routeContext) Copy() Context {
}

func (c *routeContext) reset(sess *session, reqEntry *message.Entry) {
c.rawCtx = context.Background()
c.session = sess
c.reqEntry = reqEntry
c.respEntry = nil
Expand Down
5 changes: 4 additions & 1 deletion router_context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ import (
)

func newContext(sess *session, msg *message.Entry) *routeContext {
return &routeContext{session: sess, reqEntry: msg}
ctx := NewContext()
ctx.session = sess
ctx.reqEntry = msg
return ctx
}

func Test_routeContext_Deadline(t *testing.T) {
Expand Down
2 changes: 2 additions & 0 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ func (s *session) ID() string {
// Returns error if session is closed.
func (s *session) Send(ctx Context) (ok bool) {
select {
case <-ctx.Done():
return false
case <-s.closed:
return false
case s.respQueue <- ctx:
Expand Down
40 changes: 30 additions & 10 deletions session_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package easytcp

import (
"context"
"fmt"
"github.com/DarthPestilane/easytcp/internal/mock"
"github.com/DarthPestilane/easytcp/message"
Expand Down Expand Up @@ -160,7 +161,23 @@ func TestTCPSession_Send(t *testing.T) {
}
sess := newSession(nil, &sessionOption{})
sess.Close() // close session
assert.False(t, sess.Send(&routeContext{respEntry: entry}))
c := sess.NewContext().SetRequestMessage(entry)
assert.False(t, sess.Send(c))
})
t.Run("when ctx is done", func(t *testing.T) {
sess := newSession(nil, &sessionOption{})
c := sess.NewContext()
ctx, cancel := context.WithCancel(context.Background())
c.WithContext(ctx)

done := make(chan struct{})
go func() {
assert.False(t, sess.Send(c))
close(done)
}()

cancel()
<-done
})
t.Run("when send succeed", func(t *testing.T) {
entry := &message.Entry{
Expand All @@ -171,7 +188,9 @@ func TestTCPSession_Send(t *testing.T) {
sess := newSession(nil, &sessionOption{})
sess.respQueue = make(chan Context) // no buffer
go func() { <-sess.respQueue }()
assert.True(t, sess.Send(&routeContext{respEntry: entry}))

c := sess.NewContext().SetRequestMessage(entry)
assert.True(t, sess.Send(c))
sess.Close()
})
}
Expand Down Expand Up @@ -202,7 +221,7 @@ func TestTCPSession_writeOutbound(t *testing.T) {
packer.EXPECT().Pack(gomock.Any()).AnyTimes().Return(nil, nil)

sess := newSession(nil, &sessionOption{Packer: packer, respQueueSize: 1024})
sess.respQueue <- &routeContext{respEntry: nil}
sess.respQueue <- sess.NewContext()
doneLoop := make(chan struct{})
go func() {
sess.writeOutbound(0, 10) // should stop looping and return
Expand All @@ -226,7 +245,7 @@ func TestTCPSession_writeOutbound(t *testing.T) {
sess := newSession(nil, &sessionOption{Packer: packer})
done := make(chan struct{})
go func() {
sess.respQueue <- &routeContext{respEntry: entry}
sess.respQueue <- sess.NewContext().SetResponseMessage(entry)
close(done)
}()
time.Sleep(time.Microsecond * 15)
Expand All @@ -247,7 +266,7 @@ func TestTCPSession_writeOutbound(t *testing.T) {
packer.EXPECT().Pack(gomock.Any()).Return(nil, nil)

sess := newSession(nil, &sessionOption{Packer: packer, respQueueSize: 100})
sess.respQueue <- &routeContext{respEntry: entry} // push to queue
sess.respQueue <- sess.NewContext().SetResponseMessage(entry) // push to queue
doneLoop := make(chan struct{})
go func() {
sess.writeOutbound(0, 10)
Expand All @@ -270,7 +289,7 @@ func TestTCPSession_writeOutbound(t *testing.T) {
p1, _ := net.Pipe()
_ = p1.Close()
sess := newSession(p1, &sessionOption{Packer: packer})
go func() { sess.respQueue <- &routeContext{respEntry: entry} }()
go func() { sess.respQueue <- sess.NewContext().SetResponseMessage(entry) }()
go sess.writeOutbound(time.Millisecond*10, 10)
_, ok := <-sess.closed
assert.False(t, ok)
Expand All @@ -288,7 +307,7 @@ func TestTCPSession_writeOutbound(t *testing.T) {

p1, _ := net.Pipe()
sess := newSession(p1, &sessionOption{Packer: packer})
go func() { sess.respQueue <- &routeContext{respEntry: entry} }()
go func() { sess.respQueue <- sess.NewContext().SetResponseMessage(entry) }()
go sess.writeOutbound(time.Millisecond*10, 10)
_, ok := <-sess.closed
assert.False(t, ok)
Expand All @@ -308,7 +327,7 @@ func TestTCPSession_writeOutbound(t *testing.T) {
p1, _ := net.Pipe()
assert.NoError(t, p1.Close())
sess := newSession(p1, &sessionOption{Packer: packer})
go func() { sess.respQueue <- &routeContext{respEntry: entry} }()
go func() { sess.respQueue <- sess.NewContext().SetResponseMessage(entry) }()
sess.writeOutbound(0, 10) // should stop looping and return
_, ok := <-sess.closed
assert.False(t, ok)
Expand Down Expand Up @@ -340,7 +359,7 @@ func TestTCPSession_writeOutbound(t *testing.T) {
})

sess := newSession(conn, &sessionOption{Packer: packer})
go func() { sess.respQueue <- &routeContext{respEntry: entry} }()
go func() { sess.respQueue <- sess.NewContext().SetResponseMessage(entry) }()
sess.writeOutbound(0, 10) // should stop looping and return
_, ok := <-sess.closed
assert.False(t, ok)
Expand All @@ -359,7 +378,8 @@ func TestTCPSession_writeOutbound(t *testing.T) {
p1, p2 := net.Pipe()
sess := newSession(p1, &sessionOption{Packer: packer})
go func() {
_ = sess.Send(&routeContext{respEntry: entry})
c := sess.NewContext().SetResponseMessage(entry)
_ = sess.Send(c)
}()
done := make(chan struct{})
go func() {
Expand Down