import os
import numpy as np
import pandas as pd
import rasterio
from rasterio.merge import merge
from rasterio.warp import reproject, Resampling
from rasterio.features import rasterize
from rasterio.transform import xy as transform_xy
from contextlib import ExitStack
import glob
import gc
from tqdm import tqdm
import warnings
import traceback
from datetime import datetime
from scipy.ndimage import gaussian_filter, median_filter, binary_dilation
from sklearn.cluster import MiniBatchKMeans
from sklearn.preprocessing import StandardScaler
import json
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
# ===== 新增依赖(轻量,不用GeoPandas) =====
import fiona
from shapely.geometry import shape, mapping
from shapely.ops import unary_union
from shapely.geometry.base import BaseGeometry
import pyproj
from shapely.ops import transform as shp_transform
warnings.filterwarnings('ignore')
class ImprovedMultiClassGenerator:
"""
改进的多类别训练数据生成器(专门用于不透水面提取)
核心改进:
1. 处理无夜光数据年份(1990-1991)的建成区识别
2. 固定训练集和验证集的类别比例
3. 同时输出六分类和二分类数据
4. 优化others类的处理
"""
def __init__(self, use_auxiliary=True):
"""初始化生成器"""
# 路径配置
self.landsat_path = r"D:\山西省影像"
self.auxiliary_path = r"D:\山西省辅助数据"
self.output_path = r"D:\训练数据_不透水面优化"
self.use_auxiliary = use_auxiliary
# AOI Shapefile 配置
self.aoi_path = r'D:\山西省shape\山西省\山西省_dissolve.shp'
self.require_aoi = True
# 创建输出目录
os.makedirs(self.output_path, exist_ok=True)
for subdir in ['train_6class', 'val_6class', 'test_6class',
'train_binary', 'val_binary', 'test_binary',
'statistics', 'visualizations']:
os.makedirs(os.path.join(self.output_path, subdir), exist_ok=True)
# 类别定义(六分类)
self.class_definitions = {
0: {'name': 'water', 'color': [0, 0, 255], 'description': '水体'},
1: {'name': 'forest', 'color': [0, 100, 0], 'description': '森林'},
2: {'name': 'grassland', 'color': [0, 255, 0], 'description': '草地/农田'},
3: {'name': 'bare_soil', 'color': [165, 42, 42], 'description': '裸土'},
4: {'name': 'built_up', 'color': [255, 0, 0], 'description': '建成区/不透水面'},
5: {'name': 'others', 'color': [128, 128, 128], 'description': '其他/未分类'}
}
self.n_classes = len(self.class_definitions)
# MODIS映射(优化版)
self.modis_to_our_class = {
0: 5, # 未分类 -> others
1: 1, 2: 1, 3: 1, 4: 1, 5: 1, # 森林类
6: 2, 7: 2, 8: 2, 9: 2, 10: 2, # 草地/农田类
11: 0, # 水体
12: 2, # 农田
13: 4, # 城市
14: 2, # 农田/自然植被镶嵌
15: 5, # 雪/冰 -> others
16: 3, # 裸地
17: 0, # 水体
255: 5 # 填充值 -> others
}
# 1. 调整阈值(更严格)
self.thresholds = {
'nightlight_urban': 10.0, # 提高到10
'slope_urban_max': 15.0, # 收紧到15度
'ndvi_vegetation': 0.3,
'ndbi_urban': 0.1, # 提高到0.1
'mndwi_water': 0.15,
'ui_urban': 0.15, # 提高到0.15
'bui_urban': 0.2 # 提高到0.2
}
# 采样配置(优化版)
# 固定比例:建成区50%,其他类别共50%
self.sampling_config = {
'train': {
'total_samples': 100000,
'strategy': 'fixed_ratio',
'class_weights': {
0: 0.05, # 水体 5%
1: 0.10, # 森林 10%
2: 0.15, # 草地 15%
3: 0.15, # 裸土 15%
4: 0.50, # 建成区 50%(重点)
5: 0.05 # 其他 5%
}
},
'val': {
'total_samples': 10000,
'strategy': 'fixed_ratio', # 验证集也使用固定比例
'class_weights': {
0: 0.05, # 水体 5%
1: 0.10, # 森林 10%
2: 0.15, # 草地 15%
3: 0.15, # 裸土 15%
4: 0.50, # 建成区 50%
5: 0.05 # 其他 5%
}
},
'test': {
'total_samples': 20000,
'strategy': 'natural', # 测试集使用自然分布
'class_weights': None
}
}
# 日志与统计
self.log_file = os.path.join(
self.output_path,
f"generation_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"
)
self.stats = {
'total_samples_6class': 0,
'total_samples_binary': 0,
'class_distribution_6class': {i: 0 for i in range(self.n_classes)},
'class_distribution_binary': {0: 0, 1: 0},
'years_processed': [],
'errors': []
}
self.log("="*80)
self.log("🚀 改进的不透水面训练数据生成器")
self.log(f"输出路径: {self.output_path}")
self.log("="*80)
def log(self, message, level='INFO'):
"""日志记录"""
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
log_message = f"[{timestamp}] [{level}] {message}"
print(log_message)
try:
with open(self.log_file, 'a', encoding='utf-8') as f:
f.write(log_message + '\n')
except:
pass
def create_enhanced_multiclass_labels_no_ntl(self, features, auxiliary_data, year):
"""
为无夜光数据的年份创建标签(1990-1991)
使用光谱指数组合识别建成区
"""
self.log(f" {year}年无夜光数据,使用光谱指数组合识别建成区...")
h, w = next(iter(features.values())).shape
labels = np.full((h, w), 5, dtype=np.uint8)
confidence = np.zeros((h, w), dtype=np.float32)
# 1. 水体识别
if 'MNDWI' in features:
water_mask = features['MNDWI'] > self.thresholds['mndwi_water']
labels[water_mask] = 0
confidence[water_mask] = 0.9
# 2. 植被识别
if 'NDVI' in features:
vegetation_mask = features['NDVI'] > self.thresholds['ndvi_vegetation']
# 使用DEM区分森林和草地
if 'DEM' in auxiliary_data:
forest_mask = vegetation_mask & (auxiliary_data['DEM'] > 1500) & (features['NDVI'] > 0.4)
grassland_mask = vegetation_mask & ~forest_mask
labels[forest_mask] = 1
labels[grassland_mask] = 2
confidence[forest_mask] = 0.8
confidence[grassland_mask] = 0.75
else:
labels[vegetation_mask] = 2
confidence[vegetation_mask] = 0.75
# 3. 建成区识别(无夜光数据时的替代方案)
# 使用多个指数组合
urban_mask = np.zeros((h, w), dtype=bool)
# 需要NDBI高且NDVI低且BUI高(使用AND逻辑)
if all(x in features for x in ['NDBI', 'NDVI', 'BUI']):
urban_mask = (features['NDBI'] > 0.15) & \
(features['NDVI'] < 0.1) & \
(features['BUI'] > 0.3)
# 如果有UI指数,也需要满足
if 'UI' in features:
urban_mask &= features['UI'] > 0.15
# 额外的坡度限制(更严格)
if 'SLOPE' in auxiliary_data:
urban_mask &= auxiliary_data['SLOPE'] < 10 # 从20度降到10度
# 可选:BSI作为补充条件(但需要更严格)
if 'BSI' in features and 'SLOPE' in auxiliary_data:
# 只在极平坦区域考虑BSI
potential_urban = (features['BSI'] > 0.2) & (auxiliary_data['SLOPE'] < 5)
if 'Blue' in features and 'Red' in features:
# 需要更高的可见光反射
high_visible = (features['Blue'] > 1000) & (features['Red'] > 1000)
potential_urban &= high_visible
# 作为补充,但不是主要判断
urban_mask |= potential_urban
labels[urban_mask] = 4
confidence[urban_mask] = 0.6 # 较低的置信度,因为没有夜光数据
# 4. 裸土识别
bare_soil_mask = (
(features['NDVI'] < 0.2) &
(features['NDBI'] > -0.1) &
(labels == 5) # 只在未分类区域
)
labels[bare_soil_mask] = 3
confidence[bare_soil_mask] = 0.65
# 5. 确保有一定的others类
# 低置信度区域保留为others
low_confidence = confidence < 0.5
labels[low_confidence] = 5
return labels, confidence
def create_enhanced_multiclass_labels(self, features, auxiliary_data, year):
"""创建增强的多类别标签(主函数)"""
self.log("创建增强多类别标签...")
# 检查是否有夜光数据
has_nightlight = False
if 'NTL' in auxiliary_data:
ntl = auxiliary_data['NTL']
has_nightlight = np.any(ntl > 0)
# 根据是否有夜光数据选择不同的处理方式
if not has_nightlight:
return self.create_enhanced_multiclass_labels_no_ntl(features, auxiliary_data, year)
# 有夜光数据的正常处理流程
h, w = next(iter(features.values())).shape
labels = np.full((h, w), 5, dtype=np.uint8)
confidence = np.zeros((h, w), dtype=np.float32)
# 使用MODIS土地覆盖作为基础(如果有)
if 'LC_Type1' in auxiliary_data and year >= 2001:
self.log(" 使用MODIS土地覆盖作为基础...")
modis_lc = auxiliary_data['LC_Type1'].astype(int)
for modis_class, our_class in self.modis_to_our_class.items():
mask = modis_lc == modis_class
labels[mask] = our_class
confidence[mask] = 0.7
# 使用光谱指数细化
self.log(" 使用光谱指数细化...")
# 水体识别
if 'MNDWI' in features:
water_mask = features['MNDWI'] > self.thresholds['mndwi_water']
if 'WATER_MASK' in auxiliary_data:
water_mask |= auxiliary_data['WATER_MASK'] > 0
labels[water_mask] = 0
confidence[water_mask] = 0.9
# 植被识别
if 'NDVI' in features:
vegetation_mask = features['NDVI'] > self.thresholds['ndvi_vegetation']
if 'DEM' in auxiliary_data:
forest_mask = vegetation_mask & (auxiliary_data['DEM'] > 1500) & (features['NDVI'] > 0.4)
grassland_mask = vegetation_mask & ~forest_mask
labels[forest_mask] = 1
labels[grassland_mask] = 2
confidence[forest_mask] = 0.85
confidence[grassland_mask] = 0.8
else:
labels[vegetation_mask] = 2
confidence[vegetation_mask] = 0.8
# 城市识别(使用夜光数据)
self.log(" 使用夜光数据识别城市...")
ntl = auxiliary_data['NTL']
# 修正夜光数据阈值计算
if np.any(ntl > 0):
# 使用更高的百分位数(90)
urban_threshold = max(self.thresholds['nightlight_urban'],
np.percentile(ntl[ntl > 0], 90)) # 改为90百分位
urban_from_ntl = ntl > urban_threshold
# 使用AND逻辑而不是OR
if 'SLOPE' in auxiliary_data:
urban_from_ntl &= auxiliary_data['SLOPE'] < self.thresholds['slope_urban_max']
# 需要多个条件同时满足
if 'NDBI' in features and 'UI' in features:
# 改为AND逻辑,需要多个指标都指向城市
urban_from_indices = (features['NDBI'] > self.thresholds['ndbi_urban']) & \
(features['UI'] > self.thresholds['ui_urban'])
# 夜光和光谱指数需要同时满足
urban_from_ntl = urban_from_ntl & (urban_from_indices | (ntl > urban_threshold * 1.5))
labels[urban_from_ntl] = 4
confidence[urban_from_ntl] = 0.9
# 裸土识别
bare_soil_mask = (
(features['NDVI'] < 0.2) &
(features['NDBI'] > -0.1) &
(labels == 5)
)
labels[bare_soil_mask] = 3
confidence[bare_soil_mask] = 0.7
# 确保有others类(特别是2001年前)
if year <= 2000:
# 低置信度区域设为others
uncertain_mask = (confidence < 0.6) & (labels != 0) # 保留水体
labels[uncertain_mask] = 5
# 空间细化
labels = self._spatial_refinement(labels, confidence)
# 记录类别分布
self._log_class_distribution(labels)
return labels, confidence
def adaptive_sampling(self, labels, features, quality_mask, split='train'):
"""自适应采样(改进版)"""
config = self.sampling_config[split]
total_samples = config['total_samples']
strategy = config['strategy']
self.log(f"[{split}] 自适应采样 (策略: {strategy})")
# 获取有效区域内的类别分布
unique, counts = np.unique(labels[quality_mask], return_counts=True)
if len(counts) == 0:
self.log(" 有效区域内无像元,采样跳过。", 'WARNING')
return []
actual_distribution = dict(zip(unique, counts))
total_valid = sum(counts)
# 根据策略确定目标采样数
if strategy == 'fixed_ratio':
# 固定比例采样
weights = config['class_weights']
target_samples = {}
for i in range(self.n_classes):
target_samples[i] = int(total_samples * weights.get(i, 0))
elif strategy == 'natural':
# 自然分布采样
target_samples = {}
for cid, count in actual_distribution.items():
target_samples[cid] = int(total_samples * count / total_valid)
# 补充未出现的类别
for i in range(self.n_classes):
if i not in target_samples:
target_samples[i] = 0
else:
# 均衡采样
samples_per_class = total_samples // self.n_classes
target_samples = {i: samples_per_class for i in range(self.n_classes)}
# 调整总数以确保精确
current_total = sum(target_samples.values())
if current_total != total_samples:
# 找到最大的类别进行调整
max_class = max(target_samples, key=lambda k: target_samples[k])
target_samples[max_class] += (total_samples - current_total)
# 执行采样
all_samples = []
for class_id in range(self.n_classes):
n_target = target_samples.get(class_id, 0)
if n_target > 0:
class_samples = self._smart_class_sampling(
labels, features, quality_mask, class_id, n_target
)
all_samples.extend(class_samples)
actual_n = len(class_samples)
class_name = self.class_definitions[class_id]['name']
self.log(f" {class_id}-{class_name}: 目标{n_target}, 实际{actual_n}")
return all_samples
def process_single_year(self, year, split='train'):
"""处理单年数据(同时生成六分类和二分类)"""
try:
self.log(f"\n{'='*60}")
self.log(f"处理{year}年数据 (数据集={split})")
# 检查是否已存在
output_file_6class = os.path.join(self.output_path, f"{split}_6class",
f"{split}_6class_{year}.csv")
output_file_binary = os.path.join(self.output_path, f"{split}_binary",
f"{split}_binary_{year}.csv")
if os.path.exists(output_file_6class) and os.path.exists(output_file_binary):
self.log(f"文件已存在,跳过")
return True
# 1. 加载Landsat
self.log("加载Landsat数据...")
landsat_data = self.load_landsat_safe(year)
if landsat_data is None:
return False
mosaic, transform, crs = landsat_data
# 2. 构建 AOI 掩膜
aoi_mask, _ = self.build_aoi_mask(mosaic.shape[1:], transform, crs)
# 3. 加载并对齐辅助数据
auxiliary_data = self.load_and_align_auxiliary_data(
year, transform, mosaic.shape[1:], crs
)
# 4. 特征提取
self.log("提取增强特征...")
features = self.extract_enhanced_features(mosaic, auxiliary_data)
# 5. 创建增强的多类别标签(传入年份参数)
labels, confidence = self.create_enhanced_multiclass_labels(
features, auxiliary_data, year
)
# 6. 质量掩膜
quality_mask = self.create_quality_mask(features, labels)
quality_mask &= aoi_mask
valid_ratio = np.sum(quality_mask) / quality_mask.size * 100
self.log(f"有效像元: {valid_ratio:.2f}%")
if valid_ratio < 0.1:
self.log("有效像元过少,跳过", 'WARNING')
return False
# 7. 自适应采样
sample_coords = self.adaptive_sampling(labels, features, quality_mask, split)
if not sample_coords:
self.log("没有生成样本!", 'WARNING')
return False
# 8. 提取样本
samples = self.process_samples(sample_coords, labels, features, auxiliary_data, year, transform)
# 9. 生成六分类数据
df_6class = pd.DataFrame(samples)
df_6class = df_6class.sample(frac=1, random_state=42).reset_index(drop=True)
df_6class.to_csv(output_file_6class, index=False)
# 10. 生成二分类数据(建成区 vs 非建成区)
df_binary = df_6class.copy()
df_binary['label_binary'] = (df_binary['label'] == 4).astype(int)
df_binary['label_6class'] = df_binary['label'] # 保留原始六分类标签
df_binary['label'] = df_binary['label_binary'] # 主标签改为二分类
df_binary.to_csv(output_file_binary, index=False)
# 11. 统计
self.update_stats(df_6class, df_binary, year, split)
self.visualize_enhanced_labels(labels, auxiliary_data, year)
# 内存清理
del mosaic, features, labels, auxiliary_data
gc.collect()
return True
except Exception as e:
self.log(f"处理失败: {str(e)}", 'ERROR')
self.log(traceback.format_exc(), 'DEBUG')
return False
def update_stats(self, df_6class, df_binary, year, split='train'):
"""更新统计信息"""
n_total = len(df_6class)
# 六分类统计
self.log(f"✅ [{split}] 六分类数据统计:")
self.log(f" 总样本: {n_total:,}")
for class_id in range(self.n_classes):
n_class = sum(df_6class['label'] == class_id)
ratio = n_class / n_total * 100 if n_total > 0 else 0
class_name = self.class_definitions[class_id]['name']
self.log(f" {class_id}-{class_name}: {n_class:,} ({ratio:.2f}%)")
self.stats['class_distribution_6class'][class_id] += n_class
# 二分类统计
self.log(f"✅ [{split}] 二分类数据统计:")
n_impervious = sum(df_binary['label'] == 1)
n_pervious = sum(df_binary['label'] == 0)
self.log(f" 不透水面: {n_impervious:,} ({n_impervious/n_total*100:.2f}%)")
self.log(f" 透水面: {n_pervious:,} ({n_pervious/n_total*100:.2f}%)")
self.stats['class_distribution_binary'][1] += n_impervious
self.stats['class_distribution_binary'][0] += n_pervious
self.stats['total_samples_6class'] += n_total
self.stats['total_samples_binary'] += n_total
if year not in self.stats['years_processed']:
self.stats['years_processed'].append(year)
def merge_datasets(self):
"""合并数据集(六分类和二分类)"""
self.log("\n合并数据集...")
for data_type in ['6class', 'binary']:
for split in ['train', 'val', 'test']:
pattern = os.path.join(self.output_path, f"{split}_{data_type}",
f"{split}_{data_type}_*.csv")
files = glob.glob(pattern)
if files:
all_data = []
for file in sorted(files):
df = pd.read_csv(file)
all_data.append(df)
merged = pd.concat(all_data, ignore_index=True)
output_file = os.path.join(self.output_path,
f"all_{split}_{data_type}.csv")
merged.to_csv(output_file, index=False)
n_total = len(merged)
self.log(f"\n{split}_{data_type}集合并完成: {n_total:,}样本")
if data_type == '6class':
for class_id in range(self.n_classes):
n_class = sum(merged['label'] == class_id)
ratio = n_class / n_total * 100
class_name = self.class_definitions[class_id]['name']
self.log(f" {class_id}-{class_name}: {n_class:,} ({ratio:.2f}%)")
else:
n_impervious = sum(merged['label'] == 1)
n_pervious = sum(merged['label'] == 0)
self.log(f" 不透水面: {n_impervious:,} ({n_impervious/n_total*100:.2f}%)")
self.log(f" 透水面: {n_pervious:,} ({n_pervious/n_total*100:.2f}%)")
# 继承其他必要的方法(与原始代码相同)
def build_aoi_mask(self, target_shape, target_transform, target_crs):
"""构建 AOI 栅格掩膜"""
if not self.require_aoi:
return np.ones(target_shape, dtype=bool), None
with fiona.open(self.aoi_path, 'r') as src:
src_crs = src.crs_wkt or src.crs
geoms = []
for feat in src:
geom = shape(feat['geometry'])
if not geom.is_valid:
geom = geom.buffer(0)
if geom.is_empty:
continue
geoms.append(geom)
if not geoms:
raise ValueError("AOI Shapefile 未读取到有效几何。")
proj = pyproj.Transformer.from_crs(src_crs, target_crs, always_xy=True) if src_crs else None
def _tf(g):
if proj is None:
return g
f = lambda x, y: proj.transform(x, y)
return shp_transform(f, g)
geoms_proj = [mapping(_tf(g)) for g in geoms]
mask = rasterize(
shapes=[(g, 1) for g in geoms_proj],
out_shape=target_shape,
transform=target_transform,
fill=0,
all_touched=False,
dtype='uint8'
).astype(bool)
cov = mask.mean() * 100.0
self.log(f"AOI 覆盖率: {cov:.2f}%")
return mask, geoms_proj
def load_and_align_auxiliary_data(self, year, target_transform, target_shape, target_crs):
"""加载并对齐辅助数据到Landsat坐标系"""
auxiliary_data = {}
if not self.use_auxiliary:
return auxiliary_data
try:
self.log(f"加载{year}年辅助数据...")
aux_pattern = os.path.join(self.auxiliary_path, str(year), f"shanxi_aux_{year}*.tif")
aux_files = sorted(glob.glob(aux_pattern))
if not aux_files:
self.log(f" 未找到{year}年辅助数据", 'WARNING')
return auxiliary_data
self.log(f" 找到{len(aux_files)}个辅助数据文件")
band_names = ['NTL', 'NTL_DMSP', 'NTL_VIIRS', 'DEM', 'SLOPE',
'LC_Type1', 'WATER_MASK', 'WATER_SRC', 'YEAR']
if len(aux_files) > 1:
self.log(f" 合并多个辅助数据文件...")
with ExitStack() as stack:
src_files = [stack.enter_context(rasterio.open(fp)) for fp in aux_files]
first_src = src_files[0]
mosaic, mosaic_transform = merge(src_files)
mosaic_crs = first_src.crs
for i, band_name in enumerate(band_names, 0):
if i < mosaic.shape[0]:
self.log(f" 对齐波段: {band_name}")
aligned_band = np.zeros(target_shape, dtype=np.float32)
reproject(
source=mosaic[i],
destination=aligned_band,
src_transform=mosaic_transform,
src_crs=mosaic_crs,
dst_transform=target_transform,
dst_crs=target_crs,
resampling=Resampling.bilinear if band_name in ['DEM', 'SLOPE', 'NTL', 'NTL_DMSP', 'NTL_VIIRS']
else Resampling.nearest
)
auxiliary_data[band_name] = aligned_band
else:
with rasterio.open(aux_files[0]) as src:
for i, band_name in enumerate(band_names, 1):
if i <= src.count:
self.log(f" 对齐波段: {band_name}")
aligned_band = np.zeros(target_shape, dtype=np.float32)
reproject(
source=rasterio.band(src, i),
destination=aligned_band,
src_transform=src.transform,
src_crs=src.crs,
dst_transform=target_transform,
dst_crs=target_crs,
resampling=Resampling.bilinear if band_name in ['DEM', 'SLOPE', 'NTL', 'NTL_DMSP', 'NTL_VIIRS']
else Resampling.nearest
)
auxiliary_data[band_name] = aligned_band
# 数据清理和后处理
if 'NTL' in auxiliary_data:
ntl = auxiliary_data['NTL']
ntl[ntl < 0] = 0
p99 = np.percentile(ntl[ntl > 0], 99) if np.any(ntl > 0) else 1
auxiliary_data['NTL'] = np.clip(ntl, 0, p99)
if 'NTL_DMSP' in auxiliary_data:
auxiliary_data['NTL_DMSP'] = np.maximum(auxiliary_data['NTL_DMSP'], 0)
if 'NTL_VIIRS' in auxiliary_data:
auxiliary_data['NTL_VIIRS'] = np.maximum(auxiliary_data['NTL_VIIRS'], 0)
# 合并夜光数据源
if year < 2012:
if 'NTL' not in auxiliary_data or np.all(auxiliary_data['NTL'] == 0):
if 'NTL_DMSP' in auxiliary_data:
auxiliary_data['NTL'] = auxiliary_data['NTL_DMSP']
elif year >= 2012:
if 'NTL' not in auxiliary_data or np.all(auxiliary_data['NTL'] == 0):
if 'NTL_VIIRS' in auxiliary_data and np.any(auxiliary_data['NTL_VIIRS'] > 0):
auxiliary_data['NTL'] = auxiliary_data['NTL_VIIRS']
elif 'NTL_DMSP' in auxiliary_data:
auxiliary_data['NTL'] = auxiliary_data['NTL_DMSP']
if 'SLOPE' in auxiliary_data:
auxiliary_data['SLOPE'] = np.clip(auxiliary_data['SLOPE'], 0, 90)
self.log(f" ✅ 成功加载{len(auxiliary_data)}个辅助数据层")
except Exception as e:
self.log(f" ❌ 辅助数据加载失败: {str(e)}", 'ERROR')
self.log(traceback.format_exc(), 'DEBUG')
return auxiliary_data
def extract_enhanced_features(self, mosaic, auxiliary_data=None):
"""提取增强特征"""
features = {}
bands = ['Blue', 'Green', 'Red', 'NIR', 'SWIR1', 'SWIR2']
for i, band in enumerate(bands):
if i < mosaic.shape[0]:
features[band] = mosaic[i].astype(np.float32)
def safe_divide(a, b, default=0):
with np.errstate(divide='ignore', invalid='ignore'):
result = np.where(b != 0, a / b, default)
return np.nan_to_num(result, nan=default, posinf=1, neginf=-1).astype(np.float32)
# 计算光谱指数
features['NDVI'] = safe_divide(features['NIR'] - features['Red'],
features['NIR'] + features['Red'])
features['NDBI'] = safe_divide(features['SWIR1'] - features['NIR'],
features['SWIR1'] + features['NIR'])
features['MNDWI'] = safe_divide(features['Green'] - features['SWIR1'],
features['Green'] + features['SWIR1'])
features['UI'] = safe_divide(features['SWIR2'] - features['NIR'],
features['SWIR2'] + features['NIR'])
features['BUI'] = features['NDBI'] - features['NDVI']
features['BSI'] = safe_divide(
(features['SWIR1'] + features['Red']) - (features['NIR'] + features['Blue']),
(features['SWIR1'] + features['Red']) + (features['NIR'] + features['Blue'])
)
# 添加辅助数据特征
if auxiliary_data:
if 'NTL' in auxiliary_data:
features['NTL'] = auxiliary_data['NTL']
features['NTL_log'] = np.log1p(auxiliary_data['NTL'])
if 'DEM' in auxiliary_data:
features['DEM'] = auxiliary_data['DEM']
if 'SLOPE' in auxiliary_data:
features['SLOPE'] = auxiliary_data['SLOPE']
if 'WATER_MASK' in auxiliary_data:
features['WATER_MASK'] = auxiliary_data['WATER_MASK']
return features
def _spatial_refinement(self, labels, confidence, kernel_size=3):
"""空间细化标签"""
low_conf_mask = confidence < 0.7
if np.any(low_conf_mask):
smoothed = median_filter(labels, size=kernel_size)
labels[low_conf_mask] = smoothed[low_conf_mask]
return labels
def _log_class_distribution(self, labels):
"""记录类别分布"""
unique, counts = np.unique(labels, return_counts=True)
total = labels.size
self.log(" 类别分布:")
for class_id, count in zip(unique, counts):
ratio = count / total * 100
class_name = self.class_definitions[class_id]['name']
self.log(f" {class_id}-{class_name}: {count:,} ({ratio:.2f}%)")
def _smart_class_sampling(self, labels, features, quality_mask, class_id, n_samples):
"""智能类别采样"""
class_mask = (labels == class_id) & quality_mask
class_indices = np.where(class_mask)
n_available = len(class_indices[0])
if n_available == 0:
return []
if n_available <= n_samples:
return [(class_indices[0][i], class_indices[1][i]) for i in range(n_available)]
weights = np.ones(n_available, dtype=np.float32)
# 对于城市类别,优先采样边界区域
if 'NDBI' in features and 'NDVI' in features and class_id == 4:
for idx, (r, c) in enumerate(zip(class_indices[0], class_indices[1])):
ndbi_val = features['NDBI'][r, c]
ndvi_val = features['NDVI'][r, c]
if ndbi_val < 0.1 or ndvi_val > 0.2:
weights[idx] *= 1.5
weights = weights / weights.sum()
indices = np.random.choice(n_available, n_samples, replace=False, p=weights)
return [(class_indices[0][i], class_indices[1][i]) for i in indices]
def load_landsat_safe(self, year):
"""安全加载Landsat数据"""
try:
year_path = os.path.join(self.landsat_path, str(year))
tif_files = sorted(glob.glob(os.path.join(year_path, "*.tif")))[:7]
if not tif_files:
self.log(f"未找到{year}年Landsat影像", 'WARNING')
return None
with ExitStack() as stack:
src_files = [stack.enter_context(rasterio.open(fp)) for fp in tif_files]
mosaic, transform = merge(src_files)
crs = src_files[0].crs
self.log(f"成功加载Landsat: {mosaic.shape}")
return mosaic, transform, crs
except Exception as e:
self.log(f"Landsat加载错误: {str(e)}", 'ERROR')
return None
def create_quality_mask(self, features, labels):
"""创建质量掩膜"""
h, w = labels.shape
quality_mask = np.ones((h, w), dtype=bool)
for band in ['Blue', 'Green', 'Red', 'NIR', 'SWIR1', 'SWIR2']:
if band in features:
band_data = features[band]
quality_mask &= ~np.isnan(band_data)
quality_mask &= ~np.isinf(band_data)
quality_mask &= (band_data > 0)
quality_mask &= (band_data < 20000)
if all(band in features for band in ['Blue', 'Green', 'Red']):
cloud_mask = ((features['Blue'] > 8000) &
(features['Green'] > 8000) &
(features['Red'] > 8000))
quality_mask &= ~cloud_mask
edge_buffer = 10
quality_mask[:edge_buffer, :] = False
quality_mask[-edge_buffer:, :] = False
quality_mask[:, :edge_buffer] = False
quality_mask[:, -edge_buffer:] = False
valid_ratio = np.sum(quality_mask) / quality_mask.size * 100
self.log(f"质量掩膜: {valid_ratio:.1f}%有效像素")
return quality_mask
def process_samples(self, sample_coords, labels, features, auxiliary_data, year, transform):
"""处理样本数据"""
samples = []
for row, col in tqdm(sample_coords, desc="提取特征", disable=len(sample_coords) > 10000):
x, y = rasterio.transform.xy(transform, row, col)
sample = {
'longitude': float(x),
'latitude': float(y),
'label': int(labels[row, col]),
'year': year,
'row': row,
'col': col
}
for feature_name, feature_data in features.items():
if feature_name in auxiliary_data:
continue
value = float(feature_data[row, col])
if np.isnan(value) or np.isinf(value):
value = 0.0
sample[feature_name] = value
if auxiliary_data:
for aux_name in ['NTL', 'NTL_log', 'DEM', 'SLOPE', 'WATER_MASK']:
if aux_name in features:
value = float(features[aux_name][row, col])
if np.isnan(value) or np.isinf(value):
value = 0.0
sample[aux_name] = value
samples.append(sample)
return samples
def visualize_enhanced_labels(self, labels, auxiliary_data, year):
"""可视化增强标签"""
try:
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
step = max(1, labels.shape[0] // 1000)
labels_vis = labels[::step, ::step]
# 多类别分类图
colors = np.array([info['color'] for info in self.class_definitions.values()]) / 255.0
rgb_image = np.zeros((labels_vis.shape[0], labels_vis.shape[1], 3))
for i in range(self.n_classes):
mask = labels_vis == i
rgb_image[mask] = colors[i]
axes[0, 0].imshow(rgb_image)
axes[0, 0].set_title(f'{year}年 多类别分类')
axes[0, 0].axis('off')
legend_patches = [Patch(color=colors[i], label=self.class_definitions[i]['name'])
for i in range(self.n_classes)]
axes[0, 0].legend(handles=legend_patches, loc='upper right', fontsize=8)
# 二分类图(不透水面)
binary_labels = (labels_vis == 4).astype(int)
axes[0, 1].imshow(binary_labels, cmap='RdYlBu_r', vmin=0, vmax=1)
axes[0, 1].set_title('二分类(红=不透水面)')
axes[0, 1].axis('off')
# 夜光数据
if auxiliary_data and 'NTL' in auxiliary_data:
ntl_vis = auxiliary_data['NTL'][::step, ::step]
if np.any(ntl_vis > 0):
ntl_vis_log = np.log1p(ntl_vis)
im3 = axes[0, 2].imshow(ntl_vis_log, cmap='hot', vmin=0)
axes[0, 2].set_title('夜光数据 (log)')
plt.colorbar(im3, ax=axes[0, 2], fraction=0.046)
else:
axes[0, 2].text(0.5, 0.5, f'{year}年无夜光数据',
ha='center', va='center', transform=axes[0, 2].transAxes)
axes[0, 2].axis('off')
# DEM
if auxiliary_data and 'DEM' in auxiliary_data:
dem_vis = auxiliary_data['DEM'][::step, ::step]
im4 = axes[1, 0].imshow(dem_vis, cmap='terrain')
axes[1, 0].set_title('高程 (DEM)')
axes[1, 0].axis('off')
plt.colorbar(im4, ax=axes[1, 0], fraction=0.046)
else:
axes[1, 0].axis('off')
# 坡度
if auxiliary_data and 'SLOPE' in auxiliary_data:
slope_vis = auxiliary_data['SLOPE'][::step, ::step]
im5 = axes[1, 1].imshow(slope_vis, cmap='YlOrRd', vmin=0, vmax=30)
axes[1, 1].set_title('坡度')
axes[1, 1].axis('off')
plt.colorbar(im5, ax=axes[1, 1], fraction=0.046)
else:
axes[1, 1].axis('off')
# 类别分布柱状图
unique, counts = np.unique(labels, return_counts=True)
class_names = [self.class_definitions[i]['name'] for i in unique]
axes[1, 2].bar(class_names, counts / 1000000)
axes[1, 2].set_xlabel('类别')
axes[1, 2].set_ylabel('像素数 (百万)')
axes[1, 2].set_title('类别分布')
axes[1, 2].tick_params(axis='x', rotation=45)
plt.tight_layout()
output_file = os.path.join(self.output_path, 'visualizations',
f'enhanced_labels_{year}.png')
plt.savefig(output_file, dpi=100, bbox_inches='tight')
plt.close()
self.log(f" 可视化已保存: {output_file}")
except Exception as e:
self.log(f" 可视化失败: {str(e)}", 'WARNING')
def generate_train_val_test_split(self, years=None):
"""生成训练、验证和测试集"""
if years is None:
years = list(range(1990, 2019))
# 划分年份
test_years = [2016, 2017, 2018]
val_years = [2014, 2015]
train_years = [y for y in years if y not in test_years + val_years]
self.log(f"\n数据集划分:")
self.log(f"训练集: {len(train_years)}年 ({train_years[0]}-{train_years[-1]})")
self.log(f"验证集: {val_years}")
self.log(f"测试集: {test_years}")
for split, split_years in [('train', train_years), ('val', val_years), ('test', test_years)]:
self.log(f"\n{'='*60}")
self.log(f"生成{split}集...")
for year in split_years:
self.process_single_year(year, split)
def generate_statistics_report(self):
"""生成统计报告"""
report = {
'generation_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
'n_classes': self.n_classes,
'use_auxiliary': self.use_auxiliary,
'output_path': self.output_path,
'statistics_6class': {
'total_samples': self.stats['total_samples_6class'],
'class_distribution': {
i: {
'name': self.class_definitions[i]['name'],
'samples': self.stats['class_distribution_6class'][i],
'percentage': self.stats['class_distribution_6class'][i] / self.stats['total_samples_6class'] * 100
if self.stats['total_samples_6class'] > 0 else 0
}
for i in range(self.n_classes)
}
},
'statistics_binary': {
'total_samples': self.stats['total_samples_binary'],
'impervious': self.stats['class_distribution_binary'][1],
'pervious': self.stats['class_distribution_binary'][0],
'impervious_ratio': self.stats['class_distribution_binary'][1] / self.stats['total_samples_binary'] * 100
if self.stats['total_samples_binary'] > 0 else 0
},
'years_processed': len(self.stats['years_processed']),
'sampling_strategy': self.sampling_config
}
report_file = os.path.join(self.output_path, 'generation_report.json')
with open(report_file, 'w', encoding='utf-8') as f:
json.dump(report, f, indent=2, ensure_ascii=False)
self.log(f"\n📊 统计报告已保存: {report_file}")
# 打印摘要
self.log("\n" + "="*60)
self.log("📈 生成统计摘要:")
self.log(f"处理年份: {len(self.stats['years_processed'])}")
self.log("\n六分类统计:")
self.log(f"总样本数: {self.stats['total_samples_6class']:,}")
for class_id in range(self.n_classes):
n_samples = self.stats['class_distribution_6class'][class_id]
percentage = n_samples / self.stats['total_samples_6class'] * 100 if self.stats['total_samples_6class'] > 0 else 0
class_name = self.class_definitions[class_id]['name']
self.log(f" {class_id}-{class_name}: {n_samples:,} ({percentage:.2f}%)")
self.log("\n二分类统计(不透水面提取):")
self.log(f"总样本数: {self.stats['total_samples_binary']:,}")
if self.stats['total_samples_binary'] > 0:
imp_ratio = self.stats['class_distribution_binary'][1] / self.stats['total_samples_binary'] * 100
per_ratio = self.stats['class_distribution_binary'][0] / self.stats['total_samples_binary'] * 100
self.log(f" 不透水面: {self.stats['class_distribution_binary'][1]:,} ({imp_ratio:.2f}%)")
self.log(f" 透水面: {self.stats['class_distribution_binary'][0]:,} ({per_ratio:.2f}%)")
def main():
"""主函数"""
print("="*80)
print("🚀 改进的不透水面训练数据生成器")
print("📊 同时输出六分类和二分类数据")
print("="*80)
# 检查辅助数据
aux_path = r"D:\山西省辅助数据"
if os.path.exists(aux_path):
aux_years = [d for d in os.listdir(aux_path) if d.isdigit()]
print(f"\n✅ 找到辅助数据: {len(aux_years)}个年份")
use_auxiliary = True
else:
print("\n⚠️ 未找到辅助数据目录")
use_auxiliary = False
# 创建生成器
generator = ImprovedMultiClassGenerator(use_auxiliary=use_auxiliary)
print("\n请选择处理方式:")
print("1. 快速测试(1990, 2000年)")
print("2. 关键年份(1990,1991,2000,2010,2018)")
print("3. 完整数据集(推荐)⭐")
choice = input("\n请选择 (1-3, 默认3): ").strip() or '3'
if choice == '1':
generator.process_single_year(1990, 'test')
generator.process_single_year(2000, 'test')
elif choice == '2':
for year in [1990, 1991, 2000, 2010, 2018]:
generator.process_single_year(year, 'train')
elif choice == '3':
generator.generate_train_val_test_split()
generator.merge_datasets()
# 生成报告
generator.generate_statistics_report()
print(f"\n✅ 处理完成!")
print(f"📁 数据保存在: {generator.output_path}")
print("\n" + "="*60)
print("💡 关键改进:")
print("1. 1990-1991年无夜光数据时使用光谱指数组合识别建成区")
print("2. 训练集和验证集使用固定比例(建成区50%)")
print("3. 测试集使用自然分布")
print("4. 同时输出六分类和二分类数据")
print("5. 确保所有年份都有others类(通过低置信度区域)")
print("\n📊 数据用途:")
print("- 六分类: 用于精细土地覆盖分类")
print("- 二分类: 专门用于不透水面提取")
print("="*60)
if __name__ == "__main__":
main()这个是我的代码,你分析一下
最新发布