制作多分类标签的代码

 输入:影像文件夹和矢量文件夹

输出:标签

import os
import shutil
import cv2
import numpy as np
from osgeo import gdal, ogr, osr, gdalconst
import fnmatch

# 定义每种类型的标记值和颜色
LABELS = {
    "house": 1,        # 房屋
    "road": 2,         # 道路
    "water": 3,        # 水体
    "vegetation": 4    # 植被
}
COLORS = {
    1: (0, 0, 255),      # 房屋 - 红色
    2: (0, 255, 255),    # 道路 - 黄色
    3: (255, 0, 0),      # 水体 - 蓝色
    4: (0, 255, 0),      # 植被 - 绿色
    5: (255, 255, 255)   # 背景 - 白色
}

def ShapeClip(baseFilePath, maskFilePath, saveFolderPath):
    """
    矢量裁剪
    :param baseFilePath: 要裁剪的矢量文件
    :param maskFilePath: 掩膜矢量文件
    :param saveFolderPath: 裁剪后的矢量文件保存目录
    :return:
    """
    ogr.RegisterAll()
    gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
    baseData = ogr.Open(baseFilePath)
    if not baseData:
        raise ValueError(f"无法打开文件: {baseFilePath}")
    baseLayer = baseData.GetLayer(0)
    spatial = baseLayer.GetSpatialRef()
    geomType = baseLayer.GetGeomType()
    baseLayerName = baseLayer.GetName()

    maskData = ogr.Open(maskFilePath)
    if not maskData:
        raise ValueError(f"无法打开掩膜文件: {maskFilePath}")
    maskLayer = maskData.GetLayer()
    maskLayerName = maskLayer.GetName()

    outLayerName = maskLayerName + "_Clip_" + baseLayerName
    gdal.SetConfigOption("SHAPE_ENCODING", "GBK")
    driver = ogr.GetDriverByName("ESRI Shapefile")
    outData = driver.CreateDataSource(saveFolderPath)
    outLayer = outData.CreateLayer(outLayerName, spatial, geomType)

    baseLayer.Clip(maskLayer, outLayer)

    outData.Release()
    baseData.Release()
    maskData.Release()
    return saveFolderPath

def shp2Raster(shp, templatePic, output, nodata):
    """
    矢量转栅格
    :param shp: 矢量文件
    :param templatePic: 模板栅格
    :param output: 输出栅格文件
    :param nodata: 空白区填充值
    """
    data = gdal.Open(templatePic, gdalconst.GA_ReadOnly)
    geo_transform = data.GetGeoTransform()
    proj = data.GetProjection()

    x_res = data.RasterXSize
    y_res = data.RasterYSize
    mb_v = ogr.Open(shp)
    mb_l = mb_v.GetLayer()

    target_ds = gdal.GetDriverByName('GTiff').Create(output, x_res, y_res, 1, gdal.GDT_Byte)
    target_ds.SetGeoTransform(geo_transform)
    target_ds.SetProjection(proj)
    band = target_ds.GetRasterBand(1)
    band.SetNoDataValue(nodata)
    band.FlushCache()
    gdal.RasterizeLayer(target_ds, [1], mb_l, options=['ALL_TOUCHED=TRUE'])
    target_ds = None

def overlay_images_with_priority(folder_path, output_path, raster_prefix):
    """
    按优先级叠加当前矢量对应的栅格文件集合。
    :param folder_path: 当前矢量生成的所有栅格文件存放目录
    :param output_path: 输出的最终结果路径
    :param raster_prefix: 当前矢量对应栅格文件的前缀(用于过滤文件)
    """
    final_label = None

    for category, label in sorted(LABELS.items(), key=lambda x: -x[1]):  # 按优先级从高到低
        matched_files = fnmatch.filter(os.listdir(folder_path), f"{raster_prefix}_{category}.png")
        if not matched_files:
            print(f"Warning: No files found for category '{category}' with prefix '{raster_prefix}'")
            continue

        for file_name in matched_files:
            image_path = os.path.join(folder_path, file_name)
            image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
            if image is None:
                print(f"Error: Unable to read {image_path}")
                continue

            if final_label is None:
                height, width = image.shape
                final_label = np.full((height, width), 5, dtype=np.uint8)  # 背景初始化为5

            mask = (image == 255)  # 前景为白色区域
            final_label[mask] = label  # 按优先级更新标记矩阵

    if final_label is None:
        raise ValueError(f"No valid images were found for prefix '{raster_prefix}' to generate the overlay.")

    final_image = np.zeros((final_label.shape[0], final_label.shape[1], 3), dtype=np.uint8)
    for label, color in COLORS.items():
        final_image[final_label == label] = color

    cv2.imwrite(output_path, final_image)
    print(f"Overlay image successfully saved to {output_path}")




class MyLabel:
    def __init__(self, raster_folder, vector_folder, label_output_path):
        self.raster_folder = raster_folder
        self.vector_folder = vector_folder
        self.label_output_path = label_output_path

    def start_make_label(self):
        print("开始制作标签")

        # 定义输出路径
        shape_path = os.path.join(self.label_output_path, 'mask_boundary_shp')
        mask_clip_path = os.path.join(self.label_output_path, 'mask_clip_train')

        # 初始化目录
        ogr.RegisterAll()
        self._prepare_directories([shape_path, mask_clip_path])

        # 获取栅格和矢量文件列表
        raster_list = fnmatch.filter(os.listdir(self.raster_folder), '*.tif')
        vector_list = fnmatch.filter(os.listdir(self.vector_folder), '*.shp')

        if not raster_list or not vector_list:
            print("未找到符合条件的栅格或矢量文件。")
            return

        for raster in raster_list:
            p_raster = os.path.join(self.raster_folder, raster)
            raster_name = os.path.splitext(raster)[0]  # 栅格前缀

            for vector in vector_list:
                p_vector = os.path.join(self.vector_folder, vector)
                vector_name = os.path.splitext(vector)[0]  # 矢量前缀

                # 生成临时 shapefile 路径
                outfilename = os.path.join(shape_path, f"{raster_name}_{vector_name}.shp")

                # 创建多边形 shapefile
                self._create_polygon_shapefile(p_raster, outfilename)

                # 矢量裁剪生成 mask shapefile
                mask_train_name = os.path.join(mask_clip_path, f"{raster_name}_{vector_name}.shp")
                ShapeClip(p_vector, outfilename, mask_train_name)

                # 矢量转栅格
                output = os.path.join(mask_clip_path, f"{raster_name}_{vector_name}.png")
                shp2Raster(mask_train_name, p_raster, output, 0)

            # 叠加当前矢量的所有标签
            final_output = os.path.join(self.label_output_path, f"{raster_name}.png")
            overlay_images_with_priority(mask_clip_path, final_output, raster_name)
            print(f"最终标签已生成并保存为: {final_output}")

        # 清理临时目录
        self._cleanup([shape_path, mask_clip_path])


    @staticmethod
    def _prepare_directories(dirs):
        for dir_path in dirs:
            try:
                if os.path.exists(dir_path):
                    shutil.rmtree(dir_path)
                os.mkdir(dir_path)
            except Exception as e:
                print(f"无法创建目录 {dir_path}: {e}")

    @staticmethod
    def _create_polygon_shapefile(image_path, shapefile_path):
        dataset = gdal.Open(image_path)
        if not dataset:
            print(f"无法打开影像文件 {image_path}")
            return

        driver = ogr.GetDriverByName('ESRI Shapefile')
        data_source = driver.CreateDataSource(shapefile_path)
        srs = osr.SpatialReference(wkt=dataset.GetProjection())
        geocd = dataset.GetGeoTransform()

        layer = data_source.CreateLayer("polygon", srs, ogr.wkbPolygon)
        layer_def = layer.GetLayerDefn()

        row = dataset.RasterXSize
        line = dataset.RasterYSize
        geoxmin = geocd[0]
        geoymin = geocd[3]
        geoxmax = geocd[0] + (row * geocd[1]) + (line * geocd[2])
        geoymax = geocd[3] + (row * geocd[4]) + (line * geocd[5])

        ring = ogr.Geometry(ogr.wkbLinearRing)
        ring.AddPoint(geoxmin, geoymin)
        ring.AddPoint(geoxmax, geoymin)
        ring.AddPoint(geoxmax, geoymax)
        ring.AddPoint(geoxmin, geoymax)
        ring.CloseRings()

        poly = ogr.Geometry(ogr.wkbPolygon)
        poly.AddGeometry(ring)

        feature = ogr.Feature(layer_def)
        feature.SetGeometry(poly)
        layer.CreateFeature(feature)

        feature = None
        data_source.Destroy()
        dataset = None

    @staticmethod
    def _cleanup(dirs):
        for dir_path in dirs:
            try:
                if os.path.exists(dir_path):
                    shutil.rmtree(dir_path)
            except Exception as e:
                print(f"无法删除目录 {dir_path}: {e}")

# 示例调用
raster_folder = "I:/toWuda/DOM/a/sat_test_big"
vector_folder = "I:/toWuda/paper_shiyan/shp"
label_output_path = "I:/toWuda/paper_shiyan/TrainLabels"

label_maker = MyLabel(raster_folder, vector_folder, label_output_path)
label_maker.start_make_label()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值