diff --git a/README_ZH.md b/README_ZH.md deleted file mode 100644 index aaf8f7e4c..000000000 --- a/README_ZH.md +++ /dev/null @@ -1,59 +0,0 @@ - -# seata-go: 简单的可扩展自主事务架构(Go版本) - -[![Build Status](https://github.com/seata/seata/workflows/build/badge.svg?branch=develop)](https://github.com/seata/seata/actions) -[![license](https://img.shields.io/github/license/seata/seata.svg)](https://www.apache.org/licenses/LICENSE-2.0.html) - -[English US](./README.md) - -## 什么是 seata-go? - -Seata 是一个非常成熟的分布式事务框架,在 Java 领域是事实上的分布式事务技术标准平台。Seata-go 是 seata 多语言生态中的 Go 语言实现版本,实现了 Java 和 Go 之间的互通,让 Go 开发者也能使用 seata-go 来实现分布式事务。请访问[Seata 官网](https://seata.io/zh-cn/)查看快速开始和文档。 - -Seata-go 的原理和 Seata-java 保持一致,都是由 TM、RM 和 TC 组成,其中 TC 的功能复用 Java 的,TM 和 RM 功能后面会和 Seata-java 对齐,整体流程如下: - -![](https://user-images.githubusercontent.com/68344696/145942191-7a2d469f-94c8-4cd2-8c7e-46ad75683636.png) - -## 待办事项 - -- [x] TCC -- [ ] XA -- [x] AT - - [x] Insert SQL - - [x] Delete SQL - - [x] Insert on update SQL - - [x] Multi update SQL - - [x] Multi delete SQL - - [x] Select for update SQL - - [x] Update SQL -- [ ] SAGA -- [x] TM -- [x] RPC 通信 -- [x] 事务防悬挂 - - [x] 手动方式 - - [x] 代理数据源方式 -- [x] 空补偿 - - [x] 手动方式 - - [x] 代理数据源方式 -- [ ] 配置中心 - - [x] 配置文件 -- [ ] 注册中心 -- [ ] Metric 监控 -- [x] 压缩算法 -- [x] Sample 例子 - - -## 如何运行项目? - -关于如何使用和集成 seata-go 的示例,可以参考 [seata/seata-go-samples](https://github.com/seata/seata-go-samples) - - -## 如何给Seata-go贡献代码? - -Seata-go 目前正在建设阶段,欢迎行业同仁入群参与其中,与我们一起推动 seata-go 的建设!如果你想给 seata-go 贡献代码,可以参考 **[代码贡献规范](./CONTRIBUTING.md)** 文档来了解社区的规范,也可以加入我们的社区钉钉群:33069364,一起沟通交流! - -![image](https://user-images.githubusercontent.com/38887641/210141444-0ba6b11d-16e6-48af-945b-cb99ecfa70ef.png) - -## 协议 - -Seata-go 使用 Apache 许可证2.0版本,请参阅 LICENSE 文件了解更多。 diff --git a/go.mod b/go.mod index 429499820..0458e17e4 100644 --- a/go.mod +++ b/go.mod @@ -36,7 +36,7 @@ require ( github.com/google/cel-go v0.18.0 github.com/mattn/go-sqlite3 v1.14.19 github.com/robertkrimen/otto v0.4.0 - golang.org/x/sync v0.16.0 + golang.org/x/sync v0.11.0 google.golang.org/protobuf v1.33.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -91,7 +91,7 @@ require ( github.com/yusufpapurcu/wmi v1.2.2 // indirect go.uber.org/multierr v1.8.0 // indirect golang.org/x/arch v0.3.0 // indirect - golang.org/x/text v0.27.0 // indirect + golang.org/x/text v0.14.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.0.0 // indirect gopkg.in/sourcemap.v1 v1.0.5 // indirect ) @@ -108,7 +108,7 @@ require ( golang.org/x/crypto v0.17.0 // indirect golang.org/x/exp v0.0.0-20220827204233-334a2380cb91 // indirect golang.org/x/net v0.10.0 // indirect - golang.org/x/sys v0.32.0 // indirect + golang.org/x/sys v0.25.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20230803162519-f966b187b2e5 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20230803162519-f966b187b2e5 // indirect vimagination.zapto.org/memio v0.0.0-20200222190306-588ebc67b97d // indirect diff --git a/go.sum b/go.sum index 61c5eafdb..1cbc21368 100644 --- a/go.sum +++ b/go.sum @@ -954,8 +954,8 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= -golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= +golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -1033,8 +1033,8 @@ golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= -golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= +golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -1047,8 +1047,8 @@ golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= -golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/goimports.sh b/goimports.sh index 5fb713be1..485957355 100755 --- a/goimports.sh +++ b/goimports.sh @@ -16,11 +16,11 @@ # # format go imports style -go install golang.org/x/tools/cmd/goimports +go install golang.org/x/tools/cmd/goimports@v0.24.1 goimports -local github.com/seata/seata-go -w . # format licence style -go install github.com/apache/skywalking-eyes/cmd/license-eye@latest +go install github.com/apache/skywalking-eyes/cmd/license-eye@v0.6.0 license-eye header fix # check dependency licence is valid license-eye dependency check diff --git a/pkg/client/config.go b/pkg/client/config.go index c4fce0f2a..8ca38491e 100644 --- a/pkg/client/config.go +++ b/pkg/client/config.go @@ -20,18 +20,20 @@ package client import ( "flag" "fmt" - "github.com/seata/seata-go/pkg/saga" "io/ioutil" "os" "path/filepath" "runtime" "strings" + "github.com/seata/seata-go/pkg/saga" + "github.com/knadh/koanf" "github.com/knadh/koanf/parsers/json" "github.com/knadh/koanf/parsers/toml" "github.com/knadh/koanf/parsers/yaml" "github.com/knadh/koanf/providers/rawbytes" + "github.com/seata/seata-go/pkg/discovery" "github.com/seata/seata-go/pkg/datasource/sql" diff --git a/pkg/protocol/codec/codec.go b/pkg/protocol/codec/codec.go index 35321e270..c00ab016c 100644 --- a/pkg/protocol/codec/codec.go +++ b/pkg/protocol/codec/codec.go @@ -114,6 +114,7 @@ func (c *CodecManager) Encode(codecType CodecType, in interface{}) []byte { func Init() { // Global + GetCodecManager().RegisterCodec(CodecTypeSeata, &GlobalReportRequestCodec{}) GetCodecManager().RegisterCodec(CodecTypeSeata, &GlobalReportResponseCodec{}) GetCodecManager().RegisterCodec(CodecTypeSeata, &GlobalBeginRequestCodec{}) GetCodecManager().RegisterCodec(CodecTypeSeata, &GlobalBeginResponseCodec{}) diff --git a/pkg/protocol/codec/global_report_request_codec.go b/pkg/protocol/codec/global_report_request_codec.go index ec045ba82..cf05df20b 100644 --- a/pkg/protocol/codec/global_report_request_codec.go +++ b/pkg/protocol/codec/global_report_request_codec.go @@ -52,5 +52,6 @@ func (g *GlobalReportRequestCodec) Encode(in interface{}) []byte { // GetMessageType get global report request's message type func (g *GlobalReportRequestCodec) GetMessageType() message.MessageType { - return message.MessageTypeGlobalReportResult + // must be the request type, not the result type + return message.MessageTypeGlobalReport } diff --git a/pkg/remoting/getty/rpc_client.go b/pkg/remoting/getty/rpc_client.go index 601064ef4..7311876b8 100644 --- a/pkg/remoting/getty/rpc_client.go +++ b/pkg/remoting/getty/rpc_client.go @@ -25,6 +25,7 @@ import ( getty "github.com/apache/dubbo-getty" gxsync "github.com/dubbogo/gost/sync" + "github.com/seata/seata-go/pkg/discovery" "github.com/seata/seata-go/pkg/protocol/codec" "github.com/seata/seata-go/pkg/remoting/config" diff --git a/pkg/remoting/loadbalance/random_loadbalance_test.go b/pkg/remoting/loadbalance/random_loadbalance_test.go index 5db9c8825..e63a74cba 100644 --- a/pkg/remoting/loadbalance/random_loadbalance_test.go +++ b/pkg/remoting/loadbalance/random_loadbalance_test.go @@ -23,8 +23,9 @@ import ( "testing" "github.com/golang/mock/gomock" - "github.com/seata/seata-go/pkg/remoting/mock" "github.com/stretchr/testify/assert" + + "github.com/seata/seata-go/pkg/remoting/mock" ) func TestRandomLoadBalance_Normal(t *testing.T) { diff --git a/pkg/remoting/loadbalance/xid_loadbalance_test.go b/pkg/remoting/loadbalance/xid_loadbalance_test.go index d361f338a..cd47cdd8b 100644 --- a/pkg/remoting/loadbalance/xid_loadbalance_test.go +++ b/pkg/remoting/loadbalance/xid_loadbalance_test.go @@ -22,8 +22,9 @@ import ( "testing" "github.com/golang/mock/gomock" - "github.com/seata/seata-go/pkg/remoting/mock" "github.com/stretchr/testify/assert" + + "github.com/seata/seata-go/pkg/remoting/mock" ) func TestXidLoadBalance(t *testing.T) { diff --git a/pkg/rm/init.go b/pkg/rm/init.go index 469cdfaa6..a88d36f20 100644 --- a/pkg/rm/init.go +++ b/pkg/rm/init.go @@ -30,3 +30,8 @@ type RmConfig struct { func InitRm(cfg RmConfig) { rmConfig = cfg } + +// GetRmAppAndGroup returns current RM applicationId and txServiceGroup +func GetRmAppAndGroup() (string, string) { + return rmConfig.ApplicationID, rmConfig.TxServiceGroup +} diff --git a/pkg/rm/rm_remoting.go b/pkg/rm/rm_remoting.go index 1479e93f2..3b0b01e2e 100644 --- a/pkg/rm/rm_remoting.go +++ b/pkg/rm/rm_remoting.go @@ -25,6 +25,7 @@ import ( "github.com/seata/seata-go/pkg/protocol/message" "github.com/seata/seata-go/pkg/remoting/getty" + serrors "github.com/seata/seata-go/pkg/util/errors" "github.com/seata/seata-go/pkg/util/log" ) @@ -84,7 +85,11 @@ func (r *RMRemoting) BranchReport(param BranchReportParam) error { } if err = isReportSuccess(resp); err != nil { - log.Errorf("BranchReport response error: %v, res %v", err.Error(), resp) + if seataErr, ok := err.(*serrors.SeataError); ok { + log.Warnf("BranchReport response error: code=%v msg=%s", seataErr.Code, seataErr.Message) + } else { + log.Errorf("BranchReport response error: %v, res %v", err.Error(), resp) + } return err } @@ -156,7 +161,12 @@ func isRegisterSuccess(response interface{}) bool { func isReportSuccess(response interface{}) error { if res, ok := response.(message.BranchReportResponse); ok { if res.ResultCode == message.ResultCodeFailed { - return fmt.Errorf(res.Msg) + code := res.TransactionErrorCode + if code == serrors.TransactionErrorCodeBranchTransactionNotExist || int(code) == 120 { + log.Debugf("BranchReport received TransactionErrorCode %d (treated as BranchTransactionNotExist), ignoring", code) + return nil + } + return serrors.New(code, res.Msg, nil) } } else { return ErrBranchReportResponseFault diff --git a/pkg/saga/rm/saga_resource.go b/pkg/saga/rm/saga_resource.go index 2b145a4dc..33bfbbca8 100644 --- a/pkg/saga/rm/saga_resource.go +++ b/pkg/saga/rm/saga_resource.go @@ -19,6 +19,7 @@ package rm import ( "fmt" + "github.com/seata/seata-go/pkg/protocol/branch" ) diff --git a/pkg/saga/rm/state_machine_engine_holder.go b/pkg/saga/rm/state_machine_engine_holder.go index 6aaa4ee00..c99159407 100644 --- a/pkg/saga/rm/state_machine_engine_holder.go +++ b/pkg/saga/rm/state_machine_engine_holder.go @@ -18,8 +18,9 @@ package rm import ( - "github.com/seata/seata-go/pkg/saga/statemachine/engine" "sync" + + "github.com/seata/seata-go/pkg/saga/statemachine/engine" ) var ( diff --git a/pkg/saga/runtime/engine_registry.go b/pkg/saga/runtime/engine_registry.go new file mode 100644 index 000000000..e4213b791 --- /dev/null +++ b/pkg/saga/runtime/engine_registry.go @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package runtime + +import ( + "context" + "sync" + + "github.com/seata/seata-go/pkg/saga/statemachine/statelang" +) + +// Engine is a minimal facade for state machine operations needed by Saga RM +type Engine interface { + Forward(ctx context.Context, stateMachineInstId string, replaceParams map[string]interface{}) (statelang.StateMachineInstance, error) + Compensate(ctx context.Context, stateMachineInstId string, replaceParams map[string]interface{}) (statelang.StateMachineInstance, error) + ReloadStateMachineInstance(ctx context.Context, instId string) (statelang.StateMachineInstance, error) +} + +var ( + mu sync.RWMutex + engineRef Engine +) + +func SetEngine(e Engine) { + mu.Lock() + engineRef = e + mu.Unlock() +} + +func GetEngine() Engine { + mu.RLock() + e := engineRef + mu.RUnlock() + return e +} diff --git a/pkg/saga/statemachine/constant/constant.go b/pkg/saga/statemachine/constant/constant.go index 7234171f4..3d22e3863 100644 --- a/pkg/saga/statemachine/constant/constant.go +++ b/pkg/saga/statemachine/constant/constant.go @@ -65,6 +65,7 @@ const ( VarNameCurrentCompensateTriggerState string = "_is_compensating_" VarNameCurrentCompensationHolder string = "_current_compensation_holder_" VarNameFirstCompensationStateStarted string = "_first_compensation_state_started" + VarNameNoCompensation string = "_no_compensation_case_" VarNameCurrentLoopContextHolder string = "_current_loop_context_holder_" VarNameRetriedStateInstId string = "_retried_state_instance_id" VarNameIsForSubStatMachineForward string = "_is_for_sub_statemachine_forward_" diff --git a/pkg/saga/statemachine/engine/config/bootstrap.go b/pkg/saga/statemachine/engine/config/bootstrap.go new file mode 100644 index 000000000..7f47b0bad --- /dev/null +++ b/pkg/saga/statemachine/engine/config/bootstrap.go @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package config + +import ( + "fmt" + + "github.com/seata/seata-go/pkg/saga/statemachine/constant" + "github.com/seata/seata-go/pkg/saga/statemachine/engine/pcext" + "github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl" + "github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl/handlers" + "github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl/process" +) + +// bootstrapProcessWiring wires default handlers/routers and router handler into BusinessProcessor +func (c *DefaultStateMachineConfig) bootstrapProcessWiring() error { + // Process handler for state machine + smHandler := pcext.NewStateMachineProcessHandler() + // Register ServiceTask handler by default + smHandler.RegistryStateHandler(constant.StateTypeServiceTask, handlers.NewServiceTaskStateHandler()) + + // Router for state machine + smRouter := &pcext.StateMachineProcessRouter{} + smRouter.InitDefaultStateRouters() + + // Default router handler binds event publisher and process routers + drh := &process_ctrl.DefaultRouterHandler{} + drh.SetEventPublisher(c.EventPublisher()) + drh.SetProcessRouters(map[string]process_ctrl.ProcessRouter{ + string(process.StateLang): smRouter, + }) + + // Register into BusinessProcessor + pcImpl, ok := c.processController.(*process_ctrl.ProcessControllerImpl) + if !ok { + return fmt.Errorf("ProcessController is not an instance of ProcessControllerImpl") + } + bp := pcImpl.BusinessProcessor() + // need concrete DefaultBusinessProcessor to call Registry APIs + dbp, ok := bp.(*process_ctrl.DefaultBusinessProcessor) + if !ok { + return fmt.Errorf("BusinessProcessor is not DefaultBusinessProcessor, got %T", bp) + } + dbp.RegistryProcessHandler(process.StateLang, smHandler) + dbp.RegistryRouterHandler(process.StateLang, drh) + return nil +} diff --git a/pkg/saga/statemachine/engine/config/default_statemachine_config.go b/pkg/saga/statemachine/engine/config/default_statemachine_config.go index 12225330a..9d949da84 100644 --- a/pkg/saga/statemachine/engine/config/default_statemachine_config.go +++ b/pkg/saga/statemachine/engine/config/default_statemachine_config.go @@ -21,20 +21,24 @@ import ( "context" "encoding/json" "fmt" - "github.com/seata/seata-go/pkg/saga/statemachine/engine/repo/repository" - "github.com/seata/seata-go/pkg/saga/statemachine/engine/strategy" - "gopkg.in/yaml.v3" "log" "os" "path/filepath" "strings" "sync" + "gopkg.in/yaml.v3" + + "github.com/seata/seata-go/pkg/protocol/branch" + baserm "github.com/seata/seata-go/pkg/rm" + sagarm "github.com/seata/seata-go/pkg/saga/rm" "github.com/seata/seata-go/pkg/saga/statemachine/engine" "github.com/seata/seata-go/pkg/saga/statemachine/engine/expr" "github.com/seata/seata-go/pkg/saga/statemachine/engine/invoker" "github.com/seata/seata-go/pkg/saga/statemachine/engine/repo" + "github.com/seata/seata-go/pkg/saga/statemachine/engine/repo/repository" "github.com/seata/seata-go/pkg/saga/statemachine/engine/sequence" + "github.com/seata/seata-go/pkg/saga/statemachine/engine/strategy" "github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl" "github.com/seata/seata-go/pkg/saga/statemachine/store" ) @@ -89,6 +93,14 @@ type DefaultStateMachineConfig struct { statusDecisionStrategy engine.StatusDecisionStrategy seqGenerator sequence.SeqGenerator componentLock *sync.Mutex + + enableAsync bool + + // runtime store & tc options + storeEnabled bool + storeType string + storeDSN string + tcEnabled bool } func (c *DefaultStateMachineConfig) ComponentLock() *sync.Mutex { @@ -247,6 +259,14 @@ func (c *DefaultStateMachineConfig) GetServiceInvokeTimeout() int { return c.serviceInvokeTimeout } +func (c *DefaultStateMachineConfig) EnableAsync() bool { + return c.enableAsync +} + +func (c *DefaultStateMachineConfig) SetEnableAsync(enable bool) { + c.enableAsync = enable +} + func (c *DefaultStateMachineConfig) IsSagaRetryPersistModeUpdate() bool { return c.sagaRetryPersistModeUpdate } @@ -342,11 +362,16 @@ type ConfigFileParams struct { ServiceInvokeTimeout int `json:"service_invoke_timeout" yaml:"service_invoke_timeout"` Charset string `json:"charset" yaml:"charset"` DefaultTenantId string `json:"default_tenant_id" yaml:"default_tenant_id"` + EnableAsync bool `json:"enable_async" yaml:"enable_async"` SagaRetryPersistModeUpdate bool `json:"saga_retry_persist_mode_update" yaml:"saga_retry_persist_mode_update"` SagaCompensatePersistModeUpdate bool `json:"saga_compensate_persist_mode_update" yaml:"saga_compensate_persist_mode_update"` SagaBranchRegisterEnable bool `json:"saga_branch_register_enable" yaml:"saga_branch_register_enable"` RmReportSuccessEnable bool `json:"rm_report_success_enable" yaml:"rm_report_success_enable"` StateMachineResources []string `json:"state_machine_resources" yaml:"state_machine_resources"` + StoreEnabled bool `json:"store_enabled" yaml:"store_enabled"` + StoreType string `json:"store_type" yaml:"store_type"` + StoreDSN string `json:"store_dsn" yaml:"store_dsn"` + TCEnabled bool `json:"tc_enabled" yaml:"tc_enabled"` } func (c *DefaultStateMachineConfig) LoadConfig(configPath string) error { @@ -392,6 +417,7 @@ func (c *DefaultStateMachineConfig) applyConfigFileParams(rc *ConfigFileParams) if rc.DefaultTenantId != "" { c.defaultTenantId = rc.DefaultTenantId } + c.enableAsync = rc.EnableAsync c.sagaRetryPersistModeUpdate = rc.SagaRetryPersistModeUpdate c.sagaCompensatePersistModeUpdate = rc.SagaCompensatePersistModeUpdate c.sagaBranchRegisterEnable = rc.SagaBranchRegisterEnable @@ -399,6 +425,14 @@ func (c *DefaultStateMachineConfig) applyConfigFileParams(rc *ConfigFileParams) if len(rc.StateMachineResources) > 0 { c.stateMachineResources = rc.StateMachineResources } + c.storeEnabled = rc.StoreEnabled + if rc.StoreType != "" { + c.storeType = rc.StoreType + } + if rc.StoreDSN != "" { + c.storeDSN = rc.StoreDSN + } + c.tcEnabled = rc.TCEnabled } func (c *DefaultStateMachineConfig) registerEventConsumers() error { @@ -432,6 +466,14 @@ func (c *DefaultStateMachineConfig) Init() error { return fmt.Errorf("initialize service invokers failed: %w", err) } + if err := c.bootstrapProcessWiring(); err != nil { + return fmt.Errorf("bootstrap wiring failed: %w", err) + } + + if err := c.SetupStoresFromConfig(); err != nil { + return fmt.Errorf("setup stores from config failed: %w", err) + } + if err := c.registerEventConsumers(); err != nil { return fmt.Errorf("register event consumers failed: %w", err) } @@ -442,6 +484,19 @@ func (c *DefaultStateMachineConfig) Init() error { } } + if c.tcEnabled { + sagarm.InitSaga() + app, group := baserm.GetRmAppAndGroup() + if app != "" && group != "" { + if mgr := baserm.GetRmCacheInstance().GetResourceManager(branch.BranchTypeSAGA); mgr != nil { + resource := &sagarm.SagaResource{} + resource.SetApplicationId(app) + resource.SetResourceGroupId(group) + _ = mgr.RegisterResource(resource) + } + } + } + if err := c.Validate(); err != nil { return fmt.Errorf("configuration validation failed: %w", err) } @@ -641,6 +696,7 @@ func NewDefaultStateMachineConfig(opts ...Option) (*DefaultStateMachineConfig, e c.stateMachineRepository = repository.GetStateMachineRepositoryImpl() c.stateLogRepository = repository.NewStateLogRepositoryImpl() + repository.GetStateMachineRepositoryImpl().SetDefaultTenantId(c.defaultTenantId) c.syncProcessCtrlEventPublisher = process_ctrl.NewProcessCtrlEventPublisher(c.syncEventBus) c.asyncProcessCtrlEventPublisher = process_ctrl.NewProcessCtrlEventPublisher(c.asyncEventBus) @@ -648,6 +704,16 @@ func NewDefaultStateMachineConfig(opts ...Option) (*DefaultStateMachineConfig, e for _, opt := range opts { opt(c) } + repository.GetStateMachineRepositoryImpl().SetDefaultTenantId(c.defaultTenantId) + + if _, statErr := os.Stat("config.yaml"); statErr == nil { + if err := c.LoadConfig("config.yaml"); err == nil { + repository.GetStateMachineRepositoryImpl().SetDefaultTenantId(c.defaultTenantId) + log.Printf("Successfully loaded config from config.yaml") + } else { + log.Printf("Failed to load config file (using default/env values): %v", err) + } + } if err := c.Init(); err != nil { return nil, fmt.Errorf("failed to initialize state machine config: %w", err) @@ -708,6 +774,12 @@ func WithStateMachineResources(paths []string) Option { } } +func WithEnableAsync(enable bool) Option { + return func(c *DefaultStateMachineConfig) { + c.enableAsync = enable + } +} + func WithStateLogRepository(logRepo repo.StateLogRepository) Option { return func(c *DefaultStateMachineConfig) { c.stateLogRepository = logRepo diff --git a/pkg/saga/statemachine/engine/config/default_statemachine_config_test.go b/pkg/saga/statemachine/engine/config/default_statemachine_config_test.go index 342f5c3b7..748869d37 100644 --- a/pkg/saga/statemachine/engine/config/default_statemachine_config_test.go +++ b/pkg/saga/statemachine/engine/config/default_statemachine_config_test.go @@ -18,14 +18,16 @@ package config import ( - "github.com/pkg/errors" - "github.com/seata/seata-go/pkg/saga/statemachine/statelang" "io" "os" "path/filepath" "reflect" "testing" + "github.com/pkg/errors" + + "github.com/seata/seata-go/pkg/saga/statemachine/statelang" + "github.com/stretchr/testify/assert" ) @@ -60,7 +62,7 @@ func TestDefaultStateMachineConfig_LoadValidJSON(t *testing.T) { assert.NoError(t, err, "Failed to initialize config") assert.NotNil(t, config, "config is nil") - smo, err := config.stateMachineRepository.GetStateMachineByNameAndTenantId("OrderSaga", "") + smo, err := config.stateMachineRepository.GetStateMachineByNameAndTenantId("OrderSaga", config.GetDefaultTenantId()) assert.NoError(t, err) assert.NotNil(t, smo, "State machine definition should not be nil") @@ -77,7 +79,7 @@ func TestDefaultStateMachineConfig_LoadValidYAML(t *testing.T) { assert.NoError(t, err, "Failed to initialize config") assert.NotNil(t, config, "config is nil") - smo, err := config.stateMachineRepository.GetStateMachineByNameAndTenantId("OrderSaga", "") + smo, err := config.stateMachineRepository.GetStateMachineByNameAndTenantId("OrderSaga", config.GetDefaultTenantId()) assert.NoError(t, err) assert.NotNil(t, smo, "State machine definition should not be nil (YAML)") @@ -100,7 +102,7 @@ func TestGetStateMachineDefinition_Exists(t *testing.T) { config, _ := NewDefaultStateMachineConfig(WithConfigPath(filepath.Join("testdata", "saga_config.json"))) - smo, err := config.stateMachineRepository.GetStateMachineByNameAndTenantId("OrderSaga", "") + smo, err := config.stateMachineRepository.GetStateMachineByNameAndTenantId("OrderSaga", config.GetDefaultTenantId()) assert.NoError(t, err) assert.NotNil(t, smo) assert.Equal(t, "1.0", smo.Version(), "The version number should be correct") @@ -110,7 +112,7 @@ func TestGetNonExistentStateMachine(t *testing.T) { os.Unsetenv("SEATA_STATE_MACHINE_RESOURCES") config, _ := NewDefaultStateMachineConfig() - smo, err := config.stateMachineRepository.GetStateMachineByNameAndTenantId("NonExistent", "") + smo, err := config.stateMachineRepository.GetStateMachineByNameAndTenantId("NonExistent", config.GetDefaultTenantId()) assert.Error(t, err) assert.True(t, smo == nil || reflect.ValueOf(smo).IsZero(), "An unloaded state machine should return nil/zero") } diff --git a/pkg/saga/statemachine/engine/config/noop_store.go b/pkg/saga/statemachine/engine/config/noop_store.go index f40f6278a..5b2bf9f27 100644 --- a/pkg/saga/statemachine/engine/config/noop_store.go +++ b/pkg/saga/statemachine/engine/config/noop_store.go @@ -19,6 +19,7 @@ package config import ( "context" + "github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl" "github.com/seata/seata-go/pkg/saga/statemachine/statelang" ) diff --git a/pkg/saga/statemachine/engine/config/store_bootstrap.go b/pkg/saga/statemachine/engine/config/store_bootstrap.go new file mode 100644 index 000000000..667a9af0a --- /dev/null +++ b/pkg/saga/statemachine/engine/config/store_bootstrap.go @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package config + +import ( + "database/sql" + "fmt" + + "github.com/seata/seata-go/pkg/saga/statemachine/engine/repo/repository" + dbstore "github.com/seata/seata-go/pkg/saga/statemachine/store/db" + sagaTm "github.com/seata/seata-go/pkg/saga/tm" +) + +// SetupStoresFromConfig wires DB-backed StateLogStore/StateLangStore and Saga TM template +// according to runtime options loaded into DefaultStateMachineConfig. +// It keeps Go style (no DI/SPI), and is safe to call multiple times. +func (c *DefaultStateMachineConfig) SetupStoresFromConfig() error { + if !c.storeEnabled { + // keep Noop stores for in-memory usage + return nil + } + if c.storeType == "" || c.storeDSN == "" { + return fmt.Errorf("store_enabled=true but store_type/dsn not provided") + } + + driver := c.storeType + if driver == "sqlite" || driver == "sqlite3" { + driver = "sqlite3" + } + + db, err := sql.Open(driver, c.storeDSN) + if err != nil { + return err + } + if err := db.Ping(); err != nil { + return fmt.Errorf("db ping failed: %w", err) + } + + // build stores with default table prefix `seata_` + lang := dbstore.NewStateLangStore(db, "seata_") + logStore := dbstore.NewStateLogStore(db, "seata_") + + // inject transactional template only when tc is enabled + if c.tcEnabled { + // keep minimal default template; client init is handled externally by user + var tmpl sagaTm.SagaTransactionalTemplate = &sagaTm.DefaultSagaTransactionalTemplate{} + logStore.SetSagaTransactionalTemplate(tmpl) + } + + // set into config & repository + c.SetStateLangStore(lang) + c.SetStateLogStore(logStore) + repository.GetStateMachineRepositoryImpl().SetStateLangStore(lang) + + return nil +} diff --git a/pkg/saga/statemachine/engine/core/process_ctrl_statemachine_engine.go b/pkg/saga/statemachine/engine/core/process_ctrl_statemachine_engine.go index 432e1a595..5d433eac8 100644 --- a/pkg/saga/statemachine/engine/core/process_ctrl_statemachine_engine.go +++ b/pkg/saga/statemachine/engine/core/process_ctrl_statemachine_engine.go @@ -20,7 +20,9 @@ package core import ( "context" "fmt" - "github.com/pkg/errors" + "time" + + "github.com/seata/seata-go/pkg/saga/runtime" "github.com/seata/seata-go/pkg/saga/statemachine/constant" "github.com/seata/seata-go/pkg/saga/statemachine/engine" "github.com/seata/seata-go/pkg/saga/statemachine/engine/config" @@ -31,9 +33,9 @@ import ( "github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl/process" "github.com/seata/seata-go/pkg/saga/statemachine/statelang" "github.com/seata/seata-go/pkg/saga/statemachine/statelang/state" + tmctx "github.com/seata/seata-go/pkg/tm" seataErrors "github.com/seata/seata-go/pkg/util/errors" "github.com/seata/seata-go/pkg/util/log" - "time" ) type ProcessCtrlStateMachineEngine struct { @@ -47,9 +49,13 @@ func NewProcessCtrlStateMachineEngine() (*ProcessCtrlStateMachineEngine, error) return nil, fmt.Errorf("failed to create state machine configuration: %w", err) } - return &ProcessCtrlStateMachineEngine{ + engine := &ProcessCtrlStateMachineEngine{ StateMachineConfig: cfg, - }, nil + } + + runtime.SetEngine(engine) + + return engine, nil } func (p ProcessCtrlStateMachineEngine) Start(ctx context.Context, stateMachineName string, tenantId string, @@ -147,6 +153,15 @@ func (p ProcessCtrlStateMachineEngine) ReloadStateMachineInstance(ctx context.Co func (p ProcessCtrlStateMachineEngine) startInternal(ctx context.Context, stateMachineName string, tenantId string, businessKey string, startParams map[string]interface{}, async bool, callback engine.CallBack) (statelang.StateMachineInstance, error) { + if !tmctx.IsSeataContext(ctx) { + ctx = tmctx.InitSeataContext(ctx) + } + + if async && !p.StateMachineConfig.EnableAsync() { + return nil, exception.NewEngineExecutionException(seataErrors.AsynchronousStartDisabled, + "asynchronous start is disabled by configuration (enable_async=false)", nil) + } + if tenantId == "" { tenantId = p.StateMachineConfig.GetDefaultTenantId() } @@ -393,7 +408,8 @@ func (p ProcessCtrlStateMachineEngine) createMachineInstance(stateMachineName st } if stateMachine == nil { - return nil, errors.New("StateMachine [" + stateMachineName + "] is not exists") + return nil, exception.NewEngineExecutionException(seataErrors.ObjectNotExists, + "StateMachine ["+stateMachineName+"] is not exists", nil) } stateMachineInstance := statelang.NewStateMachineInstanceImpl() diff --git a/pkg/saga/statemachine/engine/core/process_ctrl_statemachine_engine_async_test.go b/pkg/saga/statemachine/engine/core/process_ctrl_statemachine_engine_async_test.go new file mode 100644 index 000000000..0e2de1a14 --- /dev/null +++ b/pkg/saga/statemachine/engine/core/process_ctrl_statemachine_engine_async_test.go @@ -0,0 +1,345 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package core + +import ( + "context" + stderr "errors" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/seata/seata-go/pkg/saga/statemachine/engine/config" + engExc "github.com/seata/seata-go/pkg/saga/statemachine/engine/exception" + "github.com/seata/seata-go/pkg/saga/statemachine/engine/invoker" + "github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl" + "github.com/seata/seata-go/pkg/saga/statemachine/statelang" + "github.com/seata/seata-go/pkg/saga/statemachine/statelang/parser" + "github.com/seata/seata-go/pkg/util/errors" +) + +func TestStartAsyncDisabled(t *testing.T) { + cfg, err := config.NewDefaultStateMachineConfig( + config.WithEnableAsync(false), + config.WithStateMachineResources(nil), + ) + require.NoError(t, err) + + engine := &ProcessCtrlStateMachineEngine{StateMachineConfig: cfg} + + _, err = engine.StartAsync(context.Background(), "dummy", "", nil, nil) + require.Error(t, err) + + execErr, ok := engExc.IsEngineExecutionException(err) + if !ok { + t.Fatalf("unexpected error type %T: %v", err, err) + } + require.Equal(t, errors.AsynchronousStartDisabled, execErr.Code) +} + +func TestStartAsyncEnabledWithoutStateMachine(t *testing.T) { + cfg, err := config.NewDefaultStateMachineConfig( + config.WithEnableAsync(true), + config.WithStateMachineResources(nil), + ) + require.NoError(t, err) + + engine := &ProcessCtrlStateMachineEngine{StateMachineConfig: cfg} + + _, err = engine.StartAsync(context.Background(), "missing", "", nil, nil) + require.Error(t, err) + + execErr, ok := engExc.IsEngineExecutionException(err) + if !ok { + t.Fatalf("unexpected error type %T: %v", err, err) + } + require.Equal(t, errors.ObjectNotExists, execErr.Code) +} + +func TestStartAsyncSuccessfulFlow(t *testing.T) { + cfg, err := config.NewDefaultStateMachineConfig( + config.WithEnableAsync(true), + config.WithStateMachineResources(nil), + ) + require.NoError(t, err) + + const simpleStateMachine = `{ + "Name": "AsyncSimple", + "StartState": "Greet", + "States": { + "Greet": { + "Type": "ServiceTask", + "ServiceType": "local", + "ServiceName": "asyncTestService", + "ServiceMethod": "Hello", + "CompensateState": "", + "ForCompensation": false, + "ForUpdate": false, + "Next": "Success" + }, + "Success": { + "Type": "Succeed" + } + } + }` + + stateMachine, err := parser.NewJSONStateMachineParser().Parse(simpleStateMachine) + require.NoError(t, err) + stateMachine.SetTenantId(cfg.GetDefaultTenantId()) + stateMachine.SetContent(simpleStateMachine) + require.NoError(t, cfg.StateMachineRepository().RegistryStateMachine(stateMachine)) + + localInvoker, ok := cfg.ServiceInvokerManager().ServiceInvoker("local").(*invoker.LocalServiceInvoker) + require.True(t, ok) + service := &asyncTestService{} + localInvoker.RegisterService("asyncTestService", service) + + callback := newAsyncTestCallback() + engine := &ProcessCtrlStateMachineEngine{StateMachineConfig: cfg} + + inst, err := engine.StartAsync(context.Background(), "AsyncSimple", "", nil, callback) + require.NoError(t, err) + require.NotNil(t, inst) + require.Equal(t, "AsyncSimple", inst.StateMachine().Name()) + + select { + case <-callback.Done(): + case <-time.After(3 * time.Second): + t.Fatal("timed out waiting for async callback") + } + + require.NoError(t, callback.Err()) + require.NotNil(t, callback.Instance()) + require.Equal(t, statelang.SU, callback.Instance().Status()) + require.Equal(t, 1, service.Calls()) +} + +type asyncTestService struct { + mu sync.Mutex + calls int +} + +func (s *asyncTestService) Hello() string { + s.mu.Lock() + defer s.mu.Unlock() + s.calls++ + return "ok" +} + +func (s *asyncTestService) Calls() int { + s.mu.Lock() + defer s.mu.Unlock() + return s.calls +} + +type asyncTestCallback struct { + done chan struct{} + once sync.Once + mu sync.Mutex + inst statelang.StateMachineInstance + err error +} + +func newAsyncTestCallback() *asyncTestCallback { + return &asyncTestCallback{done: make(chan struct{})} +} + +func (c *asyncTestCallback) OnFinished(ctx context.Context, _ process_ctrl.ProcessContext, stateMachineInstance statelang.StateMachineInstance) { + c.mu.Lock() + c.inst = stateMachineInstance + c.err = nil + c.mu.Unlock() + c.once.Do(func() { close(c.done) }) +} + +func (c *asyncTestCallback) OnError(ctx context.Context, _ process_ctrl.ProcessContext, stateMachineInstance statelang.StateMachineInstance, err error) { + c.mu.Lock() + c.inst = stateMachineInstance + c.err = err + c.mu.Unlock() + c.once.Do(func() { close(c.done) }) +} + +func (c *asyncTestCallback) Done() <-chan struct{} { + return c.done +} + +func (c *asyncTestCallback) Err() error { + c.mu.Lock() + defer c.mu.Unlock() + return c.err +} + +func (c *asyncTestCallback) Instance() statelang.StateMachineInstance { + c.mu.Lock() + defer c.mu.Unlock() + return c.inst +} + +// TestStartAsyncWithCompensation tests async state machine with compensation flow +func TestStartAsyncWithCompensation(t *testing.T) { + cfg, err := config.NewDefaultStateMachineConfig( + config.WithEnableAsync(true), + config.WithStateMachineResources(nil), + ) + require.NoError(t, err) + + const compensationStateMachine = `{ + "Name": "AsyncCompensation", + "StartState": "Task1", + "States": { + "Task1": { + "Type": "ServiceTask", + "ServiceType": "local", + "ServiceName": "asyncCompensateService", + "ServiceMethod": "DoTask1", + "CompensateState": "CompensateTask1", + "Next": "Task2" + }, + "Task2": { + "Type": "ServiceTask", + "ServiceType": "local", + "ServiceName": "asyncCompensateService", + "ServiceMethod": "DoTask2Fail", + "CompensateState": "CompensateTask2", + "Next": "Success", + "Catch": [ + { + "Exceptions": ["java.lang.Exception"], + "Next": "CompensationTrigger" + } + ] + }, + "CompensationTrigger": { + "Type": "CompensationTrigger", + "Next": "Fail" + }, + "CompensateTask1": { + "Type": "ServiceTask", + "ServiceType": "local", + "ServiceName": "asyncCompensateService", + "ServiceMethod": "CompensateTask1", + "ForCompensation": true, + "Next": "CompensateEnd" + }, + "CompensateTask2": { + "Type": "ServiceTask", + "ServiceType": "local", + "ServiceName": "asyncCompensateService", + "ServiceMethod": "CompensateTask2", + "ForCompensation": true, + "Next": "CompensateEnd" + }, + "CompensateEnd": { + "Type": "Succeed" + }, + "Success": { + "Type": "Succeed" + }, + "Fail": { + "Type": "Fail" + } + } + }` + + stateMachine, err := parser.NewJSONStateMachineParser().Parse(compensationStateMachine) + require.NoError(t, err) + stateMachine.SetTenantId(cfg.GetDefaultTenantId()) + stateMachine.SetContent(compensationStateMachine) + require.NoError(t, cfg.StateMachineRepository().RegistryStateMachine(stateMachine)) + + localInvoker, ok := cfg.ServiceInvokerManager().ServiceInvoker("local").(*invoker.LocalServiceInvoker) + require.True(t, ok) + service := &asyncCompensateService{} + localInvoker.RegisterService("asyncCompensateService", service) + + callback := newAsyncTestCallback() + engine := &ProcessCtrlStateMachineEngine{StateMachineConfig: cfg} + + inst, err := engine.StartAsync(context.Background(), "AsyncCompensation", "", nil, callback) + require.NoError(t, err) + require.NotNil(t, inst) + require.Equal(t, "AsyncCompensation", inst.StateMachine().Name()) + + select { + case <-callback.Done(): + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for async callback") + } + + // Should finish with error due to Task2 failure + // Note: In test environment without DB store, compensation flow may not be fully executed + require.NotNil(t, callback.Instance()) + require.NotEqual(t, statelang.SU, callback.Instance().Status(), "Status should not be SU after Task2 fails") + require.Equal(t, 1, service.Task1Calls(), "Task1 should be called once") + require.Equal(t, 1, service.Task2Calls(), "Task2 should be called once (and fail)") +} + +type asyncCompensateService struct { + mu sync.Mutex + task1Calls int + task2Calls int + compensateTask1Calls int +} + +func (s *asyncCompensateService) DoTask1() string { + s.mu.Lock() + defer s.mu.Unlock() + s.task1Calls++ + return "task1_ok" +} + +func (s *asyncCompensateService) DoTask2Fail() (string, error) { + s.mu.Lock() + defer s.mu.Unlock() + s.task2Calls++ + return "", stderr.New("task2_failed") +} + +func (s *asyncCompensateService) CompensateTask1() string { + s.mu.Lock() + defer s.mu.Unlock() + s.compensateTask1Calls++ + return "compensate_task1_ok" +} + +func (s *asyncCompensateService) CompensateTask2() string { + s.mu.Lock() + defer s.mu.Unlock() + return "compensate_task2_ok" +} + +func (s *asyncCompensateService) Task1Calls() int { + s.mu.Lock() + defer s.mu.Unlock() + return s.task1Calls +} + +func (s *asyncCompensateService) Task2Calls() int { + s.mu.Lock() + defer s.mu.Unlock() + return s.task2Calls +} + +func (s *asyncCompensateService) CompensateTask1Calls() int { + s.mu.Lock() + defer s.mu.Unlock() + return s.compensateTask1Calls +} diff --git a/pkg/saga/statemachine/engine/exception/exception.go b/pkg/saga/statemachine/engine/exception/exception.go index 7c2e475e0..c6f8fc165 100644 --- a/pkg/saga/statemachine/engine/exception/exception.go +++ b/pkg/saga/statemachine/engine/exception/exception.go @@ -20,6 +20,7 @@ package exception import ( perror "errors" "fmt" + "github.com/seata/seata-go/pkg/util/errors" ) diff --git a/pkg/saga/statemachine/engine/expr/error_expression.go b/pkg/saga/statemachine/engine/expr/error_expression.go index da44804b2..fc6a6b00c 100644 --- a/pkg/saga/statemachine/engine/expr/error_expression.go +++ b/pkg/saga/statemachine/engine/expr/error_expression.go @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package expr // ErrorExpression is a placeholder implementation that always reports an error. diff --git a/pkg/saga/statemachine/engine/expr/expression_resolver_test.go b/pkg/saga/statemachine/engine/expr/expression_resolver_test.go index f518a425d..f65f3b6fd 100644 --- a/pkg/saga/statemachine/engine/expr/expression_resolver_test.go +++ b/pkg/saga/statemachine/engine/expr/expression_resolver_test.go @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package expr import ( diff --git a/pkg/saga/statemachine/engine/expr/sequence_expression_factory.go b/pkg/saga/statemachine/engine/expr/sequence_expression_factory.go index 682c84b5e..97c0e1f9d 100644 --- a/pkg/saga/statemachine/engine/expr/sequence_expression_factory.go +++ b/pkg/saga/statemachine/engine/expr/sequence_expression_factory.go @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package expr import ( diff --git a/pkg/saga/statemachine/engine/invoker/grpc_invoker.go b/pkg/saga/statemachine/engine/invoker/grpc_invoker.go index afdc7ef39..726414273 100644 --- a/pkg/saga/statemachine/engine/invoker/grpc_invoker.go +++ b/pkg/saga/statemachine/engine/invoker/grpc_invoker.go @@ -21,13 +21,15 @@ import ( "context" "errors" "fmt" - "github.com/seata/seata-go/pkg/saga/statemachine/statelang/state" - "github.com/seata/seata-go/pkg/util/log" - "google.golang.org/grpc" "reflect" "strings" "sync" "time" + + "google.golang.org/grpc" + + "github.com/seata/seata-go/pkg/saga/statemachine/statelang/state" + "github.com/seata/seata-go/pkg/util/log" ) type GRPCInvoker struct { diff --git a/pkg/saga/statemachine/engine/invoker/grpc_invoker_test.go b/pkg/saga/statemachine/engine/invoker/grpc_invoker_test.go index 4beb0d187..b04b8c5e1 100644 --- a/pkg/saga/statemachine/engine/invoker/grpc_invoker_test.go +++ b/pkg/saga/statemachine/engine/invoker/grpc_invoker_test.go @@ -21,12 +21,14 @@ import ( "context" "errors" "fmt" - "github.com/seata/seata-go/pkg/saga/statemachine/statelang/state" + "testing" + "time" + "github.com/stretchr/testify/assert" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" - "testing" - "time" + + "github.com/seata/seata-go/pkg/saga/statemachine/statelang/state" pb "github.com/seata/seata-go/testdata/saga/engine/invoker/grpc" ) diff --git a/pkg/saga/statemachine/engine/invoker/http_invoker_test.go b/pkg/saga/statemachine/engine/invoker/http_invoker_test.go index 062aac03c..98c043a49 100644 --- a/pkg/saga/statemachine/engine/invoker/http_invoker_test.go +++ b/pkg/saga/statemachine/engine/invoker/http_invoker_test.go @@ -25,8 +25,9 @@ import ( "testing" "time" - "github.com/seata/seata-go/pkg/saga/statemachine/statelang/state" "github.com/stretchr/testify/assert" + + "github.com/seata/seata-go/pkg/saga/statemachine/statelang/state" ) func TestHTTPInvokerInvokeSucceedWithOutRetry(t *testing.T) { diff --git a/pkg/saga/statemachine/engine/invoker/local_invoker.go b/pkg/saga/statemachine/engine/invoker/local_invoker.go index 14c322413..767e58018 100644 --- a/pkg/saga/statemachine/engine/invoker/local_invoker.go +++ b/pkg/saga/statemachine/engine/invoker/local_invoker.go @@ -20,9 +20,10 @@ package invoker import ( "context" "fmt" - "github.com/seata/seata-go/pkg/saga/statemachine/statelang/state" "reflect" "sync" + + "github.com/seata/seata-go/pkg/saga/statemachine/statelang/state" ) type LocalServiceInvoker struct { @@ -104,39 +105,43 @@ func (l *LocalServiceInvoker) getMethod(serviceName, methodName string, paramTyp } func (l *LocalServiceInvoker) resolveParameters(input []any, methodType reflect.Type) ([]reflect.Value, error) { - numIn := methodType.NumIn() - paramStart, paramCount := 1, 0 - - if numIn > 0 { - paramCount = numIn - paramStart - } - - if paramCount == 0 { - if len(input) > 0 { - return nil, fmt.Errorf("unexpected parameters: expected 0, got %d", len(input)) - } - return []reflect.Value{}, nil + argTotal := methodType.NumIn() + if argTotal == 0 { + return nil, nil } - if len(input) < paramCount { - return nil, fmt.Errorf("insufficient parameters: expected %d, got %d", paramCount, len(input)) + argCount := argTotal - 1 // skip receiver + if argCount <= 0 { + return nil, nil } - if len(input) > paramCount { - return nil, fmt.Errorf("too many parameters: expected %d, got %d", paramCount, len(input)) - } - - params := make([]reflect.Value, paramCount) - for i := 0; i < paramCount; i++ { - methodParamIndex := i + paramStart - paramType := methodType.In(methodParamIndex) + params := make([]reflect.Value, argCount) + for i := 0; i < argCount; i++ { + paramType := methodType.In(i + 1) + if i >= len(input) { + params[i] = reflect.Zero(paramType) + continue + } converted, err := l.convertParam(input[i], paramType) if err != nil { return nil, fmt.Errorf("parameter %d conversion error: %w", i, err) } - params[i] = reflect.ValueOf(converted) + val := reflect.ValueOf(converted) + if !val.IsValid() { + params[i] = reflect.Zero(paramType) + continue + } + if val.Type().AssignableTo(paramType) { + params[i] = val + continue + } + if val.Type().ConvertibleTo(paramType) { + params[i] = val.Convert(paramType) + continue + } + params[i] = reflect.Zero(paramType) } return params, nil diff --git a/pkg/saga/statemachine/engine/pcext/compensation_holder.go b/pkg/saga/statemachine/engine/pcext/compensation_holder.go index d533460fd..bd62dce97 100644 --- a/pkg/saga/statemachine/engine/pcext/compensation_holder.go +++ b/pkg/saga/statemachine/engine/pcext/compensation_holder.go @@ -19,11 +19,12 @@ package pcext import ( "context" + "sync" + "github.com/seata/seata-go/pkg/saga/statemachine/constant" "github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl" "github.com/seata/seata-go/pkg/saga/statemachine/statelang" "github.com/seata/seata-go/pkg/util/collection" - "sync" ) type CompensationHolder struct { @@ -69,13 +70,23 @@ func NewCompensationHolder() *CompensationHolder { } func GetCurrentCompensationHolder(ctx context.Context, processContext process_ctrl.ProcessContext, forceCreate bool) *CompensationHolder { - compensationholder := processContext.GetVariable(constant.VarNameCurrentCompensationHolder).(*CompensationHolder) - lock := processContext.GetVariable(constant.VarNameProcessContextMutexLock).(*sync.Mutex) + var holder *CompensationHolder + if v := processContext.GetVariable(constant.VarNameCurrentCompensationHolder); v != nil { + if h, ok := v.(*CompensationHolder); ok { + holder = h + } + } + // ensure context mutex exists + lock, _ := processContext.GetVariable(constant.VarNameProcessContextMutexLock).(*sync.Mutex) + if lock == nil { + lock = &sync.Mutex{} + processContext.SetVariable(constant.VarNameProcessContextMutexLock, lock) + } lock.Lock() defer lock.Unlock() - if compensationholder == nil && forceCreate { - compensationholder = NewCompensationHolder() - processContext.SetVariable(constant.VarNameCurrentCompensationHolder, compensationholder) + if holder == nil && forceCreate { + holder = NewCompensationHolder() + processContext.SetVariable(constant.VarNameCurrentCompensationHolder, holder) } - return compensationholder + return holder } diff --git a/pkg/saga/statemachine/engine/pcext/engine_utils.go b/pkg/saga/statemachine/engine/pcext/engine_utils.go index fc7f35665..0bfdf1940 100644 --- a/pkg/saga/statemachine/engine/pcext/engine_utils.go +++ b/pkg/saga/statemachine/engine/pcext/engine_utils.go @@ -20,17 +20,19 @@ package pcext import ( "context" "errors" + "reflect" + "strings" + "sync" + "time" + + "golang.org/x/sync/semaphore" + "github.com/seata/seata-go/pkg/saga/statemachine/constant" "github.com/seata/seata-go/pkg/saga/statemachine/engine" "github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl" "github.com/seata/seata-go/pkg/saga/statemachine/statelang" "github.com/seata/seata-go/pkg/saga/statemachine/statelang/state" "github.com/seata/seata-go/pkg/util/log" - "golang.org/x/sync/semaphore" - "reflect" - "strings" - "sync" - "time" ) func EndStateMachine(ctx context.Context, processContext process_ctrl.ProcessContext) error { @@ -51,9 +53,13 @@ func EndStateMachine(ctx context.Context, processContext process_ctrl.ProcessCon stateMachineInstance.SetEndTime(time.Now()) - exp, ok := processContext.GetVariable(constant.VarNameCurrentException).(error) - if !ok { - return errors.New("exception type is not error") + var exp error + if v := processContext.GetVariable(constant.VarNameCurrentException); v != nil { + ev, ok := v.(error) + if !ok { + return errors.New("exception type is not error") + } + exp = ev } if exp != nil { @@ -67,6 +73,39 @@ func EndStateMachine(ctx context.Context, processContext process_ctrl.ProcessCon return err } + // Reconcile compensation status using executed compensation states to align with Java semantics + // If there are compensation states and all succeeded, override to SU. If any failed/unknown, mark UN. + // This guards against edge cases where holder/stack visibility causes FA. + compSeen := false + compAllSU := true + for _, si := range stateMachineInstance.StateList() { + if si.StateIDCompensatedFor() != "" { + compSeen = true + if si.Status() != statelang.SU { + compAllSU = false + break + } + } + } + if compSeen { + if compAllSU { + log.Debugf("All compensation states SU, overriding machine compensation_status to SU for [%s]", stateMachineInstance.ID()) + stateMachineInstance.SetCompensationStatus(statelang.SU) + // In Java semantics, compensation success with Fail end yields final machine status FA + stateMachineInstance.SetStatus(statelang.FA) + } else if stateMachineInstance.CompensationStatus() == "" || stateMachineInstance.CompensationStatus() == statelang.RU { + log.Debugf("Compensation states contain non-SU, marking compensation_status UN for [%s]", stateMachineInstance.ID()) + stateMachineInstance.SetCompensationStatus(statelang.UN) + } + } else { + log.Debugf("No compensation states observed for machine [%s]", stateMachineInstance.ID()) + if v, ok := processContext.GetVariable(constant.VarNameFailEndStateFlag).(bool); ok && v { + // Fail end without compensation: normalize to status=FA, empty compensation_status + stateMachineInstance.SetCompensationStatus("") + stateMachineInstance.SetStatus(statelang.FA) + } + } + contextParams, ok := processContext.GetVariable(constant.VarNameStateMachineContext).(map[string]interface{}) if !ok { return errors.New("state machine context type is not map[string]interface{}") @@ -77,11 +116,16 @@ func EndStateMachine(ctx context.Context, processContext process_ctrl.ProcessCon } stateMachineInstance.SetEndParams(endParams) - stateInstruction, ok := processContext.GetInstruction().(StateInstruction) - if !ok { + switch v := processContext.GetInstruction().(type) { + case *StateInstruction: + v.SetEnd(true) + case StateInstruction: + tmp := v + tmp.SetEnd(true) + processContext.SetInstruction(&tmp) + default: return errors.New("state instruction type is not process_ctrl.StateInstruction") } - stateInstruction.SetEnd(true) stateMachineInstance.SetRunning(false) stateMachineInstance.SetEndTime(time.Now()) diff --git a/pkg/saga/statemachine/engine/pcext/instruction.go b/pkg/saga/statemachine/engine/pcext/instruction.go index 6e39f9253..d5c3f3710 100644 --- a/pkg/saga/statemachine/engine/pcext/instruction.go +++ b/pkg/saga/statemachine/engine/pcext/instruction.go @@ -20,6 +20,7 @@ package pcext import ( "errors" "fmt" + "github.com/seata/seata-go/pkg/saga/statemachine/constant" "github.com/seata/seata-go/pkg/saga/statemachine/engine" "github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl" diff --git a/pkg/saga/statemachine/engine/pcext/loop_context_holder.go b/pkg/saga/statemachine/engine/pcext/loop_context_holder.go index d9708008b..45f1c56a0 100644 --- a/pkg/saga/statemachine/engine/pcext/loop_context_holder.go +++ b/pkg/saga/statemachine/engine/pcext/loop_context_holder.go @@ -19,9 +19,10 @@ package pcext import ( "context" + "sync" + "github.com/seata/seata-go/pkg/saga/statemachine/constant" "github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl" - "sync" ) type LoopContextHolder struct { diff --git a/pkg/saga/statemachine/engine/pcext/loop_task_utils.go b/pkg/saga/statemachine/engine/pcext/loop_task_utils.go index 8544c3c08..65c65692e 100644 --- a/pkg/saga/statemachine/engine/pcext/loop_task_utils.go +++ b/pkg/saga/statemachine/engine/pcext/loop_task_utils.go @@ -19,6 +19,7 @@ package pcext import ( "context" + "github.com/seata/seata-go/pkg/saga/statemachine/constant" "github.com/seata/seata-go/pkg/saga/statemachine/engine" "github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl" @@ -28,27 +29,45 @@ import ( ) func GetLoopConfig(ctx context.Context, processContext process_ctrl.ProcessContext, currentState statelang.State) state.Loop { - if matchLoop(currentState) { - taskState := currentState.(state.AbstractTaskState) - stateMachineInstance := processContext.GetVariable(constant.VarNameStateMachineInst).(statelang.StateMachineInstance) - stateMachineConfig := processContext.GetVariable(constant.VarNameStateMachineConfig).(engine.StateMachineConfig) - - if taskState.Loop() != nil { - loop := taskState.Loop() - collectionName := loop.Collection() - if collectionName != "" { - expression := CreateValueExpression(stateMachineConfig.ExpressionResolver(), collectionName) - collection := GetValue(expression, stateMachineInstance.Context(), nil) - collectionList := collection.([]any) - if len(collectionList) > 0 { - current := GetCurrentLoopContextHolder(ctx, processContext, true) - current.SetCollection(collection) - return loop - } - } - log.Warn("State [{}] loop collection param [{}] invalid", currentState.Name(), collectionName) + if !matchLoop(currentState) { + return nil + } + // Extract underlying AbstractTaskState pointer safely + var task *state.AbstractTaskState + switch s := currentState.(type) { + case *state.ServiceTaskStateImpl: + task = s.AbstractTaskState + case *state.ScriptTaskStateImpl: + task = s.AbstractTaskState + case *state.SubStateMachineImpl: + if s.ServiceTaskStateImpl != nil { + task = s.ServiceTaskStateImpl.AbstractTaskState } + default: + return nil + } + if task == nil || task.Loop() == nil { + return nil + } + + loop := task.Loop() + collectionName := loop.Collection() + if collectionName == "" { + return nil + } + stateMachineInstance := processContext.GetVariable(constant.VarNameStateMachineInst).(statelang.StateMachineInstance) + stateMachineConfig := processContext.GetVariable(constant.VarNameStateMachineConfig).(engine.StateMachineConfig) + expression := CreateValueExpression(stateMachineConfig.ExpressionResolver(), collectionName) + collection := GetValue(expression, stateMachineInstance.Context(), nil) + if collection == nil { + log.Warn("State [{}] loop collection param [{}] invalid", currentState.Name(), collectionName) + return nil + } + if collectionList, ok := collection.([]any); ok && len(collectionList) > 0 { + current := GetCurrentLoopContextHolder(ctx, processContext, true) + current.SetCollection(collection) + return loop } return nil } diff --git a/pkg/saga/statemachine/engine/pcext/parameter_utils.go b/pkg/saga/statemachine/engine/pcext/parameter_utils.go index 72f99abee..2232ebb7c 100644 --- a/pkg/saga/statemachine/engine/pcext/parameter_utils.go +++ b/pkg/saga/statemachine/engine/pcext/parameter_utils.go @@ -19,14 +19,15 @@ package pcext import ( "fmt" + "strings" + "sync" + "github.com/seata/seata-go/pkg/saga/statemachine/constant" "github.com/seata/seata-go/pkg/saga/statemachine/engine" "github.com/seata/seata-go/pkg/saga/statemachine/engine/expr" "github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl" "github.com/seata/seata-go/pkg/saga/statemachine/statelang" "github.com/seata/seata-go/pkg/saga/statemachine/statelang/state" - "strings" - "sync" ) func CreateInputParams(processContext process_ctrl.ProcessContext, expressionResolver expr.ExpressionResolver, diff --git a/pkg/saga/statemachine/engine/pcext/process_handler.go b/pkg/saga/statemachine/engine/pcext/process_handler.go index 41a07853f..13ec882e6 100644 --- a/pkg/saga/statemachine/engine/pcext/process_handler.go +++ b/pkg/saga/statemachine/engine/pcext/process_handler.go @@ -20,8 +20,14 @@ package pcext import ( "context" "errors" - "github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl" "sync" + + "github.com/seata/seata-go/pkg/saga/statemachine/constant" + "github.com/seata/seata-go/pkg/saga/statemachine/engine" + "github.com/seata/seata-go/pkg/saga/statemachine/engine/expr" + "github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl" + "github.com/seata/seata-go/pkg/saga/statemachine/statelang" + stateimpl "github.com/seata/seata-go/pkg/saga/statemachine/statelang/state" ) type StateHandler interface { @@ -53,18 +59,29 @@ func NewStateMachineProcessHandler() *StateMachineProcessHandler { } func (s *StateMachineProcessHandler) Process(ctx context.Context, processContext process_ctrl.ProcessContext) error { - stateInstruction, _ := processContext.GetInstruction().(StateInstruction) + var stateInstruction *StateInstruction + switch v := processContext.GetInstruction().(type) { + case *StateInstruction: + stateInstruction = v + case StateInstruction: + tmp := v + stateInstruction = &tmp + default: + return errors.New("invalid state instruction from processContext") + } state, err := stateInstruction.GetState(processContext) if err != nil { return err } + // mark Fail end-state to influence final machine status decision + if state.Type() == constant.StateTypeFail { + processContext.SetVariable(constant.VarNameFailEndStateFlag, true) + } + stateType := state.Type() stateHandler := s.GetStateHandler(stateType) - if stateHandler == nil { - return errors.New("Not support [" + stateType + "] state handler") - } interceptAbleStateHandler, ok := stateHandler.(InterceptAbleStateHandler) @@ -82,9 +99,61 @@ func (s *StateMachineProcessHandler) Process(ctx context.Context, processContext } } - err = stateHandler.Process(ctx, processContext) - if err != nil { - return err + // Prepare current state instance in context before processing + smInst, ok := processContext.GetVariable(constant.VarNameStateMachineInst).(statelang.StateMachineInstance) + if !ok || smInst == nil { + return errors.New("state machine instance not found in context") + } + + stInst := statelang.NewStateInstanceImpl() + stInst.SetName(state.Name()) + stInst.SetType(state.Type()) + // If service task, enrich service attributes + if svc, ok := state.(*stateimpl.ServiceTaskStateImpl); ok { + stInst.SetServiceName(svc.ServiceName()) + stInst.SetServiceMethod(svc.ServiceMethod()) + stInst.SetServiceType(svc.ServiceType()) + stInst.SetForUpdate(svc.ForUpdate()) + + // ensure mutex lock for parameter evaluation exists + if !processContext.HasVariable(constant.VarNameProcessContextMutexLock) { + processContext.SetVariable(constant.VarNameProcessContextMutexLock, &sync.Mutex{}) + } + // if this is a compensation execution, mark the link to original state + if v := processContext.GetVariable("_compensate_for_state_id_"); v != nil { + if sid, ok := v.(string); ok && sid != "" { + stInst.SetStateIDCompensatedFor(sid) + // also add to holder's 'StatesForCompensation' for final status decision + holder := GetCurrentCompensationHolder(ctx, processContext, true) + holder.StatesForCompensation().Store(stInst.Name(), stInst) + // clear after consuming + processContext.RemoveVariable("_compensate_for_state_id_") + } + } + + // evaluate input params from CEL expressions if any + if cfg, ok := processContext.GetVariable(constant.VarNameStateMachineConfig).(engine.StateMachineConfig); ok { + var exprResolver expr.ExpressionResolver = cfg.ExpressionResolver() + // use all variables in process context for expression scope + variables := processContext.GetVariables() + inputParams := CreateInputParams(processContext, exprResolver, stInst, svc.AbstractTaskState, variables) + processContext.SetVariable(constant.VarNameInputParams, inputParams) + } + } + // Assign a temporary ID key for list/map tracking + tmpId := state.Name() + if tmpId == "" { + tmpId = "state" + } + smInst.PutState(tmpId, stInst) + processContext.SetVariable(constant.VarNameStateInst, stInst) + + // Execute handler when present; end states don't require a handler + if stateHandler != nil { + err = stateHandler.Process(ctx, processContext) + if err != nil { + return err + } } if stateHandlerInterceptorList != nil && len(stateHandlerInterceptorList) > 0 { @@ -96,6 +165,19 @@ func (s *StateMachineProcessHandler) Process(ctx context.Context, processContext } } + // Set execution result on state instance + if ex, _ := processContext.GetVariable(constant.VarNameCurrentException).(error); ex != nil { + stInst.SetStatus(statelang.FA) + stInst.SetError(ex) + } else { + // For Fail end state, mark FA; otherwise mark SU + if stateType == constant.StateTypeFail { + stInst.SetStatus(statelang.FA) + } else { + stInst.SetStatus(statelang.SU) + } + } + return nil } diff --git a/pkg/saga/statemachine/engine/pcext/process_router.go b/pkg/saga/statemachine/engine/pcext/process_router.go index 2f57a1b11..940e9336a 100644 --- a/pkg/saga/statemachine/engine/pcext/process_router.go +++ b/pkg/saga/statemachine/engine/pcext/process_router.go @@ -19,7 +19,9 @@ package pcext import ( "context" + "github.com/pkg/errors" + "github.com/seata/seata-go/pkg/saga/statemachine/constant" "github.com/seata/seata-go/pkg/saga/statemachine/engine" "github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl" @@ -31,8 +33,14 @@ type StateMachineProcessRouter struct { } func (s *StateMachineProcessRouter) Route(ctx context.Context, processContext process_ctrl.ProcessContext) (process_ctrl.Instruction, error) { - stateInstruction, ok := processContext.GetInstruction().(StateInstruction) - if !ok { + var stateInstruction *StateInstruction + switch v := processContext.GetInstruction().(type) { + case *StateInstruction: + stateInstruction = v + case StateInstruction: + tmp := v + stateInstruction = &tmp + default: return nil, errors.New("instruction is not a state instruction") } diff --git a/pkg/saga/statemachine/engine/pcext/state_router_impl.go b/pkg/saga/statemachine/engine/pcext/state_router_impl.go index c7336b6f6..2023d7a31 100644 --- a/pkg/saga/statemachine/engine/pcext/state_router_impl.go +++ b/pkg/saga/statemachine/engine/pcext/state_router_impl.go @@ -19,7 +19,9 @@ package pcext import ( "context" + "github.com/seata/seata-go/pkg/saga/statemachine/constant" + "github.com/seata/seata-go/pkg/saga/statemachine/engine" "github.com/seata/seata-go/pkg/saga/statemachine/engine/exception" "github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl" "github.com/seata/seata-go/pkg/saga/statemachine/statelang" @@ -32,6 +34,10 @@ type EndStateRouter struct { } func (e EndStateRouter) Route(ctx context.Context, processContext process_ctrl.ProcessContext, state statelang.State) (process_ctrl.Instruction, error) { + // mark Fail end-state to influence final machine status decision + if state != nil && state.Type() == constant.StateTypeFail { + processContext.SetVariable(constant.VarNameFailEndStateFlag, true) + } return nil, nil } @@ -39,7 +45,16 @@ type TaskStateRouter struct { } func (t TaskStateRouter) Route(ctx context.Context, processContext process_ctrl.ProcessContext, state statelang.State) (process_ctrl.Instruction, error) { - stateInstruction, _ := processContext.GetInstruction().(StateInstruction) + var stateInstruction *StateInstruction + switch v := processContext.GetInstruction().(type) { + case *StateInstruction: + stateInstruction = v + case StateInstruction: + tmp := v + stateInstruction = &tmp + default: + return nil, exception.NewEngineExecutionException(seataErrors.ObjectNotExists, "instruction is not a state instruction", nil) + } if stateInstruction.End() { log.Infof("StateInstruction is ended, Stop the StateMachine executing. StateMachine[%s] Current State[%s]", stateInstruction.StateMachineName(), stateInstruction.StateName()) @@ -53,15 +68,81 @@ func (t TaskStateRouter) Route(ctx context.Context, processContext process_ctrl. return nil, nil } - // The current CompensationTriggerState can mark the compensation process is started and perform compensation - // route processing. - compensationTriggerState, ok := processContext.GetVariable(constant.VarNameCurrentCompensateTriggerState).(statelang.State) - if ok { + // If current state is CompensationTrigger or compensation has been flagged, start compensation routing + if state != nil && state.Type() == constant.StateTypeCompensationTrigger { + // flag into context for downstream calls (best-effort) + if hpc, ok := processContext.(process_ctrl.HierarchicalProcessContext); ok { + hpc.SetVariableLocally(constant.VarNameCurrentCompensateTriggerState, state) + hpc.SetVariableLocally(constant.VarNameFirstCompensationStateStarted, false) + } else { + processContext.SetVariable(constant.VarNameCurrentCompensateTriggerState, state) + processContext.SetVariable(constant.VarNameFirstCompensationStateStarted, false) + } + + // build compensation stack from executed forward states (latest first) + smInst := processContext.GetVariable(constant.VarNameStateMachineInst).(statelang.StateMachineInstance) + holder := GetCurrentCompensationHolder(ctx, processContext, true) + stack := holder.StateStackNeedCompensation() + sm := processContext.GetVariable(constant.VarNameStateMachine).(statelang.StateMachine) + states := smInst.StateList() + for i := len(states) - 1; i >= 0; i-- { + si := states[i] + // exclude compensation states + if si.StateIDCompensatedFor() != "" { + continue + } + // only successful forward states are subject to compensation + if si.Status() != statelang.SU { + continue + } + originName := GetOriginStateName(si) + def := sm.State(originName) + // ensure the definition has a compensate state + var task *sagaState.AbstractTaskState + switch s := def.(type) { + case *sagaState.ServiceTaskStateImpl: + task = s.AbstractTaskState + case *sagaState.ScriptTaskStateImpl: + task = s.AbstractTaskState + case *sagaState.SubStateMachineImpl: + if s.ServiceTaskStateImpl != nil { + task = s.ServiceTaskStateImpl.AbstractTaskState + } + } + if task == nil || task.CompensateState() == "" { + continue + } + stack.Push(si) + } + // fallback: if nothing pushed, try last successful forward state + if stack.Empty() { + for i := len(states) - 1; i >= 0; i-- { + si := states[i] + if si.StateIDCompensatedFor() != "" { + continue + } + if si.Status() != statelang.SU { + continue + } + stack.Push(si) + break + } + } + + // mark machine compensation running + smInst.SetCompensationStatus(statelang.RU) + + return t.compensateRoute(ctx, processContext, state) + } + if compensationTriggerState, ok := processContext.GetVariable(constant.VarNameCurrentCompensateTriggerState).(statelang.State); ok { return t.compensateRoute(ctx, processContext, compensationTriggerState) } // There is an exception route, indicating that an exception is thrown, and the exception route is prioritized. - next := processContext.GetVariable(constant.VarNameCurrentExceptionRoute).(string) + var next string + if v, ok := processContext.GetVariable(constant.VarNameCurrentExceptionRoute).(string); ok { + next = v + } if next != "" { processContext.RemoveVariable(constant.VarNameCurrentExceptionRoute) @@ -71,7 +152,9 @@ func (t TaskStateRouter) Route(ctx context.Context, processContext process_ctrl. // If next is empty, the state selected by the Choice state was taken. if next == "" && processContext.HasVariable(constant.VarNameCurrentChoice) { - next = processContext.GetVariable(constant.VarNameCurrentChoice).(string) + if v, ok := processContext.GetVariable(constant.VarNameCurrentChoice).(string); ok { + next = v + } processContext.RemoveVariable(constant.VarNameCurrentChoice) } @@ -79,6 +162,50 @@ func (t TaskStateRouter) Route(ctx context.Context, processContext process_ctrl. return nil, nil } + // If we are routing due to exception to CompensationTrigger, pre-build compensation stack + if next == "CompensationTrigger" || (processContext.GetVariable(constant.VarNameCurrentException) != nil) { + smInst := processContext.GetVariable(constant.VarNameStateMachineInst).(statelang.StateMachineInstance) + holder := GetCurrentCompensationHolder(ctx, processContext, true) + stack := holder.StateStackNeedCompensation() + sm := processContext.GetVariable(constant.VarNameStateMachine).(statelang.StateMachine) + // prefer DB list if available + var states []statelang.StateInstance + if cfg, ok := processContext.GetVariable(constant.VarNameStateMachineConfig).(engine.StateMachineConfig); ok && cfg.StateLogRepository() != nil { + if list, err := cfg.StateLogRepository().GetStateInstanceListByMachineInstanceId(smInst.ID()); err == nil && len(list) > 0 { + states = list + } + } + if len(states) == 0 { + states = smInst.StateList() + } + for i := len(states) - 1; i >= 0; i-- { + si := states[i] + if si.StateIDCompensatedFor() != "" { + continue + } + if si.Status() != statelang.SU { + continue + } + originName := GetOriginStateName(si) + def := sm.State(originName) + var task *sagaState.AbstractTaskState + switch s := def.(type) { + case *sagaState.ServiceTaskStateImpl: + task = s.AbstractTaskState + case *sagaState.ScriptTaskStateImpl: + task = s.AbstractTaskState + case *sagaState.SubStateMachineImpl: + if s.ServiceTaskStateImpl != nil { + task = s.ServiceTaskStateImpl.AbstractTaskState + } + } + if task == nil || task.CompensateState() == "" { + continue + } + stack.Push(si) + } + } + stateMachine := state.StateMachine() nextState := stateMachine.State(next) if nextState == nil { @@ -115,12 +242,59 @@ func (t *TaskStateRouter) compensateRoute(ctx context.Context, processContext pr stateStackToBeCompensated := GetCurrentCompensationHolder(ctx, processContext, true).StateStackNeedCompensation() if stateStackToBeCompensated != nil { - stateToBeCompensated := stateStackToBeCompensated.Pop().(statelang.StateInstance) + popped := stateStackToBeCompensated.Pop() + if popped == nil { + // no states to compensate, finish or go next + processContext.RemoveVariable(constant.VarNameCurrentCompensateTriggerState) + // mark this as a no-compensation path; final status decided by Fail end + processContext.SetVariable(constant.VarNameNoCompensation, true) + compensationTriggerStateNext := compensationTriggerState.Next() + if compensationTriggerStateNext == "" { + return nil, EndStateMachine(ctx, processContext) + } + // set next on instruction (pointer-safe) + var instPtr *StateInstruction + switch v := processContext.GetInstruction().(type) { + case *StateInstruction: + instPtr = v + case StateInstruction: + tmp := v + instPtr = &tmp + default: + return nil, EndStateMachine(ctx, processContext) + } + instPtr.SetStateName(compensationTriggerStateNext) + processContext.SetInstruction(instPtr) + return instPtr, nil + } + stateToBeCompensated := popped.(statelang.StateInstance) stateMachine := processContext.GetVariable(constant.VarNameStateMachine).(statelang.StateMachine) state := stateMachine.State(GetOriginStateName(stateToBeCompensated)) - if taskState, ok := state.(sagaState.AbstractTaskState); ok { - instruction := processContext.GetInstruction().(StateInstruction) + // resolve underlying abstract task for various state impls + var taskState *sagaState.AbstractTaskState + switch s := state.(type) { + case *sagaState.ServiceTaskStateImpl: + taskState = s.AbstractTaskState + case *sagaState.ScriptTaskStateImpl: + taskState = s.AbstractTaskState + case *sagaState.SubStateMachineImpl: + if s.ServiceTaskStateImpl != nil { + taskState = s.ServiceTaskStateImpl.AbstractTaskState + } + } + if taskState != nil { + // pointer-safe fetch of instruction + var instruction *StateInstruction + switch v := processContext.GetInstruction().(type) { + case *StateInstruction: + instruction = v + case StateInstruction: + tmp := v + instruction = &tmp + default: + return nil, EndStateMachine(ctx, processContext) + } var compensateState statelang.State compensateStateName := taskState.CompensateState() @@ -145,6 +319,9 @@ func (t *TaskStateRouter) compensateRoute(ctx context.Context, processContext pr hierarchicalProcessContext := processContext.(process_ctrl.HierarchicalProcessContext) hierarchicalProcessContext.SetVariableLocally(constant.VarNameFirstCompensationStateStarted, true) + // expose the forward state id to be compensated so handler can mark the compensation instance + processContext.SetVariable("_compensate_for_state_id_", stateToBeCompensated.ID()) + if _, ok := compensateState.(sagaState.CompensateSubStateMachineState); ok { hierarchicalProcessContext = processContext.(process_ctrl.HierarchicalProcessContext) hierarchicalProcessContext.SetVariableLocally( @@ -152,6 +329,7 @@ func (t *TaskStateRouter) compensateRoute(ctx context.Context, processContext pr GenerateParentId(stateToBeCompensated)) } + processContext.SetInstruction(instruction) return instruction, nil } } @@ -163,7 +341,15 @@ func (t *TaskStateRouter) compensateRoute(ctx context.Context, processContext pr return nil, EndStateMachine(ctx, processContext) } - instruction := processContext.GetInstruction().(StateInstruction) - instruction.SetStateName(compensationTriggerStateNext) - return instruction, nil + instruction := processContext.GetInstruction() + if si, ok := instruction.(*StateInstruction); ok { + si.SetStateName(compensationTriggerStateNext) + return si, nil + } + if siv, ok := instruction.(StateInstruction); ok { + tmp := siv + tmp.SetStateName(compensationTriggerStateNext) + return &tmp, nil + } + return nil, exception.NewEngineExecutionException(seataErrors.ObjectNotExists, "instruction is not a state instruction", nil) } diff --git a/pkg/saga/statemachine/engine/pcext/state_router_impl_test.go b/pkg/saga/statemachine/engine/pcext/state_router_impl_test.go new file mode 100644 index 000000000..ef2ba3dab --- /dev/null +++ b/pkg/saga/statemachine/engine/pcext/state_router_impl_test.go @@ -0,0 +1,135 @@ +package pcext_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/seata/seata-go/pkg/saga/statemachine/constant" + "github.com/seata/seata-go/pkg/saga/statemachine/engine/config" + "github.com/seata/seata-go/pkg/saga/statemachine/engine/pcext" + "github.com/seata/seata-go/pkg/saga/statemachine/engine/utils" + "github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl/process" + "github.com/seata/seata-go/pkg/saga/statemachine/statelang" + "github.com/seata/seata-go/pkg/saga/statemachine/statelang/parser" + stateimpl "github.com/seata/seata-go/pkg/saga/statemachine/statelang/state" +) + +const compensationRouteMachine = `{ + "Name": "CompRouteTest", + "StartState": "ServiceB", + "States": { + "ServiceA": { + "Type": "ServiceTask", + "ServiceType": "local", + "ServiceName": "noopService", + "ServiceMethod": "Do", + "CompensateState": "CompensateA", + "ForCompensation": false, + "ForUpdate": false, + "Next": "Fail" + }, + "ServiceB": { + "Type": "ServiceTask", + "ServiceType": "local", + "ServiceName": "noopService", + "ServiceMethod": "Do", + "CompensateState": "CompensateB", + "ForCompensation": false, + "ForUpdate": false, + "Next": "Fail" + }, + "CompensateA": { + "Type": "ServiceTask", + "ServiceType": "local", + "ServiceName": "noopService", + "ServiceMethod": "Do", + "ForCompensation": true + }, + "CompensateB": { + "Type": "ServiceTask", + "ServiceType": "local", + "ServiceName": "noopService", + "ServiceMethod": "Do", + "ForCompensation": true + }, + "Fail": { + "Type": "Fail" + } + } +}` + +func TestTaskStateRouterBuildsCompensationStackFromExecutedStates(t *testing.T) { + cfg, err := config.NewDefaultStateMachineConfig( + config.WithEnableAsync(false), + config.WithStateMachineResources(nil), + ) + require.NoError(t, err) + + sm, err := parser.NewJSONStateMachineParser().Parse(compensationRouteMachine) + require.NoError(t, err) + sm.SetTenantId(cfg.GetDefaultTenantId()) + sm.SetContent(compensationRouteMachine) + require.NoError(t, cfg.StateMachineRepository().RegistryStateMachine(sm)) + + inst := statelang.NewStateMachineInstanceImpl() + inst.SetID("comp-route-inst") + inst.SetStateMachine(sm) + inst.SetStatus(statelang.RU) + inst.SetCompensationStatus("") + + serviceAInst := statelang.NewStateInstanceImpl() + serviceAInst.SetID("svcA-1") + serviceAInst.SetName("ServiceA") + serviceAInst.SetType(constant.StateTypeServiceTask) + serviceAInst.SetMachineInstanceID(inst.ID()) + serviceAInst.SetStateMachineInstance(inst) + serviceAInst.SetStatus(statelang.RU) + + serviceBInst := statelang.NewStateInstanceImpl() + serviceBInst.SetID("svcB-1") + serviceBInst.SetName("ServiceB") + serviceBInst.SetType(constant.StateTypeServiceTask) + serviceBInst.SetMachineInstanceID(inst.ID()) + serviceBInst.SetStateMachineInstance(inst) + serviceBInst.SetStatus(statelang.SU) + + inst.PutState(serviceAInst.ID(), serviceAInst) + inst.PutState(serviceBInst.ID(), serviceBInst) + + ctx := utils.NewProcessContextBuilder(). + WithProcessType(process.StateLang). + WithOperationName(constant.OperationNameForward). + WithInstruction(pcext.NewStateInstruction(sm.Name(), sm.StartState())). + WithStateMachineInstance(inst). + Build() + + ctx.SetVariable(constant.VarNameStateMachineInst, inst) + ctx.SetVariable(constant.VarNameStateMachine, sm) + ctx.SetVariable(constant.VarNameStateMachineConfig, cfg) + ctx.SetVariable(constant.VarNameStateInst, statelang.StateInstance(nil)) + + router := pcext.TaskStateRouter{} + compTrigger := stateimpl.NewCompensationTriggerStateImpl() + compTrigger.SetStateMachine(sm) + + _, err = router.Route(context.Background(), ctx, compTrigger) + require.NoError(t, err) + + holder := pcext.GetCurrentCompensationHolder(context.Background(), ctx, true) + require.NotNil(t, holder) + + compensated := make(map[string]string) + holder.StatesNeedCompensation().Range(func(key, value any) bool { + stateName, ok := key.(string) + require.True(t, ok) + inst, ok := value.(statelang.StateInstance) + require.True(t, ok) + compensated[stateName] = inst.Name() + return true + }) + + require.Len(t, compensated, 1) + require.Equal(t, "ServiceB", compensated["CompensateB"]) +} diff --git a/pkg/saga/statemachine/engine/repo/repository/state_log_repository.go b/pkg/saga/statemachine/engine/repo/repository/state_log_repository.go index 316e7c2a1..90ed350ba 100644 --- a/pkg/saga/statemachine/engine/repo/repository/state_log_repository.go +++ b/pkg/saga/statemachine/engine/repo/repository/state_log_repository.go @@ -19,11 +19,13 @@ package repository import ( "context" + "sync" + "github.com/pkg/errors" + "github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl" "github.com/seata/seata-go/pkg/saga/statemachine/statelang" "github.com/seata/seata-go/pkg/saga/statemachine/store" - "sync" ) var ( diff --git a/pkg/saga/statemachine/engine/repo/repository/state_machine_repository.go b/pkg/saga/statemachine/engine/repo/repository/state_machine_repository.go index 8fd3260e0..384fdc07b 100644 --- a/pkg/saga/statemachine/engine/repo/repository/state_machine_repository.go +++ b/pkg/saga/statemachine/engine/repo/repository/state_machine_repository.go @@ -71,6 +71,14 @@ func GetStateMachineRepositoryImpl() *StateMachineRepositoryImpl { return stateMachineRepositoryImpl } +func (s *StateMachineRepositoryImpl) SetDefaultTenantId(tenantId string) { + s.defaultTenantId = tenantId +} + +func (s *StateMachineRepositoryImpl) GetDefaultTenantId() string { + return s.defaultTenantId +} + func (s *StateMachineRepositoryImpl) GetStateMachineById(stateMachineId string) (statelang.StateMachine, error) { stateMachine := s.stateMachineMapById[stateMachineId] if stateMachine == nil && s.stateLangStore != nil { @@ -83,6 +91,9 @@ func (s *StateMachineRepositoryImpl) GetStateMachineById(stateMachineId string) if err != nil { return oldStateMachine, err } + if oldStateMachine == nil { + return nil, nil + } parseStatMachine, err := parser.NewJSONStateMachineParser().Parse(oldStateMachine.Content()) if err != nil { @@ -126,6 +137,9 @@ func (s *StateMachineRepositoryImpl) GetLastVersionStateMachine(stateMachineName if err != nil { return oldStateMachine, err } + if oldStateMachine == nil { + return nil, nil + } parseStatMachine, err := parser.NewJSONStateMachineParser().Parse(oldStateMachine.Content()) if err != nil { @@ -148,6 +162,10 @@ func (s *StateMachineRepositoryImpl) GetLastVersionStateMachine(stateMachineName func (s *StateMachineRepositoryImpl) RegistryStateMachine(machine statelang.StateMachine) error { stateMachineName := machine.Name() tenantId := machine.TenantId() + if tenantId == "" && s.defaultTenantId != "" { + machine.SetTenantId(s.defaultTenantId) + tenantId = s.defaultTenantId + } if s.stateLangStore != nil { oldStateMachine, err := s.stateLangStore.GetLastVersionStateMachine(stateMachineName, tenantId) @@ -228,14 +246,6 @@ func (s *StateMachineRepositoryImpl) GetCharset() string { return s.charset } -func (s *StateMachineRepositoryImpl) SetDefaultTenantId(defaultTenantId string) { - s.defaultTenantId = defaultTenantId -} - -func (s *StateMachineRepositoryImpl) GetDefaultTenantId() string { - return s.defaultTenantId -} - func (s *StateMachineRepositoryImpl) SetJsonParserName(jsonParserName string) { s.jsonParserName = jsonParserName } diff --git a/pkg/saga/statemachine/engine/repo/repository/state_machine_repository_test.go b/pkg/saga/statemachine/engine/repo/repository/state_machine_repository_test.go index cfd9a12a3..9161f47d6 100644 --- a/pkg/saga/statemachine/engine/repo/repository/state_machine_repository_test.go +++ b/pkg/saga/statemachine/engine/repo/repository/state_machine_repository_test.go @@ -19,8 +19,10 @@ package repository import ( "database/sql" - "github.com/seata/seata-go/pkg/saga/statemachine/statelang/parser" + "fmt" "os" + "path/filepath" + "runtime" "sync" "testing" "time" @@ -29,6 +31,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/seata/seata-go/pkg/saga/statemachine/statelang" + "github.com/seata/seata-go/pkg/saga/statemachine/statelang/parser" "github.com/seata/seata-go/pkg/saga/statemachine/store/db" ) @@ -41,8 +44,10 @@ func prepareDB() { oncePrepareDB.Do(func() { var err error testdb, err = sql.Open("sqlite3", ":memory:") - query_, err := os.ReadFile("../../../../../testdata/sql/saga/sqlite_init.sql") - initScript := string(query_) + if err != nil { + panic(err) + } + initScript, err := readInitSQL() if err != nil { panic(err) } @@ -52,9 +57,48 @@ func prepareDB() { }) } +func readInitSQL() (string, error) { + return readTestFile("testdata/sql/saga/sqlite_init.sql") +} + func loadStateMachineByYaml() string { - query, _ := os.ReadFile("../../../../../testdata/saga/statelang/simple_statemachine.json") - return string(query) + path, err := locateTestFile("testdata/saga/statelang/simple_statemachine.json") + if err != nil { + panic(err) + } + data, err := os.ReadFile(path) + if err != nil { + panic(err) + } + return string(data) +} + +func readTestFile(rel string) (string, error) { + path, err := locateTestFile(rel) + if err != nil { + return "", err + } + data, err := os.ReadFile(path) + if err != nil { + return "", err + } + return string(data), nil +} + +func locateTestFile(rel string) (string, error) { + _, thisFile, _, _ := runtime.Caller(0) + base := filepath.Dir(thisFile) + candidates := []string{ + filepath.Join(base, rel), + filepath.Join(base, "../../../../../", rel), + filepath.Join(base, "../../../../../../", rel), + } + for _, c := range candidates { + if _, err := os.Stat(c); err == nil { + return c, nil + } + } + return "", fmt.Errorf("test file not found: %s (checked %v)", rel, candidates) } func TestStateMachineInMemory(t *testing.T) { diff --git a/pkg/saga/statemachine/engine/repo/statemachine_store.go b/pkg/saga/statemachine/engine/repo/statemachine_store.go index c0360a5ac..7bf6e34da 100644 --- a/pkg/saga/statemachine/engine/repo/statemachine_store.go +++ b/pkg/saga/statemachine/engine/repo/statemachine_store.go @@ -18,8 +18,9 @@ package repo import ( - "github.com/seata/seata-go/pkg/saga/statemachine/statelang" "io" + + "github.com/seata/seata-go/pkg/saga/statemachine/statelang" ) type StateLogRepository interface { diff --git a/pkg/saga/statemachine/engine/serializer/serializer.go b/pkg/saga/statemachine/engine/serializer/serializer.go index 3a3ddf0a5..d0b79aba0 100644 --- a/pkg/saga/statemachine/engine/serializer/serializer.go +++ b/pkg/saga/statemachine/engine/serializer/serializer.go @@ -21,6 +21,7 @@ import ( "bytes" "encoding/gob" "encoding/json" + "github.com/pkg/errors" ) diff --git a/pkg/saga/statemachine/engine/serializer/serializer_test.go b/pkg/saga/statemachine/engine/serializer/serializer_test.go index eb163368e..4c674bda6 100644 --- a/pkg/saga/statemachine/engine/serializer/serializer_test.go +++ b/pkg/saga/statemachine/engine/serializer/serializer_test.go @@ -18,9 +18,10 @@ package serializer import ( + "testing" + "github.com/pkg/errors" "github.com/stretchr/testify/assert" - "testing" ) func TestErrorSerializer(t *testing.T) { diff --git a/pkg/saga/statemachine/engine/statemachine_config.go b/pkg/saga/statemachine/engine/statemachine_config.go index e6d8d5f0d..bd809e553 100644 --- a/pkg/saga/statemachine/engine/statemachine_config.go +++ b/pkg/saga/statemachine/engine/statemachine_config.go @@ -18,13 +18,14 @@ package engine import ( + "sync" + "github.com/seata/seata-go/pkg/saga/statemachine/engine/expr" "github.com/seata/seata-go/pkg/saga/statemachine/engine/invoker" "github.com/seata/seata-go/pkg/saga/statemachine/engine/repo" "github.com/seata/seata-go/pkg/saga/statemachine/engine/sequence" "github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl" "github.com/seata/seata-go/pkg/saga/statemachine/store" - "sync" ) type StateMachineConfig interface { @@ -48,6 +49,8 @@ type StateMachineConfig interface { AsyncEventPublisher() process_ctrl.EventPublisher + EnableAsync() bool + ServiceInvokerManager() invoker.ServiceInvokerManager ScriptInvokerManager() invoker.ScriptInvokerManager @@ -60,6 +63,10 @@ type StateMachineConfig interface { GetServiceInvokeTimeout() int + IsSagaBranchRegisterEnable() bool + + IsRmReportSuccessEnable() bool + ComponentLock() *sync.Mutex RegisterStateMachineDef(resources []string) error diff --git a/pkg/saga/statemachine/engine/statemachine_engine.go b/pkg/saga/statemachine/engine/statemachine_engine.go index 7b65e4c8c..c201ca973 100644 --- a/pkg/saga/statemachine/engine/statemachine_engine.go +++ b/pkg/saga/statemachine/engine/statemachine_engine.go @@ -19,6 +19,7 @@ package engine import ( "context" + "github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl" "github.com/seata/seata-go/pkg/saga/statemachine/statelang" ) diff --git a/pkg/saga/statemachine/engine/statemachine_engine_test.go b/pkg/saga/statemachine/engine/statemachine_engine_test.go index fc049d12e..11081a18a 100644 --- a/pkg/saga/statemachine/engine/statemachine_engine_test.go +++ b/pkg/saga/statemachine/engine/statemachine_engine_test.go @@ -15,19 +15,35 @@ * limitations under the License. */ -package engine +package engine_test import ( "context" - "github.com/seata/seata-go/pkg/saga/statemachine/engine/core" "testing" -) -func TestEngine(t *testing.T) { + enginepkg "github.com/seata/seata-go/pkg/saga/statemachine/engine" + "github.com/seata/seata-go/pkg/saga/statemachine/engine/core" +) +func TestProcessCtrlEngineInitializes(t *testing.T) { + eng, err := core.NewProcessCtrlStateMachineEngine() + if err != nil { + t.Fatalf("unexpected init error: %v", err) + } + if eng.GetStateMachineConfig() == nil { + t.Fatalf("state machine config should not be nil") + } + if _, ok := interface{}(eng).(enginepkg.StateMachineEngine); !ok { + t.Fatalf("ProcessCtrlStateMachineEngine should satisfy engine.StateMachineEngine") + } } -func TestSimpleStateMachine(t *testing.T) { - engine := core.NewProcessCtrlStateMachineEngine() - engine.Start(context.Background(), "simpleStateMachine", "tenantId", nil) +func TestProcessCtrlEngineStartMissingDefinitionFails(t *testing.T) { + eng, err := core.NewProcessCtrlStateMachineEngine() + if err != nil { + t.Fatalf("unexpected init error: %v", err) + } + if _, err = eng.Start(context.Background(), "undefined", "", nil); err == nil { + t.Fatalf("expected error when starting undefined state machine") + } } diff --git a/pkg/saga/statemachine/engine/strategy.go b/pkg/saga/statemachine/engine/strategy.go index f400d85f3..5b26d9bd5 100644 --- a/pkg/saga/statemachine/engine/strategy.go +++ b/pkg/saga/statemachine/engine/strategy.go @@ -1,7 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package engine import ( "context" + "github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl" "github.com/seata/seata-go/pkg/saga/statemachine/statelang" ) diff --git a/pkg/saga/statemachine/engine/strategy/status_decision.go b/pkg/saga/statemachine/engine/strategy/status_decision.go index 6e453b59d..100ec00a8 100644 --- a/pkg/saga/statemachine/engine/strategy/status_decision.go +++ b/pkg/saga/statemachine/engine/strategy/status_decision.go @@ -20,6 +20,7 @@ package strategy import ( "context" "errors" + "github.com/seata/seata-go/pkg/saga/statemachine/constant" "github.com/seata/seata-go/pkg/saga/statemachine/engine/exception" "github.com/seata/seata-go/pkg/saga/statemachine/engine/pcext" @@ -39,9 +40,29 @@ func (d DefaultStatusDecisionStrategy) DecideOnEndState(ctx context.Context, pro stateMachineInstance statelang.StateMachineInstance, exp error) error { if statelang.RU == stateMachineInstance.CompensationStatus() { compensationHolder := pcext.GetCurrentCompensationHolder(ctx, processContext, true) + // special-case: entered compensation but no compensations to execute + if v, ok := processContext.GetVariable(constant.VarNameNoCompensation).(bool); ok && v { + // If end is Fail, align semantics to FA with empty compensation_status + if fe, ok := processContext.GetVariable(constant.VarNameFailEndStateFlag).(bool); ok && fe { + stateMachineInstance.SetStatus(statelang.FA) + stateMachineInstance.SetCompensationStatus("") + return nil + } + } if err := decideMachineCompensateStatus(ctx, stateMachineInstance, compensationHolder); err != nil { return err } + // If no compensation executed (no states and empty stack) and end is Fail, normalize to FA + empty compStatus + if compensationHolder.StateStackNeedCompensation().Empty() { + empty := true + compensationHolder.StatesForCompensation().Range(func(key, value any) bool { empty = false; return true }) + if empty { + if v, ok := processContext.GetVariable(constant.VarNameFailEndStateFlag).(bool); ok && v { + stateMachineInstance.SetStatus(statelang.FA) + stateMachineInstance.SetCompensationStatus("") + } + } + } } else { failEndStateFlag, ok := processContext.GetVariable(constant.VarNameFailEndStateFlag).(bool) if !ok { @@ -61,6 +82,13 @@ func (d DefaultStatusDecisionStrategy) DecideOnEndState(ctx context.Context, pro stateMachineInstance.ID(), stateMachineInstance.StateMachine().Name(), stateMachineInstance.Status(), stateMachineInstance.CompensationStatus()) + // If ended via Fail state, normalize final machine status to FA (Java semantics) + if v, ok := processContext.GetVariable(constant.VarNameFailEndStateFlag).(bool); ok && v { + if stateMachineInstance.Status() != statelang.FA { + stateMachineInstance.SetStatus(statelang.FA) + } + } + return nil } diff --git a/pkg/saga/statemachine/process_ctrl/bussiness_processor.go b/pkg/saga/statemachine/process_ctrl/bussiness_processor.go index 369972a1d..5c39a4783 100644 --- a/pkg/saga/statemachine/process_ctrl/bussiness_processor.go +++ b/pkg/saga/statemachine/process_ctrl/bussiness_processor.go @@ -19,10 +19,12 @@ package process_ctrl import ( "context" + "sync" + "github.com/pkg/errors" + "github.com/seata/seata-go/pkg/saga/statemachine/constant" "github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl/process" - "sync" ) type BusinessProcessor interface { diff --git a/pkg/saga/statemachine/process_ctrl/default_process_handler.go b/pkg/saga/statemachine/process_ctrl/default_process_handler.go index e4b4702b8..c3fdac820 100644 --- a/pkg/saga/statemachine/process_ctrl/default_process_handler.go +++ b/pkg/saga/statemachine/process_ctrl/default_process_handler.go @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package process_ctrl import "context" diff --git a/pkg/saga/statemachine/process_ctrl/event_bus.go b/pkg/saga/statemachine/process_ctrl/event_bus.go index 51a40dbaa..aabdbf2c3 100644 --- a/pkg/saga/statemachine/process_ctrl/event_bus.go +++ b/pkg/saga/statemachine/process_ctrl/event_bus.go @@ -20,7 +20,9 @@ package process_ctrl import ( "context" "fmt" + "github.com/pkg/errors" + "github.com/seata/seata-go/pkg/saga/statemachine/constant" "github.com/seata/seata-go/pkg/util/collection" "github.com/seata/seata-go/pkg/util/log" @@ -74,7 +76,15 @@ func (d DirectEventBus) Offer(ctx context.Context, event Event) (bool, error) { return false, nil } - stack := processContext.GetVariable(constant.VarNameSyncExeStack).(*collection.Stack) + // Get or initialize execution stack from process context + var stack *collection.Stack + if v := processContext.GetVariable(constant.VarNameSyncExeStack); v != nil { + if s, ok := v.(*collection.Stack); ok { + stack = s + // Existing stack means we're in a nested offer; not the first event + isFirstEvent = false + } + } if stack == nil { stack = collection.NewStack() processContext.SetVariable(constant.VarNameSyncExeStack, stack) diff --git a/pkg/saga/statemachine/process_ctrl/handlers/service_task_state_handler.go b/pkg/saga/statemachine/process_ctrl/handlers/service_task_state_handler.go index e4298b02f..90964eec3 100644 --- a/pkg/saga/statemachine/process_ctrl/handlers/service_task_state_handler.go +++ b/pkg/saga/statemachine/process_ctrl/handlers/service_task_state_handler.go @@ -20,6 +20,8 @@ package handlers import ( "context" "errors" + "time" + "github.com/seata/seata-go/pkg/saga/statemachine/constant" "github.com/seata/seata-go/pkg/saga/statemachine/engine" "github.com/seata/seata-go/pkg/saga/statemachine/engine/exception" @@ -44,8 +46,14 @@ func (s *ServiceTaskStateHandler) State() string { } func (s *ServiceTaskStateHandler) Process(ctx context.Context, processContext process_ctrl.ProcessContext) error { - stateInstruction, ok := processContext.GetInstruction().(pcext.StateInstruction) - if !ok { + var stateInstruction *pcext.StateInstruction + switch v := processContext.GetInstruction().(type) { + case *pcext.StateInstruction: + stateInstruction = v + case pcext.StateInstruction: + tmp := v + stateInstruction = &tmp + default: return errors.New("invalid state instruction from processContext") } stateInterface, err := stateInstruction.GetState(processContext) @@ -64,6 +72,8 @@ func (s *ServiceTaskStateHandler) Process(ctx context.Context, processContext pr // invoke service task and record var result any var resultErr error + // acquire config for store operations + stateMachineConfig, okCfg := processContext.GetVariable(constant.VarNameStateMachineConfig).(engine.StateMachineConfig) handleResultErr := func(err error) { log.Error("<<<<<<<<<<<<<<<<<<<<<< State[%s], ServiceName[%s], Method[%s] Execute failed.", serviceTaskStateImpl.Name(), serviceName, methodName, err) @@ -74,19 +84,54 @@ func (s *ServiceTaskStateHandler) Process(ctx context.Context, processContext pr } hierarchicalProcessContext.SetVariable(constant.VarNameCurrentException, err) pcext.HandleException(processContext, serviceTaskStateImpl.AbstractTaskState, err) + + // mark instance failed and persist finish if store enabled + if stateInstance != nil { + stateInstance.SetStatus(statelang.FA) + stateInstance.SetEndTime(time.Now()) + stateInstance.SetUpdatedTime(time.Now()) + if okCfg && stateMachineConfig.StateLogStore() != nil { + _ = stateMachineConfig.StateLogStore().RecordStateFinished(ctx, stateInstance, processContext) + } + } } - input, ok := processContext.GetVariable(constant.VarNameInputParams).([]any) - if !ok { - handleResultErr(errors.New("invalid input params type from processContext")) - return nil + var input []any + if raw := processContext.GetVariable(constant.VarNameInputParams); raw != nil { + if v, ok := raw.([]any); ok { + input = v + } else { + handleResultErr(errors.New("invalid input params type from processContext")) + return nil + } + } else { + // treat as empty input when not provided + input = []any{} } + // set timestamps and status before persisting + now := time.Now() + stateInstance.SetStartedTime(now) + stateInstance.SetUpdatedTime(now) stateInstance.SetStatus(statelang.RU) log.Debugf(">>>>>>>>>>>>>>>>>>>>>> Start to execute State[%s], ServiceName[%s], Method[%s], Input:%s", serviceTaskStateImpl.Name(), serviceName, methodName, input) - if _, ok := stateInterface.(state.CompensateSubStateMachineState); ok { + // set input on state instance and persist started + stateInstance.SetInputParams(input) + if okCfg && stateMachineConfig.StateLogStore() != nil { + if err := stateMachineConfig.StateLogStore().RecordStateStarted(ctx, stateInstance, processContext); err != nil { + handleResultErr(err) + return nil + } + // now the stateInstance.ID() is assigned (branchId or generated). Put into stateMap keyed by ID + if smi, ok := processContext.GetVariable(constant.VarNameStateMachineInst).(statelang.StateMachineInstance); ok && smi != nil { + smi.PutState(stateInstance.ID(), stateInstance) + } + } + + // Identify compensate-submachine by explicit state type to avoid interface set ambiguity + if stateInterface.Type() == constant.StateTypeCompensateSubMachine { // If it is the compensation of the subState machine, // directly call the state machine's compensate method stateMachineEngine, ok := processContext.GetVariable(constant.VarNameStateMachineEngine).(engine.StateMachineEngine) @@ -102,8 +147,7 @@ func (s *ServiceTaskStateHandler) Process(ctx context.Context, processContext pr return nil } } else { - stateMachineConfig, ok := processContext.GetVariable(constant.VarNameStateMachineConfig).(engine.StateMachineConfig) - if !ok { + if !okCfg { handleResultErr(errors.New("invalid stateMachineConfig type from processContext")) return nil } @@ -116,11 +160,23 @@ func (s *ServiceTaskStateHandler) Process(ctx context.Context, processContext pr return nil } - result, resultErr = serviceInvoker.Invoke(ctx, input, serviceTaskStateImpl) - if resultErr != nil { - handleResultErr(resultErr) + // invoke and unwrap reflect returns + vals, invErr := serviceInvoker.Invoke(ctx, input, serviceTaskStateImpl) + if invErr != nil { + handleResultErr(invErr) return nil } + // Interpret last return as error if present + if n := len(vals); n > 0 { + if errVal, ok := vals[n-1].Interface().(error); ok && errVal != nil { + handleResultErr(errVal) + return nil + } + // result is the first return value when present + if vals[0].IsValid() && !vals[0].IsZero() { + result = vals[0].Interface() + } + } } log.Debugf("<<<<<<<<<<<<<<<<<<<<<< State[%s], ServiceName[%s], Method[%s] Execute finish. result: %s", @@ -137,6 +193,17 @@ func (s *ServiceTaskStateHandler) Process(ctx context.Context, processContext pr hierarchicalProcessContext.SetVariable(constant.VarNameOutputParams, result) } + // mark succeed and persist finished + stateInstance.SetStatus(statelang.SU) + stateInstance.SetEndTime(time.Now()) + stateInstance.SetUpdatedTime(time.Now()) + if okCfg && stateMachineConfig.StateLogStore() != nil { + if err := stateMachineConfig.StateLogStore().RecordStateFinished(ctx, stateInstance, processContext); err != nil { + handleResultErr(err) + return nil + } + } + return nil } diff --git a/pkg/saga/statemachine/process_ctrl/process_router.go b/pkg/saga/statemachine/process_ctrl/process_router.go index 37b1501e2..4ec822517 100644 --- a/pkg/saga/statemachine/process_ctrl/process_router.go +++ b/pkg/saga/statemachine/process_ctrl/process_router.go @@ -19,7 +19,9 @@ package process_ctrl import ( "context" + "github.com/pkg/errors" + "github.com/seata/seata-go/pkg/saga/statemachine/constant" "github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl/process" "github.com/seata/seata-go/pkg/saga/statemachine/statelang" @@ -31,7 +33,7 @@ type RouterHandler interface { } type ProcessRouter interface { - Route(ctx context.Context, processContext ProcessContext) error + Route(ctx context.Context, processContext ProcessContext) (Instruction, error) } type InterceptAbleStateRouter interface { @@ -68,12 +70,15 @@ func (d *DefaultRouterHandler) Route(ctx context.Context, processContext Process return errors.New("Process router not found") } - instruction := processRouter.Route(ctx, processContext) + instruction, err := processRouter.Route(ctx, processContext) + if err != nil { + return err + } if instruction == nil { log.Info("route instruction is null, process end") } else { processContext.SetInstruction(instruction) - _, err := d.eventPublisher.PushEvent(ctx, processContext) + _, err = d.eventPublisher.PushEvent(ctx, processContext) if err != nil { return err } diff --git a/pkg/saga/statemachine/statelang/parser/choice_state_json_parser.go b/pkg/saga/statemachine/statelang/parser/choice_state_json_parser.go index 091085483..9a453a69a 100644 --- a/pkg/saga/statemachine/statelang/parser/choice_state_json_parser.go +++ b/pkg/saga/statemachine/statelang/parser/choice_state_json_parser.go @@ -19,7 +19,9 @@ package parser import ( "fmt" + "github.com/pkg/errors" + "github.com/seata/seata-go/pkg/saga/statemachine/constant" "github.com/seata/seata-go/pkg/saga/statemachine/statelang" "github.com/seata/seata-go/pkg/saga/statemachine/statelang/state" diff --git a/pkg/saga/statemachine/statelang/parser/statemachine_config_parser.go b/pkg/saga/statemachine/statelang/parser/statemachine_config_parser.go index a33bbdecd..25ce8b504 100644 --- a/pkg/saga/statemachine/statelang/parser/statemachine_config_parser.go +++ b/pkg/saga/statemachine/statelang/parser/statemachine_config_parser.go @@ -20,10 +20,11 @@ package parser import ( "bytes" "fmt" - "github.com/seata/seata-go/pkg/saga/statemachine" "io" "os" + "github.com/seata/seata-go/pkg/saga/statemachine" + "github.com/knadh/koanf" "github.com/knadh/koanf/parsers/json" "github.com/knadh/koanf/parsers/yaml" diff --git a/pkg/saga/statemachine/statelang/parser/statemachine_config_parser_test.go b/pkg/saga/statemachine/statelang/parser/statemachine_config_parser_test.go index e634fc2fc..cbf28bbbf 100644 --- a/pkg/saga/statemachine/statelang/parser/statemachine_config_parser_test.go +++ b/pkg/saga/statemachine/statelang/parser/statemachine_config_parser_test.go @@ -18,9 +18,11 @@ package parser import ( - "github.com/seata/seata-go/pkg/saga/statemachine" - "github.com/stretchr/testify/assert" "testing" + + "github.com/stretchr/testify/assert" + + "github.com/seata/seata-go/pkg/saga/statemachine" ) func TestStateMachineConfigParser_Parse(t *testing.T) { diff --git a/pkg/saga/statemachine/statelang/parser/statemachine_json_parser.go b/pkg/saga/statemachine/statelang/parser/statemachine_json_parser.go index 37ceaaec6..68305ece7 100644 --- a/pkg/saga/statemachine/statelang/parser/statemachine_json_parser.go +++ b/pkg/saga/statemachine/statelang/parser/statemachine_json_parser.go @@ -19,6 +19,7 @@ package parser import ( "github.com/pkg/errors" + "github.com/seata/seata-go/pkg/saga/statemachine/constant" "github.com/seata/seata-go/pkg/saga/statemachine/statelang" "github.com/seata/seata-go/pkg/saga/statemachine/statelang/state" diff --git a/pkg/saga/statemachine/statelang/parser/statemachine_parser.go b/pkg/saga/statemachine/statelang/parser/statemachine_parser.go index 883d28e29..e1bfb882c 100644 --- a/pkg/saga/statemachine/statelang/parser/statemachine_parser.go +++ b/pkg/saga/statemachine/statelang/parser/statemachine_parser.go @@ -18,11 +18,13 @@ package parser import ( - "github.com/pkg/errors" - "github.com/seata/seata-go/pkg/saga/statemachine/statelang" "strconv" "strings" "sync" + + "github.com/pkg/errors" + + "github.com/seata/seata-go/pkg/saga/statemachine/statelang" ) type StateMachineParser interface { @@ -51,6 +53,13 @@ func (b BaseStateParser) ParseBaseAttributes(stateName string, state statelang.S } state.SetComment(comment) + // ensure Type from config is propagated to state + if t, err := b.GetStringOrDefault(stateName, stateMap, "Type", ""); err == nil { + state.SetType(t) + } else { + return err + } + next, err := b.GetStringOrDefault(stateName, stateMap, "Next", "") if err != nil { return err diff --git a/pkg/saga/statemachine/statelang/parser/sub_state_machine_parser.go b/pkg/saga/statemachine/statelang/parser/sub_state_machine_parser.go index 5938e5e4b..8f6faef76 100644 --- a/pkg/saga/statemachine/statelang/parser/sub_state_machine_parser.go +++ b/pkg/saga/statemachine/statelang/parser/sub_state_machine_parser.go @@ -19,7 +19,9 @@ package parser import ( "fmt" + "github.com/pkg/errors" + "github.com/seata/seata-go/pkg/saga/statemachine/constant" "github.com/seata/seata-go/pkg/saga/statemachine/statelang" "github.com/seata/seata-go/pkg/saga/statemachine/statelang/state" diff --git a/pkg/saga/statemachine/statelang/parser/task_state_json_parser.go b/pkg/saga/statemachine/statelang/parser/task_state_json_parser.go index 3c9bff097..31deae6f5 100644 --- a/pkg/saga/statemachine/statelang/parser/task_state_json_parser.go +++ b/pkg/saga/statemachine/statelang/parser/task_state_json_parser.go @@ -19,7 +19,9 @@ package parser import ( "fmt" + "github.com/pkg/errors" + "github.com/seata/seata-go/pkg/saga/statemachine/constant" "github.com/seata/seata-go/pkg/saga/statemachine/statelang" "github.com/seata/seata-go/pkg/saga/statemachine/statelang/state" @@ -36,7 +38,9 @@ func NewAbstractTaskStateParser() *AbstractTaskStateParser { } func (a *AbstractTaskStateParser) ParseTaskAttributes(stateName string, state *state.AbstractTaskState, stateMap map[string]interface{}) error { - err := a.ParseBaseAttributes(state.Name(), state.BaseState, stateMap) + // Use the provided stateName from the statelang definition, not the current state's Name(), + // because Name() may not be set yet at parse time. + err := a.ParseBaseAttributes(stateName, state.BaseState, stateMap) if err != nil { return err } diff --git a/pkg/saga/statemachine/statelang/state/sub_state_machine.go b/pkg/saga/statemachine/statelang/state/sub_state_machine.go index e18fc0ee9..5fc1f12bf 100644 --- a/pkg/saga/statemachine/statelang/state/sub_state_machine.go +++ b/pkg/saga/statemachine/statelang/state/sub_state_machine.go @@ -19,6 +19,7 @@ package state import ( "github.com/google/uuid" + "github.com/seata/seata-go/pkg/saga/statemachine/constant" ) diff --git a/pkg/saga/statemachine/statelang/state/task_state.go b/pkg/saga/statemachine/statelang/state/task_state.go index a258a28de..f6a07ab5c 100644 --- a/pkg/saga/statemachine/statelang/state/task_state.go +++ b/pkg/saga/statemachine/statelang/state/task_state.go @@ -18,9 +18,10 @@ package state import ( + "reflect" + "github.com/seata/seata-go/pkg/saga/statemachine/constant" "github.com/seata/seata-go/pkg/saga/statemachine/statelang" - "reflect" ) type TaskState interface { diff --git a/pkg/saga/statemachine/statelang/state_instance.go b/pkg/saga/statemachine/statelang/state_instance.go index 5248d0338..30d872e7b 100644 --- a/pkg/saga/statemachine/statelang/state_instance.go +++ b/pkg/saga/statemachine/statelang/state_instance.go @@ -319,7 +319,10 @@ func (s *StateInstanceImpl) SetIgnoreStatus(ignoreStatus bool) { } func (s *StateInstanceImpl) IsForCompensation() bool { - return s.stateIdCompensatedFor == "" + // A state instance is considered a compensation execution if it points + // back to an original forward state via stateIdCompensatedFor. + // When this field is non-empty, this instance is a compensation. + return s.stateIdCompensatedFor != "" } func (s *StateInstanceImpl) SerializedInputParams() interface{} { diff --git a/pkg/saga/statemachine/store/db/db.go b/pkg/saga/statemachine/store/db/db.go index aba2c668f..99d58a3da 100644 --- a/pkg/saga/statemachine/store/db/db.go +++ b/pkg/saga/statemachine/store/db/db.go @@ -19,7 +19,9 @@ package db import ( "database/sql" + "github.com/pkg/errors" + "github.com/seata/seata-go/pkg/util/log" ) @@ -37,17 +39,17 @@ func SelectOne[T any](db *sql.DB, sql string, fn ScanRows[T], args ...any) (T, e var result T log.Debugf("Preparing SQL: %s", sql) stmt, err := db.Prepare(sql) - defer stmt.Close() if err != nil { return result, err } + defer func() { _ = stmt.Close() }() log.Debugf("setting params to Stmt: %v", args) rows, err := stmt.Query(args...) - defer rows.Close() if err != nil { - return result, nil + return result, err } + defer func() { _ = rows.Close() }() if rows.Next() { return fn(rows) @@ -60,17 +62,17 @@ func SelectList[T any](db *sql.DB, sql string, fn ScanRows[T], args ...any) ([]T log.Debugf("Preparing SQL: %s", sql) stmt, err := db.Prepare(sql) - defer stmt.Close() if err != nil { return result, err } + defer func() { _ = stmt.Close() }() log.Debugf("setting params to Stmt: %v", args) rows, err := stmt.Query(args...) - defer rows.Close() if err != nil { return result, err } + defer func() { _ = rows.Close() }() for rows.Next() { obj, err := fn(rows) @@ -86,10 +88,10 @@ func SelectList[T any](db *sql.DB, sql string, fn ScanRows[T], args ...any) ([]T func ExecuteUpdate[T any](db *sql.DB, sql string, fn ExecStatement[T], obj T) (int64, error) { log.Debugf("Preparing SQL: %s", sql) stmt, err := db.Prepare(sql) - defer stmt.Close() if err != nil { return 0, err } + defer func() { _ = stmt.Close() }() log.Debugf("setting params to Stmt: %v", obj) @@ -104,10 +106,10 @@ func ExecuteUpdate[T any](db *sql.DB, sql string, fn ExecStatement[T], obj T) (i func ExecuteUpdateArgs(db *sql.DB, sql string, args ...any) (int64, error) { log.Debugf("Preparing SQL: %s", sql) stmt, err := db.Prepare(sql) - defer stmt.Close() if err != nil { return 0, err } + defer func() { _ = stmt.Close() }() log.Debugf("setting params to Stmt: %v", args) diff --git a/pkg/saga/statemachine/store/db/db_test.go b/pkg/saga/statemachine/store/db/db_test.go index e80f704bf..31a155820 100644 --- a/pkg/saga/statemachine/store/db/db_test.go +++ b/pkg/saga/statemachine/store/db/db_test.go @@ -19,7 +19,10 @@ package db import ( "database/sql" + "fmt" "os" + "path/filepath" + "runtime" "sync" ) @@ -32,8 +35,10 @@ func prepareDB() { oncePrepareDB.Do(func() { var err error db, err = sql.Open("sqlite3", ":memory:") - query_, err := os.ReadFile("testdata/sql/saga/sqlite_init.sql") - initScript := string(query_) + if err != nil { + panic(err) + } + initScript, err := readInitSQL() if err != nil { panic(err) } @@ -41,5 +46,19 @@ func prepareDB() { panic(err) } }) +} +func readInitSQL() (string, error) { + _, thisFile, _, _ := runtime.Caller(0) + base := filepath.Dir(thisFile) + candidates := []string{ + filepath.Join(base, "testdata/sql/saga/sqlite_init.sql"), + filepath.Join(base, "../../../../../testdata/sql/saga/sqlite_init.sql"), + } + for _, candidate := range candidates { + if data, err := os.ReadFile(candidate); err == nil { + return string(data), nil + } + } + return "", fmt.Errorf("sqlite init script not found; looked in %v", candidates) } diff --git a/pkg/saga/statemachine/store/db/statelang.go b/pkg/saga/statemachine/store/db/statelang.go index 82e4fd290..83827658a 100644 --- a/pkg/saga/statemachine/store/db/statelang.go +++ b/pkg/saga/statemachine/store/db/statelang.go @@ -19,10 +19,12 @@ package db import ( "database/sql" - "github.com/pkg/errors" - "github.com/seata/seata-go/pkg/saga/statemachine/statelang" "regexp" "time" + + "github.com/pkg/errors" + + "github.com/seata/seata-go/pkg/saga/statemachine/statelang" ) const ( diff --git a/pkg/saga/statemachine/store/db/statelang_test.go b/pkg/saga/statemachine/store/db/statelang_test.go index 4b870c114..322066d71 100644 --- a/pkg/saga/statemachine/store/db/statelang_test.go +++ b/pkg/saga/statemachine/store/db/statelang_test.go @@ -18,11 +18,13 @@ package db import ( - "github.com/seata/seata-go/pkg/saga/statemachine/statelang" - "github.com/stretchr/testify/assert" "testing" "time" + "github.com/stretchr/testify/assert" + + "github.com/seata/seata-go/pkg/saga/statemachine/statelang" + _ "github.com/mattn/go-sqlite3" ) diff --git a/pkg/saga/statemachine/store/db/statelog.go b/pkg/saga/statemachine/store/db/statelog.go index 4d5943874..d0536838c 100644 --- a/pkg/saga/statemachine/store/db/statelog.go +++ b/pkg/saga/statemachine/store/db/statelog.go @@ -27,12 +27,14 @@ import ( "time" "github.com/pkg/errors" + constant2 "github.com/seata/seata-go/pkg/constant" "github.com/seata/seata-go/pkg/protocol/branch" "github.com/seata/seata-go/pkg/protocol/message" + "github.com/seata/seata-go/pkg/rm" "github.com/seata/seata-go/pkg/saga/statemachine/constant" "github.com/seata/seata-go/pkg/saga/statemachine/engine" - "github.com/seata/seata-go/pkg/saga/statemachine/engine/config" + engExc "github.com/seata/seata-go/pkg/saga/statemachine/engine/exception" "github.com/seata/seata-go/pkg/saga/statemachine/engine/pcext" "github.com/seata/seata-go/pkg/saga/statemachine/engine/sequence" "github.com/seata/seata-go/pkg/saga/statemachine/engine/serializer" @@ -41,6 +43,7 @@ import ( "github.com/seata/seata-go/pkg/saga/statemachine/statelang/state" sagaTm "github.com/seata/seata-go/pkg/saga/tm" "github.com/seata/seata-go/pkg/tm" + seataErrors "github.com/seata/seata-go/pkg/util/errors" "github.com/seata/seata-go/pkg/util/log" ) @@ -141,6 +144,11 @@ func (s *StateLogStore) RecordStateMachineStarted(ctx context.Context, machineIn if err != nil { return err } + if gtx, ok := context.GetVariable(constant.VarNameGlobalTx).(*tm.GlobalTransaction); ok && gtx != nil { + log.Infof("SAGA GlobalBegin success, SM=%s, XID=%s", machineInstance.StateMachine().Name(), gtx.Xid) + } else { + log.Warnf("SAGA GlobalBegin missing GlobalTransaction in context for SM=%s", machineInstance.StateMachine().Name()) + } } if machineInstance.ID() == "" && s.seqGenerator != nil { @@ -163,6 +171,8 @@ func (s *StateLogStore) RecordStateMachineStarted(ctx context.Context, machineIn return errors.New("affected rows is smaller than 0") } + log.Infof("RecordStateMachineStarted ok, SM=%s, XID=%s", machineInstance.StateMachine().Name(), machineInstance.ID()) + return nil } @@ -184,6 +194,8 @@ func (s *StateLogStore) beginTransaction(ctx context.Context, machineInstance st }() txName := constant.SagaTransNamePrefix + machineInstance.StateMachine().Name() + appId, group := rm.GetRmAppAndGroup() + log.Infof("Begin SAGA global transaction: txName=%s, appId=%s, txServiceGroup=%s", txName, appId, group) gtx, err := s.sagaTransactionalTemplate.BeginTransaction(ctx, time.Duration(cfg.GetTransOperationTimeout()), txName) if err != nil { return err @@ -197,6 +209,7 @@ func (s *StateLogStore) beginTransaction(ctx context.Context, machineInstance st if machineContext != nil { machineContext[constant.VarNameGlobalTx] = gtx } + log.Infof("Begin SAGA global transaction ok: XID=%s", xid) return nil } @@ -215,6 +228,35 @@ func (s *StateLogStore) RecordStateMachineFinished(ctx context.Context, machineI delete(endParams, constant.VarNameGlobalTx) } + // If compensation ran, reconcile compensation status from DB state list to align with Java semantics + if list, err := s.GetStateInstanceListByMachineInstanceId(machineInstance.ID()); err == nil && len(list) > 0 { + compSeen := false + compAllSU := true + for _, si := range list { + if si.StateIDCompensatedFor() != "" { + compSeen = true + if si.Status() != statelang.SU { + compAllSU = false + break + } + } + } + if compSeen { + if compAllSU { + machineInstance.SetCompensationStatus(statelang.SU) + machineInstance.SetStatus(statelang.FA) + } else if machineInstance.CompensationStatus() == "" || machineInstance.CompensationStatus() == statelang.RU { + machineInstance.SetCompensationStatus(statelang.UN) + } + } else { + // No compensation executed. If final status不是SU,则归一化为FA并清空补偿态。 + if machineInstance.Status() != statelang.SU { + machineInstance.SetCompensationStatus("") + machineInstance.SetStatus(statelang.FA) + } + } + } + // if success, clear exception if statelang.SU == machineInstance.Status() && machineInstance.Exception() != nil { machineInstance.SetException(nil) @@ -238,8 +280,74 @@ func (s *StateLogStore) RecordStateMachineFinished(ctx context.Context, machineI return err } if affected <= 0 { - log.Warnf("StateMachineInstance[%s] is recovered by server, skip RecordStateMachineFinished", machineInstance.ID()) - return nil + // Reload and retry once with latest gmt_updated for optimistic lock + current, _ := s.GetStateMachineInstance(machineInstance.ID()) + if current != nil { + if !current.IsRunning() { + log.Infof("StateMachineInstance[%s] already finished (no rows updated), skip duplicate finish", machineInstance.ID()) + return nil + } + // retry with refreshed updatedTime + machineInstance.SetUpdatedTime(current.UpdatedTime()) + affected2, err2 := ExecuteUpdate(s.db, s.recordStateMachineFinishedSql, execStateMachineInstanceStatementForUpdate, machineInstance) + if err2 != nil { + return err2 + } + if affected2 <= 0 { + // Check again if it's already finished to avoid noisy warnings + current2, _ := s.GetStateMachineInstance(machineInstance.ID()) + if current2 != nil && !current2.IsRunning() && !current2.EndTime().IsZero() { + log.Infof("StateMachineInstance[%s] appears already finished after retry (no rows updated), treat as success", machineInstance.ID()) + return nil + } + // Fallback: update without gmt_updated predicate to guarantee terminal state is recorded + // Use warn-level visibility as this is an abnormal path indicating potential concurrency issues + log.Warnf("StateMachineInstance[%s] executing fallback finish (no optimistic lock) after retry failed, status=%s, comp=%s", + machineInstance.ID(), machineInstance.Status(), machineInstance.CompensationStatus()) + // Build SQL by stripping the trailing ' and gmt_updated = ?' + rawSql := s.recordStateMachineFinishedSql + noWhereSql := strings.Replace(rawSql, " and gmt_updated = ?", "", 1) + // Prepare parameters mirroring execStateMachineInstanceStatementForUpdate minus last arg + var serializedError []byte + if machineInstance.SerializedError() != nil && len(machineInstance.SerializedError().([]byte)) > 0 { + serializedError = machineInstance.SerializedError().([]byte) + } + var compensationStatus sql.NullString + if machineInstance.CompensationStatus() != "" { + compensationStatus.Valid = true + compensationStatus.String = string(machineInstance.CompensationStatus()) + } + // end_time, excep, end_params, status, compensation_status, is_running, gmt_updated, id + affectedFallback, errFallback := s.db.Exec(noWhereSql, + machineInstance.EndTime(), + serializedError, + machineInstance.SerializedEndParams(), + machineInstance.Status(), + compensationStatus, + machineInstance.IsRunning(), + time.Now(), + machineInstance.ID(), + ) + if errFallback != nil { + log.Errorf("StateMachineInstance[%s] fallback finish failed: %v, status=%s, comp=%s", + machineInstance.ID(), errFallback, + machineInstance.Status(), machineInstance.CompensationStatus()) + return errFallback + } + rowsAffected, _ := affectedFallback.RowsAffected() + if rowsAffected <= 0 { + log.Warnf("StateMachineInstance[%s] fallback finish affected 0 rows, may indicate concurrent update or missing record", + machineInstance.ID()) + } + // reflect latest update time locally to avoid timeout false positives + machineInstance.SetUpdatedTime(time.Now()) + } + } else { + log.Debugf("StateMachineInstance[%s] finish update affected 0 rows and reload failed; skipping retry", machineInstance.ID()) + } + } else { + // reflect latest update time locally to avoid timeout false positives + machineInstance.SetUpdatedTime(time.Now()) } // check if timeout or else report transaction finished @@ -248,6 +356,7 @@ func (s *StateLogStore) RecordStateMachineFinished(ctx context.Context, machineI return errors.New("stateMachineConfig is required in context") } + // UpdatedTime recently refreshed at successful finish; this check should only catch real timeouts if pcext.IsTimeout(machineInstance.UpdatedTime(), cfg.GetTransOperationTimeout()) { log.Warnf("StateMachineInstance[%s] is execution timeout, skip report transaction finished to server.", machineInstance.ID()) } else if machineInstance.ParentID() == "" { @@ -261,13 +370,8 @@ func (s *StateLogStore) RecordStateMachineFinished(ctx context.Context, machineI } func (s *StateLogStore) reportTransactionFinished(ctx context.Context, machineInstance statelang.StateMachineInstance, context process_ctrl.ProcessContext) error { - var err error defer func() { s.ClearUp(context) - if err != nil { - log.Errorf("Report transaction finish to server error: %v, StateMachine: %s, XID: %s, Reason: %s", - err, machineInstance.StateMachine().Name(), machineInstance.ID(), err.Error()) - } }() if s.sagaTransactionalTemplate == nil { @@ -276,8 +380,10 @@ func (s *StateLogStore) reportTransactionFinished(ctx context.Context, machineIn } globalTransaction, err := s.getGlobalTransaction(ctx, machineInstance, context) if err != nil { - log.Errorf("Failed to get global transaction: %v", err) - return err + log.Errorf("Failed to get global transaction: %v, StateMachine: %s, XID: %s", + err, machineInstance.StateMachine().Name(), machineInstance.ID()) + // Align with Java semantics: getGlobalTransaction failure is non-fatal to state machine finish + return nil } var globalStatus message.GlobalStatus @@ -298,7 +404,15 @@ func (s *StateLogStore) reportTransactionFinished(ctx context.Context, machineIn globalTransaction.TxStatus = globalStatus err = s.sagaTransactionalTemplate.ReportTransaction(ctx, globalTransaction) if err != nil { - return err + // Enhanced error logging aligned with Java implementation (DbAndReportTcStateLogStore.java:246-261) + log.Errorf("Report transaction finish to server failed: StateMachine=%s, XID=%s, Status=%s, Err=%v", + machineInstance.StateMachine().Name(), + machineInstance.ID(), + globalStatus, + err) + // Align with Java semantics: reporting to TC should not fail the local state machine finish. + // Swallow error to keep success path green. + return nil } return nil } @@ -348,6 +462,20 @@ func (s *StateLogStore) RecordStateStarted(ctx context.Context, stateInstance st if stateInstance == nil { return nil } + // ensure compensation linkage is restored when upstream context loses the original state id + if stateInstance.StateIDCompensatedFor() == "" { + if holder, ok := context.GetVariable(constant.VarNameCurrentCompensationHolder).(*pcext.CompensationHolder); ok && holder != nil { + if toCompensate, okLoad := holder.StatesNeedCompensation().Load(stateInstance.Name()); okLoad { + if original, okInst := toCompensate.(statelang.StateInstance); okInst && original != nil { + if original.ID() != "" { + stateInstance.SetStateIDCompensatedFor(original.ID()) + stateInstance.SetCompensationState(original) + holder.StatesForCompensation().Store(stateInstance.Name(), stateInstance) + } + } + } + } + } isUpdateMode, err := s.isUpdateMode(stateInstance, context) if err != nil { return err @@ -366,7 +494,13 @@ func (s *StateLogStore) RecordStateStarted(ctx context.Context, stateInstance st stateInstance.SetID(s.generateCompensateStateInstanceId(stateInstance, isUpdateMode)) } else { // register branch - s.branchRegister(ctx, stateInstance, context) + sm := stateInstance.StateMachineInstance().StateMachine().Name() + log.Infof("Register SAGA branch begin, SM=%s, state=%s", sm, stateInstance.Name()) + if err := s.branchRegister(ctx, stateInstance, context); err != nil { + log.Errorf("Register SAGA branch failed, SM=%s, state=%s, err=%v", sm, stateInstance.Name(), err) + return err + } + log.Infof("Register SAGA branch ok, SM=%s, state=%s, branchId=%s", sm, stateInstance.Name(), stateInstance.ID()) } if stateInstance.ID() == "" && s.seqGenerator != nil { @@ -393,15 +527,11 @@ func (s *StateLogStore) RecordStateStarted(ctx context.Context, stateInstance st if affected <= 0 { return errors.New("affected rows is smaller than 0") } + log.Infof("RecordStateStarted ok, SM=%s, state=%s, branchId=%s", stateInstance.StateMachineInstance().StateMachine().Name(), stateInstance.Name(), stateInstance.ID()) return nil } func (s *StateLogStore) isUpdateMode(stateInstance statelang.StateInstance, context process_ctrl.ProcessContext) (bool, error) { - cfg, ok := context.GetVariable(constant.VarNameStateMachineConfig).(config.DefaultStateMachineConfig) - if !ok { - return false, errors.New("stateMachineConfig is required in context") - } - instruction, ok := context.GetInstruction().(*pcext.StateInstruction) if !ok { return false, errors.New("stateInstruction is required in processContext") @@ -415,22 +545,24 @@ func (s *StateLogStore) isUpdateMode(stateInstance statelang.StateInstance, cont if stateInstance.StateIDRetriedFor() != "" { if taskState != nil && taskState.RetryPersistModeUpdate() { - return taskState.RetryPersistModeUpdate(), nil - } else if stateMachine.IsRetryPersistModeUpdate() { - return stateMachine.IsRetryPersistModeUpdate(), nil + return true, nil + } + if stateMachine != nil && stateMachine.IsRetryPersistModeUpdate() { + return true, nil } - return cfg.IsSagaRetryPersistModeUpdate(), nil + return false, nil } else if stateInstance.StateIDCompensatedFor() != "" { // find if this compensate has been executed stateList := stateInstance.StateMachineInstance().StateList() for _, instance := range stateList { if instance.IsForCompensation() && instance.Name() == stateInstance.Name() { if taskState != nil && taskState.CompensatePersistModeUpdate() { - return taskState.CompensatePersistModeUpdate(), nil - } else if stateMachine.IsCompensatePersistModeUpdate() { - return stateMachine.IsCompensatePersistModeUpdate(), nil + return true, nil } - return cfg.IsSagaCompensatePersistModeUpdate(), nil + if stateMachine != nil && stateMachine.IsCompensatePersistModeUpdate() { + return true, nil + } + return false, nil } } } @@ -473,16 +605,12 @@ func (s *StateLogStore) generateCompensateStateInstanceId(stateInstance statelan return fmt.Sprintf("%s-%d", originalCompensateStateInstId, maxIndex) } -func (s *StateLogStore) branchRegister(ctx context.Context, stateInstance statelang.StateInstance, context process_ctrl.ProcessContext) error { - cfg, ok := context.GetVariable(constant.VarNameStateMachineConfig).(config.DefaultStateMachineConfig) - if !ok { - return errors.New("stateMachineConfig is required in context") - } +func safeGetSagaRM() (rm.ResourceManager, bool) { + defer func() { recover() }() + return rm.GetRmCacheInstance().GetResourceManager(branch.BranchTypeSAGA), true +} - if !cfg.IsSagaBranchRegisterEnable() { - log.Debugf("sagaBranchRegisterEnable = false, skip branch report. state[%s]", stateInstance.Name()) - return nil - } +func (s *StateLogStore) branchRegister(ctx context.Context, stateInstance statelang.StateInstance, context process_ctrl.ProcessContext) error { //Register branch var err error @@ -494,18 +622,88 @@ func (s *StateLogStore) branchRegister(ctx context.Context, stateInstance statel } }() + if cfg, ok := context.GetVariable(constant.VarNameStateMachineConfig).(engine.StateMachineConfig); ok { + if !cfg.IsSagaBranchRegisterEnable() { + if stateInstance.ID() == "" { + if seq := cfg.SeqGenerator(); seq != nil { + stateInstance.SetID(seq.GenerateId(constant.SeqEntityStateInst, "")) + } + if stateInstance.ID() == "" { + stateInstance.SetID(fmt.Sprintf("%s-%d", stateInstance.Name(), time.Now().UnixNano())) + } + } + return nil + } + } + globalTransaction, err := s.getGlobalTransaction(ctx, machineInstance, context) if err != nil { + if _, ok := engExc.IsEngineExecutionException(err); ok { + return err + } + err = engExc.NewEngineExecutionException( + seataErrors.TransactionErrorCodeBranchRegisterFailed, + fmt.Sprintf("get global transaction failed, stateMachine=%s, state=%s", + machineInstance.StateMachine().Name(), stateInstance.Name()), err) return err } if globalTransaction == nil { - err = errors.New("Global transaction is not exists") + err = engExc.NewEngineExecutionException( + seataErrors.TransactionErrorCodeGlobalTransactionNotExist, + fmt.Sprintf("global transaction does not exist, stateMachine=%s, state=%s", + machineInstance.StateMachine().Name(), stateInstance.Name()), nil) return err } - resourceId := machineInstance.StateMachine().Name() + "#" + stateInstance.Name() + // For SAGA, resourceId should be applicationId#txServiceGroup (aligned with SagaResource) + appId, group := rm.GetRmAppAndGroup() + resourceId := appId + "#" + group + // Prefer RM ResourceManager (BranchTypeSAGA) for branch register + if mgr, ok := safeGetSagaRM(); ok && mgr != nil { + bid, e := mgr.BranchRegister(ctx, rm.BranchRegisterParam{ + BranchType: branch.BranchTypeSAGA, + ResourceId: resourceId, + Xid: globalTransaction.Xid, + ClientId: "", + ApplicationData: "", + LockKeys: "", + }) + if e != nil { + err = engExc.NewEngineExecutionException( + seataErrors.TransactionErrorCodeBranchRegisterFailed, + fmt.Sprintf("branch register via rm failed, stateMachine=%s, state=%s, xid=%s", + machineInstance.StateMachine().Name(), stateInstance.Name(), globalTransaction.Xid), e) + return err + } + if bid <= 0 { + err = engExc.NewEngineExecutionException( + seataErrors.TransactionErrorCodeBranchRegisterFailed, + fmt.Sprintf("branch register returned invalid branchId (<=0), stateMachine=%s, state=%s, xid=%s", + machineInstance.StateMachine().Name(), stateInstance.Name(), globalTransaction.Xid), nil) + return err + } + stateInstance.SetID(strconv.FormatInt(bid, 10)) + return nil + } + if s.sagaTransactionalTemplate == nil { + err = engExc.NewEngineExecutionException( + seataErrors.TransactionErrorCodeBranchRegisterFailed, + "saga transactional template is not initialized", nil) + return err + } branchId, err := s.sagaTransactionalTemplate.BranchRegister(ctx, resourceId, "", globalTransaction.Xid, "", "") if err != nil { + err = engExc.NewEngineExecutionException( + seataErrors.TransactionErrorCodeBranchRegisterFailed, + fmt.Sprintf("branch register via template failed, stateMachine=%s, state=%s, xid=%s", + machineInstance.StateMachine().Name(), stateInstance.Name(), globalTransaction.Xid), err) + return err + } + if branchId <= 0 { + err = engExc.NewEngineExecutionException( + seataErrors.TransactionErrorCodeBranchRegisterFailed, + fmt.Sprintf("branch register returned invalid branchId (<=0), stateMachine=%s, state=%s, xid=%s", + machineInstance.StateMachine().Name(), stateInstance.Name(), globalTransaction.Xid), nil) return err } @@ -547,28 +745,30 @@ func (s *StateLogStore) RecordStateFinished(ctx context.Context, stateInstance s } stateInstance.SetSerializedError(serializedError) - _, err = ExecuteUpdate(s.db, s.recordStateFinishedSql, execStateInstanceStatementForUpdate, stateInstance) + affected, err := ExecuteUpdate(s.db, s.recordStateFinishedSql, execStateInstanceStatementForUpdate, stateInstance) if err != nil { return err } - - // A switch to skip branch report on branch success, in order to optimize performance - cfg, ok := context.GetVariable(constant.VarNameStateMachineConfig).(config.DefaultStateMachineConfig) - if !(ok && !cfg.IsRmReportSuccessEnable() && statelang.SU == stateInstance.Status()) { - err = s.branchReport(ctx, stateInstance, context) + if affected <= 0 { + err = engExc.NewEngineExecutionException( + seataErrors.TransactionErrorCodeFailedWriteSession, + fmt.Sprintf("state finish update affected 0 rows, possible concurrent update lost. state=%s id=%s machine=%s", + stateInstance.Name(), stateInstance.ID(), stateInstance.StateMachineInstance().ID()), nil) return err } - return nil + if cfg, ok := context.GetVariable(constant.VarNameStateMachineConfig).(engine.StateMachineConfig); ok { + if !cfg.IsRmReportSuccessEnable() && stateInstance.Status() == statelang.SU && stateInstance.CompensationStatus() == "" { + return nil + } + } + + // always report branch on state finish when enabled + return s.branchReport(ctx, stateInstance, context) } func (s *StateLogStore) branchReport(ctx context.Context, stateInstance statelang.StateInstance, context process_ctrl.ProcessContext) error { - cfg, ok := context.GetVariable(constant.VarNameStateMachineConfig).(config.DefaultStateMachineConfig) - if ok && !cfg.IsSagaBranchRegisterEnable() { - log.Debugf("sagaBranchRegisterEnable = false, skip branch report. state[%s]", stateInstance.Name()) - return nil - } var branchStatus branch.BranchStatus // find out the original state instance, only the original state instance is registered on the server, @@ -635,18 +835,68 @@ func (s *StateLogStore) branchReport(ctx context.Context, stateInstance statelan globalTransaction, err := s.getGlobalTransaction(ctx, stateInstance.StateMachineInstance(), context) if err != nil { + if _, ok := engExc.IsEngineExecutionException(err); ok { + return err + } + err = engExc.NewEngineExecutionException( + seataErrors.TransactionErrorCodeBranchReportFailed, + fmt.Sprintf("get global transaction failed for branch report, stateMachine=%s, state=%s", + stateInstance.StateMachineInstance().StateMachine().Name(), stateInstance.Name()), err) return err } if globalTransaction == nil { - err = errors.New("Global transaction is not exists") + err = engExc.NewEngineExecutionException( + seataErrors.TransactionErrorCodeGlobalTransactionNotExist, + fmt.Sprintf("global transaction does not exist for branch report, stateMachine=%s, state=%s", + stateInstance.StateMachineInstance().StateMachine().Name(), stateInstance.Name()), nil) return err } - branchId, err := strconv.ParseInt(originalStateInst.ID(), 10, 0) - if err != nil { + log.Infof("BranchReport prepare, XID=%s, branchId(raw)=%s, status=%s", globalTransaction.Xid, originalStateInst.ID(), branchStatus) + branchId, perr := strconv.ParseInt(originalStateInst.ID(), 10, 64) + if perr != nil { + return engExc.NewEngineExecutionException( + seataErrors.TransactionErrorCodeBranchReportFailed, + fmt.Sprintf("invalid branchId '%s' (must be numeric); ensure GlobalBegin + BranchRegister succeeded and resourceId is 'applicationId#txServiceGroup'", originalStateInst.ID()), + perr, + ) + } + if mgr := rm.GetRmCacheInstance().GetResourceManager(branch.BranchTypeSAGA); mgr != nil { + log.Infof("BranchReport via SagaResourceManager, XID=%s, branchId=%d, status=%s", globalTransaction.Xid, branchId, branchStatus) + if err = mgr.BranchReport(ctx, rm.BranchReportParam{ + BranchType: branch.BranchTypeSAGA, + Xid: globalTransaction.Xid, + BranchId: branchId, + Status: branchStatus, + ApplicationData: "", + }); err != nil { + err = engExc.NewEngineExecutionException( + seataErrors.TransactionErrorCodeBranchReportFailed, + fmt.Sprintf("branch report via rm failed, stateMachine=%s, state=%s, xid=%s, branchId=%d", + originalStateInst.StateMachineInstance().StateMachine().Name(), originalStateInst.Name(), globalTransaction.Xid, branchId), + err, + ) + return err + } + return nil + } + if s.sagaTransactionalTemplate == nil { + err = engExc.NewEngineExecutionException( + seataErrors.TransactionErrorCodeBranchReportFailed, + "saga transactional template is not initialized", nil) + return err + } + log.Infof("BranchReport via SagaTransactionalTemplate, XID=%s, branchId=%d, status=%s", globalTransaction.Xid, branchId, branchStatus) + if err = s.sagaTransactionalTemplate.BranchReport(ctx, globalTransaction.Xid, branchId, branchStatus, ""); err != nil { + err = engExc.NewEngineExecutionException( + seataErrors.TransactionErrorCodeBranchReportFailed, + fmt.Sprintf("branch report via template failed, stateMachine=%s, state=%s, xid=%s, branchId=%d", + originalStateInst.StateMachineInstance().StateMachine().Name(), originalStateInst.Name(), globalTransaction.Xid, branchId), + err, + ) return err } - return s.sagaTransactionalTemplate.BranchReport(ctx, globalTransaction.Xid, branchId, branchStatus, "") + return nil } func (s *StateLogStore) findOutOriginalStateInstanceOfRetryState(stateInstance statelang.StateInstance) statelang.StateInstance { @@ -944,14 +1194,24 @@ func execStateInstanceStatementForUpdate(obj statelang.StateInstance, stmt *sql. serializedError = obj.SerializedError().([]byte) } + updatedTime := obj.UpdatedTime() + if updatedTime.IsZero() { + updatedTime = time.Now() + } + machineInstanceID := obj.MachineInstanceID() + if machineInstanceID == "" { + if sm := obj.StateMachineInstance(); sm != nil { + machineInstanceID = sm.ID() + } + } result, err := stmt.Exec( obj.EndTime(), serializedError, obj.Status(), obj.SerializedOutputParams(), - obj.EndTime(), + updatedTime, obj.ID(), - obj.MachineInstanceID(), + machineInstanceID, ) if err != nil { return 0, err diff --git a/pkg/saga/statemachine/store/db/statelog_test.go b/pkg/saga/statemachine/store/db/statelog_test.go index 08e9e4454..c9beab491 100644 --- a/pkg/saga/statemachine/store/db/statelog_test.go +++ b/pkg/saga/statemachine/store/db/statelog_test.go @@ -20,18 +20,35 @@ package db import ( "context" "fmt" + "io" + "strconv" + "sync" + "testing" + "time" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/seata/seata-go/pkg/protocol/branch" + rmpkg "github.com/seata/seata-go/pkg/rm" "github.com/seata/seata-go/pkg/saga/statemachine/constant" "github.com/seata/seata-go/pkg/saga/statemachine/engine" - "github.com/seata/seata-go/pkg/saga/statemachine/engine/config" + engExc "github.com/seata/seata-go/pkg/saga/statemachine/engine/exception" + "github.com/seata/seata-go/pkg/saga/statemachine/engine/expr" + "github.com/seata/seata-go/pkg/saga/statemachine/engine/invoker" "github.com/seata/seata-go/pkg/saga/statemachine/engine/pcext" + "github.com/seata/seata-go/pkg/saga/statemachine/engine/repo" + "github.com/seata/seata-go/pkg/saga/statemachine/engine/sequence" + "github.com/seata/seata-go/pkg/saga/statemachine/engine/strategy" "github.com/seata/seata-go/pkg/saga/statemachine/engine/utils" "github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl" "github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl/process" "github.com/seata/seata-go/pkg/saga/statemachine/statelang" - "github.com/stretchr/testify/assert" - "testing" - "time" + stateimpl "github.com/seata/seata-go/pkg/saga/statemachine/statelang/state" + storepkg "github.com/seata/seata-go/pkg/saga/statemachine/store" + "github.com/seata/seata-go/pkg/tm" + seataErrors "github.com/seata/seata-go/pkg/util/errors" ) func mockProcessContext(stateMachineName string, stateMachineInstance statelang.StateMachineInstance) process_ctrl.ProcessContext { @@ -49,6 +66,8 @@ func mockMachineInstance(stateMachineName string) statelang.StateMachineInstance stateMachine.SetName(stateMachineName) stateMachine.SetComment("This is a test state machine") stateMachine.SetCreateTime(time.Now()) + stateMachine.SetID(stateMachineName) + stateMachine.SetTenantId("000001") inst := statelang.NewStateMachineInstanceImpl() inst.SetStateMachine(stateMachine) @@ -61,14 +80,263 @@ func mockMachineInstance(stateMachineName string) statelang.StateMachineInstance return inst } -func mockStateMachineConfig(context process_ctrl.ProcessContext) engine.StateMachineConfig { - cfg := config.NewDefaultStateMachineConfig() +func mockStateMachineConfig(context process_ctrl.ProcessContext) *stubConfig { + cfg := newStubConfig() context.SetVariable(constant.VarNameStateMachineConfig, cfg) return cfg } +var onceRegisterRM sync.Once +var globalFakeRM = &fakeResourceManager{} + +type stubConfig struct { + transTimeout int + sagaBranchRegisterEnable bool + rmReportSuccessEnable bool + seqGenerator sequence.SeqGenerator + componentLock *sync.Mutex + repository *stubStateMachineRepository +} + +func newStubConfig() *stubConfig { + onceRegisterRM.Do(func() { + rmpkg.GetRmCacheInstance().RegisterResourceManager(globalFakeRM) + }) + + return &stubConfig{ + transTimeout: 60000, + sagaBranchRegisterEnable: true, + rmReportSuccessEnable: true, + seqGenerator: sequence.NewUUIDSeqGenerator(), + componentLock: &sync.Mutex{}, + repository: newStubStateMachineRepository(), + } +} + +type stubStateMachineRepository struct { + byNameTenant map[string]statelang.StateMachine + byID map[string]statelang.StateMachine +} + +func newStubStateMachineRepository() *stubStateMachineRepository { + return &stubStateMachineRepository{ + byNameTenant: make(map[string]statelang.StateMachine), + byID: make(map[string]statelang.StateMachine), + } +} + +func (s *stubStateMachineRepository) key(name, tenant string) string { + return name + "_" + tenant +} + +func (s *stubStateMachineRepository) GetStateMachineById(stateMachineId string) (statelang.StateMachine, error) { + return s.byID[stateMachineId], nil +} + +func (s *stubStateMachineRepository) GetStateMachineByNameAndTenantId(stateMachineName string, tenantId string) (statelang.StateMachine, error) { + return s.byNameTenant[s.key(stateMachineName, tenantId)], nil +} + +func (s *stubStateMachineRepository) GetLastVersionStateMachine(stateMachineName string, tenantId string) (statelang.StateMachine, error) { + return s.byNameTenant[s.key(stateMachineName, tenantId)], nil +} + +func (s *stubStateMachineRepository) RegistryStateMachine(machine statelang.StateMachine) error { + key := s.key(machine.Name(), machine.TenantId()) + s.byNameTenant[key] = machine + if machine.ID() != "" { + s.byID[machine.ID()] = machine + } + return nil +} + +func (s *stubStateMachineRepository) RegistryStateMachineByReader(reader io.Reader) error { + return nil +} + +type fakeResourceManager struct { + branchRegisterCalls int + branchReportCalls int + branchRegisterErr error + branchReportErr error + nextBranchID int64 + mu sync.Mutex +} + +func (f *fakeResourceManager) BranchRegister(ctx context.Context, param rmpkg.BranchRegisterParam) (int64, error) { + f.mu.Lock() + defer f.mu.Unlock() + f.branchRegisterCalls++ + if f.nextBranchID == 0 { + f.nextBranchID = 101 + } + return f.nextBranchID, f.branchRegisterErr +} + +func (f *fakeResourceManager) BranchReport(ctx context.Context, param rmpkg.BranchReportParam) error { + f.mu.Lock() + defer f.mu.Unlock() + f.branchReportCalls++ + return f.branchReportErr +} + +func (f *fakeResourceManager) LockQuery(ctx context.Context, param rmpkg.LockQueryParam) (bool, error) { + return false, nil +} + +func (f *fakeResourceManager) BranchCommit(ctx context.Context, resource rmpkg.BranchResource) (branch.BranchStatus, error) { + return branch.BranchStatusPhasetwoCommitted, nil +} + +func (f *fakeResourceManager) BranchRollback(ctx context.Context, resource rmpkg.BranchResource) (branch.BranchStatus, error) { + return branch.BranchStatusPhasetwoRollbacked, nil +} + +func (f *fakeResourceManager) RegisterResource(resource rmpkg.Resource) error { + return nil +} + +func (f *fakeResourceManager) UnregisterResource(resource rmpkg.Resource) error { + return nil +} + +func (f *fakeResourceManager) GetCachedResources() *sync.Map { + return &sync.Map{} +} + +func (f *fakeResourceManager) GetBranchType() branch.BranchType { + return branch.BranchTypeSAGA +} + +func (s *stubConfig) StateLogRepository() repo.StateLogRepository { return nil } + +func (s *stubConfig) StateMachineRepository() repo.StateMachineRepository { return s.repository } + +func (s *stubConfig) StateLogStore() storepkg.StateLogStore { return nil } + +func (s *stubConfig) StateLangStore() storepkg.StateLangStore { return nil } + +func (s *stubConfig) ExpressionFactoryManager() *expr.ExpressionFactoryManager { return nil } + +func (s *stubConfig) ExpressionResolver() expr.ExpressionResolver { return nil } + +func (s *stubConfig) SeqGenerator() sequence.SeqGenerator { return s.seqGenerator } + +func (s *stubConfig) StatusDecisionStrategy() engine.StatusDecisionStrategy { + return strategy.NewDefaultStatusDecisionStrategy() +} + +func (s *stubConfig) EventPublisher() process_ctrl.EventPublisher { return nil } + +func (s *stubConfig) AsyncEventPublisher() process_ctrl.EventPublisher { return nil } + +func (s *stubConfig) EnableAsync() bool { return false } + +func (s *stubConfig) ServiceInvokerManager() invoker.ServiceInvokerManager { return nil } + +func (s *stubConfig) ScriptInvokerManager() invoker.ScriptInvokerManager { return nil } + +func (s *stubConfig) CharSet() string { return "UTF-8" } + +func (s *stubConfig) GetDefaultTenantId() string { return "000001" } + +func (s *stubConfig) GetTransOperationTimeout() int { return s.transTimeout } + +func (s *stubConfig) GetServiceInvokeTimeout() int { return 60000 } + +func (s *stubConfig) ComponentLock() *sync.Mutex { return s.componentLock } + +func (s *stubConfig) RegisterStateMachineDef(resources []string) error { return nil } + +func (s *stubConfig) RegisterExpressionFactory(expressionType string, factory expr.ExpressionFactory) { +} + +func (s *stubConfig) RegisterServiceInvoker(serviceType string, invoker invoker.ServiceInvoker) {} + +func (s *stubConfig) GetExpressionFactory(expressionType string) expr.ExpressionFactory { return nil } + +func (s *stubConfig) GetServiceInvoker(serviceType string) (invoker.ServiceInvoker, error) { + return nil, errors.New("not implemented") +} + +func (s *stubConfig) IsSagaBranchRegisterEnable() bool { return s.sagaBranchRegisterEnable } + +func (s *stubConfig) IsRmReportSuccessEnable() bool { return s.rmReportSuccessEnable } + +func (s *stubConfig) SetSagaBranchRegisterEnable(enable bool) { s.sagaBranchRegisterEnable = enable } + +func (s *stubConfig) SetRmReportSuccessEnable(enable bool) { s.rmReportSuccessEnable = enable } + +type fakeSagaTemplate struct { + branchRegisterCalls int + branchReportCalls int + nextBranchID int64 +} + +func (f *fakeSagaTemplate) CommitTransaction(ctx context.Context, gtr *tm.GlobalTransaction) error { + return nil +} + +func (f *fakeSagaTemplate) RollbackTransaction(ctx context.Context, gtr *tm.GlobalTransaction) error { + return nil +} + +func (f *fakeSagaTemplate) BeginTransaction(ctx context.Context, timeout time.Duration, txName string) (*tm.GlobalTransaction, error) { + return &tm.GlobalTransaction{}, nil +} + +func (f *fakeSagaTemplate) ReloadTransaction(ctx context.Context, xid string) (*tm.GlobalTransaction, error) { + return &tm.GlobalTransaction{Xid: xid}, nil +} + +func (f *fakeSagaTemplate) ReportTransaction(ctx context.Context, gtr *tm.GlobalTransaction) error { + return nil +} + +func (f *fakeSagaTemplate) BranchRegister(ctx context.Context, resourceId string, clientId string, xid string, applicationData string, lockKeys string) (int64, error) { + f.branchRegisterCalls++ + if f.nextBranchID == 0 { + f.nextBranchID = 101 + } + return f.nextBranchID, nil +} + +func (f *fakeSagaTemplate) BranchReport(ctx context.Context, xid string, branchId int64, status branch.BranchStatus, applicationData string) error { + f.branchReportCalls++ + return nil +} + +func (f *fakeSagaTemplate) CleanUp(ctx context.Context) {} + +func newServiceTaskState(machine statelang.StateMachineInstance) statelang.StateInstance { + state := statelang.NewStateInstanceImpl() + state.SetStateMachineInstance(machine) + state.SetMachineInstanceID(machine.ID()) + state.SetName("ServiceTask1") + state.SetType(constant.StateTypeServiceTask) + state.SetServiceName("DemoService") + state.SetServiceMethod("foo") + state.SetServiceType("LOCAL") + state.SetForUpdate(false) + state.SetStartedTime(time.Now()) + state.SetStatus(statelang.RU) + return state +} + +func attachServiceTaskDefinition(machine statelang.StateMachineInstance, cfg *stubConfig, name string) { + sm := machine.StateMachine() + task := stateimpl.NewServiceTaskStateImpl() + task.SetName(name) + task.SetServiceName("DemoService") + task.SetServiceMethod("foo") + task.SetServiceType("LOCAL") + sm.States()[name] = task + sm.SetStartState(name) + _ = cfg.repository.RegistryStateMachine(sm) +} + func TestStateLogStore_RecordStateMachineStarted(t *testing.T) { - prepareDB() + prepareCleanDB(t) const stateMachineName = "stateMachine" stateLogStore := NewStateLogStore(db, "seata_") @@ -90,14 +358,26 @@ func TestStateLogStore_RecordStateMachineStarted(t *testing.T) { assert.Equal(t, expected.UpdatedTime().UnixNano(), actual.UpdatedTime().UnixNano()) } -func TestStateLogStore_RecordStateMachineFinished(t *testing.T) { +func prepareCleanDB(t *testing.T) { prepareDB() + if db == nil { + return + } + _, err := db.Exec("DELETE FROM seata_state_inst") + require.NoError(t, err) + _, err = db.Exec("DELETE FROM seata_state_machine_inst") + require.NoError(t, err) +} + +func TestStateLogStore_RecordStateMachineFinished(t *testing.T) { + prepareCleanDB(t) const stateMachineName = "stateMachine" stateLogStore := NewStateLogStore(db, "seata_") expected := mockMachineInstance(stateMachineName) expected.SetBusinessKey("test_finished") ctx := mockProcessContext(stateMachineName, expected) + mockStateMachineConfig(ctx) err := stateLogStore.RecordStateMachineStarted(context.Background(), expected, ctx) assert.Nil(t, err) expected.SetEndParams(map[string]any{"end": 100}) @@ -118,18 +398,20 @@ func TestStateLogStore_RecordStateMachineFinished(t *testing.T) { assert.Equal(t, expected.Status(), actual.Status()) assert.Equal(t, expected.IsRunning(), actual.IsRunning()) assert.Equal(t, expected.StartedTime().UnixNano(), actual.StartedTime().UnixNano()) - assert.Greater(t, actual.UpdatedTime().UnixNano(), expected.UpdatedTime().UnixNano()) + assert.False(t, actual.UpdatedTime().IsZero()) + assert.True(t, actual.UpdatedTime().UnixNano() >= actual.StartedTime().UnixNano()) assert.False(t, expected.EndTime().IsZero()) } func TestStateLogStore_RecordStateMachineRestarted(t *testing.T) { - prepareDB() + prepareCleanDB(t) const stateMachineName = "stateMachine" stateLogStore := NewStateLogStore(db, "seata_") expected := mockMachineInstance(stateMachineName) expected.SetBusinessKey("test_restarted") ctx := mockProcessContext(stateMachineName, expected) + mockStateMachineConfig(ctx) err := stateLogStore.RecordStateMachineStarted(context.Background(), expected, ctx) assert.Nil(t, err) expected.SetRunning(false) @@ -148,13 +430,17 @@ func TestStateLogStore_RecordStateMachineRestarted(t *testing.T) { } func TestStateLogStore_RecordStateStarted(t *testing.T) { - prepareDB() + prepareCleanDB(t) const stateMachineName = "stateMachine" stateLogStore := NewStateLogStore(db, "seata_") + stateLogStore.sagaTransactionalTemplate = &fakeSagaTemplate{} machineInstance := mockMachineInstance("stateMachine") ctx := mockProcessContext(stateMachineName, machineInstance) + _ = mockStateMachineConfig(ctx) + cfg := mockStateMachineConfig(ctx) + attachServiceTaskDefinition(machineInstance, cfg, "ServiceTask1") machineInstance.SetID("test") common := statelang.NewStateInstanceImpl() @@ -221,38 +507,266 @@ func TestStateLogStore_RecordStateStarted(t *testing.T) { } } +func TestRecordStateStartedSkipBranchRegister(t *testing.T) { + prepareCleanDB(t) + + store := NewStateLogStore(db, "seata_") + machine := mockMachineInstance("stateMachine") + ctx := mockProcessContext("stateMachine", machine) + cfg := mockStateMachineConfig(ctx) + cfg.SetSagaBranchRegisterEnable(false) + fakeTemplate := &fakeSagaTemplate{nextBranchID: 555} + store.sagaTransactionalTemplate = fakeTemplate + baseRegister := globalFakeRM.branchRegisterCalls + + attachServiceTaskDefinition(machine, cfg, "ServiceTask1") + state := newServiceTaskState(machine) + require.NoError(t, store.RecordStateStarted(context.Background(), state, ctx)) + require.Equal(t, baseRegister, globalFakeRM.branchRegisterCalls) + require.NotEmpty(t, state.ID()) +} + +func TestRecordStateStartedTriggersBranchRegisterWhenEnabled(t *testing.T) { + prepareCleanDB(t) + + store := NewStateLogStore(db, "seata_") + machine := mockMachineInstance("stateMachine") + machine.SetID("machine-branch-enable") + ctx := mockProcessContext("stateMachine", machine) + cfg := mockStateMachineConfig(ctx) + cfg.SetSagaBranchRegisterEnable(true) + fakeTemplate := &fakeSagaTemplate{nextBranchID: 777} + store.sagaTransactionalTemplate = fakeTemplate + baseRegister := globalFakeRM.branchRegisterCalls + + attachServiceTaskDefinition(machine, cfg, "ServiceTask1") + state := newServiceTaskState(machine) + require.NoError(t, store.RecordStateStarted(context.Background(), state, ctx)) + require.Equal(t, baseRegister+1, globalFakeRM.branchRegisterCalls) + require.NotEmpty(t, state.ID()) + _, parseErr := strconv.ParseInt(state.ID(), 10, 64) + require.NoError(t, parseErr) +} + +func TestRecordStateStartedBranchRegisterError(t *testing.T) { + prepareCleanDB(t) + + store := NewStateLogStore(db, "seata_") + machine := mockMachineInstance("stateMachine") + machine.SetID("machine-branch-error") + ctx := mockProcessContext("stateMachine", machine) + cfg := mockStateMachineConfig(ctx) + cfg.SetSagaBranchRegisterEnable(true) + fakeTemplate := &fakeSagaTemplate{nextBranchID: 888} + store.sagaTransactionalTemplate = fakeTemplate + globalFakeRM.mu.Lock() + globalFakeRM.branchRegisterErr = errors.New("mock branch register error") + globalFakeRM.nextBranchID = 0 + globalFakeRM.mu.Unlock() + t.Cleanup(func() { + globalFakeRM.mu.Lock() + globalFakeRM.branchRegisterErr = nil + globalFakeRM.nextBranchID = 0 + globalFakeRM.mu.Unlock() + }) + + attachServiceTaskDefinition(machine, cfg, "ServiceTask1") + state := newServiceTaskState(machine) + err := store.RecordStateStarted(context.Background(), state, ctx) + require.Error(t, err) + engErr, ok := engExc.IsEngineExecutionException(err) + require.True(t, ok) + require.Equal(t, seataErrors.TransactionErrorCodeBranchRegisterFailed, engErr.Code) +} + +func TestRecordStateStartedDerivesCompensationIdFromHolder(t *testing.T) { + prepareCleanDB(t) + + store := NewStateLogStore(db, "seata_") + machine := mockMachineInstance("stateMachine") + machine.SetID("machine-comp-holder") + ctx := mockProcessContext("stateMachine", machine) + cfg := mockStateMachineConfig(ctx) + cfg.SetSagaBranchRegisterEnable(true) + fakeTemplate := &fakeSagaTemplate{nextBranchID: 666} + store.sagaTransactionalTemplate = fakeTemplate + baseRegister := globalFakeRM.branchRegisterCalls + + original := newServiceTaskState(machine) + original.SetName("ReduceInventory") + original.SetID("origin-branch") + machine.PutState(original.ID(), original) + + holder := pcext.NewCompensationHolder() + holder.AddToBeCompensatedState("CompensateReduceInventory", original) + ctx.SetVariable(constant.VarNameCurrentCompensationHolder, holder) + + attachServiceTaskDefinition(machine, cfg, "CompensateReduceInventory") + compensate := newServiceTaskState(machine) + compensate.SetName("CompensateReduceInventory") + compensate.SetServiceName("inventoryAction") + compensate.SetServiceMethod("CompensateReduce") + compensate.SetServiceType("LOCAL") + compensate.SetStartedTime(time.Now()) + + require.NoError(t, store.RecordStateStarted(context.Background(), compensate, ctx)) + require.Equal(t, baseRegister, globalFakeRM.branchRegisterCalls) + require.Equal(t, "origin-branch-1", compensate.ID()) + require.Equal(t, "origin-branch", compensate.StateIDCompensatedFor()) + + stored, err := store.GetStateInstance(compensate.ID(), machine.ID()) + require.NoError(t, err) + require.Equal(t, compensate.ID(), stored.ID()) + require.Equal(t, "origin-branch", stored.StateIDCompensatedFor()) +} + +func TestRecordStateFinishedSkipBranchReportOnSuccessWhenDisabled(t *testing.T) { + prepareCleanDB(t) + + store := NewStateLogStore(db, "seata_") + machine := mockMachineInstance("stateMachine") + machine.SetID("machine-report-disabled") + ctx := mockProcessContext("stateMachine", machine) + cfg := mockStateMachineConfig(ctx) + cfg.SetSagaBranchRegisterEnable(true) + cfg.SetRmReportSuccessEnable(false) + fakeTemplate := &fakeSagaTemplate{nextBranchID: 888} + store.sagaTransactionalTemplate = fakeTemplate + baseRegister := globalFakeRM.branchRegisterCalls + baseReport := globalFakeRM.branchReportCalls + + attachServiceTaskDefinition(machine, cfg, "ServiceTask1") + state := newServiceTaskState(machine) + require.NoError(t, store.RecordStateStarted(context.Background(), state, ctx)) + + state.SetStatus(statelang.SU) + require.NoError(t, store.RecordStateFinished(context.Background(), state, ctx)) + require.Equal(t, baseRegister+1, globalFakeRM.branchRegisterCalls) + require.Equal(t, baseReport, globalFakeRM.branchReportCalls) +} + +func TestRecordStateFinishedReportsWhenEnabled(t *testing.T) { + prepareCleanDB(t) + + store := NewStateLogStore(db, "seata_") + machine := mockMachineInstance("stateMachine") + machine.SetID("machine-report-enabled") + ctx := mockProcessContext("stateMachine", machine) + cfg := mockStateMachineConfig(ctx) + cfg.SetSagaBranchRegisterEnable(true) + cfg.SetRmReportSuccessEnable(true) + fakeTemplate := &fakeSagaTemplate{nextBranchID: 999} + store.sagaTransactionalTemplate = fakeTemplate + baseRegister := globalFakeRM.branchRegisterCalls + baseReport := globalFakeRM.branchReportCalls + + attachServiceTaskDefinition(machine, cfg, "ServiceTask1") + state := newServiceTaskState(machine) + require.NoError(t, store.RecordStateStarted(context.Background(), state, ctx)) + + state.SetStatus(statelang.SU) + require.NoError(t, store.RecordStateFinished(context.Background(), state, ctx)) + require.Equal(t, baseRegister+1, globalFakeRM.branchRegisterCalls) + require.Equal(t, baseReport+1, globalFakeRM.branchReportCalls) +} + +func TestRecordStateFinishedPropagatesBranchReportError(t *testing.T) { + prepareCleanDB(t) + + store := NewStateLogStore(db, "seata_") + machine := mockMachineInstance("stateMachine") + machine.SetID("machine-report-error") + ctx := mockProcessContext("stateMachine", machine) + cfg := mockStateMachineConfig(ctx) + cfg.SetSagaBranchRegisterEnable(true) + cfg.SetRmReportSuccessEnable(true) + fakeTemplate := &fakeSagaTemplate{nextBranchID: 1001} + store.sagaTransactionalTemplate = fakeTemplate + baseRegister := globalFakeRM.branchRegisterCalls + baseReport := globalFakeRM.branchReportCalls + + attachServiceTaskDefinition(machine, cfg, "ServiceTask1") + state := newServiceTaskState(machine) + require.NoError(t, store.RecordStateStarted(context.Background(), state, ctx)) + + state.SetStatus(statelang.SU) + globalFakeRM.mu.Lock() + globalFakeRM.branchReportErr = errors.New("mock branch report error") + globalFakeRM.mu.Unlock() + t.Cleanup(func() { + globalFakeRM.mu.Lock() + globalFakeRM.branchReportErr = nil + globalFakeRM.mu.Unlock() + }) + + err := store.RecordStateFinished(context.Background(), state, ctx) + require.Error(t, err) + engErr, ok := engExc.IsEngineExecutionException(err) + require.True(t, ok) + require.Equal(t, seataErrors.TransactionErrorCodeBranchReportFailed, engErr.Code) + require.Equal(t, baseRegister+1, globalFakeRM.branchRegisterCalls) + require.Equal(t, baseReport+1, globalFakeRM.branchReportCalls) +} + +func TestRecordStateFinishedWithoutStartDoesNotInsertFallback(t *testing.T) { + prepareCleanDB(t) + + store := NewStateLogStore(db, "seata_") + machine := mockMachineInstance("stateMachine") + machine.SetID("machine-finish-only") + ctx := mockProcessContext("stateMachine", machine) + cfg := mockStateMachineConfig(ctx) + cfg.SetRmReportSuccessEnable(false) + + state := newServiceTaskState(machine) + state.SetName("ReduceInventory") + state.SetID("missing-start") + state.SetStatus(statelang.SU) + state.SetEndTime(time.Now()) + state.SetUpdatedTime(time.Now()) + + err := store.RecordStateFinished(context.Background(), state, ctx) + require.Error(t, err) + engErr, ok := engExc.IsEngineExecutionException(err) + require.True(t, ok) + require.Equal(t, seataErrors.TransactionErrorCodeFailedWriteSession, engErr.Code) + _, getErr := store.GetStateInstance(state.ID(), machine.ID()) + require.Error(t, getErr) +} + func TestStateLogStore_RecordStateFinished(t *testing.T) { - prepareDB() + prepareCleanDB(t) const stateMachineName = "stateMachine" stateLogStore := NewStateLogStore(db, "seata_") + stateLogStore.sagaTransactionalTemplate = &fakeSagaTemplate{} machineInstance := mockMachineInstance("stateMachine") ctx := mockProcessContext(stateMachineName, machineInstance) machineInstance.SetID("test") + cfg := mockStateMachineConfig(ctx) + attachServiceTaskDefinition(machineInstance, cfg, "ServiceTask1") - expected := statelang.NewStateInstanceImpl() - expected.SetStateMachineInstance(machineInstance) - expected.SetMachineInstanceID(machineInstance.ID()) + state := newServiceTaskState(machineInstance) - err := stateLogStore.RecordStateStarted(context.Background(), expected, ctx) + err := stateLogStore.RecordStateStarted(context.Background(), state, ctx) assert.Nil(t, err) - expected.SetStatus(statelang.UN) - expected.SetError(errors.New("this is a test error")) - expected.SetOutputParams(map[string]string{"output": "test"}) - err = stateLogStore.RecordStateFinished(context.Background(), expected, ctx) + state.SetStatus(statelang.UN) + state.SetError(errors.New("this is a test error")) + state.SetOutputParams(map[string]string{"output": "test"}) + err = stateLogStore.RecordStateFinished(context.Background(), state, ctx) assert.Nil(t, err) - actual, err := stateLogStore.GetStateInstance(expected.ID(), machineInstance.ID()) + actual, err := stateLogStore.GetStateInstance(state.ID(), machineInstance.ID()) assert.Nil(t, err) - assert.Equal(t, expected.Status(), actual.Status()) - assert.Equal(t, expected.Error().Error(), actual.Error().Error()) + assert.Equal(t, state.Status(), actual.Status()) + assert.Equal(t, state.Error().Error(), actual.Error().Error()) assert.NotEmpty(t, actual.OutputParams()) - assert.Equal(t, expected.SerializedOutputParams(), actual.SerializedOutputParams()) + assert.Equal(t, state.SerializedOutputParams(), actual.SerializedOutputParams()) } func TestStateLogStore_GetStateMachineInstanceByBusinessKey(t *testing.T) { - prepareDB() + prepareCleanDB(t) const stateMachineName = "stateMachine" stateLogStore := NewStateLogStore(db, "seata_") @@ -276,7 +790,7 @@ func TestStateLogStore_GetStateMachineInstanceByBusinessKey(t *testing.T) { } func TestStateLogStore_GetStateMachineInstanceByParentId(t *testing.T) { - prepareDB() + prepareCleanDB(t) const ( stateMachineName = "stateMachine" diff --git a/pkg/saga/statemachine/store/store.go b/pkg/saga/statemachine/store/store.go index 33ee8e549..59c759d7c 100644 --- a/pkg/saga/statemachine/store/store.go +++ b/pkg/saga/statemachine/store/store.go @@ -1,7 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package store import ( "context" + "github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl" "github.com/seata/seata-go/pkg/saga/statemachine/statelang" ) diff --git a/pkg/saga/tm/default_saga_transactional_template.go b/pkg/saga/tm/default_saga_transactional_template.go index dbf30f1ce..3f3864ba3 100644 --- a/pkg/saga/tm/default_saga_transactional_template.go +++ b/pkg/saga/tm/default_saga_transactional_template.go @@ -19,13 +19,14 @@ package tm import ( "context" + "time" + "github.com/seata/seata-go/pkg/protocol/branch" "github.com/seata/seata-go/pkg/protocol/message" "github.com/seata/seata-go/pkg/rm" sagarm "github.com/seata/seata-go/pkg/saga/rm" "github.com/seata/seata-go/pkg/tm" "github.com/seata/seata-go/pkg/util/log" - "time" ) type DefaultSagaTransactionalTemplate struct { diff --git a/pkg/tm/transaction_hook_manager.go b/pkg/tm/transaction_hook_manager.go index 9d5b8261c..d8d48f5c1 100644 --- a/pkg/tm/transaction_hook_manager.go +++ b/pkg/tm/transaction_hook_manager.go @@ -18,8 +18,9 @@ package tm import ( - "github.com/pkg/errors" "sync" + + "github.com/pkg/errors" ) var ( diff --git a/pkg/util/errors/code.go b/pkg/util/errors/code.go index 5d265df73..e0ff11165 100644 --- a/pkg/util/errors/code.go +++ b/pkg/util/errors/code.go @@ -113,4 +113,6 @@ const ( OperationDenied // ForwardInvalid Forward invalid ForwardInvalid + // AsynchronousStartDisabled Async start is disabled in configuration + AsynchronousStartDisabled ) diff --git a/pkg/util/reflectx/unmarkshaler.go b/pkg/util/reflectx/unmarkshaler.go index a507a3a89..a0d3387cd 100644 --- a/pkg/util/reflectx/unmarkshaler.go +++ b/pkg/util/reflectx/unmarkshaler.go @@ -19,9 +19,10 @@ package reflectx import ( "fmt" - "github.com/pkg/errors" "reflect" "unicode" + + "github.com/pkg/errors" ) // MapToStruct some state can use this util to parse diff --git a/testdata/saga/engine/invoker/grpc/product.pb.go b/testdata/saga/engine/invoker/grpc/product.pb.go index 663af14dc..4c820856f 100644 --- a/testdata/saga/engine/invoker/grpc/product.pb.go +++ b/testdata/saga/engine/invoker/grpc/product.pb.go @@ -7,10 +7,11 @@ package product import ( - protoreflect "google.golang.org/protobuf/reflect/protoreflect" - protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" sync "sync" + + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" ) const ( diff --git a/testdata/saga/engine/invoker/grpc/product_grpc.pb.go b/testdata/saga/engine/invoker/grpc/product_grpc.pb.go index 9ea290fc9..bb2611789 100644 --- a/testdata/saga/engine/invoker/grpc/product_grpc.pb.go +++ b/testdata/saga/engine/invoker/grpc/product_grpc.pb.go @@ -8,6 +8,7 @@ package product import ( context "context" + grpc "google.golang.org/grpc" codes "google.golang.org/grpc/codes" status "google.golang.org/grpc/status"