package migration
import (
"database/sql"
"embed"
"errors"
"fmt"
"strings"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database"
"github.com/golang-migrate/migrate/v4/database/mysql"
"github.com/golang-migrate/migrate/v4/database/postgres"
"github.com/golang-migrate/migrate/v4/source/iofs"
"github.com/wxlbd/iot-platform/internal/types"
)
var migrationFS embed.FS
type Manager struct {
migrateFS embed.FS
}
func NewManager() *Manager {
return &Manager{
migrateFS: migrationFS,
}
}
func (m *Manager) RunTenantMigration(dbType types.DBType, config *types.TenantDBConfig) error {
defaultConfig := *config
if dbType == types.DBTypePostgres {
defaultConfig.DbName = "postgres"
} else if dbType == types.DBTypeMySQL {
defaultConfig.DbName = "mysql"
}
defaultDB, err := sql.Open(string(dbType), defaultConfig.GenerateDSN(dbType))
if err != nil {
return fmt.Errorf("connect to default database failed: %w", err)
}
defer defaultDB.Close()
if dbType == types.DBTypePostgres {
if _, err := defaultDB.Exec(fmt.Sprintf("CREATE DATABASE %s;", config.DbName)); err != nil {
if !strings.Contains(err.Error(), "already exists") {
return err
}
}
} else if dbType == types.DBTypeMySQL {
if _, err := defaultDB.Exec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s;", config.DbName)); err != nil {
return err
}
}
db, err := sql.Open(string(dbType), config.GenerateDSN(dbType))
if err != nil {
return fmt.Errorf("connect to new database failed: %w", err)
}
defer db.Close()
d, err := iofs.New(m.migrateFS, "scripts")
if err != nil {
return fmt.Errorf("create migration source failed: %w", err)
}
var driver database.Driver
switch dbType {
case types.DBTypeMySQL:
driver, err = mysql.WithInstance(db, &mysql.Config{})
case types.DBTypePostgres:
driver, err = postgres.WithInstance(db, &postgres.Config{})
default:
return fmt.Errorf("unsupported database type: %s", dbType)
}
if err != nil {
return fmt.Errorf("create migration driver failed: %w", err)
}
migration, err := migrate.NewWithInstance(
"iofs", d,
string(dbType), driver,
)
if err != nil {
return fmt.Errorf("create migrator failed: %w", err)
}
if err := migration.Up(); err != nil && !errors.Is(err, migrate.ErrNoChange) {
return fmt.Errorf("run migration failed: %w", err)
}
return nil
}