Skip to content

Commit a41ef12

Browse files
authored
Re-introduce GetResponseHTTPHeaders method (#667)
Signed-off-by: Pushpalanka Jayawardhana <[email protected]>
1 parent a8f1a90 commit a41ef12

File tree

2 files changed

+216
-49
lines changed

2 files changed

+216
-49
lines changed

envoyauth/response.go

Lines changed: 94 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@ import (
44
"context"
55
"encoding/json"
66
"fmt"
7-
"net/http"
8-
"slices"
9-
107
ext_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
118
ext_type_v3 "github.com/envoyproxy/go-control-plane/envoy/type/v3"
129
_structpb "github.com/golang/protobuf/ptypes/struct"
@@ -16,6 +13,8 @@ import (
1613
"github.com/open-policy-agent/opa/storage"
1714
"github.com/open-policy-agent/opa/topdown/builtins"
1815
"google.golang.org/protobuf/types/known/structpb"
16+
"net/http"
17+
"slices"
1918
)
2019

2120
// EvalResult - Captures the result from evaluating a query against an input
@@ -185,20 +184,9 @@ func (result *EvalResult) GetRequestHTTPHeadersToRemove() ([]string, error) {
185184
return result.getStringSliceFromDecision("request_headers_to_remove")
186185
}
187186

188-
func (result *EvalResult) getHeadersFromDecision(fieldName string) ([]*ext_core_v3.HeaderValueOption, error) {
189-
switch decision := result.Decision.(type) {
190-
case bool:
191-
return nil, nil
192-
case map[string]interface{}:
193-
val, ok := decision[fieldName]
194-
if !ok {
195-
return nil, nil
196-
}
197-
198-
return transformHeadersToEnvoy(val)
199-
default:
200-
return nil, result.invalidDecisionErr()
201-
}
187+
// GetResponseHTTPHeaders - returns the http headers to return if they are part of the decision as http header
188+
func (result *EvalResult) GetResponseHTTPHeaders() (http.Header, error) {
189+
return result.getHTTPHeadersFromDecision("headers")
202190
}
203191

204192
// GetResponseEnvoyHeaderValueOptions - returns the http headers to return if they are part of the decision as envoy header value options
@@ -343,54 +331,112 @@ func makeHeaderValueOption(k, v string) *ext_core_v3.HeaderValueOption {
343331
}
344332
}
345333

346-
func makeEnvoyHeaderValueOptionsFromHeadersMap(hvo []*ext_core_v3.HeaderValueOption, headers map[string]any) ([]*ext_core_v3.HeaderValueOption, error) {
347-
hvo = slices.Grow(hvo, len(headers))
334+
func (result *EvalResult) getHeadersFromDecision(fieldName string) ([]*ext_core_v3.HeaderValueOption, error) {
335+
return getHeadersWithTransformation(result.Decision, fieldName, preallocateForEnvoyHeaders, collectEnvoyHeaders)
336+
}
337+
338+
func (result *EvalResult) getHTTPHeadersFromDecision(fieldName string) (http.Header, error) {
339+
return getHeadersWithTransformation(result.Decision, fieldName, preallocateForHTTPHeaders, collectHTTPHeaders)
340+
}
341+
342+
func getHeadersWithTransformation[T any](
343+
decision any,
344+
fieldName string,
345+
preallocate func(*T, int),
346+
collector func(string, string, *T),
347+
) (T, error) {
348+
var result T
349+
350+
switch decision := decision.(type) {
351+
case bool:
352+
return result, nil
353+
case map[string]interface{}:
354+
headersList, err := extractHeadersFromDecision(decision, fieldName)
355+
if err != nil {
356+
return result, err
357+
}
358+
359+
for _, headers := range headersList {
360+
if err := collectHeaderValues(headers, preallocate, collector, &result); err != nil {
361+
return result, err
362+
}
363+
}
364+
return result, nil
365+
default:
366+
return result, fmt.Errorf("illegal value for policy evaluation result: %T", decision)
367+
}
368+
}
369+
370+
func extractHeadersFromDecision(decision map[string]interface{}, fieldName string) ([]map[string]interface{}, error) {
371+
val, ok := decision[fieldName]
372+
if !ok {
373+
return nil, nil
374+
}
375+
376+
switch v := val.(type) {
377+
case map[string]interface{}:
378+
return []map[string]interface{}{v}, nil
379+
case []interface{}:
380+
headersList := make([]map[string]interface{}, 0, len(v))
381+
for _, item := range v {
382+
headers, ok := item.(map[string]interface{})
383+
if !ok {
384+
return nil, fmt.Errorf("type assertion error, expected headers to be of type 'object' but got '%T'", item)
385+
}
386+
headersList = append(headersList, headers)
387+
}
388+
return headersList, nil
389+
default:
390+
return nil, fmt.Errorf("type assertion error, expected headers to be of type 'object' but got '%T'", v)
391+
}
392+
}
393+
394+
func collectHeaderValues[T any](
395+
headers map[string]interface{},
396+
preallocate func(*T, int),
397+
collector func(string, string, *T),
398+
result *T,
399+
) error {
348400
for key, value := range headers {
349401
switch val := value.(type) {
350402
case string:
351-
hvo = append(hvo, makeHeaderValueOption(key, val))
403+
preallocate(result, 1)
404+
collector(key, val, result)
352405
case []string:
353-
hvo = slices.Grow(hvo, len(val))
406+
preallocate(result, len(val))
354407
for _, v := range val {
355-
hvo = append(hvo, makeHeaderValueOption(key, v))
408+
collector(key, v, result)
356409
}
357410
case []interface{}:
358-
hvo = slices.Grow(hvo, len(val))
411+
preallocate(result, len(val))
359412
for _, v := range val {
360413
s, ok := v.(string)
361414
if !ok {
362-
return nil, fmt.Errorf("invalid value type %T for header '%s'", v, key)
415+
return fmt.Errorf("invalid value type %T for header '%s'", v, key)
363416
}
364-
hvo = append(hvo, makeHeaderValueOption(key, s))
417+
collector(key, s, result)
365418
}
366419
default:
367-
return nil, fmt.Errorf("type assertion error for header '%s'", key)
420+
return fmt.Errorf("type assertion error for header '%s'", key)
368421
}
369422
}
370-
return hvo, nil
423+
return nil
371424
}
372425

373-
func transformHeadersToEnvoy(input any) ([]*ext_core_v3.HeaderValueOption, error) {
374-
switch input := input.(type) {
375-
case []any:
376-
var (
377-
hvo []*ext_core_v3.HeaderValueOption
378-
err error
379-
)
380-
for _, val := range input {
381-
headers, ok := val.(map[string]any)
382-
if !ok {
383-
return nil, fmt.Errorf("type assertion error, expected headers to be of type 'object' but got '%T'", val)
384-
}
426+
func collectEnvoyHeaders(key, value string, result *[]*ext_core_v3.HeaderValueOption) {
427+
*result = append(*result, makeHeaderValueOption(key, value))
428+
}
385429

386-
hvo, err = makeEnvoyHeaderValueOptionsFromHeadersMap(hvo, headers)
387-
if err != nil {
388-
return nil, err
389-
}
390-
}
391-
return hvo, nil
392-
case map[string]any:
393-
return makeEnvoyHeaderValueOptionsFromHeadersMap(nil, input)
430+
func collectHTTPHeaders(key, value string, result *http.Header) {
431+
result.Add(key, value)
432+
}
433+
434+
func preallocateForEnvoyHeaders(result *[]*ext_core_v3.HeaderValueOption, additional int) {
435+
*result = slices.Grow(*result, additional)
436+
}
437+
438+
func preallocateForHTTPHeaders(result *http.Header, _ int) {
439+
if *result == nil {
440+
*result = make(http.Header)
394441
}
395-
return nil, fmt.Errorf("type assertion error, expected headers to be of type 'object' but got '%T'", input)
396442
}

envoyauth/response_test.go

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ func TestGetResponseHTTPHeadersToAdd(t *testing.T) {
373373
}
374374
}
375375

376-
func TestGetResponseHeaders(t *testing.T) {
376+
func TestGetResponseHeaderValueOptions(t *testing.T) {
377377
input := make(map[string]interface{})
378378
er := EvalResult{
379379
Decision: input,
@@ -455,6 +455,127 @@ func TestGetResponseHeaders(t *testing.T) {
455455
if len(result) != 2 {
456456
t.Fatalf("Expected two header but got %v", len(result))
457457
}
458+
459+
testAddHeaders["foo"] = []interface{}{"bar", "baz"}
460+
input["headers"] = testAddHeaders
461+
462+
result, err = er.GetResponseEnvoyHeaderValueOptions()
463+
464+
if err != nil {
465+
t.Fatalf("Expected no error but got %v", err)
466+
}
467+
468+
if len(result) != 2 {
469+
t.Fatalf("Expected two header but got %v", len(result))
470+
}
471+
472+
if seen["bar"] != 1 || seen["baz"] != 1 {
473+
t.Errorf("expected 'bar' and 'baz', got %v", seen)
474+
}
475+
}
476+
477+
func TestGetResponseHeaders(t *testing.T) {
478+
input := make(map[string]interface{})
479+
er := EvalResult{
480+
Decision: input,
481+
}
482+
483+
result, err := er.GetResponseHTTPHeaders()
484+
if err != nil {
485+
t.Fatalf("Expected no error but got %v", err)
486+
}
487+
488+
if len(result) != 0 {
489+
t.Fatal("Expected no headers")
490+
}
491+
492+
badHeader := "test"
493+
input["headers"] = badHeader
494+
495+
_, err = er.GetResponseHTTPHeaders()
496+
if err == nil {
497+
t.Fatal("Expected error but got nil")
498+
}
499+
500+
testHeaders := make(map[string]interface{})
501+
testHeaders["foo"] = "bar"
502+
input["headers"] = testHeaders
503+
504+
result, err = er.GetResponseHTTPHeaders()
505+
if err != nil {
506+
t.Fatalf("Expected no error but got %v", err)
507+
}
508+
509+
if len(result) != 1 {
510+
t.Fatalf("Expected one header but got %v", len(result))
511+
}
512+
513+
testHeaders["baz"] = 1
514+
515+
_, err = er.GetResponseHTTPHeaders()
516+
if err == nil {
517+
t.Fatal("Expected error but got nil")
518+
}
519+
520+
input["headers"] = []interface{}{
521+
map[string]interface{}{
522+
"foo": "bar",
523+
},
524+
map[string]interface{}{
525+
"foo": "baz",
526+
},
527+
}
528+
529+
result, err = er.GetResponseHTTPHeaders()
530+
if err != nil {
531+
t.Fatalf("Expected no error but got %v", err)
532+
}
533+
534+
if len(result.Values("foo")) != 2 {
535+
t.Fatalf("Expected two header values but got %v", result.Values("foo"))
536+
}
537+
538+
seen := map[string]int{}
539+
for _, values := range result {
540+
for _, value := range values {
541+
seen[value]++
542+
}
543+
}
544+
545+
if seen["bar"] != 1 || seen["baz"] != 1 {
546+
t.Errorf("expected 'bar' and 'baz', got %v", seen)
547+
}
548+
549+
testAddHeaders := make(map[string]interface{})
550+
testAddHeaders["foo"] = []string{"bar", "baz"}
551+
input["headers"] = testAddHeaders
552+
553+
result, err = er.GetResponseHTTPHeaders()
554+
555+
if err != nil {
556+
t.Fatalf("Expected no error but got %v", err)
557+
}
558+
559+
if len(result.Values("foo")) != 2 {
560+
t.Fatalf("Expected two header but got %v", len(result.Values("foo")))
561+
}
562+
563+
testAddHeaders["foo"] = []interface{}{"bar", "baz"}
564+
input["headers"] = testAddHeaders
565+
566+
result, err = er.GetResponseHTTPHeaders()
567+
568+
if err != nil {
569+
t.Fatalf("Expected no error but got %v", err)
570+
}
571+
572+
if len(result.Values("foo")) != 2 {
573+
t.Fatalf("Expected two header but got %v", len(result.Values("foo")))
574+
}
575+
576+
if seen["bar"] != 1 || seen["baz"] != 1 {
577+
t.Errorf("expected 'bar' and 'baz', got %v", seen)
578+
}
458579
}
459580

460581
func TestGetResponseBody(t *testing.T) {

0 commit comments

Comments
 (0)