我发给你的代码是:import numpy as np
import matplotlib.pyplot as plt
from pymatgen.io.vasp import Vasprun
from pymatgen.core import Structure
from scipy.optimize import curve_fit
from tqdm import tqdm
import matplotlib as mpl
from collections import defaultdict
from pymatgen.analysis.local_env import VoronoiNN
from multiprocessing import Pool, cpu_count
import time
import argparse
import os
from scipy.spatial import cKDTree
import csv
# ================ 配置区域 ================
# 在这里设置您的 vasprun.xml 文件默认路径
DEFAULT_VASPRUN_PATH = "/path/to/your/vasprun4.xml" # 更改为您的实际路径
# =========================================
# 1. 设置期刊绘图格式
mpl.rcParams['font.family'] = 'DejaVu Sans'
mpl.rcParams['font.sans-serif'] = ['DejaVu Sans']
mpl.rcParams['font.size'] = 10
mpl.rcParams['axes.linewidth'] = 1.5
mpl.rcParams['xtick.major.width'] = 1.5
mpl.rcParams['ytick.major.width'] = 1.5
mpl.rcParams['lines.linewidth'] = 2
mpl.rcParams['figure.dpi'] = 300
mpl.rcParams['figure.figsize'] = (4.0, 3.0)
mpl.rcParams['savefig.dpi'] = 600
mpl.rcParams['savefig.bbox'] = 'tight'
mpl.rcParams['savefig.pad_inches'] = 0.05
mpl.rcParams['pdf.fonttype'] = 42
def get_hbond_config():
"""返回氢键配置,包含距离和角度阈值"""
return [
{
"name": "P-O/P=O···Hw",
"donor_type": "water_oxygens",
"acceptor_type": "P-O/P=O",
"h_type": "water_hydrogens",
"distance_threshold": 2.375,
"angle_threshold": 143.99,
"color": "#1f77b4"
},
{
"name": "P-O/P=O···Hh",
"donor_type": "hydronium_oxygens",
"acceptor_type": "P-O/P=O",
"h_type": "hydronium_hydrogens",
"distance_threshold": 2.275,
"angle_threshold": 157.82,
"color": "#ff7f0e"
},
{
"name": "P-O/P=O···Hp",
"donor_type": "P-OH",
"acceptor_type": "P-O/P=O",
"h_type": "phosphate_hydrogens",
"distance_threshold": 2.175,
"angle_threshold": 155.00,
"color": "#2ca02c"
},
{
"name": "P-OH···Ow",
"donor_type": "P-OH",
"acceptor_type": "water_oxygens",
"h_type": "phosphate_hydrogens",
"distance_threshold": 2.275,
"angle_threshold": 155.13,
"color": "#d62728"
},
{
"name": "Hw···Ow",
"donor_type": "water_oxygens",
"acceptor_type": "water_oxygens",
"h_type": "water_hydrogens",
"distance_threshold": 2.375,
"angle_threshold": 138.73,
"color": "#9467bd"
},
{
"name": "Hh···Ow",
"donor_type": "hydronium_oxygens",
"acceptor_type": "water_oxygens",
"h_type": "hydronium_hydrogens",
"distance_threshold": 2.225,
"angle_threshold": 155.31,
"color": "#8c564b"
},
{
"name": "Hw···F",
"donor_type": "water_oxygens",
"acceptor_type": "fluoride_atoms",
"h_type": "water_hydrogens",
"distance_threshold": 2.225,
"angle_threshold": 137.68,
"color": "#e377c2"
},
{
"name": "Hh···F",
"donor_type": "hydronium_oxygens",
"acceptor_type": "fluoride_atoms",
"h_type": "hydronium_hydrogens",
"distance_threshold": 2.175,
"angle_threshold": 154.64,
"color": "#7f7f7f"
},
{
"name": "Hp···F",
"donor_type": "P-OH",
"acceptor_type": "fluoride_atoms",
"h_type": "phosphate_hydrogens",
"distance_threshold": 2.275,
"angle_threshold": 153.71,
"color": "#bcbd22"
}
]
def identify_atom_types(struct):
"""原子类型识别函数 - 与角度统计代码保持一致"""
# 初始化数据结构
atom_types = {
"phosphate_oxygens": {"P-O/P=O": [], "P-OH": []},
"phosphate_hydrogens": [],
"water_oxygens": [],
"water_hydrogens": [],
"hydronium_oxygens": [],
"hydronium_hydrogens": [],
"fluoride_atoms": [i for i, site in enumerate(struct) if site.species_string == "F"],
"aluminum_atoms": [i for i, site in enumerate(struct) if site.species_string == "Al"]
}
# 构建全局KDTree
all_coords = np.array([site.coords for site in struct])
kdtree = cKDTree(all_coords, boxsize=struct.lattice.abc)
# 识别磷酸基团
p_atoms = [i for i, site in enumerate(struct) if site.species_string == "P"]
phosphate_oxygens = []
for p_idx in p_atoms:
# 查找P周围的O原子 (距离 < 1.6Å)
neighbors = kdtree.query_ball_point(all_coords[p_idx], r=1.6)
p_o_indices = [idx for idx in neighbors if idx != p_idx and struct[idx].species_string == "O"]
phosphate_oxygens.extend(p_o_indices)
# 识别所有H原子并确定归属
hydrogen_owners = {}
h_atoms = [i for i, site in enumerate(struct) if site.species_string == "H"]
for h_idx in h_atoms:
neighbors = kdtree.query_ball_point(all_coords[h_idx], r=1.3)
candidate_os = [idx for idx in neighbors if idx != h_idx and struct[idx].species_string == "O"]
if not candidate_os:
continue
min_dist = float('inf')
owner_o = None
for o_idx in candidate_os:
dist = struct.get_distance(h_idx, o_idx)
if dist < min_dist:
min_dist = dist
owner_o = o_idx
hydrogen_owners[h_idx] = owner_o
# 分类磷酸氧:带H的为P-OH,不带H的为P-O/P=O
for o_idx in phosphate_oxygens:
has_hydrogen = any(owner_o == o_idx for h_idx, owner_o in hydrogen_owners.items())
if has_hydrogen:
atom_types["phosphate_oxygens"]["P-OH"].append(o_idx)
else:
atom_types["phosphate_oxygens"]["P-O/P=O"].append(o_idx)
# 识别水和水合氢离子
all_o_indices = [i for i, site in enumerate(struct) if site.species_string == "O"]
non_phosphate_os = [o_idx for o_idx in all_o_indices if o_idx not in phosphate_oxygens]
o_h_count = {}
for h_idx, owner_o in hydrogen_owners.items():
o_h_count[owner_o] = o_h_count.get(owner_o, 0) + 1
for o_idx in non_phosphate_os:
h_count = o_h_count.get(o_idx, 0)
attached_hs = [h_idx for h_idx, owner_o in hydrogen_owners.items() if owner_o == o_idx]
if h_count == 2:
atom_types["water_oxygens"].append(o_idx)
atom_types["water_hydrogens"].extend(attached_hs)
elif h_count == 3:
atom_types["hydronium_oxygens"].append(o_idx)
atom_types["hydronium_hydrogens"].extend(attached_hs)
# 识别磷酸基团的H原子
for o_idx in atom_types["phosphate_oxygens"]["P-OH"]:
attached_hs = [h_idx for h_idx, owner_o in hydrogen_owners.items() if owner_o == o_idx]
atom_types["phosphate_hydrogens"].extend(attached_hs)
return atom_types, hydrogen_owners
def calculate_angle(struct, donor_idx, h_idx, acceptor_idx):
"""计算D-H···A键角 (角度制),使用笛卡尔向量表示并处理周期性"""
# 获取分数坐标
frac_coords = struct.frac_coords
lattice = struct.lattice
# 获取氢原子H的分数坐标
h_frac = frac_coords[h_idx]
# 计算供体D相对于H的分数坐标差
d_frac = frac_coords[donor_idx]
dh_frac = d_frac - h_frac
# 计算受体A相对于H的分数坐标差
a_frac = frac_coords[acceptor_idx]
ah_frac = a_frac - h_frac
# 应用周期性修正 (将分数坐标差限制在[-0.5, 0.5]范围内)
dh_frac = np.where(dh_frac > 0.5, dh_frac - 1, dh_frac)
dh_frac = np.where(dh_frac < -0.5, dh_frac + 1, dh_frac)
ah_frac = np.where(ah_frac > 0.5, ah_frac - 1, ah_frac)
ah_frac = np.where(ah_frac < -0.5, ah_frac + 1, ah_frac)
# 转换为笛卡尔向量 (H->D 和 H->A)
vec_hd = np.dot(dh_frac, lattice.matrix) # H->D向量
vec_ha = np.dot(ah_frac, lattice.matrix) # H->A向量
# 计算向量点积
dot_product = np.dot(vec_hd, vec_ha)
# 计算向量模长
norm_hd = np.linalg.norm(vec_hd)
norm_ha = np.linalg.norm(vec_ha)
# 避免除以零
if norm_hd < 1e-6 or norm_ha < 1e-6:
return None
# 计算余弦值
cos_theta = dot_product / (norm_hd * norm_ha)
# 处理数值误差
cos_theta = np.clip(cos_theta, -1.0, 1.0)
# 计算角度 (弧度转角度)
angle_rad = np.arccos(cos_theta)
angle_deg = np.degrees(angle_rad)
return angle_deg
# 2. 氢键识别函数 (优化版)
def precompute_structural_features(structure):
"""预计算结构特征减少重复计算"""
features = {
'structure': structure,
'volume': structure.volume
}
# 识别原子类型
atom_types, hydrogen_owners = identify_atom_types(structure)
features['atom_types'] = atom_types
features['hydrogen_owners'] = hydrogen_owners
# 创建给体到氢原子的映射
donor_to_hydrogens = defaultdict(list)
for h_idx, donor_idx in hydrogen_owners.items():
donor_to_hydrogens[donor_idx].append(h_idx)
features['donor_to_hydrogens'] = donor_to_hydrogens
return features
def identify_h_bonds_optimized(features, hbond_config):
"""使用预计算特征识别特定类型的氢键 - 与角度统计代码保持一致"""
h_bonds = []
struct = features['structure']
atom_types = features['atom_types']
# 构建全局KDTree用于快速搜索
all_coords = np.array([site.coords for site in struct])
lattice_abc = struct.lattice.abc
kdtree = cKDTree(all_coords, boxsize=lattice_abc)
# 处理每一类氢键
# 获取受体原子列表
acceptor_type = hbond_config["acceptor_type"]
if acceptor_type in atom_types:
if isinstance(atom_types[acceptor_type], dict):
# 处理磷酸氧的子类型
acceptors = []
for sub_type in atom_types[acceptor_type].values():
acceptors.extend(sub_type)
else:
acceptors = atom_types[acceptor_type]
else:
return h_bonds
# 获取氢原子列表
h_type = hbond_config["h_type"]
if h_type not in atom_types:
return h_bonds
hydrogens = atom_types[h_type]
# 如果没有氢原子或受体,跳过
if len(hydrogens) == 0 or len(acceptors) == 0:
return h_bonds
# 为受体构建KDTree(使用全局坐标)
acceptor_coords = all_coords[acceptors]
acceptor_kdtree = cKDTree(acceptor_coords, boxsize=lattice_abc)
# 遍历所有氢原子
for h_idx in hydrogens:
h_coords = all_coords[h_idx]
# 查找与H成键的供体 (距离 < 1.2Å)
donor_neighbors = kdtree.query_ball_point(h_coords, r=1.2)
donor_candidates = []
# 根据氢键类型确定供体类型
donor_type = hbond_config["donor_type"]
if donor_type in atom_types:
if isinstance(atom_types[donor_type], dict):
# 处理磷酸氧的子类型
for sub_type in atom_types[donor_type].values():
donor_candidates.extend([idx for idx in donor_neighbors if idx in sub_type])
else:
donor_candidates = [idx for idx in donor_neighbors if idx in atom_types[donor_type]]
# 如果没有找到供体,跳过
if not donor_candidates:
continue
# 选择最近的供体
min_dist = float('inf')
donor_idx = None
for d_idx in donor_candidates:
dist = struct.get_distance(h_idx, d_idx)
if dist < min_dist:
min_dist = dist
donor_idx = d_idx
# 查找在距离阈值内的受体
acceptor_indices = acceptor_kdtree.query_ball_point(h_coords, r=hbond_config["distance_threshold"])
for a_idx_offset in acceptor_indices:
a_idx = acceptors[a_idx_offset]
# 排除供体自身
if a_idx == donor_idx:
continue
# 计算键角
angle = calculate_angle(struct, donor_idx, h_idx, a_idx)
# 检查角度阈值
if angle is not None and angle >= hbond_config["angle_threshold"]:
# 使用三元组表示氢键 (给体, 氢, 受体)
h_bonds.append((donor_idx, h_idx, a_idx))
return h_bonds
# 3. 并行处理框架
def process_frame(args):
"""处理单个帧的并行函数"""
try:
i, structure, hbond_configs = args
features = precompute_structural_features(structure)
frame_results = {}
for config in hbond_configs:
bonds = identify_h_bonds_optimized(features, config)
density = len(bonds) / features['volume'] * 1e24 / 6.022e23
frame_results[config["name"]] = (bonds, density)
return i, frame_results
except Exception as e:
print(f"Error in process_frame: {str(e)}")
return i, {}
def calculate_tcf_parallel(h_bond_list, max_dt, n_workers):
"""并行计算时间关联函数"""
total_bonds = len(h_bond_list[0]) if len(h_bond_list) > 0 else 0
tcf = np.zeros(max_dt)
counts = np.zeros(max_dt, dtype=int)
total_frames = len(h_bond_list)
if total_frames - max_dt <= 0:
print(f"Warning: Not enough frames for TCF calculation (total_frames={total_frames}, max_dt={max_dt})")
return np.zeros(max_dt)
tasks = [(t0, h_bond_list, max_dt) for t0 in range(total_frames - max_dt)]
with Pool(n_workers) as pool:
results = list(tqdm(pool.imap(process_tcf_task, tasks),
total=len(tasks),
desc="Parallel TCF Calculation"))
for tcf_part, counts_part in results:
tcf += tcf_part
counts += counts_part
valid_counts = np.where(counts > 0, counts, 1)
tcf = tcf / valid_counts
return tcf / total_bonds if total_bonds > 0 else np.zeros(max_dt)
def process_tcf_task(args):
"""处理单个TCF任务的并行函数"""
t0, h_bond_list, max_dt = args
ref_bonds = set(h_bond_list[t0])
local_tcf = np.zeros(max_dt)
local_counts = np.zeros(max_dt, dtype=int)
for dt in range(max_dt):
t = t0 + dt
if t >= len(h_bond_list):
break
current_bonds = set(h_bond_list[t])
persistent_bonds = ref_bonds & current_bonds
local_tcf[dt] = len(persistent_bonds)
local_counts[dt] = 1
return local_tcf, local_counts
# 4. 主函数(优化版)
def main_optimized(workers, vasprun_path):
print(f"Using {workers} workers for parallel processing")
print(f"Using vasprun file: {vasprun_path}")
# 时间参数定义
time_per_step = 0.2e-3
skip_steps = 5
frame_interval = time_per_step * skip_steps
print(f"Time per frame: {frame_interval:.6f} ps")
# 读取VASP轨迹
print(f"Reading vasprun.xml...")
try:
vr = Vasprun(vasprun_path, ionic_step_skip=skip_steps)
structures = vr.structures
total_frames = len(structures)
print(f"Loaded {total_frames} frames")
# 测试第一帧的原子类型识别
test_atom_types, _ = identify_atom_types(structures[0])
print("Atom type counts in first frame:")
for key, value in test_atom_types.items():
if isinstance(value, dict):
for sub_key, sub_value in value.items():
print(f" {key}.{sub_key}: {len(sub_value)}")
else:
print(f" {key}: {len(value)}")
except Exception as e:
print(f"Error loading vasprun file: {str(e)}")
return
# 计算总模拟时间
total_sim_time = total_frames * frame_interval
print(f"Total simulation time: {total_sim_time:.3f} ps")
# 设置时间窗口
max_time_window = min(2.0, total_sim_time * 0.5)
max_dt_frames = int(max_time_window / frame_interval)
print(f"Setting TCF time window: {max_time_window:.3f} ps ({max_dt_frames} frames)")
# 获取氢键配置
hbond_configs = get_hbond_config()
bond_names = [config["name"] for config in hbond_configs]
colors = [config["color"] for config in hbond_configs]
# 并行计算氢键密度
print(f"Parallel hydrogen bond calculation with {workers} workers...")
start_time = time.time()
# 创建任务列表
tasks = [(i, structure, hbond_configs) for i, structure in enumerate(structures)]
# 使用进程池并行处理
densities = {name: np.zeros(total_frames) for name in bond_names}
h_bond_lists = {name: [None] * total_frames for name in bond_names}
try:
with Pool(workers) as pool:
results = list(tqdm(pool.imap(process_frame, tasks),
total=total_frames,
desc="Processing Frames"))
# 整理结果
for i, frame_results in results:
for name in bond_names:
if name in frame_results: # 确保键存在
bonds, density = frame_results[name]
densities[name][i] = density
h_bond_lists[name][i] = bonds
else:
densities[name][i] = 0
h_bond_lists[name][i] = []
# 输出氢键数量的统计信息
print("\nHydrogen bond counts per frame (average):")
for name in bond_names:
avg_count = np.mean([len(bonds) for bonds in h_bond_lists[name]])
print(f" {name}: {avg_count:.2f}")
except Exception as e:
print(f"Error during parallel processing: {str(e)}")
import traceback
traceback.print_exc()
return
# 计算平均密度
avg_densities = {name: np.mean(densities[name]) for name in bond_names}
hb_time = time.time() - start_time
print(f"H-bond calculation completed in {hb_time:.2f} seconds")
# 并行计算TCF
print("Parallel TCF calculation...")
tcf_results = {}
lifetimes = {}
reliability = {}
for name in bond_names:
print(f"\nProcessing bond type: {name}")
h_bond_list = h_bond_lists[name]
# 检查是否有氢键
total_hbonds = sum(len(bonds) for bonds in h_bond_list)
if total_hbonds == 0:
print(f"No hydrogen bonds found for {name}, skipping TCF calculation")
tcf_results[name] = np.zeros(max_dt_frames)
lifetimes[name] = np.nan
reliability[name] = "no bonds found"
continue
start_time = time.time()
try:
tcf = calculate_tcf_parallel(h_bond_list, max_dt_frames, workers)
tcf_results[name] = tcf
except Exception as e:
print(f"Error calculating TCF for {name}: {str(e)}")
tcf_results[name] = np.zeros(max_dt_frames)
lifetimes[name] = np.nan
reliability[name] = "TCF calculation failed"
continue
# 创建时间轴 (转换为ps)
t_frames = np.arange(len(tcf))
t_ps = t_frames * frame_interval
# 指数拟合
mask = (tcf > 0.01) & (t_ps < max_time_window * 0.8)
if np.sum(mask) > 3:
try:
# 确保TCF值在合理范围内
valid_tcf = tcf[mask]
if np.max(valid_tcf) < 0.1: # TCF值太小,可能无法拟合
print(f"TCF values too small for fitting {name}")
lifetimes[name] = np.nan
reliability[name] = "TCF values too small"
else:
popt, _ = curve_fit(exp_decay, t_ps[mask], tcf[mask], p0=[1.0, 1.0], maxfev=5000)
lifetimes[name] = popt[0]
if lifetimes[name] > max_time_window * 0.8:
reliability[name] = "unreliable (τ > 80% time window)"
else:
reliability[name] = "reliable"
print(f"Fit successful: τ = {lifetimes[name]:.3f} ps ({reliability[name]})")
except Exception as e:
print(f"Fit failed for {name}: {str(e)}")
lifetimes[name] = np.nan
reliability[name] = "fit failed"
else:
print(f"Not enough data points for fitting {name}")
lifetimes[name] = np.nan
reliability[name] = "insufficient data"
tcf_time = time.time() - start_time
print(f"TCF calculation for {name} completed in {tcf_time:.2f} seconds")
# 绘图
labels = {config["name"]: config["name"] for config in hbond_configs}
# 绘制数量密度
fig1, ax1 = plt.subplots()
time_axis = np.arange(len(structures)) * frame_interval
for i, name in enumerate(bond_names):
ax1.plot(time_axis, densities[name], color=colors[i], label=labels[name], alpha=0.7)
ax1.set_xlabel('Time (ps)')
ax1.set_ylabel('Concentration (mol/L)')
ax1.legend(frameon=False, fontsize=9)
fig1.tight_layout()
fig1.savefig('hbond_densities.pdf')
print("Saved hbond_densities.pdf")
# 绘制时间关联函数
fig2, ax2 = plt.subplots()
for i, name in enumerate(bond_names):
t_ps = np.arange(len(tcf_results[name])) * frame_interval
# 创建标签
label = f'{labels[name]}'
if name in lifetimes and not np.isnan(lifetimes[name]):
label += f' (τ={lifetimes[name]:.3f} ps)'
if reliability.get(name, "").startswith("unreliable"):
label += '*'
ax2.semilogy(t_ps, tcf_results[name], color=colors[i], label=label)
# 标记时间窗口上限
if i == 0:
ax2.axvline(x=max_time_window, color='gray', linestyle='--', alpha=0.5)
ax2.text(max_time_window, 0.5, 'Time window limit',
rotation=90, va='center', ha='right', fontsize=8)
ax2.set_xlabel('Time (ps)')
ax2.set_ylabel('Hydrogen Bond TCF')
ax2.set_xlim(0, max_time_window * 1.1)
ax2.set_ylim(1e-3, 1.1)
ax2.legend(frameon=False, loc='best', fontsize=9)
# 添加可靠性说明
if any("unreliable" in rel for name, rel in reliability.items() if isinstance(rel, str)):
ax2.text(0.95, 0.05, "* Unreliable fit (τ > 80% time window)",
transform=ax2.transAxes, ha='right', fontsize=8)
fig2.tight_layout()
fig2.savefig('hbond_tcf.pdf')
print("Saved hbond_tcf.pdf")
# 保存详细结果到CSV文件
os.makedirs("HBond_Lifetime_Results", exist_ok=True)
results_path = os.path.join("HBond_Lifetime_Results", "detailed_results.csv")
with open(results_path, 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
writer.writerow(["Bond Type", "Avg Density (mol/L)", "Lifetime (ps)", "Reliability"])
for name in bond_names:
lifetime = lifetimes.get(name, np.nan)
rel = reliability.get(name, "N/A")
writer.writerow([name, f"{avg_densities[name]:.6f}", f"{lifetime:.6f}", rel])
print(f"Saved detailed results: {results_path}")
# 打印结果
print("\nResults Summary:")
print("="*70)
print("{:<20} {:<15} {:<15} {:<20}".format(
"Bond Type", "Avg Density", "Lifetime (ps)", "Reliability"))
print("-"*70)
for name in bond_names:
lifetime = lifetimes.get(name, np.nan)
rel = reliability.get(name, "N/A")
print("{:<20} {:<15.4f} {:<15.3f} {:<20}".format(
name, avg_densities[name], lifetime, rel
))
print("="*70)
def exp_decay(t, tau, A):
return A * np.exp(-t / tau)
if __name__ == "__main__":
# 命令行参数解析
parser = argparse.ArgumentParser(description='Hydrogen Bond Lifetime Analysis')
parser.add_argument('--workers', type=int, default=max(1, cpu_count() - 4),
help='Number of parallel workers (default: CPU cores - 4)')
parser.add_argument('--vasprun', type=str, default=DEFAULT_VASPRUN_PATH,
help='Path to vasprun.xml file')
args = parser.parse_args()
print(f"Starting analysis with {args.workers} workers")
print(f"Using vasprun file: {args.vasprun}")
main_optimized(workers=args.workers, vasprun_path=args.vasprun)
最新发布