In C is “i+=1;” atomic?

本文探讨了C语言中是否所有操作都是原子性的,并指出标准C并未定义这一点。文章强调实践中不能假设任何操作都是原子的,并详细说明了唯一保证原子性的操作——对类型为sig_atomic_t的变量赋值或获取值。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

原文: http://stackoverflow.com/questions/1790204/in-c-is-i-1-atomic/1790234#1790234

=================================================================================================================

The C standard does not define whether it is atomic or not.

In practice, you never write code which fails if a given operation is atomic, but you might well write code which fails if it isn't. So assume it isn't.

=================================================================================================================

The only operation guaranteed by the C language standard to be atomic is assigning or retrieving a value to/from a variable of type sig_atomic_t, defined in <signal.h>.

(C99, chapter 7.14 Signal handling.)

=================================================================================================================

Defined in C, no. In practice, maybe. Write it in assembly.

The standard make no guarantees.

Therefore a portable program would not make the assumption. It's not clear if you mean "required to be atomic", or "happens to be atomic in my C code", and the answer to that second question is that it depends on a lot of things:

  • Not all machines even have an increment memory op. Some need to load and store the value in order to operate on it, so the answer there is "never".

  • On machines that do have an increment memory op, there is no assurance that the compiler will not output a load, increment, and store sequence anyway, or use some other non-atomic instruction.

  • On machines that do have an increment memory operation, it may or may not be atomic with respect to other CPU units.

  • On machines that do have an atomic increment memory op, it may not be specified as part of the architecture, but just a property of a particular edition of the CPU chip, or even just of certain core logic or motherboard designs.

As to "how do I do this atomically", there is generally a way to do this quickly rather than resort to (more expensive) negotiated mutual exclusion. Sometimes this involves special collision-detecting repeatable code sequences. It's best to implement these in an assembly language module, because it's target-specific anyway so there is no portability benefit to the HLL.

Finally, because atomic operations that do not require (expensive) negotiated mutual exclusion are fast and hence useful, and in any case needed for portable code, systems typically have a library, generally written in assembly, that already implements similar functions.

=================================================================================================================


使用gaussian噪声,基于参考分子生成 1000 个分子,初步过滤后有效分子: 1000 进一步过滤后,含苯环分子数量: 939 含苯环分子比例: 93.90% [15:31:55] WARNING: could not find number of expected rings. Switching to an approximate ring finding algorithm. 进一步过滤后,单一片段分子数量: 0 单一片段分子比例: 0.00% 没有分子可保存 开始基于参考分子生成新分子(均匀噪声,强制含苯环)... [15:31:55] WARNING: could not find number of expected rings. Switching to an approximate ring finding algorithm. [15:31:55] WARNING: could not find number of expected rings. Switching to an approximate ring finding algorithm. [15:31:55] WARNING: could not find number of expected rings. Switching to an approximate ring finding algorithm. 使用uniform噪声,基于参考分子生成 1000 个分子,初步过滤后有效分子: 1000 [15:31:56] WARNING: could not find number of expected rings. Switching to an approximate ring finding algorithm. 进一步过滤后,含苯环分子数量: 957 含苯环分子比例: 95.70% [15:31:56] WARNING: could not find number of expected rings. Switching to an approximate ring finding algorithm. [15:31:56] WARNING: could not find number of expected rings. Switching to an approximate ring finding algorithm. 进一步过滤后,单一片段分子数量: 1 单一片段分子比例: 0.10%这个是下面的代码运行的结果import torch import torch.nn as nn import torch.optim as optim import pandas as pd from rdkit import Chem from rdkit.Chem import Draw, rdMolDescriptors from torch_geometric.data import Data, Batch from torch_geometric.nn import GCNConv, global_mean_pool, GATConv, GraphConv from torch_geometric.loader import DataLoader import matplotlib.pyplot as plt from sklearn.preprocessing import StandardScaler import os import numpy as np import random from functools import lru_cache from torch.cuda import amp # 设置随机种子以确保结果可复现 def set_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False set_seed(42) # ------------------------- 工具函数 ------------------------- def validate_atomic_nums(atomic_nums): valid_atoms = {1, 6, 7, 8, 16, 9, 17} # H, C, N, O, S, F, Cl if isinstance(atomic_nums, torch.Tensor): atomic_nums = atomic_nums.cpu().numpy() if atomic_nums.is_cuda else atomic_nums.numpy() if isinstance(atomic_nums, list): atomic_nums = np.array(atomic_nums) if atomic_nums.ndim > 1: atomic_nums = atomic_nums.flatten() atomic_nums = np.round(atomic_nums).astype(int) return [num if num in valid_atoms else 6 for num in atomic_nums] def smiles_to_graph(smiles): mol = Chem.MolFromSmiles(smiles) if mol is None: return None Chem.SanitizeMol(mol) atom_feats = [atom.GetAtomicNum() for atom in mol.GetAtoms()] edge_index = [] for bond in mol.GetBonds(): i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() edge_index += [(i, j), (j, i)] x = torch.tensor(atom_feats, dtype=torch.long).view(-1, 1) edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous() return Data(x=x, edge_index=edge_index) def augment_mol(smiles): mol = Chem.MolFromSmiles(smiles) if mol: try: mol = Chem.AddHs(mol) except: pass return Chem.MolToSmiles(mol) return smiles # ------------------------- 增强版边生成算法(改进版)------------------------- def generate_realistic_edges_improved(node_features, avg_edge_count=20, atomic_nums=None, force_min_edges=True): if node_features is None or node_features.numel() == 0: return torch.empty((2, 0), dtype=torch.long), torch.tensor([]) num_nodes = node_features.shape[0] if num_nodes < 2: return torch.empty((2, 0), dtype=torch.long), torch.tensor([]) atomic_nums = validate_atomic_nums(atomic_nums) atomic_nums = [int(num) for num in atomic_nums] max_valence = {1: 1, 6: 4, 7: 3, 8: 2, 16: 6, 9: 1, 17: 1} # 计算节点相似度并添加额外的连接倾向 similarity = torch.matmul(node_features, node_features.transpose(0, 1)).squeeze() # 添加额外的连接偏向,确保分子整体性(避免孤立节点) connectivity_bias = torch.ones_like(similarity) - torch.eye(num_nodes, dtype=torch.long) similarity = similarity + 0.5 * connectivity_bias.float() norm_similarity = similarity / (torch.max(similarity) + 1e-8) edge_probs = norm_similarity.view(-1) num_possible_edges = num_nodes * (num_nodes - 1) num_edges = min(avg_edge_count, num_possible_edges) _, indices = torch.topk(edge_probs, k=num_edges) edge_set = set() edge_index = [] used_valence = {i: 0 for i in range(num_nodes)} bond_types = [] tried_edges = set() # 确保至少有一个连接,并优先连接苯环部分 if num_nodes >= 2 and force_min_edges: if num_nodes >= 6 and all(atomic_nums[i] == 6 for i in range(6)): # 先连接苯环形成基础结构 for i in range(6): j = (i + 1) % 6 if (i, j) not in tried_edges: edge_set.add((i, j)) edge_set.add((j, i)) edge_index.append([i, j]) edge_index.append([j, i]) used_valence[i] += 1 used_valence[j] += 1 bond_types.append(2 if i % 2 == 0 else 1) bond_types.append(2 if i % 2 == 0 else 1) tried_edges.add((i, j)) tried_edges.add((j, i)) # 连接苯环到第一个非苯环原子(确保分子整体性) if num_nodes > 6: first_non_benzene = 6 # 寻找与苯环连接最紧密的非苯环原子 max_sim = -1 best_benzene = 0 for i in range(6): sim = norm_similarity[i, first_non_benzene] if sim > max_sim: max_sim = sim best_benzene = i if (best_benzene, first_non_benzene) not in tried_edges: edge_set.add((best_benzene, first_non_benzene)) edge_set.add((first_non_benzene, best_benzene)) edge_index.append([best_benzene, first_non_benzene]) edge_index.append([first_non_benzene, best_benzene]) used_valence[best_benzene] += 1 used_valence[first_non_benzene] += 1 bond_types.append(1) bond_types.append(1) tried_edges.add((best_benzene, first_non_benzene)) tried_edges.add((first_non_benzene, best_benzene)) else: # 常规分子的最小连接(确保连通性) max_sim = -1 best_u, best_v = 0, 1 for i in range(num_nodes): for j in range(i + 1, num_nodes): if norm_similarity[i, j] > max_sim and (i, j) not in tried_edges: max_sim = norm_similarity[i, j] best_u, best_v = i, j edge_set.add((best_u, best_v)) edge_set.add((best_v, best_u)) edge_index.append([best_u, best_v]) edge_index.append([best_v, best_u]) used_valence[best_u] += 1 used_valence[best_v] += 1 bond_types.append(1) bond_types.append(1) tried_edges.add((best_u, best_v)) tried_edges.add((best_v, best_u)) # 改进边生成逻辑,确保分子连接性(优先连接未连通部分) for idx in indices: u, v = (idx // num_nodes).item(), (idx % num_nodes).item() if u == v or (u, v) in tried_edges or (v, u) in tried_edges: continue tried_edges.add((u, v)) tried_edges.add((v, u)) atom_u, atom_v = atomic_nums[u], atomic_nums[v] max_u, max_v = max_valence[atom_u], max_valence[atom_v] avail_u, avail_v = max_u - used_valence[u], max_v - used_valence[v] # 不允许两个氢原子相连 if atom_u == 1 and atom_v == 1: continue # 确保连接不会导致分子分裂(至少需要n-1条边连接n个节点) if len(edge_set) < num_nodes - 1: edge_set.add((u, v)) edge_set.add((v, u)) edge_index.append([u, v]) edge_index.append([v, u]) used_valence[u] += 1 used_valence[v] += 1 bond_types.append(1) bond_types.append(1) continue # 常规连接逻辑 if avail_u >= 1 and avail_v >= 1: edge_set.add((u, v)) edge_set.add((v, u)) edge_index.append([u, v]) edge_index.append([v, u]) used_valence[u] += 1 used_valence[v] += 1 bond_types.append(1) bond_types.append(1) continue # 尝试形成双键(增加分子稳定性) if atom_u != 1 and atom_v != 1 and avail_u >= 2 and avail_v >= 2 and random.random() < 0.3: edge_set.add((u, v)) edge_set.add((v, u)) edge_index.append([u, v]) edge_index.append([v, u]) used_valence[u] += 2 used_valence[v] += 2 bond_types.append(2) bond_types.append(2) # 确保分子连接性的最后检查 if len(edge_index) == 0 and num_nodes >= 2: edge_index = [[0, 1], [1, 0]] bond_types = [1, 1] used_valence[0], used_valence[1] = 1, 1 return torch.tensor(edge_index).t().contiguous(), torch.tensor(bond_types) @lru_cache(maxsize=128) def cached_calculate_tpsa(mol): try: return rdMolDescriptors.CalcTPSA(mol) if mol else 0 except: return 0 @lru_cache(maxsize=128) def cached_calculate_ring_penalty(mol): try: if mol: ssr = Chem.GetSymmSSSR(mol) return len(ssr) * 0.2 return 0 except: return 0 # ------------------------- Gumbel-Softmax 实现 ------------------------- def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1): gumbels = -(torch.empty_like(logits).exponential_() + eps).log() gumbels = (logits + gumbels) / tau y_soft = gumbels.softmax(dim) if hard: index = y_soft.max(dim, keepdim=True)[1] y_hard = torch.zeros_like(logits).scatter_(dim, index, 1.0) ret = y_hard - y_soft.detach() + y_soft else: ret = y_soft return ret # ------------------------- 核心优化:强化化学约束(改进版)------------------------- def build_valid_mol_improved(atomic_nums, edge_index, bond_types=None): atomic_nums = validate_atomic_nums(atomic_nums) atomic_nums = [int(num) for num in atomic_nums] bond_type_map = {1: Chem.BondType.SINGLE, 2: Chem.BondType.DOUBLE} mol = Chem.RWMol() atom_indices = [] for num in atomic_nums: atom = Chem.Atom(num) idx = mol.AddAtom(atom) atom_indices.append(idx) added_edges = set() if bond_types is not None: if isinstance(bond_types, torch.Tensor): bond_types = bond_types.cpu().numpy().tolist() if bond_types.is_cuda else bond_types.numpy().tolist() else: bond_types = [] max_valence = {1: 1, 6: 4, 7: 3, 8: 2, 16: 6, 9: 1, 17: 1} current_valence = {i: 0 for i in range(len(atomic_nums))} if edge_index.numel() > 0: # 按重要性排序边,优先连接形成单一片段(苯环连接优先) edge_importance = [] for j in range(edge_index.shape[1]): u, v = edge_index[0, j].item(), edge_index[1, j].item() # 计算边的重要性:苯环与非苯环连接 > 苯环内部连接 > 其他连接 importance = 2.0 if (u < 6 and v >= 6) or (v < 6 and u >= 6) else 1.5 if (u < 6 and v < 6) else 1.0 edge_importance.append((importance, j)) # 按重要性降序排序 edge_order = [j for _, j in sorted(edge_importance, key=lambda x: x[0], reverse=True)] for j in edge_order: u, v = edge_index[0, j].item(), edge_index[1, j].item() if u >= len(atomic_nums) or v >= len(atomic_nums) or u == v or (u, v) in added_edges: continue atom_u = mol.GetAtomWithIdx(atom_indices[u]) atom_v = mol.GetAtomWithIdx(atom_indices[v]) max_u = max_valence[atomic_nums[u]] max_v = max_valence[atomic_nums[v]] used_u = current_valence[u] used_v = current_valence[v] remain_u = max_u - used_u remain_v = max_v - used_v if remain_u >= 1 and remain_v >= 1: bond_type = bond_type_map.get(bond_types[j] if j < len(bond_types) else 1, Chem.BondType.SINGLE) bond_order = 1 if bond_type == Chem.BondType.SINGLE else 2 if bond_order > min(remain_u, remain_v): bond_order = 1 bond_type = Chem.BondType.SINGLE mol.AddBond(atom_indices[u], atom_indices[v], bond_type) added_edges.add((u, v)) added_edges.add((v, u)) current_valence[u] += bond_order current_valence[v] += bond_order # 更稳健的价态修复逻辑(优先保持分子连接性) try: Chem.SanitizeMol(mol) return mol except: # 先尝试添加氢原子修复(保持分子完整性) try: mol = Chem.AddHs(mol) Chem.SanitizeMol(mol) return mol except: pass # 智能移除键(仅移除导致价态问题的非关键键) try: problematic_bonds = [] for bond in mol.GetBonds(): begin = bond.GetBeginAtom() end = bond.GetEndAtom() if begin.GetExplicitValence() > max_valence[begin.GetAtomicNum()] or \ end.GetExplicitValence() > max_valence[end.GetAtomicNum()]: problematic_bonds.append(bond.GetIdx()) # 按键类型和连接重要性排序(先移除双键,再移除非苯环连接) problematic_bonds.sort(key=lambda idx: ( mol.GetBondWithIdx(idx).GetBondTypeAsDouble(), -1 if (mol.GetBondWithIdx(idx).GetBeginAtomIdx() < 6 and mol.GetBondWithIdx(idx).GetEndAtomIdx() >= 6) else 1 ), reverse=True) for bond_idx in problematic_bonds: mol.RemoveBond(mol.GetBondWithIdx(bond_idx).GetBeginAtomIdx(), mol.GetBondWithIdx(bond_idx).GetEndAtomIdx()) try: Chem.SanitizeMol(mol) break except: continue Chem.SanitizeMol(mol) return mol except: # 最后尝试移除氢原子(作为最后的修复手段) try: mol = Chem.RemoveHs(mol) Chem.SanitizeMol(mol) return mol except: return None def mol_to_graph(mol): if not mol: return None try: Chem.SanitizeMol(mol) atom_feats = [atom.GetAtomicNum() for atom in mol.GetAtoms()] edge_index = [] for bond in mol.GetBonds(): i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() edge_index += [(i, j), (j, i)] x = torch.tensor(atom_feats, dtype=torch.long).view(-1, 1) edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous() return Data(x=x, edge_index=edge_index) except: return None def filter_valid_mols(mols): valid_mols = [] for mol in mols: if mol is None: continue try: Chem.SanitizeMol(mol) rdMolDescriptors.CalcExactMolWt(mol) valid_mols.append(mol) except: continue return valid_mols # ------------------------- 新增:分子片段过滤工具(改进版)------------------------- def filter_single_fragment_mols_improved(mol_list): valid_mols = [] for mol in mol_list: if mol is None: continue try: # 先尝试添加氢原子以确保分子完整性 mol = Chem.AddHs(mol) Chem.SanitizeMol(mol) # 使用更鲁棒的方法检测片段数 frags = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=False) if len(frags) == 1: # 进一步检查分子重量和原子数(过滤过小分子) mol_weight = rdMolDescriptors.CalcExactMolWt(mol) if mol_weight > 50 and mol.GetNumAtoms() > 5: valid_mols.append(mol) except: continue return valid_mols # ------------------------- 新增:苯环检测工具 ------------------------- def has_benzene_ring(mol): if mol is None: return False try: benzene_smarts = Chem.MolFromSmarts("c1ccccc1") return mol.HasSubstructMatch(benzene_smarts) except Exception as e: print(f"检测苯环失败: {e}") return False def filter_benzene_mols(mols): return [mol for mol in mols if has_benzene_ring(mol)] # ------------------------- 数据集 ------------------------- class MolecularGraphDataset(torch.utils.data.Dataset): def __init__(self, path): df = pd.read_excel(path, usecols=[0, 1, 2]) df.columns = ['SMILES', 'Concentration', 'Efficiency'] df.dropna(inplace=True) df['SMILES'] = df['SMILES'].apply(augment_mol) self.graphs, self.conditions = [], [] scaler = StandardScaler() cond_scaled = scaler.fit_transform(df[['Concentration', 'Efficiency']]) self.edge_stats = {'edge_counts': []} valid_count = 0 for i, row in df.iterrows(): graph = smiles_to_graph(row['SMILES']) if graph: atomic_nums = graph.x.squeeze().tolist() atomic_nums = validate_atomic_nums(atomic_nums) graph.x = torch.tensor(atomic_nums, dtype=torch.long).view(-1, 1) self.graphs.append(graph) self.conditions.append(torch.tensor(cond_scaled[i], dtype=torch.float32)) self.edge_stats['edge_counts'].append(graph.edge_index.shape[1] // 2) valid_count += 1 print(f"成功加载 {valid_count} 个分子") self.avg_edge_count = int(np.mean(self.edge_stats['edge_counts'])) if self.edge_stats['edge_counts'] else 20 print(f"真实图平均边数: {self.avg_edge_count}") def __len__(self): return len(self.graphs) def __getitem__(self, idx): return self.graphs[idx], self.conditions[idx] # ------------------------- 增强版生成器(含苯环条件编码) ------------------------- class EnhancedGraphGenerator(nn.Module): def __init__(self, noise_dim=16, condition_dim=2, benzene_condition_dim=1, hidden_dim=128, num_atoms=15): super().__init__() self.num_atoms = num_atoms self.benzene_embedding = nn.Embedding(2, hidden_dim) # 苯环条件嵌入 # 噪声和条件处理 self.initial_transform = nn.Sequential( nn.Linear(noise_dim + condition_dim + hidden_dim, hidden_dim), nn.LeakyReLU(0.2), nn.BatchNorm1d(hidden_dim), nn.Linear(hidden_dim, hidden_dim) ) # 苯环专用生成路径 self.benzene_generator = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU(0.2), nn.Linear(hidden_dim, 6 * 7) # 6个碳原子,7种可能的原子类型 ) # 分子其余部分生成路径 self.rest_generator = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU(0.2), nn.Linear(hidden_dim, (num_atoms - 6) * 7) # 剩余原子 ) # 原子间连接预测 self.edge_predictor = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU(0.2), nn.Linear(hidden_dim, num_atoms * num_atoms) # 原子间连接概率 ) self.tau = 0.8 self.valid_atoms = torch.tensor([1, 6, 7, 8, 16, 9, 17], dtype=torch.long) def forward(self, z, cond, benzene_condition): # 嵌入苯环条件 benzene_feat = self.benzene_embedding(benzene_condition) # 合并噪声和条件 x = torch.cat([z, cond, benzene_feat], dim=1) shared_repr = self.initial_transform(x) # 生成苯环部分(前6个原子) benzene_logits = self.benzene_generator(shared_repr).view(-1, 6, 7) benzene_probs = gumbel_softmax(benzene_logits, tau=self.tau, hard=False, dim=-1) # 强制前6个原子为碳原子(苯环) benzene_indices = torch.ones_like(benzene_probs.argmax(dim=-1)) * 1 # 碳的索引是1 benzene_nodes = self.valid_atoms[benzene_indices].float().view(-1, 6, 1) / 17.0 # 生成分子其余部分 if self.num_atoms > 6: rest_logits = self.rest_generator(shared_repr).view(-1, self.num_atoms - 6, 7) rest_probs = gumbel_softmax(rest_logits, tau=self.tau, hard=False, dim=-1) rest_indices = rest_probs.argmax(dim=-1) rest_nodes = self.valid_atoms[rest_indices].float().view(-1, self.num_atoms - 6, 1) / 17.0 # 合并苯环和其余部分 node_feats = torch.cat([benzene_nodes, rest_nodes], dim=1) else: node_feats = benzene_nodes return node_feats # ------------------------- 增强版判别器 ------------------------- class EnhancedGraphDiscriminator(nn.Module): def __init__(self, node_feat_dim=1, condition_dim=2): super().__init__() # 图卷积层 self.conv1 = GCNConv(node_feat_dim + 1, 64) self.bn1 = nn.BatchNorm1d(64) self.conv2 = GATConv(64, 32, heads=3, concat=False) self.bn2 = nn.BatchNorm1d(32) self.conv3 = GraphConv(32, 16) self.bn3 = nn.BatchNorm1d(16) # 注意力机制 self.attention = nn.Sequential( nn.Linear(16, 8), nn.Tanh(), nn.Linear(8, 1) ) # 条件处理 self.condition_processor = nn.Sequential( nn.Linear(condition_dim, 16), nn.LeakyReLU(0.2) ) # 分类器 self.classifier = nn.Sequential( nn.Linear(16 + 16, 16), nn.LeakyReLU(0.2), nn.Linear(16, 1) ) def forward(self, data, condition): x, edge_index, batch = data.x.float(), data.edge_index, data.batch # 提取键类型特征 bond_features = torch.zeros(x.size(0), 1).to(x.device) if hasattr(data, 'bond_types') and data.bond_types is not None: bond_types = data.bond_types for i in range(edge_index.size(1)): u, v = edge_index[0, i].item(), edge_index[1, i].item() bond_type = bond_types[i] if i < len(bond_types) else 1 bond_features[u] = max(bond_features[u], bond_type) bond_features[v] = max(bond_features[v], bond_type) # 合并原子特征和键特征 x = torch.cat([x, bond_features], dim=1) # 图卷积 x = self.conv1(x, edge_index) x = self.bn1(x).relu() x = self.conv2(x, edge_index) x = self.bn2(x).relu() x = self.conv3(x, edge_index) x = self.bn3(x).relu() # 注意力池化 attn_weights = self.attention(x).squeeze() attn_weights = torch.softmax(attn_weights, dim=0) x = global_mean_pool(x * attn_weights.unsqueeze(-1), batch) # 处理条件 cond_repr = self.condition_processor(condition) # 合并图表示和条件 x = torch.cat([x, cond_repr], dim=1) # 分类 return torch.sigmoid(self.classifier(x)) # ------------------------- 训练流程(含苯环条件) ------------------------- def train_gan(dataset_path, epochs=2000, batch_size=32, lr=1e-5, device=None, generator_class=EnhancedGraphGenerator): if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"使用设备: {device}") dataset = MolecularGraphDataset(dataset_path) if len(dataset) == 0: print("错误:数据集为空") return None, None def collate_fn(data_list): graphs, conditions = zip(*data_list) batch_graphs = Batch.from_data_list(graphs) batch_conditions = torch.stack(conditions) return batch_graphs, batch_conditions dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) # 创建生成器和判别器 generator = generator_class(noise_dim=16, condition_dim=2).to(device) discriminator = EnhancedGraphDiscriminator().to(device) # 调整学习率比例,让生成器学习更快 g_opt = optim.Adam(generator.parameters(), lr=lr * 3, betas=(0.5, 0.999)) d_opt = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999)) g_scheduler = optim.lr_scheduler.ReduceLROnPlateau(g_opt, 'min', patience=50, factor=0.5) d_scheduler = optim.lr_scheduler.ReduceLROnPlateau(d_opt, 'min', patience=50, factor=0.5) loss_fn = nn.BCELoss() use_amp = device.type == 'cuda' scaler = amp.GradScaler(enabled=use_amp) d_loss_history, g_loss_history = [], [] valid_mol_counts = [] fragment_counts = [] benzene_ratios = [] # 新增:记录含苯环分子比例 os.makedirs("results", exist_ok=True) os.chdir("results") for epoch in range(1, epochs + 1): generator.train() discriminator.train() total_d_loss, total_g_loss = 0.0, 0.0 for real_batch, conds in dataloader: real_batch = real_batch.to(device) conds = conds.to(device) batch_size = real_batch.num_graphs real_labels = torch.ones(batch_size, 1).to(device) fake_labels = torch.zeros(batch_size, 1).to(device) # 判别器训练 with amp.autocast(enabled=use_amp): d_real = discriminator(real_batch, conds) loss_real = loss_fn(d_real, real_labels) noise = torch.randn(batch_size, 16).to(device) # 生成苯环条件(50%概率要求含苯环) benzene_condition = torch.randint(0, 2, (batch_size,), device=device) fake_nodes = generator(noise, conds, benzene_condition).detach() fake_mols = [] for i in range(fake_nodes.shape[0]): node_feats = fake_nodes[i] atomic_nums = (node_feats.squeeze() * 17).cpu().numpy().round().astype(int) atomic_nums = validate_atomic_nums(atomic_nums) atomic_nums = [int(num) for num in atomic_nums] edge_index, bond_types = generate_realistic_edges_improved( node_feats, dataset.avg_edge_count, atomic_nums ) mol = build_valid_mol_improved(atomic_nums, edge_index, bond_types) if mol and mol.GetNumAtoms() > 0 and mol.GetNumBonds() > 0: fake_mols.append(mol_to_graph(mol)) if fake_mols: fake_batch = Batch.from_data_list(fake_mols).to(device) conds_subset = conds[:len(fake_mols)] d_fake = discriminator(fake_batch, conds_subset) loss_fake = loss_fn(d_fake, fake_labels[:len(fake_mols)]) d_loss = loss_real + loss_fake else: # 如果没有生成有效的分子,创建一个简单的分子作为占位符 mol = Chem.MolFromSmiles("CCO") fake_graph = mol_to_graph(mol) fake_batch = Batch.from_data_list([fake_graph]).to(device) d_fake = discriminator(fake_batch, conds[:1]) loss_fake = loss_fn(d_fake, fake_labels[:1]) d_loss = loss_real + loss_fake d_opt.zero_grad() scaler.scale(d_loss).backward() torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1.0) scaler.step(d_opt) scaler.update() total_d_loss += d_loss.item() # 生成器训练 with amp.autocast(enabled=use_amp): noise = torch.randn(batch_size, 16).to(device) benzene_condition = torch.randint(0, 2, (batch_size,), device=device) fake_nodes = generator(noise, conds, benzene_condition) fake_graphs = [] for i in range(fake_nodes.shape[0]): node_feats = fake_nodes[i] atomic_nums = (node_feats.squeeze() * 17).cpu().numpy().round().astype(int) atomic_nums = validate_atomic_nums(atomic_nums) atomic_nums = [int(num) for num in atomic_nums] edge_index, bond_types = generate_realistic_edges_improved( node_feats, dataset.avg_edge_count, atomic_nums ) fake_graphs.append(Data(x=node_feats, edge_index=edge_index, bond_types=bond_types)) valid_fake_graphs = [] for graph in fake_graphs: if graph.edge_index.numel() == 0: graph.edge_index = torch.tensor([[0, 0]], dtype=torch.long).t().to(device) graph.bond_types = torch.tensor([1], dtype=torch.long).to(device) valid_fake_graphs.append(graph) fake_batch = Batch.from_data_list(valid_fake_graphs).to(device) g_loss = loss_fn(discriminator(fake_batch, conds), real_labels) g_opt.zero_grad() scaler.scale(g_loss).backward() torch.nn.utils.clip_grad_norm_(generator.parameters(), 1.0) scaler.step(g_opt) scaler.update() total_g_loss += g_loss.item() avg_d_loss = total_d_loss / len(dataloader) avg_g_loss = total_g_loss / len(dataloader) d_loss_history.append(avg_d_loss) g_loss_history.append(avg_g_loss) g_scheduler.step(avg_g_loss) d_scheduler.step(avg_d_loss) if epoch % 10 == 0: generator.eval() discriminator.eval() with torch.no_grad(): num_samples = 50 noise = torch.randn(num_samples, 16).to(device) conds = torch.randn(num_samples, 2).to(device) benzene_condition = torch.ones(num_samples, dtype=torch.long, device=device) # 强制生成含苯环 fake_nodes = generator(noise, conds, benzene_condition) d_real_scores, d_fake_scores = [], [] for i in range(min(num_samples, 10)): real_idx = np.random.randint(0, len(dataset)) real_graph, real_cond = dataset[real_idx] real_batch = Batch.from_data_list([real_graph]).to(device) real_cond = real_cond.unsqueeze(0).to(device) d_real = discriminator(real_batch, real_cond) d_real_scores.append(d_real.item()) node_feats = fake_nodes[i] atomic_nums = (node_feats.squeeze() * 17).cpu().numpy().round().astype(int) atomic_nums = validate_atomic_nums(atomic_nums) atomic_nums = [int(num) for num in atomic_nums] edge_index, bond_types = generate_realistic_edges_improved( node_feats, dataset.avg_edge_count, atomic_nums ) mol = build_valid_mol_improved(atomic_nums, edge_index, bond_types) if mol: fake_graph = mol_to_graph(mol) if fake_graph: fake_batch = Batch.from_data_list([fake_graph]).to(device) fake_cond = conds[i].unsqueeze(0) d_fake = discriminator(fake_batch, fake_cond) d_fake_scores.append(d_fake.item()) if d_real_scores and d_fake_scores: print(f"Epoch {epoch}: D_loss={avg_d_loss:.4f}, G_loss={avg_g_loss:.4f}") print(f"D_real评分: {np.mean(d_real_scores):.4f} ± {np.std(d_real_scores):.4f}") print(f"D_fake评分: {np.mean(d_fake_scores):.4f} ± {np.std(d_fake_scores):.4f}") print(f"学习率: G={g_opt.param_groups[0]['lr']:.8f}, D={d_opt.param_groups[0]['lr']:.8f}") else: print(f"Epoch {epoch}: D_loss={avg_d_loss:.4f}, G_loss={avg_g_loss:.4f}") generator.train() if epoch % 100 == 0: torch.save(generator.state_dict(), f"generator_epoch_{epoch}.pt") torch.save(discriminator.state_dict(), f"discriminator_epoch_{epoch}.pt") generator.eval() with torch.no_grad(): num_samples = 25 noise = torch.randn(num_samples, 16).to(device) conds = torch.randn(num_samples, 2).to(device) benzene_condition = torch.ones(num_samples, dtype=torch.long, device=device) # 强制含苯环 fake_nodes = generator(noise, conds, benzene_condition) generated_mols = [] for i in range(num_samples): node_feats = fake_nodes[i] atomic_nums = (node_feats.squeeze() * 17).cpu().numpy().round().astype(int) atomic_nums = validate_atomic_nums(atomic_nums) atomic_nums = [int(num) for num in atomic_nums] edge_index, bond_types = generate_realistic_edges_improved( node_feats, dataset.avg_edge_count, atomic_nums ) mol = build_valid_mol_improved(atomic_nums, edge_index, bond_types) if mol and mol.GetNumAtoms() > 0 and mol.GetNumBonds() > 0: try: mol_weight = rdMolDescriptors.CalcExactMolWt(mol) logp = rdMolDescriptors.CalcCrippenDescriptors(mol)[0] tpsa = cached_calculate_tpsa(mol) generated_mols.append((mol, mol_weight, logp, tpsa)) except Exception as e: print(f"计算分子描述符时出错: {e}") # 过滤无效分子 valid_mols = filter_valid_mols([mol for mol, _, _, _ in generated_mols]) # 统计含苯环分子比例 benzene_mols = filter_benzene_mols(valid_mols) benzene_ratio = len(benzene_mols) / len(valid_mols) if valid_mols else 0 benzene_ratios.append(benzene_ratio) # 统计分子片段情况(使用改进的过滤函数) single_fragment_mols = filter_single_fragment_mols_improved(valid_mols) fragment_ratio = len(single_fragment_mols) / len(valid_mols) if valid_mols else 0 fragment_counts.append(fragment_ratio) valid_mol_counts.append(len(single_fragment_mols)) print(f"Epoch {epoch}: 生成{len(generated_mols)}个分子,初步过滤后保留{len(valid_mols)}个合法分子") print(f"Epoch {epoch}: 含苯环分子比例: {len(benzene_mols)}/{len(valid_mols)} = {benzene_ratio:.2%}") print( f"Epoch {epoch}: 单一片段分子比例: {len(single_fragment_mols)}/{len(valid_mols)} = {fragment_ratio:.2%}") if single_fragment_mols: filtered_mols = [] for mol in single_fragment_mols: try: mol_weight = rdMolDescriptors.CalcExactMolWt(mol) logp = rdMolDescriptors.CalcCrippenDescriptors(mol)[0] tpsa = cached_calculate_tpsa(mol) filtered_mols.append((mol, mol_weight, logp, tpsa)) except: continue if filtered_mols: filtered_mols.sort(key=lambda x: x[1]) mols = [mol for mol, _, _, _ in filtered_mols] legends = [ f"Mol {i + 1}\nMW: {mw:.2f}\nLogP: {logp:.2f}\nTPSA: {tpsa:.2f}" for i, (_, mw, logp, tpsa) in enumerate(filtered_mols) ] img = Draw.MolsToGridImage(mols, molsPerRow=5, subImgSize=(200, 200), legends=legends) img.save(f"generated_molecules_epoch_{epoch}.png") print(f"Epoch {epoch}: 保存{len(filtered_mols)}个单一片段分子的可视化结果") else: print(f"Epoch {epoch}: 过滤后无合法单一片段分子可显示") else: print(f"Epoch {epoch}: 未生成任何合法单一片段分子") generator.train() # 绘制损失曲线 plt.figure(figsize=(10, 5)) plt.plot(d_loss_history, label='Discriminator Loss') plt.plot(g_loss_history, label='Generator Loss') plt.xlabel('Epoch') plt.ylabel('Loss') plt.title('GAN Loss Curve') plt.legend() plt.savefig('gan_loss_curve.png') plt.close() # 绘制合法分子数量曲线 plt.figure(figsize=(10, 5)) plt.plot([i * 100 for i in range(1, len(valid_mol_counts) + 1)], valid_mol_counts, label='Valid Single-Fragment Molecules') plt.xlabel('Epoch') plt.ylabel('Number of Valid Molecules') plt.title('Number of Valid Single-Fragment Molecules Generated per Epoch') plt.legend() plt.savefig('valid_molecules_curve.png') plt.close() # 绘制分子片段比例曲线 plt.figure(figsize=(10, 5)) plt.plot([i * 100 for i in range(1, len(fragment_counts) + 1)], fragment_counts, label='Single-Fragment Ratio') plt.xlabel('Epoch') plt.ylabel('Single-Fragment Molecules Ratio') plt.title('Ratio of Single-Fragment Molecules in Generated Molecules') plt.legend() plt.savefig('fragment_ratio_curve.png') plt.close() # 绘制苯环比例曲线 plt.figure(figsize=(10, 5)) plt.plot([i * 100 for i in range(1, len(benzene_ratios) + 1)], benzene_ratios, label='Benzene Ring Ratio') plt.xlabel('Epoch') plt.ylabel('Benzene Ring Molecules Ratio') plt.title('Ratio of Molecules Containing Benzene Rings') plt.legend() plt.savefig('benzene_ratio_curve.png') plt.close() return generator, discriminator # ------------------------- 参考分子处理与批量生成(改进版)------------------------- def process_reference_smiles(reference_smiles): ref_graph = smiles_to_graph(reference_smiles) if ref_graph is None: raise ValueError("参考分子SMILES转换图结构失败,请检查SMILES合法性") ref_concentration = 1.0 ref_efficiency = 0.8 ref_condition = torch.tensor([ref_concentration, ref_efficiency], dtype=torch.float32) return ref_graph, ref_condition def generate_molecules_with_reference_improved(generator, reference_smiles, num_samples=1000, noise_type="gaussian", force_benzene=True, device=None): if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ref_graph, ref_condition = process_reference_smiles(reference_smiles) ref_condition_batch = ref_condition.unsqueeze(0).repeat(num_samples, 1).to(device) if noise_type.lower() == "gaussian": noise = torch.randn(num_samples, 16).to(device) elif noise_type.lower() == "uniform": noise = 2 * torch.rand(num_samples, 16).to(device) - 1 else: raise ValueError("噪声类型必须是 'gaussian' 或 'uniform'") generator.eval() generated_mols = [] # 强制生成含苯环分子 benzene_condition = torch.ones(num_samples, dtype=torch.long, device=device) if force_benzene else \ torch.randint(0, 2, (num_samples,), device=device) with torch.no_grad(): fake_nodes = generator(noise, ref_condition_batch, benzene_condition) for i in range(num_samples): node_feats = fake_nodes[i] atomic_nums = (node_feats.squeeze() * 17).cpu().numpy().round().astype(int) atomic_nums = validate_atomic_nums(atomic_nums) atomic_nums = [int(num) for num in atomic_nums] # 使用改进的边生成函数 edge_index, bond_types = generate_realistic_edges_improved( node_feats, 20, atomic_nums ) # 使用改进的分子构建函数 mol = build_valid_mol_improved(atomic_nums, edge_index, bond_types) if mol and mol.GetNumAtoms() > 0 and mol.GetNumBonds() > 0: generated_mols.append(mol) # 过滤无效分子 valid_mols = filter_valid_mols(generated_mols) print(f"使用{noise_type}噪声,基于参考分子生成 {num_samples} 个分子,初步过滤后有效分子: {len(valid_mols)}") # 过滤含苯环的分子(如果强制要求) if force_benzene: benzene_mols = filter_benzene_mols(valid_mols) print(f"进一步过滤后,含苯环分子数量: {len(benzene_mols)}") print(f"含苯环分子比例: {len(benzene_mols) / len(valid_mols):.2%}") valid_mols = benzene_mols # 使用改进的单一片段过滤函数 single_fragment_mols = filter_single_fragment_mols_improved(valid_mols) print(f"进一步过滤后,单一片段分子数量: {len(single_fragment_mols)}") # 计算片段比例 fragment_ratio = len(single_fragment_mols) / len(valid_mols) if valid_mols else 0 print(f"单一片段分子比例: {fragment_ratio:.2%}") return single_fragment_mols # ------------------------- 分子属性计算与保存 ------------------------- def calculate_molecular_properties(mols): properties = [] for mol in mols: try: mol_weight = rdMolDescriptors.CalcExactMolWt(mol) logp = rdMolDescriptors.CalcCrippenDescriptors(mol)[0] tpsa = rdMolDescriptors.CalcTPSA(mol) hba = rdMolDescriptors.CalcNumHBA(mol) hbd = rdMolDescriptors.CalcNumHBD(mol) rot_bonds = rdMolDescriptors.CalcNumRotatableBonds(mol) n_count = sum(1 for atom in mol.GetAtoms() if atom.GetAtomicNum() == 7) o_count = sum(1 for atom in mol.GetAtoms() if atom.GetAtomicNum() == 8) s_count = sum(1 for atom in mol.GetAtoms() if atom.GetAtomicNum() == 16) p_count = sum(1 for atom in mol.GetAtoms() if atom.GetAtomicNum() == 15) frags = Chem.GetMolFrags(mol, asMols=True) fragment_count = len(frags) # 检测苯环 has_benzene = has_benzene_ring(mol) properties.append({ 'SMILES': Chem.MolToSmiles(mol), 'MW': mol_weight, 'LogP': logp, 'TPSA': tpsa, 'HBA': hba, 'HBD': hbd, 'RotBonds': rot_bonds, 'N_count': n_count, 'O_count': o_count, 'S_count': s_count, 'P_count': p_count, 'FragmentCount': fragment_count, 'HasBenzene': has_benzene }) except Exception as e: print(f"计算分子属性时出错: {e}") continue return properties def save_molecules(mols, prefix="generated", noise_type="gaussian"): if not mols: print("没有分子可保存") return subdir = f"{prefix}_{noise_type}" os.makedirs(subdir, exist_ok=True) if len(mols) <= 100: legends = [f"Mol {i + 1}" for i in range(len(mols))] img = Draw.MolsToGridImage(mols, molsPerRow=5, subImgSize=(300, 300), legends=legends) img.save(f"{subdir}/{prefix}_{noise_type}_molecules.png") properties = calculate_molecular_properties(mols) df = pd.DataFrame(properties) df.to_csv(f"{subdir}/{prefix}_{noise_type}_properties.csv", index=False) with open(f"{subdir}/{prefix}_{noise_type}_smiles.smi", "w") as f: for props in properties: f.write(f"{props['SMILES']}\n") print(f"已保存 {len(mols)} 个单一片段分子到目录: {subdir}") # ------------------------- 新增:生成结果分析工具 ------------------------- def analyze_generated_molecules(mols_gaussian, mols_uniform): print("\n===== 生成分子分析报告 =====") count_gaussian = len(mols_gaussian) count_uniform = len(mols_uniform) print(f"高斯噪声生成单一片段分子: {count_gaussian}") print(f"均匀噪声生成单一片段分子: {count_uniform}") def calculate_avg_properties(mols): if not mols: return {} props = calculate_molecular_properties(mols) avg_props = {} for key in props[0].keys(): if key != 'SMILES': avg_props[key] = sum(p[key] for p in props) / len(props) return avg_props avg_gaussian = calculate_avg_properties(mols_gaussian) avg_uniform = calculate_avg_properties(mols_uniform) if avg_gaussian and avg_uniform: print("\n高斯噪声生成分子的平均属性:") for key, value in avg_gaussian.items(): print(f" {key}: {value:.2f}") print("\n均匀噪声生成分子的平均属性:") for key, value in avg_uniform.items(): print(f" {key}: {value:.2f}") print("\n属性差异 (均匀 - 高斯):") for key in avg_gaussian.keys(): if key != 'SMILES': diff = avg_uniform[key] - avg_gaussian[key] print(f" {key}: {diff:+.2f}") if mols_gaussian and mols_uniform: properties = ['MW', 'LogP', 'TPSA', 'HBA', 'HBD', 'RotBonds'] plt.figure(figsize=(15, 10)) for i, prop in enumerate(properties, 1): plt.subplot(2, 3, i) gaussian_vals = [p[prop] for p in calculate_molecular_properties(mols_gaussian)] uniform_vals = [p[prop] for p in calculate_molecular_properties(mols_uniform)] plt.hist(gaussian_vals, bins=20, alpha=0.5, label='Gaussian') plt.hist(uniform_vals, bins=20, alpha=0.5, label='Uniform') plt.title(f'{prop} Distribution') plt.xlabel(prop) plt.ylabel('Frequency') plt.legend() plt.tight_layout() plt.savefig('molecular_property_distribution.png') plt.close() print("\n分子属性分布图已保存为 'molecular_property_distribution.png'") # ------------------------- 新增:分子可视化工具 ------------------------- def visualize_molecules_grid(molecules, num_per_row=5, filename="molecules_grid.png", legends=None): """创建高质量的分子网格可视化图""" if not molecules: print("没有分子可可视化") return if legends is None: legends = [f"Molecule {i + 1}" for i in range(len(molecules))] try: img = Draw.MolsToGridImage( molecules, molsPerRow=num_per_row, subImgSize=(300, 300), legends=legends, useSVG=False, highlightAtomLists=None, highlightBondLists=None ) img.save(filename) print(f"分子网格图已保存至: {filename}") return img except Exception as e: print(f"生成分子网格图时出错: {e}") return None # ------------------------- 主函数 ------------------------- def main(): print("=" * 80) print("基于合法SMILES的分子生成GAN系统(改进版)") print("=" * 80) dataset_path = "D:\python\pythonProject1\DATA\Inhibitor1368_data.xlsx" if not os.path.exists(dataset_path): print(f"错误:数据集文件 '{dataset_path}' 不存在!") exit(1) print(f"开始加载数据集: {dataset_path}") print("=" * 80) print("开始训练分子生成GAN...") generator, discriminator = train_gan( dataset_path=dataset_path, epochs=200, batch_size=32, lr=1e-5, ) print("=" * 80) # 设置参考缓蚀剂分子(确保是单一片段含苯环的有效分子) reference_smiles = "NCCNCc1ccc(O)c2ncccc12" if generator: print("训练完成!模型和生成结果已保存") print("生成的分子可视化结果在'results'目录") print("损失曲线已保存为'gan_loss_curve.png'") print("合法分子数量曲线已保存为'valid_molecules_curve.png'") print("分子片段比例曲线已保存为'fragment_ratio_curve.png'") print("苯环比例曲线已保存为'benzene_ratio_curve.png'") # 基于参考分子生成1000个新分子(高斯噪声,强制含苯环) print("\n开始基于参考分子生成新分子(高斯噪声,强制含苯环)...") gaussian_mols = generate_molecules_with_reference_improved( generator, reference_smiles, num_samples=1000, noise_type="gaussian", force_benzene=True # 强制生成含苯环分子 ) save_molecules(gaussian_mols, prefix="ref_based", noise_type="gaussian_benzene") # 可视化最佳分子 if gaussian_mols: # 计算每个分子的QED分数 qed_scores = [] for mol in gaussian_mols: try: qed_scores.append(rdMolDescriptors.CalcQED(mol)) except: qed_scores.append(0) # 按QED分数排序 sorted_indices = sorted(range(len(qed_scores)), key=lambda i: qed_scores[i], reverse=True) top_molecules = [gaussian_mols[i] for i in sorted_indices[:25]] top_legends = [f"QED: {qed_scores[i]:.3f}" for i in sorted_indices[:25]] # 可视化 visualize_molecules_grid( top_molecules, num_per_row=5, filename="top_molecules_by_qed.png", legends=top_legends ) print("已生成并保存最佳分子的可视化结果") # 基于参考分子生成1000个新分子(均匀噪声,强制含苯环) print("\n开始基于参考分子生成新分子(均匀噪声,强制含苯环)...") uniform_mols = generate_molecules_with_reference_improved( generator, reference_smiles, num_samples=1000, noise_type="uniform", force_benzene=True # 强制生成含苯环分子 ) save_molecules(uniform_mols, prefix="ref_based", noise_type="uniform_benzene") # 分析两种噪声生成的分子差异 if gaussian_mols and uniform_mols: analyze_generated_molecules(gaussian_mols, uniform_mols) else: print("训练过程中未生成合法分子,请检查数据") print("=" * 80) if __name__ == "__main__": main()
最新发布
07-09
源代码如下 请进行修改 import torch import torch.nn as nn import torch.optim as optim import pandas as pd from rdkit import Chem from rdkit.Chem import Draw, rdMolDescriptors from torch_geometric.data import Data, Batch from torch_geometric.nn import GCNConv, global_mean_pool from torch_geometric.loader import DataLoader import matplotlib.pyplot as plt from sklearn.preprocessing import StandardScaler import os import numpy as np import random # 设置随机种子以确保结果可复现 def set_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False set_seed(42) # ------------------------- 工具函数 ------------------------- def validate_atomic_nums(atomic_nums): """验证并修正原子类型,确保为整数且合法""" valid_atoms = {1, 6, 7, 8, 16} # H, C, N, O, S # 处理PyTorch张量类型 if isinstance(atomic_nums, torch.Tensor): atomic_nums = atomic_nums.cpu().numpy() if atomic_nums.is_cuda else atomic_nums.numpy() # 处理列表类型 if isinstance(atomic_nums, list): # 转换为NumPy数组 atomic_nums = np.array(atomic_nums) # 确保原子类型是一维数组 if atomic_nums.ndim > 1: atomic_nums = atomic_nums.flatten() # 转换为整数并验证 atomic_nums = np.round(atomic_nums).astype(int) return [num if num in valid_atoms else 6 for num in atomic_nums] # 默认替换为碳原子 def smiles_to_graph(smiles): """SMILES → 分子图,使用兼容的价态计算方法""" mol = Chem.MolFromSmiles(smiles) if mol is None: return None try: Chem.SanitizeMol(mol) except Exception as e: print(f"分子净化失败: {e}") return None atom_feats = [atom.GetAtomicNum() for atom in mol.GetAtoms()] edge_index = [] for bond in mol.GetBonds(): i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() edge_index += [(i, j), (j, i)] x = torch.tensor(atom_feats, dtype=torch.long).view(-1, 1) edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous() return Data(x=x, edge_index=edge_index) def augment_mol(smiles): """数据增强:随机加氢""" mol = Chem.MolFromSmiles(smiles) if not mol: return smiles try: mol = Chem.AddHs(mol) except: pass return Chem.MolToSmiles(mol) def generate_realistic_edges(node_features, avg_edge_count=20, atomic_nums=None): """生成符合价键约束的边,确保原子类型为整数且合法""" num_nodes = node_features.shape[0] if num_nodes < 2 or atomic_nums is None: return torch.empty((2, 0), dtype=torch.long), torch.tensor([]) # 验证并修正原子类型,确保为整数列表 atomic_nums = validate_atomic_nums(atomic_nums) max_valence = {1: 1, 6: 4, 7: 3, 8: 2, 16: 6} edge_set = set() edge_index = [] used_valence = {i: 0 for i in range(num_nodes)} bond_types = [] similarity = torch.matmul(node_features, node_features.transpose(0, 1)).squeeze() norm_similarity = similarity / (torch.max(similarity) + 1e-8) edge_probs = norm_similarity.view(-1) num_possible_edges = num_nodes * (num_nodes - 1) num_edges = min(avg_edge_count, num_possible_edges) _, indices = torch.topk(edge_probs, k=num_edges) for idx in indices: u, v = idx // num_nodes, idx % num_nodes if u == v or (u, v) in edge_set or (v, u) in edge_set: continue atom_u, atom_v = atomic_nums[u], atomic_nums[v] # 确保原子编号在 max_valence 字典中 if atom_u not in max_valence or atom_v not in max_valence: continue avail_valence_u = max_valence[atom_u] - used_valence[u] avail_valence_v = max_valence[atom_v] - used_valence[v] if avail_valence_u >= 1 and avail_valence_v >= 1: edge_set.add((u, v)) edge_index.append([u, v]) used_valence[u] += 1 used_valence[v] += 1 if atom_u not in {1} and atom_v not in {1} and \ avail_valence_u >= 2 and avail_valence_v >= 2 and \ np.random.random() < 0.3: bond_types.append(2) used_valence[u] += 1 used_valence[v] += 1 else: bond_types.append(1) return torch.tensor(edge_index).t().contiguous(), torch.tensor(bond_types) def build_valid_mol(atomic_nums, edge_index, bond_types=None): """构建合法分子,确保原子类型为整数且合法""" # 验证并修正原子类型 atomic_nums = validate_atomic_nums(atomic_nums) mol = Chem.RWMol() for num in atomic_nums: mol.AddAtom(Chem.Atom(num)) added_edges = set() bond_types = bond_types.tolist() if bond_types is not None else [] for j in range(edge_index.shape[1]): u, v = edge_index[0, j].item(), edge_index[1, j].item() if u < len(atomic_nums) and v < len(atomic_nums) and u != v and (u, v) not in added_edges: bond_type = Chem.BondType.SINGLE if j < len(bond_types) and bond_types[j] == 2: bond_type = Chem.BondType.DOUBLE mol.AddBond(u, v, bond_type) added_edges.add((u, v)) added_edges.add((v, u)) try: Chem.SanitizeMol(mol) return mol except Exception as e: print(f"分子构建失败: {e}") return None def mol_to_graph(mol): """RDKit Mol → 分子图""" if mol is None: return None try: Chem.SanitizeMol(mol) except: return None atom_feats = [atom.GetAtomicNum() for atom in mol.GetAtoms()] edge_index = [] for bond in mol.GetBonds(): i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() edge_index += [(i, j), (j, i)] x = torch.tensor(atom_feats, dtype=torch.long).view(-1, 1) edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous() return Data(x=x, edge_index=edge_index) # ------------------------- 数据集 ------------------------- class MolecularGraphDataset(torch.utils.data.Dataset): def __init__(self, path): df = pd.read_excel(path, usecols=[1, 2, 3]) df.columns = ['SMILES', 'Concentration', 'Efficiency'] df.dropna(inplace=True) df['SMILES'] = df['SMILES'].apply(augment_mol) self.graphs, self.conditions = [], [] scaler = StandardScaler() cond_scaled = scaler.fit_transform(df[['Concentration', 'Efficiency']]) self.edge_stats = {'edge_counts': []} valid_count = 0 for i, row in df.iterrows(): graph = smiles_to_graph(row['SMILES']) if graph: # 验证并修正原子类型 atomic_nums = graph.x.squeeze().tolist() atomic_nums = validate_atomic_nums(atomic_nums) graph.x = torch.tensor(atomic_nums, dtype=torch.long).view(-1, 1) self.graphs.append(graph) self.conditions.append(torch.tensor(cond_scaled[i], dtype=torch.float32)) self.edge_stats['edge_counts'].append(graph.edge_index.shape[1] // 2) valid_count += 1 print(f"成功加载 {valid_count} 个分子,过滤掉 {len(df) - valid_count} 个无效分子") self.avg_edge_count = int(np.mean(self.edge_stats['edge_counts'])) if self.edge_stats['edge_counts'] else 20 print(f"真实图平均边数: {self.avg_edge_count}") def __len__(self): return len(self.graphs) def __getitem__(self, idx): return self.graphs[idx], self.conditions[idx] # ------------------------- 生成器 ------------------------- class GraphGenerator(nn.Module): def __init__(self, noise_dim=16, condition_dim=2, hidden_dim=64): super().__init__() self.fc = nn.Sequential( nn.Linear(noise_dim + condition_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 10 * 5), # 5种原子类型 nn.Softmax(dim=-1) ) self.valid_atoms = torch.tensor([1, 6, 7, 8, 16], dtype=torch.long) self.dropout = nn.Dropout(0.2) def forward(self, z, cond): x = torch.cat([z, cond], dim=1) x = self.dropout(x) atom_probs = self.fc(x).view(-1, 10, 5) temperature = 0.8 atom_probs = torch.softmax(atom_probs / temperature, dim=-1) atom_indices = torch.argmax(atom_probs, dim=-1) node_feats = self.valid_atoms[atom_indices].float().view(-1, 10, 1) / 16.0 node_feats = torch.clamp(node_feats, 0.0, 1.0) # 转换为numpy数组并验证原子类型 atomic_nums = (node_feats.squeeze() * 16).cpu().numpy().round().astype(int) atomic_nums = validate_atomic_nums(atomic_nums) node_feats = torch.tensor(atomic_nums, dtype=torch.float32).view(-1, 10, 1) / 16.0 return node_feats # ------------------------- 判别器 ------------------------- class GraphDiscriminator(nn.Module): def __init__(self, node_feat_dim=1, condition_dim=2): super().__init__() self.conv1 = GCNConv(node_feat_dim + 1, 64) self.bn1 = nn.BatchNorm1d(64) self.conv2 = GCNConv(64, 32) self.bn2 = nn.BatchNorm1d(32) self.fc = nn.Sequential( nn.Linear(32 + condition_dim, 16), nn.LeakyReLU(0.2), nn.Linear(16, 1), ) def forward(self, data, condition): x, edge_index, batch = data.x.float(), data.edge_index, data.batch bond_features = torch.zeros(x.shape[0], 1).to(x.device) if hasattr(data, 'bond_types'): bond_features = data.bond_types.float().view(-1, 1) x = torch.cat([x, bond_features], dim=1) x = self.conv1(x, edge_index) x = self.bn1(x).relu() x = self.conv2(x, edge_index) x = self.bn2(x).relu() x = global_mean_pool(x, batch) x = torch.cat([x, condition], dim=1) raw_output = self.fc(x) mol_penalty = [] for i in range(data.num_graphs): subgraph = data.get_example(i) mol = Chem.RWMol() # 验证并修正原子类型 atomic_nums = (subgraph.x.squeeze() * 16).cpu().numpy().round().astype(int) atomic_nums = validate_atomic_nums(atomic_nums) for num in atomic_nums: mol.AddAtom(Chem.Atom(num)) added_edges = set() bond_types = subgraph.bond_types.tolist() if hasattr(subgraph, 'bond_types') else [] for j in range(edge_index.shape[1]): u, v = edge_index[0, j].item(), edge_index[1, j].item() if u < len(atomic_nums) and v < len(atomic_nums) and u != v and (u, v) not in added_edges: bond_type = Chem.BondType.SINGLE if j < len(bond_types) and bond_types[j] == 2: bond_type = Chem.BondType.DOUBLE mol.AddBond(u, v, bond_type) added_edges.add((u, v)) added_edges.add((v, u)) try: Chem.SanitizeMol(mol) ring_count = rdMolDescriptors.CalcNumRings(mol) hbd = rdMolDescriptors.CalcNumHBD(mol) hba = rdMolDescriptors.CalcNumHBA(mol) valence_violations = 0 max_valence_map = {1: 1, 6: 4, 7: 3, 8: 2, 16: 6} for atom in mol.GetAtoms(): atomic_num = atom.GetAtomicNum() if atomic_num in max_valence_map and atom.GetExplicitValence() > max_valence_map[atomic_num]: valence_violations += 1 ring_bonus = ring_count * 0.5 fg_bonus = (hbd + hba) * 0.3 valence_penalty = valence_violations * 1.0 penalty = -(ring_bonus + fg_bonus + valence_penalty) / 10.0 except Exception as e: print(f"分子检查失败: {e}") penalty = -1.0 mol_penalty.append(penalty) mol_penalty = torch.tensor(mol_penalty, dtype=torch.float32, device=x.device).view(-1, 1) return torch.sigmoid(raw_output + mol_penalty) # ------------------------- 自定义批处理函数 ------------------------- def custom_collate(batch): graphs, conditions = zip(*batch) batch_graphs = Batch.from_data_list(graphs) batch_conditions = torch.stack(conditions) return batch_graphs, batch_conditions # ------------------------- 训练流程 ------------------------- def train_gan(dataset_path, epochs=1750, batch_size=32, lr=1e-4, device=None): if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"使用设备: {device}") dataset = MolecularGraphDataset(dataset_path) if len(dataset) == 0: print("错误:数据集为空,请检查SMILES格式或RDKit版本") return None, None dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate) generator = GraphGenerator().to(device) discriminator = GraphDiscriminator().to(device) g_opt = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999)) d_opt = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999)) loss_fn = nn.BCELoss() d_loss_history, g_loss_history = [], [] os.makedirs("results", exist_ok=True) os.chdir("results") for epoch in range(1, epochs + 1): generator.train() discriminator.train() total_d_loss, total_g_loss = 0.0, 0.0 for real_graphs, conds in dataloader: real_batch = real_graphs.to(device) conds = conds.to(device) batch_size = real_batch.num_graphs real_labels = torch.ones(batch_size, 1).to(device) fake_labels = torch.zeros(batch_size, 1).to(device) # 判别器训练 d_real = discriminator(real_batch, conds) loss_real = loss_fn(d_real, real_labels) noise = torch.randn(batch_size, 16).to(device) fake_nodes = generator(noise, conds).detach() fake_mols = [] for i in range(fake_nodes.shape[0]): node_feats = fake_nodes[i] atomic_nums = (node_feats.squeeze() * 16).cpu().numpy().round().astype(int) atomic_nums = validate_atomic_nums(atomic_nums) edge_index, bond_types = generate_realistic_edges( node_feats, dataset.avg_edge_count, atomic_nums ) mol = build_valid_mol(atomic_nums, edge_index, bond_types) if mol and mol.GetNumAtoms() > 0 and mol.GetNumBonds() > 0: fake_mols.append(mol_to_graph(mol)) else: print( f"过滤非法分子:原子数={mol.GetNumAtoms() if mol else 0}, 键数={mol.GetNumBonds() if mol else 0}") if fake_mols: fake_batch = Batch.from_data_list(fake_mols).to(device) d_fake = discriminator(fake_batch, conds[:len(fake_mols)]) loss_fake = loss_fn(d_fake, fake_labels[:len(fake_mols)]) d_loss = loss_real + loss_fake d_opt.zero_grad() d_loss.backward() d_opt.step() total_d_loss += d_loss.item() # 生成器训练 noise = torch.randn(batch_size, 16).to(device) fake_nodes = generator(noise, conds) fake_graphs = [] for i in range(fake_nodes.shape[0]): node_feats = fake_nodes[i] atomic_nums = (node_feats.squeeze() * 16).cpu().numpy().round().astype(int) atomic_nums = validate_atomic_nums(atomic_nums) edge_index, bond_types = generate_realistic_edges( node_feats, dataset.avg_edge_count, atomic_nums ) fake_graphs.append(Data( x=node_feats, edge_index=edge_index, bond_types=bond_types )) fake_batch = Batch.from_data_list(fake_graphs).to(device) g_loss = loss_fn(discriminator(fake_batch, conds), real_labels) g_opt.zero_grad() g_loss.backward() g_opt.step() total_g_loss += g_loss.item() avg_d_loss = total_d_loss / len(dataloader) avg_g_loss = total_g_loss / len(dataloader) d_loss_history.append(avg_d_loss) g_loss_history.append(avg_g_loss) if epoch % 10 == 0: generator.eval() discriminator.eval() with torch.no_grad(): num_samples = 50 noise = torch.randn(num_samples, 16).to(device) conds = torch.randn(num_samples, 2).to(device) fake_nodes = generator(noise, conds) d_real_scores, d_fake_scores = [], [] for i in range(min(num_samples, 10)): real_idx = np.random.randint(0, len(dataset)) real_graph, real_cond = dataset[real_idx] real_batch = Batch.from_data_list([real_graph]).to(device) real_cond = real_cond.unsqueeze(0).to(device) d_real = discriminator(real_batch, real_cond) d_real_scores.append(d_real.item()) node_feats = fake_nodes[i] atomic_nums = (node_feats.squeeze() * 16).cpu().numpy().round().astype(int) atomic_nums = validate_atomic_nums(atomic_nums) edge_index, bond_types = generate_realistic_edges( node_feats, dataset.avg_edge_count, atomic_nums ) mol = build_valid_mol(atomic_nums, edge_index, bond_types) if mol: fake_graph = mol_to_graph(mol) if fake_graph: fake_batch = Batch.from_data_list([fake_graph]).to(device) fake_cond = conds[i].unsqueeze(0) d_fake = discriminator(fake_batch, fake_cond) d_fake_scores.append(d_fake.item()) if d_real_scores and d_fake_scores: print(f"Epoch {epoch}: D_loss={avg_d_loss:.4f}, G_loss={avg_g_loss:.4f}") print(f"D_real评分范围: {min(d_real_scores):.4f} - {max(d_real_scores):.4f}") print(f"D_fake评分范围: {min(d_fake_scores):.4f} - {max(d_fake_scores):.4f}") else: print(f"Epoch {epoch}: D_loss={avg_d_loss:.4f}, G_loss={avg_g_loss:.4f}") print("警告:生成的分子均不合法,无法评估D_fake评分") generator.train() discriminator.train() if epoch % 100 == 0: torch.save(generator.state_dict(), f"generator_epoch_{epoch}.pt") torch.save(discriminator.state_dict(), f"discriminator_epoch_{epoch}.pt") generator.eval() with torch.no_grad(): num_samples = 25 noise = torch.randn(num_samples, 16).to(device) conds = torch.randn(num_samples, 2).to(device) fake_nodes = generator(noise, conds) generated_mols = [] for i in range(num_samples): node_feats = fake_nodes[i] atomic_nums = (node_feats.squeeze() * 16).cpu().numpy().round().astype(int) atomic_nums = validate_atomic_nums(atomic_nums) edge_index, bond_types = generate_realistic_edges( node_feats, dataset.avg_edge_count, atomic_nums ) mol = build_valid_mol(atomic_nums, edge_index, bond_types) if mol and mol.GetNumAtoms() > 0 and mol.GetNumBonds() > 0: mol_weight = rdMolDescriptors.CalcExactMolWt(mol) logp = rdMolDescriptors.CalcCrippenLogP(mol) tpsa = rdMolDescriptors.CalcTPSA(mol) generated_mols.append((mol, mol_weight, logp, tpsa)) if generated_mols: generated_mols.sort(key=lambda x: x[1]) mols = [mol for mol, _, _, _ in generated_mols] legends = [ f"Molecule {i + 1}\nMW: {mw:.2f}\nLogP: {logp:.2f}\nTPSA: {tpsa:.2f}" for i, (_, mw, logp, tpsa) in enumerate(generated_mols) ] img = Draw.MolsToGridImage( mols, molsPerRow=5, subImgSize=(200, 200), legends=legends ) img.save(f"generated_molecules_epoch_{epoch}.png") print(f"Epoch {epoch}: 保存了 {len(generated_mols)} 个合法分子的可视化结果") else: print(f"Epoch {epoch}: 未生成任何合法分子") generator.train() plt.figure(figsize=(10, 5)) plt.plot(d_loss_history, label='Discriminator Loss') plt.plot(g_loss_history, label='Generator Loss') plt.xlabel('Epoch') plt.ylabel('Loss') plt.title('GAN Loss Curve') plt.legend() plt.savefig('gan_loss_curve.png') plt.close() return generator, discriminator # ------------------------- 主函数 ------------------------- if __name__ == "__main__": print("=" * 80) print("基于SMILES验证的分子生成GAN系统") print("=" * 80) dataset_path = "D:\\python\\pythonProject1\\DATA\\data-8.xlsx" if not os.path.exists(dataset_path): print(f"错误:数据集文件 '{dataset_path}' 不存在!") print("请准备一个Excel文件,包含SMILES、Concentration、Efficiency三列数据。") exit(1) print(f"开始加载数据集: {dataset_path}") print("=" * 80) print("开始训练分子生成GAN...") generator, discriminator = train_gan( dataset_path=dataset_path, epochs=1750, batch_size=32, lr=1e-4, ) print("=" * 80) if generator and discriminator: print("训练完成!") print("生成器和判别器模型已保存到'results'目录") print("生成的分子可视化结果已保存到'results'目录") print("损失曲线已保存为'gan_loss_curve.png'") else: print("训练未能启动:数据集为空或存在其他问题") print("=" * 80)
06-17
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值