gh_mirrors/ge/geekai插件开发指南:MJ绘画与Stable Diffusion集成实战
引言:AI绘画插件开发的痛点与解决方案
你是否在开发AI绘画插件时遇到以下挑战:如何设计统一的任务队列机制?如何处理多模型API的差异化调用?如何实现实时进度更新与前端交互?本文将通过实战案例,详细解析geekai项目中MidJourney(MJ)和Stable Diffusion(SD)两大绘画插件的集成方案,帮助开发者快速掌握AI绘画插件的核心开发技巧。
读完本文你将获得:
- 基于Go语言的AI绘画任务队列设计与实现
- MJ/SD多模型API的统一封装策略
- 任务进度实时追踪与前端通知机制
- 分布式环境下的资源管理与错误处理
一、项目架构概览
1.1 系统整体架构
geekai项目采用前后端分离架构,AI绘画插件作为核心服务模块,通过RESTful API与前端交互。以下是系统架构的核心组件:
1.2 插件模块结构
AI绘画插件模块位于api/service/目录下,采用领域驱动设计(DDD)思想组织代码:
api/service/
├── mj/ # MidJourney服务
│ ├── client.go # MJ API客户端
│ └── service.go # MJ服务实现
└── sd/ # Stable Diffusion服务
└── service.go # SD服务实现
二、核心概念与数据模型
2.1 任务生命周期
AI绘画任务从创建到完成经历以下阶段:
2.2 核心数据模型
任务模型(MidJourneyJob/SdJob):
| 字段名 | 类型 | 描述 |
|---|---|---|
| Id | uint | 任务ID |
| UserId | uint | 用户ID |
| TaskId | string | 外部API任务ID |
| Prompt | string | 提示词 |
| Progress | int | 进度(0-100) |
| ImgURL | string | 生成图片URL |
| Status | string | 任务状态 |
| CreatedAt | time.Time | 创建时间 |
| UpdatedAt | time.Time | 更新时间 |
任务参数模型(MjTask/SdTask):
// MJ任务参数
type MjTask struct {
Id uint // 任务ID
UserId uint // 用户ID
ClientId string // 客户端ID
Type string // 任务类型(Imagine/Upscale等)
Prompt string // 提示词
NegPrompt string // 反向提示词
ModelId string // 模型ID
// 其他参数...
}
// SD任务参数
type SdTask struct {
Id int // 任务ID
UserId uint // 用户ID
ClientId string // 客户端ID
TranslateModelId string // 翻译模型ID
Params struct { // 生成参数
Prompt string // 提示词
NegPrompt string // 反向提示词
Steps int // 步数
CfgScale float32 // CFG缩放
Width int // 宽度
Height int // 高度
Sampler string // 采样器
Scheduler string // 调度器
Seed int // 种子
HdFix bool // 是否高清修复
// 其他参数...
}
}
三、MidJourney插件实现
3.1 API客户端设计
client.go封装了MidJourney API的调用逻辑,采用接口隔离原则设计:
// MJ API客户端接口
type Client interface {
Imagine(task types.MjTask) (ImageRes, error) // 生成图片
Upscale(task types.MjTask) (ImageRes, error) // 放大图片
Variation(task types.MjTask) (ImageRes, error) // 变体生成
Blend(task types.MjTask) (ImageRes, error) // 图片融合
SwapFace(task types.MjTask) (ImageRes, error) // 人脸处理
QueryTask(taskId, channelId string) (Task, error) // 查询任务状态
}
3.2 服务实现核心逻辑
service.go实现了任务管理的核心功能,采用生产者-消费者模式处理任务队列:
// 启动任务消费者
func (s *Service) Run() {
// 加载未完成任务到队列
var jobs []model.MidJourneyJob
s.db.Where("task_id", "").Where("progress", 0).Find(&jobs)
for _, v := range jobs {
var task types.MjTask
err := utils.JsonDecode(v.TaskInfo, &task)
if err != nil {
logger.Errorf("decode task info with error: %v", err)
continue
}
task.Id = v.Id
s.clientIds[task.Id] = task.ClientId
s.PushTask(task)
}
// 启动消费者协程
go func() {
for {
var task types.MjTask
err := s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
continue
}
// 处理提示词翻译
if utils.HasChinese(task.Prompt) {
content, err := utils.OpenAIRequest(s.db,
fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt),
task.TranslateModelId)
if err == nil {
task.Prompt = content
}
}
// 根据任务类型调用不同API
var res ImageRes
switch task.Type {
case types.TaskImage:
res, err = s.client.Imagine(task)
case types.TaskUpscale:
res, err = s.client.Upscale(task)
case types.TaskVariation:
res, err = s.client.Variation(task)
// 其他任务类型...
}
// 处理API响应
if err != nil || (res.Code != 1 && res.Code != 22) {
// 错误处理逻辑
// ...
} else {
// 更新任务状态
// ...
}
}
}()
}
3.3 任务进度追踪
MJ服务通过定时查询API和WebSocket通知实现进度追踪:
// 同步任务进度
func (s *Service) SyncTaskProgress() {
go func() {
var jobs []model.MidJourneyJob
for {
err := s.db.Where("progress < ?", 100).Find(&jobs).Error
if err != nil {
continue
}
for _, job := range jobs {
// 超时任务处理
if time.Now().Sub(job.CreatedAt) > time.Minute*10 {
job.Progress = service.FailTaskProgress
job.ErrMsg = "任务超时"
s.db.Updates(&job)
continue
}
// 查询任务进度
task, err := s.client.QueryTask(job.TaskId, job.ChannelId)
if err != nil {
logger.Errorf("error with query task: %v", err)
continue
}
// 更新进度
job.Progress = utils.IntValue(
strings.Replace(task.Progress, "%", "", 1), 0)
s.db.Updates(&job)
// 通知前端
if job.Progress == 100 {
s.notifyQueue.RPush(service.NotifyMessage{
ClientId: s.clientIds[job.Id],
UserId: job.UserId,
JobId: int(job.Id),
Message: service.TaskStatusFinished})
}
}
time.Sleep(time.Second * 5)
}
}()
}
四、Stable Diffusion插件实现
4.1 API封装与调用
SD服务直接调用Stable Diffusion Web UI API,实现文生图功能:
// 文生图API调用
func (s *Service) Txt2Img(task types.SdTask) error {
body := Txt2ImgReq{
Prompt: task.Params.Prompt,
NegativePrompt: task.Params.NegPrompt,
Steps: task.Params.Steps,
CfgScale: task.Params.CfgScale,
Width: task.Params.Width,
Height: task.Params.Height,
SamplerName: task.Params.Sampler,
Scheduler: task.Params.Scheduler,
ForceTaskId: task.Params.TaskId,
}
// 高清修复参数
if task.Params.HdFix {
body.EnableHr = true
body.HrScale = task.Params.HdScale
body.HrUpscaler = task.Params.HdScaleAlg
body.HrSecondPassSteps = task.Params.HdSteps
body.DenoisingStrength = task.Params.HdRedrawRate
}
// 发送请求
response, err := s.httpClient.R().
SetHeader("Authorization", apiKey.Value).
SetBody(body).
SetSuccessResult(&res).
Post(apiURL)
// 处理响应...
}
4.2 任务队列与资源管理
SD服务使用Redis实现分布式任务队列,确保系统可扩展性:
// 初始化任务队列
func NewService(db *gorm.DB, manager *oss.UploaderManager,
levelDB *store.LevelDB, redisCli *redis.Client,
wsService *service.WebsocketService,
userService *service.UserService) *Service {
return &Service{
httpClient: req.C(),
taskQueue: store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli),
notifyQueue: store.NewRedisQueue("StableDiffusion_Queue", redisCli),
db: db,
wsService: wsService,
uploadManager: manager,
userService: userService,
}
}
// 推送任务到队列
func (s *Service) PushTask(task types.SdTask) {
logger.Debugf("add a new Stable-Diffusion task to the task list: %+v", task)
s.taskQueue.RPush(task)
}
五、前端交互与实时通知
5.1 WebSocket通知机制
服务通过WebSocket向前端推送任务状态更新:
// 通知前端任务状态
func (s *Service) CheckTaskNotify() {
go func() {
logger.Info("Running Stable-Diffusion task notify checking ...")
for {
var message service.NotifyMessage
err := s.notifyQueue.LPop(&message)
if err != nil {
continue
}
logger.Debugf("notify message: %+v", message)
client := s.wsService.Clients.Get(message.ClientId)
if client == nil {
continue
}
utils.SendChannelMsg(client, types.ChSd, message.Message)
}
}()
}
5.2 前端任务状态处理
Vue前端通过WebSocket接收任务状态更新:
// 监听SD任务状态
this.$socket.on('sd', (data) => {
const { jobId, message } = data;
if (message === 'finished') {
this.loadImageResult(jobId);
this.showNotification('任务完成', 'success');
} else if (message === 'failed') {
this.showNotification('任务失败', 'error');
} else if (message === 'running') {
this.updateTaskProgress(jobId);
}
});
六、错误处理与资源回收
6.1 健壮的错误处理策略
系统实现多层次错误处理机制:
// 任务错误处理
if err != nil || (res.Code != 1 && res.Code != 22) {
var errMsg string
if err != nil {
errMsg = err.Error()
} else {
errMsg = fmt.Sprintf("%v,%s", err, res.Description)
}
logger.Error("绘画任务执行失败:", errMsg)
job.Progress = service.FailTaskProgress
job.ErrMsg = errMsg
s.db.Updates(&job)
// 通知前端失败
s.notifyQueue.RPush(service.NotifyMessage{
ClientId: task.ClientId,
UserId: task.UserId,
JobId: int(job.Id),
Message: service.TaskStatusFailed
})
continue
}
6.2 超时任务处理
定期检查并处理超时任务,释放系统资源:
// 检查超时任务
func (s *Service) CheckTaskStatus() {
go func() {
logger.Info("Running Stable-Diffusion task status checking ...")
for {
var jobs []model.SdJob
res := s.db.Where("progress < ?", 100).Find(&jobs)
if res.Error != nil {
time.Sleep(5 * time.Second)
continue
}
for _, job := range jobs {
// 5分钟超时任务标记为失败
if time.Now().Sub(job.CreatedAt) > time.Minute*5 {
job.Progress = service.FailTaskProgress
job.ErrMsg = "任务超时"
s.db.Updates(&job)
}
}
// 恢复失败任务的算力
s.db.Where("progress", service.FailTaskProgress).
Where("power > ?", 0).Find(&jobs)
for _, job := range jobs {
err := s.userService.IncreasePower(job.UserId, job.Power, model.PowerLog{
Type: types.PowerRefund,
Model: "stable-diffusion",
Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%d, Err: %s", job.Id, job.ErrMsg),
})
// ...
}
time.Sleep(time.Second * 5)
}
}()
}
七、高级功能实现
7.1 提示词自动翻译
针对中文提示词,系统自动翻译成英文以获得更好效果:
// 提示词翻译
if utils.HasChinese(task.Prompt) {
content, err := utils.OpenAIRequest(s.db,
fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt),
task.TranslateModelId)
if err == nil {
task.Prompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
7.2 分布式图片存储
生成的图片通过统一上传管理器存储到多种存储后端:
// 上传图片
imgURL, err := s.uploadManager.GetUploadHandler().PutBase64(res.Images[0])
if err != nil {
errChan <- fmt.Errorf("error with upload image: %v", err)
return
}
// 更新数据库
s.db.Model(&model.SdJob{Id: uint(task.Id)}).
UpdateColumns(model.SdJob{ImgURL: imgURL, Params: utils.JsonEncode(task.Params)})
八、性能优化与最佳实践
8.1 性能优化策略
- 任务优先级队列:根据用户等级设置任务优先级
- 批量处理:合并多个小任务减少API调用次数
- 缓存策略:缓存常用提示词翻译结果
- 资源池化:HTTP客户端连接池复用
8.2 代码最佳实践
- 接口抽象:依赖接口而非具体实现
- 错误处理:详细日志记录与用户友好提示
- 配置外部化:敏感参数通过配置文件注入
- 监控指标:关键操作添加性能指标监控
九、扩展与定制指南
9.1 添加新的AI绘画模型
要集成新的AI绘画模型,需实现以下步骤:
- 创建服务目录
api/service/newmodel/ - 实现服务接口
Service和Client - 添加任务数据模型
NewmodelJob - 实现任务队列与通知机制
- 前端添加对应UI组件
9.2 自定义任务参数
扩展任务参数示例:
// 添加风格迁移参数
type SdTaskParams struct {
// 原有参数...
StyleModel string `json:"style_model"` // 风格模型
StyleStrength float32 `json:"style_strength"` // 风格强度
}
十、总结与展望
本文详细介绍了geekai项目中AI绘画插件的设计与实现,涵盖任务队列、API集成、实时通知等核心技术点。通过采用Go语言的并发特性和Redis分布式队列,系统实现了高性能、高可用的AI绘画服务。
未来优化方向:
- 引入GPU资源管理调度
- 实现任务断点续传
- 多模型结果对比功能
- 提示词智能优化
掌握这些技术不仅能帮助你快速集成MJ/SD等主流AI绘画模型,还能为开发其他AI能力插件提供参考。建议开发者深入研究service.go中的任务调度逻辑,这是整个插件系统的核心所在。
附录:常见问题解答
Q1: 如何解决MJ API调用频率限制?
A1: 系统通过Redis实现分布式限流,代码位于mj/client.go的AcquireQuota方法。
Q2: 如何自定义SD模型参数?
A2: 修改types/sd_task.go中的SdTaskParams结构体,添加自定义参数。
Q3: 任务失败后如何调试?
A3: 查看logs/目录下的服务日志,关键错误会记录详细上下文信息。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考



