Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ DB_SSL_MODE=disable
DB_MAX_IDLE_CONNS=10
DB_MAX_OPEN_CONNS=30
DB_CONN_MAX_LIFETIME=3600
# DB_EXTRAS in GORM format
DB_EXTRAS=
DB_CHARSET=

DIFY_INVOCATION_CONNECTION_IDLE_TIMEOUT=120

Expand Down
52 changes: 28 additions & 24 deletions internal/db/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,31 +58,35 @@ func autoMigrate() error {
func Init(config *app.Config) {
var err error
if config.DBType == "postgresql" {
DifyPluginDB, err = pg.InitPluginDB(
config.DBHost,
int(config.DBPort),
config.DBDatabase,
config.DBDefaultDatabase,
config.DBUsername,
config.DBPassword,
config.DBSslMode,
config.DBMaxIdleConns,
config.DBMaxOpenConns,
config.DBConnMaxLifetime,
)
DifyPluginDB, err = pg.InitPluginDB(&pg.PGConfig{
Host: config.DBHost,
Port: int(config.DBPort),
DBName: config.DBDatabase,
DefaultDBName: config.DBDefaultDatabase,
User: config.DBUsername,
Pass: config.DBPassword,
SSLMode: config.DBSslMode,
MaxIdleConns: config.DBMaxIdleConns,
MaxOpenConns: config.DBMaxOpenConns,
ConnMaxLifetime: config.DBConnMaxLifetime,
Charset: config.DBCharset,
Extras: config.DBExtras,
})
} else if config.DBType == "mysql" {
DifyPluginDB, err = mysql.InitPluginDB(
config.DBHost,
int(config.DBPort),
config.DBDatabase,
config.DBDefaultDatabase,
config.DBUsername,
config.DBPassword,
config.DBSslMode,
config.DBMaxIdleConns,
config.DBMaxOpenConns,
config.DBConnMaxLifetime,
)
DifyPluginDB, err = mysql.InitPluginDB(&mysql.MySQLConfig{
Host: config.DBHost,
Port: int(config.DBPort),
DBName: config.DBDatabase,
DefaultDBName: config.DBDefaultDatabase,
User: config.DBUsername,
Pass: config.DBPassword,
SSLMode: config.DBSslMode,
MaxIdleConns: config.DBMaxIdleConns,
MaxOpenConns: config.DBMaxOpenConns,
ConnMaxLifetime: config.DBConnMaxLifetime,
Charset: config.DBCharset,
Extras: config.DBExtras,
})
} else {
log.Panic("unsupported database type: %v", config.DBType)
}
Expand Down
42 changes: 29 additions & 13 deletions internal/db/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,47 @@ import (
"gorm.io/gorm"
)

func InitPluginDB(host string, port int, dbName string, defaultDbName string, user string, password string, sslMode string, maxIdleConns int, maxOpenConns int, connMaxLifetime int) (*gorm.DB, error) {
type MySQLConfig struct {
Host string
Port int
DBName string
DefaultDBName string
User string
Pass string
SSLMode string
MaxIdleConns int
MaxOpenConns int
ConnMaxLifetime int
Charset string
Extras string
}

func InitPluginDB(config *MySQLConfig) (*gorm.DB, error) {
// TODO: MySQL dose not support DB_EXTRAS now
initializer := mysqlDbInitializer{
host: host,
port: port,
user: user,
password: password,
sslMode: sslMode,
host: config.Host,
port: config.Port,
user: config.User,
password: config.Pass,
sslMode: config.SSLMode,
}

// first try to connect to target database
db, err := initializer.connect(dbName)
db, err := initializer.connect(config.DBName)
if err != nil {
// if connection fails, try to create database
db, err = initializer.connect(defaultDbName)
db, err = initializer.connect(config.DefaultDBName)
if err != nil {
return nil, err
}

err = initializer.createDatabaseIfNotExists(db, dbName)
err = initializer.createDatabaseIfNotExists(db, config.DBName)
if err != nil {
return nil, err
}

// connect to the new db
db, err = initializer.connect(dbName)
db, err = initializer.connect(config.DBName)
if err != nil {
return nil, err
}
Expand All @@ -44,9 +60,9 @@ func InitPluginDB(host string, port int, dbName string, defaultDbName string, us
}

// configure connection pool
pool.SetMaxIdleConns(maxIdleConns)
pool.SetMaxOpenConns(maxOpenConns)
pool.SetConnMaxLifetime(time.Duration(connMaxLifetime) * time.Second)
pool.SetMaxIdleConns(config.MaxIdleConns)
pool.SetMaxOpenConns(config.MaxOpenConns)
pool.SetConnMaxLifetime(time.Duration(config.ConnMaxLifetime) * time.Second)

return db, nil
}
Expand Down
63 changes: 53 additions & 10 deletions internal/db/pg/pg.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,56 @@ import (
"gorm.io/gorm"
)

func InitPluginDB(host string, port int, db_name string, default_db_name string, user string, pass string, sslmode string, maxIdleConns int, maxOpenConns int, connMaxLifetime int) (*gorm.DB, error) {
type PGConfig struct {
Host string
Port int
DBName string
DefaultDBName string
User string
Pass string
SSLMode string
MaxIdleConns int
MaxOpenConns int
ConnMaxLifetime int
Charset string
Extras string
}

func InitPluginDB(config *PGConfig) (*gorm.DB, error) {
// first try to connect to target database
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", host, port, user, pass, db_name, sslmode)
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
config.Host,
config.Port,
config.User,
config.Pass,
config.DBName,
config.SSLMode,
)
if config.Charset != "" {
dsn = fmt.Sprintf("%s client_encoding=%s", dsn, config.Charset)
}
if config.Extras != "" {
dsn = fmt.Sprintf("%s %s", dsn, config.Extras)
}

db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
if err != nil {
// if connection fails, try to create database
dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", host, port, user, pass, default_db_name, sslmode)
dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
config.Host,
config.Port,
config.User,
config.Pass,
config.DefaultDBName,
config.SSLMode,
)
if config.Charset != "" {
dsn = fmt.Sprintf("%s client_encoding=%s", dsn, config.Charset)
}
if config.Extras != "" {
dsn = fmt.Sprintf("%s %s", dsn, config.Extras)
}

db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
if err != nil {
return nil, err
Expand All @@ -27,21 +70,21 @@ func InitPluginDB(host string, port int, db_name string, default_db_name string,
defer pgsqlDB.Close()

// check if the db exists
rows, err := pgsqlDB.Query(fmt.Sprintf("SELECT 1 FROM pg_database WHERE datname = '%s'", db_name))
rows, err := pgsqlDB.Query(fmt.Sprintf("SELECT 1 FROM pg_database WHERE datname = '%s'", config.DBName))
if err != nil {
return nil, err
}

if !rows.Next() {
// create database
_, err = pgsqlDB.Exec(fmt.Sprintf("CREATE DATABASE %s", db_name))
_, err = pgsqlDB.Exec(fmt.Sprintf("CREATE DATABASE %s", config.DBName))
if err != nil {
return nil, err
}
}

// connect to the new db
dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", host, port, user, pass, db_name, sslmode)
dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", config.Host, config.Port, config.User, config.Pass, config.DBName, config.SSLMode)
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
if err != nil {
return nil, err
Expand All @@ -68,9 +111,9 @@ func InitPluginDB(host string, port int, db_name string, default_db_name string,
}

// configure connection pool
pgsqlDB.SetMaxIdleConns(maxIdleConns)
pgsqlDB.SetMaxOpenConns(maxOpenConns)
pgsqlDB.SetConnMaxLifetime(time.Duration(connMaxLifetime) * time.Second)
pgsqlDB.SetMaxIdleConns(config.MaxIdleConns)
pgsqlDB.SetMaxOpenConns(config.MaxOpenConns)
pgsqlDB.SetConnMaxLifetime(time.Duration(config.ConnMaxLifetime) * time.Second)

return db, nil
}
8 changes: 5 additions & 3 deletions internal/types/app/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,11 @@ type Config struct {
DBSslMode string `envconfig:"DB_SSL_MODE" validate:"required,oneof=disable require"`

// database connection pool settings
DBMaxIdleConns int `envconfig:"DB_MAX_IDLE_CONNS" default:"10"`
DBMaxOpenConns int `envconfig:"DB_MAX_OPEN_CONNS" default:"30"`
DBConnMaxLifetime int `envconfig:"DB_CONN_MAX_LIFETIME" default:"3600"`
DBMaxIdleConns int `envconfig:"DB_MAX_IDLE_CONNS" default:"10"`
DBMaxOpenConns int `envconfig:"DB_MAX_OPEN_CONNS" default:"30"`
DBConnMaxLifetime int `envconfig:"DB_CONN_MAX_LIFETIME" default:"3600"`
DBExtras string `envconfig:"DB_EXTRAS"`
DBCharset string `envconfig:"DB_CHARSET"`

// persistence storage
PersistenceStoragePath string `envconfig:"PERSISTENCE_STORAGE_PATH"`
Expand Down