diff --git a/envoyauth/response.go b/envoyauth/response.go index a7d0d7611..0018c195e 100644 --- a/envoyauth/response.go +++ b/envoyauth/response.go @@ -441,3 +441,51 @@ func preallocateForHTTPHeaders(result *http.Header, _ int) { *result = make(http.Header) } } + +// GetRequestQueryParametersToSet returns the query parameters to set in the request +func (result *EvalResult) GetRequestQueryParametersToSet() ([]*ext_core_v3.QueryParameter, error) { + switch decision := result.Decision.(type) { + case bool: + return nil, nil + case map[string]interface{}: + val, ok := decision["query_parameters_to_set"] + if !ok { + return nil, nil + } + + params, ok := val.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("type assertion error, expected query_parameters_to_set to be a map but got '%T'", val) + } + + result := make([]*ext_core_v3.QueryParameter, 0, len(params)) + + for key, value := range params { + switch v := value.(type) { + case string: + result = append(result, &ext_core_v3.QueryParameter{ + Key: key, + Value: v, + }) + case []interface{}: + result = slices.Grow(result, len(v)) + for _, item := range v { + strItem, ok := item.(string) + if !ok { + return nil, fmt.Errorf("type assertion error: expected array element to be string but got '%T'", item) + } + result = append(result, &ext_core_v3.QueryParameter{ + Key: key, + Value: strItem, + }) + } + default: + return nil, fmt.Errorf("type assertion error, expected value to be string or array but got '%T'", value) + } + } + + return result, nil + } + + return nil, result.invalidDecisionErr() +} diff --git a/envoyauth/response_test.go b/envoyauth/response_test.go index 37f17061d..c526120c9 100644 --- a/envoyauth/response_test.go +++ b/envoyauth/response_test.go @@ -4,9 +4,11 @@ import ( "context" "encoding/json" "reflect" + "sort" "strings" "testing" + ext_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" _structpb "github.com/golang/protobuf/ptypes/struct" "github.com/open-policy-agent/opa/v1/bundle" "github.com/open-policy-agent/opa/v1/storage" @@ -214,6 +216,165 @@ func TestGetRequestQueryParametersToRemove(t *testing.T) { } } +func TestGetQueryParametersToSet(t *testing.T) { + tests := map[string]struct { + decision interface{} + exp []*ext_core_v3.QueryParameter + wantErr bool + }{ + "bool_eval_result": { + true, + nil, + false, + }, + "empty_map_result": { + map[string]interface{}{}, + nil, + false, + }, + "invalid_type": { + map[string]interface{}{ + "query_parameters_to_set": "invalid", + }, + nil, + true, + }, + "invalid_value_type": { + map[string]interface{}{ + "query_parameters_to_set": map[string]interface{}{ + "test": 123, + }, + }, + nil, + true, + }, + "invalid_array_value_type": { + map[string]interface{}{ + "query_parameters_to_set": map[string]interface{}{ + "test": []interface{}{123}, + }, + }, + nil, + true, + }, + "single_value": { + map[string]interface{}{ + "query_parameters_to_set": map[string]interface{}{ + "param1": "value1", + "param2": "value2", + }, + }, + []*ext_core_v3.QueryParameter{ + { + Key: "param1", + Value: "value1", + }, + { + Key: "param2", + Value: "value2", + }, + }, + false, + }, + "array_values": { + map[string]interface{}{ + "query_parameters_to_set": map[string]interface{}{ + "param1": []interface{}{"value1", "value2"}, + "param2": []interface{}{"value3", "value4"}, + }, + }, + []*ext_core_v3.QueryParameter{ + { + Key: "param1", + Value: "value1", + }, + { + Key: "param1", + Value: "value2", + }, + { + Key: "param2", + Value: "value3", + }, + { + Key: "param2", + Value: "value4", + }, + }, + false, + }, + "mixed_values": { + map[string]interface{}{ + "query_parameters_to_set": map[string]interface{}{ + "param1": "single", + "param2": []interface{}{"multi1", "multi2"}, + }, + }, + []*ext_core_v3.QueryParameter{ + { + Key: "param1", + Value: "single", + }, + { + Key: "param2", + Value: "multi1", + }, + { + Key: "param2", + Value: "multi2", + }, + }, + false, + }, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + er := EvalResult{ + Decision: tc.decision, + } + + result, err := er.GetRequestQueryParametersToSet() + + if tc.wantErr { + if err == nil { + t.Fatal("Expected error but got nil") + } + } else { + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + + if len(result) != len(tc.exp) { + t.Fatalf("Expected %d parameters but got %d", len(tc.exp), len(result)) + } + + // sort first by key, then by value + + sort.Slice(result, func(i, j int) bool { + if result[i].Key == result[j].Key { + return result[i].Value < result[j].Value + } + return result[i].Key < result[j].Key + }) + + sort.Slice(tc.exp, func(i, j int) bool { + if tc.exp[i].Key == tc.exp[j].Key { + return tc.exp[i].Value < tc.exp[j].Value + } + return tc.exp[i].Key < tc.exp[j].Key + }) + + for i, param := range result { + if param.Key != tc.exp[i].Key || param.Value != tc.exp[i].Value { + t.Fatalf("Parameter mismatch at index %d. Expected %v but got %v", i, tc.exp[i], param) + } + } + } + }) + } +} + func TestGetRequestHTTPHeadersToRemove(t *testing.T) { tests := map[string]struct { decision interface{} diff --git a/internal/internal.go b/internal/internal.go index af001ac74..3ba2649a1 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -538,12 +538,21 @@ func (p *envoyExtAuthzGrpcServer) check(ctx context.Context, req interface{}) (* return nil, stop, internalErr } + var queryParamsToSet []*ext_core_v3.QueryParameter + queryParamsToSet, err = result.GetRequestQueryParametersToSet() + if err != nil { + err = errors.Wrap(err, "failed to get request query parameters to set") + internalErr = newInternalError(EnvoyAuthResultErr, err) + return nil, stop, internalErr + } + resp.HttpResponse = &ext_authz_v3.CheckResponse_OkResponse{ OkResponse: &ext_authz_v3.OkHttpResponse{ Headers: responseHeaders, HeadersToRemove: headersToRemove, ResponseHeadersToAdd: responseHeadersToAdd, QueryParametersToRemove: queryParamsToRemove, + QueryParametersToSet: queryParamsToSet, }, } } else { diff --git a/internal/internal_test.go b/internal/internal_test.go index 7ca08df5d..835bf4e93 100644 --- a/internal/internal_test.go +++ b/internal/internal_test.go @@ -11,6 +11,7 @@ import ( "net/http" "net/http/httptest" "reflect" + "sort" "strings" "sync" "testing" @@ -1562,6 +1563,84 @@ func TestCheckAllowObjectDecisionReqQueryParamsToRemove(t *testing.T) { } } +func TestCheckAllowObjectDecisionReqQueryParamsToSet(t *testing.T) { + var req ext_authz.CheckRequest + if err := util.Unmarshal([]byte(exampleAllowedRequest), &req); err != nil { + panic(err) + } + + module := ` + package envoy.authz + + default allow = true + + query_parameters_to_set := { + "foo": "value1", + "bar": ["value2", "value3"] + } + + result["allowed"] = allow + result["query_parameters_to_set"] = query_parameters_to_set` + + server := testAuthzServerWithModule(module, "envoy/authz/result", nil, withCustomLogger(&testPlugin{})) + ctx := context.Background() + output, err := server.Check(ctx, &req) + if err != nil { + t.Fatal(err) + } + + if output.Status.Code != int32(code.Code_OK) { + t.Fatalf("Expected request to be allowed but got: %v", output) + } + + response := output.GetOkResponse() + if response == nil { + t.Fatal("Expected OkHttpResponse struct but got nil") + } + + queryParams := response.GetQueryParametersToSet() + if len(queryParams) != 3 { + t.Fatalf("Expected three query params but got %v", len(queryParams)) + } + + expectedQueryParamsToSet := []*ext_core.QueryParameter{ + { + Key: "foo", + Value: "value1", + }, + { + Key: "bar", + Value: "value2", + }, + { + Key: "bar", + Value: "value3", + }, + } + + // sort first by key, then by value + + sort.Slice(queryParams, func(i, j int) bool { + if queryParams[i].Key == queryParams[j].Key { + return queryParams[i].Value < queryParams[j].Value + } + return queryParams[i].Key < queryParams[j].Key + }) + + sort.Slice(expectedQueryParamsToSet, func(i, j int) bool { + if expectedQueryParamsToSet[i].Key == expectedQueryParamsToSet[j].Key { + return expectedQueryParamsToSet[i].Value < expectedQueryParamsToSet[j].Value + } + return expectedQueryParamsToSet[i].Key < expectedQueryParamsToSet[j].Key + }) + + for i, param := range queryParams { + if !reflect.DeepEqual(expectedQueryParamsToSet[i], param) { + t.Fatalf("Expected query param %v but got %v", expectedQueryParamsToSet[i], param) + } + } +} + func TestCheckAllowObjectDecisionReqHeadersToRemove(t *testing.T) { var req ext_authz.CheckRequest if err := util.Unmarshal([]byte(exampleAllowedRequestParsedPath), &req); err != nil {