偶尔看到有关float/double计算精度的问题,做个backup

本文探讨了Java中使用非整数类型时常见的问题,包括浮点数的特殊值、舍入误差及其比较方法,以及BigDecimal的正确使用。

http://www.ibm.com/developerworks/cn/java/j-jtp0114/index.html

 

许多程序员在其整个开发生涯中都不曾使用定点或浮点数,可能的例外是,偶尔在计时测试或基准测试程序中会用到。Java语言和类库支持两类非整数类型 ― IEEE 754 浮点( floatdouble ,包装类(wrapper class)为 FloatDouble ),以及任意精度的小数( java.math.BigDecimal )。在本月的 Java 理论和实践中,Brian Goetz 探讨了在 Java 程序中使用非整数类型时一些常碰到的陷阱和“gotcha”。请在本文的 论坛上提出您对本文的想法,以飨笔者和其他读者。(您也可以单击本文顶部或底部的讨论来访问论坛)。

虽然几乎每种处理器和编程语言都支持浮点运算,但大多数程序员很少注意它。这容易理解 ― 我们中大多数很少需要使用非整数类型。除了科学计算和偶尔的计时测试或基准测试程序,其它情况下几乎都用不着它。同样,大多数开发人员也容易忽略 java.math.BigDecimal 所提供的任意精度的小数 ― 大多数应用程序不使用它们。然而,在以整数为主的程序中有时确实会出人意料地需要表示非整型数据。例如,JDBC 使用 BigDecimal 作为 SQL DECIMAL 列的首选互换格式。

IEEE 浮点

Java 语言支持两种基本的浮点类型: floatdouble ,以及与它们对应的包装类 FloatDouble 。它们都依据 IEEE 754 标准,该标准为 32 位浮点和 64 位双精度浮点二进制小数定义了二进制标准。

IEEE 754 用科学记数法以底数为 2 的小数来表示浮点数。IEEE 浮点数用 1 位表示数字的符号,用 8 位来表示指数,用 23 位来表示尾数,即小数部分。作为有符号整数的指数可以有正负之分。小数部分用二进制(底数 2)小数来表示,这意味着最高位对应着值 ?(2 -1),第二位对应着 ?(2 -2),依此类推。对于双精度浮点数,用 11 位表示指数,52 位表示尾数。IEEE 浮点值的格式如图 1 所示。


图 1. IEEE 754 浮点数的格式
图 1. IEEE 754 浮点数的格式

因为用科学记数法可以有多种方式来表示给定数字,所以要规范化浮点数,以便用底数为 2 并且小数点左边为 1 的小数来表示,按照需要调节指数就可以得到所需的数字。所以,例如,数 1.25 可以表示为尾数为 1.01,指数为 0: (-1) 0*1.01 2*2 0

数 10.0 可以表示为尾数为 1.01,指数为 3: (-1) 0*1.01 2*2 3

特殊数字

除了编码所允许的值的标准范围(对于 float ,从 1.4e-45 到 3.4028235e+38),还有一些表示无穷大、负无穷大、 -0 和 NaN(它代表“不是一个数字”)的特殊值。这些值的存在是为了在出现错误条件(譬如算术溢出,给负数开平方根,除以 0 等)下,可以用浮点值集合中的数字来表示所产生的结果。

这些特殊的数字有一些不寻常的特征。例如, 0-0 是不同值,但在比较它们是否相等时,被认为是相等的。用一个非零数去除以无穷大的数,结果等于 0 。特殊数字 NaN 是无序的;使用 ==<> 运算符将 NaN 与其它浮点值比较时,结果为 false 。如果 f 为 NaN,则即使 (f == f) 也会得到 false 。如果想将浮点值与 NaN 进行比较,则使用 Float.isNaN() 方法。表 1 显示了无穷大和 NaN 的一些属性。

表 1. 特殊浮点值的属性

表达式结果
Math.sqrt(-1.0)-> NaN
0.0 / 0.0-> NaN
1.0 / 0.0-> 无穷大
-1.0 / 0.0-> 负无穷大
NaN + 1.0-> NaN
无穷大 + 1.0-> 无穷大
无穷大 + 无穷大-> 无穷大
NaN > 1.0-> false
NaN == 1.0-> false
NaN < 1.0-> false
NaN == NaN-> false
0.0 == -0.01-> true

基本浮点类型和包装类浮点有不同的比较行为

使事情更糟的是,在基本 float 类型和包装类 Float 之间,用于比较 NaN 和 -0 的规则是不同的。对于 float 值,比较两个 NaN 值是否相等将会得到 false ,而使用 Float.equals() 来比较两个 NaN Float 对象会得到 true 。造成这种现象的原因是,如果不这样的话,就不可能将 NaN Float 对象用作 HashMap 中的键。类似的,虽然 0-0 在表示为浮点值时,被认为是相等的,但使用 Float.compareTo() 来比较作为 Float 对象的 0-0 时,会显示 -0 小于 0




回页首

浮点中的危险

由于无穷大、NaN 和 0 的特殊行为,当应用浮点数时,可能看似无害的转换和优化实际上是不正确的。例如,虽然好象 0.0-f 很明显等于 -f ,但当 f0 时,这是不正确的。还有其它类似的 gotcha,表 2 显示了其中一些 gotcha。

表 2. 无效的浮点假定

这个表达式……不一定等于……当……
0.0 - f-ff 为 0
f < g! (f >= g)f 或 g 为 NaN
f == ftruef 为 NaN
f + g - gfg 为无穷大或 NaN

舍入误差

浮点运算很少是精确的。虽然一些数字(譬如 0.5 )可以精确地表示为二进制(底数 2)小数(因为 0.5 等于 2 -1),但其它一些数字(譬如 0.1 )就不能精确的表示。因此,浮点运算可能导致舍入误差,产生的结果接近 ― 但不等于 ― 您可能希望的结果。例如,下面这个简单的计算将得到 2.600000000000001 ,而不是 2.6


double s=0;
  for (int i=0; i<26; i++)
    s += 0.1;
  System.out.println(s);

类似的, .1*26 相乘所产生的结果不等于 .1 自身加 26 次所得到的结果。当将浮点数强制转换成整数时,产生的舍入误差甚至更严重,因为强制转换成整数类型会舍弃非整数部分,甚至对于那些“看上去似乎”应该得到整数值的计算,也存在此类问题。例如,下面这些语句:


double d = 29.0 * 0.01;
  System.out.println(d);
  System.out.println((int) (d * 100));

将得到以下输出:


0.29
  28

这可能不是您起初所期望的。




回页首

浮点数比较指南

由于存在 NaN 的不寻常比较行为和在几乎所有浮点计算中都不可避免地会出现舍入误差,解释浮点值的比较运算符的结果比较麻烦。

最好完全避免使用浮点数比较。当然,这并不总是可能的,但您应该意识到要限制浮点数比较。如果必须比较浮点数来看它们是否相等,则应该将它们差的绝对值同一些预先选定的小正数进行比较,这样您所做的就是测试它们是否“足够接近”。(如果不知道基本的计算范围,可以使用测试“abs(a/b - 1) < epsilon”,这种方法比简单地比较两者之差要更准确)。甚至测试看一个值是比零大还是比零小也存在危险 ―“以为”会生成比零略大值的计算事实上可能由于积累的舍入误差会生成略微比零小的数字。

NaN 的无序性质使得在比较浮点数时更容易发生错误。当比较浮点数时,围绕无穷大和 NaN 问题,一种避免 gotcha 的经验法则是显式地测试值的有效性,而不是试图排除无效值。在清单 1 中,有两个可能的用于特性的 setter 的实现,该特性只能接受非负数值。第一个实现会接受 NaN,第二个不会。第二种形式比较好,因为它显式地检测了您认为有效的值的范围。


清单 1. 需要非负浮点值的较好办法和较差办法

// Trying to test by exclusion -- this doesn't catch NaN or infinity
    public void setFoo(float foo) {
      if (foo < 0)
          throw new IllegalArgumentException(Float.toString(f));
        this.foo = foo;
    }
    // Testing by inclusion -- this does catch NaN
    public void setFoo(float foo) {
      if (foo >= 0 && foo < Float.INFINITY)
        this.foo = foo;
  else
        throw new IllegalArgumentException(Float.toString(f));
    }

不要用浮点值表示精确值

一些非整数值(如几美元和几美分这样的小数)需要很精确。浮点数不是精确值,所以使用它们会导致舍入误差。因此,使用浮点数来试图表示象货币量这样的精确数量不是一个好的想法。使用浮点数来进行美元和美分计算会得到灾难性的后果。浮点数最好用来表示象测量值这类数值,这类值从一开始就不怎么精确。




回页首

用于较小数的 BigDecimal

从 JDK 1.3 起,Java 开发人员就有了另一种数值表示法来表示非整数: BigDecimalBigDecimal 是标准的类,在编译器中不需要特殊支持,它可以表示任意精度的小数,并对它们进行计算。在内部,可以用任意精度任何范围的值和一个换算因子来表示 BigDecimal ,换算因子表示左移小数点多少位,从而得到所期望范围内的值。因此,用 BigDecimal 表示的数的形式为 unscaledValue*10 -scale

用于加、减、乘和除的方法给 BigDecimal 值提供了算术运算。由于 BigDecimal 对象是不可变的,这些方法中的每一个都会产生新的 BigDecimal 对象。因此,因为创建对象的开销, BigDecimal 不适合于大量的数学计算,但设计它的目的是用来精确地表示小数。如果您正在寻找一种能精确表示如货币量这样的数值,则 BigDecimal 可以很好地胜任该任务。

所有的 equals 方法都不能真正测试相等

如浮点类型一样, BigDecimal 也有一些令人奇怪的行为。尤其在使用 equals() 方法来检测数值之间是否相等时要小心。 equals() 方法认为,两个表示同一个数但换算值不同(例如, 100.00100.000 )的 BigDecimal 值是不相等的。然而, compareTo() 方法会认为这两个数是相等的,所以在从数值上比较两个 BigDecimal 值时,应该使用 compareTo() 而不是 equals()

另外还有一些情形,任意精度的小数运算仍不能表示精确结果。例如, 1 除以 9 会产生无限循环的小数 .111111... 。出于这个原因,在进行除法运算时, BigDecimal 可以让您显式地控制舍入。 movePointLeft() 方法支持 10 的幂次方的精确除法。

使用 BigDecimal 作为互换类型

SQL-92 包括 DECIMAL 数据类型,它是用于表示定点小数的精确数字类型,它可以对小数进行基本的算术运算。一些 SQL 语言喜欢称此类型为 NUMERIC 类型,其它一些 SQL 语言则引入了 MONEY 数据类型,MONEY 数据类型被定义为小数点右侧带有两位的小数。

如果希望将数字存储到数据库中的 DECIMAL 字段,或从 DECIMAL 字段检索值,则如何确保精确地转换该数字?您可能不希望使用由 JDBC PreparedStatementResultSet 类所提供的 setFloat()getFloat() 方法,因为浮点数与小数之间的转换可能会丧失精确性。相反,请使用 PreparedStatementResultSetsetBigDecimal()getBigDecimal() 方法。

对于 BigDecimal ,有几个可用的构造函数。其中一个构造函数以双精度浮点数作为输入,另一个以整数和换算因子作为输入,还有一个以小数的 String 表示作为输入。要小心使用 BigDecimal(double) 构造函数,因为如果不了解它,会在计算过程中产生舍入误差。请使用基于整数或 String 的构造函数。

构造 BigDecimal 数

对于 BigDecimal ,有几个可用的构造函数。其中一个构造函数以双精度浮点数作为输入,另一个以整数和换算因子作为输入,还有一个以小数的 String 表示作为输入。要小心使用 BigDecimal(double) 构造函数,因为如果不了解它,会在计算过程中产生舍入误差。请使用基于整数或 String 的构造函数。

如果使用 BigDecimal(double) 构造函数不恰当,在传递给 JDBC setBigDecimal() 方法时,会造成似乎很奇怪的 JDBC 驱动程序中的异常。例如,考虑以下 JDBC 代码,该代码希望将数字 0.01 存储到小数字段:

PreparedStatement ps =
    connection.prepareStatement("INSERT INTO Foo SET name=?, value=?");
  ps.setString(1, "penny");
  ps.setBigDecimal(2, new BigDecimal(0.01));
  ps.executeUpdate();

在执行这段似乎无害的代码时会抛出一些令人迷惑不解的异常(这取决于具体的 JDBC 驱动程序),因为 0.01 的双精度近似值会导致大的换算值,这可能会使 JDBC 驱动程序或数据库感到迷惑。JDBC 驱动程序会产生异常,但可能不会说明代码实际上错在哪里,除非意识到二进制浮点数的局限性。相反,使用 BigDecimal("0.01")BigDecimal(1, 2) 构造 BigDecimal 来避免这类问题,因为这两种方法都可以精确地表示小数。

 

下面是网上盛传的现成的应用例子.

import java.math.BigDecimal;
import java.math.RoundingMode;

/**
*
* 由于Java的简单类型不能够精确的对浮点数进行运算,这个工具类提供精
*
* 确的浮点数运算,包括加减乘除和四舍五入。
*
*/

public class MathOperUtil {

// 默认除法运算精度

private static final int DEF_DIV_SCALE = 10;

// 这个类不能实例化

private MathOperUtil() {

}

/**
*
* 提供精确的加法运算。
*
* @param v1
*            被加数
*
* @param v2
*            加数
*
* @return 两个参数的和
*
*/

public static double add(double v1, double v2) {

BigDecimal b1 = new BigDecimal(Double.toString(v1));

BigDecimal b2 = new BigDecimal(Double.toString(v2));

return b1.add(b2).doubleValue();

}

/**
*
* 提供精确的减法运算。
*
* @param v1
*            被减数
*
* @param v2
*            减数
*
* @return 两个参数的差
*
*/

public static double sub(double v1, double v2) {

BigDecimal b1 = new BigDecimal(Double.toString(v1));

BigDecimal b2 = new BigDecimal(Double.toString(v2));

return b1.subtract(b2).doubleValue();

}

/**
*
* 提供精确的乘法运算。
*
* @param v1
*            被乘数
*
* @param v2
*            乘数
*
* @return 两个参数的积
*
*/

public static double mul(double v1, double v2) {

BigDecimal b1 = new BigDecimal(Double.toString(v1));

BigDecimal b2 = new BigDecimal(Double.toString(v2));

return b1.multiply(b2).doubleValue();

}

/**
*
* 提供(相对)精确的除法运算,当发生除不尽的情况时,精确到
*
* 小数点以后10位,以后的数字四舍五入。
*
* @param v1
*            被除数
*
* @param v2
*            除数
*
* @return 两个参数的商
*
*/

public static double div(double v1, double v2) {

return div(v1, v2, DEF_DIV_SCALE);

}

/**
*
* 提供(相对)精确的除法运算。当发生除不尽的情况时,由scale参数指
*
* 定精度,以后的数字四舍五入。
*
* @param v1
*            被除数
*
* @param v2
*            除数
*
* @param scale
*            表示表示需要精确到小数点以后几位。
*
* @return 两个参数的商
*
*/

public static double div(double v1, double v2, int scale) {

if (scale < 0) {

throw new IllegalArgumentException(

"The scale must be a positive integer or zero");

}

BigDecimal b1 = new BigDecimal(Double.toString(v1));

BigDecimal b2 = new BigDecimal(Double.toString(v2));

return b1.divide(b2, scale, BigDecimal.ROUND_HALF_UP).doubleValue();

}

/**
*
* 提供精确的小数位四舍五入处理。
*
* @param v
*            需要四舍五入的数字
*
* @param scale
*            小数点后保留几位
*
* @return 四舍五入后的结果
*
*/

public static double round(double v, int scale) {

if (scale < 0) {

throw new IllegalArgumentException(

"The scale must be a positive integer or zero");

}

BigDecimal b = new BigDecimal(Double.toString(v));

return b.setScale(scale, RoundingMode.HALF_UP).doubleValue();

}

};

 

package com.kotei.overseas.navi.update; import static com.kotei.overseas.navi.security.DecryptUtil.dataVerification; import android.content.BroadcastReceiver; import android.content.Context; import android.content.Intent; import android.content.IntentFilter; import android.os.AsyncTask; import android.os.Handler; import android.os.Looper; import android.util.Log; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import androidx.annotation.NonNull; import com.here.sdk.core.engine.SDKNativeEngine; import com.here.sdk.maploader.MapDownloader; import com.here.sdk.maploader.MapDownloaderConstructionCallback; import com.kotei.overseas.navi.business.data.MapDataController; import com.kotei.overseas.navi.security.DecryptUtil; import com.kotei.overseas.navi.security.DfCert; import java.io.File; import java.io.IOException; import java.io.UncheckedIOException; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicInteger; import java.util.regex.Pattern; import java.util.stream.Stream; /** * USB离线更新系统 */ public class USBOfflineUpdater { private static final String TAG = "USBOfflineUpdater"; // 状态码定义 /** * 操作成功 */ public static final int SUCCESS = 0; /** * 错误:未检测到USB设备 */ public static final int ERROR_NO_USB = 1; /** * 错误:未找到升级包文件 */ public static final int ERROR_NO_UPDATE_PACKAGE = 2; /** * 错误:电池电量不足(低于安全阈值) */ public static final int ERROR_BATTERY_LOW = 3; /** * 错误:存储空间不足 */ public static final int ERROR_STORAGE_INSUFFICIENT = 4; /** * 错误:系统正在执行其他升级任务 */ public static final int ERROR_UPDATE_IN_PROGRESS = 5; /** * 错误:文件复制失败(检查存储权限或磁盘状态) */ public static final int ERROR_COPY_FAILED = 6; /** * 错误:升级包解压失败(文件可能损坏) */ public static final int ERROR_EXTRACT_FAILED = 7; /** * 错误:用户手动取消操作 */ public static final int ERROR_USER_CANCELED = 8; /** * 错误:未预期的系统异常 */ public static final int ERROR_UNEXPECTED = 9; /** * 错误:升级过程中USB设备被移除 */ public static final int ERROR_USB_REMOVED = 10; /** * 错误:车辆档位未处于停车挡(P档) */ public static final int ERROR_VEHICLE_SHIFTED = 11; /** * 错误:电池电量极低(无法维持升级过程) */ public static final int ERROR_BATTERY_TOO_LOW = 12; /** * 错误:文件校验失败(MD5/SHA256校验不匹配) */ public static final int ERROR_FILE_VERIFY_FAILED = 13; /** * 错误:文件解密或验签失败 */ public static final int ERROR_DECRYPT_OR_SIGN_FAILED = 14; // 更新阶段定义 /** * 空闲状态(未开始升级) */ private static final int PHASE_IDLE = 0; /** * 设备检测阶段(检查USB/存储设备) */ private static final int PHASE_DETECTING = 1; /** * 升级包校验阶段(验证完整性/签名) */ private static final int PHASE_CHECKING = 2; /** * 系统备份阶段(备份当前系统数据) */ private static final int PHASE_BACKUP = 3; /** * 文件复制阶段(写入升级包到临时分区) */ private static final int PHASE_COPYING = 4; /** * 解压阶段(解压升级包内容) */ private static final int PHASE_EXTRACTING = 5; /** * 清理阶段(删除临时文件) */ private static final int PHASE_CLEANUP = 6; /** * 回滚阶段(升级失败时恢复备份) */ private static final int PHASE_ROLLBACK = 7; // 权重分配比例 private static final float BACKUP_WEIGHT = 0.1f; // 备份阶段权重10% private static final float PACKAGE_COPY_WEIGHT = 0.29f; // 升级包拷贝阶段权重29% private static final float PACKAGE_VERIFY_WEIGHT = 0.31f; // 升级包解密验签阶段权重31% private static final float PACKAGE_EXTRACT_WEIGHT = 0.29f; // 升级包解压阶段权重29% private static final float VERIFICATION_WEIGHT = 0.01f; // 校验阶段权重1% // 声明 mProgress 为实例变量(非静态) private float mProgress = 0; // 当前进度值(0~1) private static USBOfflineUpdater instance; private final Context context; private UpdateTask currentTask; private UpdateListener updateListener; private SDKNativeEngine sdkNativeEngine; private MapDataController mapDataController; // 目录配置 private File usbRoot; private File cacheDir; private File storageDir; // 更新控制 public boolean isPaused = false; public volatile boolean isCancelled = false; public final AtomicInteger currentPhase = new AtomicInteger(PHASE_IDLE); // 错误信息 private String lastErrorMessage = ""; // 电量阈值 private static final int MIN_BATTERY_LEVEL = 30; // 最低电量百分比 private static final int MIN_BATTERY_LEVEL_CRITICAL = 15; // 严重低电量 //升级包格式 private static final String FILE_NAME_PATTERN = "^KVM_Navi_EU_" + "(?<version>\\d{1,3})" // 版本号(1-3位数字) + "_" + "(?<serial>\\d{1,2})" // 1-2位数字编号 + "(\\.\\w+)?$"; // 可选的文件扩展名 // 进度计算相关变量(新增) private long totalUpdateSize = 0; // 所有升级包总大小 private long UpdateSize = 0; // 当前升级包大小 // private long currentCopiedBytes = 0; // 当前已拷贝字节数 // private long currentVerifiedBytes = 0; // 当前已验签字节数 // private long currentExtractedBytes = 0; // 当前已解压字节数 private final AtomicLong currentCopiedBytes = new AtomicLong(0); private final AtomicLong currentVerifiedBytes = new AtomicLong(0); private final AtomicLong currentExtractedBytes = new AtomicLong(0); // // private long backupSize = 0; // 备份数据大小 private int currentPackageIndex = 0; // 当前处理的升级包索引 private int totalPackageCount = 0; // 总升级包数量 // 用于跟踪回滚状态 public boolean isRollingBack = false; private long rollbackTotalSize = 0; private long rollbackProcessedSize = 0; //--------------------zwx----多线程--------------- // 多线程控制变量 private final Object packageLock = new Object(); private volatile boolean updateSuccess = true; private int nextPackageIndex = 0; private Handler mainHandler = new Handler(Looper.getMainLooper()); private volatile CountDownLatch allPackagesLatch; private volatile ExecutorService updateExecutor; private final AtomicLong totalPackageSize = new AtomicLong(0); // 每个包的最大进度(17.8%) private static final double MAX_PACKAGE_PROGRESS = 17.8; // 总进度中用于升级包处理的部分(89%) private static final double TOTAL_PACKAGE_PROGRESS = 89.0; // 存储每个包的当前进度(0.0 ~ 17.8) private final Map<String, Double> packageProgressMap = new ConcurrentHashMap<>(); // 存储每个包的总大小(用于计算进度比例) private final Map<String, Long> packageTotalSizeMap = new ConcurrentHashMap<>(); // 存储每个包的当前处理阶段(拷贝、验签、解压) private final Map<String, Integer> packageStageMap = new ConcurrentHashMap<>(); // 阶段标识 private static final int STAGE_COPY = 0; private static final int STAGE_VERIFY = 1; private static final int STAGE_EXTRACT = 2; //-----------------------------zwxend---------------- // USB监听器 private final BroadcastReceiver usbReceiver = new BroadcastReceiver() { @Override public void onReceive(Context context, Intent intent) { String action = intent.getAction(); if (Intent.ACTION_MEDIA_MOUNTED.equals(action)) { File usbPath = new File(intent.getData().getPath()); if (usbPath.exists() && usbPath.canRead()) { usbRoot = usbPath; Log.i(TAG, "USB mounted: " + usbRoot.getAbsolutePath()); } } else if (Intent.ACTION_MEDIA_EJECT.equals(action) || Intent.ACTION_MEDIA_UNMOUNTED.equals(action)) { if (currentTask != null && currentPhase.get() > PHASE_CHECKING) { cancelUpdate(ERROR_USB_REMOVED, "USB设备被移除"); } usbRoot = null; Log.e(TAG, "USB removed"); } } }; // // 车辆状态监听器(模拟) private final BroadcastReceiver vehicleReceiver = new BroadcastReceiver() { @Override public void onReceive(Context context, Intent intent) { if ("com.example.ACTION_SHIFT_CHANGE".equals(intent.getAction())) { String shift = intent.getStringExtra("shift"); if (!"P".equals(shift) && currentPhase.get() > PHASE_CHECKING) { // cancelUpdate(ERROR_VEHICLE_SHIFTED, "车辆已退出P挡"); } } } }; // 单例模式 public static synchronized USBOfflineUpdater getInstance(Context context) { if (instance == null) { instance = new USBOfflineUpdater(context); } return instance; } public static synchronized USBOfflineUpdater getInstance() { return instance; } private USBOfflineUpdater(Context context) { this.context = context.getApplicationContext(); // 初始化SDK sdkNativeEngine = SDKNativeEngine.getSharedInstance(); mapDataController = MapDataController.getInstance(); try { DfCert.getInstance().getService(); } catch (Exception e) { Log.e(TAG, "Exception:" + e.toString()); } // 初始化目录(默认值) cacheDir = this.context.getCacheDir(); storageDir = new File(sdkNativeEngine.getOptions().persistentMapStoragePath); // 注册USB监听器 IntentFilter usbFilter = new IntentFilter(); usbFilter.addAction(Intent.ACTION_MEDIA_MOUNTED); usbFilter.addAction(Intent.ACTION_MEDIA_EJECT); usbFilter.addAction(Intent.ACTION_MEDIA_UNMOUNTED); usbFilter.addDataScheme("file"); context.registerReceiver(usbReceiver, usbFilter); // 注册车辆状态监听器(模拟) IntentFilter vehicleFilter = new IntentFilter("com.example.ACTION_SHIFT_CHANGE"); context.registerReceiver(vehicleReceiver, vehicleFilter); //清除数据存储目录下的预留数据 removeLegacy(); } public void initialization(UpdateListener listener) { isRollingBack = true; this.updateListener = listener; Thread USBOfflineUpdaterInitialization = new Thread(new Runnable() { @Override public void run() { // 启动时检查恢复 checkRecoveryOnStartup(); } }); USBOfflineUpdaterInitialization.setName("USBOfflineUpdaterInitialization"); USBOfflineUpdaterInitialization.start(); } // 动态设置目录 public void setDirectories(File usbRoot, File cacheDir, File storageDir) { if (usbRoot != null) { this.usbRoot = usbRoot; } if (cacheDir != null) { this.cacheDir = cacheDir; } if (storageDir != null) { this.storageDir = storageDir; } } /** * 检测升级包 * * @return 状态码 (SUCCESS 或错误码) */ public int detectUpdatePackages() { // 1. 检测USB是否插入 if (usbRoot == null || !usbRoot.exists() || !usbRoot.isDirectory()) { return ERROR_NO_USB; } File[] tempPackages = usbRoot.listFiles(); // 2. 查找升级包 (命名格式: update_v{版本号}_{日期}.zip) File[] packages = usbRoot.listFiles(file -> file.isFile() && file.getName().matches(FILE_NAME_PATTERN) ); return (packages != null && packages.length > 0) ? SUCCESS : ERROR_NO_UPDATE_PACKAGE; } /** * 环境检测 * * @return 状态码 (SUCCESS 或错误码) */ public int checkEnvironment() { // 1. 检测电量 int batteryLevel = PowerUtils.getBatteryLevel(context); if (batteryLevel < MIN_BATTERY_LEVEL) { return batteryLevel < MIN_BATTERY_LEVEL_CRITICAL ? ERROR_BATTERY_TOO_LOW : ERROR_BATTERY_LOW; } // 2. 检测缓存空间 (需大于15GB) long requiredSpace = 15L * 1024 * 1024 * 1024; // 15GB long availableSpace = StorageUtils.getAvailableSpace(cacheDir); if (availableSpace < requiredSpace) { Log.e(TAG, "缓存空间剩余:【" + availableSpace + "】"); return ERROR_STORAGE_INSUFFICIENT; } return SUCCESS; } /** * 判读是否正在进行离线更新 */ public boolean isOfflineUpdate() { return currentTask != null && !currentTask.isCancelled(); } /** * 开始更新 */ public void startUpdate(UpdateListener listener) { int result = checkEnvironment(); if (result != SUCCESS) { notifyListener(result, "环境检测不合格"); return; } if (isOfflineUpdate()) { notifyListener(ERROR_UPDATE_IN_PROGRESS, "已有更新任务正在进行"); return; } if (isRollingBack) { notifyListener(ERROR_UPDATE_IN_PROGRESS, "正在进行数据回滚"); return; } Log.i(TAG, "检测到更新任务触发,开始进行地图更新"); notifyProgress("开始进行更新"); // 计算总工作量(新增) calculateTotalWorkload(); this.updateListener = listener; currentTask = new UpdateTask(); currentTask.execute(); } // 计算总工作量(新增) private void calculateTotalWorkload() { totalUpdateSize = 0; File[] packages = getUpdatePackages(); totalPackageCount = packages != null ? packages.length : 0; if (packages != null) { for (File pkg : packages) { totalUpdateSize += pkg.length(); } } backupSize = estimateBackupSize(); Log.i(TAG, "总工作量计算: 升级包数量=" + totalPackageCount + ", 升级包大小=" + formatSize(totalUpdateSize) + ", 备份大小=" + formatSize(backupSize)); } // 估算备份大小方法(避免返回0导致除0错误) private long estimateBackupSize() { long storageSize = FileUtilszwx.getDirectorySize(storageDir); long size = (long) (storageSize * 1.2); return size > 0 ? size : 1; // 确保不为0 } // 获取更新包(新增) private File[] getUpdatePackages() { if (usbRoot == null) return new File[0]; return usbRoot.listFiles(file -> file.isFile() && file.getName().matches(FILE_NAME_PATTERN) ); } // 格式化文件大小(新增) public static String formatSize(long size) { if (size < 1024) return size + "B"; else if (size < 1024 * 1024) return String.format("%.1fKB", size / 1024.0); else if (size < 1024 * 1024 * 1024) return String.format("%.1fMB", size / (1024.0 * 1024)); else return String.format("%.1fGB", size / (1024.0 * 1024 * 1024)); } /** * 暂停更新 */ public void pauseUpdate() { isPaused = true; notifyProgress("更新已暂停"); } /** * 恢复更新 */ public void resumeUpdate() { isPaused = false; notifyProgress("更新已恢复"); } /** * 取消更新 */ public void cancelUpdate() { cancelUpdate(ERROR_USER_CANCELED, "用户取消更新"); } private void cancelUpdate(int errorCode, String message) { isCancelled = true; lastErrorMessage = message; notifyListener(errorCode, message); } // 进度通知 private void notifyProgress(String message) { new Handler(Looper.getMainLooper()).post(() -> { if (updateListener != null) { // 计算当前总进度(修改) float progress = calculateOverallProgress(); updateListener.onProgress(currentPhase.get(), progress, message); } }); } private float calculateOverallProgress() { if (isRollingBack) { // 回滚阶段:直接计算回滚进度 if (rollbackTotalSize > 0) { mProgress = 99; return mProgress; } return 0; } if (totalPackageCount == 0 && currentPhase.get() != PHASE_ROLLBACK) return 0; // 每个包的总权重(拷贝+验签+解压) float packageTotalWeight = PACKAGE_COPY_WEIGHT + PACKAGE_VERIFY_WEIGHT + PACKAGE_EXTRACT_WEIGHT; //再次确认totalPackageCount是否等于0 if (totalPackageCount == 0) { throw new IllegalStateException("totalPackageCount should not be 0 here!"); } // 每个包的阶段权重 float packageCopyWeight = PACKAGE_COPY_WEIGHT / totalPackageCount; float packageVerifyWeight = PACKAGE_VERIFY_WEIGHT / totalPackageCount; float packageExtractWeight = PACKAGE_EXTRACT_WEIGHT / totalPackageCount; switch (currentPhase.get()) { case PHASE_BACKUP: if(backupSize > 0) { mProgress = BACKUP_WEIGHT * (currentCopiedBytes.get() / (float) backupSize); }else{ mProgress = BACKUP_WEIGHT; } if(mProgress > BACKUP_WEIGHT) { mProgress = BACKUP_WEIGHT; } break; case PHASE_COPYING: // 基础:备份 + 已完成包的完整进度 float copyBase = BACKUP_WEIGHT + (packageTotalWeight * currentPackageIndex) / totalPackageCount; // 增量:当前包拷贝进度 float copyProgress = currentCopiedBytes.get() / (float) UpdateSize; mProgress = copyBase + packageCopyWeight * copyProgress; break; case PHASE_CHECKING: // 基础:备份 + 已完成包的完整进度 + 当前包拷贝完成 float verifyBase = BACKUP_WEIGHT + (packageTotalWeight * currentPackageIndex) / totalPackageCount + packageCopyWeight; // 增量:当前包验签进度 float verifyProgress = currentVerifiedBytes.get() / (float) UpdateSize; mProgress = verifyBase + packageVerifyWeight * verifyProgress; break; case PHASE_EXTRACTING: // 修复:添加当前包验签完成 float extractBase = BACKUP_WEIGHT + (packageTotalWeight * currentPackageIndex) / totalPackageCount + packageCopyWeight + packageVerifyWeight; // 添加这行 // 增量:当前包解压进度 float extractProgress = currentExtractedBytes.get() / (float) UpdateSize; mProgress = extractBase + packageExtractWeight * extractProgress; break; case PHASE_DETECTING: mProgress = BACKUP_WEIGHT + packageTotalWeight + VERIFICATION_WEIGHT * (currentVerifiedBytes.get() / (float) totalUpdateSize); break; case PHASE_CLEANUP: case PHASE_ROLLBACK: mProgress = 0.99f; break; } return Math.min(Math.round(mProgress * 10000) / 100.00f, 100.00f); } // 结果通知 private void notifyListener(int resultCode, String message) { new Handler(Looper.getMainLooper()).post(() -> { if (updateListener != null) { updateListener.onResult(resultCode, message); } }); } // 获取当前进度百分比 private int getCurrentProgress() { // 此处可添加子任务进度计算 return (int) mProgress; } // =============================== 核心更新逻辑 =============================== private class UpdateTask extends AsyncTask<Void, Void, Integer> { private File backupFile; private File[] updatePackages; @Override protected void onPreExecute() { currentPhase.set(PHASE_DETECTING); isCancelled = false; isPaused = false; currentCopiedBytes.set(0); currentExtractedBytes.set(0); mProgress = 0; // 重置进度为0 } // @Override // protected Integer doInBackground(Void... voids) { // try { // // 阶段1: 备份数据 // currentPhase.set(PHASE_BACKUP); // notifyProgress("开始备份数据..."); // // backupFile = new File(cacheDir, "backup.zip"); // if (backupFile.exists()) { //// backupFile.delete(); // if (!backupFile.delete()) { // throw new IOException("删除备份文件失败: " + backupFile.getAbsolutePath()); // } // } // // // 计算实际备份大小 // Log.i(TAG, "核对实际需要备份的数据大小"); // backupSize = estimateBackupSize(); // Log.i(TAG, "需要备份的数据大小:[" + formatSize(backupSize) + "]"); // // if (backupSize > 0) { // Log.i(TAG, "开始进行数据备份"); // boolean backupResult = FileUtilszwx.compressDirectoryWithProgress( // storageDir, // backupFile, // (copied, total) -> { // currentCopiedBytes.set(copied); // notifyProgress("备份数据: " + formatSize(copied) + "/" + formatSize(total)); // // // 检查是否被取消(场景1) // if (isCancelled) { // throw new InterruptedException("用户取消备份"); // } // } // ); // // if (!backupResult) { // lastErrorMessage = "备份创建失败"; // return ERROR_UNEXPECTED; // } // } else { // // 无数据需要备份,直接标记完成 // Log.i(TAG, "无备份数据,直接进行数据更新"); //// currentCopiedBytes = 1; // currentCopiedBytes.set(1); // // backupSize = 1; // notifyProgress("无数据需要备份"); // } // // // 检查是否被取消(场景1) // if (isCancelled) { // return ERROR_USER_CANCELED; // } // // // 阶段2: 处理升级包 // updatePackages = getUpdatePackages(); // Log.i(TAG, "开始处理升级包【拷贝、解密验签、解压】"); // for (currentPackageIndex = 0; currentPackageIndex < updatePackages.length; currentPackageIndex++) { // if (isCancelled) return ERROR_USER_CANCELED; // // // 处理暂停状态 // while (isPaused) { // Thread.sleep(500); // } // // File packageFile = updatePackages[currentPackageIndex]; // UpdateSize = updatePackages[currentPackageIndex].length(); // String packageName = packageFile.getName(); // long packageSize = packageFile.length(); // // // 备份完成后重置拷贝计数器 // currentCopiedBytes.set(0); // // // 阶段3: 拷贝升级包 // currentPhase.set(PHASE_COPYING); // File destFile = new File(storageDir, packageName); // // // 拷贝时更新进度(修改) // boolean copyResult = FileUtilszwx.copyFileWithProgress( // packageFile, // destFile, // (copied, total) -> { // currentCopiedBytes.set(copied); // notifyProgress(String.format("拷贝 %s: %s/%s", // packageName, // formatSize(copied), // formatSize(total))); // } // ); // // if (!copyResult) { // lastErrorMessage = "拷贝失败: " + packageName; // return ERROR_COPY_FAILED; // } // // // 阶段4:解密验签 // currentPhase.set(PHASE_CHECKING); // currentVerifiedBytes.set(0); // 重置验签计数器 // // // 创建进度回调适配器 // DecryptUtil.ProgressCallback decryptCallback = new DecryptUtil.ProgressCallback() { // @Override // public void onProgress(long processed, long total) { // // 直接更新验签进度计数器 // currentVerifiedBytes.set(processed); // // // 触发进度通知 // notifyProgress(String.format("解密验签 %s: %s/%s", // packageName, // formatSize(processed), // formatSize(total))); // } // }; // // // 执行解密验签(传入回调) // // 执行解密验签(传入回调) // if (!dataVerification(destFile.getAbsolutePath(), decryptCallback)) { // if (!isCancelled) { // return ERROR_DECRYPT_OR_SIGN_FAILED; // } // } // // // 确保进度设置为100% // currentVerifiedBytes.set(UpdateSize); // // // 阶段5: 解压升级包 // currentPhase.set(PHASE_EXTRACTING); // // // 修复:重置解压计数器 // currentExtractedBytes.set(0); // 重置计数器 // notifyProgress("解压升级包: " + packageName); // // // 解压时更新进度(修改) // boolean extractResult = FileUtilszwx.extractZipWithProgress( // destFile, // storageDir, // (extracted, total) -> { // currentExtractedBytes.set(extracted); // notifyProgress(String.format("解压 %s: %s/%s", // packageName, // formatSize(extracted), // formatSize(total))); // } // ); // // if (!extractResult) { // lastErrorMessage = "解压失败: " + packageName; // return ERROR_EXTRACT_FAILED; // } // // // 删除已解压的升级包以节省空间 // if (!destFile.delete()) { // Log.w(TAG, "删除升级包失败: " + destFile.getName()); // } // // // 更新解压进度(完成当前包) // // currentExtractedBytes.addAndGet(packageSize); // } // // if (!mapDataController.checkInstallationStatus()) { // notifyProgress("校验失败"); // return ERROR_FILE_VERIFY_FAILED; // } else { // notifyProgress("校验成功"); // } // // MapDownloader.fromEngineAsync(sdkNativeEngine, new MapDownloaderConstructionCallback() { // @Override // public void onMapDownloaderConstructedCompleted(@NonNull MapDownloader downloader) { // Log.i(TAG, "数据同步成功"); // } // }); // // // 阶段5: 清理工作 // currentPhase.set(PHASE_CLEANUP); // notifyProgress("清理缓存..."); // if (backupFile.exists() && !backupFile.delete()) { // Log.w(TAG, "删除备份文件失败"); // } // // // 最终进度设为100% // notifyProgress("更新完成"); // return SUCCESS; // } catch (InterruptedException e) { // // 场景1:备份未完成时被取消 // if (backupFile != null && backupFile.exists()) { //// backupFile.delete(); // if (!backupFile.delete()) { // Log.w(TAG, "删除备份文件失败: " + backupFile.getAbsolutePath()); // } // } // lastErrorMessage = "更新任务被中断"; // return ERROR_USER_CANCELED; // } catch (Exception e) { // lastErrorMessage = "未知错误: " + e.getMessage(); // Log.e(TAG, "更新失败", e); // return ERROR_UNEXPECTED; // } // } @Override protected Integer doInBackground(Void... voids) { try { // 阶段1: 备份数据 currentPhase.set(PHASE_BACKUP); notifyProgress("开始备份数据..."); backupFile = new File(cacheDir, "backup.zip"); if (backupFile.exists()) { if (!backupFile.delete()) { throw new IOException("删除旧备份文件失败: " + backupFile.getAbsolutePath()); } } // 计算备份大小 Log.i(TAG, "核对实际需要备份的数据大小"); backupSize = estimateBackupSize(); Log.i(TAG, "需要备份的数据大小: [" + formatSize(backupSize) + "]"); if (backupSize > 0) { Log.i(TAG, "开始进行数据备份"); boolean backupResult = FileUtilszwx.compressDirectoryWithProgress( storageDir, backupFile, (copied, total) -> { currentCopiedBytes.set(copied); notifyProgress("备份数据: " + formatSize(copied) + "/" + formatSize(total)); if (isCancelled()) { throw new InterruptedException("用户取消备份"); } } ); if (!backupResult) { lastErrorMessage = "备份创建失败"; return ERROR_UNEXPECTED; } } else { Log.i(TAG, "无数据需要备份,直接进行数据更新"); currentCopiedBytes.set(1); backupSize = 1; notifyProgress("无数据需要备份"); } if (isCancelled()) { return ERROR_USER_CANCELED; } // 阶段2: 获取升级包列表 updatePackages = getUpdatePackages(); int packageCount = updatePackages.length; if (packageCount == 0) { notifyProgress("无升级包"); return ERROR_USER_CANCELED; } Log.i(TAG, "开始并行处理升级包【拷贝、解密验签、解压】"); ExecutorService executor = Executors.newFixedThreadPool(Math.min(packageCount, 4)); // 限制线程数 allPackagesLatch = new CountDownLatch(packageCount); updateSuccess = true; // 提交所有任务 for (File packageFile : updatePackages) { executor.execute(new PackageProcessingTask(packageFile)); } // 等待所有任务完成或取消 while (!isCancelled() && allPackagesLatch.getCount() > 0) { Thread.sleep(200); // 避免 CPU 空转 } // 如果任务被取消,则中断所有任务 if (isCancelled()) { executor.shutdownNow(); return ERROR_USER_CANCELED; } // 关闭线程池 executor.shutdown(); // 检查是否全部成功 if (!updateSuccess) { lastErrorMessage = "升级包处理失败"; return ERROR_USER_CANCELED; } // 数据校验 if (!mapDataController.checkInstallationStatus()) { notifyProgress("校验失败"); return ERROR_FILE_VERIFY_FAILED; } else { notifyProgress("校验成功"); } // 异步构建 MapDownloader MapDownloader.fromEngineAsync(sdkNativeEngine, new MapDownloaderConstructionCallback() { @Override public void onMapDownloaderConstructedCompleted(@NonNull MapDownloader downloader) { Log.i(TAG, "数据同步成功"); } }); // 阶段3: 清理工作 currentPhase.set(PHASE_CLEANUP); notifyProgress("清理缓存..."); if (backupFile.exists() && !backupFile.delete()) { Log.w(TAG, "删除备份文件失败"); } notifyProgress("更新完成"); return SUCCESS; } catch (InterruptedException e) { if (backupFile != null && backupFile.exists()) { backupFile.delete(); } lastErrorMessage = "更新任务被中断"; return ERROR_USER_CANCELED; } catch (Exception e) { lastErrorMessage = "未知错误: " + e.getMessage(); Log.e(TAG, "更新失败", e); return ERROR_UNEXPECTED; } } @Override protected void onPostExecute(Integer resultCode) { if (resultCode == SUCCESS) { notifyListener(SUCCESS, "更新成功,请重启车机"); currentPhase.set(PHASE_IDLE); currentTask = null; } else { // 只有备份完成时才进行回滚(场景2) if (backupFile != null && backupFile.exists()) { // 场景2:进入回滚流程 isRollingBack = true; currentPhase.set(PHASE_ROLLBACK); // 先发送回滚进度通知(99%) notifyProgress("更新失败,正在回滚数据..."); // 保存错误消息,因为回滚完成后还需要使用 final String errorMessage = lastErrorMessage; // 启动回滚线程 new Thread(new Runnable() { @Override public void run() { try { // 执行回滚 performRollback(backupFile); } finally { // 回滚完成后删除备份 // backupFile.delete(); if (backupFile.exists() && !backupFile.delete()) { Log.w(TAG, "删除备份文件失败: " + backupFile.getAbsolutePath()); } if (!isCancelled) { // 回滚完成后发送最终结果 notifyListener(resultCode, getErrorMessage(resultCode)); } // 重置状态 currentPhase.set(PHASE_IDLE); currentTask = null; isRollingBack = false; } } }).start(); } else { // 场景1:没有备份文件,直接报告错误 notifyListener(resultCode, lastErrorMessage); currentPhase.set(PHASE_IDLE); currentTask = null; } } } } // ================== 启动时恢复检查 ================== private void checkRecoveryOnStartup() { File backupFile = findLatestBackupFile(); // 存在备份文件说明上次更新中断 if (backupFile != null && backupFile.exists()) { long fileSize = backupFile.length(); long expectedSize = estimateBackupSize(); // 场景3:备份未完成(文件大小小于预期大小的90%) if (fileSize < expectedSize * 0.9) { isRollingBack = true; Log.i(TAG, "检测到未完成的备份,删除: " + backupFile.getName()); // backupFile.delete(); if (backupFile.exists() && !backupFile.delete()) { Log.w(TAG, "删除备份文件失败: " + backupFile.getAbsolutePath()); } return; }else { // 场景4:备份已完成,启动回滚 Log.i(TAG, "检测到完整的备份,开始回滚: " + backupFile.getName()); currentPhase.set(PHASE_ROLLBACK); notifyProgress("检测到未完成更新,正在恢复数据..."); // 执行回滚 performRollback(backupFile); // 删除备份文件 // backupFile.delete(); if (backupFile.exists() && !backupFile.delete()) { Log.w(TAG, "删除备份文件失败: " + backupFile.getAbsolutePath()); } notifyProgress("数据恢复完成"); } } isRollingBack = false; this.updateListener = null; } private void checkStoragePerformance() { long writeSpeed = StorageUtils.measureWriteSpeed(storageDir); Log.d(TAG, "存储写入速度: " + formatSize(writeSpeed) + "/s"); if (writeSpeed < 50 * 1024 * 1024) { // 低于 50MB/s Log.w(TAG, "检测到低速存储设备,还原操作可能较慢"); } } // 新增回滚方法 private void performRollback(File backupFile) { try { // 1. 设置回滚进度为99% rollbackProcessedSize = backupFile.length() * 99 / 100; notifyProgress("开始恢复备份..."); // 2. 删除更新后的数据 FileUtilszwx.deleteRecursive(storageDir); if (!storageDir.mkdirs()) { Log.w(TAG, "创建存储目录失败"); } // 3. 恢复备份 rollbackTotalSize = backupFile.length(); rollbackProcessedSize = 0; boolean restoreSuccess = FileUtilszwx.extractZipWithProgress( backupFile, storageDir, (extracted, total) -> { rollbackProcessedSize = extracted; notifyProgress(String.format("恢复备份 %s: %s/%s", backupFile.getName(), formatSize(extracted), formatSize(total))); } ); if (!restoreSuccess) { Log.e(TAG, "备份恢复失败"); } } catch (Exception e) { Log.e(TAG, "回滚过程中发生错误", e); } } private File findLatestBackupFile() { File[] backups = cacheDir.listFiles(file -> file.isFile() && file.getName().startsWith("backup.zip") ); if (backups == null || backups.length == 0) { return null; } return backups[0]; } // 释放资源 public void release() { try { context.unregisterReceiver(usbReceiver); context.unregisterReceiver(vehicleReceiver); } catch (Exception e) { Log.w(TAG, "释放资源时出错", e); } } // ================== 接口定义 ================== public interface UpdateListener { void onProgress(int phase, float progress, String message); void onResult(int resultCode, String message); } public interface ProgressCallback { void onProgress(long progress, long total) throws InterruptedException; } public File getUsbRoot() { return usbRoot; } public File getCacheDir() { return cacheDir; } public File getStorageDir() { return storageDir; } public String getErrorMessage(int code) { return switch (code) { case ERROR_NO_USB -> "未检测到USB设备,请检查连接状态或更换接口"; case ERROR_NO_UPDATE_PACKAGE -> "升级包文件缺失,请确认存储路径"; case ERROR_BATTERY_LOW -> "电池电量不足(需≥20%)"; case ERROR_STORAGE_INSUFFICIENT -> "存储空间不足(需预留500MB以上)"; case ERROR_UPDATE_IN_PROGRESS -> "系统正在执行其他升级任务"; case ERROR_COPY_FAILED -> "文件复制失败,请检查存储权限"; case ERROR_EXTRACT_FAILED -> "升级包解压失败(可能文件损坏)"; case ERROR_USER_CANCELED -> "用户已取消升级操作"; case ERROR_UNEXPECTED -> "发生未预期的系统异常"; case ERROR_USB_REMOVED -> "升级过程中USB设备被移除"; case ERROR_VEHICLE_SHIFTED -> "请将车辆档位切换至P档"; case ERROR_BATTERY_TOO_LOW -> "电池电量极低(需≥10%)"; case ERROR_FILE_VERIFY_FAILED -> "文件校验失败(MD5/SHA256不匹配)"; case ERROR_DECRYPT_OR_SIGN_FAILED -> "文件解密/验签失败"; default -> "未知错误导致更新失败"; }; } void removeLegacy() { if (storageDir == null || !storageDir.exists() || !storageDir.isDirectory()) { return; } Pattern pattern = Pattern.compile(FILE_NAME_PATTERN); File[] files = storageDir.listFiles(); if (files == null) return; for (File file : files) { if (file.isFile() && pattern.matcher(file.getName()).matches()) { // 删除匹配的文件 try { Files.deleteIfExists(file.toPath()); } catch (IOException | SecurityException e) { // 处理异常(记录日志等) } } } // 删除sign文件夹(如果存在) Path signDir = Paths.get(storageDir.getAbsolutePath(), "sign"); if (Files.exists(signDir)) { try { // 递归删除整个目录 deleteDirectoryRecursively(signDir); } catch (IOException | SecurityException e) { // 处理异常 } } } private void deleteDirectoryRecursively(Path path) throws IOException { if (Files.isDirectory(path)) { // 使用 try-with-resources 确保 Stream 关闭 try (Stream<Path> children = Files.list(path)) { children.forEach(child -> { try { deleteDirectoryRecursively(child); } catch (IOException e) { // 处理子项删除异常 throw new UncheckedIOException(e); // 转换为 RuntimeException 以便在 Stream 中抛出 } }); } catch (UncheckedIOException e) { // 重新抛出原始 IOException throw e.getCause(); } } // 删除空目录或文件 Files.deleteIfExists(path); } //-----------------------------------多线程私有类------------------------ /** * 并行更新管道,用于并发处理多个升级包 */ /** * 并行更新管道,用于并发处理多个升级包 */ /** * 并行处理升级包 */ private void shutdownUpdateExecutor() { if (updateExecutor != null && !updateExecutor.isShutdown()) { updateExecutor.shutdownNow(); } } private void updateProgress(String format, String packageName, long processed, long total) { String message = String.format(format, packageName, formatSize(processed), formatSize(total)); mainHandler.post(() -> notifyProgress(message)); } public void processUpdatePackagesInParallel(File[] packages) { if (packages == null || packages.length == 0) { notifyListener(SUCCESS, "无升级包"); return; } // 初始化每个包的总大小进度 for (File packageFile : packages) { packageTotalSizeMap.put(packageFile.getName(), packageFile.length()); packageProgressMap.put(packageFile.getName(), 0.0); packageStageMap.put(packageFile.getName(), STAGE_COPY); } allPackagesLatch = new CountDownLatch(packages.length); updateExecutor = Executors.newFixedThreadPool(3); for (File packageFile : packages) { updateExecutor.submit(new PackageProcessingTask(packageFile)); } // 在后台等待所有任务完成 new Thread(() -> { try { allPackagesLatch.await(); updateExecutor.shutdown(); if (isCancelled) { notifyListener(ERROR_USER_CANCELED, "用户取消"); } else if (updateSuccess && mapDataController.checkInstallationStatus()) { notifyProgress("校验成功"); Log.d(TAG, "最终进度:" + packageProgressMap); // 输出最终进度 notifyListener(SUCCESS, "更新成功,请重启车机"); } else { notifyProgress("校验失败"); notifyListener(ERROR_FILE_VERIFY_FAILED, "文件校验失败"); } } catch (InterruptedException e) { Thread.currentThread().interrupt(); notifyListener(ERROR_UNEXPECTED, "任务中断"); } }).start(); } private class PackageProcessingTask implements Runnable { private final File packageFile; public PackageProcessingTask(File packageFile) { this.packageFile = packageFile; } @Override public void run() { if (isCancelled) { allPackagesLatch.countDown(); return; } String packageName = packageFile.getName(); File destFile = new File(storageDir, packageName); // 阶段1: 拷贝 try { boolean copyResult = FileUtilszwx.copyFileWithProgress(packageFile, destFile, (copied, total) -> { packageStageMap.put(packageName, STAGE_COPY); updatePackageProgress(packageName, copied, total, STAGE_COPY); updateProgress("拷贝 %s: %s/%s", packageName, copied, total); }); if (!copyResult || isCancelled) { updateSuccess = false; allPackagesLatch.countDown(); return; } } catch (Exception e) { Log.e(TAG, "拷贝失败", e); updateSuccess = false; allPackagesLatch.countDown(); return; } // 阶段2: 解密验签 try { boolean verifyResult = dataVerification(destFile.getAbsolutePath(), new DecryptUtil.ProgressCallback() { @Override public void onProgress(long processed, long total) { packageStageMap.put(packageName, STAGE_VERIFY); updatePackageProgress(packageName, processed, total, STAGE_VERIFY); updateProgress("解密验签 %s: %s/%s", packageName, processed, total); } }); if (!verifyResult || isCancelled) { updateSuccess = false; allPackagesLatch.countDown(); return; } } catch (Exception e) { Log.e(TAG, "解密验签失败", e); updateSuccess = false; allPackagesLatch.countDown(); return; } // 阶段3: 解压 try { boolean extractResult = FileUtilszwx.extractZipWithProgress(destFile, storageDir, (extracted, total) -> { packageStageMap.put(packageName, STAGE_EXTRACT); updatePackageProgress(packageName, extracted, total, STAGE_EXTRACT); updateProgress("解压 %s: %s/%s", packageName, extracted, total); }); if (!extractResult || isCancelled) { updateSuccess = false; allPackagesLatch.countDown(); return; } } catch (Exception e) { Log.e(TAG, "解压失败", e); updateSuccess = false; allPackagesLatch.countDown(); return; } // 删除升级包 if (!destFile.delete()) { Log.w(TAG, "删除升级包失败: " + destFile.getName()); } // 最终进度设为最大值 packageProgressMap.put(packageName, MAX_PACKAGE_PROGRESS); allPackagesLatch.countDown(); // 任务完成 } } /** * 更新包的进度,并限制最大为 MAX_PACKAGE_PROGRESS */ private void updatePackageProgress(String packageName, long processed, long total, int stage) { double stageRatio = (double) processed / total; double progress = 0.0; switch (stage) { case STAGE_COPY: progress = MAX_PACKAGE_PROGRESS / 3 * stageRatio; break; case STAGE_VERIFY: progress = MAX_PACKAGE_PROGRESS / 3 + MAX_PACKAGE_PROGRESS / 3 * stageRatio; break; case STAGE_EXTRACT: progress = MAX_PACKAGE_PROGRESS / 3 * 2 + MAX_PACKAGE_PROGRESS / 3 * stageRatio; break; } packageProgressMap.put(packageName, Math.min(progress, MAX_PACKAGE_PROGRESS)); } } 仔细检查代码是否存在冲突问题
08-22
private final Handler uiHandler = new Handler(Looper.getMainLooper()); private void notifyProgress(String message) { uiHandler.post(() -> { if (updateListener != null) { float progress = calculateOverallProgress(); updateListener.onProgress(currentPhase.get(), progress, message); } }); } private float calculateOverallProgress() { // if (isRollingBack) { // // 回滚阶段:直接计算回滚进度 // if (rollbackTotalSize > 0) { // mProgress = 99; // return mProgress; // } // return 0; // } if (totalPackageCount == 0 && currentPhase.get() != PHASE_ROLLBACK) return 0; // 每个包的总权重(拷贝+验签+解压) float packageTotalWeight = PACKAGE_COPY_WEIGHT + PACKAGE_VERIFY_WEIGHT + PACKAGE_EXTRACT_WEIGHT; //再次确认totalPackageCount是否等于0 if (totalPackageCount == 0) { throw new IllegalStateException("totalPackageCount should not be 0 here!"); } // 每个包的阶段权重 float packageCopyWeight = PACKAGE_COPY_WEIGHT / totalPackageCount; float packageVerifyWeight = PACKAGE_VERIFY_WEIGHT / totalPackageCount; float packageExtractWeight = PACKAGE_EXTRACT_WEIGHT / totalPackageCount; switch (currentPhase.get()) { case PHASE_BACKUP: if(backupSize > 0) { mProgress = BACKUP_WEIGHT * (currentCopiedBytes / (float) backupSize); }else{ mProgress = BACKUP_WEIGHT; } if(mProgress > BACKUP_WEIGHT) { mProgress = BACKUP_WEIGHT; } break; case PHASE_COPYING: // 基础:备份 + 已完成包的完整进度 float copyBase = BACKUP_WEIGHT + (packageTotalWeight * currentPackageIndex) / totalPackageCount; // 增量:当前包拷贝进度 float copyProgress = currentCopiedBytes / (float) UpdateSize; mProgress = copyBase + packageCopyWeight * copyProgress; break; case PHASE_CHECKING: // 基础:备份 + 已完成包的完整进度 + 当前包拷贝完成 float verifyBase = BACKUP_WEIGHT + (packageTotalWeight * currentPackageIndex) / totalPackageCount + packageCopyWeight; // 增量:当前包验签进度 float verifyProgress = currentVerifiedBytes / (float) UpdateSize; mProgress = verifyBase + packageVerifyWeight * verifyProgress; break; case PHASE_EXTRACTING: // 修复:添加当前包验签完成 float extractBase = BACKUP_WEIGHT + (packageTotalWeight * currentPackageIndex) / totalPackageCount + packageCopyWeight + packageVerifyWeight; // 添加这行 // 增量:当前包解压进度 float extractProgress = currentExtractedBytes / (float) UpdateSize; mProgress = extractBase + packageExtractWeight * extractProgress; break; case PHASE_DETECTING: mProgress = BACKUP_WEIGHT + packageTotalWeight + VERIFICATION_WEIGHT * (currentVerifiedBytes / (float) totalUpdateSize); break; case PHASE_CLEANUP: // case PHASE_ROLLBACK: // mProgress = 0.99f; // // break; case PHASE_ROLLBACK: // 增量:当前包拷贝进度 mProgress = rollBackCopiedBytes / (float) backupSize; break; } return Math.min(Math.round(mProgress * 10000) / 100.00f, 100.00f); } // 结果通知 private void notifyListener(int resultCode, String message) { new Handler(Looper.getMainLooper()).post(() -> { if (updateListener != null) { updateListener.onResult(resultCode, message); } }); } 总结目前的基于阶段计算进度的逻辑
09-02
void CappDlg::OnBnClickedAdmittanceControl() { // TODO: 在此添加控件通知处理程序代码 if (!m_bStiffnessControlActive) { if (!PathFileExists(_T("initial_stiffness_params.txt"))) { AfxMessageBox(_T("请先运行Python程序生成初始刚度参数")); return; } // 通知Python std::ofstream doneFile("control_done_signal.txt"); doneFile << "1"; doneFile.close(); StartStiffnessControl(); } else { StopStiffnessControl(); } } void CappDlg::OnBnClickedRotateyaxisIncremental() { //TODO: 在此添加控件通知处理程序代码 short ncid; ncid = TestOpenComm(0, 0); if (ncid < 0) { AfxMessageBox(_T("无法打开设备通信!")); return; } double deltaPose[12] = { 0,0,0,tx,ty,tz,0,0,0,0,0,0 }; // 初始化全0 // 2. 以最小角速度(0.1°/s)执行增量运动 short rv = BscImov( ncid, "VR", // 角速度模式(单位°/s) 0.1, // 0.1°/s(文档允许的最小值) "BASE", // 基于工具坐标系旋转 sToolno, // 工具号 deltaPose // 增量位姿 ); // 3. 错误处理 if (rv != 0) { char errMsg[256] = { 0 }; BscGetError(ncid); } rv = TestCloseComm(ncid); } void CappDlg::StartFileWatcher() { m_fileWatcherThread = std::thread([this]() { while (m_bFileWatcherRunning) { // 1. 检查初始参数 if (!m_bInitialParamsLoaded && PathFileExists(_T("initial_stiffness_params.txt"))) { std::lock_guard<std::mutex> lock(m_fileMutex); CStdioFile file; if (file.Open(_T("initial_stiffness_params.txt"), CFile::modeRead)) { CString strParams; file.ReadString(strParams); CT2CA convertedString(strParams); if (sscanf_s(convertedString, "%lf,%lf,%lf", &m_dStiffness[0], &m_dStiffness[1], &m_dStiffness[2]) == 3) { m_bInitialParamsLoaded = true; m_bParamsUpdated = true; CString strMsg; strMsg.Format(_T("初始刚度参数加载成功: %.1f, %.1f, %.1f"), m_dStiffness[0], m_dStiffness[1], m_dStiffness[2]); OutputDebugString(strMsg); // 通知Python std::ofstream doneFile("control_done_signal.txt"); doneFile << "1"; doneFile.close(); } file.Close(); //DeleteFile(_T("initial_stiffness_params.txt")); } } // 2. 检查新参数 if (PathFileExists(_T("new_stiffness_params.txt"))) { std::lock_guard<std::mutex> lock(m_fileMutex); CStdioFile file; if (file.Open(_T("new_stiffness_params.txt"), CFile::modeRead)) { CString strParams; file.ReadString(strParams); CT2CA convertedString(strParams); if (sscanf_s(convertedString, "%lf,%lf,%lf", &m_dStiffness[0], &m_dStiffness[1], &m_dStiffness[2]) == 3) { m_bParamsUpdated = true; CString strMsg; strMsg.Format(_T("新刚度参数接收: %.1f, %.1f, %.1f"), m_dStiffness[0], m_dStiffness[1], m_dStiffness[2]); OutputDebugString(strMsg); // 通知Python std::ofstream doneFile("control_done_signal.txt"); doneFile << "1"; doneFile.close(); } file.Close(); DeleteFile(_T("new_stiffness_params.txt")); } } // 3. 检查重置信号 if (PathFileExists(_T("reset_signal.txt"))) { std::lock_guard<std::mutex> lock(m_fileMutex); // 输出重置信息 OutputDebugString(_T("[INFO] 收到重置信号,开始重置机器人")); DeleteFile(_T("reset_signal.txt")); /*if (m_bStiffnessControlActive) { m_bStiffnessControlActive = false; if (m_stiffnessControlThread.joinable()) { m_stiffnessControlThread.join(); } OutputDebugString(_T("[INFO] 控制线程已停止")); }*/ ResetRobot(); SaveRobotState(); // 通知Python重置完成 std::ofstream doneFile("control_done_signal.txt"); doneFile << "1"; doneFile.close(); //// 更新UI显示 //GetDlgItem(IDC_ADMITTANCE_CONTROL)->SetWindowText(_T("启动变刚度控制")); //OutputDebugString(_T("[INFO] 机器人复位流程完成")); } Sleep(100); // 100ms检查一次 } }); } void CappDlg::SaveRobotStateAsync(const double* torque, const double* pose) { // 复制数据到局部变量 double localTorque[3], localPose[3]; memcpy(localTorque, torque, sizeof(localTorque)); memcpy(localPose, pose, sizeof(localPose)); // 异步保存 std::thread([=]() { std::lock_guard<std::mutex> lock(m_fileMutex); // 先写入临时文件 CString tempFile = _T("robot_state.tmp"); std::ofstream outfile(tempFile); if (outfile.is_open()) { outfile << localTorque[0] << "," << localTorque[1] << "," << localTorque[2] << "," << localPose[0] << "," << localPose[1] << "," << localPose[2]; outfile.close(); // 原子性替换文件 if (!::MoveFileEx(tempFile, _T("robot_state.csv"), MOVEFILE_REPLACE_EXISTING | MOVEFILE_WRITE_THROUGH)) { OutputDebugString(_T("[ERROR] 无法更新状态文件")); } // 通知Python状态已就绪 std::ofstream readyFile("state_ready_signal.txt"); readyFile << "1"; readyFile.close(); } }).detach(); } void CappDlg::SaveRobotState() { // 使用互斥锁保护共享数据 std::lock_guard<std::mutex> lock(m_mutex); // 如果通信端口忙,则跳过本次保存 if (m_bCommInUse) { OutputDebugString(_T("[INFO] 延迟保存状态:通信端口正忙")); return; } // 获取当前状态 double torque[3], pose[3]; GetTorqueData(torque); memcpy(pose, current_orientation, sizeof(pose)); // 异步保存 SaveRobotStateAsync(torque, pose); } void CappDlg::StartStiffnessControl() { if (!m_bStiffnessControlActive) { m_bStiffnessControlActive = true; m_stiffnessControlThread = std::thread(&CappDlg::StiffnessControlThread, this); // 更新按钮文本 GetDlgItem(IDC_ADMITTANCE_CONTROL)->SetWindowText(_T("停止变刚度控制")); } } void CappDlg::StopStiffnessControl() { if (m_bStiffnessControlActive) { m_bStiffnessControlActive = false; if (m_stiffnessControlThread.joinable()) { m_stiffnessControlThread.join(); } // 更新按钮文本 GetDlgItem(IDC_ADMITTANCE_CONTROL)->SetWindowText(_T("启动变刚度控制")); } } void CappDlg::StiffnessControlThread() { const double dt = 0.005; // 控制周期5ms double de[3] = { 0 }; double dde[3] = { 0 }; double last_orientation[3] = { 0 }; // 线程启动时打开连接 short nCid = SafeOpenComm(); if (nCid < 0) { OutputDebugString(_T("[ERROR] 控制线程启动失败:无法连接机器人")); return; } while (m_bStiffnessControlActive) { // 增加连接状态检查 if (m_nRobotCid < 0) { SafeCloseComm(); m_nRobotCid = SafeOpenComm(); if (m_nRobotCid < 0) { Sleep(1000); continue; } } // 检查连接有效性 if (nCid < 0) { nCid = SafeOpenComm(); if (nCid < 0) { Sleep(static_cast<DWORD>(dt * 1000)); continue; } } // 带重试的姿态获取 double dPos[12] = { 0 }; WORD rconf, toolno; bool bSuccess = false; for (int i = 0; i < 3 && !bSuccess; i++) { if (BscIsRobotPos(nCid, "BASE", 0, &rconf, &toolno, dPos) == 0) { bSuccess = true; // 更新当前姿态 current_orientation[0] = dPos[3]; current_orientation[1] = dPos[4]; current_orientation[2] = dPos[5]; // 获取力矩数据 GetTorqueData(dPos_force); // 计算调整量 bool orientation_changed = ( fabs(current_orientation[0] - last_orientation[0]) > 0.1 || fabs(current_orientation[1] - last_orientation[1]) > 0.1 || fabs(current_orientation[2] - last_orientation[2]) > 0.1); if (m_bParamsUpdated || orientation_changed) { double adjust[3]; for (int i = 0; i < 3; i++) { adjust[i] = OrientationControl(i); adjust[i] = clamp(adjust[i], -0.5, 0.5); dPos[3 + i] += adjust[i]; } // 发送运动指令 BscMovj(nCid, 0.3, "BASE", 0, 0, dPos); m_bParamsUpdated = false; memcpy(last_orientation, current_orientation, sizeof(last_orientation)); // 异步保存状态 SaveRobotStateAsync(dPos_force + 3, current_orientation); } } else if (i == 2) { // 最后一次尝试失败 OutputDebugString(_T("[ERROR] 连续获取姿态失败")); SafeCloseComm(); nCid = -1; } } Sleep(static_cast<DWORD>(dt * 1000)); } // 线程结束时关闭连接 SafeCloseComm(); } double CappDlg::OrientationControl(int axis) { // 1. 获取当前力矩 double tau = dPos_force[3 + axis]; // tx,ty,tz // 2. 计算姿态误差 double e = current_orientation[axis] - m_dTargetOrientation[axis]; // 3. 二阶系统计算 de[axis] += dt * dde[axis]; e += dt * de[axis]; dde[axis] = (tau - m_dDamping[axis] * de[axis] - m_dStiffness[axis] * e) / m_dInertia[axis]; return e; } void CappDlg::GetTorqueData(double* torque) { // 从传感器获取力矩数据 Sensor sensor; std::vector<float> data = sensor.getMeasurement(); for (int i = 0; i < 3; i++) { torque[i] = data[i+3] / 1000000.0; //数据需要转换 } } void CappDlg::OnDestroy() { // 停止所有线程 m_bRunning = false; m_bStiffnessControlActive = false; m_bFileWatcherRunning = false; // 等待线程结束 if (m_stiffnessControlThread.joinable()) { m_stiffnessControlThread.join(); } if (m_fileWatcherThread.joinable()) { m_fileWatcherThread.join(); } // 清理信号文件 DeleteFile(_T("control_done_signal.txt")); DeleteFile(_T("state_ready_signal.txt")); DeleteFile(_T("reset_signal.txt")); CDialogEx::OnDestroy(); } void CappDlg::OnBnClickedBtnReset() { // TODO: 在此添加控件通知处理程序代码 double dPos[12] = { 896, 13, -259, 175, 5, 0 }; // 重置位置 short nCid = TestOpenComm(0, 0); if (nCid >= 0) { BscMovj(nCid, 0.1, "BASE", 0, 0, dPos); TestCloseComm(nCid); } } //void CappDlg::ResetRobot() //{ // double dPos[12] = { 896, 13, -259, 175, 5, 0 }; // 重置位置 // short nCid = TestOpenComm(0, 0); // if (nCid >= 0) { // BscMovj(nCid, 0.1, "BASE", 0, 0, dPos); // TestCloseComm(nCid); // } //} void CappDlg::ResetRobot() { // 设置复位状态 m_bIsResetting = true; m_bResetCompleted = false; OutputDebugString(_T("[INFO] 开始机器人复位流程")); // 先停止所有使用通信端口的线程 StopStiffnessControl(); // 确保通信端口关闭 SafeCloseComm(); // 定义安全的初始位置(根据实际需求调整) double dPos[12] = { 896, 13, -259, 180, 0, 0 }; // 位置姿态值 // 使用C++11的随机数生成器,比rand()更安全 std::random_device rd; std::mt19937 gen(rd()); std::uniform_real_distribution<> dis_rx(175.0, 179.0); std::uniform_real_distribution<> dis_ry(1.0, 5.0); dPos[3] = dis_rx(gen); // rx dPos[4] = dis_ry(gen); // ry dPos[5] = 0; // rz (固定) // 记录随机生成的位置 CString strRandomPos; strRandomPos.Format(_T("[INFO] 随机生成的复位位置: [%.2f, %.2f, %.2f, %.2f, %.2f, %.2f]"), dPos[0], dPos[1], dPos[2], dPos[3], dPos[4], dPos[5]); OutputDebugString(strRandomPos); // 打开通信连接 short nCid = TestOpenComm(0, 0); if (nCid >= 0) { // 以较低速度移动到初始位置,确保安全 short result = BscMovj(nCid, 0.1, "BASE", 0, 0, dPos); if (result == 0) { OutputDebugString(_T("[INFO] 复位命令已发送,等待机器人移动到初始位置")); // 等待机器人完成移动(根据实际机器人速度调整等待时间) Sleep(15000); // 等待15秒 // 验证复位后的位置 double currentPos[12] = { 0 }; WORD rconf, toolno; if (BscIsRobotPos(nCid, "BASE", 0, &rconf, &toolno, currentPos) == 0) { CString strMsg; strMsg.Format(_T("[INFO] 机器人已复位到位置: [%.2f, %.2f, %.2f, %.2f, %.2f, %.2f]"), currentPos[0], currentPos[1], currentPos[2], currentPos[3], currentPos[4], currentPos[5]); OutputDebugString(strMsg); // 检查复位是否成功(角度误差是否在容差范围内) double reset_errors[3] = { CalculateAngleError(currentPos[3], m_dTargetOrientation[0]), CalculateAngleError(currentPos[4], m_dTargetOrientation[1]), CalculateAngleError(currentPos[5], m_dTargetOrientation[2]) }; bool reset_successful = true; for (int i = 0; i < 3; i++) { if (reset_errors[i] > 5.0) { // 复位后允许的最大误差(度) reset_successful = false; break; } } if (reset_successful) { OutputDebugString(_T("[INFO] 机器人复位成功")); m_bResetCompleted = true; } else { OutputDebugString(_T("[WARN] 机器人复位后角度误差仍较大")); m_bResetCompleted = false; } } else { OutputDebugString(_T("[WARN] 复位后无法获取当前位置")); m_bResetCompleted = false; } } else { OutputDebugString(_T("[ERROR] 复位命令发送失败")); m_bResetCompleted = false; } // 关闭通信连接 SafeCloseComm(); // 复位完成后立即关闭 } else { OutputDebugString(_T("[ERROR] 无法打开通信连接进行复位")); m_bResetCompleted = false; } // 无论复位是否成功,都完成复位流程 m_bIsResetting = false; if (m_bStiffnessControlActive) { StartStiffnessControl(); } } // 安全打开连接(确保全局唯一) //short CappDlg::SafeOpenComm() //{ // // 如果已有有效连接,直接返回 // if (m_nRobotCid >= 0) { // return m_nRobotCid; // } // // // 设置连接使用标志 // if (m_bCommInUse.exchange(true)) { // OutputDebugString(_T("[WARN] 通信端口正忙,等待...")); // return -1; // } // // // 带重试的连接逻辑 // short nCid = -1; // for (int i = 0; i < MAX_RETRY; i++) { // nCid = TestOpenComm(0, 0); // if (nCid >= 0) { // m_nRobotCid = nCid; // OutputDebugString(_T("[INFO] 机器人连接成功")); // break; // } // OutputDebugString(_T("[WARN] 连接尝试失败,重试...")); // Sleep(RETRY_DELAY_MS); // } // // m_bCommInUse = false; // // if (nCid < 0) { // OutputDebugString(_T("[ERROR] 连接失败:达到最大重试次数")); // } // return nCid; //} short CappDlg::SafeOpenComm() { std::lock_guard<std::mutex> lock(m_commMutex); // 硬件初始化冷却期(复位后至少等待10秒) static std::chrono::time_point<std::chrono::steady_clock> lastResetTime; auto now = std::chrono::steady_clock::now(); if (now - lastResetTime < std::chrono::seconds(10)) { OutputDebugString(_T("[INFO] 处于复位冷却期,暂不连接")); return -1; } // 指数退避重试(最大等待5秒) int retryDelay = 500; // 初始500ms for (int i = 0; i < 15; i++) { // 增加重试次数 m_nRobotCid = TestOpenComm(0, 0); if (m_nRobotCid >= 0) { OutputDebugString(_T("[INFO] 连接成功")); return m_nRobotCid; } Sleep(retryDelay); retryDelay = min(retryDelay * 2, 5000); // 上限5秒 } return -1; } void CappDlg::SafeCloseComm() { std::lock_guard<std::mutex> lock(m_commMutex); if (m_nRobotCid >= 0) { TestCloseComm(m_nRobotCid); m_nRobotCid = -1; OutputDebugString(_T("[INFO] 机器人连接已关闭")); } } double CappDlg::CalculateAngleError(double current, double target) { double error = current - target; // 将误差标准化到[-180°, 180°] while (error > 180.0) { error -= 360.0; } while (error < -180.0) { error += 360.0; } return fabs(error); }以上是我的MFC程序 import numpy as np import pandas as pd import gymnasium as gym from gymnasium import spaces import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torch.distributions import Normal from watchdog.observers import Observer from watchdog.events import FileSystemEventHandler import time import skfuzzy as fuzz from skfuzzy import control as ctrl import os import logging import matplotlib.pyplot as plt from collections import deque import logging from logging.handlers import RotatingFileHandler import datetime import os # 创建日志目录 log_dir = "training_logs" os.makedirs(log_dir, exist_ok=True) # 带时间戳的日志文件名 timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") log_filename = os.path.join(log_dir, f'stiffness_control_{timestamp}.log') # 基本配置 - 控制台输出 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler()] ) # 添加文件日志 - 带轮转功能 file_handler = RotatingFileHandler( log_filename, maxBytes=10*1024*1024, # 10MB backupCount=5, encoding='utf-8' ) file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) logging.getLogger().addHandler(file_handler) # 设置更详细的文件日志级别 file_handler.setLevel(logging.DEBUG) # 日志记录 logging.info("日志系统初始化完成,日志将保存到: %s", log_filename) # 配置日志 #logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s') # 文件监控类 class StiffnessWatcher(FileSystemEventHandler): def __init__(self, env): super().__init__() self.env = env self.updated = False def on_modified(self, event): if event.is_directory: return if event.src_path.endswith("stiffness_state.csv"): self.updated = True logging.debug("检测到状态文件更新") # 机器人刚度控制环境 class StiffnessControlEnv(gym.Env): def __init__(self): super(StiffnessControlEnv, self).__init__() self.target_dir = r"F:\lry\weizi\app" os.makedirs(self.target_dir, exist_ok=True) # 信号文件路径 self.control_done_signal = os.path.join(self.target_dir, "control_done_signal.txt") self.state_ready_signal = os.path.join(self.target_dir, "state_ready_signal.txt") self.reset_signal = os.path.join(self.target_dir, "reset_signal.txt") self.initial_params_file = os.path.join(self.target_dir, "initial_stiffness_params.txt") self.new_params_file = os.path.join(self.target_dir, "new_stiffness_params.txt") self.state_file = os.path.join(self.target_dir, "robot_state.csv") # 清除旧信号文件 self.cleanup_previous_files() self.state_log_file = os.path.join(self.target_dir, "arm_state_log.csv") self._init_state_log() # 初始化状态日志 # 目标姿态: rx=180°, ry=0°, rz=0° self.target_orientation = np.array([180.0, 0.0, 0.0], dtype=np.float32) self.max_angle_error = 20.0 # 最大允许角度误差(度) # 动作空间: 三个轴的刚度参数[K_rx, K_ry, K_rz] self.action_space = spaces.Box( low=np.array([50, 50, 50], dtype=np.float32), high=np.array([500, 500, 500], dtype=np.float32), dtype=np.float32 ) # 状态空间: 力矩(tx, ty, tz) + 姿态(rx, ry, rz) self.observation_space = spaces.Box( low=np.array([-2.0, -2.0, -2.0, -180.0, -30.0, -10.0], dtype=np.float32), high=np.array([2.0, 2.0, 2.0, 180.0, 30.0, 10.0], dtype=np.float32), dtype=np.float32 ) self.state = np.zeros(6, dtype=np.float32) self.current_step = 0 self.max_steps = 200 self.last_error = np.inf self.last_torque = np.zeros(3, dtype=np.float32) # 创建模糊奖励系统 self.fuzzy_reward_system = self.create_fuzzy_system() # 文件监控 self.watcher = StiffnessWatcher(self) self.observer = Observer() self.observer.schedule(self.watcher, path=self.target_dir, recursive=True) self.observer.start() logging.info("环境初始化完成,开始监控文件夹") def _init_state_log(self): """初始化机械臂状态日志文件""" with open(self.state_log_file, 'w') as f: f.write("timestamp,step,tx,ty,tz,rx,ry,rz,target_rx,target_ry,target_rz\n") def _log_arm_state(self, step, state, action=None): """ 记录机械臂状态到日志文件终端 :param step: 当前步数 :param state: 状态数组 [tx,ty,tz,rx,ry,rz] :param action: 当前动作 [K_rx, K_ry, K_rz] (可选) """ timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f") tx, ty, tz, rx, ry, rz = state rx_display = self.to_180_degrees(rx) # 写入CSV日志 with open(self.state_log_file, 'a') as f: f.write(f"{timestamp},{step},{tx:.4f},{ty:.4f},{tz:.4f}," f"{rx:.2f},{ry:.2f},{rz:.2f}\n") # f"{self.target_orientation[0]}," # f"{self.target_orientation[1]}," # f"{self.target_orientation[2]}\n") # 实时控制台输出 status_msg = (f"机械臂状态 [Step {step}]:\n" f" - 力矩(tx,ty,tz): {tx:.3f}, {ty:.3f}, {tz:.3f} N·m\n" f" - 姿态(rx,ry,rz): {rx_display:.2f}, {ry:.2f}, {rz:.2f} °\n") # f" - 目标姿态: {self.target_orientation[0]:.2f}, " # f"{self.target_orientation[1]:.2f}, {self.target_orientation[2]:.2f} °") if action is not None: status_msg += f"\n - 当前刚度参数: {action[0]:.1f}, {action[1]:.1f}, {action[2]:.1f} N·m/rad" logging.info(status_msg) print(f"\n{'-'*40}\n{status_msg}\n{'-'*40}") # 控制台分隔线显示 def to_360_degrees(self, angle): """将-180°~180°的角度转换为0°~360°""" angle = float(angle) if angle < -180 or angle > 180: logging.warning(f"输入角度超出范围: {angle}°,将被标准化到-180°~180°") angle = (angle + 180) % 360 - 180 return (angle + 360) % 360 def to_180_degrees(self, angle): """将0°~360°的角度转换回-180°~180°(用于输出)""" angle = float(angle) if angle < 0 or angle > 360: logging.warning(f"输入角度超出范围: {angle}°,将被标准化到0°~360°") angle = angle % 360 return (angle + 180) % 360 - 180 def calculate_angle_error(self, current, target): """计算两个角度之间的最短误差,考虑周期性(-180°180°是同一点)""" error = current - target # 将误差标准化到[-180°, 180°] error = (error + 180) % 360 - 180 return abs(error) def cleanup_previous_files(self): """清除之前的信号文件""" for file in [self.control_done_signal, self.state_ready_signal, self.reset_signal, self.initial_params_file, self.new_params_file]: if os.path.exists(file): try: os.remove(file) except Exception as e: logging.warning(f"无法删除旧文件 {file}: {e}") def create_fuzzy_system(self): # 定义模糊变量 torque_range = np.arange(0, 2.1, 0.1) orientation_error_range = np.arange(0, 30.1, 0.1) reward_range = np.arange(0, 2.1, 0.1) # 输入变量 torque = ctrl.Antecedent(torque_range, 'torque') error = ctrl.Antecedent(orientation_error_range, 'error') # 输出变量 reward = ctrl.Consequent(reward_range, 'reward') # 定义隶属函数 torque['VS'] = fuzz.trimf(torque_range, [0, 0, 0.5]) torque['S'] = fuzz.trimf(torque_range, [0, 0.5, 1.0]) torque['M'] = fuzz.trimf(torque_range, [0.5, 1.0, 1.5]) torque['L'] = fuzz.trimf(torque_range, [1.0, 1.5, 2.0]) torque['VL'] = fuzz.trapmf(torque_range, [1.5, 2.0, 2.0, 2.0]) error['VS'] = fuzz.trimf(orientation_error_range, [0, 0, 3]) error['S'] = fuzz.trimf(orientation_error_range, [0, 3, 6]) error['M'] = fuzz.trimf(orientation_error_range, [3, 6, 10]) error['L'] = fuzz.trimf(orientation_error_range, [6, 10, 20]) error['VL'] = fuzz.trapmf(orientation_error_range, [10, 20, 30, 30]) reward['VL'] = fuzz.trimf(reward_range, [0, 0, 0.5]) reward['L'] = fuzz.trimf(reward_range, [0, 0.5, 1.0]) reward['M'] = fuzz.trimf(reward_range, [0.5, 1.0, 1.5]) reward['H'] = fuzz.trimf(reward_range, [1.0, 1.5, 2.0]) reward['VH'] = fuzz.trapmf(reward_range, [1.5, 2.0, 2.0, 2.0]) # 定义模糊规则 rules = [ ctrl.Rule(torque['VS'] & error['VS'], reward['VH']), ctrl.Rule(torque['S'] & error['VS'], reward['H']), ctrl.Rule(torque['M'] & error['VS'], reward['M']), ctrl.Rule(torque['L'] & error['VS'], reward['L']), ctrl.Rule(torque['VL'] & error['VS'], reward['VL']), ctrl.Rule(torque['VS'] & error['S'], reward['H']), ctrl.Rule(torque['S'] & error['S'], reward['H']), ctrl.Rule(torque['M'] & error['S'], reward['M']), ctrl.Rule(torque['L'] & error['S'], reward['L']), ctrl.Rule(torque['VL'] & error['S'], reward['VL']), ctrl.Rule(torque['VS'] & error['M'], reward['M']), ctrl.Rule(torque['S'] & error['M'], reward['M']), ctrl.Rule(torque['M'] & error['M'], reward['L']), ctrl.Rule(torque['L'] & error['M'], reward['L']), ctrl.Rule(torque['VL'] & error['M'], reward['VL']), ctrl.Rule(torque['VS'] & error['L'], reward['L']), ctrl.Rule(torque['S'] & error['L'], reward['L']), ctrl.Rule(torque['M'] & error['L'], reward['VL']), ctrl.Rule(torque['L'] & error['L'], reward['VL']), ctrl.Rule(torque['VL'] & error['L'], reward['VL']), ctrl.Rule(torque['VS'] & error['VL'], reward['VL']), ctrl.Rule(torque['S'] & error['VL'], reward['VL']), ctrl.Rule(torque['M'] & error['VL'], reward['VL']), ctrl.Rule(torque['L'] & error['VL'], reward['VL']), ctrl.Rule(torque['VL'] & error['VL'], reward['VL']), ] control_system = ctrl.ControlSystem(rules) return ctrl.ControlSystemSimulation(control_system) def reset(self): """重置环境并与MFC同步""" initial_params_path = os.path.join(self.target_dir, "initial_stiffness_params.txt") if not os.path.exists(initial_params_path): default_params = "300.0, 300.0, 300.0" # 默认刚度值 with open(initial_params_path, 'w') as f: f.write(default_params) logging.info(f"已自动生成初始参数文件: {default_params}") # 1. 发送重置信号 self.safe_write_file(self.reset_signal, "1") logging.info("已发送重置信号,等待MFC响应...") # 2. 等待MFC确认重置完成 if not self.wait_for_signal(self.control_done_signal, timeout=150.0): logging.error("MFC重置响应超时") raise TimeoutError("MFC reset timeout") # 3. 检查初始参数文件 if not os.path.exists(self.initial_params_file): logging.error("初始刚度参数文件不存在,请先生成初始参数") raise FileNotFoundError("initial_stiffness_params.txt not found") # 4. 读取初始参数(示例值,实际应从文件读取) with open(self.initial_params_file, 'r') as f: params = f.read().strip() logging.info(f"已加载初始刚度参数: {params}") #os.remove(self.initial_params_file) # 5. 读取初始状态 state = self.read_state(self.state_file) logging.info(f"初始状态: {state}") self._log_arm_state(0, state) # 初始状态记为step 0 return state, {} def safe_write_file(self, path, content): """安全写入文件,带有重试机制""" max_retries = 3 for attempt in range(max_retries): try: temp_path = path + ".tmp" with open(temp_path, "w") as f: f.write(content) if os.name == 'nt': if os.path.exists(path): os.remove(path) os.rename(temp_path, path) else: os.replace(temp_path, path) return True except Exception as e: logging.warning(f"写入文件失败 (尝试 {attempt+1}/{max_retries}): {e}") time.sleep(0.1) logging.error(f"写入文件失败: {path}") return False def read_state(self, state_path): """安全读取状态文件""" max_retries = 3 for attempt in range(max_retries): try: if not os.path.exists(state_path): raise FileNotFoundError(f"状态文件不存在: {state_path}") with open(state_path, 'r') as f: content = f.read().strip() if not content: raise ValueError("状态文件为空") data = [float(x) for x in content.split(',')] data[3] = self.to_360_degrees(data[3]) # 将-180°~180°转换为0°~360° if len(data) != 6: raise ValueError(f"无效数据长度: {len(data)}") return np.array(data, dtype=np.float32) except Exception as e: if attempt == max_retries - 1: logging.error(f"读取状态文件失败: {e}") return np.zeros(6, dtype=np.float32) time.sleep(0.5) def step(self, action): """执行一步动作""" # 1. 发送刚度参数 params_str = f"{action[0]:.2f},{action[1]:.2f},{action[2]:.2f}" self.safe_write_file(self.new_params_file, params_str) logging.info(f"已发送刚度参数: {params_str}") # 2. 等待MFC确认参数接收 if not self.wait_for_signal(self.control_done_signal, timeout=55.0): logging.warning("MFC参数接收确认超时") return self.state, -10, True, False, {} # 3. 等待状态更新 if not self.wait_for_signal(self.state_ready_signal, timeout=15.0): logging.warning("状态更新超时") return self.state, -10, True, False, {} # 4. 读取新状态 new_state = self.read_state(self.state_file) # 5. 计算角度误差(考虑rx的特殊性) rx, ry, rz = new_state[3:6] rx_error = self.calculate_angle_error(rx, self.target_orientation[0]) ry_error = self.calculate_angle_error(ry, self.target_orientation[1]) rz_error = self.calculate_angle_error(rz, self.target_orientation[2]) angle_errors = np.array([rx_error, ry_error, rz_error]) # 6. 检查角度误差是否超限 error_exceeded = np.any(angle_errors > self.max_angle_error) # 7. 计算奖励 reward = self.calculate_reward(new_state) # 8. 确定完成标志(结合原有逻辑角度误差) done = error_exceeded or np.all(angle_errors < 0.5) or (self.current_step >= self.max_steps) # 9. 如果误差超限,执行复位 if error_exceeded: logging.warning(f"角度误差超限: {angle_errors} > {self.max_angle_error},触发复位") self._reset_robot() reward -= 100 # 惩罚误差超限 # 10. 更新状态 self.state = new_state self.current_step += 1 self._log_arm_state(self.current_step, new_state, action) # 11. 返回结果 return new_state, reward, done, False, { "error": angle_errors, "torque": new_state[:3], "orientation": new_state[3:6] } def _reset_robot(self): """执行机器人复位流程""" # 发送重置信号 self.safe_write_file("reset_signal.txt", "1") # 等待MFC响应 if not self.wait_for_signal("control_done_signal.txt", timeout=150): logging.error("等待MFC重置超时") # 重置内部计数器 self.current_step = 0 def wait_for_signal(self, signal_path, timeout=10.0): """等待信号文件出现并删除它""" start_time = time.time() while time.time() - start_time < timeout: if os.path.exists(signal_path): try: os.remove(signal_path) return True except Exception as e: logging.warning(f"删除信号文件失败: {e}") time.sleep(0.1) time.sleep(0.2) # 避免CPU占用过高 return False def wait_for_update(self): """仅通过修改时间(mtime)检测文件更新,并打印文件内容到日志""" timeout = 10.0 # 超时时间10秒 start_time = time.time() state_path = os.path.join(self.target_dir, "stiffness_state.csv") # 记录初始修改时间,处理文件不存在的情况 initial_mtime = 0 if os.path.exists(state_path): initial_mtime = os.path.getmtime(state_path) logging.debug(f"初始文件修改时间: {initial_mtime}") while time.time() - start_time < timeout: if os.path.exists(state_path): current_mtime = os.path.getmtime(state_path) # 检查修改时间是否变化(添加微小阈值避免浮点数误差) if current_mtime > initial_mtime + 1e-6: # 读取文件内容并打印到日志 try: with open(state_path, 'r') as f: content = f.read().strip() logging.info(f"状态文件已更新,内容: {content}") # 关键新增:打印文件内容 except Exception as e: logging.warning(f"读取更新后的文件内容失败: {e}") logging.debug(f"文件修改时间更新: 旧={initial_mtime} → 新={current_mtime}") time.sleep(0.1) # 等待文件系统刷新 return True time.sleep(0.2) # 轮询间隔200ms logging.warning(f"等待状态更新超时(超时时间:{timeout}秒)") return False # def wait_for_update(self): # """仅通过修改时间(mtime)检测文件更新,提高检测灵敏度""" # timeout = 10.0 # 延长超时时间至10秒,给机械臂更多响应时间 # start_time = time.time() # state_path = os.path.join(self.target_dir, "stiffness_state.csv") # # 记录初始修改时间,处理文件不存在的情况 # initial_mtime = 0 # if os.path.exists(state_path): # initial_mtime = os.path.getmtime(state_path) # logging.debug(f"初始文件修改时间: {initial_mtime}") # while time.time() - start_time < timeout: # if os.path.exists(state_path): # current_mtime = os.path.getmtime(state_path) # # 关键修改:仅检查修改时间是否变化,且设置一个小的阈值避免浮点数误差 # if current_mtime > initial_mtime + 1e-6: # 添加微小阈值,避免精度误差导致的误判 # logging.debug(f"文件已更新,新修改时间: {current_mtime}") # time.sleep(0.1) # 等待文件系统完全刷新 # return True # time.sleep(0.2) # 轮询间隔调整为200ms,平衡性能响应速度 # logging.warning(f"等待状态更新超时(超时时间:{timeout}秒)") # return False def calculate_reward(self, state): """计算奖励,综合考虑力矩、姿态误差误差变化""" torque = state[0:3] orientation = state[3:6] # 计算当前误差(考虑rx的周期性) rx_error = self.calculate_angle_error(orientation[0], self.target_orientation[0]) error = np.array([ rx_error, abs(orientation[1] - self.target_orientation[1]), abs(orientation[2] - self.target_orientation[2]) ]) current_error = np.linalg.norm(error) # 模糊奖励 self.fuzzy_reward_system.input['torque'] = np.mean(np.abs(torque)) self.fuzzy_reward_system.input['error'] = current_error self.fuzzy_reward_system.compute() base_reward = self.fuzzy_reward_system.output['reward'] # 初始化改进奖励 improvement_reward = 0.0 # 只有在不是第一次计算时才计算改进奖励 if hasattr(self, 'last_error') and not np.isinf(self.last_error): error_diff = self.last_error - current_error if error_diff > 0: # 误差减小 improvement_reward = 2.0 * error_diff else: # 误差增大 improvement_reward = 0.5 * error_diff # 注意这里已经是负值 # 更新last_errorlast_torque self.last_error = current_error self.last_torque = torque.copy() # 其他奖励组件保持不变... goal_reward = 5.0 if np.all(error < 5.0) else 0.0 error_penalty = -0.2 * current_error if current_error > 15 else 0 step_penalty = -0.01 stability_reward = 0.2 if np.linalg.norm(torque - self.last_torque) < 0.1 else 0 # 总奖励 total_reward = base_reward + improvement_reward + goal_reward + error_penalty + step_penalty + stability_reward logging.info(f"奖励组成: 基础={base_reward:.4f}, 改进={improvement_reward:.4f}, " f"目标={goal_reward:.4f}, 误差惩罚={error_penalty:.4f}, " f"步数={step_penalty:.4f}, 稳定={stability_reward:.4f}") return total_reward def close(self): self.observer.stop() self.observer.join() logging.info("环境已关闭") # 策略网络 class PolicyNetwork(nn.Module): def __init__(self, state_dim, action_dim, hidden_dim=128): super(PolicyNetwork, self).__init__() self.fc1 = nn.Linear(state_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, hidden_dim) # 均值标准差层 self.mean_layer = nn.Linear(hidden_dim, action_dim) self.log_std_layer = nn.Linear(hidden_dim, action_dim) # 初始化权重 self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): nn.init.kaiming_normal_(module.weight, a=0.01) if module.bias is not None: module.bias.data.zero_() def forward(self, x): x = F.leaky_relu(self.fc1(x), 0.01) x = F.leaky_relu(self.fc2(x), 0.01) # 计算均值标准差 mean = self.mean_layer(x) log_std = self.log_std_layer(x) # 限制标准差范围,避免探索性过大或过小 log_std = torch.clamp(log_std, -20, 2) std = log_std.exp() return mean, std def get_action(self, state): state = torch.FloatTensor(state).unsqueeze(0) mean, std = self.forward(state) # 创建正态分布 dist = Normal(mean, std) action = dist.sample() # 确保动作在有效范围内 action = torch.clamp(action, torch.FloatTensor([50, 50, 50]), torch.FloatTensor([500, 500, 500])) # 计算对数概率 log_prob = dist.log_prob(action).sum(-1) return action.detach().numpy()[0], log_prob.detach().numpy()[0] def evaluate(self, state, action): mean, std = self.forward(state) dist = Normal(mean, std) # 计算对数概率熵 action_logprobs = dist.log_prob(action).sum(-1) dist_entropy = dist.entropy().sum(-1) return action_logprobs, dist_entropy # 价值网络 class ValueNetwork(nn.Module): def __init__(self, state_dim, hidden_dim=128): super(ValueNetwork, self).__init__() self.fc1 = nn.Linear(state_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, hidden_dim) self.value_layer = nn.Linear(hidden_dim, 1) # 初始化权重 self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): nn.init.kaiming_normal_(module.weight, a=0.01) if module.bias is not None: module.bias.data.zero_() def forward(self, x): x = F.leaky_relu(self.fc1(x), 0.01) x = F.leaky_relu(self.fc2(x), 0.01) value = self.value_layer(x) return value # PPO代理 class PPO: def __init__(self, state_dim, action_dim, lr=3e-4, gamma=0.99, eps_clip=0.2, value_coef=0.5, entropy_coef=0.01, hidden_dim=128): self.gamma = gamma self.eps_clip = eps_clip self.value_coef = value_coef self.entropy_coef = entropy_coef # 策略网络价值网络 self.policy = PolicyNetwork(state_dim, action_dim, hidden_dim) self.value = ValueNetwork(state_dim, hidden_dim) # 优化器 self.policy_optimizer = optim.Adam(self.policy.parameters(), lr=lr) self.value_optimizer = optim.Adam(self.value.parameters(), lr=lr) # 存储轨迹 self.states = [] self.actions = [] self.logprobs = [] self.rewards = [] self.is_terminals = [] def select_action(self, state): action, log_prob = self.policy.get_action(state) return action, log_prob def store_transition(self, state, action, log_prob, reward, done): self.states.append(state) self.actions.append(action) self.logprobs.append(log_prob) self.rewards.append(reward) self.is_terminals.append(done) def update(self, epochs=10): # 转换为张量 states = torch.FloatTensor(self.states) actions = torch.FloatTensor(self.actions) old_logprobs = torch.FloatTensor(self.logprobs) # 计算回报 returns = [] discounted_return = 0 for reward, is_terminal in zip(reversed(self.rewards), reversed(self.is_terminals)): if is_terminal: discounted_return = 0 discounted_return = reward + (self.gamma * discounted_return) returns.insert(0, discounted_return) returns = torch.FloatTensor(returns) # 标准化回报 returns = (returns - returns.mean()) / (returns.std() + 1e-8) # 优化策略价值网络 for _ in range(epochs): # 获取当前策略的对数概率熵 logprobs, entropy = self.policy.evaluate(states, actions) # 计算策略比率 ratio = torch.exp(logprobs - old_logprobs.detach()) # 计算优势函数 values = self.value(states).squeeze() advantage = returns - values.detach() # 标准化优势 advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8) # 计算策略损失 surr1 = ratio * advantage surr2 = torch.clamp(ratio, 1-self.eps_clip, 1+self.eps_clip) * advantage policy_loss = -torch.min(surr1, surr2).mean() # 计算价值损失 value_loss = F.mse_loss(values, returns) # 计算总损失 loss = policy_loss + self.value_coef * value_loss - self.entropy_coef * entropy.mean() # 更新策略网络 self.policy_optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_(self.policy.parameters(), 0.5) # 梯度裁剪 self.policy_optimizer.step() # 更新价值网络 self.value_optimizer.zero_grad() value_loss.backward() nn.utils.clip_grad_norm_(self.value.parameters(), 0.5) # 梯度裁剪 self.value_optimizer.step() # 清空轨迹 self.states.clear() self.actions.clear() self.logprobs.clear() self.rewards.clear() self.is_terminals.clear() # 训练函数 def train_ppo(env, total_episodes=1000, max_steps_per_episode=200, update_timestep=2000, save_interval=100, log_interval=10, render_interval=50): state_dim = env.observation_space.shape[0] action_dim = env.action_space.shape[0] # 创建PPO代理 agent = PPO(state_dim, action_dim, lr=3e-4, gamma=0.99, eps_clip=0.2, value_coef=0.5, entropy_coef=0.01, hidden_dim=128) # 记录训练过程 episode_rewards = [] avg_rewards = [] best_reward = -np.inf # 训练循环 global_timestep = 0 for episode in range(total_episodes): state, _ = env.reset() episode_reward = 0 for step in range(max_steps_per_episode): global_timestep += 1 # 选择动作 action, log_prob = agent.select_action(state) # 执行动作 next_state, reward, terminated, truncated, info = env.step(action) done = terminated or truncated # 存储转换 agent.store_transition(state, action, log_prob, reward, done) # 更新状态 state = next_state episode_reward += reward # 检查是否需要更新策略 if global_timestep % update_timestep == 0: agent.update() logging.info(f"更新策略: 全局步数={global_timestep}") # 检查是否完成 if done: break # 记录奖励 episode_rewards.append(episode_reward) avg_reward = np.mean(episode_rewards[-10:]) avg_rewards.append(avg_reward) # 保存最佳模型 if avg_reward > best_reward: best_reward = avg_reward torch.save(agent.policy.state_dict(), 'best_stiffness_policy.pth') logging.info(f"保存最佳模型: 平均奖励={avg_reward:.2f}") # 定期保存模型 if (episode + 1) % save_interval == 0: torch.save(agent.policy.state_dict(), f'stiffness_policy_episode_{episode+1}.pth') # 打印训练信息 if (episode + 1) % log_interval == 0: logging.info(f"Episode {episode+1}/{total_episodes} - " f"奖励: {episode_reward:.2f}, 平均奖励: {avg_reward:.2f}, " f"步数: {step+1}, 误差: {info['error'].round(2)}") # 绘制学习曲线 if (episode + 1) % 50 == 0: plt.figure(figsize=(10, 5)) plt.plot(episode_rewards, alpha=0.5, label='单轮奖励') plt.plot(avg_rewards, label='10轮平均奖励') plt.xlabel('轮次') plt.ylabel('奖励') plt.title('PPO训练进度') plt.legend() plt.grid(True) plt.savefig('training_curve.png') plt.close() # 保存最终模型 torch.save(agent.policy.state_dict(), 'final_stiffness_policy.pth') logging.info("训练完成,模型已保存") return agent, episode_rewards, avg_rewards # 测试训练好的策略 def test_policy(env, policy_path='best_stiffness_policy.pth', episodes=5): state_dim = env.observation_space.shape[0] action_dim = env.action_space.shape[0] # 加载策略网络 policy = PolicyNetwork(state_dim, action_dim) policy.load_state_dict(torch.load(policy_path)) policy.eval() for episode in range(episodes): state, _ = env.reset() done = False episode_reward = 0 steps = 0 logging.info(f"\n===== 测试第 {episode+1}/{episodes} 轮 =====") while not done: # 选择确定性动作(均值) with torch.no_grad(): state_tensor = torch.FloatTensor(state).unsqueeze(0) mean, _ = policy(state_tensor) action = mean.squeeze().numpy() # 执行动作 next_state, reward, terminated, truncated, info = env.step(action) done = terminated or truncated episode_reward += reward steps += 1 # 打印当前状态 logging.info(f"步骤 {steps}: 动作={action.round(2)}, 奖励={reward:.4f}, " f"误差={info['error'].round(2)}, 力矩={info['torque'].round(3)}") state = next_state logging.info(f"📊 第 {episode+1} 轮完成: 总奖励={episode_reward:.4f}, 总步数={steps}") # 主函数 if __name__ == "__main__": try: logging.info("===== 开始机器人刚度控制强化学习训练 =====") # 创建环境 env = StiffnessControlEnv() # 训练模型 agent, rewards, avg_rewards = train_ppo( env, total_episodes=500, max_steps_per_episode=200, update_timestep=2000, save_interval=50, log_interval=10 ) # 测试模型 test_policy(env) except KeyboardInterrupt: logging.info("训练被用户中断") finally: # 确保环境被正确关闭 if 'env' in locals() and env is not None: env.close() logging.info("程序已退出") 这是我的python程序 修改我的程序,我要求,先启动python程序但并不开始强化学习,直到MFC程序启动变刚度导纳控制后给python力矩姿态信息,python收到后开始强化学习,把计算的刚度参数传递给MFC使机械臂开始调整姿态,当机械臂完成一次调整后把力矩姿态信息再传给python计算新的刚度参数,当机械臂到达目标姿态时,MFC程序把力矩姿态信息传给python后,证明一个回合结束,该开始下一回合,MFC程序使机械臂位置重置,重置完成后给python力矩姿态信息开始新的回合,直至训练结束,要求尽量在我的程序基础上不修改太多,保证程序运行时尽量不出现线程阻塞或者无法连接机械臂的情况,修改保持原有程序的主要结构,只调整交互流程控制逻辑。MFC程序负责主动控制机械臂重置,Python程序专注于强化学习计算。
07-16
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值