diff --git a/flows/actions/base.go b/flows/actions/base.go index a9e43c0bb..e74ff76ee 100644 --- a/flows/actions/base.go +++ b/flows/actions/base.go @@ -68,7 +68,7 @@ func (a *baseAction) Type() string { return a.Type_ } func (a *baseAction) UUID() flows.ActionUUID { return a.UUID_ } // Validate validates our action is valid -func (a *baseAction) Validate() error { return nil } +func (a *baseAction) Validate(bool) error { return nil } // LocalizationUUID gets the UUID which identifies this object for localization func (a *baseAction) LocalizationUUID() uuids.UUID { return uuids.UUID(a.UUID_) } diff --git a/flows/actions/call_webhook.go b/flows/actions/call_webhook.go index 4e628e164..b7a78d586 100644 --- a/flows/actions/call_webhook.go +++ b/flows/actions/call_webhook.go @@ -74,7 +74,7 @@ func NewCallWebhook(uuid flows.ActionUUID, method string, url string, headers ma } // Validate validates our action is valid -func (a *CallWebhookAction) Validate() error { +func (a *CallWebhookAction) Validate(strict bool) error { for key := range a.Headers { if !httpguts.ValidHeaderFieldName(key) { return fmt.Errorf("header '%s' is not a valid HTTP header", key) diff --git a/flows/actions/remove_contact_groups.go b/flows/actions/remove_contact_groups.go index 581fbd312..009f0c9cd 100644 --- a/flows/actions/remove_contact_groups.go +++ b/flows/actions/remove_contact_groups.go @@ -48,7 +48,7 @@ func NewRemoveContactGroups(uuid flows.ActionUUID, groups []*assets.GroupReferen } // Validate validates our action is valid -func (a *RemoveContactGroupsAction) Validate() error { +func (a *RemoveContactGroupsAction) Validate(strict bool) error { if a.AllGroups && len(a.Groups) > 0 { return fmt.Errorf("can't specify specific groups when all_groups=true") } diff --git a/flows/definition/flow.go b/flows/definition/flow.go index ae50aa7f5..ce8acee25 100644 --- a/flows/definition/flow.go +++ b/flows/definition/flow.go @@ -54,7 +54,7 @@ type flow struct { } // NewFlow creates a new flow -func NewFlow(uuid assets.FlowUUID, name string, language i18n.Language, flowType flows.FlowType, revision int, expireAfter time.Duration, localization flows.Localization, nodes []flows.Node, ui json.RawMessage, a assets.Flow) (flows.Flow, error) { +func NewFlow(uuid assets.FlowUUID, name string, language i18n.Language, flowType flows.FlowType, revision int, expireAfter time.Duration, localization flows.Localization, nodes []flows.Node, ui json.RawMessage, a assets.Flow, strict bool) (flows.Flow, error) { f := &flow{ uuid: uuid, name: name, @@ -74,7 +74,7 @@ func NewFlow(uuid assets.FlowUUID, name string, language i18n.Language, flowType f.nodeMap[node.UUID()] = node } - if err := f.validate(); err != nil { + if err := f.validate(strict); err != nil { return nil, err } @@ -103,7 +103,7 @@ func (f *flow) ExpireAfter() time.Duration { return f.expireAfter } -func (f *flow) validate() error { +func (f *flow) validate(strict bool) error { if len(f.nodes) > flows.MaxNodesPerFlow { return fmt.Errorf("flow can't have more than %d nodes (has %d)", flows.MaxNodesPerFlow, len(f.nodes)) } @@ -118,7 +118,7 @@ func (f *flow) validate() error { } seenUUIDs[uuids.UUID(node.UUID())] = true - if err := node.Validate(f, seenUUIDs); err != nil { + if err := node.Validate(f, seenUUIDs, strict); err != nil { return fmt.Errorf("invalid node[uuid=%s]: %w", node.UUID(), err) } } @@ -323,7 +323,7 @@ type flowEnvelope struct { } // ReadFlow reads a flow definition from the passed in byte array, migrating it to the spec version of the engine if necessary -func ReadFlow(data json.RawMessage, mc *migrations.Config) (flows.Flow, error) { +func ReadFlow(data []byte, mc *migrations.Config) (flows.Flow, error) { return readFlow(data, mc, nil) } @@ -364,7 +364,7 @@ func readFlow(data json.RawMessage, mc *migrations.Config, a assets.Flow) (flows e.Localization = make(localization) } - return NewFlow(e.UUID, e.Name, e.Language, e.Type, e.Revision, time.Duration(e.ExpireAfterMinutes)*time.Minute, e.Localization, nodes, e.UI, a) + return NewFlow(e.UUID, e.Name, e.Language, e.Type, e.Revision, time.Duration(e.ExpireAfterMinutes)*time.Minute, e.Localization, nodes, e.UI, a, true) } // MarshalJSON marshals this flow into JSON diff --git a/flows/definition/flow_test.go b/flows/definition/flow_test.go index 63c3008e7..e9cf66fc5 100644 --- a/flows/definition/flow_test.go +++ b/flows/definition/flow_test.go @@ -310,8 +310,9 @@ func TestNewFlow(t *testing.T) { }, ), }, - nil, // no UI - nil, // no asset + nil, // no UI + nil, // no asset + true, // strict validation ) require.NoError(t, err) diff --git a/flows/definition/node.go b/flows/definition/node.go index 6385d13ef..7391be739 100644 --- a/flows/definition/node.go +++ b/flows/definition/node.go @@ -37,12 +37,14 @@ func (n *node) Actions() []flows.Action { return n.actions } func (n *node) Router() flows.Router { return n.router } func (n *node) Exits() []flows.Exit { return n.exits } -func (n *node) Validate(flow flows.Flow, seenUUIDs map[uuids.UUID]bool) error { - if len(n.actions) > flows.MaxActionsPerNode { - return fmt.Errorf("node can't have more than %d actions (has %d)", flows.MaxActionsPerNode, len(n.actions)) - } - if len(n.exits) > flows.MaxExitsPerNode { - return fmt.Errorf("node can't have more than %d exits (has %d)", flows.MaxExitsPerNode, len(n.exits)) +func (n *node) Validate(flow flows.Flow, seenUUIDs map[uuids.UUID]bool, strict bool) error { + if strict { + if len(n.actions) > flows.MaxActionsPerNode { + return fmt.Errorf("node can't have more than %d actions (has %d)", flows.MaxActionsPerNode, len(n.actions)) + } + if len(n.exits) > flows.MaxExitsPerNode { + return fmt.Errorf("node can't have more than %d exits (has %d)", flows.MaxExitsPerNode, len(n.exits)) + } } // validate all the node's actions @@ -59,14 +61,14 @@ func (n *node) Validate(flow flows.Flow, seenUUIDs map[uuids.UUID]bool) error { } seenUUIDs[uuids.UUID(action.UUID())] = true - if err := action.Validate(); err != nil { + if err := action.Validate(strict); err != nil { return fmt.Errorf("invalid action[uuid=%s, type=%s]: %w", action.UUID(), action.Type(), err) } } // check the router if there is one if n.Router() != nil { - if err := n.Router().Validate(flow, n.Exits()); err != nil { + if err := n.Router().Validate(flow, n.Exits(), strict); err != nil { return fmt.Errorf("invalid router: %w", err) } } diff --git a/flows/interfaces.go b/flows/interfaces.go index de11d4c44..6dbce7d9f 100644 --- a/flows/interfaces.go +++ b/flows/interfaces.go @@ -173,7 +173,7 @@ type Node interface { Router() Router Exits() []Exit - Validate(Flow, map[uuids.UUID]bool) error + Validate(Flow, map[uuids.UUID]bool, bool) error EnumerateTemplates(Localization, func(Action, Router, i18n.Language, string)) EnumerateDependencies(Localization, func(Action, Router, i18n.Language, assets.Reference)) @@ -189,7 +189,7 @@ type Action interface { UUID() ActionUUID Execute(context.Context, Run, Step, ModifierCallback, EventCallback) error - Validate() error + Validate(bool) error } // Category is how routers map results to exits @@ -209,7 +209,7 @@ type Router interface { Categories() []Category ResultName() string - Validate(Flow, []Exit) error + Validate(Flow, []Exit, bool) error AllowTimeout() bool Route(Run, Step, EventCallback) (ExitUUID, string, error) RouteTimeout(Run, Step, EventCallback) (ExitUUID, error) diff --git a/flows/routers/base.go b/flows/routers/base.go index 4926a450e..ade2ecfb2 100644 --- a/flows/routers/base.go +++ b/flows/routers/base.go @@ -90,8 +90,8 @@ func (r *baseRouter) EnumerateLocalizables(include func(uuids.UUID, string, []st } } -func (r *baseRouter) validate(flow flows.Flow, exits []flows.Exit) error { - if len(r.categories) > flows.MaxCategoriesPerRouter { +func (r *baseRouter) validate(flow flows.Flow, exits []flows.Exit, strict bool) error { + if strict && len(r.categories) > flows.MaxCategoriesPerRouter { return fmt.Errorf("can't have more than %d categories (has %d)", flows.MaxCategoriesPerRouter, len(r.categories)) } diff --git a/flows/routers/random.go b/flows/routers/random.go index 851d078a6..5a4e46e9a 100644 --- a/flows/routers/random.go +++ b/flows/routers/random.go @@ -29,8 +29,8 @@ func NewRandom(wait flows.Wait, resultName string, categories []flows.Category) } // Validate validates that the fields on this router are valid -func (r *RandomRouter) Validate(flow flows.Flow, exits []flows.Exit) error { - return r.validate(flow, exits) +func (r *RandomRouter) Validate(flow flows.Flow, exits []flows.Exit, strict bool) error { + return r.validate(flow, exits, strict) } // Route determines which exit to take from a node diff --git a/flows/routers/switch.go b/flows/routers/switch.go index 62719f053..df33d2677 100644 --- a/flows/routers/switch.go +++ b/flows/routers/switch.go @@ -41,7 +41,7 @@ func NewCase(uuid uuids.UUID, type_ string, arguments []string, categoryUUID flo } } -func (c *Case) validate(r *SwitchRouter) error { +func (c *Case) validate(r *SwitchRouter, strict bool) error { if !r.isValidCategory(c.CategoryUUID) { return fmt.Errorf("category %s is not a valid category", c.CategoryUUID) } @@ -50,7 +50,7 @@ func (c *Case) validate(r *SwitchRouter) error { return fmt.Errorf("%s is not a registered test function", c.Type) } - if len(c.Arguments) > flows.MaxArgumentsPerCase { + if strict && len(c.Arguments) > flows.MaxArgumentsPerCase { return fmt.Errorf("can't have more than %d arguments (has %d)", flows.MaxArgumentsPerCase, len(c.Arguments)) } @@ -109,8 +109,8 @@ func NewSwitch(wait flows.Wait, resultName string, categories []flows.Category, func (r *SwitchRouter) Cases() []*Case { return r.cases } // Validate validates the arguments for this router -func (r *SwitchRouter) Validate(flow flows.Flow, exits []flows.Exit) error { - if len(r.cases) > flows.MaxCasesPerRouter { +func (r *SwitchRouter) Validate(flow flows.Flow, exits []flows.Exit, strict bool) error { + if strict && len(r.cases) > flows.MaxCasesPerRouter { return fmt.Errorf("can't have more than %d cases (has %d)", flows.MaxCasesPerRouter, len(r.cases)) } @@ -120,12 +120,12 @@ func (r *SwitchRouter) Validate(flow flows.Flow, exits []flows.Exit) error { } for _, c := range r.cases { - if err := c.validate(r); err != nil { + if err := c.validate(r, strict); err != nil { return fmt.Errorf("invalid case[uuid=%s]: %s", c.UUID, err) } } - return r.validate(flow, exits) + return r.validate(flow, exits, strict) } // Route determines which exit to take from a node