726. Number of Atoms FetchNum从字符串中Parse数

本文介绍了一种使用Python实现的化学公式解析算法,该算法能够准确地解析化学公式并计算出每种原子的数量。通过使用栈和递归的思想,算法能够处理复杂的嵌套括号和原子计数。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Given a chemical formula (given as a string), return the count of each atom.

An atomic element always starts with an uppercase character, then zero or more lowercase letters, representing the name.

1 or more digits representing the count of that element may follow if the count is greater than 1. If the count is 1, no digits will follow. For example, H2O and H2O2 are possible, but H1O2 is impossible.

Two formulas concatenated together produce another formula. For example, H2O2He3Mg4 is also a formula.

A formula placed in parentheses, and a count (optionally added) is also a formula. For example, (H2O2) and (H2O2)3 are formulas.

Given a formula, output the count of all elements as a string in the following form: the first name (in sorted order), followed by its count (if that count is more than 1), followed by the second name (in sorted order), followed by its count (if that count is more than 1), and so on.

Example 1:

Input: 
formula = "H2O"
Output: "H2O"
Explanation: 
The count of elements are {'H': 2, 'O': 1}.

 

Example 2:

Input: 
formula = "Mg(OH)2"
Output: "H2MgO2"
Explanation: 
The count of elements are {'H': 2, 'Mg': 1, 'O': 2}.

 

Example 3:

Input: 
formula = "K4(ON(SO3)2)2"
Output: "K4N2O14S4"
Explanation: 
The count of elements are {'K': 4, 'N': 2, 'O': 14, 'S': 4}.

 

Note:

  • All atom names consist of lowercase letters, except for the first character which is uppercase.
  • The length of formula will be in the range [1, 1000].
  • formula will only consist of letters, digits, and round parentheses, and is a valid formula as defined in the problem.

-------------------------------------------------------------

用栈或者递归的思路很容易想到,这种问题的坑就是在于从字符串里parse数或者atom,简单起见可以先写主函数,麻烦操作用个函数,后面再补:

from collections import defaultdict
class Solution:
    def countOfAtoms(self, formula: str) -> str:
        l,sta,i = len(formula),[defaultdict(int)],0
        def fetch_num(off):
            nonlocal l,formula
            i,res = off,""
            while (i < l):
                if ("0"<=formula[i]<="9"):
                    res += formula[i]
                else:
                    break
                i += 1
            return (int(res),i) if res else (1,i)

        def fetch_atom(off):
            nonlocal l,formula
            i,res = off+1,formula[off]
            while (i < l):
                if ("a"<=formula[i]<="z"):
                    res += formula[i]
                else:
                    break
                i += 1
            return res,i

        while (i < l):
            if (formula[i] == '('):
                sta.append(defaultdict(int))
                i += 1
            elif (formula[i] == ')'):
                num,i = fetch_num(i+1)
                for k in sta[-1].keys():
                    sta[-1][k] *= num
                if (len(sta) >= 2):
                    top = sta.pop()
                    for k in top.keys():
                        sta[-1][k] += top[k]
            elif ("A"<=formula[i]<="Z"):
                atom,i = fetch_atom(i)
                num,i = fetch_num(i)
                sta[-1][atom] += num
        res = ""
        for key in sorted(sta[-1].keys()):
            val = sta[-1][key]
            pair = key if val == 1 else "{0}{1}".format(key,val)
            res += pair
        return res

s = Solution()
print(s.countOfAtoms("H2O"))

后来发现讨论区有大哥逆序搞,虽然代码短了,但是这种思路并不提倡,容易乱:

class Solution:
    def countOfAtoms(self, formula):
        dic, coeff, stack, elem, cnt, i = collections.defaultdict(int), 1, [], "", 0, 0  
        for c in formula[::-1]:
            if c.isdigit():
                cnt += int(c) * (10 ** i)
                i += 1
            elif c == ")":
                stack.append(cnt)
                coeff *= cnt
                i = cnt = 0
            elif c == "(":
                coeff //= stack.pop()
                i = cnt = 0
            elif c.isupper():
                elem += c
                dic[elem[::-1]] += (cnt or 1) * coeff
                elem = ""
                i = cnt = 0
            elif c.islower():
                elem += c
        return "".join(k + str(v > 1 and v or "") for k, v in sorted(dic.items()))

 

详细解释这个脚本,逐行 import numpy as np import linecache from io import StringIO import time import os import mmap import platform import math import argparse def pbc(a, box): return a - box * np.round(a/box) class CellList: """ A class for implementing the link cell list method in molecular simulations. Handles periodic boundary conditions and efficient neighbor searching. python calculate_hbonds.py --f 01traj.lammpstrj --d 01system.data --donor 3 --hydrogen 5 --acceptor 1,4,6 --outname 01ret.txt --axis=x --binsize=5.0 --dist 3.5 --alpha 30 --start 0 --end 80 Attributes: bounds (np.ndarray): Box boundaries of shape (3, 2) cell_size (float): Requested size of grid cells box_lengths (np.ndarray): Size of the simulation box in each dimension n_cells (np.ndarray): Number of cells in each dimension (nx, ny, nz) actual_cell_size (np.ndarray): Actual cell size after adjusting for box size total_cells (int): Total number of cells in the system particle_cell_indices (np.ndarray): Cell index for each particle cell_to_particles (dict): Mapping from cell index to particle indices particles (np.ndarray): Array of particle coordinates """ def __init__(self, bounds, cell_size): """ Initialize the cell list system. Args: bounds: Box boundaries as [[xmin, xmax], [ymin, ymax], [zmin, zmax]] cell_size: Requested size of grid cells (cubic cells assumed) """ self.bounds = np.array(bounds, dtype=float) self.cell_size = float(cell_size) # Calculate box dimensions self.box_lengths = self.bounds[:, 1] - self.bounds[:, 0] # Calculate number of cells in each dimension (ceiling to cover entire box) self.n_cells = np.ceil(self.box_lengths / self.cell_size).astype(int) self.nx, self.ny, self.nz = self.n_cells self.total_cells = self.nx * self.ny * self.nz # Calculate actual cell size (may be smaller than requested) self.actual_cell_size = self.box_lengths / self.n_cells # Initialize particle storage self.particle_cell_indices = None self.cell_to_particles = None self.particles = None # Validate grid if np.any(self.n_cells < 1): raise ValueError("Cell size too large for box dimensions") if np.any(self.actual_cell_size <= 0): raise ValueError("Invalid cell size resulting in non-positive dimensions") def assign_particles(self, coords): """ Assign particles to cells and build the cell-to-particles mapping. Args: coords: Array of particle coordinates, shape (N, 3) """ coords = np.asarray(coords) if coords.shape[1] != 3: raise ValueError("Particle coordinates must be 3-dimensional") self.particles = coords N = len(coords) self.particle_cell_indices = np.zeros(N, dtype=int) self.cell_to_particles = {} # Assign each particle to a cell for idx, coord in enumerate(coords): cell_idx = self.coord_to_cell_index(coord) self.particle_cell_indices[idx] = cell_idx if cell_idx not in self.cell_to_particles: self.cell_to_particles[cell_idx] = [] self.cell_to_particles[cell_idx].append(idx) def coord_to_cell_index(self, coord): """ Convert a 3D coordinate to a 1D cell index with periodic boundaries. Args: coord: Particle coordinate (x, y, z) Returns: int: 1D cell index """ # Apply periodic boundary conditions rel_pos = coord - self.bounds[:, 0] periodic_pos = rel_pos % self.box_lengths # Calculate grid coordinates i, j, k = (periodic_pos / self.actual_cell_size).astype(int) # Apply periodic wrapping i = i % self.nx j = j % self.ny k = k % self.nz # Convert to 1D index: index = i + j*nx + k*nx*ny return int(i + j * self.nx + k * self.nx * self.ny) def get_neighboring_cells(self, cell_index): """ Get the 26 neighboring cells (plus central cell) for a given cell index. Args: cell_index: Central cell index Returns: list: Indices of 27 neighboring cells (including central cell) """ # Convert 1D index to 3D grid coordinates k = cell_index // (self.nx * self.ny) remainder = cell_index % (self.nx * self.ny) j = remainder // self.nx i = remainder % self.nx neighbors = [] # Iterate over 3x3x3 neighborhood for dx in [-1, 0, 1]: for dy in [-1, 0, 1]: for dz in [-1, 0, 1]: # Apply periodic boundaries ni = (i + dx) % self.nx nj = (j + dy) % self.ny nk = (k + dz) % self.nz # Convert back to 1D index neighbor_index = ni + nj * self.nx + nk * (self.nx * self.ny) neighbors.append(neighbor_index) return neighbors def get_particles_in_neighboring_cells(self, coord): """ Get all particle indices in the 27 neighboring cells of a given coordinate. This includes all particles in the same cell as the coordinate. Args: coord: Reference coordinate (x, y, z) Returns: list: Indices of particles in neighboring cells """ if self.cell_to_particles is None: raise RuntimeError("Particles not assigned. Call assign_particles() first.") # Get central cell index center_idx = self.coord_to_cell_index(coord) # Get all neighboring cell indices (27 cells) neighbor_cells = self.get_neighboring_cells(center_idx) # Collect all particles in these cells particles = [] for cell_idx in neighbor_cells: if cell_idx in self.cell_to_particles: particles.extend(self.cell_to_particles[cell_idx]) return particles def get_neighbor_particles_excluding_self(self, particle_index): """ Get neighbor particles excluding the particle itself. Args: particle_index: Index of the particle to find neighbors for Returns: list: Indices of neighboring particles (excluding self) """ if self.particles is None: raise RuntimeError("Particles not assigned. Call assign_particles() first.") # Get particle coordinate coord = self.particles[particle_index] # Get all particles in neighboring cells (including self) all_neighbors = self.get_particles_in_neighboring_cells(coord) # Remove self from the list neighbors_excluding_self = [idx for idx in all_neighbors if idx != particle_index] return neighbors_excluding_self def get_particles_in_same_cell(self, coord): """ Get all particle indices in the same cell as the given coordinate. Args: coord: Reference coordinate (x, y, z) Returns: list: Indices of particles in the same cell """ if self.cell_to_particles is None: raise RuntimeError("Particles not assigned. Call assign_particles() first.") # Get cell index cell_idx = self.coord_to_cell_index(coord) # Return particles in this cell if cell_idx in self.cell_to_particles: return self.cell_to_particles[cell_idx].copy() return [] def get_cell_stats(self): """ Get information about the grid system. Returns: dict: Dictionary containing grid statistics """ return { "grid_dimensions": (self.nx, self.ny, self.nz), "total_cells": self.total_cells, "requested_cell_size": self.cell_size, "actual_cell_size": self.actual_cell_size, "box_dimensions": self.box_lengths } class snapshot: def __init__(self): self.nodes = {} class ReaderLammpstrj(): def __init__(self, filename): self.filename = filename self.mmap_scan() # grasp some info from frames list print("The trjectory is successfully loaded!") print("num of frames: %d"%(self.nframes)) def mmap_scan(self): f = open(self.filename, "r") if os.name == 'nt': # Windows: don't use MAP_SHARED or tagname self.mmap = mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_READ) else: # Linux/Unix: use MAP_SHARED self.mmap = mmap.mmap(f.fileno(), length=0, flags=mmap.MAP_SHARED, access=mmap.ACCESS_READ) # get all the seek self.Seek = [] index_past = 0 while True: index_current = self.mmap.find(b'ITEM: TIMESTEP', index_past) index_past = index_current + 1 if index_current == -1: break self.Seek.append(index_current) self.Seek = np.asarray(self.Seek, dtype=int) self.nframes = len(self.Seek) def mmap_snap(self, ifr): start = self.Seek[ifr] end = self.Seek[ifr + 1] if ifr != len(self.Seek) - 1 else self.mmap.size() - 1 nchars = end - start self.mmap.seek(start) str_frame = self.mmap.read(nchars).decode('utf-8').strip() return self.parse_snap_fromstr(str_frame) def parse_snap_fromstr(self, snap_str): snap = snapshot() snap_lines = snap_str.split('\n') for i, line in enumerate(snap_lines): sp = line.strip().split() if sp[0] == "ITEM:": if sp[1] == "TIMESTEP": snap.timestep = int(snap_lines[i + 1].strip()) elif sp[1] == "NUMBER": snap.natoms = int(snap_lines[i + 1].strip()) elif sp[1] == "BOX": snap.dimension = len(sp[3:]) edge = np.asarray([[float(l) for l in snap_lines[i + _].strip().split()] for _ in range(1, snap.dimension + 1)]).reshape(-1) snap.edge = edge lx = snap.edge[1] - snap.edge[0] ly = snap.edge[3] - snap.edge[2] lz = snap.edge[5] - snap.edge[4] snap.box = np.array([lx, ly, lz]) snap.bounds = edge.reshape(-1, 2) elif sp[1] == "ATOMS": atom_header = sp[2:] io_str = StringIO("\n".join(snap_lines[i + 1:])) atom_value = np.loadtxt(io_str) # grap info for ah in atom_header: if ah not in snap.nodes: snap.nodes[ah] = [] index = atom_header.index(ah) snap.nodes[ah] = atom_value[:, index] # add some info to snap snap.types = np.unique(snap.nodes['type']) snap.natoms = len(snap.nodes['type']) snap.nodes['type'] = snap.nodes['type'].astype(int) # change xs to x if scaled if 'xs' in snap.nodes: snap.nodes['x'] = snap.nodes['xs'] * \ snap.box[0] + snap.edge[0] snap.nodes['y'] = snap.nodes['ys'] * \ snap.box[1] + snap.edge[2] snap.nodes['z'] = snap.nodes['zs'] * \ snap.box[2] + snap.edge[4] # add pos snap.pos = np.c_[snap.nodes['x'], snap.nodes['y'], snap.nodes['z']] return snap class ReaderLammpsData: def __init__(self, filename): self.filename = filename self.scan() def scan(self): import linecache strLines = np.asarray(linecache.getlines(self.filename)) self.nodes = {} for i, line in enumerate(strLines): sp = line.split() if len(sp) == 2 and sp[-1] in ['atoms', 'bonds', 'angles', 'dihedrals', 'impropers']: key = 'num_' + sp[-1] self.nodes[key] = int(sp[0]) elif len(sp) == 3 and sp[-1] == 'types': key = 'num_' + sp[1] + '_' + sp[-1] self.nodes[key] = int(sp[0]) elif len(sp) == 4 and sp[-1][-2:] == 'hi': key = 'edge_' + sp[-1][0] self.nodes[key] = [float(sp[0]), float(sp[1])] elif len(sp) > 0 and sp[0] == 'Masses': index_start = i + 2 index_end = index_start + self.nodes['num_atom_types'] sectionLines = strLines[index_start:index_end] self.nodes['Masses'] = np.loadtxt( StringIO(''.join(sectionLines))) elif len(sp) > 0 and sp[0] == 'Atoms': index_start = i + 2 index_end = index_start + self.nodes['num_atoms'] sectionLines = strLines[index_start:index_end] io_str = StringIO(''.join(sectionLines)) self.nodes['Atoms'] = np.loadtxt( StringIO(''.join(sectionLines)), dtype=float) elif len(sp) > 0 and sp[0] == 'Bonds': index_start = i + 2 index_end = index_start + self.nodes['num_bonds'] sectionLines = strLines[index_start:index_end] self.nodes['Bonds'] = np.loadtxt( StringIO(''.join(sectionLines)), dtype=int) elif len(sp) > 0 and sp[0] == 'Angles': index_start = i + 2 index_end = index_start + self.nodes['num_angles'] sectionLines = strLines[index_start:index_end] self.nodes['Angles'] = np.loadtxt( StringIO(''.join(sectionLines)), dtype=int) elif len(sp) > 0 and sp[0] == 'Dihedrals': index_start = i + 2 index_end = index_start + self.nodes['num_dihedrals'] sectionLines = strLines[index_start:index_end] self.nodes['Dihedrals'] = np.loadtxt( StringIO(''.join(sectionLines)), dtype=int) elif len(sp) > 0 and sp[0] == 'Impropers': index_start = i + 2 index_end = index_start + self.nodes['num_impropers'] sectionLines = strLines[index_start:index_end] self.nodes['impropers'] = np.loadtxt( StringIO(''.join(sectionLines)), dtype=int) # get types self.atom_types = self.nodes['Atoms'][:, 2] self.types = np.unique(self.atom_types) # get bonded atom id for all atoms self.bonded = {} for bond in self.nodes['Bonds']: ida, idb = bond[2], bond[3] if ida not in self.bonded: self.bonded[ida] = [] self.bonded[ida].append(idb) else: self.bonded[ida].append(idb) if idb not in self.bonded: self.bonded[idb] = [] self.bonded[idb].append(ida) else: self.bonded[idb].append(ida) class compute_hbonds: def __init__(self, donors_type, hydrogens_type, acceptors_type, dist, alpha): self.donors_type = donors_type self.hydrogens_type = hydrogens_type self.acceptors_type = acceptors_type self.dist = dist self.dist2 = dist * dist self.alpha = alpha # 添加初始化 self.timesteps = [] # 存储每帧的时间步 self.frame_hbonds = [] # 存储每帧的氢键信息 self.get_index_fromInitData() self.run() def get_index_fromInitData(self): # get donors id_donors = [] for d in self.donors_type: ret = np.argwhere(D.atom_types == d).reshape(-1) id_donors.extend(D.nodes['Atoms'][ret, 0]) # get hydrogens id_hydrogens = [] for d in self.hydrogens_type: ret = np.argwhere(D.atom_types == d).reshape(-1) id_hydrogens.extend(D.nodes['Atoms'][ret, 0]) # get acceptors id_acceptors = [] for d in self.acceptors_type: ret = np.argwhere(D.atom_types == d).reshape(-1) id_acceptors.extend(D.nodes['Atoms'][ret, 0]) self.id_donors = np.asarray(id_donors, dtype=int) self.id_hydrogens = np.asarray(id_hydrogens, dtype=int) self.id_acceptors = np.asarray(id_acceptors, dtype=int) print("------------------------") print("num of donors: %d" % (self.id_donors.shape[0])) print("num of hydrogens: %d" % (len(self.id_hydrogens))) print("num of acceptors: %d" % (self.id_acceptors.shape[0])) print("------------------------") def run(self): o = open('hbonds_rawdata.csv', 'w') o.write("index,timestep,id_donor,id_hydrogen,id_acceptor,distance,angle,x_donor,y_donor,z_donor,x_h,y_h,z_h,x_acceptor,y_acceptor,z_acceptor\n") for ifr in range(T.nframes): ret_hbonds = [] id_hbonds = [] snap = T.mmap_snap(ifr) print('Tackle Frame Timestep: %d, index: %d' % (snap.timestep, ifr)) # 添加时间步记录 self.timesteps.append(snap.timestep) # 记录当前帧的时间步 hbond_cnt = 0 # link cell list cell_list = CellList(snap.bounds, 3.0) cell_list.assign_particles(snap.pos) for i, don in enumerate(self.id_donors): index = np.argwhere(snap.nodes['id'] == don).reshape(-1)[0] pos_don = np.asarray([snap.nodes['x'][index], snap.nodes['y'][index], snap.nodes['z'][index]]) # get neighboring acceptors neighbors = cell_list.get_neighbor_particles_excluding_self(index) id_neigh_acceptors = [] for nei in neighbors: if snap.nodes['type'][nei] in self.acceptors_type: id_neigh_acceptors.append(snap.nodes['id'][nei]) for h in D.bonded[don]: # check if used if h in id_hbonds: continue if h in self.id_hydrogens: index_h = np.argwhere( snap.nodes['id'] == h).reshape(-1)[0] pos_h = np.asarray([snap.nodes['x'][index_h], snap.nodes['y'][index_h], snap.nodes['z'][index_h]]) vec_h = pbc(pos_h-pos_don, snap.box) else: continue # cycle all the acceptors, like O-H...O for acc in id_neigh_acceptors: # check if used if acc == don: continue index_acc = np.argwhere( snap.nodes['id'] == acc).reshape(-1)[0] pos_acc = np.asarray( [snap.nodes['x'][index_acc], snap.nodes['y'][index_acc], snap.nodes['z'][index_acc]]) # check dist vec_acc = pbc(pos_acc-pos_don, snap.box) dist2 = (vec_acc**2).sum() if dist2 < self.dist2: # check angle angle = (vec_h * vec_acc).sum() / np.linalg.norm(vec_h) / np.linalg.norm(vec_acc) angle = np.arccos(angle) / np.pi * 180 if angle < self.alpha: # record this hydrogen bond hbond_cnt += 1 ret_hbonds.append( [don, h, acc, pos_don, pos_h, pos_acc, dist2**0.5, angle]) id_hbonds.extend([don, h, acc]) # wirte to csv o.write("%d,%d,%d,%d,%d,%.4f,%.4f,%.7f,%.7f,%.7f,%.7f,%.7f,%.7f,%.7f,%.7f,%.7f\n"%(ifr, snap.timestep, don, h, acc, dist2**0.5, angle, pos_don[0], pos_don[1], pos_don[2], pos_h[0], pos_h[1],pos_h[2],pos_acc[0],pos_acc[1],pos_acc[2])) print('hbonds: %d' % (hbond_cnt)) # append hbonds self.frame_hbonds.append(ret_hbonds) o.close() def compute_ditribution_along_vec(self, direction, start, end, dx=5): print('-----------------') if direction == 'x': index = 0 elif direction == 'y': index = 1 else: index = 2 result = [] # get position-direction bins = np.arange(start, end+0.1, dx) for frame in self.frame_hbonds: pos = np.asarray([hb[4][index] for hb in frame]) hist, e = np.histogram(pos, density=False, bins=bins) result.append(hist) rr = bins[:-1] + np.diff(bins) final_result = np.c_[rr, np.asarray(result).T] np.savetxt(outname, final_result, fmt='%.2f') print('calculation finished!') # 添加C(t)计算方法 def compute_Ct(self): """计算C(t)间歇性氢键自相关函并输出""" # 获取初始帧的氢键集合 initial_hbonds = set() for hbond in self.frame_hbonds[0]: don, h, acc = hbond[0], hbond[1], hbond[2] initial_hbonds.add((don, h, acc)) N0 = len(initial_hbonds) # 初始氢键量 if N0 == 0: print("Warning: No hydrogen bonds found at initial frame.") return # 初始化C(t)组 Ct = np.zeros(len(self.frame_hbonds)) timesteps = np.zeros(len(self.frame_hbonds)) # 计算每帧的存活氢键 for t in range(len(self.frame_hbonds)): current_hbonds = set() for hbond in self.frame_hbonds[t]: don, h, acc = hbond[0], hbond[1], hbond[2] current_hbonds.add((don, h, acc)) # 计算当前帧存活的初始氢键 surviving = len(initial_hbonds & current_hbonds) Ct[t] = surviving / N0 timesteps[t] = self.timesteps[t] # 输出C(t)结果 np.savetxt('hbond_Ct.txt', np.column_stack((timesteps, Ct)), header='Timestep C(t)', fmt='%d %.6f') print("C(t) saved to hbond_Ct.txt") def get_args(): parser = argparse.ArgumentParser( description='Calculate the hydrogen bonds distribution along one direction.') parser.add_argument( '--f', type=str, help='filename of lammps dump file', required=True) parser.add_argument( '--d', type=str, help='filename of lammps data file', required=True) parser.add_argument( '--donor', type=str, help="donor type, seperate by ',', like '7,9'", required=True) parser.add_argument('--hydrogen', type=str, help="hydrogen type, seperate by ',', like '5,6'", required=True) parser.add_argument('--acceptor', type=str, help="acceptor type, seperate by ',', like '3,4'", required=True) parser.add_argument('--axis', type=str, help="distribution along axis", required=True) parser.add_argument('--start', type=float, help="histogram bin start along axis", required=True) parser.add_argument('--end', type=float, help="histogram bin end along axis", required=True) parser.add_argument('--binsize', type=float, help="histogram bin size along axis", required=True) parser.add_argument('--outname', type=str, help="dump filename. \nThe fist column: position tick along axis \n \ The other columns: hbonds distribution for all the frames", default='hbond.txt') parser.add_argument('--dist', type=float, help="distance criteria in hydrogen bond, default 3.5", default=3.5) parser.add_argument('--alpha', type=float, help="angle criteria in hydrogen bond, default 30", default=30.0) args = parser.parse_args() return args if __name__ == '__main__': args = get_args() donors = [int(_) for _ in args.donor.split(',')] acceptors = [int(_) for _ in args.acceptor.split(',')] hydrogens = [int(_) for _ in args.hydrogen.split(',')] dist = float(args.dist) alpha = float(args.alpha) axis = args.axis start = args.start end = args.end binsize = float(args.binsize) outname = args.outname # read and calculation T = ReaderLammpstrj(args.f) D = ReaderLammpsData(args.d) # print all types and num print('-----------------') for p in np.unique(D.atom_types): num = len(np.argwhere(D.atom_types == p).reshape(-1)) print('type %d: %d' % (p, num)) C = compute_hbonds(donors, hydrogens, acceptors, dist, alpha) C.compute_ditribution_along_vec(axis, start, end, binsize) # 添加C(t)计算 C.compute_Ct() # 计算并输出C(t)
07-19
你先读取该代码import numpy as np import matplotlib.pyplot as plt from pymatgen.io.vasp import Vasprun from pymatgen.core.structure import Structure from scipy.signal import savgol_filter from scipy.spatial import cKDTree from tqdm import tqdm import matplotlib as mpl import warnings from collections import defaultdict import os import csv import argparse import multiprocessing from functools import partial import time import dill # 忽略可能的警告 warnings.filterwarnings("ignore", category=UserWarning) # 专业绘图设置 - 符合Journal of Chemical Physics要求 plt.style.use('seaborn-v0_8-whitegrid') mpl.rcParams.update({ 'font.family': 'serif', 'font.serif': ['Times New Roman', 'DejaVu Serif'], 'font.size': 12, 'axes.labelsize': 14, 'axes.titlesize': 16, 'xtick.labelsize': 12, 'ytick.labelsize': 12, 'figure.dpi': 600, # 提高分辨率 'savefig.dpi': 600, 'figure.figsize': (8, 6), # 期刊常用尺寸 'lines.linewidth': 2.0, 'legend.fontsize': 10, 'legend.framealpha': 0.8, 'mathtext.default': 'regular', 'axes.linewidth': 1.5, # 加粗坐标轴线 'xtick.major.width': 1.5, 'ytick.major.width': 1.5, 'xtick.major.size': 5, 'ytick.major.size': 5, }) # 1. 增强的原子类型识别函 - 逐帧识别 def identify_atom_types(struct): """识别所有关键原子类型并排除自身化学键""" # 磷酸氧分类 p_oxygens = {"P=O": [], "P-O": [], "P-OH": []} phosphate_hydrogens = [] # 仅P-OH基团中的H原子 # 水合氢离子识别 hydronium_oxygens = [] hydronium_hydrogens = [] # H₃O⁺中的H原子 # 普通水分子 water_oxygens = [] water_hydrogens = [] # 普通水中的H原子 # 氟离子 fluoride_atoms = [i for i, site in enumerate(struct) if site.species_string == "F"] # 铝离子 aluminum_atoms = [i for i, site in enumerate(struct) if site.species_string == "Al"] # 创建快速邻居查找表 neighbor_cache = defaultdict(list) for i, site in enumerate(struct): if site.species_string == "O": neighbors = struct.get_neighbors(site, r=1.3) h_neighbors = [n[0] for n in neighbors if n[0].species_string == "H"] neighbor_cache[i] = h_neighbors # 识别水合氢离子 (H₃O⁺) if len(h_neighbors) == 3: hydronium_oxygens.append(i) for h_site in h_neighbors: hydronium_hydrogens.append(h_site.index) # 识别磷酸基团 for site in struct: if site.species_string == "P": neighbors = struct.get_neighbors(site, r=2.0) # 扩大搜索半径 # 筛选氧原子邻居 o_neighbors = [(n[0], n[1]) for n in neighbors if n[0].species_string == "O"] if len(o_neighbors) < 4: # 如果找不到4个氧原子,使用旧方法 for neighbor in o_neighbors: nn_site = neighbor[0] if neighbor[1] < 1.55: p_oxygens["P=O"].append(nn_site.index) else: if any(n[0].species_string == "H" for n in struct.get_neighbors(nn_site, r=1.3)): p_oxygens["P-OH"].append(nn_site.index) else: p_oxygens["P-O"].append(nn_site.index) continue # 按距离排序 o_neighbors.sort(key=lambda x: x[1]) # 最近的氧原子为P=O p_double_o = o_neighbors[0][0] p_oxygens["P=O"].append(p_double_o.index) # 其他三个氧原子 for i in range(1, 4): o_site = o_neighbors[i][0] # 检查氧原子上是否有氢 if neighbor_cache.get(o_site.index, []): p_oxygens["P-OH"].append(o_site.index) else: p_oxygens["P-O"].append(o_site.index) # 识别P-OH基团中的H原子 (磷酸中的H) for o_idx in p_oxygens["P-OH"]: # 获取与P-OH氧相连的H原子 h_neighbors = neighbor_cache.get(o_idx, []) for h_site in h_neighbors: if h_site.species_string == "H": phosphate_hydrogens.append(h_site.index) # 识别普通水分子 (排除磷酸氧和水合氢离子) for i, site in enumerate(struct): if site.species_string == "O" and i not in hydronium_oxygens: is_phosphate_oxygen = False for cat in p_oxygens.values(): if i in cat: is_phosphate_oxygen = True break if not is_phosphate_oxygen: water_oxygens.append(i) # 识别普通水分子中的H原子 (水中的H) for o_idx in water_oxygens: h_neighbors = neighbor_cache.get(o_idx, []) for h_site in h_neighbors: if h_site.species_string == "H": water_hydrogens.append(h_site.index) return { "phosphate_oxygens": p_oxygens, "phosphate_hydrogens": phosphate_hydrogens, "water_oxygens": water_oxygens, "water_hydrogens": water_hydrogens, "hydronium_oxygens": hydronium_oxygens, "hydronium_hydrogens": hydronium_hydrogens, "fluoride_atoms": fluoride_atoms, "aluminum_atoms": aluminum_atoms } # 2. RDF计算函 - 修复负值问题和序列化问题 def process_frame(struct, center_sel, target_sel, r_max, exclude_bonds, bond_threshold): """处理单帧结构计算,完全处理空原子类型情况""" # 每帧重新识别原子类型(关键!) atom_types = identify_atom_types(struct) # 获取中心原子和目标原子 centers = center_sel(atom_types) targets = target_sel(atom_types) # 处理空原子类型情况 - 第一重保护 if len(centers) == 0 or len(targets) == 0: return { "distances": np.array([], dtype=np.float64), "n_centers": 0, "n_targets": 0, "volume": struct.volume } center_coords = np.array([struct[i].coords for i in centers]) target_coords = np.array([struct[i].coords for i in targets]) lattice = struct.lattice kdtree = cKDTree(target_coords, boxsize=lattice.abc) # 动态确定邻居- 不超过目标原子 k_val = min(50, len(targets)) # 处理目标原子量为0的情况 - 第二重保护 if k_val == 0: return { "distances": np.array([], dtype=np.float64), "n_centers": len(centers), "n_targets": len(targets), "volume": struct.volume } # 执行查询并确保结果统一格式 try: query_result = kdtree.query(center_coords, k=k_val, distance_upper_bound=r_max) except Exception as e: # 异常处理 - 返回空结果 print(f"KDTree query error: {str(e)}") return { "distances": np.array([], dtype=np.float64), "n_centers": len(centers), "n_targets": len(targets), "volume": struct.volume } # 统一处理不同维度的返回结果 if k_val == 1: # 处理单邻居情况 if isinstance(query_result, tuple): distances, indices = query_result else: distances = query_result indices = np.zeros_like(distances, dtype=int) # 确保组格式 distances = np.atleast_1d(distances) indices = np.atleast_1d(indices) else: # 多邻居情况 distances, indices = query_result # 确保二维组格式 if distances.ndim == 1: distances = distances.reshape(-1, 1) indices = indices.reshape(-1, 1) valid_distances = [] for i in range(distances.shape[0]): center_idx = centers[i] for j in range(distances.shape[1]): dist = distances[i, j] # 跳过超出范围的距离 if dist > r_max or np.isinf(dist): continue target_idx = targets[indices[i, j]] # 排除化学键 if exclude_bonds: actual_dist = struct.get_distance(center_idx, target_idx) if actual_dist < bond_threshold: continue valid_distances.append(dist) return { "distances": np.array(valid_distances, dtype=np.float64), "n_centers": len(centers), "n_targets": len(targets), "volume": struct.volume } def calculate_rdf_parallel(structures, center_sel, target_sel, r_max=8.0, bin_width=0.05, exclude_bonds=True, bond_threshold=1.3, workers=1): """ 并行计算径向分布函 :param workers: 并行工作进程 """ bins = np.arange(0, r_max, bin_width) hist = np.zeros(len(bins) - 1) total_centers = 0 total_targets = 0 total_volume = 0 # 准备参 - 使用dill解决序列化问题 dill.settings['recurse'] = True func = partial(process_frame, center_sel=center_sel, target_sel=target_sel, r_max=r_max, exclude_bonds=exclude_bonds, bond_threshold=bond_threshold) # 使用多进程池 with multiprocessing.Pool(processes=workers) as pool: results = [] # 使用imap_unordered提高效率 for res in tqdm(pool.imap_unordered(func, structures), total=len(structures), desc="Calculating RDF"): results.append(res) # 处理结果 - 特别注意空结果处理 n_frames = 0 for res in results: if res is None: continue n_frames += 1 valid_distances = res["distances"] n_centers = res["n_centers"] n_targets = res["n_targets"] volume = res["volume"] # 累加计 if len(valid_distances) > 0: hist += np.histogram(valid_distances, bins=bins)[0] total_centers += n_centers total_targets += n_targets total_volume += volume # 修正归一化 - 解决负值问题 if n_frames == 0: # 没有有效帧时返回空结果 r = bins[:-1] + bin_width/2 return r, np.zeros_like(r), {"position": None, "value": None} avg_density = total_targets / total_volume if total_volume > 0 else 0 r = bins[:-1] + bin_width/2 rdf = np.zeros_like(r) for i in range(len(hist)): r_lower = bins[i] r_upper = bins[i+1] shell_vol = 4/3 * np.pi * (r_upper**3 - r_lower**3) expected_count = shell_vol * avg_density * total_centers # 避免除以零 if expected_count > 1e-10: rdf[i] = hist[i] / expected_count else: rdf[i] = 0 # 更稳健的平滑处理 - 避免边界效应 if len(rdf) > 10: window_length = min(15, len(rdf)//2*2+1) polyorder = min(5, window_length-1) rdf_smoothed = savgol_filter(rdf, window_length=window_length, polyorder=polyorder, mode='mirror') else: rdf_smoothed = rdf # 计算主要峰值 peak_info = {} mask = (r >= 1.5) & (r <= 3.0) if np.any(mask) and np.any(rdf_smoothed[mask] > 0): peak_idx = np.argmax(rdf_smoothed[mask]) peak_pos = r[mask][peak_idx] peak_val = rdf_smoothed[mask][peak_idx] peak_info = {"position": peak_pos, "value": peak_val} else: peak_info = {"position": None, "value": None} return r, rdf_smoothed, peak_info # 3. 定义精细化的选择器函(避免lambda序列化问题) def selector_phosphate_P_double_O(atom_types): return atom_types["phosphate_oxygens"]["P=O"] def selector_phosphate_P_OH(atom_types): return atom_types["phosphate_oxygens"]["P-OH"] def selector_phosphate_P_O(atom_types): return atom_types["phosphate_oxygens"]["P-O"] def selector_phosphate_hydrogens(atom_types): return atom_types["phosphate_hydrogens"] def selector_water_only_hydrogens(atom_types): """仅选择水分子中的氢原子""" return atom_types["water_hydrogens"] def selector_hydronium_only_hydrogens(atom_types): """仅选择水合氢离子中的氢原子""" return atom_types["hydronium_hydrogens"] def selector_water_only_oxygens(atom_types): """仅选择水分子中的氧原子""" return atom_types["water_oxygens"] def selector_hydronium_only_oxygens(atom_types): """仅选择水合氢离子中的氧原子""" return atom_types["hydronium_oxygens"] def selector_fluoride_atoms(atom_types): return atom_types["fluoride_atoms"] def selector_aluminum_atoms(atom_types): return atom_types["aluminum_atoms"] def selector_all_phosphate_oxygens(atom_types): return (atom_types["phosphate_oxygens"]["P=O"] + atom_types["phosphate_oxygens"]["P-O"] + atom_types["phosphate_oxygens"]["P-OH"]) # 4. 根据您的要求定义六张图的RDF分组配置 def get_rdf_groups(): """返回六张图的RDF分组配置(完全符合您的需求)""" return { # 图1: Al的配位情况 "Al_Coordination": [ (selector_aluminum_atoms, selector_fluoride_atoms, "Al-F", "blue"), (selector_aluminum_atoms, selector_water_only_oxygens, "Al-Ow", "green"), (selector_aluminum_atoms, selector_all_phosphate_oxygens, "Al-Op", "red") ], # 图2: F与H形成的氢键 "F_Hydrogen_Bonding": [ (selector_fluoride_atoms, selector_water_only_hydrogens, "F-Hw", "lightblue"), (selector_fluoride_atoms, selector_hydronium_only_hydrogens, "F-Hh", "blue"), (selector_fluoride_atoms, selector_phosphate_hydrogens, "F-Hp", "darkblue") ], # 图3: 磷酸作为受体与周围环境的氢键(区分氧类型) "Phosphate_Acceptor": [ (selector_phosphate_P_double_O, selector_water_only_hydrogens, "P=O···Hw", "orange"), (selector_phosphate_P_double_O, selector_hydronium_only_hydrogens, "P=O···Hh", "red"), (selector_phosphate_P_O, selector_water_only_hydrogens, "P-O···Hw", "lightgreen"), (selector_phosphate_P_O, selector_hydronium_only_hydrogens, "P-O···Hh", "green"), (selector_phosphate_P_OH, selector_water_only_hydrogens, "P-OH···Hw", "lightblue"), (selector_phosphate_P_OH, selector_hydronium_only_hydrogens, "P-OH···Hh", "blue") ], # 图4: 磷酸--水合氢离子交叉氢键(排除同种类型) "Cross_Species_HBonding": [ (selector_phosphate_hydrogens, selector_water_only_oxygens, "Hp···Ow", "pink"), (selector_phosphate_hydrogens, selector_hydronium_only_oxygens, "Hp···Oh", "purple"), (selector_water_only_hydrogens, selector_all_phosphate_oxygens, "Hw···Op", "lightgreen"), (selector_water_only_hydrogens, selector_hydronium_only_oxygens, "Hw···Oh", "green"), (selector_hydronium_only_hydrogens, selector_water_only_oxygens, "Hh···Ow", "lightblue"), (selector_hydronium_only_hydrogens, selector_all_phosphate_oxygens, "Hh···Op", "blue") ], # 图5: 同类型分子内/间氢键(区分磷酸氧类型) "Same_Species_HBonding": [ (selector_phosphate_hydrogens, selector_phosphate_P_double_O, "Hp···P=O", "red"), (selector_phosphate_hydrogens, selector_phosphate_P_O, "Hp···P-O", "orange"), (selector_phosphate_hydrogens, selector_phosphate_P_OH, "Hp···P-OH", "yellow"), (selector_water_only_hydrogens, selector_water_only_oxygens, "Hw···Ow", "lightblue"), (selector_hydronium_only_hydrogens, selector_hydronium_only_oxygens, "Hh···Oh", "blue") ], # 图6: O-O聚集分析(Op不区分类型) "O_O_Aggregation": [ (selector_all_phosphate_oxygens, selector_water_only_oxygens, "Op-Ow", "blue"), (selector_all_phosphate_oxygens, selector_hydronium_only_oxygens, "Op-Oh", "green"), (selector_all_phosphate_oxygens, selector_all_phosphate_oxygens, "Op-Op", "red"), (selector_water_only_oxygens, selector_hydronium_only_oxygens, "Ow-Oh", "purple"), (selector_water_only_oxygens, selector_water_only_oxygens, "Ow-Ow", "cyan"), (selector_hydronium_only_oxygens, selector_hydronium_only_oxygens, "Oh-Oh", "magenta") ] } # 5. 主程序 - 优化并行处理 def main(workers=1): # 定义要处理的体系 vasprun_files = { "System1": "vasprun1.xml", "System2": "vasprun2.xml", "System3": "vasprun3.xml", "System4": "vasprun4.xml" } # 获取RDF分组配置 rdf_groups = get_rdf_groups() # 标题映射(根据您的要求) title_map = { "Al_Coordination": "Al Coordination Environment", "F_Hydrogen_Bonding": "F-H Hydrogen Bonding", "Phosphate_Acceptor": "Phosphate as H-bond Acceptor", "Cross_Species_HBonding": "Cross H-bonding between Different Species", "Same_Species_HBonding": "Intra- and Inter-molecular H-bonding", "O_O_Aggregation": "O-O Aggregation Analysis" } # 存储所有据 all_system_data = {} group_y_max = {group_name: 0 for group_name in list(rdf_groups.keys())} group_x_max = { "Al_Coordination": (1.5, 3.5), "F_Hydrogen_Bonding": (1.0, 3.0), "Phosphate_Acceptor": (1.0, 3.0), "Cross_Species_HBonding": (1.0, 3.0), "Same_Species_HBonding": (1.0, 3.0), "O_O_Aggregation": (2.0, 6.0) } # 创建输出目录 os.makedirs("RDF_Plots", exist_ok=True) # 计算所有体系的所有RDF据 for system_name, vasprun_file in vasprun_files.items(): print(f"\n{'='*50}") print(f"Processing {system_name}: {vasprun_file} with {workers} workers") print(f"{'='*50}") start_time = time.time() try: # 加载VASP结果 vr = Vasprun(vasprun_file, ionic_step_skip=5) structures = vr.structures print(f"Loaded {len(structures)} frames") # 存储体系据 system_data = { "rdf_results": {}, "peak_infos": {} } # 计算所有RDF分组 for group_name, pairs in rdf_groups.items(): system_data["rdf_results"][group_name] = {} system_data["peak_infos"][group_name] = {} group_y_max_current = 0 for center_sel, target_sel, label, color in pairs: print(f"\nCalculating RDF for: {label}") try: r, rdf, peak_info = calculate_rdf_parallel( structures, center_sel, target_sel, r_max=10.0, exclude_bonds=True, bond_threshold=1.3, workers=workers ) system_data["rdf_results"][group_name][label] = (r, rdf, color) system_data["peak_infos"][group_name][label] = peak_info if len(rdf) > 0: current_max = np.max(rdf) if current_max > group_y_max_current: group_y_max_current = current_max if peak_info["position"] is not None: print(f" Peak for {label}: {peak_info['position']:.3f} Å (g(r) = {peak_info['value']:.2f})") else: print(f" No significant peak found for {label} in 1.5-3.0 Å range") except Exception as e: print(f"Error calculating RDF for {label}: {str(e)}") system_data["rdf_results"][group_name][label] = (np.array([]), np.array([]), color) system_data["peak_infos"][group_name][label] = {"position": None, "value": None} if group_y_max_current > group_y_max[group_name]: group_y_max[group_name] = group_y_max_current all_system_data[system_name] = system_data elapsed = time.time() - start_time print(f"\nCompleted processing for {system_name} in {elapsed:.2f} seconds") except Exception as e: print(f"Error processing {system_name}: {str(e)}") # 为每个分组添加余量 for group_name in group_y_max: group_y_max[group_name] = max(group_y_max[group_name] * 1.15, 3.0) # 确保最小值 # 第二步:生成符合期刊要求的图表 for system_name, system_data in all_system_data.items(): print(f"\nGenerating publication-quality plots for {system_name}") for group_name, group_data in system_data["rdf_results"].items(): fig, ax = plt.subplots(figsize=(8, 6)) # 设置坐标轴范围 xlim = group_x_max.get(group_name, (0, 6.0)) ylim = (0, group_y_max[group_name]) for label, (r, rdf, color) in group_data.items(): if len(r) > 0 and len(rdf) > 0: ax.plot(r, rdf, color=color, label=label, linewidth=2.0) ax.set_xlim(xlim) ax.set_ylim(ylim) # 期刊格式标签 ax.set_xlabel('Radial Distance (Å)', fontweight='bold') ax.set_ylabel('g(r)', fontweight='bold') # 添加体系名称到标题 ax.set_title(f"{system_name}: {title_map[group_name]}", fontsize=16, pad=15) # 精简图例 ncol = 3 if group_name == "Same_Species_HBonding" else 1 # 图5使用三列图例 ax.legend(ncol=ncol, loc='best', framealpha=0.8, fontsize=10) # 添加氢键区域标记(除O-O聚集图外) if group_name != "O_O_Aggregation": ax.axvspan(1.5, 2.5, alpha=0.1, color='green', zorder=0) ax.text(1.7, ylim[1]*0.85, 'H-bond Region', fontsize=10) # 添加网格 ax.grid(True, linestyle='--', alpha=0.5) # 保存高分辨率图片 plt.tight_layout() filename = os.path.join("RDF_Plots", f"RDF_{system_name}_{group_name}.tiff") plt.savefig(filename, bbox_inches='tight', dpi=600, format='tiff') print(f"Saved publication plot: {filename}") plt.close() # 保存Origin兼容据 save_origin_data(system_name, system_data) print("\nAll RDF analysis completed successfully!") def save_origin_data(system_name, system_data): """保存Origin兼容格式据""" os.makedirs("Origin_Data", exist_ok=True) system_dir = os.path.join("Origin_Data", system_name) os.makedirs(system_dir, exist_ok=True) # 保存峰值信息 peak_info_path = os.path.join(system_dir, f"Peak_Positions_{system_name}.csv") with open(peak_info_path, 'w', newline='') as csvfile: writer = csv.writer(csvfile) writer.writerow(["Group", "Interaction", "Peak Position (A)", "g(r) Value"]) for group_name, peaks in system_data["peak_infos"].items(): for label, info in peaks.items(): if info["position"] is not None: writer.writerow([group_name, label, f"{info['position']:.3f}", f"{info['value']:.3f}"]) else: writer.writerow([group_name, label, "N/A", "N/A"]) print(f"Saved peak positions: {peak_info_path}") # 保存RDF据 for group_name, group_results in system_data["rdf_results"].items(): group_dir = os.path.join(system_dir, group_name) os.makedirs(group_dir, exist_ok=True) for label, (r, rdf, color) in group_results.items(): if len(r) > 0 and len(rdf) > 0: safe_label = label.replace(" ", "_").replace("/", "_").replace("=", "_") safe_label = safe_label.replace("(", "").replace(")", "").replace("$", "") filename = f"RDF_{system_name}_{group_name}_{safe_label}.csv" filepath = os.path.join(group_dir, filename) with open(filepath, 'w', newline='') as csvfile: writer = csv.writer(csvfile) writer.writerow(["Distance (A)", "g(r)"]) for i in range(len(r)): writer.writerow([f"{r[i]:.6f}", f"{rdf[i]:.6f}"]) print(f"Saved Origin data: {filename}") if __name__ == "__main__": # 设置命令行参 parser = argparse.ArgumentParser(description='Calculate RDF for VASP simulations') parser.add_argument('--workers', type=int, default=multiprocessing.cpu_count(), help=f'Number of parallel workers (default: {multiprocessing.cpu_count()})') args = parser.parse_args() print(f"Starting RDF analysis with {args.workers} workers...") main(workers=args.workers) 以上代码实现了湿法磷酸体系中水 磷酸 水合氢离子以及氟之间的RDF计算,其中将O和H分别归类。沿用该代码的框架,修改其中的判别逻辑。首先识别P,在P周围搜寻O原子,如果该O原子距离在1.6埃以内则视为Op,而对于Op在其周围搜寻H原子,如果该H在距离1.3埃以内则视为成键即该H为Hp,通过是否有Hp则可以识别P-OH与P=O/P-O,在这里P=O和P-O不能区分,我们将其记为P-O/P=O。接着体系中全部的O原子在去除Op之后剩下的O,在这些剩余的O周围搜寻整体的H,如果H的距离在1.2埃以内则视为成键,然后依照成键的H量判定:如果H的量为1,则记为-OH羟基(在这里不需要计算羟基部分,只是识别出来有利于逻辑完整性,并不参与RDF计算,也不需要特别标注表明),H的量为2,则记为H2O水(该O也随之记为Ow,对应的两个H也记为Hw),如果H的量为3,则记为水合氢离子(该O随之记为Oh,对应的三个H也记为Hh)。体系中存在质子转移的情况,所以需要每一帧重新识别原子的归属问题,如果H同时处于两个成键识别范围则按照就近原则,离哪个近则归属到哪一个(这里包括磷酸-磷酸,磷酸-水,磷酸-水合氢离子,水-水,水-水合氢离子,水合氢离子-水合氢离子,如果H同时处于某种情况下两个化学成键范围则采用就近原则),在实时重新归属质子的情况下,计算出包含质子转移部分的RDF,在这里,我们将排除自身化学键的阈值先设置为0,不需要只看氢键部分了。直接将-OH视为质子转移或者不完整而直接忽略即可,磷酸上的O需要通过H来进一步识别,所以符合距离的氧可暂时作为Op的候选,等H的识别完成再进行细分P=O/P-O和P-OH
07-15
内容概要:本文档详细介绍了基于MATLAB实现的多头长短期记忆网络(MH-LSTM)结合Transformer编码器进行多变量时间序列预测的项目实例。项目旨在通过融合MH-LSTM对时序动态的细致学习和Transformer对全局依赖的捕捉,显著提升多变量时间序列预测的精度和稳定性。文档涵盖了从项目背景、目标意义、挑战与解决方案、模型架构及代码示例,到具体的应用领域、部署与应用、未来改进方向等方面的全面内容。项目不仅展示了技术实现细节,还提供了从据预处理、模型构建与训练到性能评估的全流程指导。 适合人群:具备一定编程基础,特别是熟悉MATLAB和深度学习基础知识的研发人员、据科学家以及从事时间序列预测研究的专业人士。 使用场景及目标:①深入理解MH-LSTM与Transformer结合的多变量时间序列预测模型原理;②掌握MATLAB环境下复杂神经网络的搭建、训练及优化技巧;③应用于金融风险管理、智能电网负荷预测、气象预报、交通流量预测、工业设备健康监测、医疗据分析、供应链需求预测等多个实际场景,以提高预测精度和决策质量。 阅读建议:此资源不仅适用于希望深入了解多变量时间序列预测技术的读者,也适合希望通过MATLAB实现复杂深度学习模型的开发者。建议读者在学习过程中结合提供的代码示例进行实践操作,并关注模型训练中的关键步骤和超参调优策略,以便更好地应用于实际项目中。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值