源代码如下 请进行修改 import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from rdkit import Chem
from rdkit.Chem import Draw, rdMolDescriptors
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.loader import DataLoader
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
import os
import numpy as np
import random
# 设置随机种子以确保结果可复现
def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
set_seed(42)
# ------------------------- 工具函数 -------------------------
def validate_atomic_nums(atomic_nums):
"""验证并修正原子类型,确保为整数且合法"""
valid_atoms = {1, 6, 7, 8, 16} # H, C, N, O, S
# 处理PyTorch张量类型
if isinstance(atomic_nums, torch.Tensor):
atomic_nums = atomic_nums.cpu().numpy() if atomic_nums.is_cuda else atomic_nums.numpy()
# 处理列表类型
if isinstance(atomic_nums, list):
# 转换为NumPy数组
atomic_nums = np.array(atomic_nums)
# 确保原子类型是一维数组
if atomic_nums.ndim > 1:
atomic_nums = atomic_nums.flatten()
# 转换为整数并验证
atomic_nums = np.round(atomic_nums).astype(int)
return [num if num in valid_atoms else 6 for num in atomic_nums] # 默认替换为碳原子
def smiles_to_graph(smiles):
"""SMILES → 分子图,使用兼容的价态计算方法"""
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
try:
Chem.SanitizeMol(mol)
except Exception as e:
print(f"分子净化失败: {e}")
return None
atom_feats = [atom.GetAtomicNum() for atom in mol.GetAtoms()]
edge_index = []
for bond in mol.GetBonds():
i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
edge_index += [(i, j), (j, i)]
x = torch.tensor(atom_feats, dtype=torch.long).view(-1, 1)
edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
return Data(x=x, edge_index=edge_index)
def augment_mol(smiles):
"""数据增强:随机加氢"""
mol = Chem.MolFromSmiles(smiles)
if not mol:
return smiles
try:
mol = Chem.AddHs(mol)
except:
pass
return Chem.MolToSmiles(mol)
def generate_realistic_edges(node_features, avg_edge_count=20, atomic_nums=None):
"""生成符合价键约束的边,确保原子类型为整数且合法"""
num_nodes = node_features.shape[0]
if num_nodes < 2 or atomic_nums is None:
return torch.empty((2, 0), dtype=torch.long), torch.tensor([])
# 验证并修正原子类型,确保为整数列表
atomic_nums = validate_atomic_nums(atomic_nums)
max_valence = {1: 1, 6: 4, 7: 3, 8: 2, 16: 6}
edge_set = set()
edge_index = []
used_valence = {i: 0 for i in range(num_nodes)}
bond_types = []
similarity = torch.matmul(node_features, node_features.transpose(0, 1)).squeeze()
norm_similarity = similarity / (torch.max(similarity) + 1e-8)
edge_probs = norm_similarity.view(-1)
num_possible_edges = num_nodes * (num_nodes - 1)
num_edges = min(avg_edge_count, num_possible_edges)
_, indices = torch.topk(edge_probs, k=num_edges)
for idx in indices:
u, v = idx // num_nodes, idx % num_nodes
if u == v or (u, v) in edge_set or (v, u) in edge_set:
continue
atom_u, atom_v = atomic_nums[u], atomic_nums[v]
# 确保原子编号在 max_valence 字典中
if atom_u not in max_valence or atom_v not in max_valence:
continue
avail_valence_u = max_valence[atom_u] - used_valence[u]
avail_valence_v = max_valence[atom_v] - used_valence[v]
if avail_valence_u >= 1 and avail_valence_v >= 1:
edge_set.add((u, v))
edge_index.append([u, v])
used_valence[u] += 1
used_valence[v] += 1
if atom_u not in {1} and atom_v not in {1} and \
avail_valence_u >= 2 and avail_valence_v >= 2 and \
np.random.random() < 0.3:
bond_types.append(2)
used_valence[u] += 1
used_valence[v] += 1
else:
bond_types.append(1)
return torch.tensor(edge_index).t().contiguous(), torch.tensor(bond_types)
def build_valid_mol(atomic_nums, edge_index, bond_types=None):
"""构建合法分子,确保原子类型为整数且合法"""
# 验证并修正原子类型
atomic_nums = validate_atomic_nums(atomic_nums)
mol = Chem.RWMol()
for num in atomic_nums:
mol.AddAtom(Chem.Atom(num))
added_edges = set()
bond_types = bond_types.tolist() if bond_types is not None else []
for j in range(edge_index.shape[1]):
u, v = edge_index[0, j].item(), edge_index[1, j].item()
if u < len(atomic_nums) and v < len(atomic_nums) and u != v and (u, v) not in added_edges:
bond_type = Chem.BondType.SINGLE
if j < len(bond_types) and bond_types[j] == 2:
bond_type = Chem.BondType.DOUBLE
mol.AddBond(u, v, bond_type)
added_edges.add((u, v))
added_edges.add((v, u))
try:
Chem.SanitizeMol(mol)
return mol
except Exception as e:
print(f"分子构建失败: {e}")
return None
def mol_to_graph(mol):
"""RDKit Mol → 分子图"""
if mol is None:
return None
try:
Chem.SanitizeMol(mol)
except:
return None
atom_feats = [atom.GetAtomicNum() for atom in mol.GetAtoms()]
edge_index = []
for bond in mol.GetBonds():
i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
edge_index += [(i, j), (j, i)]
x = torch.tensor(atom_feats, dtype=torch.long).view(-1, 1)
edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
return Data(x=x, edge_index=edge_index)
# ------------------------- 数据集 -------------------------
class MolecularGraphDataset(torch.utils.data.Dataset):
def __init__(self, path):
df = pd.read_excel(path, usecols=[1, 2, 3])
df.columns = ['SMILES', 'Concentration', 'Efficiency']
df.dropna(inplace=True)
df['SMILES'] = df['SMILES'].apply(augment_mol)
self.graphs, self.conditions = [], []
scaler = StandardScaler()
cond_scaled = scaler.fit_transform(df[['Concentration', 'Efficiency']])
self.edge_stats = {'edge_counts': []}
valid_count = 0
for i, row in df.iterrows():
graph = smiles_to_graph(row['SMILES'])
if graph:
# 验证并修正原子类型
atomic_nums = graph.x.squeeze().tolist()
atomic_nums = validate_atomic_nums(atomic_nums)
graph.x = torch.tensor(atomic_nums, dtype=torch.long).view(-1, 1)
self.graphs.append(graph)
self.conditions.append(torch.tensor(cond_scaled[i], dtype=torch.float32))
self.edge_stats['edge_counts'].append(graph.edge_index.shape[1] // 2)
valid_count += 1
print(f"成功加载 {valid_count} 个分子,过滤掉 {len(df) - valid_count} 个无效分子")
self.avg_edge_count = int(np.mean(self.edge_stats['edge_counts'])) if self.edge_stats['edge_counts'] else 20
print(f"真实图平均边数: {self.avg_edge_count}")
def __len__(self):
return len(self.graphs)
def __getitem__(self, idx):
return self.graphs[idx], self.conditions[idx]
# ------------------------- 生成器 -------------------------
class GraphGenerator(nn.Module):
def __init__(self, noise_dim=16, condition_dim=2, hidden_dim=64):
super().__init__()
self.fc = nn.Sequential(
nn.Linear(noise_dim + condition_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 10 * 5), # 5种原子类型
nn.Softmax(dim=-1)
)
self.valid_atoms = torch.tensor([1, 6, 7, 8, 16], dtype=torch.long)
self.dropout = nn.Dropout(0.2)
def forward(self, z, cond):
x = torch.cat([z, cond], dim=1)
x = self.dropout(x)
atom_probs = self.fc(x).view(-1, 10, 5)
temperature = 0.8
atom_probs = torch.softmax(atom_probs / temperature, dim=-1)
atom_indices = torch.argmax(atom_probs, dim=-1)
node_feats = self.valid_atoms[atom_indices].float().view(-1, 10, 1) / 16.0
node_feats = torch.clamp(node_feats, 0.0, 1.0)
# 转换为numpy数组并验证原子类型
atomic_nums = (node_feats.squeeze() * 16).cpu().numpy().round().astype(int)
atomic_nums = validate_atomic_nums(atomic_nums)
node_feats = torch.tensor(atomic_nums, dtype=torch.float32).view(-1, 10, 1) / 16.0
return node_feats
# ------------------------- 判别器 -------------------------
class GraphDiscriminator(nn.Module):
def __init__(self, node_feat_dim=1, condition_dim=2):
super().__init__()
self.conv1 = GCNConv(node_feat_dim + 1, 64)
self.bn1 = nn.BatchNorm1d(64)
self.conv2 = GCNConv(64, 32)
self.bn2 = nn.BatchNorm1d(32)
self.fc = nn.Sequential(
nn.Linear(32 + condition_dim, 16),
nn.LeakyReLU(0.2),
nn.Linear(16, 1),
)
def forward(self, data, condition):
x, edge_index, batch = data.x.float(), data.edge_index, data.batch
bond_features = torch.zeros(x.shape[0], 1).to(x.device)
if hasattr(data, 'bond_types'):
bond_features = data.bond_types.float().view(-1, 1)
x = torch.cat([x, bond_features], dim=1)
x = self.conv1(x, edge_index)
x = self.bn1(x).relu()
x = self.conv2(x, edge_index)
x = self.bn2(x).relu()
x = global_mean_pool(x, batch)
x = torch.cat([x, condition], dim=1)
raw_output = self.fc(x)
mol_penalty = []
for i in range(data.num_graphs):
subgraph = data.get_example(i)
mol = Chem.RWMol()
# 验证并修正原子类型
atomic_nums = (subgraph.x.squeeze() * 16).cpu().numpy().round().astype(int)
atomic_nums = validate_atomic_nums(atomic_nums)
for num in atomic_nums:
mol.AddAtom(Chem.Atom(num))
added_edges = set()
bond_types = subgraph.bond_types.tolist() if hasattr(subgraph, 'bond_types') else []
for j in range(edge_index.shape[1]):
u, v = edge_index[0, j].item(), edge_index[1, j].item()
if u < len(atomic_nums) and v < len(atomic_nums) and u != v and (u, v) not in added_edges:
bond_type = Chem.BondType.SINGLE
if j < len(bond_types) and bond_types[j] == 2:
bond_type = Chem.BondType.DOUBLE
mol.AddBond(u, v, bond_type)
added_edges.add((u, v))
added_edges.add((v, u))
try:
Chem.SanitizeMol(mol)
ring_count = rdMolDescriptors.CalcNumRings(mol)
hbd = rdMolDescriptors.CalcNumHBD(mol)
hba = rdMolDescriptors.CalcNumHBA(mol)
valence_violations = 0
max_valence_map = {1: 1, 6: 4, 7: 3, 8: 2, 16: 6}
for atom in mol.GetAtoms():
atomic_num = atom.GetAtomicNum()
if atomic_num in max_valence_map and atom.GetExplicitValence() > max_valence_map[atomic_num]:
valence_violations += 1
ring_bonus = ring_count * 0.5
fg_bonus = (hbd + hba) * 0.3
valence_penalty = valence_violations * 1.0
penalty = -(ring_bonus + fg_bonus + valence_penalty) / 10.0
except Exception as e:
print(f"分子检查失败: {e}")
penalty = -1.0
mol_penalty.append(penalty)
mol_penalty = torch.tensor(mol_penalty, dtype=torch.float32, device=x.device).view(-1, 1)
return torch.sigmoid(raw_output + mol_penalty)
# ------------------------- 自定义批处理函数 -------------------------
def custom_collate(batch):
graphs, conditions = zip(*batch)
batch_graphs = Batch.from_data_list(graphs)
batch_conditions = torch.stack(conditions)
return batch_graphs, batch_conditions
# ------------------------- 训练流程 -------------------------
def train_gan(dataset_path, epochs=1750, batch_size=32, lr=1e-4, device=None):
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
dataset = MolecularGraphDataset(dataset_path)
if len(dataset) == 0:
print("错误:数据集为空,请检查SMILES格式或RDKit版本")
return None, None
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate)
generator = GraphGenerator().to(device)
discriminator = GraphDiscriminator().to(device)
g_opt = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
d_opt = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
loss_fn = nn.BCELoss()
d_loss_history, g_loss_history = [], []
os.makedirs("results", exist_ok=True)
os.chdir("results")
for epoch in range(1, epochs + 1):
generator.train()
discriminator.train()
total_d_loss, total_g_loss = 0.0, 0.0
for real_graphs, conds in dataloader:
real_batch = real_graphs.to(device)
conds = conds.to(device)
batch_size = real_batch.num_graphs
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
# 判别器训练
d_real = discriminator(real_batch, conds)
loss_real = loss_fn(d_real, real_labels)
noise = torch.randn(batch_size, 16).to(device)
fake_nodes = generator(noise, conds).detach()
fake_mols = []
for i in range(fake_nodes.shape[0]):
node_feats = fake_nodes[i]
atomic_nums = (node_feats.squeeze() * 16).cpu().numpy().round().astype(int)
atomic_nums = validate_atomic_nums(atomic_nums)
edge_index, bond_types = generate_realistic_edges(
node_feats, dataset.avg_edge_count, atomic_nums
)
mol = build_valid_mol(atomic_nums, edge_index, bond_types)
if mol and mol.GetNumAtoms() > 0 and mol.GetNumBonds() > 0:
fake_mols.append(mol_to_graph(mol))
else:
print(
f"过滤非法分子:原子数={mol.GetNumAtoms() if mol else 0}, 键数={mol.GetNumBonds() if mol else 0}")
if fake_mols:
fake_batch = Batch.from_data_list(fake_mols).to(device)
d_fake = discriminator(fake_batch, conds[:len(fake_mols)])
loss_fake = loss_fn(d_fake, fake_labels[:len(fake_mols)])
d_loss = loss_real + loss_fake
d_opt.zero_grad()
d_loss.backward()
d_opt.step()
total_d_loss += d_loss.item()
# 生成器训练
noise = torch.randn(batch_size, 16).to(device)
fake_nodes = generator(noise, conds)
fake_graphs = []
for i in range(fake_nodes.shape[0]):
node_feats = fake_nodes[i]
atomic_nums = (node_feats.squeeze() * 16).cpu().numpy().round().astype(int)
atomic_nums = validate_atomic_nums(atomic_nums)
edge_index, bond_types = generate_realistic_edges(
node_feats, dataset.avg_edge_count, atomic_nums
)
fake_graphs.append(Data(
x=node_feats,
edge_index=edge_index,
bond_types=bond_types
))
fake_batch = Batch.from_data_list(fake_graphs).to(device)
g_loss = loss_fn(discriminator(fake_batch, conds), real_labels)
g_opt.zero_grad()
g_loss.backward()
g_opt.step()
total_g_loss += g_loss.item()
avg_d_loss = total_d_loss / len(dataloader)
avg_g_loss = total_g_loss / len(dataloader)
d_loss_history.append(avg_d_loss)
g_loss_history.append(avg_g_loss)
if epoch % 10 == 0:
generator.eval()
discriminator.eval()
with torch.no_grad():
num_samples = 50
noise = torch.randn(num_samples, 16).to(device)
conds = torch.randn(num_samples, 2).to(device)
fake_nodes = generator(noise, conds)
d_real_scores, d_fake_scores = [], []
for i in range(min(num_samples, 10)):
real_idx = np.random.randint(0, len(dataset))
real_graph, real_cond = dataset[real_idx]
real_batch = Batch.from_data_list([real_graph]).to(device)
real_cond = real_cond.unsqueeze(0).to(device)
d_real = discriminator(real_batch, real_cond)
d_real_scores.append(d_real.item())
node_feats = fake_nodes[i]
atomic_nums = (node_feats.squeeze() * 16).cpu().numpy().round().astype(int)
atomic_nums = validate_atomic_nums(atomic_nums)
edge_index, bond_types = generate_realistic_edges(
node_feats, dataset.avg_edge_count, atomic_nums
)
mol = build_valid_mol(atomic_nums, edge_index, bond_types)
if mol:
fake_graph = mol_to_graph(mol)
if fake_graph:
fake_batch = Batch.from_data_list([fake_graph]).to(device)
fake_cond = conds[i].unsqueeze(0)
d_fake = discriminator(fake_batch, fake_cond)
d_fake_scores.append(d_fake.item())
if d_real_scores and d_fake_scores:
print(f"Epoch {epoch}: D_loss={avg_d_loss:.4f}, G_loss={avg_g_loss:.4f}")
print(f"D_real评分范围: {min(d_real_scores):.4f} - {max(d_real_scores):.4f}")
print(f"D_fake评分范围: {min(d_fake_scores):.4f} - {max(d_fake_scores):.4f}")
else:
print(f"Epoch {epoch}: D_loss={avg_d_loss:.4f}, G_loss={avg_g_loss:.4f}")
print("警告:生成的分子均不合法,无法评估D_fake评分")
generator.train()
discriminator.train()
if epoch % 100 == 0:
torch.save(generator.state_dict(), f"generator_epoch_{epoch}.pt")
torch.save(discriminator.state_dict(), f"discriminator_epoch_{epoch}.pt")
generator.eval()
with torch.no_grad():
num_samples = 25
noise = torch.randn(num_samples, 16).to(device)
conds = torch.randn(num_samples, 2).to(device)
fake_nodes = generator(noise, conds)
generated_mols = []
for i in range(num_samples):
node_feats = fake_nodes[i]
atomic_nums = (node_feats.squeeze() * 16).cpu().numpy().round().astype(int)
atomic_nums = validate_atomic_nums(atomic_nums)
edge_index, bond_types = generate_realistic_edges(
node_feats, dataset.avg_edge_count, atomic_nums
)
mol = build_valid_mol(atomic_nums, edge_index, bond_types)
if mol and mol.GetNumAtoms() > 0 and mol.GetNumBonds() > 0:
mol_weight = rdMolDescriptors.CalcExactMolWt(mol)
logp = rdMolDescriptors.CalcCrippenLogP(mol)
tpsa = rdMolDescriptors.CalcTPSA(mol)
generated_mols.append((mol, mol_weight, logp, tpsa))
if generated_mols:
generated_mols.sort(key=lambda x: x[1])
mols = [mol for mol, _, _, _ in generated_mols]
legends = [
f"Molecule {i + 1}\nMW: {mw:.2f}\nLogP: {logp:.2f}\nTPSA: {tpsa:.2f}"
for i, (_, mw, logp, tpsa) in enumerate(generated_mols)
]
img = Draw.MolsToGridImage(
mols,
molsPerRow=5,
subImgSize=(200, 200),
legends=legends
)
img.save(f"generated_molecules_epoch_{epoch}.png")
print(f"Epoch {epoch}: 保存了 {len(generated_mols)} 个合法分子的可视化结果")
else:
print(f"Epoch {epoch}: 未生成任何合法分子")
generator.train()
plt.figure(figsize=(10, 5))
plt.plot(d_loss_history, label='Discriminator Loss')
plt.plot(g_loss_history, label='Generator Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('GAN Loss Curve')
plt.legend()
plt.savefig('gan_loss_curve.png')
plt.close()
return generator, discriminator
# ------------------------- 主函数 -------------------------
if __name__ == "__main__":
print("=" * 80)
print("基于SMILES验证的分子生成GAN系统")
print("=" * 80)
dataset_path = "D:\\python\\pythonProject1\\DATA\\data-8.xlsx"
if not os.path.exists(dataset_path):
print(f"错误:数据集文件 '{dataset_path}' 不存在!")
print("请准备一个Excel文件,包含SMILES、Concentration、Efficiency三列数据。")
exit(1)
print(f"开始加载数据集: {dataset_path}")
print("=" * 80)
print("开始训练分子生成GAN...")
generator, discriminator = train_gan(
dataset_path=dataset_path,
epochs=1750,
batch_size=32,
lr=1e-4,
)
print("=" * 80)
if generator and discriminator:
print("训练完成!")
print("生成器和判别器模型已保存到'results'目录")
print("生成的分子可视化结果已保存到'results'目录")
print("损失曲线已保存为'gan_loss_curve.png'")
else:
print("训练未能启动:数据集为空或存在其他问题")
print("=" * 80)