Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions envoyauth/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -441,3 +441,51 @@ func preallocateForHTTPHeaders(result *http.Header, _ int) {
*result = make(http.Header)
}
}

// GetRequestQueryParametersToSet returns the query parameters to set in the request
func (result *EvalResult) GetRequestQueryParametersToSet() ([]*ext_core_v3.QueryParameter, error) {
switch decision := result.Decision.(type) {
case bool:
return nil, nil
case map[string]interface{}:
val, ok := decision["query_parameters_to_set"]
if !ok {
return nil, nil
}

params, ok := val.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("type assertion error, expected query_parameters_to_set to be a map but got '%T'", val)
}

result := make([]*ext_core_v3.QueryParameter, 0, len(params))

for key, value := range params {
switch v := value.(type) {
case string:
result = append(result, &ext_core_v3.QueryParameter{
Key: key,
Value: v,
})
case []interface{}:
result = slices.Grow(result, len(v))
for _, item := range v {
strItem, ok := item.(string)
if !ok {
return nil, fmt.Errorf("type assertion error: expected array element to be string but got '%T'", item)
}
result = append(result, &ext_core_v3.QueryParameter{
Key: key,
Value: strItem,
})
}
default:
return nil, fmt.Errorf("type assertion error, expected value to be string or array but got '%T'", value)
}
}

return result, nil
}

return nil, result.invalidDecisionErr()
}
161 changes: 161 additions & 0 deletions envoyauth/response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ import (
"context"
"encoding/json"
"reflect"
"sort"
"strings"
"testing"

ext_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
_structpb "github.com/golang/protobuf/ptypes/struct"
"github.com/open-policy-agent/opa/v1/bundle"
"github.com/open-policy-agent/opa/v1/storage"
Expand Down Expand Up @@ -214,6 +216,165 @@ func TestGetRequestQueryParametersToRemove(t *testing.T) {
}
}

func TestGetQueryParametersToSet(t *testing.T) {
tests := map[string]struct {
decision interface{}
exp []*ext_core_v3.QueryParameter
wantErr bool
}{
"bool_eval_result": {
true,
nil,
false,
},
"empty_map_result": {
map[string]interface{}{},
nil,
false,
},
"invalid_type": {
map[string]interface{}{
"query_parameters_to_set": "invalid",
},
nil,
true,
},
"invalid_value_type": {
map[string]interface{}{
"query_parameters_to_set": map[string]interface{}{
"test": 123,
},
},
nil,
true,
},
"invalid_array_value_type": {
map[string]interface{}{
"query_parameters_to_set": map[string]interface{}{
"test": []interface{}{123},
},
},
nil,
true,
},
"single_value": {
map[string]interface{}{
"query_parameters_to_set": map[string]interface{}{
"param1": "value1",
"param2": "value2",
},
},
[]*ext_core_v3.QueryParameter{
{
Key: "param1",
Value: "value1",
},
{
Key: "param2",
Value: "value2",
},
},
false,
},
"array_values": {
map[string]interface{}{
"query_parameters_to_set": map[string]interface{}{
"param1": []interface{}{"value1", "value2"},
"param2": []interface{}{"value3", "value4"},
},
},
[]*ext_core_v3.QueryParameter{
{
Key: "param1",
Value: "value1",
},
{
Key: "param1",
Value: "value2",
},
{
Key: "param2",
Value: "value3",
},
{
Key: "param2",
Value: "value4",
},
},
false,
},
"mixed_values": {
map[string]interface{}{
"query_parameters_to_set": map[string]interface{}{
"param1": "single",
"param2": []interface{}{"multi1", "multi2"},
},
},
[]*ext_core_v3.QueryParameter{
{
Key: "param1",
Value: "single",
},
{
Key: "param2",
Value: "multi1",
},
{
Key: "param2",
Value: "multi2",
},
},
false,
},
}

for name, tc := range tests {
t.Run(name, func(t *testing.T) {
er := EvalResult{
Decision: tc.decision,
}

result, err := er.GetRequestQueryParametersToSet()

if tc.wantErr {
if err == nil {
t.Fatal("Expected error but got nil")
}
} else {
if err != nil {
t.Fatalf("Unexpected error %v", err)
}

if len(result) != len(tc.exp) {
t.Fatalf("Expected %d parameters but got %d", len(tc.exp), len(result))
}

// sort first by key, then by value

sort.Slice(result, func(i, j int) bool {
if result[i].Key == result[j].Key {
return result[i].Value < result[j].Value
}
return result[i].Key < result[j].Key
})

sort.Slice(tc.exp, func(i, j int) bool {
if tc.exp[i].Key == tc.exp[j].Key {
return tc.exp[i].Value < tc.exp[j].Value
}
return tc.exp[i].Key < tc.exp[j].Key
})

for i, param := range result {
if param.Key != tc.exp[i].Key || param.Value != tc.exp[i].Value {
t.Fatalf("Parameter mismatch at index %d. Expected %v but got %v", i, tc.exp[i], param)
}
}
}
})
}
}

func TestGetRequestHTTPHeadersToRemove(t *testing.T) {
tests := map[string]struct {
decision interface{}
Expand Down
9 changes: 9 additions & 0 deletions internal/internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -538,12 +538,21 @@ func (p *envoyExtAuthzGrpcServer) check(ctx context.Context, req interface{}) (*
return nil, stop, internalErr
}

var queryParamsToSet []*ext_core_v3.QueryParameter
queryParamsToSet, err = result.GetRequestQueryParametersToSet()
if err != nil {
err = errors.Wrap(err, "failed to get request query parameters to set")
internalErr = newInternalError(EnvoyAuthResultErr, err)
return nil, stop, internalErr
}

resp.HttpResponse = &ext_authz_v3.CheckResponse_OkResponse{
OkResponse: &ext_authz_v3.OkHttpResponse{
Headers: responseHeaders,
HeadersToRemove: headersToRemove,
ResponseHeadersToAdd: responseHeadersToAdd,
QueryParametersToRemove: queryParamsToRemove,
QueryParametersToSet: queryParamsToSet,
},
}
} else {
Expand Down
79 changes: 79 additions & 0 deletions internal/internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"net/http"
"net/http/httptest"
"reflect"
"sort"
"strings"
"sync"
"testing"
Expand Down Expand Up @@ -1562,6 +1563,84 @@ func TestCheckAllowObjectDecisionReqQueryParamsToRemove(t *testing.T) {
}
}

func TestCheckAllowObjectDecisionReqQueryParamsToSet(t *testing.T) {
var req ext_authz.CheckRequest
if err := util.Unmarshal([]byte(exampleAllowedRequest), &req); err != nil {
panic(err)
}

module := `
package envoy.authz

default allow = true

query_parameters_to_set := {
"foo": "value1",
"bar": ["value2", "value3"]
}

result["allowed"] = allow
result["query_parameters_to_set"] = query_parameters_to_set`

server := testAuthzServerWithModule(module, "envoy/authz/result", nil, withCustomLogger(&testPlugin{}))
ctx := context.Background()
output, err := server.Check(ctx, &req)
if err != nil {
t.Fatal(err)
}

if output.Status.Code != int32(code.Code_OK) {
t.Fatalf("Expected request to be allowed but got: %v", output)
}

response := output.GetOkResponse()
if response == nil {
t.Fatal("Expected OkHttpResponse struct but got nil")
}

queryParams := response.GetQueryParametersToSet()
if len(queryParams) != 3 {
t.Fatalf("Expected three query params but got %v", len(queryParams))
}

expectedQueryParamsToSet := []*ext_core.QueryParameter{
{
Key: "foo",
Value: "value1",
},
{
Key: "bar",
Value: "value2",
},
{
Key: "bar",
Value: "value3",
},
}

// sort first by key, then by value

sort.Slice(queryParams, func(i, j int) bool {
if queryParams[i].Key == queryParams[j].Key {
return queryParams[i].Value < queryParams[j].Value
}
return queryParams[i].Key < queryParams[j].Key
})

sort.Slice(expectedQueryParamsToSet, func(i, j int) bool {
if expectedQueryParamsToSet[i].Key == expectedQueryParamsToSet[j].Key {
return expectedQueryParamsToSet[i].Value < expectedQueryParamsToSet[j].Value
}
return expectedQueryParamsToSet[i].Key < expectedQueryParamsToSet[j].Key
})

for i, param := range queryParams {
if !reflect.DeepEqual(expectedQueryParamsToSet[i], param) {
t.Fatalf("Expected query param %v but got %v", expectedQueryParamsToSet[i], param)
}
}
}

func TestCheckAllowObjectDecisionReqHeadersToRemove(t *testing.T) {
var req ext_authz.CheckRequest
if err := util.Unmarshal([]byte(exampleAllowedRequestParsedPath), &req); err != nil {
Expand Down