diff --git a/base/copier/copier.go b/common/copier/copier.go similarity index 99% rename from base/copier/copier.go rename to common/copier/copier.go index 621af599e..fd8b2c4c2 100644 --- a/base/copier/copier.go +++ b/common/copier/copier.go @@ -16,8 +16,9 @@ package copier import ( "encoding" - "github.com/juju/errors" "reflect" + + "github.com/juju/errors" ) func Copy(dst, src interface{}) error { diff --git a/base/copier/copier_test.go b/common/copier/copier_test.go similarity index 99% rename from base/copier/copier_test.go rename to common/copier/copier_test.go index fca089e67..eee876b90 100644 --- a/base/copier/copier_test.go +++ b/common/copier/copier_test.go @@ -15,10 +15,11 @@ package copier import ( + "testing" + "github.com/juju/errors" "github.com/stretchr/testify/assert" "google.golang.org/protobuf/proto" - "testing" ) func TestPrimitives(t *testing.T) { diff --git a/base/jsonutil/json.go b/common/jsonutil/json.go similarity index 100% rename from base/jsonutil/json.go rename to common/jsonutil/json.go diff --git a/base/jsonutil/json_test.go b/common/jsonutil/json_test.go similarity index 99% rename from base/jsonutil/json_test.go rename to common/jsonutil/json_test.go index 7e7610542..b82f4de6d 100644 --- a/base/jsonutil/json_test.go +++ b/common/jsonutil/json_test.go @@ -15,8 +15,9 @@ package jsonutil import ( - "github.com/stretchr/testify/assert" "testing" + + "github.com/stretchr/testify/assert" ) func TestUnmarshal(t *testing.T) { diff --git a/common/parallel/parallel.go b/common/parallel/parallel.go index 071244784..07c35bf84 100644 --- a/common/parallel/parallel.go +++ b/common/parallel/parallel.go @@ -17,7 +17,7 @@ package parallel import ( "sync" - "github.com/gorse-io/gorse/base" + "github.com/gorse-io/gorse/common/util" "github.com/juju/errors" "github.com/samber/lo" "modernc.org/mathutil" @@ -56,7 +56,7 @@ func Parallel(nJobs, nWorkers int, worker func(workerId, jobId int) error) error // start workers workerId := j wg.Go(func() { - defer base.CheckPanic() + defer util.CheckPanic() for { // read job jobId, ok := <-c diff --git a/common/parallel/parallel_test.go b/common/parallel/parallel_test.go index 4562aab68..5cc90b41e 100644 --- a/common/parallel/parallel_test.go +++ b/common/parallel/parallel_test.go @@ -20,13 +20,13 @@ import ( "time" mapset "github.com/deckarep/golang-set/v2" - "github.com/gorse-io/gorse/base" + "github.com/gorse-io/gorse/common/util" "github.com/stretchr/testify/assert" ) func TestParallel(t *testing.T) { synctest.Test(t, func(t *testing.T) { - a := base.RangeInt(10000) + a := util.RangeInt(10000) b := make([]int, len(a)) workerIds := make([]int, len(a)) // multiple threads @@ -55,7 +55,7 @@ func TestParallel(t *testing.T) { func TestFor(t *testing.T) { synctest.Test(t, func(t *testing.T) { // multiple threads - a := base.RangeInt(10000) + a := util.RangeInt(10000) b := make([]int, len(a)) For(len(a), 4, func(jobId int) { b[jobId] = a[jobId] @@ -73,7 +73,7 @@ func TestFor(t *testing.T) { func TestForEach(t *testing.T) { synctest.Test(t, func(t *testing.T) { - a := base.RangeInt(10000) + a := util.RangeInt(10000) b := make([]int, len(a)) // multiple threads ForEach(a, 4, func(i, v int) { @@ -94,7 +94,7 @@ func TestForEach(t *testing.T) { func TestBatchParallel(t *testing.T) { synctest.Test(t, func(t *testing.T) { - a := base.RangeInt(10000) + a := util.RangeInt(10000) b := make([]int, len(a)) workerIds := make([]int, len(a)) // multiple threads diff --git a/base/random.go b/common/util/random.go similarity index 99% rename from base/random.go rename to common/util/random.go index 72a9a7b7c..758581fe0 100644 --- a/base/random.go +++ b/common/util/random.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package base +package util import ( "math/rand" diff --git a/base/random_test.go b/common/util/random_test.go similarity index 99% rename from base/random_test.go rename to common/util/random_test.go index 364c6dcc3..4b26730a6 100644 --- a/base/random_test.go +++ b/common/util/random_test.go @@ -11,7 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -package base +package util import ( "testing" diff --git a/base/util.go b/common/util/util.go similarity index 99% rename from base/util.go rename to common/util/util.go index bc2176f5d..a41169bdd 100644 --- a/base/util.go +++ b/common/util/util.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package base +package util import ( "fmt" diff --git a/base/util_test.go b/common/util/util_test.go similarity index 99% rename from base/util_test.go rename to common/util/util_test.go index 095334123..f189a3e25 100644 --- a/base/util_test.go +++ b/common/util/util_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package base +package util import ( "testing" diff --git a/dataset/dataset.go b/dataset/dataset.go index cc56367f2..ddbf43898 100644 --- a/dataset/dataset.go +++ b/dataset/dataset.go @@ -24,7 +24,6 @@ import ( "github.com/chewxy/math32" mapset "github.com/deckarep/golang-set/v2" - "github.com/gorse-io/gorse/base" "github.com/gorse-io/gorse/common/util" "github.com/gorse-io/gorse/model" "github.com/gorse-io/gorse/storage/data" @@ -65,7 +64,7 @@ type CTRSplit interface { CountContextLabels() int CountPositive() int CountNegative() int - GetIndex() base.UnifiedIndex + GetIndex() UnifiedIndex GetTarget(i int) float32 Get(i int) ([]int32, []float32, float32) } @@ -230,7 +229,7 @@ func (d *Dataset) AddFeedback(userId, itemId string) { func (d *Dataset) SampleUserNegatives(excludeSet CFSplit, numCandidates int) [][]int32 { if len(d.negatives) == 0 { - rng := base.NewRandomGenerator(0) + rng := util.NewRandomGenerator(0) d.negatives = make([][]int32, d.CountUsers()) for userIndex := 0; userIndex < d.CountUsers(); userIndex++ { s1 := mapset.NewSet(d.GetUserFeedback()[userIndex]...) @@ -252,7 +251,7 @@ func (d *Dataset) SplitCF(numTestUsers int, seed int64) (CFSplit, CFSplit) { trainSet.itemFeedback, testSet.itemFeedback = make([][]int32, d.CountItems()), make([][]int32, d.CountItems()) trainSet.userDict, testSet.userDict = d.userDict, d.userDict trainSet.itemDict, testSet.itemDict = d.itemDict, d.itemDict - rng := base.NewRandomGenerator(seed) + rng := util.NewRandomGenerator(seed) if numTestUsers >= d.CountUsers() || numTestUsers <= 0 { for userIndex := int32(0); userIndex < int32(d.CountUsers()); userIndex++ { if len(d.userFeedback[userIndex]) > 0 { diff --git a/dataset/dict.go b/dataset/dict.go index 5b2356f72..5fe2ad1a9 100644 --- a/dataset/dict.go +++ b/dataset/dict.go @@ -14,8 +14,6 @@ package dataset -import "github.com/gorse-io/gorse/base" - type FreqDict struct { si map[string]int32 is []string @@ -77,8 +75,8 @@ func (d *FreqDict) Freq(id int32) int32 { return d.cnt[id] } -func (d *FreqDict) ToIndex() *base.Index { - return &base.Index{ +func (d *FreqDict) ToIndex() *Index { + return &Index{ Numbers: d.si, Names: d.is, } diff --git a/base/index.go b/dataset/index.go similarity index 99% rename from base/index.go rename to dataset/index.go index bdfcc0704..e320c2b98 100644 --- a/base/index.go +++ b/dataset/index.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package base +package dataset import ( "encoding/binary" diff --git a/base/index_test.go b/dataset/index_test.go similarity index 98% rename from base/index_test.go rename to dataset/index_test.go index 40db7bfda..42a427e23 100644 --- a/base/index_test.go +++ b/dataset/index_test.go @@ -1,4 +1,4 @@ -package base +package dataset import ( "bytes" diff --git a/base/unified_index.go b/dataset/unified_index.go similarity index 99% rename from base/unified_index.go rename to dataset/unified_index.go index c4f6163fc..0153466eb 100644 --- a/base/unified_index.go +++ b/dataset/unified_index.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package base +package dataset import ( "encoding/binary" diff --git a/base/unified_index_test.go b/dataset/unified_index_test.go similarity index 99% rename from base/unified_index_test.go rename to dataset/unified_index_test.go index b595caa5b..9a260842a 100644 --- a/base/unified_index_test.go +++ b/dataset/unified_index_test.go @@ -11,7 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -package base +package dataset import ( "bytes" diff --git a/master/master.go b/master/master.go index 6b34b9a5d..76d6e5969 100644 --- a/master/master.go +++ b/master/master.go @@ -26,7 +26,6 @@ import ( "github.com/coreos/go-oidc/v3/oidc" "github.com/emicklei/go-restful/v3" - "github.com/gorse-io/gorse/base" "github.com/gorse-io/gorse/common/log" "github.com/gorse-io/gorse/common/monitor" "github.com/gorse-io/gorse/common/parallel" @@ -305,7 +304,7 @@ func (m *Master) Shutdown() { } func (m *Master) RunTasksLoop() { - defer base.CheckPanic() + defer util.CheckPanic() var ( err error //firstLoop = true diff --git a/master/rest.go b/master/rest.go index ddde771a0..6927328cc 100644 --- a/master/rest.go +++ b/master/rest.go @@ -34,7 +34,6 @@ import ( "github.com/go-viper/mapstructure/v2" "github.com/gorilla/securecookie" _ "github.com/gorse-io/dashboard" - "github.com/gorse-io/gorse/base" "github.com/gorse-io/gorse/cmd/version" "github.com/gorse-io/gorse/common/expression" "github.com/gorse-io/gorse/common/log" @@ -1039,7 +1038,7 @@ func (m *Master) importExportUsers(response http.ResponseWriter, request *http.R return } // validate user id - if err = base.ValidateId(user.UserId); err != nil { + if err = util.ValidateId(user.UserId); err != nil { server.BadRequest(restful.NewResponse(response), fmt.Errorf("invalid user id `%v` at line %d (%s)", user.UserId, lineCount, err.Error())) return @@ -1131,14 +1130,14 @@ func (m *Master) importExportItems(response http.ResponseWriter, request *http.R return } // validate item id - if err = base.ValidateId(item.ItemId); err != nil { + if err = util.ValidateId(item.ItemId); err != nil { server.BadRequest(restful.NewResponse(response), fmt.Errorf("invalid item id `%v` at line %d (%s)", item.ItemId, lineCount, err.Error())) return } // validate categories for _, category := range item.Categories { - if err = base.ValidateId(category); err != nil { + if err = util.ValidateId(category); err != nil { server.BadRequest(restful.NewResponse(response), fmt.Errorf("invalid category `%v` at line %d (%s)", category, lineCount, err.Error())) return @@ -1243,19 +1242,19 @@ func (m *Master) importExportFeedback(response http.ResponseWriter, request *htt return } // validate feedback type - if err = base.ValidateId(feedback.FeedbackType); err != nil { + if err = util.ValidateId(feedback.FeedbackType); err != nil { server.BadRequest(restful.NewResponse(response), fmt.Errorf("invalid feedback type `%v` at line %d (%s)", feedback.FeedbackType, lineCount, err.Error())) return } // validate user id - if err = base.ValidateId(feedback.UserId); err != nil { + if err = util.ValidateId(feedback.UserId); err != nil { server.BadRequest(restful.NewResponse(response), fmt.Errorf("invalid user id `%v` at line %d (%s)", feedback.UserId, lineCount, err.Error())) return } // validate item id - if err = base.ValidateId(feedback.ItemId); err != nil { + if err = util.ValidateId(feedback.ItemId); err != nil { server.BadRequest(restful.NewResponse(response), fmt.Errorf("invalid item id `%v` at line %d (%s)", feedback.ItemId, lineCount, err.Error())) return diff --git a/master/rpc_test.go b/master/rpc_test.go index c0be61613..2d5a8f0a3 100644 --- a/master/rpc_test.go +++ b/master/rpc_test.go @@ -24,7 +24,6 @@ import ( "testing" "time" - "github.com/gorse-io/gorse/base" "github.com/gorse-io/gorse/common/monitor" "github.com/gorse-io/gorse/common/util" "github.com/gorse-io/gorse/config" @@ -48,10 +47,10 @@ func newRankingDataset() (*dataset.Dataset, *dataset.Dataset) { } func newClickDataset() (*ctr.Dataset, *ctr.Dataset) { - dataset := &ctr.Dataset{ - Index: base.NewUnifiedMapIndexBuilder().Build(), + dataSet := &ctr.Dataset{ + Index: dataset.NewUnifiedMapIndexBuilder().Build(), } - return dataset, dataset + return dataSet, dataSet } type mockMasterRPC struct { diff --git a/master/tasks.go b/master/tasks.go index db44b6eab..db3c16c80 100644 --- a/master/tasks.go +++ b/master/tasks.go @@ -25,7 +25,6 @@ import ( "github.com/c-bata/goptuna" "github.com/c-bata/goptuna/tpe" mapset "github.com/deckarep/golang-set/v2" - "github.com/gorse-io/gorse/base" "github.com/gorse-io/gorse/common/expression" "github.com/gorse-io/gorse/common/log" "github.com/gorse-io/gorse/common/monitor" @@ -275,7 +274,7 @@ func (m *Master) LoadDataFromDatabase( // STEP 1: pull users userLabelCount := make(map[string]int) userLabelFirst := make(map[string]int32) - userLabelIndex := base.NewMapIndex() + userLabelIndex := dataset.NewMapIndex() userLabels := make([][]lo.Tuple2[int32, float32], 0, estimatedNumUsers) start := time.Now() userChan, errChan := database.GetUserStream(newCtx, batchSize) @@ -327,9 +326,9 @@ func (m *Master) LoadDataFromDatabase( var items []data.Item itemLabelCount := make(map[string]int) itemLabelFirst := make(map[string]int32) - itemLabelIndex := base.NewMapIndex() + itemLabelIndex := dataset.NewMapIndex() itemLabels := make([][]lo.Tuple2[int32, float32], 0, estimatedNumItems) - itemEmbeddingIndexer := base.NewMapIndex() + itemEmbeddingIndexer := dataset.NewMapIndex() itemEmbeddingDimension := make([]map[int]int, 0) itemEmbeddings := make([][][]float32, 0, estimatedNumItems) start = time.Now() @@ -429,11 +428,11 @@ func (m *Master) LoadDataFromDatabase( for _, f := range feedback { // convert user and item id to index userIndex := dataSet.GetUserDict().Id(f.UserId) - if userIndex == base.NotId { + if userIndex == dataset.NotId { continue } itemIndex := dataSet.GetItemDict().Id(f.ItemId) - if itemIndex == base.NotId { + if itemIndex == dataset.NotId { continue } // insert feedback to positive set @@ -514,11 +513,11 @@ func (m *Master) LoadDataFromDatabase( for feedback := range feedbackChan { for _, f := range feedback { userIndex := dataSet.GetUserDict().Id(f.UserId) - if userIndex == base.NotId { + if userIndex == dataset.NotId { continue } itemIndex := dataSet.GetItemDict().Id(f.ItemId) - if itemIndex == base.NotId { + if itemIndex == dataset.NotId { continue } negativeSet[userIndex].Add(itemIndex) @@ -544,7 +543,7 @@ func (m *Master) LoadDataFromDatabase( // STEP 5: create click-through rate dataset start = time.Now() - unifiedIndex := base.NewUnifiedMapIndexBuilder() + unifiedIndex := dataset.NewUnifiedMapIndexBuilder() unifiedIndex.ItemIndex = dataSet.GetItemDict().ToIndex() unifiedIndex.UserIndex = dataSet.GetUserDict().ToIndex() unifiedIndex.ItemLabelIndex = itemLabelIndex @@ -1120,7 +1119,7 @@ func (m *Master) collectGarbage(ctx context.Context, dataSet *dataset.Dataset) e log.Logger().Error("invalid subset", zap.String("subset", subset)) return nil } - if dataSet.GetUserDict().Id(splits[1]) == base.NotId || !lo.ContainsBy(m.Config.Recommend.UserToUser, func(cfg config.UserToUserConfig) bool { + if dataSet.GetUserDict().Id(splits[1]) == dataset.NotId || !lo.ContainsBy(m.Config.Recommend.UserToUser, func(cfg config.UserToUserConfig) bool { return cfg.Name == splits[0] }) { return m.CacheClient.DeleteScores(ctx, []string{cache.UserToUser}, cache.ScoreCondition{ @@ -1134,7 +1133,7 @@ func (m *Master) collectGarbage(ctx context.Context, dataSet *dataset.Dataset) e log.Logger().Error("invalid subset", zap.String("subset", subset)) return nil } - if dataSet.GetItemDict().Id(splits[1]) == base.NotId || !lo.ContainsBy(m.Config.Recommend.ItemToItem, func(cfg config.ItemToItemConfig) bool { + if dataSet.GetItemDict().Id(splits[1]) == dataset.NotId || !lo.ContainsBy(m.Config.Recommend.ItemToItem, func(cfg config.ItemToItemConfig) bool { return cfg.Name == splits[0] }) { return m.CacheClient.DeleteScores(ctx, []string{cache.ItemToItem}, cache.ScoreCondition{ @@ -1143,7 +1142,7 @@ func (m *Master) collectGarbage(ctx context.Context, dataSet *dataset.Dataset) e }) } case cache.CollaborativeFiltering: - if dataSet.GetUserDict().Id(subset) == base.NotId { + if dataSet.GetUserDict().Id(subset) == dataset.NotId { return m.CacheClient.DeleteScores(ctx, []string{cache.CollaborativeFiltering}, cache.ScoreCondition{ Subset: lo.ToPtr(subset), Before: lo.ToPtr(dataSet.GetTimestamp()), diff --git a/model/cf/model.go b/model/cf/model.go index 4cfe0f3e5..6fbdfca14 100644 --- a/model/cf/model.go +++ b/model/cf/model.go @@ -26,13 +26,12 @@ import ( "github.com/c-bata/goptuna" "github.com/chewxy/math32" mapset "github.com/deckarep/golang-set/v2" - "github.com/gorse-io/gorse/base" - "github.com/gorse-io/gorse/base/copier" "github.com/gorse-io/gorse/common/encoding" "github.com/gorse-io/gorse/common/floats" "github.com/gorse-io/gorse/common/log" "github.com/gorse-io/gorse/common/monitor" "github.com/gorse-io/gorse/common/parallel" + "github.com/gorse-io/gorse/common/util" "github.com/gorse-io/gorse/dataset" "github.com/gorse-io/gorse/model" "github.com/gorse-io/gorse/protocol" @@ -301,17 +300,6 @@ func (baseModel *BaseMatrixFactorization) Invalid() bool { baseModel.UserFactor == nil } -// Clone a model with deep copy. -func Clone(m MatrixFactorization) MatrixFactorization { - var copied MatrixFactorization - if err := copier.Copy(&copied, m); err != nil { - panic(err) - } else { - copied.SetParams(copied.GetParams()) - return copied - } -} - func GetModelName(m Model) string { switch m.(type) { case *BPR: @@ -419,13 +407,13 @@ func (bpr *BPR) Fit(ctx context.Context, trainSet, valSet dataset.CFSplit, confi zap.Any("config", config)) bpr.Init(trainSet) // Create buffers - temp := base.NewMatrix32(config.Jobs, bpr.nFactors) - userFactor := base.NewMatrix32(config.Jobs, bpr.nFactors) - positiveItemFactor := base.NewMatrix32(config.Jobs, bpr.nFactors) - negativeItemFactor := base.NewMatrix32(config.Jobs, bpr.nFactors) - rng := make([]base.RandomGenerator, config.Jobs) + temp := util.NewMatrix32(config.Jobs, bpr.nFactors) + userFactor := util.NewMatrix32(config.Jobs, bpr.nFactors) + positiveItemFactor := util.NewMatrix32(config.Jobs, bpr.nFactors) + negativeItemFactor := util.NewMatrix32(config.Jobs, bpr.nFactors) + rng := make([]util.RandomGenerator, config.Jobs) for i := 0; i < config.Jobs; i++ { - rng[i] = base.NewRandomGenerator(bpr.GetRandomGenerator().Int63()) + rng[i] = util.NewRandomGenerator(bpr.GetRandomGenerator().Int63()) } // Convert array to hashmap userFeedback := make([]mapset.Set[int32], trainSet.CountUsers()) @@ -604,7 +592,7 @@ func (als *ALS) Fit(ctx context.Context, trainSet, valSet dataset.CFSplit, confi zap.Any("config", config)) als.Init(trainSet) // Create temporary matrix - s := base.NewMatrix32(als.nFactors, als.nFactors) + s := util.NewMatrix32(als.nFactors, als.nFactors) userPredictions := make([][]float32, config.Jobs) itemPredictions := make([][]float32, config.Jobs) userRes := make([][]float32, config.Jobs) diff --git a/model/ctr/data.go b/model/ctr/data.go index 4ac3dd15c..d6744626e 100644 --- a/model/ctr/data.go +++ b/model/ctr/data.go @@ -22,8 +22,9 @@ import ( "strings" mapset "github.com/deckarep/golang-set/v2" - "github.com/gorse-io/gorse/base" - "github.com/gorse-io/gorse/base/jsonutil" + "github.com/gorse-io/gorse/common/jsonutil" + "github.com/gorse-io/gorse/common/util" + "github.com/gorse-io/gorse/dataset" "github.com/gorse-io/gorse/model" "github.com/juju/errors" "github.com/samber/lo" @@ -139,7 +140,7 @@ func convertEmbeddings(result []Embedding, prefix string, o any) []Embedding { // Dataset for click-through-rate models. type Dataset struct { - Index base.UnifiedIndex + Index dataset.UnifiedIndex UserLabels [][]lo.Tuple2[int32, float32] ItemLabels [][]lo.Tuple2[int32, float32] ContextLabels [][]lo.Tuple2[int32, float32] @@ -148,7 +149,7 @@ type Dataset struct { Target []float32 ItemEmbeddings [][][]float32 ItemEmbeddingDimension []int - ItemEmbeddingIndex *base.Index + ItemEmbeddingIndex *dataset.Index PositiveCount int NegativeCount int } @@ -183,7 +184,7 @@ func (dataset *Dataset) CountNegative() int { return dataset.NegativeCount } -func (dataset *Dataset) GetIndex() base.UnifiedIndex { +func (dataset *Dataset) GetIndex() dataset.UnifiedIndex { return dataset.Index } @@ -314,7 +315,7 @@ func LoadDataFromBuiltIn(name string) (train, test *Dataset, err error) { if test.ContextLabels, test.Target, testMaxLabel, err = LoadLibFMFile(testFilePath); err != nil { return nil, nil, err } - unifiedIndex := base.NewUnifiedDirectIndex(mathutil.MaxInt32(trainMaxLabel, testMaxLabel) + 1) + unifiedIndex := dataset.NewUnifiedDirectIndex(mathutil.MaxInt32(trainMaxLabel, testMaxLabel) + 1) train.Index = unifiedIndex test.Index = unifiedIndex return @@ -335,7 +336,7 @@ func (dataset *Dataset) Split(ratio float32, seed int64) (*Dataset, *Dataset) { } // split by random numTestSize := int(float32(dataset.Count()) * ratio) - rng := base.NewRandomGenerator(seed) + rng := util.NewRandomGenerator(seed) sampledIndex := mapset.NewSet(rng.Sample(0, dataset.Count(), numTestSize)...) for i := 0; i < len(dataset.Target); i++ { if sampledIndex.Contains(i) { diff --git a/model/ctr/data_test.go b/model/ctr/data_test.go index dfe2a0e89..6f09ba33b 100644 --- a/model/ctr/data_test.go +++ b/model/ctr/data_test.go @@ -18,7 +18,7 @@ import ( "fmt" "testing" - "github.com/gorse-io/gorse/base" + "github.com/gorse-io/gorse/dataset" "github.com/samber/lo" "github.com/stretchr/testify/assert" ) @@ -111,14 +111,14 @@ func TestLoadDataFromBuiltIn(t *testing.T) { func TestDataset_Split(t *testing.T) { // create dataset - unifiedIndex := base.NewUnifiedMapIndexBuilder() - dataset := NewMapIndexDataset() + unifiedIndex := dataset.NewUnifiedMapIndexBuilder() + dataSet := NewMapIndexDataset() numUsers, numItems := 5, 6 for i := 0; i < numUsers; i++ { unifiedIndex.AddUser(fmt.Sprintf("user%v", i)) unifiedIndex.AddUserLabel(fmt.Sprintf("user_label%v", 2*i)) unifiedIndex.AddUserLabel(fmt.Sprintf("user_label%v", 2*i+1)) - dataset.UserLabels = append(dataset.UserLabels, []lo.Tuple2[int32, float32]{ + dataSet.UserLabels = append(dataSet.UserLabels, []lo.Tuple2[int32, float32]{ {A: int32(2 * i), B: 1}, {A: int32(2*i + 1), B: 1}, }) @@ -128,7 +128,7 @@ func TestDataset_Split(t *testing.T) { unifiedIndex.AddItemLabel(fmt.Sprintf("item_label%v", 3*i)) unifiedIndex.AddItemLabel(fmt.Sprintf("item_label%v", 3*i+1)) unifiedIndex.AddItemLabel(fmt.Sprintf("item_label%v", 3*i+2)) - dataset.ItemLabels = append(dataset.ItemLabels, []lo.Tuple2[int32, float32]{ + dataSet.ItemLabels = append(dataSet.ItemLabels, []lo.Tuple2[int32, float32]{ {A: int32(3 * i), B: 1}, {A: int32(3*i + 1), B: 1}, {A: int32(3*i + 2), B: 1}, @@ -137,44 +137,44 @@ func TestDataset_Split(t *testing.T) { for i := 0; i < numUsers; i++ { for j := 0; j < numItems; j++ { if i+j > 4 { - dataset.Users = append(dataset.Users, int32(i)) - dataset.Items = append(dataset.Items, int32(j)) - dataset.ContextLabels = append(dataset.ContextLabels, []lo.Tuple2[int32, float32]{{A: int32(i * j), B: 0.5}}) - dataset.Target = append(dataset.Target, 1) - dataset.PositiveCount++ + dataSet.Users = append(dataSet.Users, int32(i)) + dataSet.Items = append(dataSet.Items, int32(j)) + dataSet.ContextLabels = append(dataSet.ContextLabels, []lo.Tuple2[int32, float32]{{A: int32(i * j), B: 0.5}}) + dataSet.Target = append(dataSet.Target, 1) + dataSet.PositiveCount++ } else { - dataset.Users = append(dataset.Users, int32(i)) - dataset.Items = append(dataset.Items, int32(j)) - dataset.ContextLabels = append(dataset.ContextLabels, []lo.Tuple2[int32, float32]{{A: int32(i * j), B: 0.5}}) - dataset.Target = append(dataset.Target, -1) - dataset.NegativeCount++ + dataSet.Users = append(dataSet.Users, int32(i)) + dataSet.Items = append(dataSet.Items, int32(j)) + dataSet.ContextLabels = append(dataSet.ContextLabels, []lo.Tuple2[int32, float32]{{A: int32(i * j), B: 0.5}}) + dataSet.Target = append(dataSet.Target, -1) + dataSet.NegativeCount++ } } } - dataset.Index = unifiedIndex.Build() + dataSet.Index = unifiedIndex.Build() - assert.Equal(t, numUsers*numItems, dataset.Count()) - assert.Equal(t, numUsers, dataset.CountUsers()) - assert.Equal(t, numItems, dataset.CountItems()) - assert.Equal(t, numUsers*numItems/2, dataset.PositiveCount) - assert.Equal(t, numUsers*numItems/2, dataset.NegativeCount) + assert.Equal(t, numUsers*numItems, dataSet.Count()) + assert.Equal(t, numUsers, dataSet.CountUsers()) + assert.Equal(t, numItems, dataSet.CountItems()) + assert.Equal(t, numUsers*numItems/2, dataSet.PositiveCount) + assert.Equal(t, numUsers*numItems/2, dataSet.NegativeCount) - features, values, target := dataset.Get(2) + features, values, target := dataSet.Get(2) assert.Equal(t, []int32{ 0, - dataset.Index.CountUsers() + 2, - dataset.Index.CountUsers() + dataset.Index.CountItems() + 0, - dataset.Index.CountUsers() + dataset.Index.CountItems() + 1, - dataset.Index.CountUsers() + dataset.Index.CountItems() + dataset.Index.CountUserLabels() + 6, - dataset.Index.CountUsers() + dataset.Index.CountItems() + dataset.Index.CountUserLabels() + 7, - dataset.Index.CountUsers() + dataset.Index.CountItems() + dataset.Index.CountUserLabels() + 8, + dataSet.Index.CountUsers() + 2, + dataSet.Index.CountUsers() + dataSet.Index.CountItems() + 0, + dataSet.Index.CountUsers() + dataSet.Index.CountItems() + 1, + dataSet.Index.CountUsers() + dataSet.Index.CountItems() + dataSet.Index.CountUserLabels() + 6, + dataSet.Index.CountUsers() + dataSet.Index.CountItems() + dataSet.Index.CountUserLabels() + 7, + dataSet.Index.CountUsers() + dataSet.Index.CountItems() + dataSet.Index.CountUserLabels() + 8, 0, }, features) assert.Equal(t, []float32{1, 1, 1, 1, 1, 1, 1, 0.5}, values) assert.Equal(t, float32(-1), target) // split - train, test := dataset.Split(0.2, 0) + train, test := dataSet.Split(0.2, 0) assert.Equal(t, numUsers, train.CountUsers()) assert.Equal(t, numItems, train.CountItems()) assert.Equal(t, 24, train.Count()) diff --git a/model/ctr/model.go b/model/ctr/model.go index 18615678b..7a25337d1 100644 --- a/model/ctr/model.go +++ b/model/ctr/model.go @@ -24,8 +24,6 @@ import ( "github.com/c-bata/goptuna" "github.com/chewxy/math32" - "github.com/gorse-io/gorse/base" - "github.com/gorse-io/gorse/base/copier" "github.com/gorse-io/gorse/common/encoding" "github.com/gorse-io/gorse/common/log" "github.com/gorse-io/gorse/common/monitor" @@ -115,7 +113,7 @@ type FactorizationMachineSpawner interface { type BaseFactorizationMachines struct { model.BaseModel - Index base.UnifiedIndex + Index dataset.UnifiedIndex } func (b *BaseFactorizationMachines) Init(trainSet dataset.CTRSplit) { @@ -156,20 +154,6 @@ func UnmarshalModel(r io.Reader) (FactorizationMachines, error) { return nil, fmt.Errorf("unknown model: %v", header) } -// Clone a model with deep copy. -func Clone(m FactorizationMachines) FactorizationMachines { - if cloner, ok := m.(FactorizationMachineCloner); ok { - return cloner.Clone() - } - var copied FactorizationMachines - if err := copier.Copy(&copied, m); err != nil { - panic(err) - } else { - copied.SetParams(copied.GetParams()) - return copied - } -} - func Spawn(m FactorizationMachines) FactorizationMachines { if cloner, ok := m.(FactorizationMachineSpawner); ok { return cloner.Spawn() @@ -285,25 +269,25 @@ func (fm *FMV2) BatchPredict(inputs []lo.Tuple4[string, string, []Label, []Label x := make([]lo.Tuple2[[]int32, []float32], len(inputs)) for i, input := range inputs { // encode user - if userIndex := fm.Index.EncodeUser(input.A); userIndex != base.NotId { + if userIndex := fm.Index.EncodeUser(input.A); userIndex != dataset.NotId { x[i].A = append(x[i].A, userIndex) x[i].B = append(x[i].B, 1) } // encode item - if itemIndex := fm.Index.EncodeItem(input.B); itemIndex != base.NotId { + if itemIndex := fm.Index.EncodeItem(input.B); itemIndex != dataset.NotId { x[i].A = append(x[i].A, itemIndex) x[i].B = append(x[i].B, 1) } // encode user labels for _, userFeature := range input.C { - if userFeatureIndex := fm.Index.EncodeUserLabel(userFeature.Name); userFeatureIndex != base.NotId { + if userFeatureIndex := fm.Index.EncodeUserLabel(userFeature.Name); userFeatureIndex != dataset.NotId { x[i].A = append(x[i].A, userFeatureIndex) x[i].B = append(x[i].B, userFeature.Value) } } // encode item labels for _, itemFeature := range input.D { - if itemFeatureIndex := fm.Index.EncodeItemLabel(itemFeature.Name); itemFeatureIndex != base.NotId { + if itemFeatureIndex := fm.Index.EncodeItemLabel(itemFeature.Name); itemFeatureIndex != dataset.NotId { x[i].A = append(x[i].A, itemFeatureIndex) x[i].B = append(x[i].B, itemFeature.Value) } @@ -400,7 +384,7 @@ func (fm *FMV2) Marshal(w io.Writer) error { return errors.Trace(err) } // write index - if err := base.MarshalUnifiedIndex(w, fm.Index); err != nil { + if err := dataset.MarshalUnifiedIndex(w, fm.Index); err != nil { return errors.Trace(err) } // write dataset stats @@ -425,7 +409,7 @@ func (fm *FMV2) Unmarshal(r io.Reader) error { } fm.SetParams(fm.Params) // read index - fm.Index, err = base.UnmarshalUnifiedIndex(r) + fm.Index, err = dataset.UnmarshalUnifiedIndex(r) if err != nil { return errors.Trace(err) } diff --git a/model/ctr/optimize_test.go b/model/ctr/optimize_test.go index 92d17f7f0..c12baa12e 100644 --- a/model/ctr/optimize_test.go +++ b/model/ctr/optimize_test.go @@ -20,7 +20,6 @@ import ( "github.com/c-bata/goptuna" "github.com/c-bata/goptuna/tpe" - "github.com/gorse-io/gorse/base" "github.com/gorse-io/gorse/dataset" "github.com/gorse-io/gorse/model" "github.com/samber/lo" @@ -30,7 +29,7 @@ import ( // NewMapIndexDataset creates a data set. func NewMapIndexDataset() *Dataset { s := new(Dataset) - s.Index = base.NewUnifiedDirectIndex(0) + s.Index = dataset.NewUnifiedDirectIndex(0) return s } @@ -46,11 +45,11 @@ func (m *mockFactorizationMachineForSearch) Invalid() bool { panic("implement me") } -func (m *mockFactorizationMachineForSearch) GetUserIndex() base.Index { +func (m *mockFactorizationMachineForSearch) GetUserIndex() dataset.Index { panic("don't call me") } -func (m *mockFactorizationMachineForSearch) GetItemIndex() base.Index { +func (m *mockFactorizationMachineForSearch) GetItemIndex() dataset.Index { panic("don't call me") } diff --git a/model/model.go b/model/model.go index cf76687ae..55dc7e451 100644 --- a/model/model.go +++ b/model/model.go @@ -16,7 +16,7 @@ package model import ( "github.com/c-bata/goptuna" - "github.com/gorse-io/gorse/base" + "github.com/gorse-io/gorse/common/util" ) // Model is the interface for all models. Any model in this @@ -33,7 +33,7 @@ type Model interface { // ID sets, random generator and fitting options are managed the BaseModel model. type BaseModel struct { Params Params // Hyper-parameters - rng base.RandomGenerator // Random generator + rng util.RandomGenerator // Random generator randState int64 // Random seed } @@ -41,7 +41,7 @@ type BaseModel struct { func (model *BaseModel) SetParams(params Params) { model.Params = params model.randState = model.Params.GetInt64(RandomState, 0) - model.rng = base.NewRandomGenerator(model.randState) + model.rng = util.NewRandomGenerator(model.randState) } // GetParams returns all hyper-parameters. @@ -49,6 +49,6 @@ func (model *BaseModel) GetParams() Params { return model.Params } -func (model *BaseModel) GetRandomGenerator() base.RandomGenerator { +func (model *BaseModel) GetRandomGenerator() util.RandomGenerator { return model.rng } diff --git a/server/server.go b/server/server.go index 4049692c4..6a16f7960 100644 --- a/server/server.go +++ b/server/server.go @@ -27,7 +27,6 @@ import ( "time" "github.com/emicklei/go-restful/v3" - "github.com/gorse-io/gorse/base" "github.com/gorse-io/gorse/cmd/version" "github.com/gorse-io/gorse/common/log" "github.com/gorse-io/gorse/common/util" @@ -146,7 +145,7 @@ func (s *Server) Shutdown() { // Sync this server to the master. func (s *Server) Sync() { - defer base.CheckPanic() + defer util.CheckPanic() log.Logger().Info("start meta sync", zap.Duration("meta_timeout", s.Config.Master.MetaTimeout)) for { var meta *protocol.Meta diff --git a/storage/data/database.go b/storage/data/database.go index a24210e5b..aa715c7e4 100644 --- a/storage/data/database.go +++ b/storage/data/database.go @@ -25,8 +25,8 @@ import ( "time" "github.com/XSAM/otelsql" - "github.com/gorse-io/gorse/base/jsonutil" "github.com/gorse-io/gorse/common/expression" + "github.com/gorse-io/gorse/common/jsonutil" "github.com/gorse-io/gorse/common/log" "github.com/gorse-io/gorse/storage" "github.com/juju/errors" @@ -122,9 +122,9 @@ type ItemPatch struct { // User stores meta data about user. type User struct { - UserId string `gorm:"primaryKey" mapstructure:"user_id"` - Labels any `gorm:"serializer:json" mapstructure:"labels"` - Comment string `mapstructure:"comment"` + UserId string `gorm:"primaryKey" mapstructure:"user_id"` + Labels any `gorm:"serializer:json" mapstructure:"labels"` + Comment string `mapstructure:"comment"` } // UserPatch is the modification on a user. diff --git a/storage/data/sql.go b/storage/data/sql.go index a1a3369e9..7d6bc5b8b 100644 --- a/storage/data/sql.go +++ b/storage/data/sql.go @@ -23,8 +23,8 @@ import ( mapset "github.com/deckarep/golang-set/v2" _ "github.com/go-sql-driver/mysql" - "github.com/gorse-io/gorse/base/jsonutil" "github.com/gorse-io/gorse/common/expression" + "github.com/gorse-io/gorse/common/jsonutil" "github.com/gorse-io/gorse/common/log" "github.com/gorse-io/gorse/storage" "github.com/juju/errors" diff --git a/worker/worker.go b/worker/worker.go index bbc8f801a..2e1e43dea 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -30,7 +30,6 @@ import ( "time" mapset "github.com/deckarep/golang-set/v2" - "github.com/gorse-io/gorse/base" "github.com/gorse-io/gorse/cmd/version" "github.com/gorse-io/gorse/common/expression" "github.com/gorse-io/gorse/common/heap" @@ -136,7 +135,7 @@ func NewWorker( return &Worker{ rankers: make([]ctr.FactorizationMachines, jobs), Settings: config.NewSettings(), - randGenerator: base.NewRand(time.Now().UTC().UnixNano()), + randGenerator: util.NewRand(time.Now().UTC().UnixNano()), // config cacheFile: cacheFile, masterHost: masterHost, @@ -161,7 +160,7 @@ func (w *Worker) SetOneMode(settings *config.Settings) { // Sync this worker to the master. func (w *Worker) Sync() { - defer base.CheckPanic() + defer util.CheckPanic() log.Logger().Info("start meta sync", zap.Duration("meta_timeout", w.Config.Master.MetaTimeout)) for { var meta *protocol.Meta @@ -272,7 +271,7 @@ func (w *Worker) Sync() { // Pull user index and ranking model from master. func (w *Worker) Pull() { - defer base.CheckPanic() + defer util.CheckPanic() for range w.syncedChan.C { pulled := false @@ -511,7 +510,7 @@ func (w *Worker) Recommend(users []data.User) { defer span.End() go func() { - defer base.CheckPanic() + defer util.CheckPanic() completedCount, previousCount := 0, 0 ticker := time.NewTicker(10 * time.Second) for { diff --git a/worker/worker_test.go b/worker/worker_test.go index 6d6abe045..62d784fb7 100644 --- a/worker/worker_test.go +++ b/worker/worker_test.go @@ -30,7 +30,6 @@ import ( "github.com/c-bata/goptuna" mapset "github.com/deckarep/golang-set/v2" - "github.com/gorse-io/gorse/base" "github.com/gorse-io/gorse/common/expression" "github.com/gorse-io/gorse/common/monitor" "github.com/gorse-io/gorse/common/parallel" @@ -531,10 +530,10 @@ func newRankingDataset() (*dataset.Dataset, *dataset.Dataset) { } func newClickDataset() (*ctr.Dataset, *ctr.Dataset) { - dataset := &ctr.Dataset{ - Index: base.NewUnifiedMapIndexBuilder().Build(), + dataSet := &ctr.Dataset{ + Index: dataset.NewUnifiedMapIndexBuilder().Build(), } - return dataset, dataset + return dataSet, dataSet } type mockMaster struct { @@ -572,7 +571,7 @@ func newMockMaster(t *testing.T) *mockMaster { // create user index userIndexBuffer := bytes.NewBuffer(nil) - err = base.MarshalIndex(userIndexBuffer, base.NewMapIndex()) + err = dataset.MarshalIndex(userIndexBuffer, dataset.NewMapIndex()) assert.NoError(t, err) return &mockMaster{