A*(A_star)搜索总结

本文深入解析A*搜索算法,一种结合启发式估价函数的高效搜索方法,特别适用于寻找最优路径问题。通过实例说明如何利用A*算法求解第K最短路径,展示其在减少无效状态转移和加速搜索过程中的优势。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

\(A^*(A star)\)搜索总结

标签:算法——搜索
阅读体验:https://zybuluo.com/Junlier/note/1299772

定义

先复制一则定义

\(A^*\)算法在人工智能中是一种典型的启发式搜索算法
启发中的估价是用估价函数表示的:
h(n)=f(n)+g(n)
其中f(n)是节点n的估价函数
g(n)表示实际状态空间中从初始节点到n节点的实际代价
h(n)是从n到目标节点最佳路径的估计代价。
另外定义h'(n)为n到目标节点最佳路径的实际值。
如果h'(n)≥h(n)则如果存在从初始状态走到目标状态的最小代价的解
那么用该估价函数搜索的算法就叫\(A^*\)算法。

有点繁琐,但也看得过去

通俗来讲

\(A^*\)的核心在于上面所讲到的估价函数
他是干什么用的呢
就是我们在搜索的过程中,保证更优的先搜用的
还是有些繁琐对不对,嗯,我也不大会讲啊(没事我会加油


嘿,认真看下面,我可认真了的啊。。。

如果一个题目要求我们求前K个代价最小的解(只是一个典型,不是所有题目都这样)
假设我们现在有一个状态在\(now\)
已经要记录到答案里面的代价是\(D\)(我喜欢用这个)
我们发现如果爆搜的话状态会是乱的对不对,肯定会使搜索搜到太多
而如果直接把状态按照\(D\)排序的话不能保证答案就会正确(当然,不然就去贪心去)
所以我们引进一个估价函数\(g[状态]\)
当然要求一般是可以预处理出一个状态到答案状态的最优解
回到前面讲到的当前状态\(now\)
如果我们把与\(now\)并列的所有状态按\(D+g[now]\)排序呢?
既不影响答案的正确性,又可以减少坏状态的转移
(因为题目要求是K个最优状态,而这样待决策状态会有序且跑完K个就可以结束,所以会变快)

好吧,还有点蒙对不对,那我们看例题

例题

洛谷P2901 [USACO08MAR]牛慢跑Cow Jogging
好像其他很多\(oj\)都有,但是\(Bzoj\)是权限。。。

题目简述

要求我们求出从起点n到终点1的最短K条路径的长度
(只能从编号大的点往编号小的点走&边有边权)

很裸对吧?

  1. 预处理估价函数

先跑一遍反向边的\(SPFA\)预处理出每个点到1的最短路作为估价函数

  1. 直接跑\(A^*\)(这里用\(Bfs\)实现)

从n号点开始\(Bfs\),用堆来代替队列(实现上面所讲的排序)
这时候先到1节点的肯定答案更优(也就是路径更短)
原因很简单吧:估价函数保证答案合法,而排序之后答案有序
搜到K个到达1节点的路径就可以结束,快的飞起。。。

放个代码?

好不容易写一次注释

#include<bits/stdc++.h>
#define lst long long
#define ldb double
#define N 1050
#define M 10050
#define qw ljl[i].to
using namespace std;
const lst Inf=1e15;
int read()
{
    int s=0,m=0;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')m=1;ch=getchar();}
    while( isdigit(ch))s=(s<<3)+(s<<1)+(ch^48),ch=getchar();
    return m?-s:s;
}

int n,m,K,Done;
bool in[N];
lst dis[N];
queue<int> Q;
int hd[N],cnt;
struct EDGE{int to,nxt,v;}ljl[M<<1];
void Add(int p,int q,int o){ljl[++cnt]=(EDGE){q,hd[p],o},hd[p]=cnt;}
void SPFA()
{
    for(int i=2;i<=n;++i)dis[i]=Inf;
    while(!Q.empty())Q.pop();
    Q.push(1),dis[1]=0,in[1]=true;
    while(!Q.empty())
    {
        int now=Q.front();Q.pop(),in[now]=false;
        for(int i=hd[now];i;i=ljl[i].nxt)
            if(qw>now&&dis[qw]>dis[now]+ljl[i].v)
            {
                dis[qw]=dis[now]+ljl[i].v;
                if(!in[qw])in[qw]=true,Q.push(qw);
            }
    }
}

//h[i]=g[i]+f[i]---->ans[i]=D+dis[i]
struct NODE{
    lst D;int id;
    bool operator<(const NODE &X) const
        {
            return D+dis[id]>X.D+dis[X.id];
        }
};priority_queue<NODE> H;
void A_star_Bfs()
{
    while(!H.empty())H.pop();
    H.push((NODE){0,n});
    while(!H.empty())
    {
        NODE temp=H.top();
        int now=temp.id;H.pop();
        if(now==1)
        {
            printf("%lld\n",temp.D);
            if(++Done==K)return;continue;
        }
        for(int i=hd[now];i;i=ljl[i].nxt)
            if(qw<now)H.push((NODE){temp.D+ljl[i].v,qw});
    }while(Done<K)++Done,puts("-1");
}

int main()
{
    n=read(),m=read(),K=read();
    for(int i=1;i<=m;++i)
    {
        int p=read(),q=read(),o=read();
        Add(p,q,o),Add(q,p,o);
    }
    SPFA(),A_star_Bfs();
    return 0;
}
/************

1.A*算法在人工智能中是一种典型的启发式搜索算法
启发中的估价是用估价函数表示的:
h(n)=f(n)+g(n)
其中f(n)是节点n的估价函数
g(n)表示实际状态空间中从初始节点到n节点的实际代价
h(n)是从n到目标节点最佳路径的估计代价。
另外定义h'(n)为n到目标节点最佳路径的实际值。
如果h'(n)≥h(n)则如果存在从初始状态走到目标状态的最小代价的解
那么用该估价函数搜索的算法就叫A*算法。

2.第K最短路的算法
我们设源点为s,终点为t,我们设状态f(i)的g(i)为从s走到节点i的实际
距离,h(i)为从节点i到t的最短距离,从而满足A*算法的要求,
当第K次走到f(n-1)时表示此时的g(n-1)为第K最短路长度。

3.这里是kuai的xzy的。。。别怪我。。。

*************/ 

总结

暂时就将这么多吧
主要是看到网上没有写的那么通俗的\(A^*\)搜索
就想自己总结一下(其实也不通俗。。。
撤撤撤溜了溜了_______

function [psinew,phinew]=new(h,dt) L = 2; a=1; % 空间域半长 波速 T = 1; % 终止时间 M = round(2*L/h); % 空间网格数 x = (0:h:2*L)'; x = x(1:end-1); % 周期边界处理 % 非线性参数 a1 = 0.5; b1 = -2; c1 = -1; a2 = 0.5; b2 = -3; c2 = -1; vc = 0.9; % 检查关键参数是否合理,避免后续计算出现NaN if vc - a^2 >= 0 error('参数设置错误:vc - a^2 必须小于0以保证sqrt内为正值'); end if c1*c2 - b1*b2 == 0 error('参数设置错误:c1*c2 - b1*b2 不能为0以避免除零'); end % 2-stage Gauss系数 c1g = 0.5 - sqrt(3)/6; c2g = 0.5 + sqrt(3)/6; a11 = 1/4; a12 = 1/4 - sqrt(3)/6; a21 = 1/4 + sqrt(3)/6; a22 = 1/4; b1g = 0.5; b2g = 0.5; %% 构造矩阵 A(五对角循环矩阵) I = speye(M); main_A = 150/180 * ones(M, 1); off1_A = 16/180 * ones(M, 1); off2_A = -1/180 * ones(M, 1); A = spdiags([off2_A, off1_A, main_A, off1_A, off2_A], -2:2, M, M); % 周期性边界处理 A(1,end) = 16/180; A(1,end-1) = -1/180; A(2,end) = -1/180; A(end-1,1) = -1/180; A(end,1) = 16/180; A(end,2) = -1/180; % 矩阵S(六阶差分格式) coef_S = 1/(60*h^2); main_S = -126*coef_S * ones(M, 1); off1_S = 64*coef_S * ones(M, 1); off2_S = -1*coef_S * ones(M, 1); S = spdiags([off2_S, off1_S, main_S, off1_S, off2_S], -2:2, M, M); % 周期性边界处理 S(1, end) = 64*coef_S; S(1, end-1) = -coef_S; S(2, end) = -1*coef_S; S(end-1, 1) = -1*coef_S; S(end, 1) = 64*coef_S; S(end, 2) = -coef_S; B = A \ S; % 紧致差分算子 %% 一维六阶紧致差分矩阵特征值 k = 0:M-1; theta = 2*pi*k/M; lambda_A = (150 + 32*cos(theta) - 2*cos(2*theta)) / 180; lambda_S = (-126 + 128*cos(theta) - 2*cos(2*theta)) * coef_S; lambda_B = lambda_S ./ lambda_A; lambda_B = lambda_B'; %% 初始条件 exact_psi = @(x, t) sqrt(2*a1*(c1-b2)/(b1*b2-c1*c2)) .* sech(sqrt(-a1/(vc-a^2))*(x-vc*t)); exact_phi = @(x, t) sqrt(2*a2*(b1-c2)/(c1*c2-b1*b2)) .* sech(sqrt(-a1/(vc-a^2))*(x-vc*t)); exact_psit = @(x, t) sqrt(-a1/(vc-a^2))*vc*sqrt(2*a1*(c1-b2)/(b1*b2-c1*c2)) .* sech(sqrt(-a1/(vc-a^2))*(x-vc*t)) .* tanh(sqrt(-a1/(vc-a^2))*(x-vc*t)); exact_phit = @(x, t) sqrt(-a1/(vc-a^2))*vc*sqrt(2*a2*(b1-c2)/(c1*c2-b1*b2)) .* sech(sqrt(-a1/(vc-a^2))*(x-vc*t)) .* tanh(sqrt(-a1/(vc-a^2))*(x-vc*t)); psi = exact_psi(x, 0); psi0 = exact_psi(x, 0); phi = exact_phi(x, 0); phi0 = exact_phi(x, 0); u = exact_psit(x, 0); v = exact_phit(x, 0); % 检查初始条件是否有NaN/Inf check_nan_inf([psi, phi, u, v], '初始条件'); %% 初始质量 M0 M0 = h * sum(psi0.^2 + phi0.^2); % 初始能量 E0 term1 = h * sum(u.^2); term2 = - a^2 * h * sum((B*psi0).*psi0); term3 = a1 * h * sum(psi0.^2); term4 = h * sum(v.^2); term5 = - a^2 * h * sum((B*phi0).*phi0); term6 = a2 * h * sum(phi0.^2); term7 = (c2*b1/(4*c1)) * h * sum(psi0.^4); term8 = (b2/4) * h * sum(phi0.^4); term9 = (c2/2) * h * sum((psi0.*phi0).^2); E0 = (1/2) * ( (c2/c1)*(term1 + term2 + term3) + term4 + term5 + term6 ) + term7 + term8 + term9; Nt = round(T/dt); % 初始化守恒量记录 M_arr = zeros(Nt+1, 1); E_arr = zeros(Nt+1, 1); abs_err_M = zeros(Nt+1, 1); abs_err_E = zeros(Nt+1, 1); M_arr(1) = M0; E_arr(1) = E0; abs_err_M(1) = 0; abs_err_E(1) = 0; %% 初始化 prev_u = u; prev_v = v; prev_psi = psi; prev_phi = phi; prev_psi1 = psi; prev_psi2 = psi; prev_phi1 = phi; prev_phi2 = phi; prev_u1 = u; prev_u2 = u; % 前一时刻的阶段值 prev_v1 = v; prev_v2 = v; %% 时间层循环 for n = 1:Nt psi10 = prev_psi1; psi20 = prev_psi2; phi10 = prev_phi1; phi20 = prev_phi2; % 上一时间层的解 u = prev_u; u1 = prev_u1; u2 = prev_u2; v = prev_v; v1 = prev_v1; v2 = prev_v2; psi = prev_psi; psi1 = prev_psi1; psi2 = prev_psi2; phi = prev_phi; phi1 = prev_phi1; phi2 = prev_phi2; %% 使用预测龙格库塔求高阶预估值 % 迭代初值,第0次迭代的值 u1n_s = u1; u2n_s = u2; psi1n_s = psi1; psi2n_s = psi2; v1n_s = v1; v2n_s = v2; phi1n_s = phi1; phi2n_s = phi2; iterw=0;err=inf; % 迭代循环 while iterw < 20 && err > 1e-14 % 非线性项用S次迭代 p1n_s = a1*psi1n_s + b1*psi1n_s.^3 + c1*psi1n_s.*phi1n_s.^2; p2n_s = a1*psi2n_s + b1*psi2n_s.^3 + c1*psi2n_s.*phi2n_s.^2; q1n_s = a2*phi1n_s + b2*phi1n_s.^3 + c2*phi1n_s.*psi1n_s.^2; q2n_s = a2*phi2n_s + b2*phi2n_s.^3 + c2*phi2n_s.*psi2n_s.^2; % 检查中间变量是否有NaN/Inf check_nan_inf([p1n_s, p2n_s, q1n_s, q2n_s], '阶段变量p,q'); % 列方程组求解s+1次迭代的值 A11 = I; A12 = 0*I; A13 = -a^2*dt*a11*B; A14 = -a^2*dt*a12*B; A21 = 0*I; A22 = I; A23 = -a^2*dt*a21*B; A24 = -a^2*dt*a22*B; A31 = -dt*a11*I; A32 = -dt*a12*I; A33 = I; A34 = 0*I; A41 = -dt*a21*I; A42 = -dt*a22*I; A43 = 0*I; A44 = I; A1 = [A11, A12, A13, A14; A21, A22, A23, A24; A31, A32, A33, A34; A41, A42, A43, A44]; bw = [u-dt*a11*p1n_s-dt*a12*p2n_s; u-dt*a21*p1n_s-dt*a22*p2n_s; psi; psi ]; bo = [v-dt*a11*q1n_s-dt*a12*q2n_s; v-dt*a21*q1n_s-dt*a22*q2n_s; phi; phi ]; xw=A1\bw;xo=A1\bo; %new 就是s+1次迭代的值 u1n_new = xw(1:M); u2n_new = xw(M+1:2*M); psi1n_new = xw(2*M+1:3*M); psi2n_new = xw(3*M+1:4*M); v1n_new = xo(1:M); v2n_new = xo(M+1:2*M); phi1n_new = xo(2*M+1:3*M); phi2n_new = xo(3*M+1:4*M); % 计算误差 err_u1n = norm(u1n_new - u1n_s, inf); err_u2n = norm(u2n_new - u2n_s, inf); err_psi1n = norm(psi1n_new - psi1n_s, inf); err_psi2n = norm(psi2n_new - psi2n_s, inf); err_v1n = norm(v1n_new - v1n_s, inf); err_v2n = norm(v2n_new - v2n_s, inf); err_phi1n = norm(phi1n_new - phi1n_s, inf); err_phi2n = norm(phi2n_new - phi2n_s, inf); err = max([err_u1n , err_u2n , err_psi1n , err_psi2n,... err_v1n , err_v2n , err_phi1n , err_phi2n]); % 更新迭代值 u1n_s = u1n_new; u2n_s = u2n_new; psi1n_s = psi1n_new; psi2n_s = psi2n_new; v1n_s = v1n_new; v2n_s = v2n_new; phi1n_s = phi1n_new; phi2n_s = phi2n_new; % 迭代计数增加 iterw = iterw + 1; end % 输出最终迭代值作为预估值 (star) u1_star = u1n_s; u2_star = u2n_s; psi1_star = psi1n_s; psi2_star = psi2n_s; v1_star = v1n_s; v2_star = v2n_s; phi1_star = phi1n_s; phi2_star = phi2n_s; % 得到psi和phi的预估值 %% 然后解线性方程组求解K1(1,2),K2(1,2),K3(1,2),K4(1,2) %(0,lamda1,lamda2,eita1,eita2)五部分 % 系数矩阵如下 A11 = I; A12 = 0*I; A13 = -a^2*dt*a11*B; A14 = -a^2*dt*a12*B; A21 = 0*I; A22 = I; A23 = -a^2*dt*a21*B; A24 = -a^2*dt*a22*B; A31 = -dt*a11*I; A32 = -dt*a12*I; A33 = I; A34 = 0*I; A41 = -dt*a21*I; A42 = -dt*a22*I; A43 = 0*I; A44 = I; A1 = [A11, A12, A13, A14; A21, A22, A23, A24; A31, A32, A33, A34; A41, A42, A43, A44]; % 检查矩阵A1是否合理 if any(isnan(A1(:))) || any(isinf(A1(:))) error('矩阵A1包含NaN或Inf'); end % 右端项 bb0 = [a^2*B*psi; a^2*B*psi; u; u]; bb1 = [-(a1*psi1_star + b1*psi1_star.^3 + c1*psi1_star.*phi1_star.^2); 0*ones(M, 1); 0*ones(M, 1); 0*ones(M, 1)]; bb2 = [0*ones(M, 1); -(a1*psi2_star + b1*psi2_star.^3 + c1*psi2_star.*phi2_star.^2); 0*ones(M, 1); 0*ones(M, 1)]; bb3 = [psi1_star; 0*ones(M, 1); 0*ones(M, 1); 0*ones(M, 1)]; bb4 = [0*ones(M, 1); psi2_star; 0*ones(M, 1); 0*ones(M, 1)]; % 检查右端项是否有NaN/Inf check_nan_inf([bb0; bb1; bb2; bb3; bb4], '线性方程组右端项'); % 得到k1,k2 的五个部分,添加正则化以避免奇异 x0 = A1 \ bb0; x1 = A1 \ bb1; x2 = A1 \ bb2; x3 = A1 \ bb3; x4 = A1 \ bb4; % 右端项 f 和 g 函数分别是 psi 和 phi yb0 = [a^2*B*phi; a^2*B*phi; v; v]; yb1 = [-(a2*phi1_star + b2*phi1_star.^3 + c2*phi1_star.*psi1_star.^2); 0*ones(M, 1); 0*ones(M, 1); 0*ones(M, 1)]; yb2 = [0*ones(M, 1); -(a2*phi2_star + b2*phi2_star.^3 + c2*phi2_star.*psi2_star.^2); 0*ones(M, 1); 0*ones(M, 1)]; yb3 = [phi1_star; 0*ones(M, 1); 0*ones(M, 1); 0*ones(M, 1)]; yb4 = [0*ones(M, 1); phi2_star; 0*ones(M, 1); 0*ones(M, 1)]; % 得到k3(1,2),k4(1,2) 的五个部分 y0 = A1 \ yb0; y1 = A1 \ yb1; y2 = A1 \ yb2; y3 = A1 \ yb3; y4 = A1 \ yb4; k1_1_0 = x0(1:M); k1_1_lamda1 = x1(1:M); k1_1_lamda2 = x2(1:M); k1_1_eita1 = x3(1:M); k1_1_eita2 = x4(1:M); k1_2_0 = x0(M+1:2*M); k1_2_lamda1 = x1(M+1:2*M); k1_2_lamda2 = x2(M+1:2*M); k1_2_eita1 = x3(M+1:2*M); k1_2_eita2 = x4(M+1:2*M); k2_1_0 = x0(2*M+1:3*M); k2_1_lamda1 = x1(2*M+1:3*M); k2_1_lamda2 = x2(2*M+1:3*M); k2_1_eita1 = x3(2*M+1:3*M); k2_1_eita2 = x4(2*M+1:3*M); k2_2_0 = x0(3*M+1:4*M); k2_2_lamda1 = x1(3*M+1:4*M); k2_2_lamda2 = x2(3*M+1:4*M); k2_2_eita1 = x3(3*M+1:4*M); k2_2_eita2 = x4(3*M+1:4*M); k3_1_0 = y0(1:M); k3_1_lamda1 = y1(1:M); k3_1_lamda2 = y2(1:M); k3_1_eita1 = y3(1:M); k3_1_eita2 = y4(1:M); k3_2_0 = y0(M+1:2*M); k3_2_lamda1 = y1(M+1:2*M); k3_2_lamda2 = y2(M+1:2*M); k3_2_eita1 = y3(M+1:2*M); k3_2_eita2 = y4(M+1:2*M); k4_1_0 = y0(2*M+1:3*M); k4_1_lamda1 = y1(2*M+1:3*M); k4_1_lamda2 = y2(2*M+1:3*M); k4_1_eita1 = y3(2*M+1:3*M); k4_1_eita2 = y4(2*M+1:3*M); k4_2_0 = y0(3*M+1:4*M); k4_2_lamda1 = y1(3*M+1:4*M); k4_2_lamda2 = y2(3*M+1:4*M); k4_2_eita1 = y3(3*M+1:4*M); k4_2_eita2 = y4(3*M+1:4*M); %% 解线性方程得到了k1,2,3,4,的五个部分,下面求解优化问题的到lamda(1,2) 和 eita(1,2) p1n=a1*psi1_star+b1*psi1_star.^3+c1*psi1_star.*phi1_star.^2; p2n=a1*psi2_star+b1*psi2_star.^3+c1*psi2_star.*phi2_star.^2; q1n=a2*phi1_star+b2*phi1_star.^3+c2*phi1_star.*psi1_star.^2; q2n=a2*phi2_star+b2*phi2_star.^3+c2*phi2_star.*psi2_star.^2; G = h*sum((b2/4)*phi.^4+(c2/2)*psi.^2.*phi.^2+(a2/2)*phi.^2+((c2*b1)/(4*c1)*psi.^4)+((c2*a1)/(2*c1)*psi.^2));%G(n) % 参数设置 delta_fd = 1e-10; % 有限差分步长 max_iter_bfgs = 100; % 最大BFGS迭代次数 tol_bfgs = 1e-7; % 梯度收敛容差 armijo_c = 1e-4; % Armijo条件常数 tau = 0.5; % 步长缩减因子 % 初始参数估计 (6维向量: [lamda1, lamda2, eita1, eita2, kesi1, kesi2]) x0 = [1; 1; 0; 0; 1; 1]; % 计算初始目标函数值 F_old = computeEquations(... x0, k2_1_0, k2_1_lamda1, k2_1_lamda2, k2_1_eita1, k2_1_eita2, ... k2_2_0, k2_2_lamda1, k2_2_lamda2, k2_2_eita1, k2_2_eita2, ... k4_1_0, k4_1_lamda1, k4_1_lamda2, k4_1_eita1, k4_1_eita2, ... k4_2_0, k4_2_lamda1, k4_2_lamda2, k4_2_eita1, k4_2_eita2, ... dt, b1g, b2g, psi, phi, h, psi0, phi0, p1n, q1n, psi10, phi10,... p2n, q2n, psi20, phi20, G); S_old = 0.5 * (F_old' * F_old); % 检查初始点是否有NaN/Inf check_nan_inf(F_old, '初始F值'); % 使用中心差分法计算初始梯度 g_old = zeros(6,1); for idx = 1:6 x_temp1 = x0; x_temp1(idx) = x0(idx) + delta_fd; F_temp1 = computeEquations(... x_temp1, k2_1_0, k2_1_lamda1, k2_1_lamda2, k2_1_eita1, k2_1_eita2, ... k2_2_0, k2_2_lamda1, k2_2_lamda2, k2_2_eita1, k2_2_eita2, ... k4_1_0, k4_1_lamda1, k4_1_lamda2, k4_1_eita1, k4_1_eita2, ... k4_2_0, k4_2_lamda1, k4_2_lamda2, k4_2_eita1, k4_2_eita2, ... dt, b1g, b2g, psi, phi, h, psi0, phi0, p1n, q1n, psi10, phi10,... p2n, q2n, psi20, phi20, G); S_temp1 = 0.5*(F_temp1'*F_temp1); x_temp2 = x0; x_temp2(idx) = x0(idx) - delta_fd; F_temp2 = computeEquations(... x_temp2, k2_1_0, k2_1_lamda1, k2_1_lamda2, k2_1_eita1, k2_1_eita2, ... k2_2_0, k2_2_lamda1, k2_2_lamda2, k2_2_eita1, k2_2_eita2, ... k4_1_0, k4_1_lamda1, k4_1_lamda2, k4_1_eita1, k4_1_eita2, ... k4_2_0, k4_2_lamda1, k4_2_lamda2, k4_2_eita1, k4_2_eita2, ... dt, b1g, b2g, psi, phi, h, psi0, phi0, p1n, q1n, psi10, phi10,... p2n, q2n, psi20, phi20, G); S_temp2 = 0.5*(F_temp2'*F_temp2); g_old(idx) = (S_temp1 - S_temp2)/(2*delta_fd); % 中心差分 end % 主BFGS循环 Bk = eye(6); % Hessian逆矩阵的初始近似 converged = false; fprintf('开始BFGS优化...\n'); for iter_bfgs = 1:max_iter_bfgs % 收敛检查 (检查梯度范数) if norm(g_old) < tol_bfgs converged = true; fprintf('BFGS在迭代%d收敛: |grad| = %.4e\n', iter_bfgs, norm(g_old)); break; end % 计算搜索方向 (dk = -Bk * g) dk = -Bk * g_old; dir_deriv = g_old' * dk; % 方向导数 % 检查是否下降方向 (应<0) if dir_deriv >= 0 warning('非下降方向 (dir_deriv = %.4e), 重置为负梯度', dir_deriv); dk = -g_old; dir_deriv = g_old' * dk; end % Armijo线搜索 (确保满足充分下降条件) alpha = 1.0; armijo_iter = 0; max_armijo_iter = 100; while armijo_iter < max_armijo_iter x_new = x0 + alpha * dk; % 计算新点目标函数 F_new = computeEquations(x_new, k2_1_0, k2_1_lamda1, k2_1_lamda2, k2_1_eita1, k2_1_eita2, ... k2_2_0, k2_2_lamda1, k2_2_lamda2, k2_2_eita1, k2_2_eita2, ... k4_1_0, k4_1_lamda1, k4_1_lamda2, k4_1_eita1, k4_1_eita2, ... k4_2_0, k4_2_lamda1, k4_2_lamda2, k4_2_eita1, k4_2_eita2, ... dt, b1g, b2g, psi, phi, h, psi0, phi0, p1n, q1n, psi10, phi10, p2n, q2n, psi20, phi20, ... G); S_new = 0.5 * (F_new' * F_new); % Armijo条件: S(x + αd) ≤ S(x) + c·α·∇S(x)·d if S_new <= S_old + armijo_c * alpha * dir_deriv break; end % 更新步长 alpha = tau * alpha; armijo_iter = armijo_iter + 1; end if armijo_iter == max_armijo_iter warning('Armijo搜索未在%d次迭代内收敛, alpha=%.2e', max_armijo_iter, alpha); end %% 计算新梯度 g_new = zeros(6,1); for idx = 1:6 x_temp1 = x_new; x_temp1(idx) = x_new(idx) + delta_fd; F_temp1 = computeEquations(... x_temp1, k2_1_0, k2_1_lamda1, k2_1_lamda2, k2_1_eita1, k2_1_eita2, ... k2_2_0, k2_2_lamda1, k2_2_lamda2, k2_2_eita1, k2_2_eita2, ... k4_1_0, k4_1_lamda1, k4_1_lamda2, k4_1_eita1, k4_1_eita2, ... k4_2_0, k4_2_lamda1, k4_2_lamda2, k4_2_eita1, k4_2_eita2, ... dt, b1g, b2g, psi, phi, h, psi0, phi0, p1n, q1n, psi10, phi10,... p2n, q2n, psi20, phi20, G); S_temp1 = 0.5*(F_temp1'*F_temp1); x_temp2 = x_new; x_temp2(idx) = x_new(idx) - delta_fd; F_temp2 = computeEquations(... x_temp2, k2_1_0, k2_1_lamda1, k2_1_lamda2, k2_1_eita1, k2_1_eita2, ... k2_2_0, k2_2_lamda1, k2_2_lamda2, k2_2_eita1, k2_2_eita2, ... k4_1_0, k4_1_lamda1, k4_1_lamda2, k4_1_eita1, k4_1_eita2, ... k4_2_0, k4_2_lamda1, k4_2_lamda2, k4_2_eita1, k4_2_eita2, ... dt, b1g, b2g, psi, phi, h, psi0, phi0, p1n, q1n, psi10, phi10,... p2n, q2n, psi20, phi20, G); S_temp2 = 0.5*(F_temp2'*F_temp2); g_new(idx) = (S_temp1 - S_temp2)/(2*delta_fd); % 中心差分 end %% BFGS更新 (更新逆海森矩阵近似Bk) sk = x_new - x0; yk = g_new - g_old; % 曲率条件检查 if yk' * sk > 0 % 确保正曲率 rho_k = 1 / (yk' * sk); % (BFGS更新) Bk = (eye(6) - rho_k * sk * yk') * Bk * (eye(6) - rho_k * yk * sk') + rho_k * (sk * sk'); else warning('曲率条件失败 (yᵀs=%.4e), 重置Hessian近似', yk' * sk); Bk = eye(6); end % 更新迭代点 x0 = x_new; F_old = F_new; S_old = S_new; g_old = g_new; % 打印迭代进度 if mod(iter_bfgs, 5) == 0 || iter_bfgs == 1 fprintf('BFGS迭代 %3d: S=%.4e, |grad|=%.4e, alpha=%.2e\n', ... iter_bfgs, S_old, norm(g_old), alpha); end end if ~converged warning('BFGS未收敛于%d次迭代, 最终|grad|=%.4e', max_iter_bfgs, norm(g_old)); end %% 提取优化参数 lamda1 = x0(1); lamda2 = x0(2); eita1 = x0(3); eita2 = x0(4); kesi1 = x0(5); kesi2 = x0(6); % 在控制台输出当前时间步的参数值 fprintf('时间步 %d/%d: lamda1=%.10f, lamda2=%.10f, eita1=%.10f, eita2=%.10f, kesi1=%.10f, kesi2=%.10f\n', ... n, Nt, lamda1, lamda2, eita1, eita2, kesi1, kesi2); %% 然后组装得到K1(1,2),K2(1,2),K3(1,2),K4(1,2) k1_1 = k1_1_0 + lamda1.*k1_1_lamda1 + lamda2.*k1_1_lamda2 + eita1.*k1_1_eita1 + eita2.*k1_1_eita2; k1_2 = k1_2_0 + lamda1.*k1_2_lamda1 + lamda2.*k1_2_lamda2 + eita1.*k1_2_eita1 + eita2.*k1_2_eita2; k2_1 = k2_1_0 + lamda1.*k2_1_lamda1 + lamda2.*k2_1_lamda2 + eita1.*k2_1_eita1 + eita2.*k2_1_eita2; k2_2 = k2_2_0 + lamda1.*k2_2_lamda1 + lamda2.*k2_2_lamda2 + eita1.*k2_2_eita1 + eita2.*k2_2_eita2; k3_1 = k3_1_0 + lamda1.*k3_1_lamda1 + lamda2.*k3_1_lamda2 + eita1.*k3_1_eita1 + eita2.*k3_1_eita2; k3_2 = k3_2_0 + lamda1.*k3_2_lamda1 + lamda2.*k3_2_lamda2 + eita1.*k3_2_eita1 + eita2.*k3_2_eita2; k4_1 = k4_1_0 + lamda1.*k4_1_lamda1 + lamda2.*k4_1_lamda2 + eita1.*k4_1_eita1 + eita2.*k4_1_eita2; k4_2 = k4_2_0 + lamda1.*k4_2_lamda1 + lamda2.*k4_2_lamda2 + eita1.*k4_2_eita1 + eita2.*k4_2_eita2; % 检查k值是否有NaN/Inf check_nan_inf([k1_1, k1_2, k2_1, k2_2, k3_1, k3_2, k4_1, k4_2], 'k值'); %% n+1层的解 unew = u + dt*b1g*k1_1 + dt*b2g*k1_2; psinew = psi + dt*b1g*k2_1 + dt*b2g*k2_2; vnew = v + dt*b1g*k3_1 + dt*b2g*k3_2; phinew = phi + dt*b1g*k4_1 + dt*b2g*k4_2; % 检查新解是否有NaN/Inf check_nan_inf([unew, psinew, vnew, phinew], '新时间步解'); % 计算当前时间步的质量和能量 M_current = h * sum(psinew.^2 + phinew.^2); term1 = h * sum(vnew.^2); term2 = -a^2 * h * sum((B*psinew).*psinew); term3 = a1 * h * sum(psinew.^2); term4 = h * sum(unew.^2); term5 = -a^2 * h * sum((B*phinew).*phinew); term6 = a2 * h * sum(phinew.^2); term7 = (c2*b1/(4*c1)) * h * sum(psinew.^4); term8 = (b2/4) * h * sum(phinew.^4); term9 = (c2/2) * h * sum((psinew.*phinew).^2); E_current = (1/2)*((c2/c1)*(term1 + term2 + term3) + term4 + term5 + term6) + term7 + term8 + term9; % 记录守恒量及误差 M_arr(n+1) = M_current; E_arr(n+1) = E_current; abs_err_M(n+1) = abs(M_current - M0); abs_err_E(n+1) = abs(E_current - E0); % 更新历史变量 prev_u = unew; prev_psi = psinew; prev_v = vnew; prev_phi = phinew; prev_psi1 = prev_psi + dt*(a11*k2_1 + a12*k2_2); prev_psi2 = prev_psi + dt*(a21*k2_1 + a22*k2_2); prev_phi1 = prev_phi + dt*(a11*k4_1 + a12*k4_2); prev_phi2 = prev_phi + dt*(a21*k4_1 + a22*k4_2); prev_u1 = prev_u + dt*(a11*k1_1 + a12*k1_2); prev_u2 = prev_u + dt*(a21*k1_1 + a22*k1_2); prev_v1 = prev_v + dt*(a11*k3_1 + a12*k3_2); prev_v2 = prev_v + dt*(a21*k3_1 + a22*k3_2); % 每10步显示进度 if mod(n, 10) == 0 fprintf('完成时间步:%d/%d\n', n, Nt); end end %% 绘制守恒量误差图 figure; semilogy((0:Nt)*dt, abs_err_M, 'b', 'LineWidth', 1.5); hold on; semilogy((0:Nt)*dt, abs_err_E, 'r', 'LineWidth', 1.5); xlabel('Time'); ylabel('Absolute Error (log scale)'); title('Conservation Laws (Logarithmic Scale)'); legend('Mass Error', 'Energy Error'); grid on; figure; plot(x, [psi0, psinew]); title('初始和最终波形对比'); legend('t=0','t=T'); end %% 六个方程式组成的方程组 function F = computeEquations(x,k2_1_0, k2_1_lamda1, k2_1_lamda2, k2_1_eita1, k2_1_eita2, ... k2_2_0, k2_2_lamda1, k2_2_lamda2, k2_2_eita1, k2_2_eita2, ... k4_1_0, k4_1_lamda1, k4_1_lamda2, k4_1_eita1, k4_1_eita2, ... k4_2_0, k4_2_lamda1, k4_2_lamda2, k4_2_eita1, k4_2_eita2, ... dt, b1g, b2g, psi, phi, h, f0, g0, p1n, q1n, f10, g10, p2n, q2n, f20, g20, ... G) % 非线性参数 a1 = 0.5; b1 = -2; c1 = -1; a2 = 0.5; b2 = -3; c2 = -1; % 解包参数 lamda1 = x(1); lamda2 = x(2); eita1 = x(3); eita2 = x(4); kesi1 = x(5); kesi2 = x(6); m1 = dt*b1g*k2_1_lamda1+ dt*b2g*k2_2_lamda1; m2 = dt*b1g*k2_1_lamda2+ dt*b2g*k2_2_lamda2; m3 = dt*b1g*k2_1_eita1 + dt*b2g*k2_2_eita1 ; m4 = dt*b1g*k2_1_eita2 + dt*b2g*k2_2_eita2 ; n1 = dt*b1g*k4_1_lamda1+ dt*b2g*k4_2_lamda1; n2 = dt*b1g*k4_1_lamda2+ dt*b2g*k4_2_lamda2; n3 = dt*b1g*k4_1_eita1 + dt*b2g*k4_2_eita1 ; n4 = dt*b1g*k4_1_eita2 + dt*b2g*k4_2_eita2 ; % 计算中间变量 k2_1 = k2_1_0 + lamda1*k2_1_lamda1 + lamda2*k2_1_lamda2 + eita1*k2_1_eita1 + eita2*k2_1_eita2; k2_2 = k2_2_0 + lamda1*k2_2_lamda1 + lamda2*k2_2_lamda2 + eita1*k2_2_eita1 + eita2*k2_2_eita2; k4_1 = k4_1_0 + lamda1*k4_1_lamda1 + lamda2*k4_1_lamda2 + eita1*k4_1_eita1 + eita2*k4_1_eita2; k4_2 = k4_2_0 + lamda1*k4_2_lamda1 + lamda2*k4_2_lamda2 + eita1*k4_2_eita1 + eita2*k4_2_eita2; fnw = psi + dt*b1g*k2_1 + dt*b2g*k2_2; gnw = phi + dt*b1g*k4_1 + dt*b2g*k4_2; Gnw = h*sum((b2/4)*gnw.^4+(c2/2)*(fnw.*gnw).^2+(a2/2)*gnw.^2+((c2*b1)/(4*c1)*fnw.^4)+((c2*a1)/(2*c1)*fnw.^2));%G(n+1) Gnwlamda1=h*sum(b2*gnw.^3.*n1+c2*fnw.*gnw.*(m1.*gnw+n1.*fnw)+a2*gnw.*n1+((c2*b1)/c1*fnw.^3.*m1)+((c2*a1)/c1*fnw.*m1)); Gnwlamda2=h*sum(b2*gnw.^3.*n2+c2*fnw.*gnw.*(m2.*gnw+n2.*fnw)+a2*gnw.*n2+((c2*b1)/c1*fnw.^3.*m2)+((c2*a1)/c1*fnw.*m2)); Gnweita1 =h*sum(b2*gnw.^3.*n3+c2*fnw.*gnw.*(m3.*gnw+n3.*fnw)+a2*gnw.*n3+((c2*b1)/c1*fnw.^3.*m3)+((c2*a1)/c1*fnw.*m3)); Gnweita2 =h*sum(b2*gnw.^3.*n4+c2*fnw.*gnw.*(m4.*gnw+n4.*fnw)+a2*gnw.*n4+((c2*b1)/c1*fnw.^3.*m4)+((c2*a1)/c1*fnw.*m4)); % 计算方程值 Llamda1 = 2*lamda1 - 2 + kesi1*2*h*sum(fnw.*m1 + gnw.*n1) + kesi2*Gnwlamda1 ... - kesi2*dt*b1g*h*((sum(p1n.*k2_1) + sum(q1n.*k4_1))+lamda1*(sum(p1n.*k2_1_lamda1) + sum(q1n.*k4_1_lamda1)) + (c2/c1)*eita1*(sum(f10.*k2_1_lamda1) + sum(g10.*k4_1_lamda1)))... - kesi2*dt*b2g*h*(lamda2*(sum(p2n.*k2_2_lamda1) + sum(q2n.*k4_2_lamda1)) + (c2/c1)*eita2*(sum(f20.*k2_2_lamda1) + sum(g20.*k4_2_lamda1))); Llamda2 = 2*lamda2 - 2 + kesi1*2*h*sum(fnw.*m2 + gnw.*n2) + kesi2*Gnwlamda2 ... - kesi2*dt*b1g*h*(lamda1*(sum(p1n.*k2_1_lamda2) + sum(q1n.*k4_1_lamda2)) + (c2/c1)*eita1*(sum(f10.*k2_1_lamda2) + sum(g10.*k4_1_lamda2)))... - kesi2*dt*b2g*h*((sum(p2n.*k2_2) + sum(q2n.*k4_2))+lamda2*(sum(p2n.*k2_2_lamda2) + sum(q2n.*k4_2_lamda2)) + (c2/c1)*eita2*(sum(f20.*k2_2_lamda2) + sum(g20.*k4_2_lamda2))); Leita1 = 2*eita1 + 2*kesi1*h*sum(fnw.*m3 + gnw.*n3) + kesi2*Gnweita1 ... - kesi2*dt*b1g*h*(lamda1*(sum(p1n.*k2_1_eita1) + sum(q1n.*k4_1_eita1)) + (c2/c1)*h*(sum(f10.*k2_1)+sum(g10.*k4_1)) + (c2/c1)*eita1*(sum(f10.*k2_1_eita1) + sum(g10.*k4_1_eita1)))... - kesi2*dt*b2g*h*(lamda2*(sum(p2n.*k2_2_eita1) + sum(q2n.*k4_2_eita1)) + (c2/c1)*eita2*(sum(f20.*k2_2_eita1) + sum(g20.*k4_2_eita1))); Leita2 = 2*eita2 + 2*kesi1*h*sum(fnw.*m4 + gnw.*n4) + kesi2*Gnweita2 ... - kesi2*dt*b1g*h*(lamda1*(sum(p1n.*k2_1_eita2) + sum(q1n.*k4_1_eita2)) + (c2/c1) * eita1*(sum(f10.*k2_1_eita2) + sum(g10.*k4_1_eita2)))... - kesi2*dt*b2g*h*(lamda2*(sum(p2n.*k2_2_eita2) + sum(q2n.*k4_2_eita2)) + (c2/c1)*(sum(f20.*k2_2)+sum(g20.*k4_2))+ (c2/c1)*eita2*(sum(f20.*k2_2_eita2) + sum(g20.*k4_2_eita2))); Lkesi1 = h*sum(gnw.^2 + fnw.^2) - h*sum(f0.^2 + g0.^2); Lkesi2 = Gnw - G ... - dt*b1g*(lamda1*h*(sum(p1n.*k2_1) + sum(q1n.*k4_1)) + (c2/c1)*eita1*h*sum(f10.*k2_1) + eita1*sum(g10.*k4_1)) ... - dt*b2g*(lamda2*h*(sum(p2n.*k2_2) + sum(q2n.*k4_2)) + (c2/c1)*eita2*h*sum(f20.*k2_2) + eita2*sum(g20.*k4_2)); F = [Llamda1; Llamda2; Leita1; Leita2; Lkesi1; Lkesi2]; end %% 辅助函数:检查变量是否包含NaN或Inf function check_nan_inf(var, varname) if any(isnan(var(:))) || any(isinf(var(:))) warning([varname, ' 包含NaN或Inf值']); end end 我这里面的拟牛顿法是否正确,
最新发布
08-15
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值