import torch
from torch_geometric.data import Data
from rdkit import Chem
from rdkit.Chem import rdmolops
def smi_to_graph(smi_file):
graphs = []
with open(smi_file, 'r') as f:
for line in f:
smi, mol_id = line.strip().split('\t') # 假设格式:SMILES\tID
# 转换为分子图
mol = Chem.MolFromSmiles(smi)
if not mol: continue
# 原子特征(节点)
atom_features = []
for atom in mol.GetAtoms():
features = [
atom.GetAtomicNum(), # 原子序数
atom.GetDegree(), # 度
atom.GetFormalCharge(), # 形式电荷
int(atom.IsInRing()) # 是否在环中
]
atom_features.append(features)
# 化学键(边)
edge_index = []
for bond in mol.GetBonds():
i = bond.GetBeginAtomIdx()
j = bond.GetEndAtomIdx()
edge_index.append([i, j])
edge_index.append([j, i]) # 无向图
# 构建PyG数据对象
graph = Data(
x=torch.tensor(atom_features, dtype=torch.float),
edge_index=torch.tensor(edge_index).t().contiguous(),
y=torch.tensor([int(mol_id)], dtype=torch.long) # 分子ID作为标签
)
graphs.append(graph)
return graphs
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import DataLoader
class JCZS_GNN(torch.nn.Module):
def __init__(self, input_dim=4, hidden_dim=128, output_dim=1, num_classes=3):
super().__init__()
# 图卷积层
self.conv1 = GCNConv(input_dim, hidden_dim)
self.conv2 = GCNConv(hidden_dim, hidden_dim)
# 载药量预测头
self.reg_head = torch.nn.Sequential(
torch.nn.Linear(hidden_dim * 2, 64),
torch.nn.ReLU(),
torch.nn.Linear(64, output_dim)
)
# 君臣佐使分类头
self.cls_head = torch.nn.Sequential(
torch.nn.Linear(hidden_dim * 2, hidden_dim),
torch.nn.ReLU(),
torch.nn.Linear(hidden_dim, num_classes)
)
def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
# 图卷积
x = F.relu(self.conv1(x, edge_index))
x = F.dropout(x, p=0.5, training=self.training)
x = F.relu(self.conv2(x, edge_index))
# 全局池化
graph_embed = global_mean_pool(x, batch)
# 多任务输出
y_reg = self.reg_head(graph_embed) # 载药量预测
y_cls = self.cls_head(graph_embed) # 君臣分类
return y_reg, F.softmax(y_cls, dim=1)
# 训练流程
def train_gnn(dataset, epochs=100):
loader = DataLoader(dataset, batch_size=32, shuffle=True)
model = JCZS_GNN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(epochs):
model.train()
for batch in loader:
optimizer.zero_grad()
# 载药量预测损失
y_reg_pred, y_cls_pred = model(batch)
reg_loss = F.mse_loss(y_reg_pred.squeeze(), batch.y_reg) # 假设batch.y_reg是载药量标签
# 君臣分类损失(加权交叉熵)
weights = torch.tensor([3.0, 2.0, 1.0]) # 君、臣、佐使权重
cls_loss = F.cross_entropy(y_cls_pred, batch.y_cls, weight=weights) # 假设batch.y_cls是分类标签
total_loss = reg_loss + cls_loss
total_loss.backward()
optimizer.step()
import torch
import torch.nn as nn
class ConditionedGenerator(nn.Module):
def __init__(self, latent_dim=100, cond_dim=13, vocab_size=100, max_len=80):
super().__init__()
self.embed = nn.Embedding(vocab_size, 256)
self.cond_proj = nn.Linear(cond_dim, 256)
self.latent_proj = nn.Linear(latent_dim, 256)
# Transformer解码器
decoder_layer = nn.TransformerDecoderLayer(d_model=256, nhead=8)
self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
self.fc_out = nn.Linear(256, vocab_size)
self.max_len = max_len
def forward(self, z, condition, start_token):
# 条件向量投影
cond_embed = self.cond_proj(condition).unsqueeze(0) # (1, batch, 256)
# 噪声向量投影
latent_embed = self.latent_proj(z).unsqueeze(0) # (1, batch, 256)
# 起始token嵌入
tokens = start_token
token_embeds = self.embed(tokens).permute(1, 0, 2) # (seq_len, batch, 256)
# 自回归生成
for i in range(self.max_len - 1):
output = self.decoder(token_embeds, cond_embed)
logits = self.fc_out(output[-1]) # (batch, vocab_size)
next_token = torch.argmax(logits, dim=1)
tokens = torch.cat([tokens, next_token.unsqueeze(1)], dim=1)
# 更新嵌入
new_embed = self.embed(next_token).unsqueeze(0)
token_embeds = torch.cat([token_embeds, new_embed], dim=0)
return tokens
# 判别器(含君臣权重检测)
class JCZS_Discriminator(nn.Module):
def __init__(self, gnn_model):
super().__init__()
self.gnn = gnn_model # 共享GNN权重
self.fc = nn.Sequential(
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 1)
)
def forward(self, mol_graph):
# 提取分子特征
_, edge_index, _ = mol_graph.x, mol_graph.edge_index, mol_graph.batch
x = F.relu(self.gnn.conv1(mol_graph.x, edge_index))
x = F.relu(self.gnn.conv2(x, edge_index))
graph_embed = global_mean_pool(x, mol_graph.batch)
# 君臣权重合规检测
cls_probs = self.gnn.cls_head(graph_embed)
jczs_compliance = torch.abs(3.0 * cls_probs[:, 0] + 2.0 * cls_probs[:, 1] + 1.0 * cls_probs[:, 2] - 2.5)
# 判别器输出
validity_score = torch.sigmoid(self.fc(graph_embed))
return validity_score * (1.0 - jczs_compliance)
# 1. 数据准备
smi_file = r"D:\PyCharm 2025.2.0.1\project\分子数据.smi" # 包含SMILES和ID
dataset = smi_to_graph(smi_file)
# 2. 训练GNN
gnn_model = JCZS_GNN()
train_gnn(dataset)
# 3. 训练GAN
generator = ConditionedGenerator()
discriminator = JCZS_Discriminator(gnn_model) # 共享GNN权重
def syndrome_to_vector(syndrome):
mapping = {
"气滞血瘀": [1, 0, 0, -1, -1, 0, 0], # TNF-α↓, VEGF↓
"阴虚火旺": [0, 1, 0, 0, 0, -1, 1] # IL-6↑, CRP↑
}
return torch.tensor(mapping[syndrome])
# 4. 生成新分子
def generate_molecule(syndrome_type):
# 证型条件向量 (e.g., "气滞血瘀"→[TNF-α↓, VEGF↓])
condition = syndrome_to_vector(syndrome_type)
# 拼接君臣佐使权重
jczs_weights = torch.tensor([3.0, 2.0, 1.0])
cond_vector = torch.cat([condition, jczs_weights])
# 生成分子
z = torch.randn(1, 100) # 噪声向量
start_token = torch.tensor([[SOS_IDX]]) # 起始符
smiles_seq = generator(z, cond_vector, start_token)
return decode_smiles(smiles_seq)
# 示例:生成针对"气滞血瘀"的新分子
new_molecule = generate_molecule("气滞血瘀")报错
最新发布