diff --git a/go.mod b/go.mod index 1f88639c85..d180427271 100644 --- a/go.mod +++ b/go.mod @@ -18,12 +18,21 @@ require ( github.com/stretchr/testify v1.11.1 github.com/tencent-connect/botgo v0.2.1 golang.org/x/oauth2 v0.35.0 + modernc.org/sqlite v1.46.1 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/ncruces/go-strftime v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect + modernc.org/libc v1.67.6 // indirect + modernc.org/mathutil v1.7.1 // indirect + modernc.org/memory v1.11.0 // indirect ) require ( diff --git a/go.sum b/go.sum index 0e95bf5cd2..852442a1a4 100644 --- a/go.sum +++ b/go.sum @@ -30,6 +30,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/github/copilot-sdk/go v0.1.23 h1:uExtO/inZQndCZMiSAA1hvXINiz9tqo/MZgQzFzurxw= @@ -62,6 +64,8 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8= github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -71,6 +75,8 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grbit/go-json v0.11.0 h1:bAbyMdYrYl/OjYsSqLH99N2DyQ291mHy726Mx+sYrnc= github.com/grbit/go-json v0.11.0/go.mod h1:IYpHsdybQ386+6g3VE6AXQ3uTGa5mquBme5/ZWmtzek= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= @@ -88,8 +94,12 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/larksuite/oapi-sdk-go/v3 v3.5.3 h1:xvf8Dv29kBXC5/DNDCLhHkAFW8l/0LlQJimO5Zn+JUk= github.com/larksuite/oapi-sdk-go/v3 v3.5.3/go.mod h1:ZEplY+kwuIrj/nqw5uSCINNATcH3KdxSN7y+UxYY5fI= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mymmrac/telego v1.6.0 h1:Zc8rgyHozvd/7ZgyrigyHdAF9koHYMfilYfyB6wlFC0= github.com/mymmrac/telego v1.6.0/go.mod h1:xt6ZWA8zi8KmuzryE1ImEdl9JSwjHNpM4yhC7D8hU4Y= +github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= +github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= @@ -105,6 +115,8 @@ github.com/openai/openai-go/v3 v3.22.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbD github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= @@ -161,10 +173,14 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= +golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= +golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA= +golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -208,6 +224,7 @@ golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= @@ -233,6 +250,8 @@ golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4f golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ= +golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -259,3 +278,31 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis= +modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= +modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc= +modernc.org/ccgo/v4 v4.30.1/go.mod h1:bIOeI1JL54Utlxn+LwrFyjCx2n2RDiYEaJVSrgdrRfM= +modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA= +modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc= +modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= +modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= +modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE= +modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY= +modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= +modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= +modernc.org/libc v1.67.6 h1:eVOQvpModVLKOdT+LvBPjdQqfrZq+pC39BygcT+E7OI= +modernc.org/libc v1.67.6/go.mod h1:JAhxUVlolfYDErnwiqaLvUqc8nfb2r6S6slAgZOnaiE= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8= +modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= +modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= +modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= +modernc.org/sqlite v1.46.1 h1:eFJ2ShBLIEnUWlLy12raN0Z1plqmFX9Qe3rjQTKt6sU= +modernc.org/sqlite v1.46.1/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA= +modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= +modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= diff --git a/pkg/memory/migration.go b/pkg/memory/migration.go new file mode 100644 index 0000000000..abb2d7af3d --- /dev/null +++ b/pkg/memory/migration.go @@ -0,0 +1,137 @@ +package memory + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "time" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +// jsonSession mirrors the JSON structure used by pkg/session for deserialization. +type jsonSession struct { + Key string `json:"key"` + Messages []providers.Message `json:"messages"` + Summary string `json:"summary,omitempty"` + Created time.Time `json:"created"` + Updated time.Time `json:"updated"` +} + +// MigrateFromJSON reads JSON session files from sessionsDir and imports them +// into the given SQLiteStore. Successfully migrated files are renamed to +// .json.migrated (not deleted) as a backup. +// +// Returns the number of sessions successfully migrated. +// Files that are already .migrated or fail to parse are skipped (partial failures +// do not stop the migration of other files). +func MigrateFromJSON(ctx context.Context, sessionsDir string, store *SQLiteStore) (int, error) { + entries, err := os.ReadDir(sessionsDir) + if err != nil { + if os.IsNotExist(err) { + return 0, nil + } + return 0, fmt.Errorf("memory: read sessions dir: %w", err) + } + + migrated := 0 + for _, entry := range entries { + if entry.IsDir() { + continue + } + name := entry.Name() + if !strings.HasSuffix(name, ".json") { + continue + } + // Skip already-migrated files. + if strings.HasSuffix(name, ".migrated") { + continue + } + + filePath := filepath.Join(sessionsDir, name) + data, err := os.ReadFile(filePath) + if err != nil { + continue // skip unreadable files + } + + var sess jsonSession + if err := json.Unmarshal(data, &sess); err != nil { + continue // skip invalid JSON + } + + if sess.Key == "" { + continue // skip sessions with no key + } + + if err := importSession(ctx, store, &sess); err != nil { + continue // skip on import error + } + + // Rename to .json.migrated as backup. + _ = os.Rename(filePath, filePath+".migrated") + migrated++ + } + + return migrated, nil +} + +// importSession imports a single JSON session into the SQLite store. +func importSession(ctx context.Context, store *SQLiteStore, sess *jsonSession) error { + tx, err := store.db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer tx.Rollback() + + createdAt := sess.Created.UTC().Format(time.RFC3339) + updatedAt := sess.Updated.UTC().Format(time.RFC3339) + + // INSERT OR IGNORE to handle partial migration retries. + _, err = tx.ExecContext(ctx, + `INSERT OR IGNORE INTO sessions (key, summary, created_at, updated_at) VALUES (?, ?, ?, ?)`, + sess.Key, sess.Summary, createdAt, updatedAt, + ) + if err != nil { + return err + } + + // Check if messages already exist (idempotent — skip if already imported). + var count int + err = tx.QueryRowContext(ctx, + `SELECT COUNT(*) FROM messages WHERE session_key = ?`, sess.Key, + ).Scan(&count) + if err != nil { + return err + } + if count > 0 { + // Already imported, nothing to do. + return tx.Commit() + } + + now := time.Now().UTC().Format(time.RFC3339) + for i, msg := range sess.Messages { + var toolCallsJSON *string + if len(msg.ToolCalls) > 0 { + data, err := json.Marshal(msg.ToolCalls) + if err != nil { + return err + } + s := string(data) + toolCallsJSON = &s + } + + _, err := tx.ExecContext(ctx, + `INSERT INTO messages (session_key, seq, role, content, tool_calls_json, tool_call_id, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?)`, + sess.Key, i+1, msg.Role, msg.Content, toolCallsJSON, msg.ToolCallID, now, + ) + if err != nil { + return err + } + } + + return tx.Commit() +} diff --git a/pkg/memory/migration_test.go b/pkg/memory/migration_test.go new file mode 100644 index 0000000000..82b31eef14 --- /dev/null +++ b/pkg/memory/migration_test.go @@ -0,0 +1,303 @@ +package memory + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +func writeJSONSession(t *testing.T, dir string, filename string, sess jsonSession) { + t.Helper() + data, err := json.MarshalIndent(sess, "", " ") + if err != nil { + t.Fatalf("marshal session: %v", err) + } + if err := os.WriteFile(filepath.Join(dir, filename), data, 0o644); err != nil { + t.Fatalf("write session file: %v", err) + } +} + +func TestMigrateFromJSON_Basic(t *testing.T) { + sessionsDir := t.TempDir() + store := openTestDB(t) + ctx := context.Background() + + writeJSONSession(t, sessionsDir, "test.json", jsonSession{ + Key: "test", + Messages: []providers.Message{ + {Role: "user", Content: "hello"}, + {Role: "assistant", Content: "hi"}, + }, + Summary: "A greeting.", + Created: time.Now(), + Updated: time.Now(), + }) + + count, err := MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("MigrateFromJSON: %v", err) + } + if count != 1 { + t.Errorf("expected 1 migrated, got %d", count) + } + + history, err := store.GetHistory(ctx, "test") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 2 { + t.Fatalf("expected 2 messages, got %d", len(history)) + } + if history[0].Content != "hello" || history[1].Content != "hi" { + t.Errorf("unexpected messages: %+v", history) + } + + summary, err := store.GetSummary(ctx, "test") + if err != nil { + t.Fatalf("GetSummary: %v", err) + } + if summary != "A greeting." { + t.Errorf("summary = %q", summary) + } +} + +func TestMigrateFromJSON_WithToolCalls(t *testing.T) { + sessionsDir := t.TempDir() + store := openTestDB(t) + ctx := context.Background() + + writeJSONSession(t, sessionsDir, "tools.json", jsonSession{ + Key: "tools", + Messages: []providers.Message{ + { + Role: "assistant", + Content: "Searching...", + ToolCalls: []providers.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: &providers.FunctionCall{ + Name: "web_search", + Arguments: `{"q":"test"}`, + }, + }, + }, + }, + { + Role: "tool", + Content: "result", + ToolCallID: "call_1", + }, + }, + Created: time.Now(), + Updated: time.Now(), + }) + + count, err := MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("MigrateFromJSON: %v", err) + } + if count != 1 { + t.Errorf("expected 1, got %d", count) + } + + history, err := store.GetHistory(ctx, "tools") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 2 { + t.Fatalf("expected 2 messages, got %d", len(history)) + } + if len(history[0].ToolCalls) != 1 { + t.Fatalf("expected 1 tool call, got %d", len(history[0].ToolCalls)) + } + if history[0].ToolCalls[0].Function.Name != "web_search" { + t.Errorf("tool call function = %q", history[0].ToolCalls[0].Function.Name) + } + if history[1].ToolCallID != "call_1" { + t.Errorf("ToolCallID = %q", history[1].ToolCallID) + } +} + +func TestMigrateFromJSON_MultipleFiles(t *testing.T) { + sessionsDir := t.TempDir() + store := openTestDB(t) + ctx := context.Background() + + for i := 0; i < 3; i++ { + key := string(rune('a' + i)) + writeJSONSession(t, sessionsDir, key+".json", jsonSession{ + Key: key, + Messages: []providers.Message{{Role: "user", Content: "msg " + key}}, + Created: time.Now(), + Updated: time.Now(), + }) + } + + count, err := MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("MigrateFromJSON: %v", err) + } + if count != 3 { + t.Errorf("expected 3, got %d", count) + } + + for i := 0; i < 3; i++ { + key := string(rune('a' + i)) + history, err := store.GetHistory(ctx, key) + if err != nil { + t.Fatalf("GetHistory(%q): %v", key, err) + } + if len(history) != 1 { + t.Errorf("session %q: expected 1 msg, got %d", key, len(history)) + } + } +} + +func TestMigrateFromJSON_InvalidJSON(t *testing.T) { + sessionsDir := t.TempDir() + store := openTestDB(t) + ctx := context.Background() + + // Write one valid and one invalid file. + writeJSONSession(t, sessionsDir, "good.json", jsonSession{ + Key: "good", + Messages: []providers.Message{{Role: "user", Content: "ok"}}, + Created: time.Now(), + Updated: time.Now(), + }) + if err := os.WriteFile(filepath.Join(sessionsDir, "bad.json"), []byte("{invalid json"), 0o644); err != nil { + t.Fatalf("write bad file: %v", err) + } + + count, err := MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("MigrateFromJSON: %v", err) + } + if count != 1 { + t.Errorf("expected 1 (bad file skipped), got %d", count) + } + + // Good file should be migrated. + history, err := store.GetHistory(ctx, "good") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 { + t.Errorf("expected 1 message, got %d", len(history)) + } +} + +func TestMigrateFromJSON_RenamesFiles(t *testing.T) { + sessionsDir := t.TempDir() + store := openTestDB(t) + ctx := context.Background() + + writeJSONSession(t, sessionsDir, "rename.json", jsonSession{ + Key: "rename", + Messages: []providers.Message{{Role: "user", Content: "hi"}}, + Created: time.Now(), + Updated: time.Now(), + }) + + _, err := MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("MigrateFromJSON: %v", err) + } + + // Original .json should not exist. + if _, err := os.Stat(filepath.Join(sessionsDir, "rename.json")); !os.IsNotExist(err) { + t.Error("rename.json should have been renamed") + } + // .json.migrated should exist. + if _, err := os.Stat(filepath.Join(sessionsDir, "rename.json.migrated")); err != nil { + t.Errorf("rename.json.migrated should exist: %v", err) + } +} + +func TestMigrateFromJSON_Idempotent(t *testing.T) { + sessionsDir := t.TempDir() + store := openTestDB(t) + ctx := context.Background() + + writeJSONSession(t, sessionsDir, "idem.json", jsonSession{ + Key: "idem", + Messages: []providers.Message{{Role: "user", Content: "once"}}, + Created: time.Now(), + Updated: time.Now(), + }) + + count1, err := MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("first migration: %v", err) + } + if count1 != 1 { + t.Errorf("first run: expected 1, got %d", count1) + } + + // Second run should find only .migrated files, skip them. + count2, err := MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("second migration: %v", err) + } + if count2 != 0 { + t.Errorf("second run: expected 0, got %d", count2) + } + + // Data should still be intact. + history, err := store.GetHistory(ctx, "idem") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 { + t.Errorf("expected 1 message, got %d", len(history)) + } +} + +func TestMigrateFromJSON_ColonInKey(t *testing.T) { + sessionsDir := t.TempDir() + store := openTestDB(t) + ctx := context.Background() + + // File is named telegram_123 (sanitized), but the key inside is telegram:123. + writeJSONSession(t, sessionsDir, "telegram_123.json", jsonSession{ + Key: "telegram:123", + Messages: []providers.Message{{Role: "user", Content: "from telegram"}}, + Created: time.Now(), + Updated: time.Now(), + }) + + count, err := MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("MigrateFromJSON: %v", err) + } + if count != 1 { + t.Errorf("expected 1, got %d", count) + } + + // Should be stored under the real key "telegram:123", not the filename. + history, err := store.GetHistory(ctx, "telegram:123") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 { + t.Fatalf("expected 1 message, got %d", len(history)) + } + if history[0].Content != "from telegram" { + t.Errorf("content = %q", history[0].Content) + } + + // Looking up by sanitized name should find nothing. + history2, err := store.GetHistory(ctx, "telegram_123") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history2) != 0 { + t.Errorf("expected 0 messages for sanitized key, got %d", len(history2)) + } +} diff --git a/pkg/memory/sqlite.go b/pkg/memory/sqlite.go new file mode 100644 index 0000000000..8d382a891c --- /dev/null +++ b/pkg/memory/sqlite.go @@ -0,0 +1,347 @@ +package memory + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "os" + "path/filepath" + "time" + + _ "modernc.org/sqlite" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +// SQLiteStore implements Store backed by a SQLite database. +type SQLiteStore struct { + db *sql.DB +} + +// Open creates or opens a SQLite database at dbPath and returns a ready-to-use SQLiteStore. +func Open(ctx context.Context, dbPath string) (*SQLiteStore, error) { + if err := os.MkdirAll(filepath.Dir(dbPath), 0o755); err != nil { + return nil, fmt.Errorf("memory: create directory: %w", err) + } + + db, err := sql.Open("sqlite", dbPath) + if err != nil { + return nil, fmt.Errorf("memory: open database: %w", err) + } + + // Single connection — serializes all operations for safety. + db.SetMaxOpenConns(1) + + // Apply PRAGMAs for embedded-friendly performance. + pragmas := []string{ + "PRAGMA journal_mode=WAL", + "PRAGMA busy_timeout=5000", + "PRAGMA synchronous=NORMAL", + "PRAGMA foreign_keys=ON", + "PRAGMA cache_size=-512", + } + for _, p := range pragmas { + if _, err := db.ExecContext(ctx, p); err != nil { + db.Close() + return nil, fmt.Errorf("memory: pragma %q: %w", p, err) + } + } + + s := &SQLiteStore{db: db} + if err := s.ensureSchema(ctx); err != nil { + db.Close() + return nil, err + } + return s, nil +} + +func (s *SQLiteStore) ensureSchema(ctx context.Context) error { + const ddl = ` +CREATE TABLE IF NOT EXISTS sessions ( + key TEXT PRIMARY KEY, + summary TEXT NOT NULL DEFAULT '', + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS messages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_key TEXT NOT NULL REFERENCES sessions(key) ON DELETE CASCADE, + seq INTEGER NOT NULL, + role TEXT NOT NULL, + content TEXT NOT NULL DEFAULT '', + tool_calls_json TEXT, + tool_call_id TEXT NOT NULL DEFAULT '', + created_at TEXT NOT NULL, + UNIQUE(session_key, seq) +);` + if _, err := s.db.ExecContext(ctx, ddl); err != nil { + return fmt.Errorf("memory: create schema: %w", err) + } + return nil +} + +// ensureSession inserts a session row if it doesn't already exist. +// Must be called inside a transaction. +func ensureSession(ctx context.Context, tx *sql.Tx, sessionKey string) error { + now := time.Now().UTC().Format(time.RFC3339) + _, err := tx.ExecContext(ctx, + `INSERT OR IGNORE INTO sessions (key, summary, created_at, updated_at) VALUES (?, '', ?, ?)`, + sessionKey, now, now, + ) + return err +} + +// touchSession updates the updated_at timestamp. Must be called inside a transaction. +func touchSession(ctx context.Context, tx *sql.Tx, sessionKey string) error { + now := time.Now().UTC().Format(time.RFC3339) + _, err := tx.ExecContext(ctx, + `UPDATE sessions SET updated_at = ? WHERE key = ?`, + now, sessionKey, + ) + return err +} + +// nextSeq returns the next sequence number for a session. Must be called inside a transaction. +func nextSeq(ctx context.Context, tx *sql.Tx, sessionKey string) (int, error) { + var maxSeq sql.NullInt64 + err := tx.QueryRowContext(ctx, + `SELECT MAX(seq) FROM messages WHERE session_key = ?`, + sessionKey, + ).Scan(&maxSeq) + if err != nil { + return 0, err + } + if maxSeq.Valid { + return int(maxSeq.Int64) + 1, nil + } + return 1, nil +} + +func (s *SQLiteStore) AddMessage(ctx context.Context, sessionKey, role, content string) error { + return s.AddFullMessage(ctx, sessionKey, providers.Message{ + Role: role, + Content: content, + }) +} + +func (s *SQLiteStore) AddFullMessage(ctx context.Context, sessionKey string, msg providers.Message) error { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("memory: begin tx: %w", err) + } + defer tx.Rollback() + + err = ensureSession(ctx, tx, sessionKey) + if err != nil { + return fmt.Errorf("memory: ensure session: %w", err) + } + + seq, err := nextSeq(ctx, tx, sessionKey) + if err != nil { + return fmt.Errorf("memory: next seq: %w", err) + } + + var toolCallsJSON *string + if len(msg.ToolCalls) > 0 { + var data []byte + data, err = json.Marshal(msg.ToolCalls) + if err != nil { + return fmt.Errorf("memory: marshal tool calls: %w", err) + } + str := string(data) + toolCallsJSON = &str + } + + now := time.Now().UTC().Format(time.RFC3339) + _, err = tx.ExecContext(ctx, + `INSERT INTO messages (session_key, seq, role, content, tool_calls_json, tool_call_id, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?)`, + sessionKey, seq, msg.Role, msg.Content, toolCallsJSON, msg.ToolCallID, now, + ) + if err != nil { + return fmt.Errorf("memory: insert message: %w", err) + } + + err = touchSession(ctx, tx, sessionKey) + if err != nil { + return fmt.Errorf("memory: touch session: %w", err) + } + + return tx.Commit() +} + +func (s *SQLiteStore) GetHistory(ctx context.Context, sessionKey string) ([]providers.Message, error) { + rows, err := s.db.QueryContext(ctx, + `SELECT role, content, tool_calls_json, tool_call_id + FROM messages + WHERE session_key = ? + ORDER BY seq ASC`, + sessionKey, + ) + if err != nil { + return nil, fmt.Errorf("memory: query messages: %w", err) + } + defer rows.Close() + + var messages []providers.Message + for rows.Next() { + var ( + role string + content string + toolCallsJSON sql.NullString + toolCallID string + ) + if err := rows.Scan(&role, &content, &toolCallsJSON, &toolCallID); err != nil { + return nil, fmt.Errorf("memory: scan message: %w", err) + } + + msg := providers.Message{ + Role: role, + Content: content, + ToolCallID: toolCallID, + } + if toolCallsJSON.Valid && toolCallsJSON.String != "" { + if err := json.Unmarshal([]byte(toolCallsJSON.String), &msg.ToolCalls); err != nil { + return nil, fmt.Errorf("memory: unmarshal tool calls: %w", err) + } + } + messages = append(messages, msg) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("memory: rows iteration: %w", err) + } + + if messages == nil { + messages = []providers.Message{} + } + return messages, nil +} + +func (s *SQLiteStore) GetSummary(ctx context.Context, sessionKey string) (string, error) { + var summary string + err := s.db.QueryRowContext(ctx, + `SELECT summary FROM sessions WHERE key = ?`, + sessionKey, + ).Scan(&summary) + if err == sql.ErrNoRows { + return "", nil + } + if err != nil { + return "", fmt.Errorf("memory: get summary: %w", err) + } + return summary, nil +} + +func (s *SQLiteStore) SetSummary(ctx context.Context, sessionKey, summary string) error { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("memory: begin tx: %w", err) + } + defer tx.Rollback() + + err = ensureSession(ctx, tx, sessionKey) + if err != nil { + return fmt.Errorf("memory: ensure session: %w", err) + } + + now := time.Now().UTC().Format(time.RFC3339) + _, err = tx.ExecContext(ctx, + `UPDATE sessions SET summary = ?, updated_at = ? WHERE key = ?`, + summary, now, sessionKey, + ) + if err != nil { + return fmt.Errorf("memory: set summary: %w", err) + } + + return tx.Commit() +} + +func (s *SQLiteStore) TruncateHistory(ctx context.Context, sessionKey string, keepLast int) error { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("memory: begin tx: %w", err) + } + defer tx.Rollback() + + if keepLast <= 0 { + _, err = tx.ExecContext(ctx, + `DELETE FROM messages WHERE session_key = ?`, + sessionKey, + ) + } else { + _, err = tx.ExecContext(ctx, + `DELETE FROM messages WHERE session_key = ? AND id NOT IN ( + SELECT id FROM messages WHERE session_key = ? ORDER BY seq DESC LIMIT ? + )`, + sessionKey, sessionKey, keepLast, + ) + } + if err != nil { + return fmt.Errorf("memory: truncate history: %w", err) + } + + err = touchSession(ctx, tx, sessionKey) + if err != nil { + return fmt.Errorf("memory: touch session: %w", err) + } + + return tx.Commit() +} + +func (s *SQLiteStore) SetHistory(ctx context.Context, sessionKey string, history []providers.Message) error { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("memory: begin tx: %w", err) + } + defer tx.Rollback() + + err = ensureSession(ctx, tx, sessionKey) + if err != nil { + return fmt.Errorf("memory: ensure session: %w", err) + } + + // Delete all existing messages for this session. + _, err = tx.ExecContext(ctx, + `DELETE FROM messages WHERE session_key = ?`, sessionKey, + ) + if err != nil { + return fmt.Errorf("memory: delete old messages: %w", err) + } + + // Insert new messages with sequential seq numbers. + now := time.Now().UTC().Format(time.RFC3339) + for i, msg := range history { + var toolCallsJSON *string + if len(msg.ToolCalls) > 0 { + var data []byte + data, err = json.Marshal(msg.ToolCalls) + if err != nil { + return fmt.Errorf("memory: marshal tool calls: %w", err) + } + str := string(data) + toolCallsJSON = &str + } + + _, err = tx.ExecContext(ctx, + `INSERT INTO messages (session_key, seq, role, content, tool_calls_json, tool_call_id, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?)`, + sessionKey, i+1, msg.Role, msg.Content, toolCallsJSON, msg.ToolCallID, now, + ) + if err != nil { + return fmt.Errorf("memory: insert message %d: %w", i, err) + } + } + + err = touchSession(ctx, tx, sessionKey) + if err != nil { + return fmt.Errorf("memory: touch session: %w", err) + } + + return tx.Commit() +} + +func (s *SQLiteStore) Close() error { + return s.db.Close() +} diff --git a/pkg/memory/sqlite_test.go b/pkg/memory/sqlite_test.go new file mode 100644 index 0000000000..0d350edb67 --- /dev/null +++ b/pkg/memory/sqlite_test.go @@ -0,0 +1,509 @@ +package memory + +import ( + "context" + "path/filepath" + "sync" + "testing" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +func openTestDB(t *testing.T) *SQLiteStore { + t.Helper() + dbPath := filepath.Join(t.TempDir(), "test.db") + store, err := Open(context.Background(), dbPath) + if err != nil { + t.Fatalf("Open(%q): %v", dbPath, err) + } + t.Cleanup(func() { store.Close() }) + return store +} + +func TestOpen_CreatesDatabase(t *testing.T) { + dir := t.TempDir() + dbPath := filepath.Join(dir, "sub", "sessions.db") + store, err := Open(context.Background(), dbPath) + if err != nil { + t.Fatalf("Open: %v", err) + } + defer store.Close() + + // Verify schema tables exist. + for _, table := range []string{"sessions", "messages"} { + var name string + err := store.db.QueryRow( + `SELECT name FROM sqlite_master WHERE type='table' AND name=?`, table, + ).Scan(&name) + if err != nil { + t.Errorf("table %q not found: %v", table, err) + } + } +} + +func TestOpen_ExistingDatabase(t *testing.T) { + dir := t.TempDir() + dbPath := filepath.Join(dir, "persist.db") + ctx := context.Background() + + // Write data, then close. + store, err := Open(ctx, dbPath) + if err != nil { + t.Fatalf("Open: %v", err) + } + err = store.AddMessage(ctx, "s1", "user", "hello") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + store.Close() + + // Re-open and verify persistence. + store2, err := Open(ctx, dbPath) + if err != nil { + t.Fatalf("Open (reopen): %v", err) + } + defer store2.Close() + + history, err := store2.GetHistory(ctx, "s1") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 || history[0].Content != "hello" { + t.Errorf("expected 1 message 'hello', got %v", history) + } +} + +func TestAddMessage_BasicRoundtrip(t *testing.T) { + store := openTestDB(t) + ctx := context.Background() + + if err := store.AddMessage(ctx, "s1", "user", "hi"); err != nil { + t.Fatalf("AddMessage: %v", err) + } + if err := store.AddMessage(ctx, "s1", "assistant", "hello"); err != nil { + t.Fatalf("AddMessage: %v", err) + } + + history, err := store.GetHistory(ctx, "s1") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 2 { + t.Fatalf("expected 2 messages, got %d", len(history)) + } + if history[0].Role != "user" || history[0].Content != "hi" { + t.Errorf("msg[0] = %+v", history[0]) + } + if history[1].Role != "assistant" || history[1].Content != "hello" { + t.Errorf("msg[1] = %+v", history[1]) + } +} + +func TestAddMessage_AutoCreatesSession(t *testing.T) { + store := openTestDB(t) + ctx := context.Background() + + // Adding a message to a non-existent session should auto-create it. + if err := store.AddMessage(ctx, "new-session", "user", "first"); err != nil { + t.Fatalf("AddMessage: %v", err) + } + + history, err := store.GetHistory(ctx, "new-session") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 { + t.Fatalf("expected 1 message, got %d", len(history)) + } +} + +func TestAddFullMessage_WithToolCalls(t *testing.T) { + store := openTestDB(t) + ctx := context.Background() + + msg := providers.Message{ + Role: "assistant", + Content: "Let me search.", + ToolCalls: []providers.ToolCall{ + { + ID: "call_123", + Type: "function", + Function: &providers.FunctionCall{ + Name: "web_search", + Arguments: `{"query":"test"}`, + }, + }, + }, + } + if err := store.AddFullMessage(ctx, "s1", msg); err != nil { + t.Fatalf("AddFullMessage: %v", err) + } + + history, err := store.GetHistory(ctx, "s1") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 { + t.Fatalf("expected 1 message, got %d", len(history)) + } + + got := history[0] + if got.Content != "Let me search." { + t.Errorf("content = %q", got.Content) + } + if len(got.ToolCalls) != 1 { + t.Fatalf("expected 1 tool call, got %d", len(got.ToolCalls)) + } + tc := got.ToolCalls[0] + if tc.ID != "call_123" || tc.Type != "function" { + t.Errorf("tool call = %+v", tc) + } + if tc.Function == nil || tc.Function.Name != "web_search" { + t.Errorf("function = %+v", tc.Function) + } + if tc.Function.Arguments != `{"query":"test"}` { + t.Errorf("arguments = %q", tc.Function.Arguments) + } +} + +func TestAddFullMessage_ToolCallID(t *testing.T) { + store := openTestDB(t) + ctx := context.Background() + + msg := providers.Message{ + Role: "tool", + Content: "search result here", + ToolCallID: "call_123", + } + if err := store.AddFullMessage(ctx, "s1", msg); err != nil { + t.Fatalf("AddFullMessage: %v", err) + } + + history, err := store.GetHistory(ctx, "s1") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 { + t.Fatalf("expected 1 message, got %d", len(history)) + } + if history[0].ToolCallID != "call_123" { + t.Errorf("ToolCallID = %q, want %q", history[0].ToolCallID, "call_123") + } +} + +func TestGetHistory_EmptySession(t *testing.T) { + store := openTestDB(t) + ctx := context.Background() + + history, err := store.GetHistory(ctx, "nonexistent") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if history == nil { + t.Fatal("expected non-nil empty slice") + } + if len(history) != 0 { + t.Errorf("expected 0 messages, got %d", len(history)) + } +} + +func TestGetHistory_Ordering(t *testing.T) { + store := openTestDB(t) + ctx := context.Background() + + for i := 0; i < 10; i++ { + content := string(rune('a' + i)) + if err := store.AddMessage(ctx, "s1", "user", content); err != nil { + t.Fatalf("AddMessage(%d): %v", i, err) + } + } + + history, err := store.GetHistory(ctx, "s1") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 10 { + t.Fatalf("expected 10 messages, got %d", len(history)) + } + for i, msg := range history { + expected := string(rune('a' + i)) + if msg.Content != expected { + t.Errorf("msg[%d].Content = %q, want %q", i, msg.Content, expected) + } + } +} + +func TestSetSummary_GetSummary(t *testing.T) { + store := openTestDB(t) + ctx := context.Background() + + // No session yet — should return empty string. + summary, err := store.GetSummary(ctx, "s1") + if err != nil { + t.Fatalf("GetSummary: %v", err) + } + if summary != "" { + t.Errorf("expected empty summary, got %q", summary) + } + + // Set summary (auto-creates session). + err = store.SetSummary(ctx, "s1", "User asked about Go.") + if err != nil { + t.Fatalf("SetSummary: %v", err) + } + + summary, err = store.GetSummary(ctx, "s1") + if err != nil { + t.Fatalf("GetSummary: %v", err) + } + if summary != "User asked about Go." { + t.Errorf("summary = %q", summary) + } + + // Overwrite summary. + err = store.SetSummary(ctx, "s1", "Updated.") + if err != nil { + t.Fatalf("SetSummary: %v", err) + } + summary, err = store.GetSummary(ctx, "s1") + if err != nil { + t.Fatalf("GetSummary: %v", err) + } + if summary != "Updated." { + t.Errorf("summary = %q", summary) + } +} + +func TestTruncateHistory_KeepLast(t *testing.T) { + store := openTestDB(t) + ctx := context.Background() + + for i := 0; i < 10; i++ { + if err := store.AddMessage(ctx, "s1", "user", string(rune('a'+i))); err != nil { + t.Fatalf("AddMessage: %v", err) + } + } + + if err := store.TruncateHistory(ctx, "s1", 4); err != nil { + t.Fatalf("TruncateHistory: %v", err) + } + + history, err := store.GetHistory(ctx, "s1") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 4 { + t.Fatalf("expected 4 messages, got %d", len(history)) + } + // Should keep the last 4: g, h, i, j + expected := []string{"g", "h", "i", "j"} + for i, msg := range history { + if msg.Content != expected[i] { + t.Errorf("msg[%d].Content = %q, want %q", i, msg.Content, expected[i]) + } + } +} + +func TestTruncateHistory_KeepZero(t *testing.T) { + store := openTestDB(t) + ctx := context.Background() + + for i := 0; i < 5; i++ { + if err := store.AddMessage(ctx, "s1", "user", "msg"); err != nil { + t.Fatalf("AddMessage: %v", err) + } + } + + if err := store.TruncateHistory(ctx, "s1", 0); err != nil { + t.Fatalf("TruncateHistory: %v", err) + } + + history, err := store.GetHistory(ctx, "s1") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 0 { + t.Errorf("expected 0 messages, got %d", len(history)) + } +} + +func TestSetHistory_ReplacesAll(t *testing.T) { + store := openTestDB(t) + ctx := context.Background() + + // Add some initial messages. + for i := 0; i < 5; i++ { + if err := store.AddMessage(ctx, "s1", "user", "old"); err != nil { + t.Fatalf("AddMessage: %v", err) + } + } + + // Replace with new history. + newHistory := []providers.Message{ + {Role: "system", Content: "You are helpful."}, + {Role: "user", Content: "new question"}, + {Role: "assistant", Content: "new answer"}, + } + if err := store.SetHistory(ctx, "s1", newHistory); err != nil { + t.Fatalf("SetHistory: %v", err) + } + + history, err := store.GetHistory(ctx, "s1") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 3 { + t.Fatalf("expected 3 messages, got %d", len(history)) + } + for i, msg := range history { + if msg.Role != newHistory[i].Role || msg.Content != newHistory[i].Content { + t.Errorf("msg[%d] = %+v, want %+v", i, msg, newHistory[i]) + } + } +} + +func TestConcurrent_AddAndRead(t *testing.T) { + store := openTestDB(t) + ctx := context.Background() + + const goroutines = 10 + const msgsPerGoroutine = 20 + + var wg sync.WaitGroup + wg.Add(goroutines * 2) // writers + readers + + // Writers + for g := 0; g < goroutines; g++ { + go func(id int) { + defer wg.Done() + for i := 0; i < msgsPerGoroutine; i++ { + err := store.AddMessage(ctx, "concurrent", "user", "msg") + if err != nil { + t.Errorf("goroutine %d: AddMessage: %v", id, err) + return + } + } + }(g) + } + + // Readers + for g := 0; g < goroutines; g++ { + go func(id int) { + defer wg.Done() + for i := 0; i < msgsPerGoroutine; i++ { + _, err := store.GetHistory(ctx, "concurrent") + if err != nil { + t.Errorf("goroutine %d: GetHistory: %v", id, err) + return + } + } + }(g) + } + + wg.Wait() + + // Final count should be exactly goroutines * msgsPerGoroutine. + history, err := store.GetHistory(ctx, "concurrent") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + expected := goroutines * msgsPerGoroutine + if len(history) != expected { + t.Errorf("expected %d messages, got %d", expected, len(history)) + } +} + +func TestConcurrent_SummarizeRace(t *testing.T) { + // Simulates the #704 race: one goroutine does TruncateHistory while another does AddFullMessage. + store := openTestDB(t) + ctx := context.Background() + + // Seed with some messages. + for i := 0; i < 20; i++ { + if err := store.AddMessage(ctx, "race", "user", "msg"); err != nil { + t.Fatalf("AddMessage: %v", err) + } + } + + var wg sync.WaitGroup + wg.Add(2) + + // Summarize goroutine: set summary + truncate + go func() { + defer wg.Done() + for i := 0; i < 10; i++ { + _ = store.SetSummary(ctx, "race", "summary text") + _ = store.TruncateHistory(ctx, "race", 4) + } + }() + + // Main loop goroutine: add messages + go func() { + defer wg.Done() + for i := 0; i < 10; i++ { + _ = store.AddMessage(ctx, "race", "user", "new msg") + _ = store.AddMessage(ctx, "race", "assistant", "response") + } + }() + + wg.Wait() + + // No panic, no corruption — just verify we can still read. + _, err := store.GetHistory(ctx, "race") + if err != nil { + t.Fatalf("GetHistory after race: %v", err) + } + _, err = store.GetSummary(ctx, "race") + if err != nil { + t.Fatalf("GetSummary after race: %v", err) + } +} + +func BenchmarkAddMessage(b *testing.B) { + dbPath := filepath.Join(b.TempDir(), "bench.db") + store, err := Open(context.Background(), dbPath) + if err != nil { + b.Fatalf("Open: %v", err) + } + defer store.Close() + + ctx := context.Background() + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := store.AddMessage(ctx, "bench", "user", "benchmark message") + if err != nil { + b.Fatalf("AddMessage: %v", err) + } + } +} + +func BenchmarkGetHistory_100(b *testing.B) { + benchGetHistory(b, 100) +} + +func BenchmarkGetHistory_1000(b *testing.B) { + benchGetHistory(b, 1000) +} + +func benchGetHistory(b *testing.B, count int) { + dbPath := filepath.Join(b.TempDir(), "bench.db") + store, err := Open(context.Background(), dbPath) + if err != nil { + b.Fatalf("Open: %v", err) + } + defer store.Close() + + ctx := context.Background() + for i := 0; i < count; i++ { + if err := store.AddMessage(ctx, "bench", "user", "message content"); err != nil { + b.Fatalf("AddMessage: %v", err) + } + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if _, err := store.GetHistory(ctx, "bench"); err != nil { + b.Fatalf("GetHistory: %v", err) + } + } +} diff --git a/pkg/memory/store.go b/pkg/memory/store.go new file mode 100644 index 0000000000..d6981a70dd --- /dev/null +++ b/pkg/memory/store.go @@ -0,0 +1,38 @@ +package memory + +import ( + "context" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +// Store defines an interface for persistent session storage. +// Each method is an atomic operation — there is no separate Save() call. +type Store interface { + // AddMessage appends a simple text message to a session. + AddMessage(ctx context.Context, sessionKey, role, content string) error + + // AddFullMessage appends a complete message (with tool calls and tool call ID) to a session. + AddFullMessage(ctx context.Context, sessionKey string, msg providers.Message) error + + // GetHistory returns all messages for a session in insertion order. + // Returns an empty slice (not nil) if the session does not exist. + GetHistory(ctx context.Context, sessionKey string) ([]providers.Message, error) + + // GetSummary returns the conversation summary for a session. + // Returns an empty string if no summary exists. + GetSummary(ctx context.Context, sessionKey string) (string, error) + + // SetSummary updates the conversation summary for a session. + SetSummary(ctx context.Context, sessionKey, summary string) error + + // TruncateHistory removes all but the last keepLast messages from a session. + // If keepLast <= 0, all messages are removed. + TruncateHistory(ctx context.Context, sessionKey string, keepLast int) error + + // SetHistory replaces all messages in a session with the provided history. + SetHistory(ctx context.Context, sessionKey string, history []providers.Message) error + + // Close releases any resources held by the store. + Close() error +}