MAX_BONUS

本文定义了一系列用于计算用户优先级及基于此优先级的奖励比例的宏定义。包括最大用户优先级(MAX_USER_PRIO)、用户优先级(USER_PRIO)、优先级奖励比率(PRIO_BONUS_RATIO)以及最大奖励(MAX_BONUS)的计算方法。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

MAX_BONUS = (40*25/100) = 10

#define MAX_BONUS       (MAX_USER_PRIO * PRIO_BONUS_RATIO / 100)

MAX_USER_PRIO = 40
----------------------------
#define MAX_USER_PRIO        (USER_PRIO( MAX_PRIO))    //USER_PRIO(140)
#define USER_PRIO(p)         ((p)- MAX_RT_PRIO)       //140-100

PRIO_BONUS_RATIO = 25
----------------------------
#define PRIO_BONUS_RATIO     25



源代码如下 请进行修改 import torch import torch.nn as nn import torch.optim as optim import pandas as pd from rdkit import Chem from rdkit.Chem import Draw, rdMolDescriptors from torch_geometric.data import Data, Batch from torch_geometric.nn import GCNConv, global_mean_pool from torch_geometric.loader import DataLoader import matplotlib.pyplot as plt from sklearn.preprocessing import StandardScaler import os import numpy as np import random # 设置随机种子以确保结果可复现 def set_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False set_seed(42) # ------------------------- 工具函数 ------------------------- def validate_atomic_nums(atomic_nums): """验证并修正原子类型,确保为整数且合法""" valid_atoms = {1, 6, 7, 8, 16} # H, C, N, O, S # 处理PyTorch张量类型 if isinstance(atomic_nums, torch.Tensor): atomic_nums = atomic_nums.cpu().numpy() if atomic_nums.is_cuda else atomic_nums.numpy() # 处理列表类型 if isinstance(atomic_nums, list): # 转换为NumPy数组 atomic_nums = np.array(atomic_nums) # 确保原子类型是一维数组 if atomic_nums.ndim > 1: atomic_nums = atomic_nums.flatten() # 转换为整数并验证 atomic_nums = np.round(atomic_nums).astype(int) return [num if num in valid_atoms else 6 for num in atomic_nums] # 默认替换为碳原子 def smiles_to_graph(smiles): """SMILES → 分子图,使用兼容的价态计算方法""" mol = Chem.MolFromSmiles(smiles) if mol is None: return None try: Chem.SanitizeMol(mol) except Exception as e: print(f"分子净化失败: {e}") return None atom_feats = [atom.GetAtomicNum() for atom in mol.GetAtoms()] edge_index = [] for bond in mol.GetBonds(): i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() edge_index += [(i, j), (j, i)] x = torch.tensor(atom_feats, dtype=torch.long).view(-1, 1) edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous() return Data(x=x, edge_index=edge_index) def augment_mol(smiles): """数据增强:随机加氢""" mol = Chem.MolFromSmiles(smiles) if not mol: return smiles try: mol = Chem.AddHs(mol) except: pass return Chem.MolToSmiles(mol) def generate_realistic_edges(node_features, avg_edge_count=20, atomic_nums=None): """生成符合价键约束的边,确保原子类型为整数且合法""" num_nodes = node_features.shape[0] if num_nodes < 2 or atomic_nums is None: return torch.empty((2, 0), dtype=torch.long), torch.tensor([]) # 验证并修正原子类型,确保为整数列表 atomic_nums = validate_atomic_nums(atomic_nums) max_valence = {1: 1, 6: 4, 7: 3, 8: 2, 16: 6} edge_set = set() edge_index = [] used_valence = {i: 0 for i in range(num_nodes)} bond_types = [] similarity = torch.matmul(node_features, node_features.transpose(0, 1)).squeeze() norm_similarity = similarity / (torch.max(similarity) + 1e-8) edge_probs = norm_similarity.view(-1) num_possible_edges = num_nodes * (num_nodes - 1) num_edges = min(avg_edge_count, num_possible_edges) _, indices = torch.topk(edge_probs, k=num_edges) for idx in indices: u, v = idx // num_nodes, idx % num_nodes if u == v or (u, v) in edge_set or (v, u) in edge_set: continue atom_u, atom_v = atomic_nums[u], atomic_nums[v] # 确保原子编号在 max_valence 字典中 if atom_u not in max_valence or atom_v not in max_valence: continue avail_valence_u = max_valence[atom_u] - used_valence[u] avail_valence_v = max_valence[atom_v] - used_valence[v] if avail_valence_u >= 1 and avail_valence_v >= 1: edge_set.add((u, v)) edge_index.append([u, v]) used_valence[u] += 1 used_valence[v] += 1 if atom_u not in {1} and atom_v not in {1} and \ avail_valence_u >= 2 and avail_valence_v >= 2 and \ np.random.random() < 0.3: bond_types.append(2) used_valence[u] += 1 used_valence[v] += 1 else: bond_types.append(1) return torch.tensor(edge_index).t().contiguous(), torch.tensor(bond_types) def build_valid_mol(atomic_nums, edge_index, bond_types=None): """构建合法分子,确保原子类型为整数且合法""" # 验证并修正原子类型 atomic_nums = validate_atomic_nums(atomic_nums) mol = Chem.RWMol() for num in atomic_nums: mol.AddAtom(Chem.Atom(num)) added_edges = set() bond_types = bond_types.tolist() if bond_types is not None else [] for j in range(edge_index.shape[1]): u, v = edge_index[0, j].item(), edge_index[1, j].item() if u < len(atomic_nums) and v < len(atomic_nums) and u != v and (u, v) not in added_edges: bond_type = Chem.BondType.SINGLE if j < len(bond_types) and bond_types[j] == 2: bond_type = Chem.BondType.DOUBLE mol.AddBond(u, v, bond_type) added_edges.add((u, v)) added_edges.add((v, u)) try: Chem.SanitizeMol(mol) return mol except Exception as e: print(f"分子构建失败: {e}") return None def mol_to_graph(mol): """RDKit Mol → 分子图""" if mol is None: return None try: Chem.SanitizeMol(mol) except: return None atom_feats = [atom.GetAtomicNum() for atom in mol.GetAtoms()] edge_index = [] for bond in mol.GetBonds(): i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() edge_index += [(i, j), (j, i)] x = torch.tensor(atom_feats, dtype=torch.long).view(-1, 1) edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous() return Data(x=x, edge_index=edge_index) # ------------------------- 数据集 ------------------------- class MolecularGraphDataset(torch.utils.data.Dataset): def __init__(self, path): df = pd.read_excel(path, usecols=[1, 2, 3]) df.columns = ['SMILES', 'Concentration', 'Efficiency'] df.dropna(inplace=True) df['SMILES'] = df['SMILES'].apply(augment_mol) self.graphs, self.conditions = [], [] scaler = StandardScaler() cond_scaled = scaler.fit_transform(df[['Concentration', 'Efficiency']]) self.edge_stats = {'edge_counts': []} valid_count = 0 for i, row in df.iterrows(): graph = smiles_to_graph(row['SMILES']) if graph: # 验证并修正原子类型 atomic_nums = graph.x.squeeze().tolist() atomic_nums = validate_atomic_nums(atomic_nums) graph.x = torch.tensor(atomic_nums, dtype=torch.long).view(-1, 1) self.graphs.append(graph) self.conditions.append(torch.tensor(cond_scaled[i], dtype=torch.float32)) self.edge_stats['edge_counts'].append(graph.edge_index.shape[1] // 2) valid_count += 1 print(f"成功加载 {valid_count} 个分子,过滤掉 {len(df) - valid_count} 个无效分子") self.avg_edge_count = int(np.mean(self.edge_stats['edge_counts'])) if self.edge_stats['edge_counts'] else 20 print(f"真实图平均边数: {self.avg_edge_count}") def __len__(self): return len(self.graphs) def __getitem__(self, idx): return self.graphs[idx], self.conditions[idx] # ------------------------- 生成器 ------------------------- class GraphGenerator(nn.Module): def __init__(self, noise_dim=16, condition_dim=2, hidden_dim=64): super().__init__() self.fc = nn.Sequential( nn.Linear(noise_dim + condition_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 10 * 5), # 5种原子类型 nn.Softmax(dim=-1) ) self.valid_atoms = torch.tensor([1, 6, 7, 8, 16], dtype=torch.long) self.dropout = nn.Dropout(0.2) def forward(self, z, cond): x = torch.cat([z, cond], dim=1) x = self.dropout(x) atom_probs = self.fc(x).view(-1, 10, 5) temperature = 0.8 atom_probs = torch.softmax(atom_probs / temperature, dim=-1) atom_indices = torch.argmax(atom_probs, dim=-1) node_feats = self.valid_atoms[atom_indices].float().view(-1, 10, 1) / 16.0 node_feats = torch.clamp(node_feats, 0.0, 1.0) # 转换为numpy数组并验证原子类型 atomic_nums = (node_feats.squeeze() * 16).cpu().numpy().round().astype(int) atomic_nums = validate_atomic_nums(atomic_nums) node_feats = torch.tensor(atomic_nums, dtype=torch.float32).view(-1, 10, 1) / 16.0 return node_feats # ------------------------- 判别器 ------------------------- class GraphDiscriminator(nn.Module): def __init__(self, node_feat_dim=1, condition_dim=2): super().__init__() self.conv1 = GCNConv(node_feat_dim + 1, 64) self.bn1 = nn.BatchNorm1d(64) self.conv2 = GCNConv(64, 32) self.bn2 = nn.BatchNorm1d(32) self.fc = nn.Sequential( nn.Linear(32 + condition_dim, 16), nn.LeakyReLU(0.2), nn.Linear(16, 1), ) def forward(self, data, condition): x, edge_index, batch = data.x.float(), data.edge_index, data.batch bond_features = torch.zeros(x.shape[0], 1).to(x.device) if hasattr(data, 'bond_types'): bond_features = data.bond_types.float().view(-1, 1) x = torch.cat([x, bond_features], dim=1) x = self.conv1(x, edge_index) x = self.bn1(x).relu() x = self.conv2(x, edge_index) x = self.bn2(x).relu() x = global_mean_pool(x, batch) x = torch.cat([x, condition], dim=1) raw_output = self.fc(x) mol_penalty = [] for i in range(data.num_graphs): subgraph = data.get_example(i) mol = Chem.RWMol() # 验证并修正原子类型 atomic_nums = (subgraph.x.squeeze() * 16).cpu().numpy().round().astype(int) atomic_nums = validate_atomic_nums(atomic_nums) for num in atomic_nums: mol.AddAtom(Chem.Atom(num)) added_edges = set() bond_types = subgraph.bond_types.tolist() if hasattr(subgraph, 'bond_types') else [] for j in range(edge_index.shape[1]): u, v = edge_index[0, j].item(), edge_index[1, j].item() if u < len(atomic_nums) and v < len(atomic_nums) and u != v and (u, v) not in added_edges: bond_type = Chem.BondType.SINGLE if j < len(bond_types) and bond_types[j] == 2: bond_type = Chem.BondType.DOUBLE mol.AddBond(u, v, bond_type) added_edges.add((u, v)) added_edges.add((v, u)) try: Chem.SanitizeMol(mol) ring_count = rdMolDescriptors.CalcNumRings(mol) hbd = rdMolDescriptors.CalcNumHBD(mol) hba = rdMolDescriptors.CalcNumHBA(mol) valence_violations = 0 max_valence_map = {1: 1, 6: 4, 7: 3, 8: 2, 16: 6} for atom in mol.GetAtoms(): atomic_num = atom.GetAtomicNum() if atomic_num in max_valence_map and atom.GetExplicitValence() > max_valence_map[atomic_num]: valence_violations += 1 ring_bonus = ring_count * 0.5 fg_bonus = (hbd + hba) * 0.3 valence_penalty = valence_violations * 1.0 penalty = -(ring_bonus + fg_bonus + valence_penalty) / 10.0 except Exception as e: print(f"分子检查失败: {e}") penalty = -1.0 mol_penalty.append(penalty) mol_penalty = torch.tensor(mol_penalty, dtype=torch.float32, device=x.device).view(-1, 1) return torch.sigmoid(raw_output + mol_penalty) # ------------------------- 自定义批处理函数 ------------------------- def custom_collate(batch): graphs, conditions = zip(*batch) batch_graphs = Batch.from_data_list(graphs) batch_conditions = torch.stack(conditions) return batch_graphs, batch_conditions # ------------------------- 训练流程 ------------------------- def train_gan(dataset_path, epochs=1750, batch_size=32, lr=1e-4, device=None): if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"使用设备: {device}") dataset = MolecularGraphDataset(dataset_path) if len(dataset) == 0: print("错误:数据集为空,请检查SMILES格式或RDKit版本") return None, None dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate) generator = GraphGenerator().to(device) discriminator = GraphDiscriminator().to(device) g_opt = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999)) d_opt = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999)) loss_fn = nn.BCELoss() d_loss_history, g_loss_history = [], [] os.makedirs("results", exist_ok=True) os.chdir("results") for epoch in range(1, epochs + 1): generator.train() discriminator.train() total_d_loss, total_g_loss = 0.0, 0.0 for real_graphs, conds in dataloader: real_batch = real_graphs.to(device) conds = conds.to(device) batch_size = real_batch.num_graphs real_labels = torch.ones(batch_size, 1).to(device) fake_labels = torch.zeros(batch_size, 1).to(device) # 判别器训练 d_real = discriminator(real_batch, conds) loss_real = loss_fn(d_real, real_labels) noise = torch.randn(batch_size, 16).to(device) fake_nodes = generator(noise, conds).detach() fake_mols = [] for i in range(fake_nodes.shape[0]): node_feats = fake_nodes[i] atomic_nums = (node_feats.squeeze() * 16).cpu().numpy().round().astype(int) atomic_nums = validate_atomic_nums(atomic_nums) edge_index, bond_types = generate_realistic_edges( node_feats, dataset.avg_edge_count, atomic_nums ) mol = build_valid_mol(atomic_nums, edge_index, bond_types) if mol and mol.GetNumAtoms() > 0 and mol.GetNumBonds() > 0: fake_mols.append(mol_to_graph(mol)) else: print( f"过滤非法分子:原子数={mol.GetNumAtoms() if mol else 0}, 键数={mol.GetNumBonds() if mol else 0}") if fake_mols: fake_batch = Batch.from_data_list(fake_mols).to(device) d_fake = discriminator(fake_batch, conds[:len(fake_mols)]) loss_fake = loss_fn(d_fake, fake_labels[:len(fake_mols)]) d_loss = loss_real + loss_fake d_opt.zero_grad() d_loss.backward() d_opt.step() total_d_loss += d_loss.item() # 生成器训练 noise = torch.randn(batch_size, 16).to(device) fake_nodes = generator(noise, conds) fake_graphs = [] for i in range(fake_nodes.shape[0]): node_feats = fake_nodes[i] atomic_nums = (node_feats.squeeze() * 16).cpu().numpy().round().astype(int) atomic_nums = validate_atomic_nums(atomic_nums) edge_index, bond_types = generate_realistic_edges( node_feats, dataset.avg_edge_count, atomic_nums ) fake_graphs.append(Data( x=node_feats, edge_index=edge_index, bond_types=bond_types )) fake_batch = Batch.from_data_list(fake_graphs).to(device) g_loss = loss_fn(discriminator(fake_batch, conds), real_labels) g_opt.zero_grad() g_loss.backward() g_opt.step() total_g_loss += g_loss.item() avg_d_loss = total_d_loss / len(dataloader) avg_g_loss = total_g_loss / len(dataloader) d_loss_history.append(avg_d_loss) g_loss_history.append(avg_g_loss) if epoch % 10 == 0: generator.eval() discriminator.eval() with torch.no_grad(): num_samples = 50 noise = torch.randn(num_samples, 16).to(device) conds = torch.randn(num_samples, 2).to(device) fake_nodes = generator(noise, conds) d_real_scores, d_fake_scores = [], [] for i in range(min(num_samples, 10)): real_idx = np.random.randint(0, len(dataset)) real_graph, real_cond = dataset[real_idx] real_batch = Batch.from_data_list([real_graph]).to(device) real_cond = real_cond.unsqueeze(0).to(device) d_real = discriminator(real_batch, real_cond) d_real_scores.append(d_real.item()) node_feats = fake_nodes[i] atomic_nums = (node_feats.squeeze() * 16).cpu().numpy().round().astype(int) atomic_nums = validate_atomic_nums(atomic_nums) edge_index, bond_types = generate_realistic_edges( node_feats, dataset.avg_edge_count, atomic_nums ) mol = build_valid_mol(atomic_nums, edge_index, bond_types) if mol: fake_graph = mol_to_graph(mol) if fake_graph: fake_batch = Batch.from_data_list([fake_graph]).to(device) fake_cond = conds[i].unsqueeze(0) d_fake = discriminator(fake_batch, fake_cond) d_fake_scores.append(d_fake.item()) if d_real_scores and d_fake_scores: print(f"Epoch {epoch}: D_loss={avg_d_loss:.4f}, G_loss={avg_g_loss:.4f}") print(f"D_real评分范围: {min(d_real_scores):.4f} - {max(d_real_scores):.4f}") print(f"D_fake评分范围: {min(d_fake_scores):.4f} - {max(d_fake_scores):.4f}") else: print(f"Epoch {epoch}: D_loss={avg_d_loss:.4f}, G_loss={avg_g_loss:.4f}") print("警告:生成的分子均不合法,无法评估D_fake评分") generator.train() discriminator.train() if epoch % 100 == 0: torch.save(generator.state_dict(), f"generator_epoch_{epoch}.pt") torch.save(discriminator.state_dict(), f"discriminator_epoch_{epoch}.pt") generator.eval() with torch.no_grad(): num_samples = 25 noise = torch.randn(num_samples, 16).to(device) conds = torch.randn(num_samples, 2).to(device) fake_nodes = generator(noise, conds) generated_mols = [] for i in range(num_samples): node_feats = fake_nodes[i] atomic_nums = (node_feats.squeeze() * 16).cpu().numpy().round().astype(int) atomic_nums = validate_atomic_nums(atomic_nums) edge_index, bond_types = generate_realistic_edges( node_feats, dataset.avg_edge_count, atomic_nums ) mol = build_valid_mol(atomic_nums, edge_index, bond_types) if mol and mol.GetNumAtoms() > 0 and mol.GetNumBonds() > 0: mol_weight = rdMolDescriptors.CalcExactMolWt(mol) logp = rdMolDescriptors.CalcCrippenLogP(mol) tpsa = rdMolDescriptors.CalcTPSA(mol) generated_mols.append((mol, mol_weight, logp, tpsa)) if generated_mols: generated_mols.sort(key=lambda x: x[1]) mols = [mol for mol, _, _, _ in generated_mols] legends = [ f"Molecule {i + 1}\nMW: {mw:.2f}\nLogP: {logp:.2f}\nTPSA: {tpsa:.2f}" for i, (_, mw, logp, tpsa) in enumerate(generated_mols) ] img = Draw.MolsToGridImage( mols, molsPerRow=5, subImgSize=(200, 200), legends=legends ) img.save(f"generated_molecules_epoch_{epoch}.png") print(f"Epoch {epoch}: 保存了 {len(generated_mols)} 个合法分子的可视化结果") else: print(f"Epoch {epoch}: 未生成任何合法分子") generator.train() plt.figure(figsize=(10, 5)) plt.plot(d_loss_history, label='Discriminator Loss') plt.plot(g_loss_history, label='Generator Loss') plt.xlabel('Epoch') plt.ylabel('Loss') plt.title('GAN Loss Curve') plt.legend() plt.savefig('gan_loss_curve.png') plt.close() return generator, discriminator # ------------------------- 主函数 ------------------------- if __name__ == "__main__": print("=" * 80) print("基于SMILES验证的分子生成GAN系统") print("=" * 80) dataset_path = "D:\\python\\pythonProject1\\DATA\\data-8.xlsx" if not os.path.exists(dataset_path): print(f"错误:数据集文件 '{dataset_path}' 不存在!") print("请准备一个Excel文件,包含SMILES、Concentration、Efficiency三列数据。") exit(1) print(f"开始加载数据集: {dataset_path}") print("=" * 80) print("开始训练分子生成GAN...") generator, discriminator = train_gan( dataset_path=dataset_path, epochs=1750, batch_size=32, lr=1e-4, ) print("=" * 80) if generator and discriminator: print("训练完成!") print("生成器和判别器模型已保存到'results'目录") print("生成的分子可视化结果已保存到'results'目录") print("损失曲线已保存为'gan_loss_curve.png'") else: print("训练未能启动:数据集为空或存在其他问题") print("=" * 80)
06-17
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值