Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func (c *Client) Call(ctx context.Context, service, method string, req, resp int
)

if metadata, ok := GetMetadata(ctx); ok {
creq.Metadata = metadata
metadata.setRequest(creq)
}

if dl, ok := ctx.Deadline(); ok {
Expand Down
2 changes: 1 addition & 1 deletion example/cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func client() error {
}

ctx := context.Background()
md := ttrpc.Metadata{}
md := ttrpc.MD{}
md.Set("name", "koye")
ctx = ttrpc.WithMetadata(ctx, md)

Expand Down
55 changes: 38 additions & 17 deletions metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,53 +16,74 @@

package ttrpc

import "context"
import (
"context"
"strings"
)

// Metadata represents the key-value pairs (similar to http.Header) to be passed to ttrpc server from a client.
type Metadata map[string]StringList
// MD is the user type for ttrpc metadata
type MD map[string][]string

// Get returns the metadata for a given key when they exist.
// If there is no metadata, a nil slice and false are returned.
func (m Metadata) Get(key string) ([]string, bool) {
func (m MD) Get(key string) ([]string, bool) {
key = strings.ToLower(key)
list, ok := m[key]
if !ok || len(list.List) == 0 {
if !ok || len(list) == 0 {
return nil, false
}

return list.List, true
return list, true
}

// Set sets the provided values for a given key.
// The values will overwrite any existing values.
// If no values provided, a key will be deleted.
func (m Metadata) Set(key string, values ...string) {
func (m MD) Set(key string, values ...string) {
key = strings.ToLower(key)
if len(values) == 0 {
delete(m, key)
return
}

m[key] = StringList{List: values}
m[key] = values
}

// Append appends additional values to the given key.
func (m Metadata) Append(key string, values ...string) {
func (m MD) Append(key string, values ...string) {
key = strings.ToLower(key)
if len(values) == 0 {
return
}

list, ok := m[key]
current, ok := m[key]
if ok {
m.Set(key, append(list.List, values...)...)
m.Set(key, append(current, values...)...)
} else {
m.Set(key, values...)
}
}

func (m MD) setRequest(r *Request) {
for k, values := range m {
for _, v := range values {
r.Metadata = append(r.Metadata, &KeyValue{
Key: k,
Value: v,
})
}
}
}

func (m MD) fromRequest(r *Request) {
for _, kv := range r.Metadata {
m[kv.Key] = append(m[kv.Key], kv.Value)
}
}

type metadataKey struct{}

// GetMetadata retrieves metadata from context.Context (previously attached with WithMetadata)
func GetMetadata(ctx context.Context) (Metadata, bool) {
metadata, ok := ctx.Value(metadataKey{}).(Metadata)
func GetMetadata(ctx context.Context) (MD, bool) {
metadata, ok := ctx.Value(metadataKey{}).(MD)
return metadata, ok
}

Expand All @@ -81,6 +102,6 @@ func GetMetadataValue(ctx context.Context, name string) (string, bool) {
}

// WithMetadata attaches metadata map to a context.Context
func WithMetadata(ctx context.Context, headers Metadata) context.Context {
return context.WithValue(ctx, metadataKey{}, headers)
func WithMetadata(ctx context.Context, md MD) context.Context {
return context.WithValue(ctx, metadataKey{}, md)
}
24 changes: 12 additions & 12 deletions metadata_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ import (
"testing"
)

func TestMetadata_Get(t *testing.T) {
metadata := make(Metadata)
func TestMetadataGet(t *testing.T) {
metadata := make(MD)
metadata.Set("foo", "1", "2")

if list, ok := metadata.Get("foo"); !ok {
Expand All @@ -36,17 +36,17 @@ func TestMetadata_Get(t *testing.T) {
}
}

func TestMetadata_GetInvalidKey(t *testing.T) {
metadata := make(Metadata)
func TestMetadataGetInvalidKey(t *testing.T) {
metadata := make(MD)
metadata.Set("foo", "1", "2")

if _, ok := metadata.Get("invalid"); ok {
t.Error("found invalid key")
}
}

func TestMetadata_Unset(t *testing.T) {
metadata := make(Metadata)
func TestMetadataUnset(t *testing.T) {
metadata := make(MD)
metadata.Set("foo", "1", "2")
metadata.Set("foo")

Expand All @@ -55,8 +55,8 @@ func TestMetadata_Unset(t *testing.T) {
}
}

func TestMetadata_Replace(t *testing.T) {
metadata := make(Metadata)
func TestMetadataReplace(t *testing.T) {
metadata := make(MD)
metadata.Set("foo", "1", "2")
metadata.Set("foo", "3", "4")

Expand All @@ -71,8 +71,8 @@ func TestMetadata_Replace(t *testing.T) {
}
}

func TestMetadata_Append(t *testing.T) {
metadata := make(Metadata)
func TestMetadataAppend(t *testing.T) {
metadata := make(MD)
metadata.Set("foo", "1")
metadata.Append("foo", "2")
metadata.Append("bar", "3")
Expand All @@ -94,8 +94,8 @@ func TestMetadata_Append(t *testing.T) {
}
}

func TestMetadata_Context(t *testing.T) {
metadata := make(Metadata)
func TestMetadataContext(t *testing.T) {
metadata := make(MD)
metadata.Set("foo", "bar")

ctx := WithMetadata(context.Background(), metadata)
Expand Down
6 changes: 4 additions & 2 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -469,8 +469,10 @@ func (c *serverConn) run(sctx context.Context) {
var noopFunc = func() {}

func getRequestContext(ctx context.Context, req *Request) (retCtx context.Context, cancel func()) {
if req.Metadata != nil {
ctx = WithMetadata(ctx, req.Metadata)
if len(req.Metadata) > 0 {
md := MD{}
md.fromRequest(req)
ctx = WithMetadata(ctx, md)
}

cancel = noopFunc
Expand Down
2 changes: 1 addition & 1 deletion server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ func roundTrip(ctx context.Context, t *testing.T, client *testingClient, value s
}
)

ctx = WithMetadata(ctx, Metadata{"foo": makeStringList("bar")})
ctx = WithMetadata(ctx, MD{"foo": []string{"bar"}})

resp, err := client.Test(ctx, tp)
if err != nil {
Expand Down
19 changes: 14 additions & 5 deletions types.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ import (
)

type Request struct {
Service string `protobuf:"bytes,1,opt,name=service,proto3"`
Method string `protobuf:"bytes,2,opt,name=method,proto3"`
Payload []byte `protobuf:"bytes,3,opt,name=payload,proto3"`
TimeoutNano int64 `protobuf:"varint,4,opt,name=timeout_nano,proto3"`
Metadata Metadata `protobuf:"bytes,5,opt,name=metadata,proto3" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"`
Service string `protobuf:"bytes,1,opt,name=service,proto3"`
Method string `protobuf:"bytes,2,opt,name=method,proto3"`
Payload []byte `protobuf:"bytes,3,opt,name=payload,proto3"`
TimeoutNano int64 `protobuf:"varint,4,opt,name=timeout_nano,proto3"`
Metadata []*KeyValue `protobuf:"bytes,5,rep,name=metadata,proto3"`
}

func (r *Request) Reset() { *r = Request{} }
Expand All @@ -52,3 +52,12 @@ func (r *StringList) String() string { return fmt.Sprintf("%+#v", r) }
func (r *StringList) ProtoMessage() {}

func makeStringList(item ...string) StringList { return StringList{List: item} }

type KeyValue struct {
Key string `protobuf:"bytes,1,opt,name=key,proto3"`
Value string `protobuf:"bytes,2,opt,name=value,proto3"`
}

func (m *KeyValue) Reset() { *m = KeyValue{} }
func (*KeyValue) ProtoMessage() {}
func (m *KeyValue) String() string { return fmt.Sprintf("%+#v", m) }