Skip to content

Commit 4ad2051

Browse files
committed
updates to the cache implementation
1 parent 3c2b2cb commit 4ad2051

4 files changed

Lines changed: 158 additions & 73 deletions

File tree

engine/cache/cache.go

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,28 +17,38 @@ import (
1717

1818
type Cache struct {
1919
sync.Mutex
20+
start time.Time
21+
freq time.Duration
2022
done chan struct{}
23+
cdone chan struct{}
2124
cache repository.Repository
2225
db repository.Repository
2326
queue queue.Queue
2427
}
2528

26-
func New(database repository.Repository) (repository.Repository, error) {
29+
func New(database repository.Repository, done chan struct{}) (*Cache, error) {
2730
if db := assetdb.New(sqlrepo.SQLiteMemory, ""); db != nil {
2831
c := &Cache{
32+
start: time.Now(),
33+
freq: 10 * time.Minute,
34+
done: done,
35+
cdone: make(chan struct{}, 1),
2936
cache: db.Repo,
30-
done: make(chan struct{}, 1),
3137
db: database,
3238
queue: queue.NewQueue(),
3339
}
3440

3541
go c.processDBCallbacks()
3642
return c, nil
3743
}
38-
3944
return nil, errors.New("failed to create the cache repository")
4045
}
4146

47+
// StartTime returns the time that the cache was created.
48+
func (c *Cache) StartTime() time.Time {
49+
return c.start
50+
}
51+
4252
// Close implements the Repository interface.
4353
func (c *Cache) Close() error {
4454
c.Lock()
@@ -50,12 +60,12 @@ func (c *Cache) Close() error {
5060
}
5161
}
5262

53-
close(c.done)
63+
close(c.cdone)
5464
for {
5565
if c.queue.Empty() {
5666
break
5767
}
58-
time.Sleep(5 * time.Second)
68+
time.Sleep(2 * time.Second)
5969
}
6070
return nil
6171
}
@@ -66,6 +76,13 @@ func (c *Cache) GetDBType() string {
6676
}
6777

6878
func (c *Cache) appendToDBQueue(callback func()) {
79+
select {
80+
case <-c.done:
81+
return
82+
case <-c.cdone:
83+
return
84+
default:
85+
}
6986
c.queue.Append(callback)
7087
}
7188

@@ -75,6 +92,8 @@ loop:
7592
select {
7693
case <-c.done:
7794
break loop
95+
case <-c.cdone:
96+
break loop
7897
case <-c.queue.Signal():
7998
element, ok := c.queue.Next()
8099

@@ -87,10 +106,6 @@ loop:
87106
}
88107
}
89108
}
90-
91-
c.queue.Process(func(data interface{}) {
92-
if callback, ok := data.(func()); ok {
93-
callback()
94-
}
95-
})
109+
// drain the callback queue of all remaining elements
110+
c.queue.Process(func(data interface{}) {})
96111
}

engine/cache/entity.go

Lines changed: 94 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,16 @@ func (c *Cache) CreateEntity(asset oam.Asset) (*types.Entity, error) {
2121
return nil, err
2222
}
2323

24-
c.appendToDBQueue(func() {
25-
_, _ = c.db.CreateEntity(asset)
26-
})
24+
if tag, found := c.checkCacheEntityTag(entity, "cache_create_entity"); !found {
25+
if last, err := time.Parse("2006-01-02 15:04:05", tag.Value()); err == nil && time.Now().Add(-1*c.freq).After(last) {
26+
_ = c.cache.DeleteEntityTag(tag.ID)
27+
_ = c.createCacheEntityTag(entity, "cache_create_entity")
28+
29+
c.appendToDBQueue(func() {
30+
_, _ = c.db.CreateEntity(asset)
31+
})
32+
}
33+
}
2734

2835
return entity, nil
2936
}
@@ -62,18 +69,100 @@ func (c *Cache) FindEntityById(id string) (*types.Entity, error) {
6269

6370
// FindEntityByContent implements the Repository interface.
6471
func (c *Cache) FindEntityByContent(asset oam.Asset, since time.Time) ([]*types.Entity, error) {
72+
c.Lock()
73+
entities, err := c.cache.FindEntityByContent(asset, since)
74+
if err == nil && len(entities) == 1 {
75+
if !since.IsZero() && !since.Before(c.start) {
76+
c.Unlock()
77+
return entities, err
78+
}
79+
if _, found := c.checkCacheEntityTag(entities[0], "cache_find_entity_by_content"); found {
80+
c.Unlock()
81+
return entities, err
82+
}
83+
}
84+
c.Unlock()
85+
86+
var dberr error
87+
var dbentities []*types.Entity
88+
done := make(chan struct{}, 1)
89+
c.appendToDBQueue(func() {
90+
defer func() { done <- struct{}{} }()
91+
92+
dbentities, dberr = c.db.FindEntityByContent(asset, since)
93+
})
94+
<-done
95+
close(done)
96+
97+
if dberr != nil {
98+
return entities, err
99+
}
100+
65101
c.Lock()
66102
defer c.Unlock()
67103

68-
return c.cache.FindEntityByContent(asset, since)
104+
var results []*types.Entity
105+
for _, entity := range dbentities {
106+
if e, err := c.cache.CreateEntity(entity.Asset); err == nil {
107+
results = append(results, e)
108+
if tags, err := c.cache.GetEntityTags(entity, c.start, "cache_find_entity_by_content"); err == nil && len(tags) > 0 {
109+
for _, tag := range tags {
110+
_ = c.cache.DeleteEntityTag(tag.ID)
111+
}
112+
}
113+
_ = c.createCacheEntityTag(entity, "cache_find_entity_by_content")
114+
}
115+
}
116+
return results, nil
69117
}
70118

71119
// FindEntitiesByType implements the Repository interface.
72120
func (c *Cache) FindEntitiesByType(atype oam.AssetType, since time.Time) ([]*types.Entity, error) {
121+
c.Lock()
122+
entities, err := c.cache.FindEntitiesByType(atype, since)
123+
if err == nil && len(entities) > 0 {
124+
if !since.IsZero() && !since.Before(c.start) {
125+
c.Unlock()
126+
return entities, err
127+
}
128+
if _, found := c.checkCacheEntityTag(entities[0], "cache_find_entities_by_type"); found {
129+
c.Unlock()
130+
return entities, err
131+
}
132+
}
133+
c.Unlock()
134+
135+
var dberr error
136+
var dbentities []*types.Entity
137+
done := make(chan struct{}, 1)
138+
c.appendToDBQueue(func() {
139+
defer func() { done <- struct{}{} }()
140+
141+
dbentities, dberr = c.db.FindEntitiesByType(atype, since)
142+
})
143+
<-done
144+
close(done)
145+
146+
if dberr != nil {
147+
return entities, err
148+
}
149+
73150
c.Lock()
74151
defer c.Unlock()
75152

76-
return c.cache.FindEntitiesByType(atype, since)
153+
var results []*types.Entity
154+
for _, entity := range dbentities {
155+
if e, err := c.cache.CreateEntity(entity.Asset); err == nil {
156+
results = append(results, e)
157+
if tags, err := c.cache.GetEntityTags(entity, c.start, "cache_find_entities_by_type"); err == nil && len(tags) > 0 {
158+
for _, tag := range tags {
159+
_ = c.cache.DeleteEntityTag(tag.ID)
160+
}
161+
}
162+
_ = c.createCacheEntityTag(entity, "cache_find_entities_by_type")
163+
}
164+
}
165+
return results, nil
77166
}
78167

79168
// DeleteEntity implements the Repository interface.

engine/cache/tag.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010

1111
"github.com/owasp-amass/asset-db/types"
1212
oam "github.com/owasp-amass/open-asset-model"
13+
"github.com/owasp-amass/open-asset-model/property"
1314
)
1415

1516
// CreateEntityTag implements the Repository interface.
@@ -178,3 +179,33 @@ func (c *Cache) DeleteEdgeTag(id string) error {
178179

179180
return nil
180181
}
182+
183+
func (c *Cache) createCacheEntityTag(entity *types.Entity, name string) error {
184+
_, err := c.cache.CreateEntityTag(entity, &property.SimpleProperty{
185+
PropertyName: name,
186+
PropertyValue: time.Now().Format("2006-01-02 15:04:05"),
187+
})
188+
return err
189+
}
190+
191+
func (c *Cache) checkCacheEntityTag(entity *types.Entity, name string) (*types.EntityTag, bool) {
192+
if tags, err := c.cache.GetEntityTags(entity, time.Time{}, name); err == nil && len(tags) == 1 {
193+
return tags[0], true
194+
}
195+
return nil, false
196+
}
197+
198+
func (c *Cache) createCacheEdgeTag(edge *types.Edge, name string) error {
199+
_, err := c.cache.CreateEdgeTag(edge, &property.SimpleProperty{
200+
PropertyName: name,
201+
PropertyValue: time.Now().Format("2006-01-02 15:04:05"),
202+
})
203+
return err
204+
}
205+
206+
func (c *Cache) checkCacheEdgeTag(edge *types.Edge, name string) (*types.EdgeTag, bool) {
207+
if tags, err := c.cache.GetEdgeTags(edge, time.Time{}, name); err == nil && len(tags) == 1 {
208+
return tags[0], true
209+
}
210+
return nil, false
211+
}

engine/sessions/session.go

Lines changed: 7 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -5,27 +5,20 @@
55
package sessions
66

77
import (
8-
"embed"
98
"errors"
109
"fmt"
1110
"log/slog"
1211
"path/filepath"
1312
"strings"
1413

15-
"github.com/glebarez/sqlite"
1614
"github.com/google/uuid"
1715
"github.com/owasp-amass/amass/v4/config"
1816
"github.com/owasp-amass/amass/v4/engine/cache"
1917
"github.com/owasp-amass/amass/v4/engine/pubsub"
2018
"github.com/owasp-amass/amass/v4/engine/sessions/scope"
2119
et "github.com/owasp-amass/amass/v4/engine/types"
2220
assetdb "github.com/owasp-amass/asset-db"
23-
pgmigrations "github.com/owasp-amass/asset-db/migrations/postgres"
24-
sqlitemigrations "github.com/owasp-amass/asset-db/migrations/sqlite3"
2521
"github.com/owasp-amass/asset-db/repository"
26-
migrate "github.com/rubenv/sql-migrate"
27-
"gorm.io/driver/postgres"
28-
"gorm.io/gorm"
2922
)
3023

3124
type Session struct {
@@ -36,8 +29,8 @@ type Session struct {
3629
scope *scope.Scope
3730
db *assetdb.AssetDB
3831
dsn string
39-
dbtype repository.DBType
40-
c cache.Cache
32+
dbtype string
33+
c *cache.Cache
4134
stats *et.SessionStats
4235
done chan struct{}
4336
}
@@ -55,7 +48,6 @@ func CreateSession(cfg *config.Config) (et.Session, error) {
5548
cfg: cfg,
5649
scope: scope.CreateFromConfigScope(cfg),
5750
ps: pubsub.NewLogger(),
58-
c: cache.NewOAMCache(nil),
5951
stats: new(et.SessionStats),
6052
done: make(chan struct{}),
6153
}
@@ -64,6 +56,11 @@ func CreateSession(cfg *config.Config) (et.Session, error) {
6456
if err := s.setupDB(); err != nil {
6557
return nil, err
6658
}
59+
60+
s.c = cache.New(s.db.Repo)
61+
if s.c == nil {
62+
return nil, errors.New("failed to create the session cache")
63+
}
6764
return s, nil
6865
}
6966

@@ -87,10 +84,6 @@ func (s *Session) Scope() *scope.Scope {
8784
return s.scope
8885
}
8986

90-
func (s *Session) DB() *assetdb.AssetDB {
91-
return s.db
92-
}
93-
9487
func (s *Session) Cache() cache.Cache {
9588
return s.c
9689
}
@@ -121,9 +114,6 @@ func (s *Session) setupDB() error {
121114
if err := s.selectDBMS(); err != nil {
122115
return err
123116
}
124-
if err := s.migrations(); err != nil {
125-
return err
126-
}
127117
return nil
128118
}
129119

@@ -169,43 +159,3 @@ func (s *Session) selectDBMS() error {
169159
s.db = store
170160
return nil
171161
}
172-
173-
func (s *Session) migrations() error {
174-
var name string
175-
var fs embed.FS
176-
var database gorm.Dialector
177-
178-
switch s.dbtype {
179-
case repository.SQLite:
180-
name = "sqlite3"
181-
fs = sqlitemigrations.Migrations()
182-
database = sqlite.Open(s.dsn)
183-
case repository.Postgres:
184-
name = "postgres"
185-
fs = pgmigrations.Migrations()
186-
database = postgres.Open(s.dsn)
187-
default:
188-
return fmt.Errorf("unsupported database type: %s", s.dbtype)
189-
}
190-
// Initialize the GORM database connection
191-
sql, err := gorm.Open(database, &gorm.Config{})
192-
if err != nil {
193-
return fmt.Errorf("failed to open database: %s", err)
194-
}
195-
// Set up migrations
196-
migrationsSource := migrate.EmbedFileSystemMigrationSource{
197-
FileSystem: fs,
198-
Root: "/",
199-
}
200-
// Extract the raw SQL database instance
201-
sqlDb, err := sql.DB()
202-
if err != nil {
203-
return fmt.Errorf("failed to extract raw SQL DB from GORM: %s", err)
204-
}
205-
// Run migrations
206-
_, err = migrate.Exec(sqlDb, name, migrationsSource, migrate.Up)
207-
if err != nil {
208-
return fmt.Errorf("failed to execute migrations: %s", err)
209-
}
210-
return nil
211-
}

0 commit comments

Comments
 (0)