diff --git a/client.go b/client.go index 804024e49..ef130be1d 100644 --- a/client.go +++ b/client.go @@ -112,7 +112,7 @@ func (c *Client) Call(ctx context.Context, service, method string, req, resp int ) if metadata, ok := GetMetadata(ctx); ok { - creq.Metadata = metadata + metadata.setRequest(creq) } if dl, ok := ctx.Deadline(); ok { diff --git a/example/cmd/main.go b/example/cmd/main.go index f452381ee..f1f43ab70 100644 --- a/example/cmd/main.go +++ b/example/cmd/main.go @@ -109,7 +109,7 @@ func client() error { } ctx := context.Background() - md := ttrpc.Metadata{} + md := ttrpc.MD{} md.Set("name", "koye") ctx = ttrpc.WithMetadata(ctx, md) diff --git a/metadata.go b/metadata.go index d03626321..ce8c0d13c 100644 --- a/metadata.go +++ b/metadata.go @@ -16,53 +16,74 @@ package ttrpc -import "context" +import ( + "context" + "strings" +) -// Metadata represents the key-value pairs (similar to http.Header) to be passed to ttrpc server from a client. -type Metadata map[string]StringList +// MD is the user type for ttrpc metadata +type MD map[string][]string // Get returns the metadata for a given key when they exist. // If there is no metadata, a nil slice and false are returned. -func (m Metadata) Get(key string) ([]string, bool) { +func (m MD) Get(key string) ([]string, bool) { + key = strings.ToLower(key) list, ok := m[key] - if !ok || len(list.List) == 0 { + if !ok || len(list) == 0 { return nil, false } - return list.List, true + return list, true } // Set sets the provided values for a given key. // The values will overwrite any existing values. // If no values provided, a key will be deleted. -func (m Metadata) Set(key string, values ...string) { +func (m MD) Set(key string, values ...string) { + key = strings.ToLower(key) if len(values) == 0 { delete(m, key) return } - - m[key] = StringList{List: values} + m[key] = values } // Append appends additional values to the given key. -func (m Metadata) Append(key string, values ...string) { +func (m MD) Append(key string, values ...string) { + key = strings.ToLower(key) if len(values) == 0 { return } - - list, ok := m[key] + current, ok := m[key] if ok { - m.Set(key, append(list.List, values...)...) + m.Set(key, append(current, values...)...) } else { m.Set(key, values...) } } +func (m MD) setRequest(r *Request) { + for k, values := range m { + for _, v := range values { + r.Metadata = append(r.Metadata, &KeyValue{ + Key: k, + Value: v, + }) + } + } +} + +func (m MD) fromRequest(r *Request) { + for _, kv := range r.Metadata { + m[kv.Key] = append(m[kv.Key], kv.Value) + } +} + type metadataKey struct{} // GetMetadata retrieves metadata from context.Context (previously attached with WithMetadata) -func GetMetadata(ctx context.Context) (Metadata, bool) { - metadata, ok := ctx.Value(metadataKey{}).(Metadata) +func GetMetadata(ctx context.Context) (MD, bool) { + metadata, ok := ctx.Value(metadataKey{}).(MD) return metadata, ok } @@ -81,6 +102,6 @@ func GetMetadataValue(ctx context.Context, name string) (string, bool) { } // WithMetadata attaches metadata map to a context.Context -func WithMetadata(ctx context.Context, headers Metadata) context.Context { - return context.WithValue(ctx, metadataKey{}, headers) +func WithMetadata(ctx context.Context, md MD) context.Context { + return context.WithValue(ctx, metadataKey{}, md) } diff --git a/metadata_test.go b/metadata_test.go index c334e059e..d7fc09559 100644 --- a/metadata_test.go +++ b/metadata_test.go @@ -21,8 +21,8 @@ import ( "testing" ) -func TestMetadata_Get(t *testing.T) { - metadata := make(Metadata) +func TestMetadataGet(t *testing.T) { + metadata := make(MD) metadata.Set("foo", "1", "2") if list, ok := metadata.Get("foo"); !ok { @@ -36,8 +36,8 @@ func TestMetadata_Get(t *testing.T) { } } -func TestMetadata_GetInvalidKey(t *testing.T) { - metadata := make(Metadata) +func TestMetadataGetInvalidKey(t *testing.T) { + metadata := make(MD) metadata.Set("foo", "1", "2") if _, ok := metadata.Get("invalid"); ok { @@ -45,8 +45,8 @@ func TestMetadata_GetInvalidKey(t *testing.T) { } } -func TestMetadata_Unset(t *testing.T) { - metadata := make(Metadata) +func TestMetadataUnset(t *testing.T) { + metadata := make(MD) metadata.Set("foo", "1", "2") metadata.Set("foo") @@ -55,8 +55,8 @@ func TestMetadata_Unset(t *testing.T) { } } -func TestMetadata_Replace(t *testing.T) { - metadata := make(Metadata) +func TestMetadataReplace(t *testing.T) { + metadata := make(MD) metadata.Set("foo", "1", "2") metadata.Set("foo", "3", "4") @@ -71,8 +71,8 @@ func TestMetadata_Replace(t *testing.T) { } } -func TestMetadata_Append(t *testing.T) { - metadata := make(Metadata) +func TestMetadataAppend(t *testing.T) { + metadata := make(MD) metadata.Set("foo", "1") metadata.Append("foo", "2") metadata.Append("bar", "3") @@ -94,8 +94,8 @@ func TestMetadata_Append(t *testing.T) { } } -func TestMetadata_Context(t *testing.T) { - metadata := make(Metadata) +func TestMetadataContext(t *testing.T) { + metadata := make(MD) metadata.Set("foo", "bar") ctx := WithMetadata(context.Background(), metadata) diff --git a/server.go b/server.go index 5c33559f6..f5d87ba46 100644 --- a/server.go +++ b/server.go @@ -469,8 +469,10 @@ func (c *serverConn) run(sctx context.Context) { var noopFunc = func() {} func getRequestContext(ctx context.Context, req *Request) (retCtx context.Context, cancel func()) { - if req.Metadata != nil { - ctx = WithMetadata(ctx, req.Metadata) + if len(req.Metadata) > 0 { + md := MD{} + md.fromRequest(req) + ctx = WithMetadata(ctx, md) } cancel = noopFunc diff --git a/server_test.go b/server_test.go index d76ece1b8..2af8c4a70 100644 --- a/server_test.go +++ b/server_test.go @@ -546,7 +546,7 @@ func roundTrip(ctx context.Context, t *testing.T, client *testingClient, value s } ) - ctx = WithMetadata(ctx, Metadata{"foo": makeStringList("bar")}) + ctx = WithMetadata(ctx, MD{"foo": []string{"bar"}}) resp, err := client.Test(ctx, tp) if err != nil { diff --git a/types.go b/types.go index c8ecb3868..9a1c19a72 100644 --- a/types.go +++ b/types.go @@ -23,11 +23,11 @@ import ( ) type Request struct { - Service string `protobuf:"bytes,1,opt,name=service,proto3"` - Method string `protobuf:"bytes,2,opt,name=method,proto3"` - Payload []byte `protobuf:"bytes,3,opt,name=payload,proto3"` - TimeoutNano int64 `protobuf:"varint,4,opt,name=timeout_nano,proto3"` - Metadata Metadata `protobuf:"bytes,5,opt,name=metadata,proto3" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` + Service string `protobuf:"bytes,1,opt,name=service,proto3"` + Method string `protobuf:"bytes,2,opt,name=method,proto3"` + Payload []byte `protobuf:"bytes,3,opt,name=payload,proto3"` + TimeoutNano int64 `protobuf:"varint,4,opt,name=timeout_nano,proto3"` + Metadata []*KeyValue `protobuf:"bytes,5,rep,name=metadata,proto3"` } func (r *Request) Reset() { *r = Request{} } @@ -52,3 +52,12 @@ func (r *StringList) String() string { return fmt.Sprintf("%+#v", r) } func (r *StringList) ProtoMessage() {} func makeStringList(item ...string) StringList { return StringList{List: item} } + +type KeyValue struct { + Key string `protobuf:"bytes,1,opt,name=key,proto3"` + Value string `protobuf:"bytes,2,opt,name=value,proto3"` +} + +func (m *KeyValue) Reset() { *m = KeyValue{} } +func (*KeyValue) ProtoMessage() {} +func (m *KeyValue) String() string { return fmt.Sprintf("%+#v", m) }