使用gaussian噪声,基于参考分子生成 1000 个分子,初步过滤后有效分子: 1000
进一步过滤后,含苯环分子数量: 939
含苯环分子比例: 93.90%
[15:31:55] WARNING: could not find number of expected rings. Switching to an approximate ring finding algorithm.
进一步过滤后,单一片段分子数量: 0
单一片段分子比例: 0.00%
没有分子可保存
开始基于参考分子生成新分子(均匀噪声,强制含苯环)...
[15:31:55] WARNING: could not find number of expected rings. Switching to an approximate ring finding algorithm.
[15:31:55] WARNING: could not find number of expected rings. Switching to an approximate ring finding algorithm.
[15:31:55] WARNING: could not find number of expected rings. Switching to an approximate ring finding algorithm.
使用uniform噪声,基于参考分子生成 1000 个分子,初步过滤后有效分子: 1000
[15:31:56] WARNING: could not find number of expected rings. Switching to an approximate ring finding algorithm.
进一步过滤后,含苯环分子数量: 957
含苯环分子比例: 95.70%
[15:31:56] WARNING: could not find number of expected rings. Switching to an approximate ring finding algorithm.
[15:31:56] WARNING: could not find number of expected rings. Switching to an approximate ring finding algorithm.
进一步过滤后,单一片段分子数量: 1
单一片段分子比例: 0.10%这个是下面的代码运行的结果import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from rdkit import Chem
from rdkit.Chem import Draw, rdMolDescriptors
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GCNConv, global_mean_pool, GATConv, GraphConv
from torch_geometric.loader import DataLoader
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
import os
import numpy as np
import random
from functools import lru_cache
from torch.cuda import amp
# 设置随机种子以确保结果可复现
def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
set_seed(42)
# ------------------------- 工具函数 -------------------------
def validate_atomic_nums(atomic_nums):
valid_atoms = {1, 6, 7, 8, 16, 9, 17} # H, C, N, O, S, F, Cl
if isinstance(atomic_nums, torch.Tensor):
atomic_nums = atomic_nums.cpu().numpy() if atomic_nums.is_cuda else atomic_nums.numpy()
if isinstance(atomic_nums, list):
atomic_nums = np.array(atomic_nums)
if atomic_nums.ndim > 1:
atomic_nums = atomic_nums.flatten()
atomic_nums = np.round(atomic_nums).astype(int)
return [num if num in valid_atoms else 6 for num in atomic_nums]
def smiles_to_graph(smiles):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
Chem.SanitizeMol(mol)
atom_feats = [atom.GetAtomicNum() for atom in mol.GetAtoms()]
edge_index = []
for bond in mol.GetBonds():
i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
edge_index += [(i, j), (j, i)]
x = torch.tensor(atom_feats, dtype=torch.long).view(-1, 1)
edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
return Data(x=x, edge_index=edge_index)
def augment_mol(smiles):
mol = Chem.MolFromSmiles(smiles)
if mol:
try:
mol = Chem.AddHs(mol)
except:
pass
return Chem.MolToSmiles(mol)
return smiles
# ------------------------- 增强版边生成算法(改进版)-------------------------
def generate_realistic_edges_improved(node_features, avg_edge_count=20, atomic_nums=None, force_min_edges=True):
if node_features is None or node_features.numel() == 0:
return torch.empty((2, 0), dtype=torch.long), torch.tensor([])
num_nodes = node_features.shape[0]
if num_nodes < 2:
return torch.empty((2, 0), dtype=torch.long), torch.tensor([])
atomic_nums = validate_atomic_nums(atomic_nums)
atomic_nums = [int(num) for num in atomic_nums]
max_valence = {1: 1, 6: 4, 7: 3, 8: 2, 16: 6, 9: 1, 17: 1}
# 计算节点相似度并添加额外的连接倾向
similarity = torch.matmul(node_features, node_features.transpose(0, 1)).squeeze()
# 添加额外的连接偏向,确保分子整体性(避免孤立节点)
connectivity_bias = torch.ones_like(similarity) - torch.eye(num_nodes, dtype=torch.long)
similarity = similarity + 0.5 * connectivity_bias.float()
norm_similarity = similarity / (torch.max(similarity) + 1e-8)
edge_probs = norm_similarity.view(-1)
num_possible_edges = num_nodes * (num_nodes - 1)
num_edges = min(avg_edge_count, num_possible_edges)
_, indices = torch.topk(edge_probs, k=num_edges)
edge_set = set()
edge_index = []
used_valence = {i: 0 for i in range(num_nodes)}
bond_types = []
tried_edges = set()
# 确保至少有一个连接,并优先连接苯环部分
if num_nodes >= 2 and force_min_edges:
if num_nodes >= 6 and all(atomic_nums[i] == 6 for i in range(6)):
# 先连接苯环形成基础结构
for i in range(6):
j = (i + 1) % 6
if (i, j) not in tried_edges:
edge_set.add((i, j))
edge_set.add((j, i))
edge_index.append([i, j])
edge_index.append([j, i])
used_valence[i] += 1
used_valence[j] += 1
bond_types.append(2 if i % 2 == 0 else 1)
bond_types.append(2 if i % 2 == 0 else 1)
tried_edges.add((i, j))
tried_edges.add((j, i))
# 连接苯环到第一个非苯环原子(确保分子整体性)
if num_nodes > 6:
first_non_benzene = 6
# 寻找与苯环连接最紧密的非苯环原子
max_sim = -1
best_benzene = 0
for i in range(6):
sim = norm_similarity[i, first_non_benzene]
if sim > max_sim:
max_sim = sim
best_benzene = i
if (best_benzene, first_non_benzene) not in tried_edges:
edge_set.add((best_benzene, first_non_benzene))
edge_set.add((first_non_benzene, best_benzene))
edge_index.append([best_benzene, first_non_benzene])
edge_index.append([first_non_benzene, best_benzene])
used_valence[best_benzene] += 1
used_valence[first_non_benzene] += 1
bond_types.append(1)
bond_types.append(1)
tried_edges.add((best_benzene, first_non_benzene))
tried_edges.add((first_non_benzene, best_benzene))
else:
# 常规分子的最小连接(确保连通性)
max_sim = -1
best_u, best_v = 0, 1
for i in range(num_nodes):
for j in range(i + 1, num_nodes):
if norm_similarity[i, j] > max_sim and (i, j) not in tried_edges:
max_sim = norm_similarity[i, j]
best_u, best_v = i, j
edge_set.add((best_u, best_v))
edge_set.add((best_v, best_u))
edge_index.append([best_u, best_v])
edge_index.append([best_v, best_u])
used_valence[best_u] += 1
used_valence[best_v] += 1
bond_types.append(1)
bond_types.append(1)
tried_edges.add((best_u, best_v))
tried_edges.add((best_v, best_u))
# 改进边生成逻辑,确保分子连接性(优先连接未连通部分)
for idx in indices:
u, v = (idx // num_nodes).item(), (idx % num_nodes).item()
if u == v or (u, v) in tried_edges or (v, u) in tried_edges:
continue
tried_edges.add((u, v))
tried_edges.add((v, u))
atom_u, atom_v = atomic_nums[u], atomic_nums[v]
max_u, max_v = max_valence[atom_u], max_valence[atom_v]
avail_u, avail_v = max_u - used_valence[u], max_v - used_valence[v]
# 不允许两个氢原子相连
if atom_u == 1 and atom_v == 1:
continue
# 确保连接不会导致分子分裂(至少需要n-1条边连接n个节点)
if len(edge_set) < num_nodes - 1:
edge_set.add((u, v))
edge_set.add((v, u))
edge_index.append([u, v])
edge_index.append([v, u])
used_valence[u] += 1
used_valence[v] += 1
bond_types.append(1)
bond_types.append(1)
continue
# 常规连接逻辑
if avail_u >= 1 and avail_v >= 1:
edge_set.add((u, v))
edge_set.add((v, u))
edge_index.append([u, v])
edge_index.append([v, u])
used_valence[u] += 1
used_valence[v] += 1
bond_types.append(1)
bond_types.append(1)
continue
# 尝试形成双键(增加分子稳定性)
if atom_u != 1 and atom_v != 1 and avail_u >= 2 and avail_v >= 2 and random.random() < 0.3:
edge_set.add((u, v))
edge_set.add((v, u))
edge_index.append([u, v])
edge_index.append([v, u])
used_valence[u] += 2
used_valence[v] += 2
bond_types.append(2)
bond_types.append(2)
# 确保分子连接性的最后检查
if len(edge_index) == 0 and num_nodes >= 2:
edge_index = [[0, 1], [1, 0]]
bond_types = [1, 1]
used_valence[0], used_valence[1] = 1, 1
return torch.tensor(edge_index).t().contiguous(), torch.tensor(bond_types)
@lru_cache(maxsize=128)
def cached_calculate_tpsa(mol):
try:
return rdMolDescriptors.CalcTPSA(mol) if mol else 0
except:
return 0
@lru_cache(maxsize=128)
def cached_calculate_ring_penalty(mol):
try:
if mol:
ssr = Chem.GetSymmSSSR(mol)
return len(ssr) * 0.2
return 0
except:
return 0
# ------------------------- Gumbel-Softmax 实现 -------------------------
def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
gumbels = -(torch.empty_like(logits).exponential_() + eps).log()
gumbels = (logits + gumbels) / tau
y_soft = gumbels.softmax(dim)
if hard:
index = y_soft.max(dim, keepdim=True)[1]
y_hard = torch.zeros_like(logits).scatter_(dim, index, 1.0)
ret = y_hard - y_soft.detach() + y_soft
else:
ret = y_soft
return ret
# ------------------------- 核心优化:强化化学约束(改进版)-------------------------
def build_valid_mol_improved(atomic_nums, edge_index, bond_types=None):
atomic_nums = validate_atomic_nums(atomic_nums)
atomic_nums = [int(num) for num in atomic_nums]
bond_type_map = {1: Chem.BondType.SINGLE, 2: Chem.BondType.DOUBLE}
mol = Chem.RWMol()
atom_indices = []
for num in atomic_nums:
atom = Chem.Atom(num)
idx = mol.AddAtom(atom)
atom_indices.append(idx)
added_edges = set()
if bond_types is not None:
if isinstance(bond_types, torch.Tensor):
bond_types = bond_types.cpu().numpy().tolist() if bond_types.is_cuda else bond_types.numpy().tolist()
else:
bond_types = []
max_valence = {1: 1, 6: 4, 7: 3, 8: 2, 16: 6, 9: 1, 17: 1}
current_valence = {i: 0 for i in range(len(atomic_nums))}
if edge_index.numel() > 0:
# 按重要性排序边,优先连接形成单一片段(苯环连接优先)
edge_importance = []
for j in range(edge_index.shape[1]):
u, v = edge_index[0, j].item(), edge_index[1, j].item()
# 计算边的重要性:苯环与非苯环连接 > 苯环内部连接 > 其他连接
importance = 2.0 if (u < 6 and v >= 6) or (v < 6 and u >= 6) else 1.5 if (u < 6 and v < 6) else 1.0
edge_importance.append((importance, j))
# 按重要性降序排序
edge_order = [j for _, j in sorted(edge_importance, key=lambda x: x[0], reverse=True)]
for j in edge_order:
u, v = edge_index[0, j].item(), edge_index[1, j].item()
if u >= len(atomic_nums) or v >= len(atomic_nums) or u == v or (u, v) in added_edges:
continue
atom_u = mol.GetAtomWithIdx(atom_indices[u])
atom_v = mol.GetAtomWithIdx(atom_indices[v])
max_u = max_valence[atomic_nums[u]]
max_v = max_valence[atomic_nums[v]]
used_u = current_valence[u]
used_v = current_valence[v]
remain_u = max_u - used_u
remain_v = max_v - used_v
if remain_u >= 1 and remain_v >= 1:
bond_type = bond_type_map.get(bond_types[j] if j < len(bond_types) else 1, Chem.BondType.SINGLE)
bond_order = 1 if bond_type == Chem.BondType.SINGLE else 2
if bond_order > min(remain_u, remain_v):
bond_order = 1
bond_type = Chem.BondType.SINGLE
mol.AddBond(atom_indices[u], atom_indices[v], bond_type)
added_edges.add((u, v))
added_edges.add((v, u))
current_valence[u] += bond_order
current_valence[v] += bond_order
# 更稳健的价态修复逻辑(优先保持分子连接性)
try:
Chem.SanitizeMol(mol)
return mol
except:
# 先尝试添加氢原子修复(保持分子完整性)
try:
mol = Chem.AddHs(mol)
Chem.SanitizeMol(mol)
return mol
except:
pass
# 智能移除键(仅移除导致价态问题的非关键键)
try:
problematic_bonds = []
for bond in mol.GetBonds():
begin = bond.GetBeginAtom()
end = bond.GetEndAtom()
if begin.GetExplicitValence() > max_valence[begin.GetAtomicNum()] or \
end.GetExplicitValence() > max_valence[end.GetAtomicNum()]:
problematic_bonds.append(bond.GetIdx())
# 按键类型和连接重要性排序(先移除双键,再移除非苯环连接)
problematic_bonds.sort(key=lambda idx: (
mol.GetBondWithIdx(idx).GetBondTypeAsDouble(),
-1 if (mol.GetBondWithIdx(idx).GetBeginAtomIdx() < 6 and
mol.GetBondWithIdx(idx).GetEndAtomIdx() >= 6) else 1
), reverse=True)
for bond_idx in problematic_bonds:
mol.RemoveBond(mol.GetBondWithIdx(bond_idx).GetBeginAtomIdx(),
mol.GetBondWithIdx(bond_idx).GetEndAtomIdx())
try:
Chem.SanitizeMol(mol)
break
except:
continue
Chem.SanitizeMol(mol)
return mol
except:
# 最后尝试移除氢原子(作为最后的修复手段)
try:
mol = Chem.RemoveHs(mol)
Chem.SanitizeMol(mol)
return mol
except:
return None
def mol_to_graph(mol):
if not mol:
return None
try:
Chem.SanitizeMol(mol)
atom_feats = [atom.GetAtomicNum() for atom in mol.GetAtoms()]
edge_index = []
for bond in mol.GetBonds():
i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
edge_index += [(i, j), (j, i)]
x = torch.tensor(atom_feats, dtype=torch.long).view(-1, 1)
edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
return Data(x=x, edge_index=edge_index)
except:
return None
def filter_valid_mols(mols):
valid_mols = []
for mol in mols:
if mol is None:
continue
try:
Chem.SanitizeMol(mol)
rdMolDescriptors.CalcExactMolWt(mol)
valid_mols.append(mol)
except:
continue
return valid_mols
# ------------------------- 新增:分子片段过滤工具(改进版)-------------------------
def filter_single_fragment_mols_improved(mol_list):
valid_mols = []
for mol in mol_list:
if mol is None:
continue
try:
# 先尝试添加氢原子以确保分子完整性
mol = Chem.AddHs(mol)
Chem.SanitizeMol(mol)
# 使用更鲁棒的方法检测片段数
frags = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=False)
if len(frags) == 1:
# 进一步检查分子重量和原子数(过滤过小分子)
mol_weight = rdMolDescriptors.CalcExactMolWt(mol)
if mol_weight > 50 and mol.GetNumAtoms() > 5:
valid_mols.append(mol)
except:
continue
return valid_mols
# ------------------------- 新增:苯环检测工具 -------------------------
def has_benzene_ring(mol):
if mol is None:
return False
try:
benzene_smarts = Chem.MolFromSmarts("c1ccccc1")
return mol.HasSubstructMatch(benzene_smarts)
except Exception as e:
print(f"检测苯环失败: {e}")
return False
def filter_benzene_mols(mols):
return [mol for mol in mols if has_benzene_ring(mol)]
# ------------------------- 数据集 -------------------------
class MolecularGraphDataset(torch.utils.data.Dataset):
def __init__(self, path):
df = pd.read_excel(path, usecols=[0, 1, 2])
df.columns = ['SMILES', 'Concentration', 'Efficiency']
df.dropna(inplace=True)
df['SMILES'] = df['SMILES'].apply(augment_mol)
self.graphs, self.conditions = [], []
scaler = StandardScaler()
cond_scaled = scaler.fit_transform(df[['Concentration', 'Efficiency']])
self.edge_stats = {'edge_counts': []}
valid_count = 0
for i, row in df.iterrows():
graph = smiles_to_graph(row['SMILES'])
if graph:
atomic_nums = graph.x.squeeze().tolist()
atomic_nums = validate_atomic_nums(atomic_nums)
graph.x = torch.tensor(atomic_nums, dtype=torch.long).view(-1, 1)
self.graphs.append(graph)
self.conditions.append(torch.tensor(cond_scaled[i], dtype=torch.float32))
self.edge_stats['edge_counts'].append(graph.edge_index.shape[1] // 2)
valid_count += 1
print(f"成功加载 {valid_count} 个分子")
self.avg_edge_count = int(np.mean(self.edge_stats['edge_counts'])) if self.edge_stats['edge_counts'] else 20
print(f"真实图平均边数: {self.avg_edge_count}")
def __len__(self):
return len(self.graphs)
def __getitem__(self, idx):
return self.graphs[idx], self.conditions[idx]
# ------------------------- 增强版生成器(含苯环条件编码) -------------------------
class EnhancedGraphGenerator(nn.Module):
def __init__(self, noise_dim=16, condition_dim=2, benzene_condition_dim=1, hidden_dim=128, num_atoms=15):
super().__init__()
self.num_atoms = num_atoms
self.benzene_embedding = nn.Embedding(2, hidden_dim) # 苯环条件嵌入
# 噪声和条件处理
self.initial_transform = nn.Sequential(
nn.Linear(noise_dim + condition_dim + hidden_dim, hidden_dim),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(hidden_dim),
nn.Linear(hidden_dim, hidden_dim)
)
# 苯环专用生成路径
self.benzene_generator = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LeakyReLU(0.2),
nn.Linear(hidden_dim, 6 * 7) # 6个碳原子,7种可能的原子类型
)
# 分子其余部分生成路径
self.rest_generator = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LeakyReLU(0.2),
nn.Linear(hidden_dim, (num_atoms - 6) * 7) # 剩余原子
)
# 原子间连接预测
self.edge_predictor = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LeakyReLU(0.2),
nn.Linear(hidden_dim, num_atoms * num_atoms) # 原子间连接概率
)
self.tau = 0.8
self.valid_atoms = torch.tensor([1, 6, 7, 8, 16, 9, 17], dtype=torch.long)
def forward(self, z, cond, benzene_condition):
# 嵌入苯环条件
benzene_feat = self.benzene_embedding(benzene_condition)
# 合并噪声和条件
x = torch.cat([z, cond, benzene_feat], dim=1)
shared_repr = self.initial_transform(x)
# 生成苯环部分(前6个原子)
benzene_logits = self.benzene_generator(shared_repr).view(-1, 6, 7)
benzene_probs = gumbel_softmax(benzene_logits, tau=self.tau, hard=False, dim=-1)
# 强制前6个原子为碳原子(苯环)
benzene_indices = torch.ones_like(benzene_probs.argmax(dim=-1)) * 1 # 碳的索引是1
benzene_nodes = self.valid_atoms[benzene_indices].float().view(-1, 6, 1) / 17.0
# 生成分子其余部分
if self.num_atoms > 6:
rest_logits = self.rest_generator(shared_repr).view(-1, self.num_atoms - 6, 7)
rest_probs = gumbel_softmax(rest_logits, tau=self.tau, hard=False, dim=-1)
rest_indices = rest_probs.argmax(dim=-1)
rest_nodes = self.valid_atoms[rest_indices].float().view(-1, self.num_atoms - 6, 1) / 17.0
# 合并苯环和其余部分
node_feats = torch.cat([benzene_nodes, rest_nodes], dim=1)
else:
node_feats = benzene_nodes
return node_feats
# ------------------------- 增强版判别器 -------------------------
class EnhancedGraphDiscriminator(nn.Module):
def __init__(self, node_feat_dim=1, condition_dim=2):
super().__init__()
# 图卷积层
self.conv1 = GCNConv(node_feat_dim + 1, 64)
self.bn1 = nn.BatchNorm1d(64)
self.conv2 = GATConv(64, 32, heads=3, concat=False)
self.bn2 = nn.BatchNorm1d(32)
self.conv3 = GraphConv(32, 16)
self.bn3 = nn.BatchNorm1d(16)
# 注意力机制
self.attention = nn.Sequential(
nn.Linear(16, 8),
nn.Tanh(),
nn.Linear(8, 1)
)
# 条件处理
self.condition_processor = nn.Sequential(
nn.Linear(condition_dim, 16),
nn.LeakyReLU(0.2)
)
# 分类器
self.classifier = nn.Sequential(
nn.Linear(16 + 16, 16),
nn.LeakyReLU(0.2),
nn.Linear(16, 1)
)
def forward(self, data, condition):
x, edge_index, batch = data.x.float(), data.edge_index, data.batch
# 提取键类型特征
bond_features = torch.zeros(x.size(0), 1).to(x.device)
if hasattr(data, 'bond_types') and data.bond_types is not None:
bond_types = data.bond_types
for i in range(edge_index.size(1)):
u, v = edge_index[0, i].item(), edge_index[1, i].item()
bond_type = bond_types[i] if i < len(bond_types) else 1
bond_features[u] = max(bond_features[u], bond_type)
bond_features[v] = max(bond_features[v], bond_type)
# 合并原子特征和键特征
x = torch.cat([x, bond_features], dim=1)
# 图卷积
x = self.conv1(x, edge_index)
x = self.bn1(x).relu()
x = self.conv2(x, edge_index)
x = self.bn2(x).relu()
x = self.conv3(x, edge_index)
x = self.bn3(x).relu()
# 注意力池化
attn_weights = self.attention(x).squeeze()
attn_weights = torch.softmax(attn_weights, dim=0)
x = global_mean_pool(x * attn_weights.unsqueeze(-1), batch)
# 处理条件
cond_repr = self.condition_processor(condition)
# 合并图表示和条件
x = torch.cat([x, cond_repr], dim=1)
# 分类
return torch.sigmoid(self.classifier(x))
# ------------------------- 训练流程(含苯环条件) -------------------------
def train_gan(dataset_path, epochs=2000, batch_size=32, lr=1e-5, device=None,
generator_class=EnhancedGraphGenerator):
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
dataset = MolecularGraphDataset(dataset_path)
if len(dataset) == 0:
print("错误:数据集为空")
return None, None
def collate_fn(data_list):
graphs, conditions = zip(*data_list)
batch_graphs = Batch.from_data_list(graphs)
batch_conditions = torch.stack(conditions)
return batch_graphs, batch_conditions
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
# 创建生成器和判别器
generator = generator_class(noise_dim=16, condition_dim=2).to(device)
discriminator = EnhancedGraphDiscriminator().to(device)
# 调整学习率比例,让生成器学习更快
g_opt = optim.Adam(generator.parameters(), lr=lr * 3, betas=(0.5, 0.999))
d_opt = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
g_scheduler = optim.lr_scheduler.ReduceLROnPlateau(g_opt, 'min', patience=50, factor=0.5)
d_scheduler = optim.lr_scheduler.ReduceLROnPlateau(d_opt, 'min', patience=50, factor=0.5)
loss_fn = nn.BCELoss()
use_amp = device.type == 'cuda'
scaler = amp.GradScaler(enabled=use_amp)
d_loss_history, g_loss_history = [], []
valid_mol_counts = []
fragment_counts = []
benzene_ratios = [] # 新增:记录含苯环分子比例
os.makedirs("results", exist_ok=True)
os.chdir("results")
for epoch in range(1, epochs + 1):
generator.train()
discriminator.train()
total_d_loss, total_g_loss = 0.0, 0.0
for real_batch, conds in dataloader:
real_batch = real_batch.to(device)
conds = conds.to(device)
batch_size = real_batch.num_graphs
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
# 判别器训练
with amp.autocast(enabled=use_amp):
d_real = discriminator(real_batch, conds)
loss_real = loss_fn(d_real, real_labels)
noise = torch.randn(batch_size, 16).to(device)
# 生成苯环条件(50%概率要求含苯环)
benzene_condition = torch.randint(0, 2, (batch_size,), device=device)
fake_nodes = generator(noise, conds, benzene_condition).detach()
fake_mols = []
for i in range(fake_nodes.shape[0]):
node_feats = fake_nodes[i]
atomic_nums = (node_feats.squeeze() * 17).cpu().numpy().round().astype(int)
atomic_nums = validate_atomic_nums(atomic_nums)
atomic_nums = [int(num) for num in atomic_nums]
edge_index, bond_types = generate_realistic_edges_improved(
node_feats, dataset.avg_edge_count, atomic_nums
)
mol = build_valid_mol_improved(atomic_nums, edge_index, bond_types)
if mol and mol.GetNumAtoms() > 0 and mol.GetNumBonds() > 0:
fake_mols.append(mol_to_graph(mol))
if fake_mols:
fake_batch = Batch.from_data_list(fake_mols).to(device)
conds_subset = conds[:len(fake_mols)]
d_fake = discriminator(fake_batch, conds_subset)
loss_fake = loss_fn(d_fake, fake_labels[:len(fake_mols)])
d_loss = loss_real + loss_fake
else:
# 如果没有生成有效的分子,创建一个简单的分子作为占位符
mol = Chem.MolFromSmiles("CCO")
fake_graph = mol_to_graph(mol)
fake_batch = Batch.from_data_list([fake_graph]).to(device)
d_fake = discriminator(fake_batch, conds[:1])
loss_fake = loss_fn(d_fake, fake_labels[:1])
d_loss = loss_real + loss_fake
d_opt.zero_grad()
scaler.scale(d_loss).backward()
torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1.0)
scaler.step(d_opt)
scaler.update()
total_d_loss += d_loss.item()
# 生成器训练
with amp.autocast(enabled=use_amp):
noise = torch.randn(batch_size, 16).to(device)
benzene_condition = torch.randint(0, 2, (batch_size,), device=device)
fake_nodes = generator(noise, conds, benzene_condition)
fake_graphs = []
for i in range(fake_nodes.shape[0]):
node_feats = fake_nodes[i]
atomic_nums = (node_feats.squeeze() * 17).cpu().numpy().round().astype(int)
atomic_nums = validate_atomic_nums(atomic_nums)
atomic_nums = [int(num) for num in atomic_nums]
edge_index, bond_types = generate_realistic_edges_improved(
node_feats, dataset.avg_edge_count, atomic_nums
)
fake_graphs.append(Data(x=node_feats, edge_index=edge_index, bond_types=bond_types))
valid_fake_graphs = []
for graph in fake_graphs:
if graph.edge_index.numel() == 0:
graph.edge_index = torch.tensor([[0, 0]], dtype=torch.long).t().to(device)
graph.bond_types = torch.tensor([1], dtype=torch.long).to(device)
valid_fake_graphs.append(graph)
fake_batch = Batch.from_data_list(valid_fake_graphs).to(device)
g_loss = loss_fn(discriminator(fake_batch, conds), real_labels)
g_opt.zero_grad()
scaler.scale(g_loss).backward()
torch.nn.utils.clip_grad_norm_(generator.parameters(), 1.0)
scaler.step(g_opt)
scaler.update()
total_g_loss += g_loss.item()
avg_d_loss = total_d_loss / len(dataloader)
avg_g_loss = total_g_loss / len(dataloader)
d_loss_history.append(avg_d_loss)
g_loss_history.append(avg_g_loss)
g_scheduler.step(avg_g_loss)
d_scheduler.step(avg_d_loss)
if epoch % 10 == 0:
generator.eval()
discriminator.eval()
with torch.no_grad():
num_samples = 50
noise = torch.randn(num_samples, 16).to(device)
conds = torch.randn(num_samples, 2).to(device)
benzene_condition = torch.ones(num_samples, dtype=torch.long, device=device) # 强制生成含苯环
fake_nodes = generator(noise, conds, benzene_condition)
d_real_scores, d_fake_scores = [], []
for i in range(min(num_samples, 10)):
real_idx = np.random.randint(0, len(dataset))
real_graph, real_cond = dataset[real_idx]
real_batch = Batch.from_data_list([real_graph]).to(device)
real_cond = real_cond.unsqueeze(0).to(device)
d_real = discriminator(real_batch, real_cond)
d_real_scores.append(d_real.item())
node_feats = fake_nodes[i]
atomic_nums = (node_feats.squeeze() * 17).cpu().numpy().round().astype(int)
atomic_nums = validate_atomic_nums(atomic_nums)
atomic_nums = [int(num) for num in atomic_nums]
edge_index, bond_types = generate_realistic_edges_improved(
node_feats, dataset.avg_edge_count, atomic_nums
)
mol = build_valid_mol_improved(atomic_nums, edge_index, bond_types)
if mol:
fake_graph = mol_to_graph(mol)
if fake_graph:
fake_batch = Batch.from_data_list([fake_graph]).to(device)
fake_cond = conds[i].unsqueeze(0)
d_fake = discriminator(fake_batch, fake_cond)
d_fake_scores.append(d_fake.item())
if d_real_scores and d_fake_scores:
print(f"Epoch {epoch}: D_loss={avg_d_loss:.4f}, G_loss={avg_g_loss:.4f}")
print(f"D_real评分: {np.mean(d_real_scores):.4f} ± {np.std(d_real_scores):.4f}")
print(f"D_fake评分: {np.mean(d_fake_scores):.4f} ± {np.std(d_fake_scores):.4f}")
print(f"学习率: G={g_opt.param_groups[0]['lr']:.8f}, D={d_opt.param_groups[0]['lr']:.8f}")
else:
print(f"Epoch {epoch}: D_loss={avg_d_loss:.4f}, G_loss={avg_g_loss:.4f}")
generator.train()
if epoch % 100 == 0:
torch.save(generator.state_dict(), f"generator_epoch_{epoch}.pt")
torch.save(discriminator.state_dict(), f"discriminator_epoch_{epoch}.pt")
generator.eval()
with torch.no_grad():
num_samples = 25
noise = torch.randn(num_samples, 16).to(device)
conds = torch.randn(num_samples, 2).to(device)
benzene_condition = torch.ones(num_samples, dtype=torch.long, device=device) # 强制含苯环
fake_nodes = generator(noise, conds, benzene_condition)
generated_mols = []
for i in range(num_samples):
node_feats = fake_nodes[i]
atomic_nums = (node_feats.squeeze() * 17).cpu().numpy().round().astype(int)
atomic_nums = validate_atomic_nums(atomic_nums)
atomic_nums = [int(num) for num in atomic_nums]
edge_index, bond_types = generate_realistic_edges_improved(
node_feats, dataset.avg_edge_count, atomic_nums
)
mol = build_valid_mol_improved(atomic_nums, edge_index, bond_types)
if mol and mol.GetNumAtoms() > 0 and mol.GetNumBonds() > 0:
try:
mol_weight = rdMolDescriptors.CalcExactMolWt(mol)
logp = rdMolDescriptors.CalcCrippenDescriptors(mol)[0]
tpsa = cached_calculate_tpsa(mol)
generated_mols.append((mol, mol_weight, logp, tpsa))
except Exception as e:
print(f"计算分子描述符时出错: {e}")
# 过滤无效分子
valid_mols = filter_valid_mols([mol for mol, _, _, _ in generated_mols])
# 统计含苯环分子比例
benzene_mols = filter_benzene_mols(valid_mols)
benzene_ratio = len(benzene_mols) / len(valid_mols) if valid_mols else 0
benzene_ratios.append(benzene_ratio)
# 统计分子片段情况(使用改进的过滤函数)
single_fragment_mols = filter_single_fragment_mols_improved(valid_mols)
fragment_ratio = len(single_fragment_mols) / len(valid_mols) if valid_mols else 0
fragment_counts.append(fragment_ratio)
valid_mol_counts.append(len(single_fragment_mols))
print(f"Epoch {epoch}: 生成{len(generated_mols)}个分子,初步过滤后保留{len(valid_mols)}个合法分子")
print(f"Epoch {epoch}: 含苯环分子比例: {len(benzene_mols)}/{len(valid_mols)} = {benzene_ratio:.2%}")
print(
f"Epoch {epoch}: 单一片段分子比例: {len(single_fragment_mols)}/{len(valid_mols)} = {fragment_ratio:.2%}")
if single_fragment_mols:
filtered_mols = []
for mol in single_fragment_mols:
try:
mol_weight = rdMolDescriptors.CalcExactMolWt(mol)
logp = rdMolDescriptors.CalcCrippenDescriptors(mol)[0]
tpsa = cached_calculate_tpsa(mol)
filtered_mols.append((mol, mol_weight, logp, tpsa))
except:
continue
if filtered_mols:
filtered_mols.sort(key=lambda x: x[1])
mols = [mol for mol, _, _, _ in filtered_mols]
legends = [
f"Mol {i + 1}\nMW: {mw:.2f}\nLogP: {logp:.2f}\nTPSA: {tpsa:.2f}"
for i, (_, mw, logp, tpsa) in enumerate(filtered_mols)
]
img = Draw.MolsToGridImage(mols, molsPerRow=5, subImgSize=(200, 200), legends=legends)
img.save(f"generated_molecules_epoch_{epoch}.png")
print(f"Epoch {epoch}: 保存{len(filtered_mols)}个单一片段分子的可视化结果")
else:
print(f"Epoch {epoch}: 过滤后无合法单一片段分子可显示")
else:
print(f"Epoch {epoch}: 未生成任何合法单一片段分子")
generator.train()
# 绘制损失曲线
plt.figure(figsize=(10, 5))
plt.plot(d_loss_history, label='Discriminator Loss')
plt.plot(g_loss_history, label='Generator Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('GAN Loss Curve')
plt.legend()
plt.savefig('gan_loss_curve.png')
plt.close()
# 绘制合法分子数量曲线
plt.figure(figsize=(10, 5))
plt.plot([i * 100 for i in range(1, len(valid_mol_counts) + 1)], valid_mol_counts,
label='Valid Single-Fragment Molecules')
plt.xlabel('Epoch')
plt.ylabel('Number of Valid Molecules')
plt.title('Number of Valid Single-Fragment Molecules Generated per Epoch')
plt.legend()
plt.savefig('valid_molecules_curve.png')
plt.close()
# 绘制分子片段比例曲线
plt.figure(figsize=(10, 5))
plt.plot([i * 100 for i in range(1, len(fragment_counts) + 1)], fragment_counts, label='Single-Fragment Ratio')
plt.xlabel('Epoch')
plt.ylabel('Single-Fragment Molecules Ratio')
plt.title('Ratio of Single-Fragment Molecules in Generated Molecules')
plt.legend()
plt.savefig('fragment_ratio_curve.png')
plt.close()
# 绘制苯环比例曲线
plt.figure(figsize=(10, 5))
plt.plot([i * 100 for i in range(1, len(benzene_ratios) + 1)], benzene_ratios, label='Benzene Ring Ratio')
plt.xlabel('Epoch')
plt.ylabel('Benzene Ring Molecules Ratio')
plt.title('Ratio of Molecules Containing Benzene Rings')
plt.legend()
plt.savefig('benzene_ratio_curve.png')
plt.close()
return generator, discriminator
# ------------------------- 参考分子处理与批量生成(改进版)-------------------------
def process_reference_smiles(reference_smiles):
ref_graph = smiles_to_graph(reference_smiles)
if ref_graph is None:
raise ValueError("参考分子SMILES转换图结构失败,请检查SMILES合法性")
ref_concentration = 1.0
ref_efficiency = 0.8
ref_condition = torch.tensor([ref_concentration, ref_efficiency], dtype=torch.float32)
return ref_graph, ref_condition
def generate_molecules_with_reference_improved(generator, reference_smiles, num_samples=1000,
noise_type="gaussian", force_benzene=True, device=None):
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ref_graph, ref_condition = process_reference_smiles(reference_smiles)
ref_condition_batch = ref_condition.unsqueeze(0).repeat(num_samples, 1).to(device)
if noise_type.lower() == "gaussian":
noise = torch.randn(num_samples, 16).to(device)
elif noise_type.lower() == "uniform":
noise = 2 * torch.rand(num_samples, 16).to(device) - 1
else:
raise ValueError("噪声类型必须是 'gaussian' 或 'uniform'")
generator.eval()
generated_mols = []
# 强制生成含苯环分子
benzene_condition = torch.ones(num_samples, dtype=torch.long, device=device) if force_benzene else \
torch.randint(0, 2, (num_samples,), device=device)
with torch.no_grad():
fake_nodes = generator(noise, ref_condition_batch, benzene_condition)
for i in range(num_samples):
node_feats = fake_nodes[i]
atomic_nums = (node_feats.squeeze() * 17).cpu().numpy().round().astype(int)
atomic_nums = validate_atomic_nums(atomic_nums)
atomic_nums = [int(num) for num in atomic_nums]
# 使用改进的边生成函数
edge_index, bond_types = generate_realistic_edges_improved(
node_feats, 20, atomic_nums
)
# 使用改进的分子构建函数
mol = build_valid_mol_improved(atomic_nums, edge_index, bond_types)
if mol and mol.GetNumAtoms() > 0 and mol.GetNumBonds() > 0:
generated_mols.append(mol)
# 过滤无效分子
valid_mols = filter_valid_mols(generated_mols)
print(f"使用{noise_type}噪声,基于参考分子生成 {num_samples} 个分子,初步过滤后有效分子: {len(valid_mols)}")
# 过滤含苯环的分子(如果强制要求)
if force_benzene:
benzene_mols = filter_benzene_mols(valid_mols)
print(f"进一步过滤后,含苯环分子数量: {len(benzene_mols)}")
print(f"含苯环分子比例: {len(benzene_mols) / len(valid_mols):.2%}")
valid_mols = benzene_mols
# 使用改进的单一片段过滤函数
single_fragment_mols = filter_single_fragment_mols_improved(valid_mols)
print(f"进一步过滤后,单一片段分子数量: {len(single_fragment_mols)}")
# 计算片段比例
fragment_ratio = len(single_fragment_mols) / len(valid_mols) if valid_mols else 0
print(f"单一片段分子比例: {fragment_ratio:.2%}")
return single_fragment_mols
# ------------------------- 分子属性计算与保存 -------------------------
def calculate_molecular_properties(mols):
properties = []
for mol in mols:
try:
mol_weight = rdMolDescriptors.CalcExactMolWt(mol)
logp = rdMolDescriptors.CalcCrippenDescriptors(mol)[0]
tpsa = rdMolDescriptors.CalcTPSA(mol)
hba = rdMolDescriptors.CalcNumHBA(mol)
hbd = rdMolDescriptors.CalcNumHBD(mol)
rot_bonds = rdMolDescriptors.CalcNumRotatableBonds(mol)
n_count = sum(1 for atom in mol.GetAtoms() if atom.GetAtomicNum() == 7)
o_count = sum(1 for atom in mol.GetAtoms() if atom.GetAtomicNum() == 8)
s_count = sum(1 for atom in mol.GetAtoms() if atom.GetAtomicNum() == 16)
p_count = sum(1 for atom in mol.GetAtoms() if atom.GetAtomicNum() == 15)
frags = Chem.GetMolFrags(mol, asMols=True)
fragment_count = len(frags)
# 检测苯环
has_benzene = has_benzene_ring(mol)
properties.append({
'SMILES': Chem.MolToSmiles(mol),
'MW': mol_weight,
'LogP': logp,
'TPSA': tpsa,
'HBA': hba,
'HBD': hbd,
'RotBonds': rot_bonds,
'N_count': n_count,
'O_count': o_count,
'S_count': s_count,
'P_count': p_count,
'FragmentCount': fragment_count,
'HasBenzene': has_benzene
})
except Exception as e:
print(f"计算分子属性时出错: {e}")
continue
return properties
def save_molecules(mols, prefix="generated", noise_type="gaussian"):
if not mols:
print("没有分子可保存")
return
subdir = f"{prefix}_{noise_type}"
os.makedirs(subdir, exist_ok=True)
if len(mols) <= 100:
legends = [f"Mol {i + 1}" for i in range(len(mols))]
img = Draw.MolsToGridImage(mols, molsPerRow=5, subImgSize=(300, 300), legends=legends)
img.save(f"{subdir}/{prefix}_{noise_type}_molecules.png")
properties = calculate_molecular_properties(mols)
df = pd.DataFrame(properties)
df.to_csv(f"{subdir}/{prefix}_{noise_type}_properties.csv", index=False)
with open(f"{subdir}/{prefix}_{noise_type}_smiles.smi", "w") as f:
for props in properties:
f.write(f"{props['SMILES']}\n")
print(f"已保存 {len(mols)} 个单一片段分子到目录: {subdir}")
# ------------------------- 新增:生成结果分析工具 -------------------------
def analyze_generated_molecules(mols_gaussian, mols_uniform):
print("\n===== 生成分子分析报告 =====")
count_gaussian = len(mols_gaussian)
count_uniform = len(mols_uniform)
print(f"高斯噪声生成单一片段分子: {count_gaussian}")
print(f"均匀噪声生成单一片段分子: {count_uniform}")
def calculate_avg_properties(mols):
if not mols:
return {}
props = calculate_molecular_properties(mols)
avg_props = {}
for key in props[0].keys():
if key != 'SMILES':
avg_props[key] = sum(p[key] for p in props) / len(props)
return avg_props
avg_gaussian = calculate_avg_properties(mols_gaussian)
avg_uniform = calculate_avg_properties(mols_uniform)
if avg_gaussian and avg_uniform:
print("\n高斯噪声生成分子的平均属性:")
for key, value in avg_gaussian.items():
print(f" {key}: {value:.2f}")
print("\n均匀噪声生成分子的平均属性:")
for key, value in avg_uniform.items():
print(f" {key}: {value:.2f}")
print("\n属性差异 (均匀 - 高斯):")
for key in avg_gaussian.keys():
if key != 'SMILES':
diff = avg_uniform[key] - avg_gaussian[key]
print(f" {key}: {diff:+.2f}")
if mols_gaussian and mols_uniform:
properties = ['MW', 'LogP', 'TPSA', 'HBA', 'HBD', 'RotBonds']
plt.figure(figsize=(15, 10))
for i, prop in enumerate(properties, 1):
plt.subplot(2, 3, i)
gaussian_vals = [p[prop] for p in calculate_molecular_properties(mols_gaussian)]
uniform_vals = [p[prop] for p in calculate_molecular_properties(mols_uniform)]
plt.hist(gaussian_vals, bins=20, alpha=0.5, label='Gaussian')
plt.hist(uniform_vals, bins=20, alpha=0.5, label='Uniform')
plt.title(f'{prop} Distribution')
plt.xlabel(prop)
plt.ylabel('Frequency')
plt.legend()
plt.tight_layout()
plt.savefig('molecular_property_distribution.png')
plt.close()
print("\n分子属性分布图已保存为 'molecular_property_distribution.png'")
# ------------------------- 新增:分子可视化工具 -------------------------
def visualize_molecules_grid(molecules, num_per_row=5, filename="molecules_grid.png", legends=None):
"""创建高质量的分子网格可视化图"""
if not molecules:
print("没有分子可可视化")
return
if legends is None:
legends = [f"Molecule {i + 1}" for i in range(len(molecules))]
try:
img = Draw.MolsToGridImage(
molecules,
molsPerRow=num_per_row,
subImgSize=(300, 300),
legends=legends,
useSVG=False,
highlightAtomLists=None,
highlightBondLists=None
)
img.save(filename)
print(f"分子网格图已保存至: {filename}")
return img
except Exception as e:
print(f"生成分子网格图时出错: {e}")
return None
# ------------------------- 主函数 -------------------------
def main():
print("=" * 80)
print("基于合法SMILES的分子生成GAN系统(改进版)")
print("=" * 80)
dataset_path = "D:\python\pythonProject1\DATA\Inhibitor1368_data.xlsx"
if not os.path.exists(dataset_path):
print(f"错误:数据集文件 '{dataset_path}' 不存在!")
exit(1)
print(f"开始加载数据集: {dataset_path}")
print("=" * 80)
print("开始训练分子生成GAN...")
generator, discriminator = train_gan(
dataset_path=dataset_path,
epochs=200,
batch_size=32,
lr=1e-5,
)
print("=" * 80)
# 设置参考缓蚀剂分子(确保是单一片段含苯环的有效分子)
reference_smiles = "NCCNCc1ccc(O)c2ncccc12"
if generator:
print("训练完成!模型和生成结果已保存")
print("生成的分子可视化结果在'results'目录")
print("损失曲线已保存为'gan_loss_curve.png'")
print("合法分子数量曲线已保存为'valid_molecules_curve.png'")
print("分子片段比例曲线已保存为'fragment_ratio_curve.png'")
print("苯环比例曲线已保存为'benzene_ratio_curve.png'")
# 基于参考分子生成1000个新分子(高斯噪声,强制含苯环)
print("\n开始基于参考分子生成新分子(高斯噪声,强制含苯环)...")
gaussian_mols = generate_molecules_with_reference_improved(
generator,
reference_smiles,
num_samples=1000,
noise_type="gaussian",
force_benzene=True # 强制生成含苯环分子
)
save_molecules(gaussian_mols, prefix="ref_based", noise_type="gaussian_benzene")
# 可视化最佳分子
if gaussian_mols:
# 计算每个分子的QED分数
qed_scores = []
for mol in gaussian_mols:
try:
qed_scores.append(rdMolDescriptors.CalcQED(mol))
except:
qed_scores.append(0)
# 按QED分数排序
sorted_indices = sorted(range(len(qed_scores)), key=lambda i: qed_scores[i], reverse=True)
top_molecules = [gaussian_mols[i] for i in sorted_indices[:25]]
top_legends = [f"QED: {qed_scores[i]:.3f}" for i in sorted_indices[:25]]
# 可视化
visualize_molecules_grid(
top_molecules,
num_per_row=5,
filename="top_molecules_by_qed.png",
legends=top_legends
)
print("已生成并保存最佳分子的可视化结果")
# 基于参考分子生成1000个新分子(均匀噪声,强制含苯环)
print("\n开始基于参考分子生成新分子(均匀噪声,强制含苯环)...")
uniform_mols = generate_molecules_with_reference_improved(
generator,
reference_smiles,
num_samples=1000,
noise_type="uniform",
force_benzene=True # 强制生成含苯环分子
)
save_molecules(uniform_mols, prefix="ref_based", noise_type="uniform_benzene")
# 分析两种噪声生成的分子差异
if gaussian_mols and uniform_mols:
analyze_generated_molecules(gaussian_mols, uniform_mols)
else:
print("训练过程中未生成合法分子,请检查数据")
print("=" * 80)
if __name__ == "__main__":
main()
最新发布