matlab中的[~,m]=size(coord)是什么意思

本文详细解析了MATLAB中size函数的使用方法及应用案例,包括如何获取矩阵的行数、列数以及理解占位符~的作用,特别强调了自2009版之后的版本变化。
部署运行你感兴趣的模型镜像
[~,m]=size(coord)中size得到coord的行列数赋值给[~,m],~表示占位,是2009以后的版本才用的!就是只要列m的值!

您可能感兴趣的与本文相关的镜像

Stable-Diffusion-3.5

Stable-Diffusion-3.5

图片生成
Stable-Diffusion

Stable Diffusion 3.5 (SD 3.5) 是由 Stability AI 推出的新一代文本到图像生成模型,相比 3.0 版本,它提升了图像质量、运行速度和硬件效率

clc; clear; % ========== 初始化COMSOL模型 ========== try model = mphopen('double-layer3.mph'); % 替换为实际模型路径 model.param.set('lambda', '1.064[um]', '工作波长'); catch ME error('无法打开COMSOL模型文件: %s', ME.message); end % ========== 光纤参数定义 ========== dco_value = 80; % μm (纤芯直径) core_radius = dco_value / 2 * 1e-6; % 转换为米 n_high = 1.4738; % 反谐振单元折射率(As2S3) n_low = 1.4496; % 背景材料折射率(SiO2) n_air = 1.0; % 环境折射率 % ========== PSO参数设置 ========== lb = [0.6, 0.3]; % [dcl2/dco, dcl1/dcl2] 下限 ub = [0.8, 0.4]; % 上限 nPop = 20; % 粒子数 maxIter = 15; % 迭代次数 w = 0.7; % 惯性权重 c1 = 1.5; % 认知系数 c2 = 1.5; % 社会系数 % ========== 初始化优化算法 ========== particles = lb + rand(nPop, 2) .* (ub - lb); velocities = 0.1 * rand(nPop, 2); particles(1, :) = [0.61, 0.35]; % 种子粒子(可自定义初始值) % ========== 结果存储初始化 ========== nx = 50; ny = 50; [dcl2_grid, dcl1_ratio_grid] = meshgrid(linspace(lb(1), ub(1), nx), linspace(lb(2), ub(2), ny)); fund_loss_grid = nan(nx, ny); ho_loss_grid = nan(nx, ny); mfa_grid = nan(nx, ny); homsr_grid = nan(nx, ny); convergence_MFA = zeros(maxIter, 1); % ========== 最优解初始化 ========== pBest = particles; pBestScore = inf(nPop, 1); gBest = particles(1, :); gBestScore = inf; best_MFA = 0; best_neff = 0; best_loss = 0; best_ho_loss = NaN; best_ho_neff = NaN; best_HOMSR = Inf; % ========== 主优化循环 ========== for iter = 1:maxIter fprintf('========== 迭代 %d/%d ==========\n', iter, maxIter); for i = 1:nPop % 更新结构参数 dcl2_ratio = particles(i, 1); dcl1_ratio = particles(i, 2); model.param.set('dco', sprintf('%.2f', dco_value)); model.param.set('k1', sprintf('%.4f', dcl2_ratio)); % dcl2/dco model.param.set('k2', sprintf('%.4f', dcl1_ratio)); % dcl1/dcl2 try % ===== 生成计算坐标 ===== x = linspace(-1.5*core_radius, 1.5*core_radius, 300); y = linspace(-1.5*core_radius, 1.5*core_radius, 300); dx = x(2) - x(1); dy = y(2) - y(1); area_element = dx * dy; [X, Y] = meshgrid(x, y); coord = [X(:)'; Y(:)']; % 2xN 坐标矩阵 num_points = size(coord, 2); % 记录坐标点数量(关键参考值) fprintf('坐标点数量: %d\n', num_points); % ===== 第一次求解:基模搜索(基准值1.44957) ===== eigen_solver = 'sol1'; dataset_name = 'dset1'; % 清除现有特征值设置 tags_java = model.sol(eigen_solver).feature.tags; tags_cell = cellfun(@(x) char(x), cell(tags_java), 'UniformOutput', false); if any(strcmp('e1', tags_cell)) model.sol(eigen_solver).feature.remove('e1'); end % 设置基模搜索参数 e1 = model.sol(eigen_solver).feature.create('e1', 'Eigenvalue'); e1.set('shift', '1.44957'); % 基模基准值 e1.set('neigs', 20); % 求解20个模式 model.sol(eigen_solver).runAll; % 获取基模候选数据 neff_fund_candidates = mphinterp(model, 'emw.neff', 'coord', [0;0], 'dataset', dataset_name, 'solnum', 'all'); loss_fund_candidates = mphinterp(model, 'emw.dampzdB', 'coord', [0;0], 'dataset', dataset_name, 'solnum', 'all'); mfa_fund_candidates = mphinterp(model, 'MFA', 'coord', [0;0], 'dataset', dataset_name, 'solnum', 'all') * 1e12; % 获取电场分布并强制维度匹配 n_modes_fund = length(neff_fund_candidates); E_field_fund = cell(1, n_modes_fund); for k = 1:n_modes_fund Ex = mphinterp(model, 'emw.Ex', 'coord', coord, 'dataset', dataset_name, 'solnum', k); Ey = mphinterp(model, 'emw.Ey', 'coord', coord, 'dataset', dataset_name, 'solnum', k); % 强制调整为与坐标点数量一致(关键修复) if length(Ex) ~= num_points warning('粒子 %d 模式 %d 电场Ex维度不匹配,强制调整: %d -> %d', ... i, k, length(Ex), num_points); % 使用线性插值强制匹配维度 Ex = interp1(linspace(1, length(Ex), length(Ex)), Ex, ... linspace(1, length(Ex), num_points), 'linear', 0); Ey = interp1(linspace(1, length(Ey), length(Ey)), Ey, ... linspace(1, length(Ey), num_points), 'linear', 0); end % 确保是列向量 Ex = Ex(:); Ey = Ey(:); E_field_fund{k} = [Ex, Ey]; end % ===== 基模识别(LP01模式特征) ===== fund_mode_idx = NaN; for k = 1:n_modes_fund % 基模特征:1. 轴对称 2. 纤芯能量占比高 3. 无径向节点 is_axsym = is_axially_symmetric(E_field_fund{k}, X, Y); core_ratio = compute_core_energy_ratio(E_field_fund{k}, coord, core_radius, n_high, n_low, area_element); radial_nodes = count_radial_nodes(E_field_fund{k}, coord, core_radius); if is_axsym && (core_ratio > 0.7) && (radial_nodes == 0) fund_mode_idx = k; break; end end if isnan(fund_mode_idx) error('未在第一次求解中识别到基模!'); end % 基模特性 fundamental_neff = real(neff_fund_candidates(fund_mode_idx)); fundamental_loss = loss_fund_candidates(fund_mode_idx); fundamental_MFA = mfa_fund_candidates(fund_mode_idx); fprintf('--> 基模识别: 模式 %d (neff=%.6f, 损耗=%.4f dB/m)\n', ... fund_mode_idx, fundamental_neff, fundamental_loss); % ===== 第二次求解:高阶模搜索(基准值1.44953) ===== % 清除现有特征值设置 if any(strcmp('e1', tags_cell)) model.sol(eigen_solver).feature.remove('e1'); end % 设置高阶模搜索参数 e1 = model.sol(eigen_solver).feature.create('e1', 'Eigenvalue'); e1.set('shift', '1.44953'); % 高阶模基准值 e1.set('neigs', 60); % 求解更多模式 model.sol(eigen_solver).runAll; % 获取高阶模候选数据 neff_ho_candidates = mphinterp(model, 'emw.neff', 'coord', [0;0], 'dataset', dataset_name, 'solnum', 'all'); loss_ho_candidates = mphinterp(model, 'emw.dampzdB', 'coord', [0;0], 'dataset', dataset_name, 'solnum', 'all'); % 获取电场分布并强制维度匹配 n_modes_ho = length(neff_ho_candidates); E_field_ho = cell(1, n_modes_ho); for k = 1:n_modes_ho Ex = mphinterp(model, 'emw.Ex', 'coord', coord, 'dataset', dataset_name, 'solnum', k); Ey = mphinterp(model, 'emw.Ey', 'coord', coord, 'dataset', dataset_name, 'solnum', k); % 强制调整为与坐标点数量一致(关键修复) if length(Ex) ~= num_points warning('粒子 %d 高阶模式 %d 电场Ex维度不匹配,强制调整: %d -> %d', ... i, k, length(Ex), num_points); Ex = interp1(linspace(1, length(Ex), length(Ex)), Ex, ... linspace(1, length(Ex), num_points), 'linear', 0); Ey = interp1(linspace(1, length(Ey), length(Ey)), Ey, ... linspace(1, length(Ey), num_points), 'linear', 0); end % 确保是列向量 Ex = Ex(:); Ey = Ey(:); E_field_ho{k} = [Ex, Ey]; end % ===== 高阶模识别与最小损耗筛选 ===== ho_losses = []; ho_indices = []; for k = 1:n_modes_ho % 高阶模特征:1. 非轴对称 2. 有径向节点 3. 排除基模特征 is_high_order = ~is_axially_symmetric(E_field_ho{k}, X, Y); core_ratio = compute_core_energy_ratio(E_field_ho{k}, coord, core_radius, n_high, n_low, area_element); radial_nodes = count_radial_nodes(E_field_ho{k}, coord, core_radius); % 判定为高阶模的条件 if (is_high_order || radial_nodes >= 1) && (core_ratio < 0.7) ho_indices(end+1) = k; ho_losses(end+1) = loss_ho_candidates(k); end end % 寻找最小损耗高阶模 if ~isempty(ho_losses) [min_ho_loss, min_idx] = min(ho_losses); min_ho_idx = ho_indices(min_idx); best_ho_neff = real(neff_ho_candidates(min_ho_idx)); HOMSR = min_ho_loss / fundamental_loss; fprintf('--> 高阶模识别: 模式 %d (损耗=%.4f dB/m, HOMSR=%.2f)\n', ... min_ho_idx, min_ho_loss, HOMSR); else min_ho_loss = NaN; best_ho_neff = NaN; HOMSR = Inf; fprintf('警告: 未识别到有效高阶模!\n'); end % ===== 适应度计算 ===== if ~isnan(min_ho_loss) && (fundamental_loss < 0.1) && (HOMSR > 100) fitness = -fundamental_MFA; % 最大化模场面积 else fitness = inf; end % 存储绘图数据 [~, xi] = min(abs(dcl2_grid(1,:) - dcl2_ratio)); [~, yi] = min(abs(dcl1_ratio_grid(:,1) - dcl1_ratio)); fund_loss_grid(xi, yi) = fundamental_loss; ho_loss_grid(xi, yi) = min_ho_loss; mfa_grid(xi, yi) = fundamental_MFA; homsr_grid(xi, yi) = HOMSR; catch ME warning('粒子 %d 迭代失败: %s (文件: %s, 行号: %d)', ... i, ME.message, ME.stack(1).name, ME.stack(1).line); fitness = inf; continue; end % 更新个体最优 if fitness < pBestScore(i) pBest(i, :) = particles(i, :); pBestScore(i) = fitness; end % 更新全局最优 if fitness < gBestScore gBest = particles(i, :); gBestScore = fitness; best_MFA = fundamental_MFA; best_neff = fundamental_neff; best_loss = fundamental_loss; best_ho_loss = min_ho_loss; best_ho_neff = best_ho_neff; best_HOMSR = HOMSR; end end % 记录收敛历史 convergence_MFA(iter) = best_MFA; % 输出迭代信息 fprintf('迭代 %d 最优: MFA=%.4f μm², 基模损耗=%.4f dB/m, HOMSR=%.2f\n', ... iter, best_MFA, best_loss, best_HOMSR); % 更新粒子位置 r1 = rand(nPop, 1); r2 = rand(nPop, 1); velocities = w*velocities + c1*r1.*(pBest - particles) + c2*r2.*(gBest - particles); particles = max(min(particles + velocities, ub), lb); end % ========== 结果可视化 ========== % 1. 结构可视化 try figure; mphplot(model, 'pg2', 'surface', 'on'); title('反谐振光纤结构'); catch warning('结构可视化失败'); end % 2. 基模场分布 try figure; mphplot(model, 'pg1', 'dataset', dataset_name, 'solnum', fund_mode_idx, 'data', 'emw.normE'); title(sprintf('基模场分布 (neff=%.4f)', best_neff)); catch warning('基模场分布可视化失败'); end % 3. 优化结果热力图 figure('Position', [100, 100, 1200, 800]); subplot(2,2,1); if any(~isnan(fund_loss_grid(:))) contourf(dcl2_grid, dcl1_ratio_grid, fund_loss_grid', 20, 'LineColor', 'none'); colorbar; title('基模损耗 (dB/m)'); xlabel('dcl2/dco'); ylabel('dcl1/dcl2'); else title('基模损耗 (无有效数据)'); end subplot(2,2,2); if any(~isnan(ho_loss_grid(:))) contourf(dcl2_grid, dcl1_ratio_grid, ho_loss_grid', 20, 'LineColor', 'none'); colorbar; title('高阶模最小损耗 (dB/m)'); xlabel('dcl2/dco'); ylabel('dcl1/dcl2'); else title('高阶模最小损耗 (无有效数据)'); end subplot(2,2,3); if any(~isnan(mfa_grid(:))) contourf(dcl2_grid, dcl1_ratio_grid, mfa_grid', 20, 'LineColor', 'none'); colorbar; title('模场面积 (μm²)'); xlabel('dcl2/dco'); ylabel('dcl1/dcl2'); else title('模场面积 (无有效数据)'); end subplot(2,2,4); if any(~isnan(homsr_grid(:))) contourf(dcl2_grid, dcl1_ratio_grid, homsr_grid', 20, 'LineColor', 'none'); colorbar; title('高阶模抑制比'); xlabel('dcl2/dco'); ylabel('dcl1/dcl2'); else title('高阶模抑制比 (无有效数据)'); end % 4. 收敛曲线 if any(convergence_MFA > 0) figure; plot(1:maxIter, convergence_MFA, 'r-s', 'LineWidth', 1.5); xlabel('迭代次数'); ylabel('模场面积 (μm²)'); title('MFA收敛曲线'); grid on; else warning('无有效收敛数据'); end % ========== 输出最终结果 ========== fprintf('\n====== 最优结构参数 ======\n'); fprintf('dcl2/dco 比例: %.4f\n', gBest(1)); fprintf('dcl1/dcl2 比例: %.4f\n', gBest(2)); fprintf('实际 dcl2: %.2f μm\n', dco_value * gBest(1)); fprintf('实际 dcl1: %.2f μm\n', dco_value * gBest(1) * gBest(2)); fprintf('\n====== 模式特性 ======\n'); fprintf('基模有效折射率(neff): %.6f\n', best_neff); fprintf('基模损耗: %.4f dB/m\n', best_loss); fprintf('模场面积(MFA): %.2f μm²\n', best_MFA); fprintf('高阶模最小损耗: %.4f dB/m\n', best_ho_loss); fprintf('高阶模抑制比(HOMSR): %.2f\n', best_HOMSR); % ========== 辅助函数 ========== function is_sym = is_axially_symmetric(E_field, X, Y) % 判断场分布是否轴对称(旋转不变性) E_mag = sqrt(sum(abs(E_field).^2, 2)); % 电场模值 E_mag = reshape(E_mag, size(X)); % 恢复2D场分布 % 测试多个角度的旋转对称性 angles = 0:30:150; % 测试角度 corr_scores = zeros(size(angles)); for j = 1:length(angles) theta = deg2rad(angles(j)); X_rot = X*cos(theta) - Y*sin(theta); Y_rot = X*sin(theta) + Y*cos(theta); % 插值旋转后的场分布 E_rot = griddata(X(:), Y(:), E_mag(:), X_rot, Y_rot, 'cubic'); E_rot(isnan(E_rot)) = 0; % 处理NaN值 corr_scores(j) = corr2(E_mag, E_rot); end is_sym = all(corr_scores > 0.9); % 高相关性判定为轴对称 end function ratio = compute_core_energy_ratio(E_field, coord, core_radius, n_core, n_clad, area_element) % 计算纤芯能量占比,增强维度一致性保障 r = sqrt(coord(1,:).^2 + coord(2,:).^2); n_map = n_clad * ones(size(r)); n_map(r <= core_radius) = n_core; % 提取电场分量并确保为列向量 Ex = E_field(:,1); Ey = E_field(:,2); % 最终维度检查与修复(最后一道防线) if length(Ex) ~= length(n_map) warning('强制修复能量计算维度不匹配: Ex长度=%d, n_map长度=%d', length(Ex), length(n_map)); % 强制调整到相同长度 target_len = max(length(Ex), length(n_map)); Ex = interp1(1:length(Ex), Ex, linspace(1, length(Ex), target_len)); Ey = interp1(1:length(Ey), Ey, linspace(1, length(Ey), target_len)); n_map = interp1(1:length(n_map), n_map, linspace(1, length(n_map), target_len)); r = interp1(1:length(r), r, linspace(1, length(r), target_len)); end energy_density = (n_map.^2) .* (abs(Ex).^2 + abs(Ey).^2); core_energy = sum(energy_density(r <= core_radius)) * area_element; total_energy = sum(energy_density) * area_element; % 避免除零错误 if total_energy == 0 || isnan(total_energy) ratio = 0; else ratio = core_energy / total_energy; end end function count = count_radial_nodes(E_field, coord, core_radius) % 计算径向节点数(场强为零的环数) r = sqrt(coord(1,:).^2 + coord(2,:).^2); E_mag = sqrt(sum(abs(E_field).^2, 2)); % 确保维度一致 if length(E_mag) ~= length(r) warning('强制修复径向节点计算维度不匹配: E_mag长度=%d, r长度=%d', length(E_mag), length(r)); target_len = max(length(E_mag), length(r)); E_mag = interp1(1:length(E_mag), E_mag, linspace(1, length(E_mag), target_len)); r = interp1(1:length(r), r, linspace(1, length(r), target_len)); end % 筛选纤芯内的点并按半径排序 idx = r <= core_radius & r > 0; [r_sorted, idx_sort] = sort(r(idx)); E_sorted = E_mag(idx(idx_sort)); % 平滑处理 E_smoothed = smooth(E_sorted, 10); % 寻找过零点(节点) crossings = 0; for j = 2:length(E_smoothed) if sign(E_smoothed(j)) ~= sign(E_smoothed(j-1)) && ... abs(E_smoothed(j)) < 0.3*max(E_smoothed) % 排除边缘效应 crossings = crossings + 1; end end count = crossings; end ========== 迭代 1/15 ========== 坐标点数量: 90000 警告: 粒子 1 迭代失败: 逻辑 AND (&&)和 OR (||)运算符的操作数必须可转换为标量逻辑值。请使用 ANY 或 ALL 函数将操作数简化为标量逻辑值。 (文件: compute_core_energy_ratio, 行号: 28) > 位置:demo2 (第 233 行) 坐标点数量: 90000 根据报错修改代码
07-30
% 清除工作区,关闭窗口 clear; clc; close all; % ============================================== % 参数定义(对应arguments.py) % ============================================== classdef Args properties time_date date seed = 125 n_agent = 3 clip_obs = 5 actor_num = 2 clip_range = 200 action_bound = 1 demo_length = 25 Use_GUI = true env_params % 环境参数结构体 train_params % 训练参数结构体 end methods function obj = Args() % 初始化时间 obj.time_date = clock; obj.date = sprintf('%d_%d_%d_%d', ... obj.time_date(2), obj.time_date(3), ... obj.time_date(4), obj.time_date(5)); % 环境参数(对应env_params) obj.env_params = struct(); obj.env_params.n_agents = obj.n_agent; obj.env_params.dim_observation = 21; obj.env_params.dim_action = 5; obj.env_params.dim_achieved_goal = 3; obj.env_params.clip_obs = obj.clip_obs; obj.env_params.dim_goal = 3; obj.env_params.action_max = 1; obj.env_params.grid_size = 30; obj.env_params.observation_dim = 35; obj.env_params.action_dim = 5; obj.env_params.clip_obs = false; obj.env_params.max_timesteps = 500; % 训练参数(对应train_params) obj.train_params = struct(); obj.train_params.learner_step = int64(1e6); obj.train_params.update_tar_interval = 40; obj.train_params.evalue_interval = 240; obj.train_params.evalue_time = 5; obj.train_params.store_interval = 2; obj.train_params.actor_num = obj.actor_num; obj.train_params.date = obj.date; obj.train_params.checkpoint = []; obj.train_params.polyak = 0.95; obj.train_params.action_l2 = 1; obj.train_params.noise_eps = 0.01; obj.train_params.random_eps = 0.3; obj.train_params.theta = 0.1; obj.train_params.Is_train_discrim = true; obj.train_params.roll_time = 2; obj.train_params.gamma = 0.98; obj.train_params.batch_size = 256; obj.train_params.buffer_size = int64(1e6); obj.train_params.device = 'cpu'; obj.train_params.lr_actor = 0.001; obj.train_params.lr_critic = 0.001; obj.train_params.lr_disc = 0.001; obj.train_params.clip_obs = obj.clip_obs; obj.train_params.clip_range = obj.clip_range; obj.train_params.add_demo = false; obj.train_params.save_dir = 'saved_models/'; obj.train_params.seed = obj.seed; obj.train_params.env_name = sprintf('grid_world_seed%d_%s', ... obj.seed, obj.date); obj.train_params.demo_name = 'armrobot_100_push_demo.npz'; obj.train_params.replay_strategy = 'future'; obj.train_params.replay_k = 4; % 合并环境参数到训练参数 obj.train_params = merge_structs(obj.train_params, obj.env_params); end end end % ============================================== % 主训练函数(对应train.py) % ============================================== function train() % 初始化参数 args = Args(); train_params = args.train_params; env_params = args.env_params; actor_num = train_params.actor_num; % 创建保存目录 model_path = fullfile(train_params.save_dir, train_params.env_name); if ~exist(train_params.save_dir, 'dir') mkdir(train_params.save_dir); logger('info', sprintf('creating directory %s', train_params.save_dir)); else logger('info', sprintf('directory %s already exists', train_params.save_dir)); end if ~exist(model_path, 'dir') mkdir(model_path); logger('info', sprintf('creating directory %s', model_path)); else logger('info', sprintf('directory %s already exists', model_path)); end % 设置随机种子 setup_seed(args.seed); logger('info', sprintf('New experiment date: %s, seed: %d', args.date, args.seed)); % 生成障碍物坐标并保存 origin_obstacle_states = []; for _ = 1:60 x = randi([1,29]); y = randi([1,29]); coord = [x, y]; if isempty(find(all(origin_obstacle_states == coord, 2), 1)) origin_obstacle_states = [origin_obstacle_states; coord]; end end % 保存障碍物坐标 fileID = fopen('origin_obstacle_states.txt', 'w'); for i = 1:size(origin_obstacle_states, 1) fprintf(fileID, '[%d, %d]\n', origin_obstacle_states(i,1), origin_obstacle_states(i,2)); end fclose(fileID); % 初始化并行池(替代Python multiprocessing) if isempty(gcp('nocreate')) parpool(actor_num); end % 初始化进程通信队列(全局变量模拟) global data_queue evalue_queue actor_queues; data_queue = {}; evalue_queue = {}; actor_queues = cell(1, actor_num); for i = 1:actor_num actor_queues{i} = {}; end % 启动Actor进程(并行执行) logger('info', 'Starting actor processes...'); actor_indices = 1:actor_num; % 并行启动Actor parfor i = actor_indices is_display = (i == 1); % 仅第一个Actor显示GUI actor_worker(data_queue, actor_queues{i}, i, origin_obstacle_states, is_display, env_params); end % 启动Learner和Evaluator(MATLAB中使用后台任务) learner_handle = parfeval(@learn, 0, model_path, data_queue, evalue_queue, actor_queues); evaluator_handle = parfeval(@evaluate_worker, 0, train_params, env_params, model_path, ... train_params.evalue_time, evalue_queue, origin_obstacle_states); % 等待完成 wait(learner_handle); wait(evaluator_handle); end % ============================================== % Actor工作函数(对应collection_experiments.py的actor_worker) % ============================================== function actor_worker(data_queue, actor_queue, actor_id, obstacles, is_display, env_params) % 初始化环境 env = Gridworld(obstacles, is_display); % 初始化策略网络 policy = actor_model(env_params); store_item = {'obs', 'next_obs', 'acts', 'r'}; mb_store_dict = cell2struct(cell(length(store_item), 1), store_item, 1); rolltime_count = 1; while true % 从队列获取最新模型参数 if ~isempty(actor_queue) policy_params = actor_queue{end}; actor_queue = {}; % 清空队列 policy = update_model(policy, policy_params); end % 收集样本 for rollouts_times = 1:env_params.store_interval ep_store_dict = cell2struct(cell(length(store_item), 1), store_item, 1); obs = env.reset(); for t = 1:env_params.max_timesteps % 选择动作 actions = select_action(policy, obs, true); % 探索模式 % 环境步进 [next_obs, reward, done, info] = env.step(actions); is_done = info{1}; % 存储数据 store_data.obs = obs; if t == env_params.max_timesteps store_data.next_obs = obs; else store_data.next_obs = next_obs; end store_data.acts = actions; store_data.r = reward; % 保存到episode字典 for key = store_item ep_store_dict.(key{1}) = [ep_store_dict.(key{1}); {store_data.(key{1})}]; end obs = next_obs; end % 如果完成则保存到批次字典 if is_done for key = store_item mb_store_dict.(key{1}) = [mb_store_dict.(key{1}); {ep_store_dict.(key{1})}]; end end end % 将数据放入队列 global data_queue; data_queue{end+1} = mb_store_dict; pause(0.1); % 避免资源竞争 end end % ============================================== % 评估函数(对应evaluator.py逻辑) % ============================================== function evaluate_worker(train_params, env_params, model_path, evalue_time, evalue_queue, obstacles) % 初始化评估环境 env = Gridworld(obstacles, true); % 评估时显示GUI policy = actor_model(env_params); while true % 检查是否有模型更新 if ~isempty(evalue_queue) model_params = evalue_queue{end}; evalue_queue = {}; policy = update_model(policy, model_params); end % 执行评估 total_rewards = []; for i = 1:evalue_time obs = env.reset(); episode_reward = 0; for t = 1:env_params.max_timesteps actions = select_action(policy, obs, false); % 不探索 [next_obs, reward, done, ~] = env.step(actions); episode_reward = episode_reward + sum(reward); obs = next_obs; if done break; end end total_rewards = [total_rewards, episode_reward]; end % 记录评估结果 avg_reward = mean(total_rewards); logger('info', sprintf('Evaluation average reward: %.2f', avg_reward)); % 保存评估数据 save(fullfile(model_path, 'evaluation_results.mat'), 'avg_reward', 'total_rewards'); pause(train_params.evalue_interval); % 评估间隔 end end % ============================================== % 学习函数(对应learner.py逻辑) % ============================================== function learn(model_path, data_queue, evalue_queue, actor_queues) % 初始化模型 args = Args(); env_params = args.env_params; train_params = args.train_params; policy = actor_model(env_params); target_policy = actor_model(env_params); % 目标网络 target_policy = copy_model(policy, target_policy); % 初始同步 % 经验回放缓冲区 replay_buffer = init_replay_buffer(train_params.buffer_size, env_params); step = 0; while step < train_params.learner_step % 从队列获取数据 if ~isempty(data_queue) % 取出所有样本 batch_data = data_queue; data_queue = {}; % 存入回放缓冲区 for i = 1:length(batch_data) replay_buffer = add_to_buffer(replay_buffer, batch_data{i}); end end % 训练步骤 if length(replay_buffer.obs) >= train_params.batch_size % 采样批次 batch = sample_buffer(replay_buffer, train_params.batch_size); % 更新策略(简化实现) policy = update_policy(policy, target_policy, batch, train_params); % 软更新目标网络 if mod(step, train_params.update_tar_interval) == 0 target_policy = soft_update(target_policy, policy, train_params.polyak); end % 定期保存模型 if mod(step, train_params.store_interval) == 0 model_save_path = fullfile(model_path, sprintf('%d_%d_model.mat', args.seed, step)); save(model_save_path, 'policy', 'target_policy'); logger('info', sprintf('Model saved to %s', model_save_path)); end % 发送模型参数给Actor policy_params = get_model_params(policy); for i = 1:length(actor_queues) actor_queues{i}{end+1} = policy_params; end % 发送模型给评估器 evalue_queue{end+1} = policy_params; step = step + 1; end pause(0.01); % 控制训练节奏 end end % ============================================== % 辅助函数:日志输出 % ============================================== function logger(level, msg) timestamp = datestr(now, 'yyyy-mm-dd HH:MM:SS'); fprintf('[%s] [%s] %s\n', timestamp, level, msg); end % ============================================== % 辅助函数:设置随机种子 % ============================================== function setup_seed(seed) rng(seed); % MATLAB随机种子 randn('seed', seed); % 正态分布种子 end % ============================================== % 环境类(对应Env.env.Gridworld) % ============================================== classdef Gridworld properties obstacles is_display max_timesteps grid_size n_agents current_obs end methods function obj = Gridworld(obstacles, is_display) obj.obstacles = obstacles; obj.is_display = is_display; obj.max_timesteps = 500; obj.grid_size = 30; obj.n_agents = 3; end function obs = reset(obj) % 重置环境,返回初始观测 obj.current_obs = rand(1, 35); % 35维观测 obs = obj.current_obs; end function [next_obs, reward, done, info] = step(obj, actions) % 执行动作,返回环境反馈 next_obs = rand(1, 35); % 新观测 reward = rand(1, obj.n_agents); % 每个智能体的奖励 done = (rand < 0.05); % 5%概率结束 info = {done}; % 信息字典 obj.current_obs = next_obs; % 如果需要显示则更新可视化 if obj.is_display obj.render(reward, done); end end function render(obj, reward, done) % 简单可视化 if obj.is_display clf; hold on; % 绘制障碍物 for i = 1:size(obj.obstacles, 1) plot(obj.obstacles(i,1), obj.obstacles(i,2), 'ks', 'MarkerSize', 8); end % 绘制智能体 plot(rand(3,1)*obj.grid_size, rand(3,1)*obj.grid_size, 'ro', 'MarkerSize', 10); axis([0 obj.grid_size 0 obj.grid_size]); title(sprintf('Reward: %.2f, Done: %d', sum(reward), done)); drawnow; end end end end % ============================================== % 策略网络模型(对应core.model.actor) % ============================================== function model = actor_model(env_params) model = struct(); % 简单网络结构:输入35 -> 256 -> 5 model.fc1_weights = randn(env_params.observation_dim, 256) * 0.01; model.fc1_bias = zeros(1, 256); model.fc2_weights = randn(256, env_params.action_dim) * 0.01; model.fc2_bias = zeros(1, env_params.action_dim); end % ============================================== % 动作选择函数(对应core.util.select_action) % ============================================== function actions = select_action(model, obs, explore) % 前向传播 h = relu(obs * model.fc1_weights + model.fc1_bias); logits = h * model.fc2_weights + model.fc2_bias; % 选择动作(argmax) [~, acts] = max(logits, [], 2); % 探索:添加随机噪声 if explore args = Args(); if rand < args.train_params.random_eps acts = randi(size(logits, 2), size(acts)); else % 添加高斯噪声 acts = acts + round(normrnd(0, args.train_params.noise_eps, size(acts))); acts = max(1, min(acts, size(logits, 2))); % 裁剪到有效范围 end end actions = acts; end % ============================================== % 辅助函数:合并结构体 % ============================================== function merged = merge_structs(a, b) merged = a; fields = fieldnames(b); for i = 1:length(fields) merged.(fields{i}) = b.(fields{i}); end end % ============================================== % 其他辅助函数(模型更新、缓冲区操作等) % ============================================== function new_model = update_model(old_model, params) new_model = old_model; fields = fieldnames(params); for i = 1:length(fields) new_model.(fields{i}) = params.(fields{i}); end end function target = copy_model(source, target) fields = fieldnames(source); for i = 1:length(fields) target.(fields{i}) = source.(fields{i}); end return target; end function target = soft_update(target, source, polyak) fields = fieldnames(source); for i = 1:length(fields) target.(fields{i}) = polyak * target.(fields{i}) + (1 - polyak) * source.(fields{i}); end return target; end function params = get_model_params(model) params = model; % 返回模型所有参数 end function buffer = init_replay_buffer(size, env_params) buffer = struct(); buffer.obs = {}; buffer.next_obs = {}; buffer.acts = {}; buffer.r = {}; buffer.max_size = size; end function buffer = add_to_buffer(buffer, data) % 添加数据到缓冲区 buffer.obs = [buffer.obs; data.obs]; buffer.next_obs = [buffer.next_obs; data.next_obs]; buffer.acts = [buffer.acts; data.acts]; buffer.r = [buffer.r; data.r]; % 超出最大容量则截断 if length(buffer.obs) > buffer.max_size excess = length(buffer.obs) - buffer.max_size; buffer.obs(1:excess) = []; buffer.next_obs(1:excess) = []; buffer.acts(1:excess) = []; buffer.r(1:excess) = []; end end function batch = sample_buffer(buffer, batch_size) % 从缓冲区采样批次 indices = randi(length(buffer.obs), batch_size, 1); batch.obs = buffer.obs(indices); batch.next_obs = buffer.next_obs(indices); batch.acts = buffer.acts(indices); batch.r = buffer.r(indices); end function policy = update_policy(policy, target_policy, batch, train_params) % 简化的策略更新(实际应包含梯度下降) % 这里仅作为框架示例 policy = policy; % 实际训练中应更新参数 return policy; end % ============================================== % 启动训练 % ============================================== train();检查有没有问题,并给我解释框架
最新发布
09-28
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值