# train_cnn_lstm.py
# PyTorch CNN-LSTM for FY-4B GIIRS profile bias correction (bias = radiosonde - FY4B)
# Author: ChatGPT (adapt to your data)
# Usage:
# python train_cnn_lstm.py --data fy4bt0305_vertical.csv --epochs 60 --batch_size 32
import os
import argparse
import numpy as np
import pandas as pd
from scipy import interpolate
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
# ----------------------------
# Config / Column mapping
# ----------------------------
DEFAULT_PRESSURE_LEVELS = np.linspace(1000, 100, 101) # 101 levels: 1000 -> 100 hPa
DATA_COLUMNS = {
"station": "Station_Id_C",
"time": "Time_Dev_WQ",
"lat": "Lat_Dev",
"lon": "Lon_Dev",
"pressure": "PRS_HWC",
"temp": "TEM", # assume Celsius or Kelvin - see note below
"dpt": "DPT", # dewpoint
# optionally cloud mask column if present:
"clm": "CLM" # if not present, code will fill zeros
}
# If TEM/DPT in Celsius convert to Kelvin? We'll assume they are in degC; if already K, adjust scale outside.
# Function to compute specific humidity from temperature (C) and dewpoint (C) and pressure (hPa)
def compute_specific_humidity(t_c, td_c, p_hpa):
# Magnus formula for saturation vapor pressure (hPa) over water
# t_c, td_c arrays or scalars
a = 6.112
b = 17.67
c = 243.5
e_td = a * np.exp(b * td_c / (td_c + c)) # vapor pressure (hPa)
q = 0.622 * (e_td / (p_hpa - 0.378 * e_td)) # specific humidity (kg/kg)
q_g_per_kg = q * 1000.0 # convert to g/kg (if you prefer)
return q # keep kg/kg
# ----------------------------
# Data utilities
# ----------------------------
def read_and_assemble_profiles(csv_path, pressure_levels=DEFAULT_PRESSURE_LEVELS, min_levels=20):
"""
Read CSV and assemble profiles grouped by station+time.
Returns lists: samples, where each sample is dict with keys:
- 'pressure': array
- 'fy4b_T': array
- 'fy4b_Q': array
- 'obs_T': array (if radiosonde TEM present)
- 'obs_Q': array (if DPT present -> computed)
- 'lat','lon','clm','time'
Assumes CSV rows are levels; you may need to adapt column names.
"""
df = pd.read_csv(csv_path, low_memory=False)
# Ensure required columns exist
for k,v in DATA_COLUMNS.items():
if v not in df.columns:
if k == 'clm':
df[v] = 0.0
else:
raise ValueError(f"Required column '{v}' not found in CSV. Please adjust DATA_COLUMNS mapping.")
# Cast numeric
df[DATA_COLUMNS['pressure']] = pd.to_numeric(df[DATA_COLUMNS['pressure']], errors='coerce')
df[DATA_COLUMNS['temp']] = pd.to_numeric(df[DATA_COLUMNS['temp']], errors='coerce')
df[DATA_COLUMNS['dpt']] = pd.to_numeric(df[DATA_COLUMNS['dpt']], errors='coerce')
df = df.dropna(subset=[DATA_COLUMNS['pressure']])
grouped = df.groupby([DATA_COLUMNS['station'], DATA_COLUMNS['time']])
samples = []
for (sid, t), g in grouped:
# Build arrays
p = g[DATA_COLUMNS['pressure']].values.astype(float)
T = g[DATA_COLUMNS['temp']].values.astype(float)
DPT = g[DATA_COLUMNS['dpt']].values.astype(float)
# convert to SI if needed: assume T/DPT are in degC -> ok for Magnus formula
# compute specific humidity
Q = compute_specific_humidity(T, DPT, p) # kg/kg
if len(p) < min_levels:
continue
# Interpolate onto standard levels (pressure decreases)
try:
# Need monotonic p for interp: sort by pressure descending
idx = np.argsort(-p)
p_sort = p[idx]
T_sort = T[idx]
Q_sort = Q[idx]
# Clip to interpolation range
mask = (pressure_levels <= max(p_sort)) & (pressure_levels >= min(p_sort))
if mask.sum() < 10:
continue
fT = interpolate.interp1d(p_sort, T_sort, kind='linear', bounds_error=False, fill_value="extrapolate")
fQ = interpolate.interp1d(p_sort, Q_sort, kind='linear', bounds_error=False, fill_value="extrapolate")
T_interp = fT(pressure_levels)
Q_interp = fQ(pressure_levels)
except Exception as e:
continue
# Prepare sample dict
sample = {
"station": sid,
"time": t,
"lat": float(g[DATA_COLUMNS['lat']].iloc[0]) if DATA_COLUMNS['lat'] in g.columns else 0.0,
"lon": float(g[DATA_COLUMNS['lon']].iloc[0]) if DATA_COLUMNS['lon'] in g.columns else 0.0,
"clm": float(g[DATA_COLUMNS['clm']].iloc[0]) if DATA_COLUMNS['clm'] in g.columns else 0.0,
"pressure": pressure_levels,
"obs_T": T_interp, # NOTE: this uses same TEM field as obs; if TEM is FY4B, need separate obs column
"obs_Q": Q_interp,
# Here we assume the CSV TEM/DPT are radiosonde obs;
# If CSV contains both FY4B and obs, you must adapt field names accordingly.
}
samples.append(sample)
return samples
# ----------------------------
# Dataset for PyTorch
# ----------------------------
class ProfileDataset(Dataset):
def __init__(self, samples, scalers=None, use_aux=True):
"""
samples: list of sample dicts (see read_and_assemble_profiles)
We'll build X: [levels, feat_dim] where feat_dim includes fy4bT, fy4bQ (here we use obs as placeholder)
y: bias (obs - fy4b) per level. If you have separate FY4B columns adapt upstream.
"""
self.samples = samples
self.use_aux = use_aux
# Build arrays
X_list = []
Y_list = []
aux_list = []
for s in samples:
# Here we only have obs_T/obs_Q; in practice use fy4b_T/fy4b_Q + obs_T/obs_Q
# For demonstration, assume CSV contains both 'FY4B_T' & 'OBS_T' etc.
# We'll use the same obs as both input and truth unless user provides separate columns.
fy4b_T = s.get('fy4b_T', s['obs_T'])
fy4b_Q = s.get('fy4b_Q', s['obs_Q'])
obs_T = s['obs_T']
obs_Q = s['obs_Q']
# input features per level: [fy4b_T, fy4b_Q]
feat = np.stack([fy4b_T, fy4b_Q], axis=1) # shape (levels,2)
X_list.append(feat.astype(np.float32))
# target = obs - fy4b (bias)
y_T = (obs_T - fy4b_T).astype(np.float32)
y_Q = (obs_Q - fy4b_Q).astype(np.float32)
# We'll predict both T and Q bias as two-channel output; for simplicity predict T only:
Y_list.append(y_T.reshape(-1,1).astype(np.float32))
aux = np.array([s['lat'], s['lon'], s['clm']]).astype(np.float32)
aux_list.append(aux)
self.X = np.stack(X_list) # (N, levels, 2)
self.Y = np.stack(Y_list) # (N, levels, 1)
self.aux = np.stack(aux_list)
# scalers
if scalers is None:
# Simple per-feature scaler over levels
self.feat_scaler = StandardScaler()
N, L, F = self.X.shape
self.feat_scaler.fit(self.X.reshape(N*L, F))
else:
self.feat_scaler = scalers['feat_scaler']
# scale features
N, L, F = self.X.shape
self.X = self.feat_scaler.transform(self.X.reshape(N*L, F)).reshape(N, L, F)
# scale target
self.target_scaler = None
def __len__(self):
return self.X.shape[0]
def __getitem__(self, idx):
# return (levels, features) as tensor; model expects (batch, features, levels) for Conv1D
x = torch.from_numpy(self.X[idx]) # (levels, feat)
y = torch.from_numpy(self.Y[idx]) # (levels,1)
aux = torch.from_numpy(self.aux[idx])
# permute x to (feat, levels)
x = x.permute(1,0) # (feat, levels)
y = y.permute(1,0) # (1, levels)
return x, y, aux
# ----------------------------
# Model
# ----------------------------
class CNN_LSTM_Corrector(nn.Module):
def __init__(self, in_channels=2, hidden_channels=64, lstm_hidden=64, levels=101):
super().__init__()
self.conv1 = nn.Conv1d(in_channels, hidden_channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm1d(hidden_channels)
self.conv2 = nn.Conv1d(hidden_channels, hidden_channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm1d(hidden_channels)
# LSTM expects (seq_len, batch, features) -> we'll feed seq_len=levels, features=hidden_channels
self.lstm = nn.LSTM(input_size=hidden_channels, hidden_size=lstm_hidden, num_layers=1, bidirectional=True, batch_first=False)
self.fc1 = nn.Linear(lstm_hidden*2, 64)
self.fc2 = nn.Linear(64, 1) # predict bias per level (T)
self.levels = levels
def forward(self, x_aux):
"""
x_aux: x tuple (x, aux) where x: (batch, feat, levels)
returns: pred (batch, 1, levels)
"""
x, aux = x_aux
# x shape: (batch, feat, levels)
h = self.conv1(x)
h = self.bn1(h)
h = torch.relu(h)
h = self.conv2(h)
h = self.bn2(h)
h = torch.relu(h)
# permute to (levels, batch, hidden_channels) for LSTM
h = h.permute(2, 0, 1) # (levels, batch, hidden)
out, _ = self.lstm(h) # (levels, batch, 2*lstm_hidden)
# apply fc per time step
out = out.permute(1, 0, 2) # (batch, levels, 2*h)
out = self.fc1(out) # (batch, levels, 64)
out = torch.relu(out)
out = self.fc2(out) # (batch, levels, 1)
out = out.permute(0,2,1) # (batch, 1, levels) -> match y
return out
# ----------------------------
# Loss with smoothness constraint
# ----------------------------
def smoothness_loss(pred):
# pred: (batch, 1, levels)
diff = pred[:,:,1:] - pred[:,:,:-1]
return torch.mean(diff**2)
# ----------------------------
# Training loop
# ----------------------------
def train_model(model, train_loader, val_loader, epochs=50, lr=1e-3, weight_smooth=0.01, device='cpu', save_path='model.pth'):
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.MSELoss()
best_val = 1e9
for epoch in range(1, epochs+1):
model.train()
train_loss = 0.0
for x, y, aux in train_loader:
x = x.to(device) # (batch, feat, levels)
y = y.to(device) # (batch,1,levels)
aux = aux.to(device)
optimizer.zero_grad()
pred = model((x, aux))
mse = criterion(pred, y)
smooth = smoothness_loss(pred)
loss = mse + weight_smooth * smooth
loss.backward()
optimizer.step()
train_loss += loss.item() * x.size(0)
train_loss /= len(train_loader.dataset)
# validation
model.eval()
val_loss = 0.0
with torch.no_grad():
for x, y, aux in val_loader:
x = x.to(device); y = y.to(device); aux = aux.to(device)
pred = model((x, aux))
mse = criterion(pred, y)
smooth = smoothness_loss(pred)
loss = mse + weight_smooth * smooth
val_loss += loss.item() * x.size(0)
val_loss /= len(val_loader.dataset)
print(f"Epoch {epoch:03d} TrainLoss={train_loss:.6f} ValLoss={val_loss:.6f}")
if val_loss < best_val:
best_val = val_loss
torch.save(model.state_dict(), save_path)
print(f"Saved best model to {save_path}")
return model
# ----------------------------
# Main
# ----------------------------
def main(args):
print("Reading and assembling profiles...")
samples = read_and_assemble_profiles(args.data, pressure_levels=DEFAULT_PRESSURE_LEVELS, min_levels=20)
print(f"Total samples: {len(samples)}")
if len(samples) < 10:
raise RuntimeError("Too few samples after assembling. Check CSV format and min_levels setting.")
# split
train_s, test_s = train_test_split(samples, test_size=0.2, random_state=42)
train_s, val_s = train_test_split(train_s, test_size=0.2, random_state=42)
train_ds = ProfileDataset(train_s)
val_ds = ProfileDataset(val_s, scalers={"feat_scaler": train_ds.feat_scaler})
test_ds = ProfileDataset(test_s, scalers={"feat_scaler": train_ds.feat_scaler})
print("Dataset sizes:", len(train_ds), len(val_ds), len(test_ds))
train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, drop_last=False)
val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = CNN_LSTM_Corrector(in_channels=2, hidden_channels=64, lstm_hidden=64, levels=DEFAULT_PRESSURE_LEVELS.size)
print(model)
trained = train_model(model, train_loader, val_loader, epochs=args.epochs, lr=args.lr, weight_smooth=args.smooth, device=device, save_path=args.save)
# After training evaluate on test set (compute RMSE/Bias)
model.load_state_dict(torch.load(args.save, map_location=device))
model.to(device).eval()
import math
total_mse = 0.0
total_n = 0
all_bias = []
with torch.no_grad():
for x, y, aux in test_loader:
x = x.to(device); y = y.to(device); aux = aux.to(device)
pred = model((x, aux)) # (batch,1,levels)
# compute bias (mean over levels, then mean over batch)
mse = torch.mean((pred - y)**2).item()
total_mse += mse * x.size(0)
total_n += x.size(0)
all_bias.append((pred - y).cpu().numpy())
total_mse /= total_n
print(f"Test MSE (on bias): {total_mse:.6f}")
# Save scalers & feature mapping for inference
import pickle
meta = {"feat_scaler": train_ds.feat_scaler}
with open(os.path.splitext(args.save)[0] + "_meta.pkl", "wb") as f:
pickle.dump(meta, f)
print("Saved meta info.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data", type=str, default="fy4bt0305_vertical.csv", help="Path to CSV")
parser.add_argument("--epochs", type=int, default=50)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--smooth", type=float, default=0.01, help="weight for smoothness loss")
parser.add_argument("--save", type=str, default="cnn_lstm_corrector.pth")
args = parser.parse_args()
main(args)