diff --git a/client.go b/client.go index 35ca91fba..748a0073a 100644 --- a/client.go +++ b/client.go @@ -99,6 +99,10 @@ func (c *Client) Call(ctx context.Context, service, method string, req, resp int cresp = &Response{} ) + if metadata, ok := GetMetadata(ctx); ok { + creq.Metadata = metadata + } + if dl, ok := ctx.Deadline(); ok { creq.TimeoutNano = dl.Sub(time.Now()).Nanoseconds() } diff --git a/metadata.go b/metadata.go new file mode 100644 index 000000000..d03626321 --- /dev/null +++ b/metadata.go @@ -0,0 +1,86 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package ttrpc + +import "context" + +// Metadata represents the key-value pairs (similar to http.Header) to be passed to ttrpc server from a client. +type Metadata map[string]StringList + +// Get returns the metadata for a given key when they exist. +// If there is no metadata, a nil slice and false are returned. +func (m Metadata) Get(key string) ([]string, bool) { + list, ok := m[key] + if !ok || len(list.List) == 0 { + return nil, false + } + + return list.List, true +} + +// Set sets the provided values for a given key. +// The values will overwrite any existing values. +// If no values provided, a key will be deleted. +func (m Metadata) Set(key string, values ...string) { + if len(values) == 0 { + delete(m, key) + return + } + + m[key] = StringList{List: values} +} + +// Append appends additional values to the given key. +func (m Metadata) Append(key string, values ...string) { + if len(values) == 0 { + return + } + + list, ok := m[key] + if ok { + m.Set(key, append(list.List, values...)...) + } else { + m.Set(key, values...) + } +} + +type metadataKey struct{} + +// GetMetadata retrieves metadata from context.Context (previously attached with WithMetadata) +func GetMetadata(ctx context.Context) (Metadata, bool) { + metadata, ok := ctx.Value(metadataKey{}).(Metadata) + return metadata, ok +} + +// GetMetadataValue gets a specific metadata value by name from context.Context +func GetMetadataValue(ctx context.Context, name string) (string, bool) { + metadata, ok := GetMetadata(ctx) + if !ok { + return "", false + } + + if list, ok := metadata.Get(name); ok { + return list[0], true + } + + return "", false +} + +// WithMetadata attaches metadata map to a context.Context +func WithMetadata(ctx context.Context, headers Metadata) context.Context { + return context.WithValue(ctx, metadataKey{}, headers) +} diff --git a/metadata_test.go b/metadata_test.go new file mode 100644 index 000000000..c334e059e --- /dev/null +++ b/metadata_test.go @@ -0,0 +1,108 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package ttrpc + +import ( + "context" + "testing" +) + +func TestMetadata_Get(t *testing.T) { + metadata := make(Metadata) + metadata.Set("foo", "1", "2") + + if list, ok := metadata.Get("foo"); !ok { + t.Error("key not found") + } else if len(list) != 2 { + t.Errorf("unexpected number of values: %d", len(list)) + } else if list[0] != "1" { + t.Errorf("invalid metadata value at 0: %s", list[0]) + } else if list[1] != "2" { + t.Errorf("invalid metadata value at 1: %s", list[1]) + } +} + +func TestMetadata_GetInvalidKey(t *testing.T) { + metadata := make(Metadata) + metadata.Set("foo", "1", "2") + + if _, ok := metadata.Get("invalid"); ok { + t.Error("found invalid key") + } +} + +func TestMetadata_Unset(t *testing.T) { + metadata := make(Metadata) + metadata.Set("foo", "1", "2") + metadata.Set("foo") + + if _, ok := metadata.Get("foo"); ok { + t.Error("key not deleted") + } +} + +func TestMetadata_Replace(t *testing.T) { + metadata := make(Metadata) + metadata.Set("foo", "1", "2") + metadata.Set("foo", "3", "4") + + if list, ok := metadata.Get("foo"); !ok { + t.Error("key not found") + } else if len(list) != 2 { + t.Errorf("unexpected number of values: %d", len(list)) + } else if list[0] != "3" { + t.Errorf("invalid metadata value at 0: %s", list[0]) + } else if list[1] != "4" { + t.Errorf("invalid metadata value at 1: %s", list[1]) + } +} + +func TestMetadata_Append(t *testing.T) { + metadata := make(Metadata) + metadata.Set("foo", "1") + metadata.Append("foo", "2") + metadata.Append("bar", "3") + + if list, ok := metadata.Get("foo"); !ok { + t.Error("key not found") + } else if len(list) != 2 { + t.Errorf("unexpected number of values: %d", len(list)) + } else if list[0] != "1" { + t.Errorf("invalid metadata value at 0: %s", list[0]) + } else if list[1] != "2" { + t.Errorf("invalid metadata value at 1: %s", list[1]) + } + + if list, ok := metadata.Get("bar"); !ok { + t.Error("key not found") + } else if list[0] != "3" { + t.Errorf("invalid value: %s", list[0]) + } +} + +func TestMetadata_Context(t *testing.T) { + metadata := make(Metadata) + metadata.Set("foo", "bar") + + ctx := WithMetadata(context.Background(), metadata) + + if bar, ok := GetMetadataValue(ctx, "foo"); !ok { + t.Error("metadata not found") + } else if bar != "bar" { + t.Errorf("invalid metadata value: %q", bar) + } +} diff --git a/server.go b/server.go index dc605f4db..595a69a0d 100644 --- a/server.go +++ b/server.go @@ -466,6 +466,10 @@ func (c *serverConn) run(sctx context.Context) { var noopFunc = func() {} func getRequestContext(ctx context.Context, req *Request) (retCtx context.Context, cancel func()) { + if req.Metadata != nil { + ctx = WithMetadata(ctx, req.Metadata) + } + cancel = noopFunc if req.TimeoutNano == 0 { return ctx, cancel diff --git a/server_test.go b/server_test.go index 9ee7e0d65..d76ece1b8 100644 --- a/server_test.go +++ b/server_test.go @@ -61,6 +61,7 @@ func (tc *testingClient) Test(ctx context.Context, req *testPayload) (*testPaylo type testPayload struct { Foo string `protobuf:"bytes,1,opt,name=foo,proto3"` Deadline int64 `protobuf:"varint,2,opt,name=deadline,proto3"` + Metadata string `protobuf:"bytes,3,opt,name=metadata,proto3"` } func (r *testPayload) Reset() { *r = testPayload{} } @@ -75,6 +76,11 @@ func (s *testingServer) Test(ctx context.Context, req *testPayload) (*testPayloa if dl, ok := ctx.Deadline(); ok { tp.Deadline = dl.UnixNano() } + + if v, ok := GetMetadataValue(ctx, "foo"); ok { + tp.Metadata = v + } + return tp, nil } @@ -540,6 +546,8 @@ func roundTrip(ctx context.Context, t *testing.T, client *testingClient, value s } ) + ctx = WithMetadata(ctx, Metadata{"foo": makeStringList("bar")}) + resp, err := client.Test(ctx, tp) if err != nil { t.Fatal(err) @@ -547,7 +555,7 @@ func roundTrip(ctx context.Context, t *testing.T, client *testingClient, value s results <- callResult{ input: tp, - expected: &testPayload{Foo: strings.Repeat(tp.Foo, 2)}, + expected: &testPayload{Foo: strings.Repeat(tp.Foo, 2), Metadata: "bar"}, received: resp, } } diff --git a/types.go b/types.go index a6b3b818e..c8ecb3868 100644 --- a/types.go +++ b/types.go @@ -23,10 +23,11 @@ import ( ) type Request struct { - Service string `protobuf:"bytes,1,opt,name=service,proto3"` - Method string `protobuf:"bytes,2,opt,name=method,proto3"` - Payload []byte `protobuf:"bytes,3,opt,name=payload,proto3"` - TimeoutNano int64 `protobuf:"varint,4,opt,name=timeout_nano,proto3"` + Service string `protobuf:"bytes,1,opt,name=service,proto3"` + Method string `protobuf:"bytes,2,opt,name=method,proto3"` + Payload []byte `protobuf:"bytes,3,opt,name=payload,proto3"` + TimeoutNano int64 `protobuf:"varint,4,opt,name=timeout_nano,proto3"` + Metadata Metadata `protobuf:"bytes,5,opt,name=metadata,proto3" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` } func (r *Request) Reset() { *r = Request{} } @@ -41,3 +42,13 @@ type Response struct { func (r *Response) Reset() { *r = Response{} } func (r *Response) String() string { return fmt.Sprintf("%+#v", r) } func (r *Response) ProtoMessage() {} + +type StringList struct { + List []string `protobuf:"bytes,1,rep,name=list,proto3"` +} + +func (r *StringList) Reset() { *r = StringList{} } +func (r *StringList) String() string { return fmt.Sprintf("%+#v", r) } +func (r *StringList) ProtoMessage() {} + +func makeStringList(item ...string) StringList { return StringList{List: item} }