1054. The Dominant Color 解析

寻找最多颜色的算法
本文介绍了一种通过使用哈希映射(map)来找出给定数据集中出现次数最多的颜色的算法实现。该方法首先将颜色编号映射到唯一的整数标识符,随后统计每个颜色的出现频率并记录最大的出现次数及其对应的颜色。

找出最多的颜色,并且输出。

我直接用map标记颜色编号,然后统计出现的次数,记录最大值,当当前次数大于最大值则就是最对的那个颜色。

#include <iostream>
#include <algorithm>
#include <cstring>
#include <string>
#include <vector>
#include <map>
#include <queue>
#include <stack>

#define MAX 480000

using namespace std;

int m,n;
int c[MAX];
int ans = 0;
map <int,int> c2no;
map <int,int> no2c;
int no = 1;

int main(){

	scanf("%d%d",&m,&n);
	
	memset(c,0,sizeof(c));
	int t,max = -1;

	for(int i = 0;i < n ;i++){
		for(int j = 0 ;j< m ;j++){
			scanf("%d",&t);
			if(c2no[t] == 0){
				c2no[t]=no;
				no2c[no++] = t;
			}
			int now = ++c[c2no[t]];
			if(now > max){
				ans = t;
				max = now;
			}
				
		}
	}

	printf("%d\n",ans);

	return 0;
}


clc; clear; close all; %% ===================== 参数设置 ===================== % 算法参数 snr_threshold = 30; % MNF信噪比阈值 snr_threshold1 = 3; % MNF信噪比阈值 gauss_sigma = 0.5; % 高斯滤波参数 gmm_k = 6; % GMM聚类数量 gmm_k1 = 3; %% ===================== 文件路径配置 ===================== hdr_file = "C:\Users\11218\Desktop\1126ggpp.hdr"; spe_file = "C:\Users\11218\Desktop\1126ggpp.spe"; %% ===================== 统一的头文件解析函数 ===================== function hdr_info = read_envihdr(hdr_file) hdr_info = struct(); fid = fopen(hdr_file, 'r'); if fid == -1 error('无法打开头文件: %s', hdr_file); end while ~feof(fid) line = strtrim(fgetl(fid)); if isempty(line) || line(1) == ';' || isempty(strfind(line, '=')) continue; end [key, value] = strtok(line, '='); key = strtrim(key); value = strtrim(value(2:end)); switch lower(key) case 'samples' hdr_info.samples = str2double(value); case 'lines' hdr_info.lines = str2double(value); case 'bands' hdr_info.bands = str2double(value); case 'data type' hdr_info.data_type = parse_data_type(str2double(value)); case 'interleave' hdr_info.interleave = lower(value); case 'byte order' hdr_info.byte_order = str2double(value); case 'header offset' hdr_info.header_offset = str2double(value); end end fclose(fid); % 设置默认值 if ~isfield(hdr_info, 'interleave'), hdr_info.interleave = 'bsq'; end if ~isfield(hdr_info, 'byte_order'), hdr_info.byte_order = 0; end if ~isfield(hdr_info, 'header_offset'), hdr_info.header_offset = 0; end end % 解析ENVI数据类型的辅助函数 function data_type_str = parse_data_type(type_num) switch type_num case 1 data_type_str = 'uint8'; case 2 data_type_str = 'int16'; case 3 data_type_str = 'int32'; case 4 data_type_str = 'single'; case 5 data_type_str = 'double'; case 12 data_type_str = 'uint16'; case 13 data_type_str = 'uint32'; case 14 data_type_str = 'int64'; case 15 data_type_str = 'uint64'; otherwise error('不支持的数据类型: %d', type_num); end end %% ===================== 主处理流程 ===================== % 解析HDR文件获取元数据 hdr_info = read_envihdr(hdr_file); % 读取SPE文件数据 fid = fopen(spe_file, 'r'); % 应用头文件偏移量 if hdr_info.header_offset > 0 fseek(fid, hdr_info.header_offset, 'bof'); end % 根据数据类型读取数据 data = fread(fid, [hdr_info.samples * hdr_info.lines * hdr_info.bands], ... [hdr_info.data_type '=>single']); % 统一转换为单精度浮点数 fclose(fid); % 根据交错方式重组数据 data_cube = reshape_data(data, hdr_info); % 创建有效像素掩膜 (排除0值像素) total_spectrum = sum(data_cube, 3); valid_mask = total_spectrum > 0 & ~isnan(total_spectrum); data_filtered = data_cube; % 实际未滤波,保留原始数据 % 选择RGB波段并显示 show_rgb_composite(data_cube, [117, 71, 26]); %% ===================== MNF降维 ===================== fprintf('\n3. MNF降维...\n'); [rows, cols, bands] = size(data_cube); % 将数据展平为2D矩阵 data_2d = reshape(data_cube, rows*cols, bands); valid_pixels = valid_mask(:); data_2d_valid = data_2d(valid_pixels, :); % 改进的噪声估计(使用移动差分法) fprintf('-- 改进的噪声估计...\n'); noise_est = zeros(size(data_2d_valid), 'single'); for i = 1:bands band_data = data_cube(:, :, i); % 使用移动差分法估计噪声 diff1 = abs(band_data - circshift(band_data, [1, 0])); diff2 = abs(band_data - circshift(band_data, [0, 1])); noise_band = min(diff1, diff2); noise_est(:, i) = noise_band(valid_mask); end % 噪声协方差矩阵计算 noise_cov = cov(noise_est); reg_param = 1e-5 * trace(noise_cov)/size(noise_cov,1); noise_cov_reg = noise_cov + reg_param*eye(size(noise_cov)); % 特征分解与白化 [U_noise, D_noise] = eig(noise_cov_reg); D_noise = diag(D_noise); D_noise(D_noise <= 0) = min(D_noise(D_noise > 0)); % 白化变换矩阵 W = diag(1./sqrt(D_noise)) * U_noise'; data_centered = data_2d_valid - mean(data_2d_valid, 1); data_whitened = data_centered * W'; % 白化后数据的PCA cov_whitened = cov(data_whitened); [U_signal, D_signal] = eig(cov_whitened); D_signal = diag(D_signal); [~, idx] = sort(D_signal, 'descend'); U_signal = U_signal(:, idx); D_signal = D_signal(idx); % 计算信噪比并确定分量数量 snr_all = D_signal - 1; valid_components = find(snr_all > snr_threshold); suggested_components = length(valid_components); fprintf('信噪比大于 %.2f 的分量数量: %d\n', snr_threshold, suggested_components); % MNF变换 T = W' * U_signal(:, 1:suggested_components); mnf_data = data_centered * T; % 重构MNF数据立方体 mnf_cube = nan(rows*cols, suggested_components); mnf_cube(valid_pixels, :) = mnf_data; mnf_cube = reshape(mnf_cube, [rows, cols, suggested_components]); %% ===================== 辅助函数 ===================== function data_cube = reshape_data(data, hdr_info) switch lower(hdr_info.interleave) case 'bil' % BIL交错: [行, 波段, 列] data_cube = reshape(data, [hdr_info.samples, hdr_info.bands, hdr_info.lines]); data_cube = permute(data_cube, [3, 1, 2]); % [行, 列, 波段] case 'bip' % BIP交错: [波段, 列, 行] data_cube = reshape(data, [hdr_info.bands, hdr_info.samples, hdr_info.lines]); data_cube = permute(data_cube, [3, 2, 1]); % [行, 列, 波段] otherwise % BSQ % BSQ交错: [列, 行, 波段] data_cube = reshape(data, [hdr_info.samples, hdr_info.lines, hdr_info.bands]); data_cube = permute(data_cube, [2, 1, 3]); % [行, 列, 波段] end end function show_rgb_composite(data_cube, band_indices) % 提取RGB波段 red_band = data_cube(:, :, band_indices(1)); green_band = data_cube(:, :, band_indices(2)); blue_band = data_cube(:, :, band_indices(3)); % 创建RGB图像 rgb_image = cat(3, red_band, green_band, blue_band); % 线性拉伸增强对比度 rgb_stretched = zeros(size(rgb_image), 'like', rgb_image); for i = 1:3 band = rgb_image(:, :, i); p_low = prctile(band(:), 2); p_high = prctile(band(:), 98); band_stretched = (band - p_low) / (p_high - p_low); band_stretched(band_stretched < 0) = 0; band_stretched(band_stretched > 1) = 1; rgb_stretched(:, :, i) = band_stretched; end % 显示RGB图像 figure; imshow(rgb_stretched); title(sprintf('RGB合成 (波段 %d, %d, %d)', band_indices(1), band_indices(2), band_indices(3)), ... 'FontSize', 14); end %% ===================== 伪彩色图像生成 ===================== fprintf('\n4. 生成伪彩色图像...\n'); % 创建全尺寸图像容器 mnf_rgb = zeros(rows, cols, 3, 'single'); % 将MNF分量映射到RGB通道 for i = 1:3 band_full = zeros(rows*cols, 1, 'single'); band_full(valid_pixels) = mnf_data(:, i); band_image = reshape(band_full, rows, cols); % 应用对比度拉伸 band_valid = band_image(valid_mask); p_low = prctile(band_valid, 2); p_high = prctile(band_valid, 98); band_image = (band_image - p_low) / (p_high - p_low); band_image = max(0, min(1, band_image)); mnf_rgb(:,:,i) = band_image; end %% ===================== GMM分类 ===================== fprintf('\n5. GMM分类 (k=%d)...\n', gmm_k); % 训练GMM模型 options = statset('MaxIter', 100, 'TolFun', 1e-3); gmm_model = fitgmdist(mnf_data, gmm_k, ... 'CovarianceType', 'diagonal', ... 'SharedCovariance', false, ... 'Options', options, ... 'Replicates', 5); % 预测类别 class_idx = cluster(gmm_model, mnf_data); % 创建分类结果图 class_map = zeros(rows, cols); class_map(valid_mask) = class_idx; % 各类别像素数量 class_counts = histcounts(class_idx, 1:gmm_k+1); %% ===================== 结果可视化 ===================== fprintf('\n6. 结果可视化...\n'); % 显示信噪比和类别分布信息 fprintf('信噪比统计: 最高=%.2f, 最低=%.2f\n', max(snr_all), min(snr_all)); fprintf('类别分布:\n'); for i = 1:gmm_k fprintf(' 类别 %d: %.2f%% (%d像素)\n', i, ... 100*class_counts(i)/sum(class_counts), class_counts(i)); end % MNF伪彩色图像 figure('Name', 'MNF伪彩色', 'Position', [650, 300, 500, 500]); imshow(mnf_rgb); title(sprintf('MNF伪彩色 (前3分量)'), 'FontWeight', 'bold'); fprintf('\n主类别颜色编辑 (共 %d 个类别)\n', gmm_k); main_class_cmap = create_custom_colormap(gmm_k); % 显示主类别分类图(使用自定义颜色) figure('Name', '主类别分类', 'Position', [100, 100, 800, 600]); imagesc(class_map); axis equal tight; title('主类别分类结果 (自定义颜色)', 'FontWeight', 'bold'); colormap(main_class_cmap); colorbar; %% ===================== 提取主要类别 ===================== fprintf('\n7. 提取主要类别...\n'); % 找到像素数量最多的类别 [~, dominant_class] = max(class_counts); fprintf('主要类别: %d, 像素数量: %d (占总有效像素的%.2f%%)\n', ... dominant_class, class_counts(dominant_class), ... 100*class_counts(dominant_class)/sum(class_counts)); % 创建主要类别掩膜 dominant_mask = (class_map == dominant_class); %% 第二次MNF降维 (修复NaN处理) fprintf('\n8. MNF降维 (仅处理主要类别)...\n'); % 提取主要类别数据 (更安全的方式) main_class_data = zeros(sum(dominant_mask(:)), size(data_cube, 3), 'single'); for band_idx = 1:size(data_cube, 3) band_data = data_cube(:, :, band_idx); main_class_data(:, band_idx) = band_data(dominant_mask); end % 噪声估计 (修复NaN处理) noise_est = zeros(size(main_class_data), 'single'); for band_idx = 1:size(data_cube, 3) band_data = data_cube(:, :, band_idx); band_data(~dominant_mask) = 0; % 将非主要类别设为0而不是NaN % 使用移动差分法估计噪声 diff1 = abs(band_data - circshift(band_data, [1, 0])); diff2 = abs(band_data - circshift(band_data, [0, 1])); noise_band = min(diff1, diff2); noise_est(:, band_idx) = noise_band(dominant_mask); endst(:, band_idx) = noise_band(dominant_mask); end % 噪声协方差矩阵计算 noise_cov = nancov(noise_est); % 处理可能的NaN值 % 正则化噪声协方差矩阵 reg_param = 1e-5 * trace(noise_cov)/size(noise_cov,1); noise_cov_reg = noise_cov + reg_param*eye(size(noise_cov)); % 特征分解与白化 [U_noise, D_noise] = eig(noise_cov_reg); D_noise = diag(D_noise); D_noise(D_noise <= 0) = min(D_noise(D_noise > 0)); % 白化变换矩阵 W = diag(1./sqrt(D_noise)) * U_noise'; data_centered = main_class_data - mean(main_class_data, 1, 'omitnan'); data_whitened = data_centered * W'; % 白化后数据的PCA cov_whitened = nancov(data_whitened); [U_signal, D_signal] = eig(cov_whitened); D_signal = diag(D_signal); [~, idx] = sort(D_signal, 'descend'); U_signal = U_signal(:, idx); D_signal = D_signal(idx); % 计算信噪比并确定分量数量 snr_all = D_signal - 1; valid_components = find(snr_all > snr_threshold1); suggested_components = length(valid_components); fprintf('信噪比大于 %.2f 的分量数量: %d\n', snr_threshold1, suggested_components); % MNF变换 T = W' * U_signal(:, 1:suggested_components); mnf_data = data_centered * T; %% ===================== GMM子类分类 ===================== fprintf('\n9. GMM子类分类 (k=%d)...\n', gmm_k1); % 训练GMM模型 options = statset('MaxIter', 200, 'TolFun', 1e-4); gmm_model = fitgmdist(mnf_data, gmm_k1, ... 'CovarianceType', 'diagonal', ... 'SharedCovariance', false, ... 'Options', options, ... 'Replicates', 10); % 预测子类 subclass_idx = cluster(gmm_model, mnf_data); % 创建子类分类图 subclass_map = zeros(size(dominant_mask)); subclass_map(dominant_mask) = subclass_idx; % 显示子类分布信息 subclass_counts = histcounts(subclass_idx, 1:gmm_k1+1); fprintf('子类分布 (主要类别 %d):\n', dominant_class); for i = 1:gmm_k1 fprintf(' 子类 %d: %.2f%% (%d像素)\n', i, ... 100*subclass_counts(i)/sum(subclass_counts), subclass_counts(i)); end % 可视化子类分布 figure('Name', '子类分类结果', 'Position', [100, 100, 800, 600]); imagesc(subclass_map); axis equal tight; title(sprintf('主要类别 %d 的子类分布', dominant_class), 'FontWeight', 'bold'); colormap(jet(gmm_k1)); colorbar; %% ===================== 子类颜色编辑 ===================== fprintf('\n子类颜色编辑 (共 %d 个子类)\n', gmm_k1); subclass_cmap = create_custom_colormap(gmm_k1); % 显示子类分类图(使用自定义颜色) figure('Name', '子类分类', 'Position', [100, 100, 800, 600]); imagesc(subclass_map); axis equal tight; title('子类分类结果 (自定义颜色)', 'FontWeight', 'bold'); colormap(subclass_cmap); colorbar; % 创建自定义颜色映射函数 function cmap = create_custom_colormap(num_classes) % 默认颜色映射(Jet方案) default_cmap = jet(num_classes); % 添加命令行颜色编辑功能 fprintf('\n*** 颜色编辑模式 ***\n'); fprintf('可以为每个类别指定自定义颜色。\n'); fprintf('输入格式可以是:\n'); fprintf(' 1. 颜色名称(如 "red", "blue", "green")\n'); fprintf(' 2. RGB三元组(如 [1, 0, 0])\n'); fprintf(' 3. 十六进制代码(如 "#FF0000")\n'); fprintf('输入 "skip" 保留当前颜色,输入 "exit" 结束编辑\n\n'); end % 创建颜色名称到RGB值的映射 color_map = containers.Map(); color_map('red') = [1, 0, 0]; color_map('green') = [0, 1, 0]; color_map('blue') = [0, 0, 1]; color_map('yellow') = [1, 1, 0]; color_map('cyan') = [0, 1, 1]; color_map('magenta') = [1, 0, 1]; color_map('white') = [1, 1, 1]; color_map('black') = [0, 0, 0]; color_map('orange') = [1, 0.5, 0]; color_map('purple') = [0.5, 0, 0.5]; color_map('pink') = [1, 0.75, 0.8]; color_map('gray') = [0.5, 0.5, 0.5]; % 初始化颜色映射 cmap = default_cmap; % 创建预览图 hFig = figure('Name', '分类预览', 'Position', [100, 100, 800, 600]); ax = axes('Parent', hFig); dummy_data = repmat((1:num_classes)', 1, 10); im = imagesc(ax, dummy_data); colormap(ax, cmap); colorbar(ax); title(ax, '分类颜色预览', 'FontWeight', 'bold'); % 为每个类别获取用户输入 for class_idx = 1:num_classes fprintf('类别 %d 当前颜色: [%.2f, %.2f, %.2f]\n', ... class_idx, cmap(class_idx, 1), ... cmap(class_idx, 2), cmap(class_idx, 3)); while true user_input = input(sprintf('为类别 %d 输入新颜色: ', class_idx), 's'); % 处理特殊命令 if strcmpi(user_input, 'exit') fprintf('颜色编辑已终止。\n'); break; elseif strcmpi(user_input, 'skip') fprintf('保留当前颜色。\n'); break; end % 尝试解析颜色输入 try % 尝试解析为颜色名称 if isKey(color_map, lower(user_input)) new_color = color_map(lower(user_input)); % 尝试解析为十六进制代码 elseif startsWith(user_input, '#') && length(user_input) == 7 hex = user_input(2:end); r = hex2dec(hex(1:2)) / 255; g = hex2dec(hex(3:4)) / 255; b = hex2dec(hex(5:6)) / 255; new_color = [r, g, b]; % 尝试解析为RGB向量 else new_color = str2num(user_input); %#ok<ST2NM> if isempty(new_color) error('无效输入格式'); end % 验证RGB值有效性 if numel(new_color) ~= 3 error('必须输入3个值的RGB向量'); end % 如果输入0-255范围的RGB值,转换为0-1范围 if any(new_color > 1) if all(new_color <= 255) && all(new_color >= 0) new_color = new_color / 255; else error('RGB值必须在0-1或0-255范围内'); end end % 检查值范围 if any(new_color < 0) || any(new_color > 1) error('RGB值必须在0-1范围内'); end end % 更新颜色映射 cmap(class_idx, :) = new_color; colormap(ax, cmap); drawnow; % 立即更新图形 fprintf('颜色已更新为: [%.3f, %.3f, %.3f]\n', new_color(1), new_color(2), new_color(3)); break; catch ME fprintf('输入无效: %s\n', ME.message); fprintf('请使用以下格式之一:\n'); fprintf(' - 颜色名称 (如 "red")\n'); fprintf(' - RGB向量 (如 [1, 0, 0])\n'); fprintf(' - 十六进制代码 (如 "#FF0000")\n'); end end % 如果用户选择退出,终止循环 if strcmpi(user_input, 'exit') break; end end % 关闭预览图 close(hFig); fprintf('\n颜色编辑完成。最终颜色映射已保存。\n');修改bugOutput argument "cmap" (and possibly others) not assigned a value in the execution with "ggp09131651>create_custom_colormap" function. 出错 ggp09131651 (第 274 行) main_class_cmap = create_custom_colormap(gmm_k);
最新发布
11-27
import os import cv2 import torch import numpy as np import pandas as pd from ultralytics import YOLO from torchvision import transforms import torch.nn.functional as F from collections import deque import time from config import ( DETECTOR_MODEL, CLASSIFIER_MODEL, ACTION_CLIP_LENGTH, ACTION_IMG_SIZE, ID_TO_CLASS, ACTION_CLASSES, DETECTION_IMG_SIZE ) class ActionClassifierInference: """动作分类模型推理类""" def __init__(self, model_path, clip_length, img_size, num_classes): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.clip_length = clip_length self.img_size = img_size self.num_classes = num_classes # 加载模型 checkpoint = torch.load(model_path, map_location=self.device) self.model = checkpoint['model_architecture'] self.model.load_state_dict(checkpoint['model_state_dict']) self.model.to(self.device) self.model.eval() # 预处理变换 self.transform = transforms.Compose([ transforms.ToPILImage(), transforms.Resize((img_size, img_size)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # 帧缓冲区 self.frame_buffer = deque(maxlen=clip_length) def preprocess_frames(self, frames): """预处理帧序列""" processed_frames = [] for frame in frames: if frame is not None: frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame = self.transform(frame) else: # 如果没有检测到教师,使用黑色图像 frame = torch.zeros((3, self.img_size, self.img_size)) processed_frames.append(frame) # 堆叠帧 [T, C, H, W] -> [C, T, H, W] return torch.stack(processed_frames).permute(1, 0, 2, 3).unsqueeze(0) def predict(self, frames): """预测动作类别""" if len(frames) < self.clip_length: # 如果帧数不足,用最后一帧或黑色图像填充 while len(frames) < self.clip_length: frames.append(frames[-1] if frames else None) # 预处理 input_tensor = self.preprocess_frames(frames).to(self.device) # 推理 with torch.no_grad(): outputs = self.model(input_tensor) probs = F.softmax(outputs, dim=1) confidence, pred = torch.max(probs, 1) # 如果置信度低于阈值,返回unknown if confidence.item() < 0.5: return "unknown", confidence.item() return ID_TO_CLASS[pred.item()], confidence.item() class TeacherActionRecognizer: """教师动作识别器""" def __init__(self): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 加载检测模型 self.detector = YOLO(DETECTOR_MODEL) # 加载分类模型 self.classifier = ActionClassifierInference( CLASSIFIER_MODEL, ACTION_CLIP_LENGTH, ACTION_IMG_SIZE, len(ACTION_CLASSES) ) # 初始化变量 self.current_time = 0 self.results = [] self.current_second_results = [] self.last_second = -1 def extract_teacher_region(self, frame, detections): """从帧中提取教师区域""" if detections is None or len(detections) == 0: return None # 获取置信度最高的检测结果 best_det = None highest_conf = 0 for det in detections: if det.conf > highest_conf: highest_conf = det.conf best_det = det if best_det is None: return None # 提取边界框 x1, y1, x2, y2 = map(int, best_det.xyxy[0].cpu().numpy()) # 确保边界框在图像范围内 h, w = frame.shape[:2] x1, y1 = max(0, x1), max(0, y1) x2, y2 = min(w, x2), min(h, y2) # 裁剪教师区域 teacher_region = frame[y1:y2, x1:x2] # 如果区域太小,返回None if teacher_region.size == 0 or min(teacher_region.shape[:2]) < 20: return None return teacher_region def get_dominant_action(self, actions): """获取1秒内的主要动作""" if not actions: return "unknown" # 统计每个动作的出现次数 action_counts = {} for action in actions: action_counts[action] = action_counts.get(action, 0) + 1 # 返回出现次数最多的动作 return max(action_counts.items(), key=lambda x: x[1])[0] def process_video(self, video_path, output_video_path=None, output_csv_path=None): """处理视频并生成结果""" # 打开视频文件 cap = cv2.VideoCapture(video_path) if not cap.isOpened(): raise ValueError(f"无法打开视频文件: {video_path}") # 获取视频属性 fps = cap.get(cv2.CAP_PROP_FPS) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 准备输出视频 if output_video_path: fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height)) frame_count = 0 start_time = time.time() print(f"开始处理视频: {video_path}") print(f"视频信息: {fps}FPS, 总帧数: {total_frames}") while cap.isOpened(): ret, frame = cap.read() if not ret: break # 计算当前时间(秒) current_second = int(frame_count / fps) # 每帧都进行教师检测 detections = self.detector(frame, imgsz=DETECTION_IMG_SIZE, verbose=False)[0] # 提取教师区域 teacher_region = self.extract_teacher_region(frame, detections) # 将教师区域添加到缓冲区 self.classifier.frame_buffer.append(teacher_region) # 如果缓冲区已满,进行动作分类 if len(self.classifier.frame_buffer) == self.classifier.frame_buffer.maxlen: action, confidence = self.classifier.predict(list(self.classifier.frame_buffer)) # 记录当前秒的结果 if current_second != self.last_second: if self.last_second != -1 and self.current_second_results: # 确定上一秒的主要动作 dominant_action = self.get_dominant_action(self.current_second_results) self.results.append({ "time_second": self.last_second, "action": dominant_action }) # 重置当前秒结果 self.current_second_results = [] self.last_second = current_second # 添加当前动作到当前秒结果 if action != "unknown": self.current_second_results.append(action) else: action = "processing" confidence = 0.0 # 在帧上绘制结果 label = f"Action: {action} ({confidence:.2f})" cv2.putText(frame, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) # 绘制检测框 if detections is not None: for det in detections: if det.conf > 0.5: # 只绘制置信度高的检测结果 x1, y1, x2, y2 = map(int, det.xyxy[0].cpu().numpy()) cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2) cv2.putText(frame, f"Teacher: {det.conf:.2f}", (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) # 写入输出视频 if output_video_path: out.write(frame) # 显示处理进度 frame_count += 1 if frame_count % 30 == 0: elapsed = time.time() - start_time fps_processed = frame_count / elapsed print(f"已处理 {frame_count}/{total_frames} 帧, 速度: {fps_processed:.2f} FPS") # 处理最后一秒的结果 if self.current_second_results: dominant_action = self.get_dominant_action(self.current_second_results) self.results.append({ "time_second": self.last_second, "action": dominant_action }) # 释放资源 cap.release() if output_video_path: out.release() # 保存CSV结果 if output_csv_path: df = pd.DataFrame(self.results) df.to_csv(output_csv_path, index=False) print(f"结果已保存到: {output_csv_path}") print("视频处理完成!") return self.results def main(): import argparse parser = argparse.ArgumentParser(description='教师动作识别推理') parser.add_argument('--input', type=str, required=True, help='输入视频路径') parser.add_argument('--output_video', type=str, default='output_video.mp4', help='输出视频路径') parser.add_argument('--output_csv', type=str, default='action_results.csv', help='输出CSV路径') args = parser.parse_args() # 初始化识别器 recognizer = TeacherActionRecognizer() # 处理视频 results = recognizer.process_video( args.input, args.output_video, args.output_csv ) # 打印结果摘要 print("\n动作识别结果摘要:") for result in results: print(f"时间: {result['time_second']}秒, 动作: {result['action']}") if __name__ == "__main__": main() 运行时报错:(TAL) PS D:\Ven\teacher_action_recognition> python -m inference.recognize_actions --input data/raw_videos/YW1.mp4 --output_video data/output/YW1_O.mp4 --output_csv data/output/YW1.csv 开始处理视频: data/raw_videos/YW1.mp4 视频信息: 30.000008219753568FPS, 总帧数: 75276 Traceback (most recent call last): File "D:\Ven\Documents\Anaconda\envs\TAL\lib\runpy.py", line 197, in _run_module_as_main return _run_code(code, main_globals, None, File "D:\Ven\Documents\Anaconda\envs\TAL\lib\runpy.py", line 87, in _run_code exec(code, run_globals) File "D:\Ven\teacher_action_recognition\inference\recognize_actions.py", line 285, in <module> main() File "D:\Ven\teacher_action_recognition\inference\recognize_actions.py", line 273, in main results = recognizer.process_video( File "D:\Ven\teacher_action_recognition\inference\recognize_actions.py", line 179, in process_video detections = self.detector(frame, imgsz=DETECTION_IMG_SIZE, verbose=False)[0] File "D:\Ven\Documents\Anaconda\envs\TAL\lib\site-packages\ultralytics\yolo\engine\model.py", line 58, in __call__ return self.predict(source, **kwargs) File "D:\Ven\Documents\Anaconda\envs\TAL\lib\site-packages\torch\autograd\grad_mode.py", line 27, in decorate_context return func(*args, **kwargs) File "D:\Ven\Documents\Anaconda\envs\TAL\lib\site-packages\ultralytics\yolo\engine\model.py", line 130, in predict predictor.setup(model=self.model, source=source) File "D:\Ven\Documents\Anaconda\envs\TAL\lib\site-packages\ultralytics\yolo\engine\predictor.py", line 111, in setup source = str(source or self.args.source) ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() 这个问题应该怎么纠正?
08-23
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值