Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 66 additions & 11 deletions pkg/cron/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ type CronService struct {
mu sync.RWMutex
running bool
stopChan chan struct{}
wakeChan chan struct{}
gronx *gronx.Gronx
}

Expand All @@ -73,6 +74,7 @@ func NewCronService(storePath string, onJob JobHandler) *CronService {
storePath: storePath,
onJob: onJob,
gronx: gronx.New(),
wakeChan: make(chan struct{}),
}
// Initialize and load store on creation
cs.loadStore()
Expand All @@ -97,6 +99,9 @@ func (cs *CronService) Start() error {
}

cs.stopChan = make(chan struct{})
if cs.wakeChan == nil {
cs.wakeChan = make(chan struct{})
}
cs.running = true
go cs.runLoop(cs.stopChan)

Expand All @@ -119,14 +124,47 @@ func (cs *CronService) Stop() {
}

func (cs *CronService) runLoop(stopChan chan struct{}) {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
timer := time.NewTimer(time.Hour)
if !timer.Stop() {
<-timer.C
}
defer timer.Stop()

for {
// every loop, recalculate the next wake time
cs.mu.RLock()
nextWake := cs.getNextWakeMS()
cs.mu.RUnlock()

var delay time.Duration
now := time.Now().UnixMilli()

if nextWake == nil {
// no jobs, sleep for a long time (or until a new job is added)
delay = time.Hour
} else {
diff := *nextWake - now
if diff <= 0 {
delay = 0
} else {
delay = time.Duration(diff) * time.Millisecond
}
}

timer.Reset(delay)

select {
case <-stopChan:
return
case <-ticker.C:
case <-cs.wakeChan: // wake on new job or update
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
continue
case <-timer.C:
cs.checkJobs()
}
}
Expand Down Expand Up @@ -264,22 +302,19 @@ func (cs *CronService) executeJobByID(jobID string) {
}

func (cs *CronService) computeNextRun(schedule *CronSchedule, nowMS int64) *int64 {
if schedule.Kind == "at" {
switch schedule.Kind {
case "at":
if schedule.AtMS != nil && *schedule.AtMS > nowMS {
return schedule.AtMS
}
return nil
}

if schedule.Kind == "every" {
case "every":
if schedule.EveryMS == nil || *schedule.EveryMS <= 0 {
return nil
}
next := nowMS + *schedule.EveryMS
return &next
}

if schedule.Kind == "cron" {
case "cron":
if schedule.Expr == "" {
return nil
}
Expand All @@ -294,9 +329,19 @@ func (cs *CronService) computeNextRun(schedule *CronSchedule, nowMS int64) *int6

nextMS := nextTime.UnixMilli()
return &nextMS
default:
log.Printf("[cron] unknown schedule kind '%s'", schedule.Kind)
return nil
}
}

return nil
// wake up the loop to re-evaluate next wake time immediately (e.g. after add/update/remove jobs)
func (cs *CronService) notify() {
select {
case cs.wakeChan <- struct{}{}:
default:
// if the channel is full, it means the loop will wake up soon anyway, so we can skip sending
}
}

func (cs *CronService) recomputeNextRuns() {
Expand Down Expand Up @@ -400,6 +445,8 @@ func (cs *CronService) AddJob(
return nil, err
}

cs.notify()

return &job, nil
}

Expand All @@ -411,6 +458,9 @@ func (cs *CronService) UpdateJob(job *CronJob) error {
if cs.store.Jobs[i].ID == job.ID {
cs.store.Jobs[i] = *job
cs.store.Jobs[i].UpdatedAtMS = time.Now().UnixMilli()

cs.notify()

return cs.saveStoreUnsafe()
}
}
Expand Down Expand Up @@ -441,6 +491,8 @@ func (cs *CronService) removeJobUnsafe(jobID string) bool {
}
}

cs.notify()

return removed
}

Expand All @@ -463,6 +515,9 @@ func (cs *CronService) EnableJob(jobID string, enabled bool) *CronJob {
if err := cs.saveStoreUnsafe(); err != nil {
log.Printf("[cron] failed to save store after enable: %v", err)
}

cs.notify()

return job
}
}
Expand Down
199 changes: 199 additions & 0 deletions pkg/cron/service_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package cron

import (
"fmt"
"os"
"path/filepath"
"runtime"
"sync"
"testing"
"time"
)

func TestSaveStore_FilePermissions(t *testing.T) {
Expand Down Expand Up @@ -36,3 +39,199 @@ func TestSaveStore_FilePermissions(t *testing.T) {
func int64Ptr(v int64) *int64 {
return &v
}

func setupService(handler JobHandler) (*CronService, string) {
tmpFile := fmt.Sprintf("test_cron_%d.json", time.Now().UnixNano())
cs := NewCronService(tmpFile, handler)
return cs, tmpFile
}

func TestCronService_CRUD(t *testing.T) {
cs, path := setupService(nil)
defer os.Remove(path)

// Test AddJob
at := time.Now().Add(time.Hour).UnixMilli()
job, err := cs.AddJob("Task1", CronSchedule{Kind: "at", AtMS: &at}, "msg", true, "ch", "to")
if err != nil || job.ID == "" {
t.Fatalf("AddJob failed: %v", err)
}

// Test ListJobs
if len(cs.ListJobs(true)) != 1 {
t.Error("ListJobs should return 1 job")
}

// Test UpdateJob
job.Name = "UpdatedName"
err = cs.UpdateJob(job)
if err != nil || cs.store.Jobs[0].Name != "UpdatedName" {
t.Error("UpdateJob failed")
}

// Test EnableJob
cs.EnableJob(job.ID, false)
if cs.store.Jobs[0].Enabled != false || cs.store.Jobs[0].State.NextRunAtMS != nil {
t.Error("EnableJob(false) failed to clear state")
}

// Test RemoveJob
removed := cs.RemoveJob(job.ID)
if !removed || len(cs.store.Jobs) != 0 {
t.Error("RemoveJob failed")
}
}

// 2. Test Cron Expression Calculation Logic
func TestCronService_ComputeNextRun(t *testing.T) {
cs, path := setupService(nil)
defer os.Remove(path)

now := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC).UnixMilli()

tests := []struct {
name string
schedule CronSchedule
wantNil bool
}{
{"Valid Cron", CronSchedule{Kind: "cron", Expr: "0 * * * *"}, false},
{"Invalid Cron", CronSchedule{Kind: "cron", Expr: "invalid"}, true},
{"Every MS", CronSchedule{Kind: "every", EveryMS: int64Ptr(5000)}, false},
{"At Future", CronSchedule{Kind: "at", AtMS: int64Ptr(now + 1000)}, false},
{"At Past", CronSchedule{Kind: "at", AtMS: int64Ptr(now - 1000)}, true},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := cs.computeNextRun(&tt.schedule, now)
if (got == nil) != tt.wantNil {
t.Errorf("%s: got %v, wantNil %v", tt.name, got, tt.wantNil)
}
})
}
}

// 3. Test Execution Flow
func TestCronService_ExecutionFlow(t *testing.T) {
var mu sync.Mutex
executedJobs := make(map[string]bool)

handler := func(job *CronJob) (string, error) {
mu.Lock()
executedJobs[job.ID] = true
mu.Unlock()
return "ok", nil
}

cs, path := setupService(handler)
defer os.Remove(path)

// Start the service
if err := cs.Start(); err != nil {
t.Fatalf("Start failed: %v", err)
}
defer cs.Stop()

// Add a job then runs 100ms from now
target := time.Now().Add(100 * time.Millisecond).UnixMilli()
job, _ := cs.AddJob("FastJob", CronSchedule{Kind: "at", AtMS: &target}, "", false, "", "")

// Check for job execution with a timeout
success := false
for range 20 {
mu.Lock()
if executedJobs[job.ID] {
success = true
mu.Unlock()
break
}
mu.Unlock()
time.Sleep(100 * time.Millisecond)
}

if !success {
t.Error("Job was not executed in time")
}

// check that the job is removed after execution (DeleteAfterRun = true)
status := cs.Status()
if status["jobs"].(int) != 0 {
t.Errorf("Job should be deleted after run, got count: %v", status["jobs"])
}
}

func TestCronService_PersistenceIntegrity(t *testing.T) {
tmpFile := "persist_test.json"
defer os.Remove(tmpFile)

// write a job and persist
cs1 := NewCronService(tmpFile, nil)
at := int64(2000000000000)
cs1.AddJob("PersistMe", CronSchedule{Kind: "at", AtMS: &at}, "payload", true, "ch1", "")

// check file exists
if _, err := os.Stat(tmpFile); os.IsNotExist(err) {
t.Fatal("Store file was not created")
}

// reload and check data integrity
cs2 := NewCronService(tmpFile, nil)
if err := cs2.Load(); err != nil {
t.Fatalf("Failed to load store: %v", err)
}

jobs := cs2.ListJobs(true)
if len(jobs) != 1 || jobs[0].Name != "PersistMe" {
t.Errorf("Data corruption after reload. Got: %+v", jobs)
}

// test loading invalid JSON
os.WriteFile(tmpFile, []byte("{invalid json}"), 0o644)
cs3 := NewCronService(tmpFile, nil)
err := cs3.loadStore()
if err == nil {
t.Error("Should return error when loading invalid JSON")
}
}

func TestCronService_ConcurrentAccess(t *testing.T) {
cs, path := setupService(nil)
defer os.Remove(path)

cs.Start()
defer cs.Stop()

var wg sync.WaitGroup
workers := 10
iterations := 50

wg.Add(workers * 2)

// add jobs concurrently
for i := range workers {
go func(id int) {
defer wg.Done()
for j := range iterations {
at := time.Now().Add(time.Hour).UnixMilli()
cs.AddJob(fmt.Sprintf("Job-%d-%d", id, j), CronSchedule{Kind: "at", AtMS: &at}, "", false, "", "")
time.Sleep(100 * time.Microsecond)
}
}(i)
}

// read and update jobs concurrently
for range workers {
go func() {
defer wg.Done()
for j := range iterations {
jobs := cs.ListJobs(true)
if len(jobs) > 0 {
cs.EnableJob(jobs[0].ID, j%2 == 0)
}
time.Sleep(100 * time.Microsecond)
}
}()
}

wg.Wait()
}
Loading