diff --git a/router_context.go b/router_context.go index 4464eb0..4abff41 100644 --- a/router_context.go +++ b/router_context.go @@ -33,6 +33,9 @@ type Context interface { // Response returns the response message entry. Response() *message.Entry + // RawResponseData returns the not yet encoded response data. + RawResponseData() interface{} + // SetResponse encodes data with session's codec and sets response message entry. SetResponse(id, data interface{}) error @@ -64,11 +67,12 @@ type Context interface { // routeContext implements the Context interface. type routeContext struct { - mu sync.RWMutex - storage map[string]interface{} - session Session - reqEntry *message.Entry - respEntry *message.Entry + mu sync.RWMutex + storage map[string]interface{} + session Session + reqEntry *message.Entry + respEntry *message.Entry + rawRespData interface{} } // Deadline implements the context.Context Deadline method. @@ -118,6 +122,17 @@ func (c *routeContext) Response() *message.Entry { return c.respEntry } +// RawResponseData returns the not yet encoded response data. +func (c *routeContext) RawResponseData() interface{} { + if c.rawRespData != nil { + return c.rawRespData + } + if c.respEntry != nil { + return c.respEntry.Data + } + return nil +} + // SetResponse implements Context.SetResponse method. func (c *routeContext) SetResponse(id, data interface{}) error { if c.Session().Codec() == nil { @@ -127,6 +142,7 @@ func (c *routeContext) SetResponse(id, data interface{}) error { if err != nil { return err } + c.rawRespData = data c.respEntry = &message.Entry{ ID: id, Data: dataRaw, diff --git a/router_context_test.go b/router_context_test.go index 0d2a131..2418307 100644 --- a/router_context_test.go +++ b/router_context_test.go @@ -149,6 +149,7 @@ func TestRouteContext_SetResponse(t *testing.T) { err := c.SetResponse(1, "test") assert.NoError(t, err) assert.Equal(t, c.respEntry, entry) + assert.Equal(t, c.rawRespData, "test") }) } @@ -246,3 +247,23 @@ func Test_routeContext_MustSetResponse(t *testing.T) { }) }) } + +func Test_routeContext_RawResponseData(t *testing.T) { + t.Run("when raw resp data is not nil", func(t *testing.T) { + c := newContext(nil, nil) + c.rawRespData = 123 + assert.Equal(t, c.RawResponseData(), 123) + }) + t.Run("when resp entry is not nil", func(t *testing.T) { + c := newContext(nil, nil) + c.rawRespData = nil + c.respEntry = &message.Entry{ + Data: []byte("123"), + } + assert.Equal(t, c.RawResponseData(), []byte("123")) + }) + t.Run("when resp entry is nil", func(t *testing.T) { + c := newContext(nil, nil) + assert.Nil(t, c.RawResponseData()) + }) +}