输入:影像文件夹和矢量文件夹
输出:标签
import os
import shutil
import cv2
import numpy as np
from osgeo import gdal, ogr, osr, gdalconst
import fnmatch
# 定义每种类型的标记值和颜色
LABELS = {
"house": 1, # 房屋
"road": 2, # 道路
"water": 3, # 水体
"vegetation": 4 # 植被
}
COLORS = {
1: (0, 0, 255), # 房屋 - 红色
2: (0, 255, 255), # 道路 - 黄色
3: (255, 0, 0), # 水体 - 蓝色
4: (0, 255, 0), # 植被 - 绿色
5: (255, 255, 255) # 背景 - 白色
}
def ShapeClip(baseFilePath, maskFilePath, saveFolderPath):
"""
矢量裁剪
:param baseFilePath: 要裁剪的矢量文件
:param maskFilePath: 掩膜矢量文件
:param saveFolderPath: 裁剪后的矢量文件保存目录
:return:
"""
ogr.RegisterAll()
gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
baseData = ogr.Open(baseFilePath)
if not baseData:
raise ValueError(f"无法打开文件: {baseFilePath}")
baseLayer = baseData.GetLayer(0)
spatial = baseLayer.GetSpatialRef()
geomType = baseLayer.GetGeomType()
baseLayerName = baseLayer.GetName()
maskData = ogr.Open(maskFilePath)
if not maskData:
raise ValueError(f"无法打开掩膜文件: {maskFilePath}")
maskLayer = maskData.GetLayer()
maskLayerName = maskLayer.GetName()
outLayerName = maskLayerName + "_Clip_" + baseLayerName
gdal.SetConfigOption("SHAPE_ENCODING", "GBK")
driver = ogr.GetDriverByName("ESRI Shapefile")
outData = driver.CreateDataSource(saveFolderPath)
outLayer = outData.CreateLayer(outLayerName, spatial, geomType)
baseLayer.Clip(maskLayer, outLayer)
outData.Release()
baseData.Release()
maskData.Release()
return saveFolderPath
def shp2Raster(shp, templatePic, output, nodata):
"""
矢量转栅格
:param shp: 矢量文件
:param templatePic: 模板栅格
:param output: 输出栅格文件
:param nodata: 空白区填充值
"""
data = gdal.Open(templatePic, gdalconst.GA_ReadOnly)
geo_transform = data.GetGeoTransform()
proj = data.GetProjection()
x_res = data.RasterXSize
y_res = data.RasterYSize
mb_v = ogr.Open(shp)
mb_l = mb_v.GetLayer()
target_ds = gdal.GetDriverByName('GTiff').Create(output, x_res, y_res, 1, gdal.GDT_Byte)
target_ds.SetGeoTransform(geo_transform)
target_ds.SetProjection(proj)
band = target_ds.GetRasterBand(1)
band.SetNoDataValue(nodata)
band.FlushCache()
gdal.RasterizeLayer(target_ds, [1], mb_l, options=['ALL_TOUCHED=TRUE'])
target_ds = None
def overlay_images_with_priority(folder_path, output_path, raster_prefix):
"""
按优先级叠加当前矢量对应的栅格文件集合。
:param folder_path: 当前矢量生成的所有栅格文件存放目录
:param output_path: 输出的最终结果路径
:param raster_prefix: 当前矢量对应栅格文件的前缀(用于过滤文件)
"""
final_label = None
for category, label in sorted(LABELS.items(), key=lambda x: -x[1]): # 按优先级从高到低
matched_files = fnmatch.filter(os.listdir(folder_path), f"{raster_prefix}_{category}.png")
if not matched_files:
print(f"Warning: No files found for category '{category}' with prefix '{raster_prefix}'")
continue
for file_name in matched_files:
image_path = os.path.join(folder_path, file_name)
image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
if image is None:
print(f"Error: Unable to read {image_path}")
continue
if final_label is None:
height, width = image.shape
final_label = np.full((height, width), 5, dtype=np.uint8) # 背景初始化为5
mask = (image == 255) # 前景为白色区域
final_label[mask] = label # 按优先级更新标记矩阵
if final_label is None:
raise ValueError(f"No valid images were found for prefix '{raster_prefix}' to generate the overlay.")
final_image = np.zeros((final_label.shape[0], final_label.shape[1], 3), dtype=np.uint8)
for label, color in COLORS.items():
final_image[final_label == label] = color
cv2.imwrite(output_path, final_image)
print(f"Overlay image successfully saved to {output_path}")
class MyLabel:
def __init__(self, raster_folder, vector_folder, label_output_path):
self.raster_folder = raster_folder
self.vector_folder = vector_folder
self.label_output_path = label_output_path
def start_make_label(self):
print("开始制作标签")
# 定义输出路径
shape_path = os.path.join(self.label_output_path, 'mask_boundary_shp')
mask_clip_path = os.path.join(self.label_output_path, 'mask_clip_train')
# 初始化目录
ogr.RegisterAll()
self._prepare_directories([shape_path, mask_clip_path])
# 获取栅格和矢量文件列表
raster_list = fnmatch.filter(os.listdir(self.raster_folder), '*.tif')
vector_list = fnmatch.filter(os.listdir(self.vector_folder), '*.shp')
if not raster_list or not vector_list:
print("未找到符合条件的栅格或矢量文件。")
return
for raster in raster_list:
p_raster = os.path.join(self.raster_folder, raster)
raster_name = os.path.splitext(raster)[0] # 栅格前缀
for vector in vector_list:
p_vector = os.path.join(self.vector_folder, vector)
vector_name = os.path.splitext(vector)[0] # 矢量前缀
# 生成临时 shapefile 路径
outfilename = os.path.join(shape_path, f"{raster_name}_{vector_name}.shp")
# 创建多边形 shapefile
self._create_polygon_shapefile(p_raster, outfilename)
# 矢量裁剪生成 mask shapefile
mask_train_name = os.path.join(mask_clip_path, f"{raster_name}_{vector_name}.shp")
ShapeClip(p_vector, outfilename, mask_train_name)
# 矢量转栅格
output = os.path.join(mask_clip_path, f"{raster_name}_{vector_name}.png")
shp2Raster(mask_train_name, p_raster, output, 0)
# 叠加当前矢量的所有标签
final_output = os.path.join(self.label_output_path, f"{raster_name}.png")
overlay_images_with_priority(mask_clip_path, final_output, raster_name)
print(f"最终标签已生成并保存为: {final_output}")
# 清理临时目录
self._cleanup([shape_path, mask_clip_path])
@staticmethod
def _prepare_directories(dirs):
for dir_path in dirs:
try:
if os.path.exists(dir_path):
shutil.rmtree(dir_path)
os.mkdir(dir_path)
except Exception as e:
print(f"无法创建目录 {dir_path}: {e}")
@staticmethod
def _create_polygon_shapefile(image_path, shapefile_path):
dataset = gdal.Open(image_path)
if not dataset:
print(f"无法打开影像文件 {image_path}")
return
driver = ogr.GetDriverByName('ESRI Shapefile')
data_source = driver.CreateDataSource(shapefile_path)
srs = osr.SpatialReference(wkt=dataset.GetProjection())
geocd = dataset.GetGeoTransform()
layer = data_source.CreateLayer("polygon", srs, ogr.wkbPolygon)
layer_def = layer.GetLayerDefn()
row = dataset.RasterXSize
line = dataset.RasterYSize
geoxmin = geocd[0]
geoymin = geocd[3]
geoxmax = geocd[0] + (row * geocd[1]) + (line * geocd[2])
geoymax = geocd[3] + (row * geocd[4]) + (line * geocd[5])
ring = ogr.Geometry(ogr.wkbLinearRing)
ring.AddPoint(geoxmin, geoymin)
ring.AddPoint(geoxmax, geoymin)
ring.AddPoint(geoxmax, geoymax)
ring.AddPoint(geoxmin, geoymax)
ring.CloseRings()
poly = ogr.Geometry(ogr.wkbPolygon)
poly.AddGeometry(ring)
feature = ogr.Feature(layer_def)
feature.SetGeometry(poly)
layer.CreateFeature(feature)
feature = None
data_source.Destroy()
dataset = None
@staticmethod
def _cleanup(dirs):
for dir_path in dirs:
try:
if os.path.exists(dir_path):
shutil.rmtree(dir_path)
except Exception as e:
print(f"无法删除目录 {dir_path}: {e}")
# 示例调用
raster_folder = "I:/toWuda/DOM/a/sat_test_big"
vector_folder = "I:/toWuda/paper_shiyan/shp"
label_output_path = "I:/toWuda/paper_shiyan/TrainLabels"
label_maker = MyLabel(raster_folder, vector_folder, label_output_path)
label_maker.start_make_label()