<think>
我们正在将YOLOv8模型修改为YOLOv8-seg(实例分割模型)以适配单通道float32 TIFF图像(范围0.0~1.0)。主要修改点包括:
1. 修改模型输入层为单通道(与之前类似)
2. 修改数据集加载器以支持分割任务的标注(包括目标检测框和分割掩码)
3. 调整损失函数以处理分割任务(YOLOv8-seg已经包含分割头,所以我们主要关注数据加载和输入适配)
注意:YOLOv8-seg的输出有两个部分:检测输出(框和类别)和分割输出(掩码原型和系数)。因此,我们的数据集需要同时提供边界框和分割掩码。
步骤:
1. 自定义数据集类,加载TIFF图像以及对应的标签(包括边界框和分割掩码)。
2. 修改模型的第一层卷积,使其接受单通道输入。
3. 调整数据增强流程,确保同时增强图像、边界框和分割掩码。
4. 训练时使用YOLOv8-seg的损失函数(包含检测损失和分割损失)。
由于问题中图像是单通道float32且范围0.0~1.0,我们在数据加载时不需要改变数据类型,但需要注意:
- 分割掩码通常为二值图像(0和1)或多类别整数标签图(对于实例分割,每个实例一个唯一的整数,背景为0)。
- 在YOLOv8-seg中,分割标注通常以多边形坐标的形式存储(在标签文件中),或者我们可以直接提供掩码图像。
这里我们假设标注格式为:
- 边界框:YOLO格式的txt文件(class, x_center, y_center, width, height)
- 分割掩码:与图像同名的PNG文件(或TIFF),每个实例用不同的整数标记(0表示背景)。
但注意,YOLOv8-seg的训练数据标注格式与检测任务类似,但额外包括分割多边形的坐标。具体来说,标签文件中的每一行代表一个实例,格式为:
class_id x1 y1 x2 y2 ... xn yn (多边形坐标归一化)
或者,也可以使用单独的分割掩码文件。
为了简化,我们使用两种方式:
方式1:使用单独的分割掩码文件(每个图像对应一个掩码文件,每个实例用不同的整数标记)。
方式2:在标签文件中同时包含边界框和多边形点(YOLO官方格式)。
我们选择方式2,因为YOLOv8-seg官方支持这种方式。因此,每个标签文件的内容为:
class_id x1 y1 x2 y2 ... xn yn
其中,多边形坐标是归一化的(相对于图像宽高)。
但是,我们的数据集加载器需要同时提供边界框和分割多边形。因此,在数据加载时,我们:
- 从标签文件中读取每个实例的类别和多边形点
- 将多边形点转换为边界框(通过计算多边形的外接矩形)
- 同时保留多边形点用于分割损失计算
另外,我们也可以直接使用边界框的标注(如果有单独的边界框标注文件),但为了统一,我们假设标签文件已经包含多边形点。
修改后的数据集类需要返回:
- 图像(单通道float32,0-1范围)
- 边界框(归一化的xywh)和类别
- 实例的多边形点(归一化坐标)
然而,在YOLOv8-seg中,训练时实际上不需要原始的掩码图像,而是需要多边形点。因此,我们只需读取标签文件中的多边形点即可。
具体实现:
</think>
# 适配单通道float32 TIFF图像的YOLOv8-seg修改方案
我将提供完整的解决方案,将YOLOv8模型修改为YOLOv8-seg(实例分割模型),同时适配单通道float32 TIFF图像(像素范围0.0~1.0)。
## 完整的YOLOv8-seg修改方案
### 1. 自定义数据集加载器(支持分割标注)
```python
import os
import numpy as np
import torch
import cv2
from tifffile import imread
from torch.utils.data import Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
import json
class TiffSegmentationDataset(Dataset):
"""支持实例分割的自定义TIFF数据集"""
def __init__(self, img_dir, label_dir, mask_dir, img_size=640, augment=True):
"""
:param img_dir: TIFF图像目录
:param label_dir: 检测标注目录 (YOLO格式)
:param mask_dir: 分割掩码目录 (PNG或TIFF格式)
:param img_size: 模型输入尺寸
:param augment: 是否应用数据增强
"""
self.img_dir = img_dir
self.label_dir = label_dir
self.mask_dir = mask_dir
self.img_size = img_size
self.augment = augment
# 获取所有TIFF文件
self.img_files = [f for f in os.listdir(img_dir)
if f.lower().endswith(('.tif', '.tiff'))]
# 数据增强管道
self.transform = self.build_transforms()
# 缓存图像信息
self.imgs = []
self.labels = []
self.masks = []
self.shapes = []
self.load_all_data()
def load_all_data(self):
"""预加载所有图像、标注和掩码"""
print(f"Loading {len(self.img_files)} TIFF images with segmentation masks...")
for img_file in self.img_files:
# 加载TIFF图像
img_path = os.path.join(self.img_dir, img_file)
img = imread(img_path).astype(np.float32)
# 确保单通道格式
if img.ndim > 2:
img = img[:, :, 0] # 取第一个通道
img = img[..., np.newaxis] # 添加通道维度 (H, W, 1)
# 加载检测标签
label_file = img_file.replace('.tif', '.txt').replace('.tiff', '.txt')
label_path = os.path.join(self.label_dir, label_file)
labels = []
if os.path.exists(label_path):
with open(label_path, 'r') as f:
for line in f:
line = line.strip().split()
if line:
# 转换为float: [class, x_center, y_center, width, height]
labels.append([float(x) for x in line])
# 加载分割掩码
mask_file = img_file.replace('.tif', '.png').replace('.tiff', '.png')
mask_path = os.path.join(self.mask_dir, mask_file)
if os.path.exists(mask_path):
# 读取为灰度图
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE).astype(np.float32)
mask = mask[..., np.newaxis] # (H, W, 1)
else:
mask = np.zeros(img.shape[:2] + (1,), dtype=np.float32)
# 存储数据
self.imgs.append(img)
self.shapes.append(img.shape[:2]) # (H, W)
self.labels.append(np.array(labels, dtype=np.float32) if labels else np.zeros((0, 5)))
self.masks.append(mask)
def build_transforms(self):
"""构建数据增强管道(支持分割掩码)"""
# 基础转换
transforms = [
A.Resize(self.img_size, self.img_size, always_apply=True),
ToTensorV2(transpose_mask=True)
]
# 训练时添加增强
if self.augment:
transforms = [
# 几何变换(同时应用于图像和掩码)
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.Rotate(limit=10, p=0.5),
A.RandomResizedCrop(self.img_size, self.img_size, scale=(0.8, 1.0), ratio=(0.8, 1.2), p=0.5),
# 像素级变换(仅应用于图像)
A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
A.GaussianBlur(blur_limit=(3, 7), p=0.3),
A.GaussNoise(var_limit=(10.0, 50.0), mean=0, p=0.3),
# 确保最后是尺寸调整和张量转换
*transforms
]
return A.Compose(
transforms,
bbox_params=A.BboxParams(
format='yolo',
min_visibility=0.4,
label_fields=['class_labels']
) if self.augment else None,
additional_targets={'mask': 'mask'}
)
def __len__(self):
return len(self.img_files)
def __getitem__(self, idx):
img = self.imgs[idx]
labels = self.labels[idx]
mask = self.masks[idx]
shape = self.shapes[idx]
# 准备增强参数
transform_kwargs = {'image': img, 'mask': mask}
if self.augment and len(labels) > 0:
# 准备边界框和类别标签
bboxes = labels[:, 1:]
class_labels = labels[:, 0]
transform_kwargs['bboxes'] = bboxes
transform_kwargs['class_labels'] = class_labels
# 应用转换
transformed = self.transform(**transform_kwargs)
img_tensor = transformed['image']
mask_tensor = transformed['mask']
# 处理增强后的边界框
if self.augment and len(labels) > 0:
bboxes = np.array(transformed['bboxes'])
class_labels = np.array(transformed['class_labels'])
labels = np.column_stack([class_labels, bboxes])
# 转换为YOLO格式: (class, x_center, y_center, width, height)
label_tensor = torch.zeros((len(labels), 6))
if len(labels) > 0:
label_tensor[:, 1:] = torch.from_numpy(labels)
return img_tensor, label_tensor, mask_tensor, torch.tensor(idx)
```
### 2. 修改YOLOv8-seg模型输入层
```python
import torch
import torch.nn as nn
from ultralytics import YOLO
def adapt_yolov8seg_for_single_channel(model_path='yolov8n-seg.pt'):
"""
修改YOLOv8-seg模型以适配单通道输入
:param model_path: 预训练模型路径
:return: 修改后的模型
"""
# 加载原始模型
model = YOLO(model_path)
model = model.model # 获取底层PyTorch模型
# 查找第一个卷积层
first_conv = None
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
first_conv = module
break
if first_conv is None:
raise RuntimeError("未能在模型中找到卷积层")
print(f"Original first conv layer: {first_conv}")
# 创建新的单通道卷积层
new_conv = nn.Conv2d(
in_channels=1, # 单通道输入
out_channels=first_conv.out_channels,
kernel_size=first_conv.kernel_size,
stride=first_conv.stride,
padding=first_conv.padding,
bias=first_conv.bias is not None
)
# 转换权重:将RGB三通道权重取平均
with torch.no_grad():
original_weight = first_conv.weight.data
# 原始形状: [out_channels, 3, kH, kW]
# 新形状: [out_channels, 1, kH, kW]
new_weight = original_weight.mean(dim=1, keepdim=True)
new_conv.weight.data.copy_(new_weight)
# 复制偏置(如果有)
if first_conv.bias is not None:
new_conv.bias.data.copy_(first_conv.bias.data)
# 替换模型中的第一层卷积
# 需要找到模块的父模块和名称
parent_module = None
conv_name = None
for name, module in model.named_modules():
if module is first_conv:
# 解析层级结构
parts = name.split('.')
parent_module = model
for part in parts[:-1]:
parent_module = getattr(parent_module, part)
conv_name = parts[-1]
break
if parent_module is None or conv_name is None:
raise RuntimeError("无法定位卷积层的父模块")
setattr(parent_module, conv_name, new_conv)
print("Successfully adapted YOLOv8-seg model for single-channel input")
return model
```
### 3. 自定义训练循环(支持分割任务)
```python
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from ultralytics.yolo.utils.loss import v8SegmentationLoss
def train_yolov8seg_tiff():
# 配置参数
config = {
'data': {
'train_img': 'path/to/train/images',
'train_label': 'path/to/train/labels',
'train_mask': 'path/to/train/masks',
'val_img': 'path/to/val/images',
'val_label': 'path/to/val/labels',
'val_mask': 'path/to/val/masks'
},
'model': 'yolov8n-seg_single_channel.pt', # 修改后的模型
'epochs': 100,
'batch_size': 8, # 分割任务需要更多内存
'img_size': 640,
'mask_size': 160, # 分割掩码下采样尺寸
'device': 'cuda:0' if torch.cuda.is_available() else 'cpu',
'lr0': 0.001,
'weight_decay': 0.0005,
'save_dir': 'runs/train-seg',
'name': 'yolov8seg_tiff'
}
# 准备数据集
train_dataset = TiffSegmentationDataset(
img_dir=config['data']['train_img'],
label_dir=config['data']['train_label'],
mask_dir=config['data']['train_mask'],
img_size=config['img_size'],
augment=True
)
val_dataset = TiffSegmentationDataset(
img_dir=config['data']['val_img'],
label_dir=config['data']['val_label'],
mask_dir=config['data']['val_mask'],
img_size=config['img_size'],
augment=False
)
# 数据加载器
train_loader = DataLoader(
train_dataset,
batch_size=config['batch_size'],
shuffle=True,
num_workers=4,
pin_memory=True,
collate_fn=seg_collate_fn
)
val_loader = DataLoader(
val_dataset,
batch_size=config['batch_size'],
shuffle=False,
num_workers=2,
pin_memory=True,
collate_fn=seg_collate_fn
)
# 加载模型
device = torch.device(config['device'])
model = torch.load(config['model'], map_location=device)
model.to(device)
model.train()
# 损失函数 (包含检测和分割损失)
loss_fn = v8SegmentationLoss(model)
# 优化器
optimizer = AdamW(
model.parameters(),
lr=config['lr0'],
weight_decay=config['weight_decay']
)
# 学习率调度器
scheduler = CosineAnnealingLR(
optimizer,
T_max=config['epochs'] * len(train_loader),
eta_min=config['lr0'] * 0.01
)
# 训练循环
for epoch in range(config['epochs']):
model.train()
train_loss = 0.0
train_det_loss = 0.0
train_seg_loss = 0.0
for batch_idx, (imgs, labels, masks, _) in enumerate(train_loader):
imgs = imgs.to(device, non_blocking=True)
labels = labels.to(device)
masks = masks.to(device)
# 调整掩码尺寸以匹配分割头输出
masks = F.interpolate(masks, size=config['mask_size'], mode='nearest')
# 前向传播
preds, seg_preds = model(imgs)
# 计算损失 (包含检测和分割损失)
loss, loss_items = loss_fn(preds, labels, seg_preds, masks)
det_loss = loss_items['det_loss']
seg_loss = loss_items['seg_loss']
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
# 累计损失
train_loss += loss.item()
train_det_loss += det_loss.item()
train_seg_loss += seg_loss.item()
# 每10个batch打印一次
if batch_idx % 10 == 0:
print(f"Epoch: {epoch+1}/{config['epochs']} | "
f"Batch: {batch_idx}/{len(train_loader)} | "
f"Loss: {loss.item():.4f} | "
f"Det Loss: {det_loss.item():.4f} | "
f"Seg Loss: {seg_loss.item():.4f}")
# 计算平均训练损失
avg_train_loss = train_loss / len(train_loader)
avg_det_loss = train_det_loss / len(train_loader)
avg_seg_loss = train_seg_loss / len(train_loader)
# 验证
val_loss, val_det_loss, val_seg_loss = validate_seg(model, val_loader, loss_fn, device, config['mask_size'])
print(f"Epoch {epoch+1} summary: "
f"Train Loss: {avg_train_loss:.4f} (Det: {avg_det_loss:.4f}, Seg: {avg_seg_loss:.4f}) | "
f"Val Loss: {val_loss:.4f} (Det: {val_det_loss:.4f}, Seg: {val_seg_loss:.4f})")
# 保存模型
if (epoch + 1) % 10 == 0:
save_path = os.path.join(
config['save_dir'],
config['name'],
f'epoch_{epoch+1}.pt'
)
torch.save(model.state_dict(), save_path)
print(f"Model saved to {save_path}")
print("Training completed!")
return model
def seg_collate_fn(batch):
"""自定义批处理函数(支持分割掩码)"""
imgs, labels, masks, idxs = zip(*batch)
return torch.stack(imgs, 0), torch.cat(labels, 0), torch.stack(masks, 0), torch.stack(idxs, 0)
def validate_seg(model, val_loader, loss_fn, device, mask_size):
"""在验证集上评估分割模型"""
model.eval()
total_loss = 0.0
total_det_loss = 0.0
total_seg_loss = 0.0
with torch.no_grad():
for imgs, labels, masks, _ in val_loader:
imgs = imgs.to(device, non_blocking=True)
labels = labels.to(device)
masks = masks.to(device)
# 调整掩码尺寸
masks = F.interpolate(masks, size=mask_size, mode='nearest')
# 前向传播
preds, seg_preds = model(imgs)
# 计算损失
loss, loss_items = loss_fn(preds, labels, seg_preds, masks)
det_loss = loss_items['det_loss']
seg_loss = loss_items['seg_loss']
# 累计损失
total_loss += loss.item()
total_det_loss += det_loss.item()
total_seg_loss += seg_loss.item()
avg_loss = total_loss / len(val_loader)
avg_det_loss = total_det_loss / len(val_loader)
avg_seg_loss = total_seg_loss / len(val_loader)
model.train()
return avg_loss, avg_det_loss, avg_seg_loss
```
### 4. TIFF图像分割推理函数
```python
import numpy as np
import torch
import torch.nn.functional as F
from tifffile import imread, imsave
def predict_segmentation_on_tiff(model, image_path, conf_thres=0.25, iou_thres=0.5):
"""
在TIFF图像上进行实例分割预测
:param model: 训练好的YOLOv8-seg模型
:param image_path: TIFF图像路径
:param conf_thres: 置信度阈值
:param iou_thres: IoU阈值
:return: 检测结果, 分割掩码
"""
# 读取TIFF图像
img = imread(image_path).astype(np.float32)
# 确保单通道格式
if img.ndim > 2:
img = img[:, :, 0] # 取第一个通道
img = img[..., np.newaxis] # (H, W, 1)
# 预处理
original_shape = img.shape[:2]
img_tensor = preprocess_image(img, model.img_size)
img_tensor = img_tensor.unsqueeze(0) # 添加批次维度
# 推理
model.eval()
with torch.no_grad():
preds, proto = model(img_tensor)
detections, masks = process_segmentation_output(
preds, proto, conf_thres, iou_thres
)
# 转换到原始图像坐标
scale = min(model.img_size / original_shape[0], model.img_size / original_shape[1])
pad_x = (model.img_size - original_shape[1] * scale) / 2
pad_y = (model.img_size - original_shape[0] * scale) / 2
if detections is not None:
detections[:, [0, 2]] -= pad_x # x坐标
detections[:, [1, 3]] -= pad_y # y坐标
detections[:, :4] /= scale
# 处理分割掩码
if masks is not None and len(masks) > 0:
# 调整掩码尺寸到原始图像
masks = F.interpolate(masks.unsqueeze(0), size=original_shape, mode='bilinear')[0]
masks = masks.cpu().numpy()
masks = (masks > 0.5).astype(np.uint8) # 二值化
return detections, masks
def process_segmentation_output(preds, proto, conf_thres, iou_thres):
"""
处理YOLOv8-seg的输出
"""
# 非极大值抑制 (简化实现)
from ultralytics.yolo.utils.ops import non_max_suppression
# 应用NMS获取检测结果
detections = non_max_suppression(
preds,
conf_thres,
iou_thres,
multi_label=True
)[0] # 取批次中的第一个结果
if detections is None:
return None, None
# 处理分割掩码
masks = process_proto_masks(detections, proto)
return detections, masks
def process_proto_masks(detections, proto):
"""
从原型掩码生成实例分割掩码
"""
# 提取分割系数
mask_coef = detections[:, 6:6+32] # 假设有32个系数
# 计算掩码
masks = torch.sigmoid(mask_coef @ proto.view(32, -1)).view(-1, *proto.shape[2:])
return masks
def visualize_segmentation_results(image_path, detections, masks, class_names):
"""
可视化分割结果
"""
# 读取原始图像
img = imread(image_path)
if img.ndim == 2:
# 转换为伪彩色以便可视化
img = np.stack([img, img, img], axis=-1)
# 绘制边界框和分割掩码
for i, det in enumerate(detections):
*xyxy, conf, cls = det
x1, y1, x2, y2 = map(int, xyxy)
# 绘制矩形
color = (np.random.randint(0, 255), np.random.randint(0, 255), np.random.randint(0, 255))
cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
# 标签
label = f"{class_names[int(cls)]} {conf:.2f}"
cv2.putText(img, label, (x1, y1-10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
# 绘制分割掩码
if masks is not None and i < len(masks):
mask = masks[i]
# 创建彩色掩码
mask_img = np.zeros_like(img)
mask_img[mask == 1] = color
# 融合到原始图像
img = cv2.addWeighted(img, 1, mask_img, 0.5, 0)
# 保存结果
output_path = image_path.replace('.tif', '_segmentation.tif')
imsave(output_path, img.astype(np.float32))
print(f"Segmentation result saved to {output_path}")
# 显示(转换为8位)
display_img = (img * 255).astype(np.uint8)
cv2.imshow("Segmentation Result", display_img)
cv2.waitKey(0)
cv2.destroyAllWindows()
```
## 关键修改点详解
### 1. YOLOv8到YOLOv8-seg的转换
- **模型结构差异**:YOLOv8-seg增加了一个分割头,输出分割原型和系数
- **输入层修改**:同样需要将第一层卷积从3通道改为1通道
- **损失函数**:使用`v8SegmentationLoss`同时处理检测和分割损失
### 2. 分割数据处理
- **掩码加载**:添加对分割掩码的支持(PNG格式)
- **掩码预处理**:调整为分割头输出尺寸(通常为输入的1/4)
- **数据增强**:保持图像和掩码的同步增强
### 3. 分割训练优化
- **批量大小**:分割任务需要更多内存,减小批量大小
- **掩码尺寸**:设置合适的掩码下采样尺寸(通常为160×160)
- **损失权重**:平衡检测损失和分割损失的权重
### 4. 分割推理处理
- **输出处理**:同时处理检测框和分割掩码
- **掩码生成**:使用分割系数和原型生成实例掩码
- **结果融合**:将检测框和分割掩码融合显示
## YOLOv8-seg专用修改
### 1. 分割头适配
```python
# 在模型修改函数中添加分割头检查
def adapt_yolov8seg_for_single_channel(model_path):
# ... [之前代码]
# 检查分割头是否正常
if not hasattr(model, 'proto') or not hasattr(model, 'seg'):
print("Warning: Model segmentation head may not be properly configured")
# 确保分割头权重不变
print("Segmentation head parameters remain unchanged")
return model
```
### 2. 分割损失优化
```python
# 自定义分割损失权重
class CustomSegmentationLoss(v8SegmentationLoss):
def __init__(self, model, seg_weight=1.5):
super().__init__(model)
self.seg_weight = seg_weight # 增加分割损失的权重
def __call__(self, preds, targets, seg_preds, masks):
loss, loss_items = super().__call__(preds, targets, seg_preds, masks)
# 调整分割损失权重
loss_items['seg_loss'] *= self.seg_weight
loss = loss_items['det_loss'] + loss_items['seg_loss']
return loss, loss_items
```
### 3. 掩码后处理优化
```python
def refine_segmentation_masks(masks, detections, min_area=100):
"""
精炼分割掩码:移除小区域,填充空洞
"""
refined_masks = []
for i in range(len(masks)):
mask = masks[i].astype(np.uint8)
# 移除小区域
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for cnt in contours:
if cv2.contourArea(cnt) < min_area:
cv2.drawContours(mask, [cnt], 0, 0, -1)
# 填充空洞
kernel = np.ones((5,5), np.uint8)
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
refined_masks.append(mask)
return np.array(refined_masks)
```
### 4. 多实例掩码处理
```python
def combine_instance_masks(masks):
"""
将多个实例掩码组合成单一标签图像
"""
if len(masks) == 0:
return np.zeros_like(masks[0])
combined = np.zeros_like(masks[0], dtype=np.uint16)
for i, mask in enumerate(masks):
combined[mask > 0] = i + 1 # 实例ID从1开始
return combined
```
## 完整工作流程
1. **准备数据**:
- 组织TIFF图像
- YOLO格式检测标注
- PNG格式分割掩码
2. **修改模型**:
```python
adapted_model = adapt_yolov8seg_for_single_channel('yolov8n-seg.pt')
torch.save(adapted_model.state_dict(), 'yolov8n-seg_single_channel.pt')
```
3. **训练模型**:
```python
trained_model = train_yolov8seg_tiff()
```
4. **推理测试**:
```python
detections, masks = predict_segmentation_on_tiff(trained_model, 'test.tif')
```
5. **结果可视化**:
```python
visualize_segmentation_results('test.tif', detections, masks, class_names=['object'])
```
6. **保存结果**:
```python
# 保存组合掩码
combined_mask = combine_instance_masks(masks)
imsave('test_mask.tif', combined_mask)
```
## 性能优化技巧
1. **混合精度训练**:
```python
scaler = torch.cuda.amp.GradScaler()
with torch.cuda.amp.autocast():
preds, seg_preds = model(imgs)
loss, loss_items = loss_fn(preds, labels, seg_preds, masks)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
```
2. **掩码压缩存储**:
```python
# 使用LZW压缩保存TIFF掩码
imsave('mask.tif', combined_mask, compression='lzw')
```
3. **批处理推理**:
```python
def batch_predict_segmentation(model, image_paths, batch_size=4):
"""批处理推理提高效率"""
# 预处理所有图像
# ...
# 批处理推理
results = []
for i in range(0, len(images), batch_size):
batch = images[i:i+batch_size]
batch_tensor = torch.stack(batch).to(device)
with torch.no_grad():
preds, proto = model(batch_tensor)
# 处理每个图像的输出
for j in range(len(batch)):
detections, masks = process_segmentation_output(
preds[j:j+1], proto[j:j+1], conf_thres, iou_thres
)
results.append((detections, masks))
return results
```