1、实现功能
将shp数据转化为带坐标系的二值图;
shp数据可以是面要素也可以是线要素;
必须具有转化为二值图的属性字段,如FID_1,具体字段数值代表栅格化后的value值。
2、实现代码
# -*- coding:utf-8 -*-
import argparse
import os
from osgeo import gdalconst, gdal, ogr
import numpy as np
import sys
import traceback
import shutil
# 用于打包所需要的依赖库,加入环境变量
# proj_str = os.path.dirname(sys.argv[0]) + '/proj'
# os.environ['PROJ_LIB'] = proj_str
def shp2Raster(shp, templatePic, output, field, nodata):
"""
shp:字符串,一个矢量,从0开始计数,整数
templatePic:字符串,模板栅格,一个tif,地理变换信息从这里读,栅格大小与该栅格一致
output:字符串,输出栅格,一个tif
field:字符串,栅格值的字段
nodata:整型或浮点型,矢量空白区转换后的值
"""
ndsm = templatePic
data = gdal.Open(ndsm, gdalconst.GA_ReadOnly)
geo_transform = data.GetGeoTransform()
proj = data.GetProjection()
# source_layer = data.GetLayer()
x_min = geo_transform[0]
y_max = geo_transform[3]
x_max = x_min + geo_transform[1] * data.RasterXSize
y_min = y_max + geo_transform[5] * data.RasterYSize
x_res = data.RasterXSize
y_res = data.RasterYSize
mb_v = ogr.Open(shp)
mb_l = mb_v.GetLayer()
pixel_width = geo_transform[1]
# 输出影像为16位整型
target_ds = gdal.GetDriverByName('GTiff').Create(output, x_res, y_res, 1, gdal.GDT_Int16)
target_ds.SetGeoTransform(geo_transform)
target_ds.SetProjection(proj)
band = target_ds.GetRasterBand(1)
NoData_value = nodata
band.SetNoDataValue(NoData_value)
band.FlushCache()
gdal.RasterizeLayer(target_ds, [1], mb_l, options=["ATTRIBUTE=%s" % field, 'ALL_TOUCHED=TRUE'])
# gdal.RasterizeLayer(target_ds, [1], mb_l)
target_ds = None # todo 释放内存,只有强制为None才可以释放干净
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-o', '--open_path', type=str, default='./samples', help=u'样本目录')
# parser.add_argument('-t', '--temp', type=str, default='./temp', help=u'样本缓存目录')
parser.add_argument('-l', '--labels', type=str, default='./labels', help=u'样本标签目录')
args = parser.parse_args()
open_path = args.open_path
# temp = args.temp
label_path = args.labels
open_path = os.path.realpath(open_path) # 去除斜杠
if not os.path.exists('temp'):
os.makedirs('temp')
temp = "./temp"
text_path = os.path.join(label_path , 'train.txt')
sample_line = os.path.join(label_path, "sample_line")
if not os.path.exists(sample_line):
os.makedirs(sample_line)
try:
if len(os.listdir(open_path)) == 1:
print(os.listdir(open_path)[0])
print(open_path.split("\\")[-1])
if os.listdir(open_path)[0] == open_path.split("\\")[-1]:
raise IOError("Please put multiple sample folders directly in the zip file, don't nest folders.")
path_list = []
temp_list = []
fileFolders = os.listdir(open_path)
for fileFolder in fileFolders:
real_path = os.path.join(open_path, fileFolder)
if os.path.isdir(real_path):
# print(real_path)
files = os.listdir(real_path)
for file in files:
if file[-3:] == "tif":
temp_list.append(os.path.join(real_path, file))
if file[-3:] == "shp":
temp_list.append(os.path.join(real_path, file))
refer_Name = file[0:-12] + "_V1"
path_list.append(temp_list)
temp_list = []
for i in path_list:
tifName = os.path.split(i[0])[1]
tmp = temp + "\\" + tifName[:-4] + "_mask.tif"
print(tmp)
shp2Raster(i[1], i[0], tmp, 'FID_1', 0)
# img = cv2.imread(tmp, cv2.IMREAD_UNCHANGED)
# img[np.where(img > 0)] = 255
outfile = sample_line + "\\" + tifName[:-4] + "_LINE" + tifName[-4:]
# cv2.imwrite(outfile, img)
dataset = gdal.Open(tmp)
width = dataset.RasterXSize
height = dataset.RasterYSize
bandCount = dataset.RasterCount
# 空间参考
srcProj = dataset.GetProjection()
srcTranform = dataset.GetGeoTransform() # 获取原数据仿射变换参数
band = dataset.GetRasterBand(1)
bandArray = band.ReadAsArray(0, 0, width, height).astype(np.int8)
bandArray[np.where(bandArray > 0)] = 255
driver = gdal.GetDriverByName("GTiff")
datasetOut = driver.Create(outfile, width, height, 1, gdal.GDT_Byte)
datasetOut.SetProjection(srcProj)
datasetOut.SetGeoTransform(srcTranform)
datasetOut.GetRasterBand(1).WriteArray(bandArray)
datasetOut.FlushCache()
dataset.FlushCache()
datasetOut = None
dataset = None
f = open(text_path, "w+")
for root, dirs, files in os.walk(open_path):
for tif in files:
if tif.endswith("tif"):
tifpath = os.path.realpath(os.path.join(root, tif))
label_name = tif.split(".")[0] + "_LINE.tif"
# print(label_name)
img_file_path = os.path.realpath(os.path.join(sample_line, label_name))
f.write(tifpath + " " + img_file_path + "\n")
f.close()
print("Making labels is done!")
except:
traceback.print_exc()
# 递归删除文件夹
try:
shutil.rmtree("./temp")
except:
pass
3、实现效果