package migration
import (
"database/sql"
"embed"
"errors"
"fmt"
"strings"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database"
"github.com/golang-migrate/migrate/v4/database/mysql"
"github.com/golang-migrate/migrate/v4/database/postgres"
"github.com/golang-migrate/migrate/v4/source/iofs"
"github.com/wxlbd/iot-platform/internal/types"
)
//go:embed scripts/*.sql
var migrationFS embed.FS
// Manager 迁移管理器
type Manager struct {
migrateFS embed.FS
}
// NewManager 创建迁移管理器
func NewManager() *Manager {
return &Manager{
migrateFS: migrationFS,
}
}
// RunTenantMigration 执行租户数据库迁移
func (m *Manager) RunTenantMigration(dbType types.DBType, config *types.TenantDBConfig) error {
// 先连接默认数据库
defaultConfig := *config
if dbType == types.DBTypePostgres {
defaultConfig.DbName = "postgres" // PostgreSQL的默认数据库
} else if dbType == types.DBTypeMySQL {
defaultConfig.DbName = "mysql" // MySQL的默认数据库
}
// 连接默认数据库来创建新数据库
defaultDB, err := sql.Open(string(dbType), defaultConfig.GenerateDSN(dbType))
if err != nil {
return fmt.Errorf("connect to default database failed: %w", err)
}
defer defaultDB.Close()
// 创建新数据库
if dbType == types.DBTypePostgres {
if _, err := defaultDB.Exec(fmt.Sprintf("CREATE DATABASE %s;", config.DbName)); err != nil {
// 忽略数据库已存在的错误
if !strings.Contains(err.Error(), "already exists") {
return err
}
}
} else if dbType == types.DBTypeMySQL {
if _, err := defaultDB.Exec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s;", config.DbName)); err != nil {
return err
}
}
// 连接新创建的数据库执行迁移
db, err := sql.Open(string(dbType), config.GenerateDSN(dbType))
if err != nil {
return fmt.Errorf("connect to new database failed: %w", err)
}
defer db.Close()
// 初始化迁移源
d, err := iofs.New(m.migrateFS, "scripts")
if err != nil {
return fmt.Errorf("create migration source failed: %w", err)
}
var driver database.Driver
switch dbType {
case types.DBTypeMySQL:
driver, err = mysql.WithInstance(db, &mysql.Config{})
case types.DBTypePostgres:
driver, err = postgres.WithInstance(db, &postgres.Config{})
default:
return fmt.Errorf("unsupported database type: %s", dbType)
}
if err != nil {
return fmt.Errorf("create migration driver failed: %w", err)
}
// 创建迁移实例
migration, err := migrate.NewWithInstance(
"iofs", d,
string(dbType), driver,
)
if err != nil {
return fmt.Errorf("create migrator failed: %w", err)
}
// 执行迁移
if err := migration.Up(); err != nil && !errors.Is(err, migrate.ErrNoChange) {
return fmt.Errorf("run migration failed: %w", err)
}
return nil
}
github.com/golang-migrate/migrate库用法
最新推荐文章于 2025-11-13 15:05:30 发布
396

被折叠的 条评论
为什么被折叠?



