Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flows/actions/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func (a *baseAction) Type() string { return a.Type_ }
func (a *baseAction) UUID() flows.ActionUUID { return a.UUID_ }

// Validate validates our action is valid
func (a *baseAction) Validate() error { return nil }
func (a *baseAction) Validate(bool) error { return nil }

// LocalizationUUID gets the UUID which identifies this object for localization
func (a *baseAction) LocalizationUUID() uuids.UUID { return uuids.UUID(a.UUID_) }
Expand Down
2 changes: 1 addition & 1 deletion flows/actions/call_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func NewCallWebhook(uuid flows.ActionUUID, method string, url string, headers ma
}

// Validate validates our action is valid
func (a *CallWebhookAction) Validate() error {
func (a *CallWebhookAction) Validate(strict bool) error {
for key := range a.Headers {
if !httpguts.ValidHeaderFieldName(key) {
return fmt.Errorf("header '%s' is not a valid HTTP header", key)
Expand Down
2 changes: 1 addition & 1 deletion flows/actions/remove_contact_groups.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func NewRemoveContactGroups(uuid flows.ActionUUID, groups []*assets.GroupReferen
}

// Validate validates our action is valid
func (a *RemoveContactGroupsAction) Validate() error {
func (a *RemoveContactGroupsAction) Validate(strict bool) error {
if a.AllGroups && len(a.Groups) > 0 {
return fmt.Errorf("can't specify specific groups when all_groups=true")
}
Expand Down
12 changes: 6 additions & 6 deletions flows/definition/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ type flow struct {
}

// NewFlow creates a new flow
func NewFlow(uuid assets.FlowUUID, name string, language i18n.Language, flowType flows.FlowType, revision int, expireAfter time.Duration, localization flows.Localization, nodes []flows.Node, ui json.RawMessage, a assets.Flow) (flows.Flow, error) {
func NewFlow(uuid assets.FlowUUID, name string, language i18n.Language, flowType flows.FlowType, revision int, expireAfter time.Duration, localization flows.Localization, nodes []flows.Node, ui json.RawMessage, a assets.Flow, strict bool) (flows.Flow, error) {
f := &flow{
uuid: uuid,
name: name,
Expand All @@ -74,7 +74,7 @@ func NewFlow(uuid assets.FlowUUID, name string, language i18n.Language, flowType
f.nodeMap[node.UUID()] = node
}

if err := f.validate(); err != nil {
if err := f.validate(strict); err != nil {
return nil, err
}

Expand Down Expand Up @@ -103,7 +103,7 @@ func (f *flow) ExpireAfter() time.Duration {
return f.expireAfter
}

func (f *flow) validate() error {
func (f *flow) validate(strict bool) error {
if len(f.nodes) > flows.MaxNodesPerFlow {
return fmt.Errorf("flow can't have more than %d nodes (has %d)", flows.MaxNodesPerFlow, len(f.nodes))
}
Expand All @@ -118,7 +118,7 @@ func (f *flow) validate() error {
}
seenUUIDs[uuids.UUID(node.UUID())] = true

if err := node.Validate(f, seenUUIDs); err != nil {
if err := node.Validate(f, seenUUIDs, strict); err != nil {
return fmt.Errorf("invalid node[uuid=%s]: %w", node.UUID(), err)
}
}
Expand Down Expand Up @@ -323,7 +323,7 @@ type flowEnvelope struct {
}

// ReadFlow reads a flow definition from the passed in byte array, migrating it to the spec version of the engine if necessary
func ReadFlow(data json.RawMessage, mc *migrations.Config) (flows.Flow, error) {
func ReadFlow(data []byte, mc *migrations.Config) (flows.Flow, error) {
return readFlow(data, mc, nil)
}

Expand Down Expand Up @@ -364,7 +364,7 @@ func readFlow(data json.RawMessage, mc *migrations.Config, a assets.Flow) (flows
e.Localization = make(localization)
}

return NewFlow(e.UUID, e.Name, e.Language, e.Type, e.Revision, time.Duration(e.ExpireAfterMinutes)*time.Minute, e.Localization, nodes, e.UI, a)
return NewFlow(e.UUID, e.Name, e.Language, e.Type, e.Revision, time.Duration(e.ExpireAfterMinutes)*time.Minute, e.Localization, nodes, e.UI, a, true)
}

// MarshalJSON marshals this flow into JSON
Expand Down
5 changes: 3 additions & 2 deletions flows/definition/flow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,9 @@ func TestNewFlow(t *testing.T) {
},
),
},
nil, // no UI
nil, // no asset
nil, // no UI
nil, // no asset
true, // strict validation
)
require.NoError(t, err)

Expand Down
18 changes: 10 additions & 8 deletions flows/definition/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,14 @@ func (n *node) Actions() []flows.Action { return n.actions }
func (n *node) Router() flows.Router { return n.router }
func (n *node) Exits() []flows.Exit { return n.exits }

func (n *node) Validate(flow flows.Flow, seenUUIDs map[uuids.UUID]bool) error {
if len(n.actions) > flows.MaxActionsPerNode {
return fmt.Errorf("node can't have more than %d actions (has %d)", flows.MaxActionsPerNode, len(n.actions))
}
if len(n.exits) > flows.MaxExitsPerNode {
return fmt.Errorf("node can't have more than %d exits (has %d)", flows.MaxExitsPerNode, len(n.exits))
func (n *node) Validate(flow flows.Flow, seenUUIDs map[uuids.UUID]bool, strict bool) error {
if strict {
if len(n.actions) > flows.MaxActionsPerNode {
return fmt.Errorf("node can't have more than %d actions (has %d)", flows.MaxActionsPerNode, len(n.actions))
}
if len(n.exits) > flows.MaxExitsPerNode {
return fmt.Errorf("node can't have more than %d exits (has %d)", flows.MaxExitsPerNode, len(n.exits))
}
}

// validate all the node's actions
Expand All @@ -59,14 +61,14 @@ func (n *node) Validate(flow flows.Flow, seenUUIDs map[uuids.UUID]bool) error {
}
seenUUIDs[uuids.UUID(action.UUID())] = true

if err := action.Validate(); err != nil {
if err := action.Validate(strict); err != nil {
return fmt.Errorf("invalid action[uuid=%s, type=%s]: %w", action.UUID(), action.Type(), err)
}
}

// check the router if there is one
if n.Router() != nil {
if err := n.Router().Validate(flow, n.Exits()); err != nil {
if err := n.Router().Validate(flow, n.Exits(), strict); err != nil {
return fmt.Errorf("invalid router: %w", err)
}
}
Expand Down
6 changes: 3 additions & 3 deletions flows/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ type Node interface {
Router() Router
Exits() []Exit

Validate(Flow, map[uuids.UUID]bool) error
Validate(Flow, map[uuids.UUID]bool, bool) error

EnumerateTemplates(Localization, func(Action, Router, i18n.Language, string))
EnumerateDependencies(Localization, func(Action, Router, i18n.Language, assets.Reference))
Expand All @@ -189,7 +189,7 @@ type Action interface {

UUID() ActionUUID
Execute(context.Context, Run, Step, ModifierCallback, EventCallback) error
Validate() error
Validate(bool) error
}

// Category is how routers map results to exits
Expand All @@ -209,7 +209,7 @@ type Router interface {
Categories() []Category
ResultName() string

Validate(Flow, []Exit) error
Validate(Flow, []Exit, bool) error
AllowTimeout() bool
Route(Run, Step, EventCallback) (ExitUUID, string, error)
RouteTimeout(Run, Step, EventCallback) (ExitUUID, error)
Expand Down
4 changes: 2 additions & 2 deletions flows/routers/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ func (r *baseRouter) EnumerateLocalizables(include func(uuids.UUID, string, []st
}
}

func (r *baseRouter) validate(flow flows.Flow, exits []flows.Exit) error {
if len(r.categories) > flows.MaxCategoriesPerRouter {
func (r *baseRouter) validate(flow flows.Flow, exits []flows.Exit, strict bool) error {
if strict && len(r.categories) > flows.MaxCategoriesPerRouter {
return fmt.Errorf("can't have more than %d categories (has %d)", flows.MaxCategoriesPerRouter, len(r.categories))
}

Expand Down
4 changes: 2 additions & 2 deletions flows/routers/random.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ func NewRandom(wait flows.Wait, resultName string, categories []flows.Category)
}

// Validate validates that the fields on this router are valid
func (r *RandomRouter) Validate(flow flows.Flow, exits []flows.Exit) error {
return r.validate(flow, exits)
func (r *RandomRouter) Validate(flow flows.Flow, exits []flows.Exit, strict bool) error {
return r.validate(flow, exits, strict)
}

// Route determines which exit to take from a node
Expand Down
12 changes: 6 additions & 6 deletions flows/routers/switch.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func NewCase(uuid uuids.UUID, type_ string, arguments []string, categoryUUID flo
}
}

func (c *Case) validate(r *SwitchRouter) error {
func (c *Case) validate(r *SwitchRouter, strict bool) error {
if !r.isValidCategory(c.CategoryUUID) {
return fmt.Errorf("category %s is not a valid category", c.CategoryUUID)
}
Expand All @@ -50,7 +50,7 @@ func (c *Case) validate(r *SwitchRouter) error {
return fmt.Errorf("%s is not a registered test function", c.Type)
}

if len(c.Arguments) > flows.MaxArgumentsPerCase {
if strict && len(c.Arguments) > flows.MaxArgumentsPerCase {
return fmt.Errorf("can't have more than %d arguments (has %d)", flows.MaxArgumentsPerCase, len(c.Arguments))
}

Expand Down Expand Up @@ -109,8 +109,8 @@ func NewSwitch(wait flows.Wait, resultName string, categories []flows.Category,
func (r *SwitchRouter) Cases() []*Case { return r.cases }

// Validate validates the arguments for this router
func (r *SwitchRouter) Validate(flow flows.Flow, exits []flows.Exit) error {
if len(r.cases) > flows.MaxCasesPerRouter {
func (r *SwitchRouter) Validate(flow flows.Flow, exits []flows.Exit, strict bool) error {
if strict && len(r.cases) > flows.MaxCasesPerRouter {
return fmt.Errorf("can't have more than %d cases (has %d)", flows.MaxCasesPerRouter, len(r.cases))
}

Expand All @@ -120,12 +120,12 @@ func (r *SwitchRouter) Validate(flow flows.Flow, exits []flows.Exit) error {
}

for _, c := range r.cases {
if err := c.validate(r); err != nil {
if err := c.validate(r, strict); err != nil {
return fmt.Errorf("invalid case[uuid=%s]: %s", c.UUID, err)
}
}

return r.validate(flow, exits)
return r.validate(flow, exits, strict)
}

// Route determines which exit to take from a node
Expand Down