offline_eval_map_corloc.py

蝴蝶检测与分类
本文分享了使用TensorFlow进行蝴蝶检测比赛的经验,详细介绍了数据集划分、模型选择与调整、评估指标计算的过程,以及解决评估过程中遇到的问题。
    <div id="post_detail">
<div class="post">
	<h2>
		<a id="cb_post_title_url" href="https://www.cnblogs.com/caffeaoto/p/8758962.html">用Tensorflow做蝴蝶检测</a>
	</h2>
	<div class="postbody">
	<div id="cnblogs_post_body" class="blogpost-body"><p>报名了一个蝴蝶检测比赛,一共给了700多张图,包含94种蝴蝶类别,要求检测出图片中的蝴蝶并正确分类。</p>

1.

https://www.cnblogs.com/caffeaoto/p/8758962.html
拿到数据集后,第一部就是将700多张图分成了 483张训练样本和238张测试样本(由于数据集中,有15种类别的蝴蝶只有一张,所以在测试样本中,仅包含了79种蝴蝶类别)

2.利用一个现有的包含蝴蝶类别的模型直接对测试集中的蝴蝶进行检测(相当于二分类),这里选用的是“ faster_rcnn_inception_resnet_v2_atrous_oid_2018_01_28 ”模型。该模型是在Open Image 数据集上训练的,总共有545个不同物体类别。

先是看的 object_detection/object_detection_tutorial.ipynb 这里直接导入的是frozen model,导致无法修改阈值,所以师兄换了一种模型导入方式,可以把检测的阈值调低,以提高检测到的蝴蝶数目

然后我需要做的是对检测的结果进行评估(计算Percison),但是走了点弯路。

tensorflow教程中的流程是,先把测试集转成TFrecord格式,然后再inference之后,直接在后面加入检测到的结果。但是在模型的输入处出现的问题,因为原本Frozen Model的输入是一个tensor,而师兄修改后的模型输入直接就是 image,所以我不知道怎么将输入进行转换。这里还需要后续的学习。于是我就换了一个方向,以师兄的code为基础,每次从文件夹中提出一张image,进行inference,然后我就根据image信息,解析其对应的xml文件中的信息,并写成tf_example的格式;同时将模型的输出dets添加到刚刚生成的example中。这样就解决了之前的问题,顺利把 annotation+detection结果保存成了 TFrecord格式。

下一步就是利用 object_detection/metrics/offline_eval_map_corloc.py 进行评估,但是出现了两个问题,导致我一度陷入僵局。第一个就是出现了  " ground_truth_group_of .size  :None type has no attrbute to size ",我以为是我的TFrecord出现了问题,但是最后发现是因为 

decoded_dict = data_parser.parse(example)

这里解析的时候,由于我原本TFrecord中并没有写入  standard_fields.TfExampleFields.object_group_of.object_group_of 信息,所以在解析的时候,这个内容就被填上了None ,所以不存在size,导致上面的问题产生。

self.optional_items_to_handlers = {
        fields.InputDataFields.groundtruth_difficult:
            Int64Parser(fields.TfExampleFields.object_difficult),
        fields.InputDataFields.groundtruth_group_of:
            Int64Parser(fields.TfExampleFields.object_group_of)

 查下 open image 中的特有的group_of参数是什么意思: Indicates that the box spans a group of objects (e.g., a bed of flowers or a crowd of people). We asked annotators to use this tag for cases with more than 5 instances which are heavily occluding each other and are physically touching.

 也就是说,带有group_of标记的说明,该框中包含了5个以上的物体,如拥挤的人群,一个铺满鲜花的床等等。

 

还有一个问题就是,我的GroundTruth中只有蝴蝶和背景两类,但是原本模型的label_map中却包含545类,所以其余的类别是没有GT的,这样在程序中有一个判断:

复制代码
# object_detection/utils/object_detection_evaluation.py
 if (self.num_gt_instances_per_class == 0).any():
      logging.warn(
          'The following classes have no ground truth examples: %s',
          np.squeeze(np.argwhere(self.num_gt_instances_per_class == 0)) +
          self.label_id_offset)
复制代码

我后来找到后,直接将其注释掉,最终跑通了。

该模型在蝴蝶单类的检测Precision=0.728

 



 

以下主要解释下评估的代码,防止以后忘记。假设模型输出validation_detections.tfrecord已保存在

models/research/butterfly路径下

第一步是生成配置文件:

复制代码
# From models/research/butterfly
SPLIT=validation  # or test

mkdir -p ${SPLIT}_eval_metrics

echo "
label_map_path: ‘…/object_detection/data/oid_bbox_trainable_label_map.pbtxt’
tf_record_input_reader: { input_path: ‘${SPLIT}_detections.tfrecord’ }
" > S P L I T e v a l m e t r i c s / {SPLIT}_eval_metrics/ SPLITevalmetrics/{SPLIT}_input_config.pbtxt

echo "
metrics_set: ‘open_images_detection_metrics’
" > S P L I T e v a l m e t r i c s / {SPLIT}_eval_metrics/ SPLITevalmetrics/{SPLIT}_eval_config.pbtxt

复制代码

然后运行评估程序:

复制代码
# From tensorflow/models/research/butterfly
SPLIT=validation  # or test

PYTHONPATH= P Y T H O N P A T H : PYTHONPATH: PYTHONPATH:(readlink -f …)
python -m object_detection/metrics/offline_eval_map_corloc
–eval_dir=KaTeX parse error: Expected 'EOF', got '\ ' at position 22: …}_eval_metrics \̲ ̲     #结果保存的路径 …{SPLIT}_eval_metrics/KaTeX parse error: Expected 'EOF', got '\ ' at position 27: …l_config.pbtxt \̲ ̲  </span>--in…{SPLIT}_eval_metrics/${SPLIT}_input_config.pbtxt  #输入的路径

复制代码

首先来看下主程序  models/research/object_detection/metrics/offline_eval_map_corloc.py

复制代码

  import csv
  import os
  import re
  import tensorflow as tf

 
 

  from object_detection import evaluator
  from object_detection.core import standard_fields
  from object_detection.metrics import tf_example_parser
  from object_detection.utils import config_util
  from object_detection.utils import label_map_util

...
def read_data_and_evaluate(input_config, eval_config):

def write_metrics(metrics, output_dir):
...
def main(argv): del argv required_flags = ['input_config_path', 'eval_config_path', 'eval_dir'] #对应输入的三个参数 for flag_name in required_flags: if not getattr(FLAGS, flag_name): raise ValueError('Flag --{} is required'.format(flag_name))

configs = config_util.get_configs_from_multiple_files(
eval_input_config_path=FLAGS.input_config_path,
eval_config_path=FLAGS.eval_config_path)

eval_config = configs[‘eval_config’]
input_config = configs[‘eval_input_config’]

metrics = read_data_and_evaluate(input_config, eval_config)    #主要实现部分在这里

# Save metrics
write_metrics(metrics, FLAGS.eval_dir)

复制代码

具体来看下

  read_data_and_evaluate(input_config, eval_config):

复制代码
def read_data_and_evaluate(input_config, eval_config):
  """Reads pre-computed object detections and groundtruth from tf_record.

Args:
input_config: input config proto of type  输入配置文件
object_detection.protos.InputReader.
eval_config: evaluation config proto of type 评估配置文件
object_detection.protos.EvalConfig.

Returns:
Evaluated detections metrics.  返回:评估结果

Raises:
ValueError: if input_reader type is not supported or metric type is unknown.
“”"
if input_config.WhichOneof(‘input_reader’) == ‘tf_record_input_reader’:
input_paths = input_config.tf_record_input_reader.input_path

label_map </span>=<span style="color: #000000;"> label_map_util.load_labelmap(input_config.label_map_path)#载入label_map
max_num_classes </span>= max([item.id <span style="color: #0000ff;">for</span> item <span style="color: #0000ff;">in</span><span style="color: #000000;"> label_map.item])      #获得最大的类别对应id (545)
categories </span>=<span style="color: #000000;"> label_map_util.convert_label_map_to_categories(
    label_map, max_num_classes)                       #list类型,eg. categories[110]={'id':111,'name':'Butterfly'}

object_detection_evaluators </span>=<span style="color: #000000;"> evaluator.get_evaluators(
    eval_config, categories)
</span><span style="color: #008000;">#</span><span style="color: #008000;"> Support a single evaluator</span>
object_detection_evaluator =<span style="color: #000000;"> object_detection_evaluators[0]      #对应object_detection_evaluation.OpenImagesDetectionEvaluator

skipped_images </span>=<span style="color: #000000;"> 0
processed_images </span>=<span style="color: #000000;"> 0
</span><span style="color: #0000ff;">for</span> input_path <span style="color: #0000ff;">in</span><span style="color: #000000;"> _generate_filenames(input_paths):
  tf.logging.info(</span><span style="color: #800000;">'</span><span style="color: #800000;">Processing file: {0}</span><span style="color: #800000;">'</span><span style="color: #000000;">.format(input_path))

  record_iterator </span>= tf.python_io.tf_record_iterator(path=<span style="color: #000000;">input_path)  #读取 validation_detection.tfrecord
  data_parser </span>=<span style="color: #000000;"> tf_example_parser.TfExampleDetectionAndGTParser()

  </span><span style="color: #0000ff;">for</span> string_record <span style="color: #0000ff;">in</span><span style="color: #000000;"> record_iterator:                   #迭代器,一共238个测试样本,每次读取一个样本检测结果
    tf.logging.log_every_n(tf.logging.INFO, </span><span style="color: #800000;">'</span><span style="color: #800000;">Processed %d images...</span><span style="color: #800000;">'</span>, 1000<span style="color: #000000;">,
                           processed_images)
    processed_images </span>+= 1<span style="color: #000000;">

    example </span>=<span style="color: #000000;"> tf.train.Example()
    example.ParseFromString(string_record)                #解析TFrecord--&gt; example.features.feature 中以字典形式存放数据
    decoded_dict </span>=<span style="color: #000000;"> data_parser.parse(example)              #对TFrecord进一步解析,还原:groundtruth_boxes、groundtruth_classes、detection_boxes、detection_classes、detection_scores<br>                                          <br>

    </span><span style="color: #0000ff;">if</span><span style="color: #000000;"> decoded_dict:<br>     <strong>#对应 object_detection/utils/object_detection_evaluation.py 中的 class OpenImagesDetectionEvaluator(),默认iou_threshold=0.5</strong>
      object_detection_evaluator.add_single_ground_truth_image_info( 
          decoded_dict[standard_fields.DetectionResultFields.key],
          decoded_dict)
      object_detection_evaluator.add_single_detected_image_info(
          decoded_dict[standard_fields.DetectionResultFields.key],
          decoded_dict)
    </span><span style="color: #0000ff;">else</span><span style="color: #000000;">:
      skipped_images </span>+= 1<span style="color: #000000;">
      tf.logging.info(</span><span style="color: #800000;">'</span><span style="color: #800000;">Skipped images: {0}</span><span style="color: #800000;">'</span><span style="color: #000000;">.format(skipped_images))

</span><span style="color: #0000ff;">return</span><em><strong><span style="color: #000000;"> object_detection_evaluator.evaluate()

raise ValueError(‘Unsupported input_reader_config.’)

复制代码

可以看出,主要的评测过程又放在了 

  object_detection/utils/object_detection_evaluation.py

第一个是 class OpenImagesDetectionEvaluator(ObjectDetectionEvaluator): 继承自  class ObjectDetectionEvaluator(DetectionEvaluator) ,而这个class 又继承自 class DetectionEvaluator(object)

所以我们从上往下看这几个函数,先是基类 DetectionEvaluator(object):Line:42

复制代码
class DetectionEvaluator(object):
  """Interface for object detection evalution classes.

Example usage of the Evaluator:

evaluator = DetectionEvaluator(categories)
                            即挨个添加 GT和detections ,最后一起evaluate()

Detections and groundtruth for image 1.

evaluator.add_single_groundtruth_image_info(…)
evaluator.add_single_detected_image_info(…)

Detections and groundtruth for image 2.

evaluator.add_single_groundtruth_image_info(…)
evaluator.add_single_detected_image_info(…)

metrics_dict = evaluator.evaluate()
“”"
metaclass = ABCMeta

def init(self, categories):
“”"Constructor.

Args:
  categories: A list of dicts, each of which has the following keys -
    'id': (required) an integer id uniquely identifying this category.
    'name': (required) string representing category name e.g., 'cat', 'dog'.
</span><span style="color: #800000;">"""</span><span style="color: #000000;">
self._categories </span>=<span style="color: #000000;"> categories

@abstractmethod
def add_single_ground_truth_image_info(self, image_id, groundtruth_dict):
“”"Adds groundtruth for a single image to be used for evaluation.

Args:
  image_id: A unique string/integer identifier for the image.
  groundtruth_dict: A dictionary of groundtruth numpy arrays required
    for evaluations.
</span><span style="color: #800000;">"""</span>
<span style="color: #0000ff;">pass</span><span style="color: #000000;">

@abstractmethod
def add_single_detected_image_info(self, image_id, detections_dict):
“”"Adds detections for a single image to be used for evaluation.

Args:
  image_id: A unique string/integer identifier for the image.
  detections_dict: A dictionary of detection numpy arrays required
    for evaluation.
</span><span style="color: #800000;">"""</span>
<span style="color: #0000ff;">pass</span><span style="color: #000000;">

@abstractmethod
def evaluate(self):
“”“Evaluates detections and returns a dictionary of metrics.”""
pass

@abstractmethod
def clear(self):
“”“Clears the state to prepare for a fresh evaluation.”""
pass

复制代码

然后class ObjectDetectionEvaluator(DetectionEvaluator)  Line:104class ObjectDetectionEvaluator(DetectionEvaluator):

复制代码
"""A class to evaluate detections."""

def init(self,
categories,
matching_iou_threshold=0.5,
evaluate_corlocs=False,
metric_prefix=None,
use_weighted_mean_ap=False,
evaluate_masks=False):
“”“Constructor.
Args:
xxxxxx
Raises:
ValueError: If the category ids are not 1-indexed.
“””
  …#这个地方是最关键的,后面会一直用到
self._evaluation = ObjectDetectionEvaluation(
num_groundtruth_classes=self._num_classes,
matching_iou_threshold=self._matching_iou_threshold,
use_weighted_mean_ap=self._use_weighted_mean_ap,
label_id_offset=self._label_id_offset)

def add_single_ground_truth_image_info(self, image_id, groundtruth_dict):
“”“Adds groundtruth for a single image to be used for evaluation.
“””
…略
  self._evaluation.add_single_ground_truth_image_info(xxx)
  
 def add_single_detected_image_info(self, image_id, detections_dict):
“”"Adds detections for a single image to be used for evaluation.

</span><span style="color: #800000;">"""</span><span style="color: #000000;">
...<br></span></pre>
  self._evaluation.add_detected_image_info(xxx)
 
 def evaluate(self): """Compute evaluation result. """ ...
  
(per_class_ap, mean_ap, _, _, per_class_corloc, mean_corloc) = (self._evaluation.evaluate())
    ...
def clear(self): """Clears the state to prepare for a fresh evaluation.""" self._evaluation = ObjectDetectionEvaluation( num_groundtruth_classes=self._num_classes, matching_iou_threshold=self._matching_iou_threshold, use_weighted_mean_ap=self._use_weighted_mean_ap, label_id_offset=self._label_id_offset) self._image_ids.clear()
复制代码

最后是OpenImagesDetectionEvaluator(ObjectDetectionEvaluator)    Line:376

复制代码
class OpenImagesDetectionEvaluator(ObjectDetectionEvaluator):
  """A class to evaluate detections using Open Images V2 metrics.
Open Images V2 introduce group_of type of bounding boxes and this metric
handles those boxes appropriately.

“”"

def init(self,
categories,
matching_iou_threshold=0.5,
evaluate_corlocs=False):
“”"Constructor.

Args:
  categories: A list of dicts, each of which has the following keys -
    'id': (required) an integer id uniquely identifying this category.
    'name': (required) string representing category name e.g., 'cat', 'dog'.
  matching_iou_threshold: IOU threshold to use for matching groundtruth
    boxes to detection boxes.
  evaluate_corlocs: if True, additionally evaluates and returns CorLoc.
</span><span style="color: #800000;">"""</span><span style="color: #000000;">
super(OpenImagesDetectionEvaluator, self).</span><span style="color: #800080;">__init__</span><span style="color: #000000;">(
    categories,
    matching_iou_threshold,
    evaluate_corlocs,
    metric_prefix</span>=<span style="color: #800000;">'</span><span style="color: #800000;">OpenImagesV2</span><span style="color: #800000;">'</span><span style="color: #000000;">)

def add_single_ground_truth_image_info(self, image_id, groundtruth_dict):
“”"Adds groundtruth for a single image to be used for evaluation.

</span><span style="color: #800000;">"""</span>
<span style="color: #0000ff;">if</span> image_id <span style="color: #0000ff;">in</span><span style="color: #000000;"> self._image_ids:
  </span><span style="color: #0000ff;">raise</span> ValueError(<span style="color: #800000;">'</span><span style="color: #800000;">Image with id {} already added.</span><span style="color: #800000;">'</span><span style="color: #000000;">.format(image_id))

groundtruth_classes </span>=<span style="color: #000000;"> (
    groundtruth_dict[standard_fields.InputDataFields.groundtruth_classes] </span>-<span style="color: #000000;">
    self._label_id_offset)
</span><span style="color: #008000;">#</span><span style="color: #008000;"> If the key is not present in the groundtruth_dict or the array is empty</span>
<span style="color: #008000;">#</span><span style="color: #008000;"> (unless there are no annotations for the groundtruth on this image)</span>
<span style="color: #008000;">#</span><span style="color: #008000;"> use values from the dictionary or insert None otherwise.</span>
<span style="color: #0000ff;">if</span> (standard_fields.InputDataFields.groundtruth_group_of <span style="color: #0000ff;">in</span><span style="color: #000000;">
    groundtruth_dict.keys() </span><span style="color: #0000ff;">and</span><span style="color: #000000;">
    (groundtruth_dict[standard_fields.InputDataFields.groundtruth_group_of]
     .size </span><span style="color: #0000ff;">or</span> <span style="color: #0000ff;">not</span><span style="color: #000000;"> groundtruth_classes.size)):
  groundtruth_group_of </span>=<span style="color: #000000;"> groundtruth_dict[
      standard_fields.InputDataFields.groundtruth_group_of]
</span><span style="color: #0000ff;">else</span><span style="color: #000000;">:
  groundtruth_group_of </span>=<span style="color: #000000;"> None
  </span><span style="color: #0000ff;">if</span> <span style="color: #0000ff;">not</span> len(self._image_ids) % 1000<span style="color: #000000;">:
    logging.warn(
        </span><span style="color: #800000;">'</span><span style="color: #800000;">image %s does not have groundtruth group_of flag specified</span><span style="color: #800000;">'</span><span style="color: #000000;">,
        image_id)
self._evaluation.add_single_ground_truth_image_info(
    image_id,
    groundtruth_dict[standard_fields.InputDataFields.groundtruth_boxes],
    groundtruth_classes,
    groundtruth_is_difficult_list</span>=<span style="color: #000000;">None,
    groundtruth_is_group_of_list</span>=<span style="color: #000000;">groundtruth_group_of)
self._image_ids.update([image_id])<br></span></pre>
复制代码

可以看出,这里只是修改了add_single_ground_truth_image_info 函数,其他都没变。而在其父类中,有把主要的工作交给了 class ObjectDetectionEvaluation(object) 来处理,这下整个代码在逐渐清晰起来。我最后会画个程序包含关系图,可能更容易理解些。

下面整个才是主要的保存 GT 和 Detection 结果的部分哦!!!

复制代码
class ObjectDetectionEvaluation(object):
  """Internal implementation of Pascal object detection metrics."""

def init(self,num_groundtruth_classes,matching_iou_threshold=0.5,nms_iou_threshold=1.0,nms_max_output_boxes=10000,use_weighted_mean_ap=False,label_id_offset=0):
if num_groundtruth_classes < 1:
raise ValueError(‘Need at least 1 groundtruth class for evaluation.’)

<span style="font-size: 15px;"><strong>self.per_image_eval </strong></span></span><span style="font-size: 15px;"><strong>=</strong></span><span style="color: #000000;"><span style="font-size: 15px;"><strong> per_image_evaluation.PerImageEvaluation</strong></span>(
    num_groundtruth_classes</span>=<span style="color: #000000;">num_groundtruth_classes,
    matching_iou_threshold</span>=<span style="color: #000000;">matching_iou_threshold,
    nms_iou_threshold</span>=<span style="color: #000000;">nms_iou_threshold,
    nms_max_output_boxes</span>=<span style="color: #000000;">nms_max_output_boxes)
</span><span style="color: #0000ff;">def</span><span style="color: #000000;"> clear_detections(self):
self._initialize_detections()

def add_single_ground_truth_image_info(self,image_key, groundtruth_boxes, groundtruth_class_labels, groundtruth_is_difficult_list=None, groundtruth_is_group_of_list=None, groundtruth_masks=None):
def add_single_detected_image_info(self, image_key, detected_boxes,detected_scores, detected_class_labels,detected_masks=None):

scores, tp_fp_labels, is_class_correctly_detected_in_image = (
self.per_image_eval.compute_object_detection_metrics(
detected_boxes=detected_boxes,
detected_scores=detected_scores,
detected_class_labels=detected_class_labels,
groundtruth_boxes=groundtruth_boxes,
groundtruth_class_labels=groundtruth_class_labels,
groundtruth_is_difficult_list=groundtruth_is_difficult_list,
groundtruth_is_group_of_list=groundtruth_is_group_of_list,
detected_masks=detected_masks,
groundtruth_masks=groundtruth_masks))

</span><span style="color: #0000ff;">for</span> i <span style="color: #0000ff;">in</span><span style="color: #000000;"> range(self.num_class):
  </span><span style="color: #0000ff;">if</span> scores[i].shape[0] &gt;<span style="color: #000000;"> 0:
    self.scores_per_class[i].append(scores[i])
    self.tp_fp_labels_per_class[i].append(tp_fp_labels[i])
(self.num_images_correctly_detected_per_class
) </span>+=<span style="color: #000000;"> is_class_correctly_detected_in_image<br><br></span>

def evaluate(self):
“”"Compute evaluation result.

Returns:
  A named tuple with the following fields -
    average_precision: float numpy array of average precision for
        each class.
    mean_ap: mean average precision of all classes, float scalar
    precisions: List of precisions, each precision is a float numpy
        array
    recalls: List of recalls, each recall is a float numpy array
    corloc: numpy float array
    mean_corloc: Mean CorLoc score for each class, float scalar
</span><span style="color: #800000;">"""</span>
scores = np.concatenate(self.scores_per_class[class_index]) tp_fp_labels = np.concatenate(self.tp_fp_labels_per_class[class_index]) precision, recall = metrics.compute_precision_recall( scores, tp_fp_labels, self.num_gt_instances_per_class[class_index]) self.precisions_per_class.append(precision) self.recalls_per_class.append(recall) average_precision = metrics.compute_average_precision(precision, recall) self.average_precision_per_class[class_index] = average_precision
self.corloc_per_class </span>=<span style="color: #000000;"> metrics.compute_cor_loc(
    self.num_gt_imgs_per_class,
    self.num_images_correctly_detected_per_class)<br></span><span style="color: #000000;">
mean_ap </span>=<span style="color: #000000;"> np.nanmean(self.average_precision_per_class)
mean_corloc </span>=<span style="color: #000000;"> np.nanmean(self.corloc_per_class)
</span><span style="color: #0000ff;">return</span><span style="color: #000000;"> ObjectDetectionEvalMetrics(
    self.average_precision_per_class, mean_ap, self.precisions_per_class,
    self.recalls_per_class, self.corloc_per_class, mean_corloc)</span></pre>
复制代码

我把不重要的部分都剃掉了,主要有两个重要的函数 1. object_detection/utils/per_image_evaluatuion.py  计算单张图的precision和recall

                        2. object_detection/utils/metrics.py          统计上述结果,并计算mAP等数值

1. object_detection/utils/per_image_evaluatuion.py  计算单张图的precision和recall

 

复制代码
scores, tp_fp_labels, is_class_correctly_detected_in_image = compute_object_detection_metrics(...)
-->
      scores, tp_fp_labels = self._compute_tp_fp(...) 
      -->for i in range(self.num_groundtruth_classes):
               scores, tp_fp_labels = self._compute_tp_fp_for_single_class(...)
               -->(iou, ioa, scores,num_detected_boxes) = self._get_overlaps_and_scores_box_mode(...)
                   -->detected_boxlist = np_box_list_ops.non_max_suppression(...)
         -->
复制代码

 

0
0
« 上一篇: 双系统,重装ubuntu后无法进入windows
» 下一篇: [ERROR] 安装完Detectron后出现 cython_nms.so: undefined symbol: PyFPE_jbuf
	</div>
	<p class="postfoot">
		posted on <span id="post-date">2018-04-10 16:11</span> <a href="https://www.cnblogs.com/caffeaoto/">caffeauto</a> 阅读(<span id="post_view_count">1518</span>) 评论(<span id="post_comment_count">2</span>)  <a href="https://i.cnblogs.com/EditPosts.aspx?postid=8758962" rel="nofollow">编辑</a> <a href="#" "AddToWz(8758962);return false;">收藏</a>
	</p>
</div>
<script type="text/javascript">var allowComments=true,cb_blogId=336677,cb_entryId=8758962,cb_blogApp=currentBlogApp,cb_blogUserGuid='5299712e-ab0b-e611-9fc1-ac853d9f53cc',cb_entryCreatedDate='2018/4/10 16:11:00';loadViewCount(cb_entryId);var cb_postType=1;var isMarkdown=false;</script>

</div><a name="!comments"></a><div id="blog-comments-placeholder"><div id="comments_pager_top"></div>

评论

		<div class="post">
			<h2>
				<a href="#3996063" class="layer">#1楼</a><a name="3996063" id="comment_anchor_3996063"></a>
				&nbsp;&nbsp;<span class="comment_actions"></span>
			</h2>
			<div id="comment_body_3996063" class="blog_comment_body">博主,我想问下,你的那个第一步生成config配置文件的eval_config.pbtxt是填写什么?还请指导下</div><div class="comment_vote"><a href="javascript:void(0);" class="comment_digg" "return voteComment(3996063,'Digg',this)">支持(0)</a><a href="javascript:void(0);" class="comment_bury" "return voteComment(3996063,'Bury',this)">反对(0)</a></div>
			<div class="postfoot">
				 <span class="comment_date">2018-06-11 19:12</span> | <a id="a_comment_author_3996063" href="http://home.cnblogs.com/u/1010645/" target="_blank">Zoe_启</a> <a href="http://msg.cnblogs.com/send/Zoe_%E5%90%AF" title="发送站内短消息" class="sendMsg2This">&nbsp;</a>
			</div>
		</div>
	
		<div class="post">
			<h2>
				<a href="#3996708" class="layer">#2楼</a><a name="3996708" id="comment_anchor_3996708"></a>[<span class="louzhu">楼主</span>]<span id="comment-maxId" style="display:none;">3996708</span><span id="comment-maxDate" style="display:none;">2018/6/12 15:57:56</span>
				&nbsp;&nbsp;<span class="comment_actions"></span>
			</h2>
			<div id="comment_body_3996708" class="blog_comment_body"><a href="#3996063" title="查看所回复的评论" "commentManager.renderComments(0,50,3996063);">@</a>

Zoe_启
metrics_set: 'open_images_detection_metrics

就是上面这个内容啊

http://pic.cnblogs.com/face/945479/20170302164420.png

2018-06-12 15:57 | caffeauto  

""" mlperf inference benchmarking tool """ from __future__ import division from __future__ import print_function from __future__ import unicode_literals # from memory_profiler import profile import argparse import array import collections import json import logging import os import sys import threading import time from multiprocessing import JoinableQueue import sklearn import star_loadgen as lg import numpy as np from mindspore import context from mindspore.train.model import Model from mindspore.train.serialization import load_checkpoint, load_param_into_net from src.model_utils.device_adapter import get_device_num, get_rank_id,get_device_id import os from mindspore import Model, context from mindspore.train.serialization import load_checkpoint, load_param_into_net,\ build_searched_strategy, merge_sliced_parameter from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel from src.callbacks import LossCallBack from src.datasets import create_dataset, DataType from src.metrics import AUCMetric from src.model_utils.moxing_adapter import moxing_wrapper from src.metrics import AUCMetric from src.callbacks import EvalCallBack import src.wide_and_deep as wide_deep import src.datasets as datasets from src.model_utils.config import config # from pygcbs_client.task import Task # task = Task() # config = task.config context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target , device_id = 1) print(config.device_id) batch_size = config.batch_size def add_write(file_path, print_str): with open(file_path, 'a+', encoding='utf-8') as file_out: file_out.write(print_str + '\n') def get_WideDeep_net(config): """ Get network of wide&deep model. """ WideDeep_net = WideDeepModel(config) loss_net = NetWithLossClass(WideDeep_net, config) train_net = TrainStepWrap(loss_net) eval_net = PredictWithSigmoid(WideDeep_net) return train_net, eval_net class ModelBuilder(): """ Wide and deep model builder """ def __init__(self): pass def get_hook(self): pass def get_train_hook(self): hooks = [] callback = LossCallBack() hooks.append(callback) if int(os.getenv('DEVICE_ID')) == 0: pass return hooks def get_net(self, config): return get_WideDeep_net(config) logging.basicConfig(level=logging.INFO) log = logging.getLogger("main") NANO_SEC = 1e9 MILLI_SEC = 1000 # pylint: disable=missing-docstring # the datasets we support SUPPORTED_DATASETS = { "debug": (datasets.Dataset, wide_deep.pre_process_criteo_wide_deep, wide_deep.WideDeepPostProcess(), {"randomize": 'total', "memory_map": True}), "multihot-criteo-sample": (wide_deep.WideDeepModel, wide_deep.pre_process_criteo_wide_deep, wide_deep.WideDeepPostProcess(), {"randomize": 'total', "memory_map": True}), "kaggle-criteo": (wide_deep.WideDeepModel, wide_deep.pre_process_criteo_wide_deep, wide_deep.WideDeepPostProcess(), {"randomize": 'total', "memory_map": True}), } # pre-defined command line options so simplify things. They are used as defaults and can be # overwritten from command line SUPPORTED_PROFILES = { "defaults": { "dataset": "multihot-criteo", "inputs": "continuous and categorical features", "outputs": "probability", "backend": "mindspore-native", "model": "wide_deep", "max-batchsize": 2048, }, "dlrm-debug-mindspore": { "dataset": "debug", "inputs": "continuous and categorical features", "outputs": "probability", "backend": "pytorch-native", "model": "dlrm", "max-batchsize": 128, }, "dlrm-multihot-sample-mindspore": { "dataset": "multihot-criteo-sample", "inputs": "continuous and categorical features", "outputs": "probability", "backend": "pytorch-native", "model": "dlrm", "max-batchsize": 2048, }, "dlrm-multihot-mindspore": { "dataset": "multihot-criteo", "inputs": "continuous and categorical features", "outputs": "probability", "backend": "pytorch-native", "model": "dlrm", "max-batchsize": 2048, } } SCENARIO_MAP = { "SingleStream": lg.TestScenario.SingleStream, "MultiStream": lg.TestScenario.MultiStream, "Server": lg.TestScenario.Server, "Offline": lg.TestScenario.Offline, } last_timeing = [] import copy class Item: """An item that we queue for processing by the thread pool.""" def __init__(self, query_id, content_id, features, batch_T=None, idx_offsets=None): self.query_id = query_id self.content_id = content_id self.data = features self.batch_T = batch_T self.idx_offsets = idx_offsets self.start = time.time() import mindspore.dataset as mds lock = threading.Lock() class RunnerBase: def __init__(self, model, ds, threads, post_proc=None): self.take_accuracy = False self.ds = ds self.model = model self.post_process = post_proc self.threads = threads self.result_timing = [] def handle_tasks(self, tasks_queue): pass def start_run(self, result_dict, take_accuracy): self.result_dict = result_dict self.result_timing = [] self.take_accuracy = take_accuracy self.post_process.start() def run_one_item(self, qitem): # run the prediction processed_results = [] try: lock.acquire() data = datasets._get_mindrecord_dataset(*qitem.data) t1 = time.time() results = self.model.eval(data) self.result_timing.append(time.time() - t1) print('##################',time.time()) lock.release() processed_results = self.post_process(results, qitem.batch_T, self.result_dict) # self.post_process.add_results(, ) # g_lables.extend(self.model.auc_metric.true_labels) # g_predicts.extend(self.model.auc_metric.pred_probs) g_lables.extend(self.model.auc_metric.true_labels) g_predicts.extend(self.model.auc_metric.pred_probs) except Exception as ex: # pylint: disable=broad-except log.error("thread: failed, %s", ex) # since post_process will not run, fake empty responses processed_results = [[]] * len(qitem.query_id) finally: response_array_refs = [] response = [] for idx, query_id in enumerate(qitem.query_id): # NOTE: processed_results returned by DlrmPostProcess store both # result = processed_results[idx][0] and target = processed_results[idx][1] # also each idx might be a query of samples, rather than a single sample # depending on the --samples-to-aggregate* arguments. # debug prints # print("s,e:",s_idx,e_idx, len(processed_results)) response_array = array.array("B", np.array(processed_results[0:1], np.float32).tobytes()) response_array_refs.append(response_array) bi = response_array.buffer_info() response.append(lg.QuerySampleResponse(query_id, bi[0], bi[1])) lg.QuerySamplesComplete(response) def enqueue(self, query_samples): idx = [q.index for q in query_samples] query_id = [q.id for q in query_samples] # print(idx) query_len = len(query_samples) # if query_len < self.max_batchsize: # samples = self.ds.get_samples(idx) # # batch_T = [self.ds.get_labels(sample) for sample in samples] # self.run_one_item(Item(query_id, idx, samples)) # else: bs = 1 for i in range(0, query_len, bs): ie = min(i + bs, query_len) samples = self.ds.get_samples(idx[i:ie]) # batch_T = [self.ds.get_labels(sample) for sample in samples] self.run_one_item(Item(query_id[i:ie], idx[i:ie], samples)) def finish(self): pass import threading class MyQueue: def __init__(self, *args, **kwargs): self.lock = threading.Lock() self._data = [] self.status = True def put(self, value): # self.lock.acquire() self._data.append(value) # self.lock.release() def get(self): if self.status and self._data: return self._data.pop(0) if self.status: while self._data: time.sleep(0.1) return self._data.pop(0) return None # self.lock.acquire() # return self._data.pop(0) # self.lock.release() def task_done(self, *args, **kwargs): self.status = False class QueueRunner(RunnerBase): def __init__(self, model, ds, threads, post_proc=None): super().__init__(model, ds, threads, post_proc) queue_size_multiplier = 4 # (args.samples_per_query_offline + max_batchsize - 1) // max_batchsize) self.tasks = JoinableQueue(maxsize=threads * queue_size_multiplier) self.workers = [] self.result_dict = {} for _ in range(self.threads): worker = threading.Thread(target=self.handle_tasks, args=(self.tasks,)) worker.daemon = True self.workers.append(worker) worker.start() def handle_tasks(self, tasks_queue): """Worker thread.""" while True: qitem = tasks_queue.get() if qitem is None: # None in the queue indicates the parent want us to exit tasks_queue.task_done() break self.run_one_item(qitem) tasks_queue.task_done() def enqueue(self, query_samples): idx = [q.index for q in query_samples] query_id = [q.id for q in query_samples] query_len = len(query_samples) # print(idx) # if query_len < self.max_batchsize: # samples = self.ds.get_samples(idx) # # batch_T = [self.ds.get_labels(sample) for sample in samples] # data = Item(query_id, idx, samples) # self.tasks.put(data) # else: bs = 1 for i in range(0, query_len, bs): ie = min(i + bs, query_len) samples = self.ds.get_samples(idx) # batch_T = [self.ds.get_labels(sample) for sample in samples] self.tasks.put(Item(query_id[i:ie], idx[i:ie], samples)) def finish(self): # exit all threads for _ in self.workers: self.tasks.put(None) for worker in self.workers: worker.join() def add_results(final_results, name, result_dict, result_list, took, show_accuracy=False): percentiles = [50., 80., 90., 95., 99., 99.9] buckets = np.percentile(result_list, percentiles).tolist() buckets_str = ",".join(["{}:{:.4f}".format(p, b) for p, b in zip(percentiles, buckets)]) if result_dict["total"] == 0: result_dict["total"] = len(result_list) # this is what we record for each run result = { "took": took, "mean": np.mean(result_list), "percentiles": {str(k): v for k, v in zip(percentiles, buckets)}, "qps": len(result_list) / took, "count": len(result_list), "good_items": result_dict["good"], "total_items": result_dict["total"], } acc_str = "" if show_accuracy: result["accuracy"] = 100. * result_dict["good"] / result_dict["total"] acc_str = ", acc={:.3f}%".format(result["accuracy"]) if "roc_auc" in result_dict: result["roc_auc"] = 100. * result_dict["roc_auc"] acc_str += ", auc={:.3f}%".format(result["roc_auc"]) # add the result to the result dict final_results[name] = result # to stdout print("{} qps={:.2f}, mean={:.4f}, time={:.3f}{}, queries={}, tiles={}".format( name, result["qps"], result["mean"], took, acc_str, len(result_list), buckets_str)) lock = threading.Lock() def append_file(file_path, data): with lock: my_string = ','.join(str(f) for f in data) with open(file_path, 'a+') as file: file.write(my_string + '\n') def read_file(file_path): lines_as_lists = [] with open(file_path, 'r') as file: for line in file: # 去除行尾的换行符,并将行分割成列表 lines_as_lists.extend([float(num) for num in line.strip().split(',')]) return lines_as_lists def get_score(model, quality: float, performance: float): print(model, quality,performance) try: score =0 if model["scenario"] == 'SingleStream'or model["scenario"]== 'MultiStream': if quality >= model["accuracy"]: score = model["baseline_performance"] / (performance + 1e-9) * model["base_score"] else: score ==0 elif model["scenario"]== 'Server' or model["scenario"] == 'Offline': if quality >= model["accuracy"]: score = performance * model["base_score"] / model["baseline_performance"] print(model["baseline_performance"]) else: score ==0 except Exception as e: score ==0 finally: return score def main(): global last_timeing # 初始化时清空文件 global g_lables global g_predicts # args = get_args() g_lables=[] g_predicts=[] # # dataset to use wanted_dataset, pre_proc, post_proc, kwargs = SUPPORTED_DATASETS["debug"] # # --count-samples can be used to limit the number of samples used for testing ds = wanted_dataset(directory=config.dataset_path, train_mode=False, epochs=15, line_per_sample=1000, batch_size=config.test_batch_size, data_type=DataType.MINDRECORD,total_size = config.total_size) # # load model to backend # model = backend.load(args.model_path, inputs=args.inputs, outputs=args.outputs) net_builder = ModelBuilder() train_net, eval_net = net_builder.get_net(config) # ckpt_path = config.ckpt_path param_dict = load_checkpoint(config.ckpt_path) load_param_into_net(eval_net, param_dict) train_net.set_train() eval_net.set_train(False) # acc_metric = AccMetric() # model = Model(train_net, eval_network=eval_net, metrics={"acc": acc_metric}) auc_metric1 = AUCMetric() model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric1}) model.auc_metric = auc_metric1 # res = model.eval(ds_eval) final_results = { "runtime": "wide_deep_mindspore", "version": "v2", "time": int(time.time()), "cmdline": str(config), } mlperf_conf = os.path.abspath(config.mlperf_conf) if not os.path.exists(mlperf_conf): log.error("{} not found".format(mlperf_conf)) sys.exit(1) user_conf = os.path.abspath(config.user_conf) if not os.path.exists(user_conf): log.error("{} not found".format(user_conf)) sys.exit(1) if config.output: output_dir = os.path.abspath(config.output) os.makedirs(output_dir, exist_ok=True) os.chdir(output_dir) # # make one pass over the dataset to validate accuracy # count = ds.get_item_count() count = 5 base_score = config.config.base_score #task. accuracy = config.config.accuracy #task. baseline_performance = config.baseline_performance scenario_str = config["scenario"] #task. scenario = SCENARIO_MAP[scenario_str] runner_map = { lg.TestScenario.SingleStream: RunnerBase, lg.TestScenario.MultiStream: QueueRunner, lg.TestScenario.Server: QueueRunner, lg.TestScenario.Offline: QueueRunner } runner = runner_map[scenario](model, ds, config.threads_count, post_proc=post_proc) def issue_queries(query_samples): runner.enqueue(query_samples) def flush_queries(): pass settings = lg.TestSettings() settings.FromConfig(mlperf_conf, config.model_path, config.scenario) settings.FromConfig(user_conf, config.model_path, config.scenario) settings.scenario = scenario settings.mode = lg.TestMode.AccuracyOnly sut = lg.ConstructSUT(issue_queries, flush_queries) qsl = lg.ConstructQSL(count, config.performance_count, ds.load_query_samples, ds.unload_query_samples) log.info("starting {}".format(scenario)) result_dict = {"good": 0, "total": 0, "roc_auc": 0, "scenario": str(scenario)} runner.start_run(result_dict, config.accuracy) lg.StartTest(sut, qsl, settings) result_dict["good"] = runner.post_process.good result_dict["total"] = runner.post_process.total last_timeing = runner.result_timing post_proc.finalize(result_dict) add_results(final_results, "{}".format(scenario), result_dict, last_timeing, time.time() - ds.last_loaded, config.accuracy) runner.finish() lg.DestroyQSL(qsl) lg.DestroySUT(sut) # If multiple subprocesses are running the model send a signal to stop them if (int(os.environ.get("WORLD_SIZE", 1)) > 1): model.eval(None) from sklearn.metrics import roc_auc_score # labels = read_file(labels_path) # predicts = read_file(predicts_path) final_results['auc'] = sklearn.metrics.roc_auc_score(g_lables, g_predicts) print("auc+++++", final_results['auc']) NormMetric= { 'scenario': scenario_str, 'accuracy': accuracy, 'baseline_performance': baseline_performance, 'performance_unit': 's', 'base_score': base_score } # 打开文件 reprot_array=[] test_suit_array=[] test_suit_obj={} test_cases_array=[] test_cases_obj={} test_cases_obj["Name"] = config["task_name"] test_cases_obj["Performance Unit"] = config["performance_unit"] test_cases_obj["Total Duration"] = time.time() - ds.last_loaded test_cases_obj["Train Duration"] = None test_cases_obj["Training Info"] = { "Real Quality" :None, "Learning Rate" :None, "Base Quality" :None, "Epochs" :None, "Optimizer" :None } test_cases_obj["Software Versions"] = { "Python" :3.8, "Framework" :"Mindspore 2.2.14" } percentiles = [50., 80., 90., 95., 99., 99.9] buckets = np.percentile(last_timeing, percentiles).tolist() took = time.time() - ds.last_loaded qps = len(last_timeing) / took print(buckets) if scenario_str=="SingleStream": test_cases_obj["Performance Metric"] = buckets[2] if scenario_str=="MultiStream": test_cases_obj["Performance Metric"] = buckets[4] if scenario_str=="Server": test_cases_obj["Performance Metric"] = qps if scenario_str=="Offline": test_cases_obj["Performance Metric"] = qps score = get_score(NormMetric, final_results["auc"], test_cases_obj["Performance Metric"]) test_cases_obj["Score"] = score test_cases_array.append(test_cases_obj) test_suit_obj["Test Cases"] = test_cases_array test_suit_obj["Name"] = "Inference Suite" test_suit_array.append(test_suit_obj) test_obj = {"Test Suites": test_suit_array} reprot_array.append(test_obj) test_suit_result_obj = {"Name":"Inference Suite","Description":"inference model","Score":score } test_suit_result_array = [] test_suit_result_array.append(test_suit_result_obj) test_suit_result_obj1 = {"Test Suites Results":test_suit_result_array} reprot_array.append(test_suit_result_obj1) epoch_obj={ "epoch":None, "epoch_time":None, "train_loss":None, "metric":None, "metric_name":None } reprot_array.append(epoch_obj) result_final = {"Report Info": reprot_array } print("result_final", result_final) # task.save(result_final) if __name__ == "__main__": # try: # if task.config["is_run_infer"] == True: # main() # task.close() # else: # task.close() # raise ValueError # except Exception as e: # task.logger.error(e) # task.close(e) # task.logger.info("Finish ") # if task.config["is_run_infer"] == True: # print(config) main() # task.close() 这段代码的目的是什么?
06-17
# pygcbs: # app_name: 'APP1' # master: '192.168.0.123' # port: 6789 # level: 'DEBUG' # interval: 1 # checklist: [ "System","CPU", "GPU","Mem","NPU", ] # save_path: "./" # docker: # pygcbs_image: nvidia-pygcbs:v1.0 # worker_image: nvidia-mindspore1.8.1:v1.0 # python_path: /opt/miniconda/bin/python # workers: # - '192.168.0.123:1' # socket_ifname: # - enp4s0 # tasks: #--------------------wide_deep--------------------------------- # - application_domain: "推荐" # task_framework: "Mindspore" # task_type: "推理" # task_name: "wide_deep_infer" # scenario: "SingleStream" # is_run_infer: True # project_path: '/home/gcbs/infer/wide_deep_infer' # main_path: "main.py" # dataset_path: '/home/gcbs/Dataset/wide_deep_data/' # times: 1 # 重试次数 #distribute do_eval: True is_distributed: False is_mhost: False exp_value: 0.501 #model log name: "wide_deep" Metrics: "AUC" request_auc: 0.74 dataset_name: "Criteo 1TB Click Logs Dataset" application: "推荐" standard_time: 3600 python_version: 3.8 mindspore_version: 1.8.1 # Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing) enable_modelarts: False data_url: "" train_url: "" checkpoint_url: "" data_path: "./data" dataset_path: "/home/gcbs/Dataset/wide_deep_data/" output_path: "/cache/train" load_path: "/cache/checkpoint_path" device_target: GPU enable_profiling: False data_format: 1 total_size: 10000000 performance_count: 10 # argparse_init 'WideDeep' epochs: 15 full_batch: False batch_size: 16000 eval_batch_size: 16000 test_batch_size: 16000 field_size: 39 vocab_size: 200000 vocab_cache_size: 0 emb_dim: 80 deep_layer_dim: [1024, 512, 256, 128] deep_layer_act: 'relu' keep_prob: 1.0 dropout_flag: False ckpt_path: "./check_points" stra_ckpt: "./check_points" eval_file_name: "./output/eval.log" loss_file_name: "./output/loss.log" host_device_mix: 0 dataset_type: "mindrecord" parameter_server: 0 field_slice: False sparse: False use_sp: True deep_table_slice_mode: "column_slice" #star_logen config mlperf_conf: './test.conf' user_conf: './user.conf' output: '/tmp/code/' scenario: 'Offline' max_batchsize: 16000 threads: 4 model_path: "./check_points/widedeep_train-12_123328.ckpt" is_accuracy: False find_peak_performance: False duration: False target_qps: False count_queries: False samples_per_query_multistream: False max_latency: False samples_per_query_offline: 500 # WideDeepConfig #data_path: "./test_raw_data/" #vocab_cache_size: 100000 #stra_ckpt: './checkpoints/strategy.ckpt' weight_bias_init: ['normal', 'normal'] emb_init: 'normal' init_args: [-0.01, 0.01] l2_coef: 0.00008 # 8e-5 manual_shape: None # wide_and_deep export device_id: 1 ckpt_file: "./check_points/widedeep_train-12_123328.ckpt" file_name: "wide_and_deep" file_format: "MINDIR" # src/process_data.py "Get and Process datasets" raw_data_path: "./raw_data" # src/preprocess_data.py "Recommendation dataset" dense_dim: 13 slot_dim: 26 threshold: 100 train_line_count: 45840617 skip_id_convert: 0 # src/generate_synthetic_data.py 'Generate Synthetic Data' output_file: "./train.txt" label_dim: 2 number_examples: 4000000 vocabulary_size: 400000000 random_slot_values: 0 #get_score threads_count: 4 base_score: 1 accuracy: 0.72 baseline_performance: 1文件中的这些是什么?、
06-17
import os import sys import time import wave import struct import random import datetime import threading import tkinter as tk from tkinter import ttk, messagebox import numpy as np import pyaudio import webrtcvad import torch import torchaudio import pyttsx3 from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC # 检查必要库是否安装 try: import pyaudio import webrtcvad import torch import torchaudio from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC import pyttsx3 except ImportError as e: messagebox.showerror("缺少依赖库", f"请安装必要的依赖库:\n{e}\n" "pip install pyaudio webrtcvad torch torchaudio transformers pyttsx3") sys.exit(1) # 检查设备 device = "cuda" if torch.cuda.is_available() else "cpu" print(f"使用设备: {device}") class OfflineVoiceAssistant: def __init__(self, root): self.root = root self.root.title("离线语音陪伴助手") self.root.geometry("800x600") self.root.configure(bg="#f0f8ff") # 淡蓝色背景 # 状态变量 self.listening = False self.processing = False self.model_loaded = False self.capture_complete = False # 音频参数 - 优化不抢话功能 self.FORMAT = pyaudio.paInt16 self.CHANNELS = 1 self.RATE = 16000 self.CHUNK = 480 self.VAD_AGGRESSIVENESS = 1 # 降低VAD敏感度 (0-3, 0最不敏感) self.SILENCE_DURATION = 1.5 # 增加静音检测时间 (1.5秒) self.POST_SPEECH_DELAY = 0.5 # 语音结束后额外等待时间 self.MAX_RECORDING_TIME = 15 # 最大录音时间(秒) # 模型路径 - 使用中文模型 self.model_dir = "C://Users//24596//offline_models//chinese_model" print(f"模型目录: {self.model_dir}") # 确保模型目录存在 if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) messagebox.showinfo("创建目录", f"已创建模型目录: {self.model_dir}") # 初始化TTS引擎 self.tts_engine = self.init_tts() # 创建UI self.create_widgets() # 加载模型 self.load_models() def init_tts(self): """初始化文本转语音引擎""" try: engine = pyttsx3.init() engine.setProperty('rate', 150) # 语速 engine.setProperty('volume', 0.9) # 音量 # 设置中文语音 voices = engine.getProperty('voices') chinese_voices = [v for v in voices if 'chinese' in v.name.lower() or 'zh' in v.id.lower()] if chinese_voices: engine.setProperty('voice', chinese_voices[0].id) print(f"使用中文语音: {chinese_voices[0].name}") else: print("未找到中文语音,将使用默认语音") return engine except Exception as e: messagebox.showerror("TTS初始化失败", f"错误: {str(e)}") return None def create_widgets(self): """创建用户界面""" # 主框架 main_frame = ttk.Frame(self.root, padding=20) main_frame.pack(fill=tk.BOTH, expand=True) # 标题 title_label = ttk.Label(main_frame, text="离线语音陪伴助手", font=("微软雅黑", 20, "bold"), foreground="#2c3e50") title_label.pack(pady=10) # 状态区域 status_frame = ttk.Frame(main_frame) status_frame.pack(fill=tk.X, pady=10) self.status_label = ttk.Label(status_frame, text="状态: 初始化中...", font=("微软雅黑", 12)) self.status_label.pack(side=tk.LEFT) # 模型状态 self.model_status = ttk.Label(status_frame, text="模型: 未加载", font=("微软雅黑", 10), foreground="#e74c3c") self.model_status.pack(side=tk.RIGHT, padx=10) # 指示灯 light_frame = ttk.Frame(main_frame) light_frame.pack(pady=5) self.light_canvas = tk.Canvas(light_frame, width=30, height=30, bg="#f0f8ff", highlightthickness=0) self.light_canvas.pack() self.light = self.light_canvas.create_oval(5, 5, 25, 25, fill="gray") # 对话区域 conv_frame = ttk.LabelFrame(main_frame, text="对话记录", padding=10) conv_frame.pack(fill=tk.BOTH, expand=True, pady=10) self.conversation = tk.Text(conv_frame, height=15, width=70, font=("微软雅黑", 10), wrap=tk.WORD, bg="white") scrollbar = ttk.Scrollbar(conv_frame, command=self.conversation.yview) scrollbar.pack(side=tk.RIGHT, fill=tk.Y) self.conversation.config(yscrollcommand=scrollbar.set) self.conversation.pack(fill=tk.BOTH, expand=True) self.conversation.tag_config("user", foreground="#2980b9") self.conversation.tag_config("assistant", foreground="#27ae60") self.conversation.tag_config("system", foreground="#7f8c8d") self.conversation.insert(tk.END, "系统: 初始化完成,准备加载模型...\n", "system") self.conversation.config(state=tk.DISABLED) # 控制按钮 btn_frame = ttk.Frame(main_frame) btn_frame.pack(pady=10) self.listen_btn = ttk.Button(btn_frame, text="开始聆听", width=12, command=self.toggle_listening, state=tk.DISABLED) self.listen_btn.pack(side=tk.LEFT, padx=5) self.clear_btn = ttk.Button(btn_frame, text="清空记录", width=12, command=self.clear_conversation) self.clear_btn.pack(side=tk.LEFT, padx=5) self.quit_btn = ttk.Button(btn_frame, text="退出", width=12, command=self.on_closing) self.quit_btn.pack(side=tk.RIGHT, padx=5) # 底部信息 bottom_frame = ttk.Frame(main_frame) bottom_frame.pack(fill=tk.X, pady=5) device_info = f"设备: {'GPU加速' if device == 'cuda' else 'CPU运行'}" ttk.Label(bottom_frame, text=device_info, font=("微软雅黑", 9), foreground="#7f8c8d").pack(side=tk.LEFT) ttk.Label(bottom_frame, text="完全离线 | 隐私安全 | 中文识别", font=("微软雅黑", 9), foreground="#7f8c8d").pack(side=tk.RIGHT) def update_status(self, message, color="gray"): """更新状态指示""" self.status_label.config(text=f"状态: {message}") colors = { "gray": "#95a5a6", "green": "#2ecc71", "red": "#e74c3c", "orange": "#f39c12", "blue": "#3498db" } self.light_canvas.itemconfig(self.light, fill=colors.get(color, "gray")) def add_to_conversation(self, speaker, text): """添加消息到对话记录""" timestamp = datetime.datetime.now().strftime("%H:%M:%S") self.conversation.config(state=tk.NORMAL) self.conversation.insert(tk.END, f"[{timestamp}] {speaker}: {text}\n", speaker.lower()) self.conversation.see(tk.END) self.conversation.config(state=tk.DISABLED) def toggle_listening(self): """切换监听状态 - 每次点击只识别一次""" if not self.model_loaded: self.add_to_conversation("系统", "模型尚未加载完成,无法开始监听") return if not self.listening: self.listening = True self.capture_complete = False self.listen_btn.config(text="停止聆听") self.update_status("聆听中...", "green") threading.Thread(target=self.start_listening, daemon=True).start() else: self.listening = False self.listen_btn.config(text="开始聆听") self.update_status("就绪", "blue") def clear_conversation(self): """清空对话记录""" self.conversation.config(state=tk.NORMAL) self.conversation.delete(1.0, tk.END) self.conversation.insert(tk.END, "系统: 对话记录已清空\n", "system") self.conversation.config(state=tk.DISABLED) def load_models(self): """加载中文语音识别模型""" try: # 检查模型目录是否存在 if not os.path.exists(self.model_dir): self.add_to_conversation("系统", f"模型目录不存在: {self.model_dir}") self.model_status.config(text="模型: 目录缺失", foreground="#e74c3c") return # 检查模型文件是否存在 required_files = ["config.json", "preprocessor_config.json", "pytorch_model.bin"] model_files = os.listdir(self.model_dir) missing_files = [f for f in required_files if f not in model_files] if missing_files: self.add_to_conversation("系统", f"缺少模型文件: {', '.join(missing_files)}") self.model_status.config(text="模型: 文件缺失", foreground="#e74c3c") return # 加载中文模型 self.add_to_conversation("系统", "开始加载中文语音识别模型...") self.update_status("加载模型中...", "orange") # 使用中文模型 self.asr_processor = Wav2Vec2Processor.from_pretrained(self.model_dir) self.asr_model = Wav2Vec2ForCTC.from_pretrained( self.model_dir, ignore_mismatched_sizes=True ).to(device) self.model_loaded = True self.listen_btn.config(state=tk.NORMAL) self.update_status("就绪", "blue") self.model_status.config(text="模型: 中文已加载", foreground="#27ae60") self.add_to_conversation("系统", "中文模型加载成功!") self.add_to_conversation("助手", "您好!我是您的离线语音助手,请点击'开始聆听'按钮与我对话。") except Exception as e: self.add_to_conversation("系统", f"模型加载失败: {str(e)}") self.model_status.config(text="模型: 加载失败", foreground="#e74c3c") self.update_status("错误", "red") def start_listening(self): """开始监听音频输入 - 优化不抢话功能""" self.add_to_conversation("系统", "启动音频监听...") # 初始化VAD和PyAudio vad = webrtcvad.Vad(self.VAD_AGGRESSIVENESS) p = pyaudio.PyAudio() # 打开音频流 stream = p.open( format=self.FORMAT, channels=self.CHANNELS, rate=self.RATE, input=True, frames_per_buffer=self.CHUNK ) # 音频缓冲区 frames = [] silence_count = 0 speech_detected = False max_silence = int(self.SILENCE_DURATION * self.RATE / self.CHUNK) max_frames = int(self.MAX_RECORDING_TIME * self.RATE / self.CHUNK) frame_count = 0 try: # 检测语音开始 self.add_to_conversation("系统", "请开始说话...") start_time = time.time() while self.listening and not self.capture_complete and frame_count < max_frames: # 读取音频数据 data = stream.read(self.CHUNK, exception_on_overflow=False) frame_count += 1 # 检测语音活动 if vad.is_speech(data, self.RATE): # 第一次检测到语音 if not speech_detected: self.add_to_conversation("系统", "检测到语音,正在聆听...") start_time = time.time() # 重置超时计时 speech_detected = True silence_count = 0 frames.append(data) elif speech_detected: silence_count += 1 frames.append(data) # 检测到静音结束 if silence_count > max_silence: # 语音结束后额外等待一段时间,确保用户说完话 time.sleep(self.POST_SPEECH_DELAY) # 保存录音 audio_file = os.path.join(self.model_dir, "recording.wav") with wave.open(audio_file, 'wb') as wf: wf.setnchannels(self.CHANNELS) wf.setsampwidth(p.get_sample_size(self.FORMAT)) wf.setframerate(self.RATE) wf.writeframes(b''.join(frames)) # 标记捕获完成 self.capture_complete = True # 处理音频 self.add_to_conversation("系统", "检测到语音结束,开始处理...") threading.Thread(target=self.process_audio, args=(audio_file,), daemon=True).start() break else: # 没有检测到语音,清空缓冲区 frames = [] silence_count = 0 # 超时检测 if time.time() - start_time > self.MAX_RECORDING_TIME: self.add_to_conversation("系统", "超时未检测到语音,停止监听") self.listening = False break # 如果用户手动停止 if not self.listening: self.add_to_conversation("系统", "用户手动停止监听") except Exception as e: self.add_to_conversation("系统", f"监听错误: {str(e)}") finally: # 清理资源 stream.stop_stream() stream.close() p.terminate() self.add_to_conversation("系统", "音频监听已停止") # 重置状态 if self.capture_complete: self.listening = False self.listen_btn.config(text="开始聆听") self.update_status("处理中", "orange") else: self.listening = False self.listen_btn.config(text="开始聆听") self.update_status("就绪", "blue") def process_audio(self, audio_file): """处理录制的音频""" if not self.model_loaded: return try: self.processing = True self.add_to_conversation("系统", f"处理音频文件: {os.path.basename(audio_file)}") # 记录开始时间 start_time = time.time() # 加载音频文件 waveform, sample_rate = torchaudio.load(audio_file) audio_duration = len(waveform[0]) / sample_rate self.add_to_conversation("系统", f"音频时长: {audio_duration:.2f}秒, 采样率: {sample_rate}Hz") # 确保采样率正确 if sample_rate != 16000: self.add_to_conversation("系统", f"重新采样: {sample_rate} -> 16000Hz") resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) waveform = resampler(waveform) # 预处理音频 self.add_to_conversation("系统", "预处理音频中...") input_values = self.asr_processor( waveform.squeeze().numpy(), return_tensors="pt", sampling_rate=16000 ).input_values.to(device) # 语音识别 self.add_to_conversation("系统", "进行语音识别...") with torch.no_grad(): logits = self.asr_model(input_values).logits # 解码识别结果 predicted_ids = torch.argmax(logits, dim=-1) text = self.asr_processor.batch_decode(predicted_ids)[0] if text: self.add_to_conversation("您", text) self.process_command(text) else: self.add_to_conversation("系统", "未识别到有效语音") # 计算处理时间 process_time = time.time() - start_time self.add_to_conversation("系统", f"处理完成,耗时: {process_time:.2f}秒") # 删除临时文件 try: os.remove(audio_file) self.add_to_conversation("系统", "已删除临时音频文件") except: pass except Exception as e: self.add_to_conversation("系统", f"处理错误: {str(e)}") finally: self.processing = False self.capture_complete = False self.update_status("就绪", "blue") def process_command(self, text): """处理用户命令并生成响应(中文)""" text = text.lower() response = "" # 问候 if any(word in text for word in ["你好", "嗨", "哈喽", "您好", "喂"]): greetings = ["您好!", "你好呀!", "很高兴为您服务!", "有什么我可以帮忙的吗?"] response = random.choice(greetings) # 时间 elif any(word in text for word in ["时间", "几点", "现在几点", "当前时间"]): now = datetime.datetime.now() response = f"当前时间是 {now.strftime('%H:%M')}" # 日期 elif any(word in text for word in ["日期", "今天", "今天日期", "几号"]): now = datetime.datetime.now() response = f"今天是 {now.strftime('%Y年%m月%d日')}" # 感谢 elif any(word in text for word in ["谢谢", "感谢", "多谢", "谢谢你"]): responses = ["不客气!", "很高兴能帮到您!", "这是我的荣幸!", "随时为您服务!"] response = random.choice(responses) # 退出 elif any(word in text for word in ["退出", "再见", "拜拜", "关闭"]): response = "再见!感谢您使用离线语音助手。" self.speak(response) time.sleep(1) self.on_closing() return response # 自我介绍 elif any(word in text for word in ["你是谁", "你叫什么", "你的名字"]): response = "我是一个完全离线的语音陪伴助手,保护您的隐私是我的首要任务。" # 天气 elif any(word in text for word in ["天气", "天气预报", "今天天气"]): response = "抱歉,作为一个离线助手,我无法获取实时天气信息。" # 讲笑话 elif any(word in text for word in ["笑话", "讲笑话", "说个笑话", "幽默"]): jokes = [ "为什么程序员喜欢黑暗模式?因为光会吸引bug!", "为什么计算机永远不会感冒?因为它有Windows!", "为什么科学家不相信原子?因为它们构成了一切!", "我告诉我的电脑我需要休息,它说:'没问题,我会在这里缓存。'" ] response = random.choice(jokes) # 默认响应 else: default_responses = [ "我在听,您可以说点别的吗?", "抱歉,我不太明白您的意思。", "能再说一遍吗?", "您需要什么帮助吗?", "我还在学习理解人类语言,请多包涵。" ] response = random.choice(default_responses) self.speak(response) def speak(self, text): """使用TTS引擎播放语音""" if not text or not self.tts_engine: return try: # 添加到对话记录 self.add_to_conversation("助手", text) # 播放语音 self.tts_engine.say(text) self.tts_engine.runAndWait() except Exception as e: self.add_to_conversation("系统", f"语音播放失败: {str(e)}") # 蜂鸣提示 try: import winsound winsound.Beep(440, 300) except: pass def on_closing(self): """关闭程序时的清理工作""" self.listening = False self.root.destroy() print("程序已安全退出") if __name__ == "__main__": # 创建主窗口 root = tk.Tk() # 设置窗口图标(可选) try: root.iconbitmap("icon.ico") except: pass # 创建应用 app = OfflineVoiceAssistant(root) # 启动主循环 root.mainloop() 这段代码是一个基础的语音助手 我想让他改变声音使用我下载好的模型,pth格式 你能帮我修改么,并把完整的代码给我
07-29
import torch import torch.nn as nn from utils.trainer import model_init_ from utils.build import check_cfg, build_from_cfg import os import glob from torchvision import transforms, datasets from PIL import Image, ImageDraw, ImageFont import time from graphic.RawDataProcessor import generate_images import imageio import sys import cv2 import numpy as np from torch.utils.data import DataLoader try: from DetModels import YOLOV5S from DetModels.yolo.basic import LoadImages, Profile, Path, non_max_suppression, Annotator, scale_boxes, colorstr, \ Colors, letterbox except ImportError: pass # Current directory and metric directory current_dir = os.path.dirname(os.path.abspath(__file__)) METRIC = os.path.join(current_dir, './metrics') sys.path.append(METRIC) sys.path.append(current_dir) sys.path.append('utils/DetModels/yolo') try: from .metrics.base_metric import EVAMetric except ImportError: pass from logger import colorful_logger # Supported image and raw data extensions image_ext = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff'] raw_data_ext = ['.iq', '.dat'] class Classify_Model(nn.Module): """ A class representing a classification model for performing inference and benchmarking using a pre-trained model. Attributes: - logger (colorful_logger): Logger for logging messages with color. - cfg (str): Path to configuration dictionary. - device (str): Device to use for inference (CPU or GPU). - model (torch.nn.Module): Pre-trained model. - save_path (str): Path to save the results. - save (bool): Flag to indicate whether to save the results. """ def __init__(self, cfg: str = '../configs/exp1_test.yaml', weight_path: str = '../default.path', save: bool = True, ): """ Initializes the Classify_Model. Parameters: - cfg (str): Path to configuration dictionary. - weight_path (str): Path to the pre-trained model weights. - save (bool): Flag to indicate whether to save the results. """ super().__init__() self.logger = self.set_logger if check_cfg(cfg): self.logger.log_with_color(f"Using config file: {cfg}") self.cfg = build_from_cfg(cfg) if self.cfg['device'] == 'cuda': if torch.cuda.is_available(): self.logger.log_with_color("Using GPU for inference") self.device = self.cfg['device'] else: self.logger.log_with_color("Using CPU for inference") self.device = "cpu" if os.path.exists(weight_path): self.logger.log_with_color(f"Using weight file: {weight_path}") self.weight_path = weight_path else: raise FileNotFoundError(f"weight path: {weight_path} does not exist") self.model = self.load_model self.model.to(self.device) self.model.eval() self.save_path = None self.save = save self.confidence_threshold = self.cfg.get('confidence_threshold', 0.49) self.logger.log_with_color(f"Using confidence threshold: {self.confidence_threshold * 100}%") def inference(self, source='../example/', save_path: str = '../result'): """ Performs inference on the given source data. Parameters: - source (str): Path to the source data. - save_path (str): Path to save the results. """ torch.no_grad() if self.save: if not os.path.exists(save_path): os.mkdir(save_path) self.save_path = save_path self.logger.log_with_color(f"Saving results to: {save_path}") if not os.path.exists(source): self.logger.log_with_color(f"Source {source} dose not exit") # dir detect if os.path.isdir(source): data_list = glob.glob(os.path.join(source, '*')) for data in data_list: # detect images in dir if is_valid_file(data, image_ext): self.ImgProcessor(data) # detect raw datas in dir elif is_valid_file(data, raw_data_ext): self.RawdataProcess(data) else: continue # detect single image elif is_valid_file(source, image_ext): self.ImgProcessor(source) # detect single pack of raw data elif is_valid_file(source, raw_data_ext): self.RawdataProcess(source) def forward(self, img): """ Forward pass through the model. Parameters: - img (torch.Tensor): Input image tensor. Returns: - probability (float): Confidence probability of the predicted class. - predicted_class_name (str): Name of the predicted class. """ self.model.eval() temp = self.model(img) probabilities = torch.softmax(temp, dim=1) predicted_class_index = torch.argmax(probabilities, dim=1).item() predicted_class_name = get_key_from_value(self.cfg['class_names'], predicted_class_index) probability = probabilities[0][predicted_class_index].item() * 100 return probability, predicted_class_name @property def load_model(self): """ Loads the pre-trained model. Returns: - model (torch.nn.Module): Loaded model. """ self.logger.log_with_color(f"Using device: {self.device}") # model = model_init_(self.cfg['model'], self.cfg['num_classes'], pretrained=True) model = model_init_(self.cfg['model'], self.cfg['num_classes'], pretrained_path=None) if os.path.exists(self.weight_path): self.logger.log_with_color(f"Loading init weights from: {self.weight_path}") # state_dict = torch.load(self.weight_path, map_location=self.device) state_dict = torch.load(self.weight_path, map_location=self.device, weights_only=True) model.load_state_dict(state_dict) self.logger.log_with_color(f"Successfully loaded pretrained weights from: {self.weight_path}") else: self.logger.log_with_color(f"init weights file not found at: {self.weight_path}. Skipping weight loading.") return model def ImgProcessor(self, source): """ Performs inference on spectromgram data. Parameters: - source (str): Path to the image. """ start_time = time.time() name = os.path.basename(source)[:-4] origin_image = Image.open(source).convert('RGB') preprocessed_image = self.preprocess(source) # 提取文件名(仅保留文件名,不含路径) filename = os.path.basename(source) temp = self.model(preprocessed_image) probabilities = torch.softmax(temp, dim=1) # # 新增:获取最大概率和对应类别索引 max_prob, predicted_class_index = torch.max(probabilities, dim=1) max_prob_val = max_prob.item() # 转换为浮点数' # 核心:计算unknown置信度为1 - 最高置信度(转换为百分比) unknown_prob = (1 - max_prob_val) * 100 # 已知类别置信度为模型输出值(转换为百分比) known_prob = max_prob_val * 100 # predicted_class_index = torch.argmax(probabilities, dim=1).item() # predicted_class_name = get_key_from_value(self.cfg['class_names'], predicted_class_index) if max_prob_val < self.confidence_threshold: predicted_class_name = 'unknown' current_prob = unknown_prob # 使用1-置信度 else: predicted_class_name = get_key_from_value(self.cfg['class_names'], predicted_class_index.item()) current_prob = known_prob # 使用模型原始置信度 end_time = time.time() self.logger.log_with_color(f"Inference time: {(end_time - start_time) / 100 :.8f} sec") # self.logger.log_with_color(f"{source} contains Drone: {predicted_class_name}, " # f"confidence1: {probabilities[0][predicted_class_index].item() * 100 :.2f} %," # f" start saving result") #这个版本是对未知机型置信度做了处理 # self.logger.log_with_color(f"{source} contains Drone: {predicted_class_name}, confidence: {current_prob:.2f}%") # 仅输出:文件名、机型、置信度(简化格式) self.logger.log_with_color(f"{filename}, contains Drone: {predicted_class_name}, {current_prob:.2f}%, 推理时间: {(end_time - start_time):.6f} sec") if self.save: # res = self.add_result(res=predicted_class_name, # probability=probabilities[0][predicted_class_index].item() * 100, # image=origin_image) res = self.add_result(res=predicted_class_name, probability=current_prob, image=origin_image) res.save(os.path.join(self.save_path, name + '.jpg')) def RawdataProcess(self, source): """ Transforming raw data into a video and performing inference on video. Parameters: - source (str): Path to the raw data. """ res = [] images = generate_images(source) name = os.path.splitext(os.path.basename(source)) for image in images: temp = self.model(self.preprocess(image)) probabilities = torch.softmax(temp, dim=1) predicted_class_index = torch.argmax(probabilities, dim=1).item() predicted_class_name = get_key_from_value(self.cfg['class_names'], predicted_class_index) _ = self.add_result(res=predicted_class_name, probability=probabilities[0][predicted_class_index].item() * 100, image=image) res.append(_) imageio.mimsave(os.path.join(self.save_path, name + '.mp4'), res, fps=5) def add_result(self, res, image, position=(40, 40), font="arial.ttf", font_size=45, text_color=(255, 0, 0), probability=0.0 ): """ Adds the inference result to the image. Parameters: - res (str): Inference result. - image (PIL.Image): Input image. - position (tuple): Position to add the text. - font (str): Font file path. - font_size (int): Font size. - text_color (tuple): Text color. - probability (float): Confidence probability. Returns: - image (PIL.Image): Image with added result. """ draw = ImageDraw.Draw(image) font = ImageFont.truetype("C:/Windows/Fonts/simhei.ttf", font_size) draw.text(position, res + f" {probability:.2f}%", fill=text_color, font=font) return image @property def set_logger(self): """ Sets up the logger. Returns: - logger (colorful_logger): Logger instance. """ logger = colorful_logger('Inference') return logger def preprocess(self, img): transform = transforms.Compose([ transforms.Resize((self.cfg['image_size'], self.cfg['image_size'])), transforms.ToTensor(), ]) image = Image.open(img).convert('RGB') preprocessed_image = transform(image) preprocessed_image = preprocessed_image.to(self.device) preprocessed_image = preprocessed_image.unsqueeze(0) return preprocessed_image def benchmark(self, data_path, save_path=None): """ Performs benchmarking on the given data and calculates evaluation metrics. Parameters: - data_path (str): Path to the benchmark data. Returns: - metrics (dict): Dictionary containing evaluation metrics. """ snrs = os.listdir(data_path) if not save_path: save_path = os.path.join(data_path, 'benchmark result') if not os.path.exists(save_path): os.mkdir(save_path) if not os.path.exists(save_path): os.mkdir(save_path) #根据得到映射关系写下面的,我得到的是★ 最佳映射 pred → gt: {0: 2, 1: 1, 2: 3, 3: 4, 4: 0} #MAP_P2G=torch.tensor([2,1,3,4,0],device=self.cfg['device']) #INV_MAP=torch.argsort(MAP_P2G) with torch.no_grad(): for snr in snrs: CMS = os.listdir(os.path.join(data_path, snr)) for CM in CMS: stat_time = time.time() self.model.eval() _dataset = datasets.ImageFolder( root=os.path.join(data_path, snr, CM), transform=transforms.Compose([ transforms.Resize((self.cfg['image_size'], self.cfg['image_size'])), transforms.ToTensor(),]) ) dataset = DataLoader(_dataset, batch_size=self.cfg['batch_size'], shuffle=self.cfg['shuffle']) print("Starting Benchmark...") correct = 0 total = 0 probabilities = [] total_labels = [] classes_name = tuple(self.cfg['class_names'].keys()) cm_raw = np.zeros((5, 5), dtype=int) for images, labels in dataset: images, labels = images.to(self.cfg['device']), labels.to(self.cfg['device']) outputs = self.model(images) #outputs=outputs[:,INV_MAP] #probs =torch.softmax(outputs,dim=1) for output in outputs: probabilities.append(list(torch.softmax(output, dim=0))) _, predicted = outputs.max(1) for p, t in zip(predicted.cpu(), labels.cpu()): cm_raw[p,t]+=1 cm_raw[p, t] += 1 # 行 = pred, 列 = gt total += labels.size(0) correct += predicted.eq(labels).sum().item() total_labels.append(labels) _total_labels = torch.concat(total_labels, dim=0) _probabilities = torch.tensor(probabilities) metrics = EVAMetric(preds=_probabilities.to(self.cfg['device']), labels=_total_labels, num_classes=self.cfg['num_classes'], tasks=('f1', 'precision', 'CM'), topk=(1, 3, 5), save_path=save_path, classes_name=classes_name, pic_name=f'{snr}_{CM}') metrics['acc'] = 100 * correct / total s = (f'{snr} ' + f'CM: {CM} eva result:' + ' acc: ' + f'{metrics["acc"]}' + ' top-1: ' + f'{metrics["Top-k"]["top1"]}' + ' top-1: ' + f'{metrics["Top-k"]["top1"]}' + ' top-2 ' + f'{metrics["Top-k"]["top2"]}' + ' top-3 ' + f'{metrics["Top-k"]["top3"]}' + ' mAP: ' + f'{metrics["mAP"]["mAP"]}' + ' macro_f1: ' + f'{metrics["f1"]["macro_f1"]}' + ' micro_f1 : ' + f' {metrics["f1"]["micro_f1"]}\n') txt_path = os.path.join(save_path, 'benchmark_result.txt') colorful_logger(f'cost {(time.time()-stat_time)/60} mins') with open(txt_path, 'a') as file: file.write(s) print(f'{CM} Done!') print(f'{snr} Done!') row_ind, col_ind = linear_sum_assignment(-cm_raw) # 取负→最大化对角线 mapping_pred2gt = {int(r): int(c) for r, c in zip(row_ind, col_ind)} print("\n★ 最佳映射 pred → gt:", mapping_pred2gt) # 若要保存下来以后用: import json json.dump(mapping_pred2gt, open('class_to_idx_pred2gt.json', 'w')) print("映射已保存到 class_to_idx_pred2gt.json") class Detection_Model: """ A common interface for initializing and running different detection models. This class provides methods to initialize and run object detection models such as YOLOv5 and Faster R-CNN. It allows for easy switching between different models by providing a unified interface. Attributes: - S1model: The initialized detection model (e.g., YOLOv5S). - model_name: The name of the detection model to be used. - weight_path: The path to the pre-trained model weights. Methods: - __init__(self, cfg=None, model_name=None, weight_path=None): Initializes the detection model based on the provided configuration or parameters. If a configuration dictionary `cfg` is provided, it will be used to set the model name and weight path. Otherwise, the `model_name` and `weight_path` parameters can be specified directly. - yolov5_detect(self, source='../example/source/', save_dir='../res', imgsz=(640, 640), conf_thres=0.6, iou_thres=0.45, max_det=1000, line_thickness=3, hide_labels=True, hide_conf=False): Runs YOLOv5 object detection on the specified source. - source: Path to the input image or directory containing images. - save_dir: Directory to save the detection results. - imgsz: Image size for inference (height, width). - conf_thres: Confidence threshold for filtering detections. - iou_thres: IoU threshold for non-maximum suppression. - max_det: Maximum number of detections per image. - line_thickness: Thickness of the bounding box lines. - hide_labels: Whether to hide class labels in the output. - hide_conf: Whether to hide confidence scores in the output. - faster_rcnn_detect(self, source='../example/source/', save_dir='../res', weight_path='../example/detect/', imgsz=(640, 640), conf_thres=0.25, iou_thres=0.45, max_det=1000, line_thickness=3, hide_labels=False, hide_conf=False): Placeholder method for running Faster R-CNN object detection. This method is currently not implemented and should be replaced with the actual implementation. """ def __init__(self, cfg=None, model_name=None, weight_path=None): if cfg: model_name = cfg['model_name'] weight_path = cfg['weight_path'] if model_name == 'yolov5': self.S1model = YOLOV5S(weights=weight_path) self.S1model.inference = self.yolov5_detect # ToDo elif model_name == 'faster_rcnn': self.S1model = YOLOV5S(weights=weight_path) self.S1model.inference = self.yolov5_detect else: if model_name == 'yolov5': self.S1model = YOLOV5S(weights=weight_path) self.S1model.inference = self.yolov5_detect # ToDo elif model_name == 'faster_rcnn': self.S1model = YOLOV5S(weights=weight_path) self.S1model.inference = self.yolov5_detect def yolov5_detect(self, source='../example/source/', save_dir='../res', imgsz=(640, 640), conf_thres=0.6, iou_thres=0.45, max_det=1000, line_thickness=3, hide_labels=True, hide_conf=False, ): color = Colors() detmodel = self.S1model stride, names = detmodel.stride, detmodel.names torch.no_grad() # Run inference if isinstance(source, np.ndarray): detmodel.eval() im = letterbox(source, imgsz, stride=stride, auto=True)[0] # padded resize im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB im = np.ascontiguousarray(im) # contiguous im = torch.from_numpy(im).to(detmodel.device) im = im.float() # uint8 to fp16/32 im /= 255 # 0 - 255 to 0.0 - 1.0 if len(im.shape) == 3: im = im[None] # expand for batch dim # Inference pred = detmodel(im) # NMS pred = non_max_suppression(pred, conf_thres, iou_thres, agnostic=False, max_det=max_det) # Process predictions for i, det in enumerate(pred): # per image annotator = Annotator(source, line_width=line_thickness, example=str(names)) if len(det): # Rescale boxes from img_size to im0 size det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], source.shape).round() # Print results for c in det[:, 5].unique(): n = (det[:, 5] == c).sum() # detections per class for *xyxy, conf, cls in reversed(det): c = int(cls) # integer class label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}') annotator.box_label(xyxy, label, color=color(c + 2, True)) # Stream results im0 = annotator.result() # Save results (image with detections) return im0 else: # Ensure the save directory exists os.makedirs(save_dir, exist_ok=True) dataset = LoadImages(source, img_size=imgsz, stride=stride) seen, windows, dt = 0, [], (Profile(), Profile(), Profile()) for path, im, im0s, s in dataset: im = torch.from_numpy(im).to(detmodel.device) im = im.float() # uint8 to fp16/32 im /= 255 # 0 - 255 to 0.0 - 1.0 if len(im.shape) == 3: im = im[None] # expand for batch dim # Inference pred = detmodel(im) # NMS pred = non_max_suppression(pred, conf_thres, iou_thres, agnostic=False, max_det=max_det) # Process predictions for i, det in enumerate(pred): # per image seen += 1 p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0) p = Path(p) # to Path save_path = str(save_dir + p.name) # im.jpg s += '%gx%g ' % im.shape[2:] # print string annotator = Annotator(im0, line_width=line_thickness, example=str(names)) if len(det): # Rescale boxes from img_size to im0 size det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() # Print results for c in det[:, 5].unique(): n = (det[:, 5] == c).sum() # detections per class s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string for *xyxy, conf, cls in reversed(det): c = int(cls) # integer class label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}') annotator.box_label(xyxy, label, color=color(c + 2, True)) # Stream results im0 = annotator.result() # Save results (image with detections) if save_dir == 'buffer': return im0 else: cv2.imwrite(save_path, im0) del im0 # Release memory after saving # Print results print(f"Results saved to {colorstr('bold', save_dir)}") #ToDo def faster_rcnn_detect(self, source='../example/source/', save_dir='../res', weight_path='../example/detect/', imgsz=(640, 640), conf_thres=0.25, iou_thres=0.45, max_det=1000, line_thickness=3, hide_labels=False, hide_conf=False, ): pass def is_valid_file(path, total_ext): """ Checks if the file has a valid extension. Parameters: - path (str): Path to the file. - total_ext (list): List of valid extensions. Returns: - bool: True if the file has a valid extension, False otherwise. """ last_element = os.path.basename(path) if any(last_element.lower().endswith(ext) for ext in total_ext): return True else: return False def get_key_from_value(d, value): """ Gets the key from a dictionary based on the value. Parameters: - d (dict): Dictionary. - value: Value to find the key for. Returns: - key: Key corresponding to the value, or None if not found. """ for key, val in d.items(): if val == value: return key return None def preprocess_image_yolo(im0, imgsz, stride, detmodel): im = letterbox(im0, imgsz, stride=stride, auto=True)[0] # padded resize im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB im = np.ascontiguousarray(im) # contiguous im = torch.from_numpy(im).to(detmodel.device) im = im.float() # uint8 to fp16/32 im /= 255 # 0 - 255 to 0.0 - 1.0 if len(im.shape) == 3: im = im[None] # expand for batch dim return im def process_predictions_yolo(det, im, im0, names, line_thickness, hide_labels, hide_conf, color): annotator = Annotator(im0, line_width=line_thickness, example=str(names)) if len(det): # Rescale boxes from img_size to im0 size det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() # Print results for c in det[:, 5].unique(): n = (det[:, 5] == c).sum() # detections per class for *xyxy, conf, cls in reversed(det): c = int(cls) # integer class label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}') annotator.box_label(xyxy, label, color=color(c + 2, True)) # Stream results im0 = annotator.result() return im0 # Usage----------------------------------------------------------------------------------------------------------------- def main(): """ cfg = '' weight_path = '' source = '' save_path = '' test = Classify_Model(cfg=cfg, weight_path=weight_path) test.inference(source=source, save_path=save_path) # test.benchmark() """ """ source = '' weight_path = '' save_dir = '' test = Detection_Model(model_name='yolov5', weight_path=weight_path) test.yolov5_detect(source=source, save_dir=save_dir,) """ if __name__ == '__main__': main() 报错部分代码是这样的,我该怎么做,才能让我的推理正常跑通呢
最新发布
09-08
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license import contextlib import pickle import re import types from copy import deepcopy from pathlib import Path from .AddModules import * import thop import torch import torch.nn as nn from ultralytics.nn.modules import ( AIFI, C1, C2, C2PSA, C3, C3TR, ELAN1, OBB, PSA, SPP, SPPELAN, SPPF, AConv, ADown, Bottleneck, BottleneckCSP, C2f, C2fAttn, C2fCIB, C2fPSA, C3Ghost, C3k2, C3x, CBFuse, CBLinear, Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Focus, GhostBottleneck, GhostConv, HGBlock, HGStem, ImagePoolingAttn, Index, Pose, RepC3, RepConv, RepNCSPELAN4, RepVGGDW, ResNetLayer, RTDETRDecoder, SCDown, Segment, TorchVision, WorldDetect, v10Detect, A2C2f, ) from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml from ultralytics.utils.loss import ( E2EDetectLoss, v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss, ) from ultralytics.utils.ops import make_divisible from ultralytics.utils.plotting import feature_visualization from ultralytics.utils.torch_utils import ( fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights, intersect_dicts, model_info, scale_img, time_sync, ) from .AddModules import * class BaseModel(nn.Module): """The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family.""" def forward(self, x, *args, **kwargs): """ Perform forward pass of the model for either training or inference. If x is a dict, calculates and returns the loss for training. Otherwise, returns predictions for inference. Args: x (torch.Tensor | dict): Input tensor for inference, or dict with image tensor and labels for training. *args (Any): Variable length argument list. **kwargs (Any): Arbitrary keyword arguments. Returns: (torch.Tensor): Loss if x is a dict (training), or network predictions (inference). """ if isinstance(x, dict): # for cases of training and validating while training. return self.loss(x, *args, **kwargs) return self.predict(x, *args, **kwargs) def predict(self, x, profile=False, visualize=False, augment=False, embed=None): """ Perform a forward pass through the network. Args: x (torch.Tensor): The input tensor to the model. profile (bool): Print the computation time of each layer if True, defaults to False. visualize (bool): Save the feature maps of the model if True, defaults to False. augment (bool): Augment image during prediction, defaults to False. embed (list, optional): A list of feature vectors/embeddings to return. Returns: (torch.Tensor): The last output of the model. """ if augment: return self._predict_augment(x) return self._predict_once(x, profile, visualize, embed) def _predict_once(self, x, profile=False, visualize=False, embed=None): y, dt, embeddings = [], [], [] # outputs for m in self.model: if m.f != -1: # if not from previous layer x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers if profile: self._profile_one_layer(m, x, dt) if hasattr(m, 'backbone'): x = m(x) if len(x) != 5: # 0 - 5 x.insert(0, None) for index, i in enumerate(x): if index in self.save: y.append(i) else: y.append(None) x = x[-1] # 最后一个输出传给下一层 else: x = m(x) # run y.append(x if m.i in self.save else None) # save output if visualize: feature_visualization(x, m.type, m.i, save_dir=visualize) if embed and m.i in embed: embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten if m.i == max(embed): return torch.unbind(torch.cat(embeddings, 1), dim=0) return x def _predict_augment(self, x): """Perform augmentations on input image x and return augmented inference.""" LOGGER.warning( f"WARNING ⚠️ {self.__class__.__name__} does not support 'augment=True' prediction. " f"Reverting to single-scale prediction." ) return self._predict_once(x) def _profile_one_layer(self, m, x, dt): """ Profile the computation time and FLOPs of a single layer of the model on a given input. Appends the results to the provided list. Args: m (nn.Module): The layer to be profiled. x (torch.Tensor): The input data to the layer. dt (list): A list to store the computation time of the layer. Returns: None """ c = m == self.model[-1] and isinstance(x, list) # is final layer list, copy input as inplace fix flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs t = time_sync() for _ in range(10): m(x.copy() if c else x) dt.append((time_sync() - t) * 100) if m == self.model[0]: LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module") LOGGER.info(f"{dt[-1]:10.2f} {flops:10.2f} {m.np:10.0f} {m.type}") if c: LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total") def fuse(self, verbose=True): """ Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer, in order to improve the computation efficiency. Returns: (nn.Module): The fused model is returned. """ if not self.is_fused(): for m in self.model.modules(): if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, "bn"): if isinstance(m, Conv2): m.fuse_convs() m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv delattr(m, "bn") # remove batchnorm m.forward = m.forward_fuse # update forward if isinstance(m, ConvTranspose) and hasattr(m, "bn"): m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn) delattr(m, "bn") # remove batchnorm m.forward = m.forward_fuse # update forward if isinstance(m, RepConv): m.fuse_convs() m.forward = m.forward_fuse # update forward if isinstance(m, RepVGGDW): m.fuse() m.forward = m.forward_fuse self.info(verbose=verbose) return self def is_fused(self, thresh=10): """ Check if the model has less than a certain threshold of BatchNorm layers. Args: thresh (int, optional): The threshold number of BatchNorm layers. Default is 10. Returns: (bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise. """ bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d() return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model def info(self, detailed=False, verbose=True, imgsz=640): """ Prints model information. Args: detailed (bool): if True, prints out detailed information about the model. Defaults to False verbose (bool): if True, prints out the model information. Defaults to False imgsz (int): the size of the image that the model will be trained on. Defaults to 640 """ return model_info(self, detailed=detailed, verbose=verbose, imgsz=imgsz) def _apply(self, fn): """ Applies a function to all the tensors in the model that are not parameters or registered buffers. Args: fn (function): the function to apply to the model Returns: (BaseModel): An updated BaseModel object. """ self = super()._apply(fn) m = self.model[-1] # Detect() if isinstance(m, Detect): # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect m.stride = fn(m.stride) m.anchors = fn(m.anchors) m.strides = fn(m.strides) return self def load(self, weights, verbose=True): """ Load the weights into the model. Args: weights (dict | torch.nn.Module): The pre-trained weights to be loaded. verbose (bool, optional): Whether to log the transfer progress. Defaults to True. """ model = weights["model"] if isinstance(weights, dict) else weights # torchvision models are not dicts csd = model.float().state_dict() # checkpoint state_dict as FP32 csd = intersect_dicts(csd, self.state_dict()) # intersect self.load_state_dict(csd, strict=False) # load if verbose: LOGGER.info(f"Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights") def loss(self, batch, preds=None): """ Compute loss. Args: batch (dict): Batch to compute loss on preds (torch.Tensor | List[torch.Tensor]): Predictions. """ if getattr(self, "criterion", None) is None: self.criterion = self.init_criterion() preds = self.forward(batch["img"]) if preds is None else preds return self.criterion(preds, batch) def init_criterion(self): """Initialize the loss criterion for the BaseModel.""" raise NotImplementedError("compute_loss() needs to be implemented by task heads") class DetectionModel(BaseModel): """YOLOv8 detection model.""" def __init__(self, cfg="yolov8n.yaml", ch=3, nc=None, verbose=True): # model, input channels, number of classes """Initialize the YOLOv8 detection model with the given config and parameters.""" super().__init__() self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict if self.yaml["backbone"][0][2] == "Silence": LOGGER.warning( "WARNING ⚠️ YOLOv9 `Silence` module is deprecated in favor of nn.Identity. " "Please delete local *.pt file and re-download the latest model checkpoint." ) self.yaml["backbone"][0][2] = "nn.Identity" # Define model ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels if nc and nc != self.yaml["nc"]: LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}") self.yaml["nc"] = nc # override YAML value self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict self.inplace = self.yaml.get("inplace", True) self.end2end = getattr(self.model[-1], "end2end", False) # Build strides m = self.model[-1] # Detect() if isinstance(m, Detect): # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect s = 256 # 2x min stride m.inplace = self.inplace def _forward(x): """Performs a forward pass through the model, handling different Detect subclass types accordingly.""" if self.end2end: return self.forward(x)["one2many"] return self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x) m.stride = torch.tensor([s / x.shape[-2] for x in _forward(torch.zeros(1, ch, s, s))]) # forward self.stride = m.stride m.bias_init() # only run once else: self.stride = torch.Tensor([32]) # default stride for i.e. RTDETR # Init weights, biases initialize_weights(self) if verbose: self.info() LOGGER.info("") def _predict_augment(self, x): """Perform augmentations on input image x and return augmented inference and train outputs.""" if getattr(self, "end2end", False) or self.__class__.__name__ != "DetectionModel": LOGGER.warning("WARNING ⚠️ Model does not support 'augment=True', reverting to single-scale prediction.") return self._predict_once(x) img_size = x.shape[-2:] # height, width s = [1, 0.83, 0.67] # scales f = [None, 3, None] # flips (2-ud, 3-lr) y = [] # outputs for si, fi in zip(s, f): xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max())) yi = super().predict(xi)[0] # forward yi = self._descale_pred(yi, fi, si, img_size) y.append(yi) y = self._clip_augmented(y) # clip augmented tails return torch.cat(y, -1), None # augmented inference, train @staticmethod def _descale_pred(p, flips, scale, img_size, dim=1): """De-scale predictions following augmented inference (inverse operation).""" p[:, :4] /= scale # de-scale x, y, wh, cls = p.split((1, 1, 2, p.shape[dim] - 4), dim) if flips == 2: y = img_size[0] - y # de-flip ud elif flips == 3: x = img_size[1] - x # de-flip lr return torch.cat((x, y, wh, cls), dim) def _clip_augmented(self, y): """Clip YOLO augmented inference tails.""" nl = self.model[-1].nl # number of detection layers (P3-P5) g = sum(4**x for x in range(nl)) # grid points e = 1 # exclude layer count i = (y[0].shape[-1] // g) * sum(4**x for x in range(e)) # indices y[0] = y[0][..., :-i] # large i = (y[-1].shape[-1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices y[-1] = y[-1][..., i:] # small return y def init_criterion(self): """Initialize the loss criterion for the DetectionModel.""" return E2EDetectLoss(self) if getattr(self, "end2end", False) else v8DetectionLoss(self) class OBBModel(DetectionModel): """YOLOv8 Oriented Bounding Box (OBB) model.""" def __init__(self, cfg="yolov8n-obb.yaml", ch=3, nc=None, verbose=True): """Initialize YOLOv8 OBB model with given config and parameters.""" super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) def init_criterion(self): """Initialize the loss criterion for the model.""" return v8OBBLoss(self) class SegmentationModel(DetectionModel): """YOLOv8 segmentation model.""" def __init__(self, cfg="yolov8n-seg.yaml", ch=3, nc=None, verbose=True): """Initialize YOLOv8 segmentation model with given config and parameters.""" super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) def init_criterion(self): """Initialize the loss criterion for the SegmentationModel.""" return v8SegmentationLoss(self) class PoseModel(DetectionModel): """YOLOv8 pose model.""" def __init__(self, cfg="yolov8n-pose.yaml", ch=3, nc=None, data_kpt_shape=(None, None), verbose=True): """Initialize YOLOv8 Pose model.""" if not isinstance(cfg, dict): cfg = yaml_model_load(cfg) # load model YAML if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg["kpt_shape"]): LOGGER.info(f"Overriding model.yaml kpt_shape={cfg['kpt_shape']} with kpt_shape={data_kpt_shape}") cfg["kpt_shape"] = data_kpt_shape super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) def init_criterion(self): """Initialize the loss criterion for the PoseModel.""" return v8PoseLoss(self) class ClassificationModel(BaseModel): """YOLOv8 classification model.""" def __init__(self, cfg="yolov8n-cls.yaml", ch=3, nc=None, verbose=True): """Init ClassificationModel with YAML, channels, number of classes, verbose flag.""" super().__init__() self._from_yaml(cfg, ch, nc, verbose) def _from_yaml(self, cfg, ch, nc, verbose): """Set YOLOv8 model configurations and define the model architecture.""" self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict # Define model ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels if nc and nc != self.yaml["nc"]: LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}") self.yaml["nc"] = nc # override YAML value elif not nc and not self.yaml.get("nc", None): raise ValueError("nc not specified. Must specify nc in model.yaml or function arguments.") self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist self.stride = torch.Tensor([1]) # no stride constraints self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict self.info() @staticmethod def reshape_outputs(model, nc): """Update a TorchVision classification model to class count 'n' if required.""" name, m = list((model.model if hasattr(model, "model") else model).named_children())[-1] # last module if isinstance(m, Classify): # YOLO Classify() head if m.linear.out_features != nc: m.linear = nn.Linear(m.linear.in_features, nc) elif isinstance(m, nn.Linear): # ResNet, EfficientNet if m.out_features != nc: setattr(model, name, nn.Linear(m.in_features, nc)) elif isinstance(m, nn.Sequential): types = [type(x) for x in m] if nn.Linear in types: i = len(types) - 1 - types[::-1].index(nn.Linear) # last nn.Linear index if m[i].out_features != nc: m[i] = nn.Linear(m[i].in_features, nc) elif nn.Conv2d in types: i = len(types) - 1 - types[::-1].index(nn.Conv2d) # last nn.Conv2d index if m[i].out_channels != nc: m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None) def init_criterion(self): """Initialize the loss criterion for the ClassificationModel.""" return v8ClassificationLoss() class RTDETRDetectionModel(DetectionModel): """ RTDETR (Real-time DEtection and Tracking using Transformers) Detection Model class. This class is responsible for constructing the RTDETR architecture, defining loss functions, and facilitating both the training and inference processes. RTDETR is an object detection and tracking model that extends from the DetectionModel base class. Attributes: cfg (str): The configuration file path or preset string. Default is 'rtdetr-l.yaml'. ch (int): Number of input channels. Default is 3 (RGB). nc (int, optional): Number of classes for object detection. Default is None. verbose (bool): Specifies if summary statistics are shown during initialization. Default is True. Methods: init_criterion: Initializes the criterion used for loss calculation. loss: Computes and returns the loss during training. predict: Performs a forward pass through the network and returns the output. """ def __init__(self, cfg="rtdetr-l.yaml", ch=3, nc=None, verbose=True): """ Initialize the RTDETRDetectionModel. Args: cfg (str): Configuration file name or path. ch (int): Number of input channels. nc (int, optional): Number of classes. Defaults to None. verbose (bool, optional): Print additional information during initialization. Defaults to True. """ super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) def init_criterion(self): """Initialize the loss criterion for the RTDETRDetectionModel.""" from ultralytics.models.utils.loss import RTDETRDetectionLoss return RTDETRDetectionLoss(nc=self.nc, use_vfl=True) def loss(self, batch, preds=None): """ Compute the loss for the given batch of data. Args: batch (dict): Dictionary containing image and label data. preds (torch.Tensor, optional): Precomputed model predictions. Defaults to None. Returns: (tuple): A tuple containing the total loss and main three losses in a tensor. """ if not hasattr(self, "criterion"): self.criterion = self.init_criterion() img = batch["img"] # NOTE: preprocess gt_bbox and gt_labels to list. bs = len(img) batch_idx = batch["batch_idx"] gt_groups = [(batch_idx == i).sum().item() for i in range(bs)] targets = { "cls": batch["cls"].to(img.device, dtype=torch.long).view(-1), "bboxes": batch["bboxes"].to(device=img.device), "batch_idx": batch_idx.to(img.device, dtype=torch.long).view(-1), "gt_groups": gt_groups, } preds = self.predict(img, batch=targets) if preds is None else preds dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds if self.training else preds[1] if dn_meta is None: dn_bboxes, dn_scores = None, None else: dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta["dn_num_split"], dim=2) dn_scores, dec_scores = torch.split(dec_scores, dn_meta["dn_num_split"], dim=2) dec_bboxes = torch.cat([enc_bboxes.unsqueeze(0), dec_bboxes]) # (7, bs, 300, 4) dec_scores = torch.cat([enc_scores.unsqueeze(0), dec_scores]) loss = self.criterion( (dec_bboxes, dec_scores), targets, dn_bboxes=dn_bboxes, dn_scores=dn_scores, dn_meta=dn_meta ) # NOTE: There are like 12 losses in RTDETR, backward with all losses but only show the main three losses. return sum(loss.values()), torch.as_tensor( [loss[k].detach() for k in ["loss_giou", "loss_class", "loss_bbox"]], device=img.device ) def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None): """ Perform a forward pass through the model. Args: x (torch.Tensor): The input tensor. profile (bool, optional): If True, profile the computation time for each layer. Defaults to False. visualize (bool, optional): If True, save feature maps for visualization. Defaults to False. batch (dict, optional): Ground truth data for evaluation. Defaults to None. augment (bool, optional): If True, perform data augmentation during inference. Defaults to False. embed (list, optional): A list of feature vectors/embeddings to return. Returns: (torch.Tensor): Model's output tensor. """ y, dt, embeddings = [], [], [] # outputs for m in self.model[:-1]: # except the head part if m.f != -1: # if not from previous layer x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers if profile: self._profile_one_layer(m, x, dt) x = m(x) # run y.append(x if m.i in self.save else None) # save output if visualize: feature_visualization(x, m.type, m.i, save_dir=visualize) if embed and m.i in embed: embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten if m.i == max(embed): return torch.unbind(torch.cat(embeddings, 1), dim=0) head = self.model[-1] x = head([y[j] for j in head.f], batch) # head inference return x class WorldModel(DetectionModel): """YOLOv8 World Model.""" def __init__(self, cfg="yolov8s-world.yaml", ch=3, nc=None, verbose=True): """Initialize YOLOv8 world model with given config and parameters.""" self.txt_feats = torch.randn(1, nc or 80, 512) # features placeholder self.clip_model = None # CLIP model placeholder super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) def set_classes(self, text, batch=80, cache_clip_model=True): """Set classes in advance so that model could do offline-inference without clip model.""" try: import clip except ImportError: check_requirements("git+https://github.com/ultralytics/CLIP.git") import clip if ( not getattr(self, "clip_model", None) and cache_clip_model ): # for backwards compatibility of models lacking clip_model attribute self.clip_model = clip.load("ViT-B/32")[0] model = self.clip_model if cache_clip_model else clip.load("ViT-B/32")[0] device = next(model.parameters()).device text_token = clip.tokenize(text).to(device) txt_feats = [model.encode_text(token).detach() for token in text_token.split(batch)] txt_feats = txt_feats[0] if len(txt_feats) == 1 else torch.cat(txt_feats, dim=0) txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True) self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1]) self.model[-1].nc = len(text) def predict(self, x, profile=False, visualize=False, txt_feats=None, augment=False, embed=None): """ Perform a forward pass through the model. Args: x (torch.Tensor): The input tensor. profile (bool, optional): If True, profile the computation time for each layer. Defaults to False. visualize (bool, optional): If True, save feature maps for visualization. Defaults to False. txt_feats (torch.Tensor): The text features, use it if it's given. Defaults to None. augment (bool, optional): If True, perform data augmentation during inference. Defaults to False. embed (list, optional): A list of feature vectors/embeddings to return. Returns: (torch.Tensor): Model's output tensor. """ txt_feats = (self.txt_feats if txt_feats is None else txt_feats).to(device=x.device, dtype=x.dtype) if len(txt_feats) != len(x): txt_feats = txt_feats.repeat(len(x), 1, 1) ori_txt_feats = txt_feats.clone() y, dt, embeddings = [], [], [] # outputs for m in self.model: # except the head part if m.f != -1: # if not from previous layer x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers if profile: self._profile_one_layer(m, x, dt) if isinstance(m, C2fAttn): x = m(x, txt_feats) elif isinstance(m, WorldDetect): x = m(x, ori_txt_feats) elif isinstance(m, ImagePoolingAttn): txt_feats = m(x, txt_feats) else: x = m(x) # run y.append(x if m.i in self.save else None) # save output if visualize: feature_visualization(x, m.type, m.i, save_dir=visualize) if embed and m.i in embed: embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten if m.i == max(embed): return torch.unbind(torch.cat(embeddings, 1), dim=0) return x def loss(self, batch, preds=None): """ Compute loss. Args: batch (dict): Batch to compute loss on. preds (torch.Tensor | List[torch.Tensor]): Predictions. """ if not hasattr(self, "criterion"): self.criterion = self.init_criterion() if preds is None: preds = self.forward(batch["img"], txt_feats=batch["txt_feats"]) return self.criterion(preds, batch) class Ensemble(nn.ModuleList): """Ensemble of models.""" def __init__(self): """Initialize an ensemble of models.""" super().__init__() def forward(self, x, augment=False, profile=False, visualize=False): """Function generates the YOLO network's final layer.""" y = [module(x, augment, profile, visualize)[0] for module in self] # y = torch.stack(y).max(0)[0] # max ensemble # y = torch.stack(y).mean(0) # mean ensemble y = torch.cat(y, 2) # nms ensemble, y shape(B, HW, C) return y, None # inference, train output # Functions ------------------------------------------------------------------------------------------------------------ @contextlib.contextmanager def temporary_modules(modules=None, attributes=None): """ Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`). This function can be used to change the module paths during runtime. It's useful when refactoring code, where you've moved a module from one location to another, but you still want to support the old import paths for backwards compatibility. Args: modules (dict, optional): A dictionary mapping old module paths to new module paths. attributes (dict, optional): A dictionary mapping old module attributes to new module attributes. Example: ```python with temporary_modules({"old.module": "new.module"}, {"old.module.attribute": "new.module.attribute"}): import old.module # this will now import new.module from old.module import attribute # this will now import new.module.attribute ``` Note: The changes are only in effect inside the context manager and are undone once the context manager exits. Be aware that directly manipulating `sys.modules` can lead to unpredictable results, especially in larger applications or libraries. Use this function with caution. """ if modules is None: modules = {} if attributes is None: attributes = {} import sys from importlib import import_module try: # Set attributes in sys.modules under their old name for old, new in attributes.items(): old_module, old_attr = old.rsplit(".", 1) new_module, new_attr = new.rsplit(".", 1) setattr(import_module(old_module), old_attr, getattr(import_module(new_module), new_attr)) # Set modules in sys.modules under their old name for old, new in modules.items(): sys.modules[old] = import_module(new) yield finally: # Remove the temporary module paths for old in modules: if old in sys.modules: del sys.modules[old] class SafeClass: """A placeholder class to replace unknown classes during unpickling.""" def __init__(self, *args, **kwargs): """Initialize SafeClass instance, ignoring all arguments.""" pass def __call__(self, *args, **kwargs): """Run SafeClass instance, ignoring all arguments.""" pass class SafeUnpickler(pickle.Unpickler): """Custom Unpickler that replaces unknown classes with SafeClass.""" def find_class(self, module, name): """Attempt to find a class, returning SafeClass if not among safe modules.""" safe_modules = ( "torch", "collections", "collections.abc", "builtins", "math", "numpy", # Add other modules considered safe ) if module in safe_modules: return super().find_class(module, name) else: return SafeClass def torch_safe_load(weight, safe_only=False): """ Attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches the error, logs a warning message, and attempts to install the missing module via the check_requirements() function. After installation, the function again attempts to load the model using torch.load(). Args: weight (str): The file path of the PyTorch model. safe_only (bool): If True, replace unknown classes with SafeClass during loading. Example: ```python from ultralytics.nn.tasks import torch_safe_load ckpt, file = torch_safe_load("path/to/best.pt", safe_only=True) ``` Returns: ckpt (dict): The loaded model checkpoint. file (str): The loaded filename """ from ultralytics.utils.downloads import attempt_download_asset check_suffix(file=weight, suffix=".pt") file = attempt_download_asset(weight) # search online if missing locally try: with temporary_modules( modules={ "ultralytics.yolo.utils": "ultralytics.utils", "ultralytics.yolo.v8": "ultralytics.models.yolo", "ultralytics.yolo.data": "ultralytics.data", }, attributes={ "ultralytics.nn.modules.block.Silence": "torch.nn.Identity", # YOLOv9e "ultralytics.nn.tasks.YOLOv10DetectionModel": "ultralytics.nn.tasks.DetectionModel", # YOLOv10 "ultralytics.utils.loss.v10DetectLoss": "ultralytics.utils.loss.E2EDetectLoss", # YOLOv10 }, ): if safe_only: # Load via custom pickle module safe_pickle = types.ModuleType("safe_pickle") safe_pickle.Unpickler = SafeUnpickler safe_pickle.load = lambda file_obj: SafeUnpickler(file_obj).load() with open(file, "rb") as f: ckpt = torch.load(f, pickle_module=safe_pickle) else: ckpt = torch.load(file, map_location="cpu") except ModuleNotFoundError as e: # e.name is missing module name if e.name == "models": raise TypeError( emojis( f"ERROR ❌️ {weight} appears to be an Ultralytics YOLOv5 model originally trained " f"with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with " f"YOLOv8 at https://github.com/ultralytics/ultralytics." f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to " f"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolov8n.pt'" ) ) from e LOGGER.warning( f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in Ultralytics requirements." f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future." f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to " f"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolov8n.pt'" ) check_requirements(e.name) # install missing module ckpt = torch.load(file, map_location="cpu") if not isinstance(ckpt, dict): # File is likely a YOLO instance saved with i.e. torch.save(model, "saved_model.pt") LOGGER.warning( f"WARNING ⚠️ The file '{weight}' appears to be improperly saved or formatted. " f"For optimal results, use model.save('filename.pt') to correctly save YOLO models." ) ckpt = {"model": ckpt.model} return ckpt, file def attempt_load_weights(weights, device=None, inplace=True, fuse=False): """Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a.""" ensemble = Ensemble() for w in weights if isinstance(weights, list) else [weights]: ckpt, w = torch_safe_load(w) # load ckpt args = {**DEFAULT_CFG_DICT, **ckpt["train_args"]} if "train_args" in ckpt else None # combined args model = (ckpt.get("ema") or ckpt["model"]).to(device).float() # FP32 model # Model compatibility updates model.args = args # attach args to model model.pt_path = w # attach *.pt file path to model model.task = guess_model_task(model) if not hasattr(model, "stride"): model.stride = torch.tensor([32.0]) # Append ensemble.append(model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval()) # model in eval mode # Module updates for m in ensemble.modules(): if hasattr(m, "inplace"): m.inplace = inplace elif isinstance(m, nn.Upsample) and not hasattr(m, "recompute_scale_factor"): m.recompute_scale_factor = None # torch 1.11.0 compatibility # Return model if len(ensemble) == 1: return ensemble[-1] # Return ensemble LOGGER.info(f"Ensemble created with {weights}\n") for k in "names", "nc", "yaml": setattr(ensemble, k, getattr(ensemble[0], k)) ensemble.stride = ensemble[int(torch.argmax(torch.tensor([m.stride.max() for m in ensemble])))].stride assert all(ensemble[0].nc == m.nc for m in ensemble), f"Models differ in class counts {[m.nc for m in ensemble]}" return ensemble def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False): """Loads a single model weights.""" ckpt, weight = torch_safe_load(weight) # load ckpt args = {**DEFAULT_CFG_DICT, **(ckpt.get("train_args", {}))} # combine model and default args, preferring model args model = (ckpt.get("ema") or ckpt["model"]).to(device).float() # FP32 model # Model compatibility updates model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model model.pt_path = weight # attach *.pt file path to model model.task = guess_model_task(model) if not hasattr(model, "stride"): model.stride = torch.tensor([32.0]) model = model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval() # model in eval mode # Module updates for m in model.modules(): if hasattr(m, "inplace"): m.inplace = inplace elif isinstance(m, nn.Upsample) and not hasattr(m, "recompute_scale_factor"): m.recompute_scale_factor = None # torch 1.11.0 compatibility # Return model and ckpt return model, ckpt def parse_model(d, ch, verbose=True): # model_dict, input_channels(3) """Parse a YOLO model.yaml dictionary into a PyTorch model.""" import ast # Args legacy = True # backward compatibility for v3/v5/v8/v9 models max_channels = float("inf") nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales")) depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape")) if scales: scale = d.get("scale") if not scale: scale = tuple(scales.keys())[0] LOGGER.warning(f"WARNING ⚠️ no model scale passed. Assuming scale='{scale}'.") depth, width, max_channels = scales[scale] if act: Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU() if verbose: LOGGER.info(f"{colorstr('activation:')} {act}") # print if verbose: LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}") ch = [ch] layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out backbone = False for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args t=m m = getattr(torch.nn, m[3:]) if "nn." in m else globals()[m] # get module for j, a in enumerate(args): if isinstance(a, str): with contextlib.suppress(ValueError): args[j] = locals()[a] if a in locals() else ast.literal_eval(a) n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain if m in { Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, C2fPSA, C2PSA, DWConv, Focus, BottleneckCSP, C1, C2, C2f, C3k2, RepNCSPELAN4, ELAN1, ADown, AConv, SPPELAN, C2fAttn, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, RepC3, PSA, SCDown, C2fCIB, A2C2f, }: c1, c2 = ch[f], args[0] if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output) c2 = make_divisible(min(c2, max_channels) * width, 8) if m is C2fAttn: args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8) # embed channels args[2] = int( max(round(min(args[2], max_channels // 2 // 32)) * width, 1) if args[2] > 1 else args[2] ) # num heads args = [c1, c2, *args[1:]] if m in { BottleneckCSP, C1, C2, C2f, C3k2, C2fAttn, C3, C3TR, C3Ghost, C3x, RepC3, C2fPSA, C2fCIB, C2PSA, A2C2f, }: args.insert(2, n) # number of repeats n = 1 if m is C3k2: # for M/L/X sizes legacy = False if scale in "mlx": args[3] = True if m is A2C2f: legacy = False if scale in "lx": # for L/X sizes args.append(True) args.append(1.2) elif m is AIFI: args = [ch[f], *args] elif m in {HGStem, HGBlock}: c1, cm, c2 = ch[f], args[0], args[1] args = [c1, cm, c2, *args[2:]] if m is HGBlock: args.insert(4, n) # number of repeats n = 1 elif m is ResNetLayer: c2 = args[1] if args[3] else args[1] * 4 elif m is nn.BatchNorm2d: args = [ch[f]] #LSKNet elif m in {LSKNET_T,LSKNET_S}: m = m(*args) c2 = m.width_list backbone =True #BiFPN elif m is BiFPN: length = len([ch[x] for x in f]) args = [length] elif m is Concat: c2 = sum(ch[x] for x in f) elif m in {Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn, v10Detect}: args.append([ch[x] for x in f]) if m is Segment: args[2] = make_divisible(min(args[2], max_channels) * width, 8) if m in {Detect, Segment, Pose, OBB}: m.legacy = legacy elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1 args.insert(1, [ch[x] for x in f]) elif m in {CBLinear, TorchVision, Index}: c2 = args[0] c1 = ch[f] args = [c1, c2, *args[1:]] elif m is CBFuse: c2 = ch[f[-1]] else: c2 = ch[f] # m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module # t = str(m)[8:-2].replace("__main__.", "") # module type # m_.np = sum(x.numel() for x in m_.parameters()) # number params # m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type # if verbose: # LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{m_.np:10.0f} {t:<45}{str(args):<30}") # print # save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist # layers.append(m_) # if i == 0: # ch = [] # ch.append(c2) #替换上面的 if isinstance(c2, list): backbone = True m_ = m m_.backbone = True else: m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module t = str(m)[8:-2].replace('__main__.', '') # module type m.np = sum(x.numel() for x in m_.parameters()) # number params m_.i, m_.f, m_.type = i + 4 if backbone else i, f, t # attach index, 'from' index, type if verbose: LOGGER.info(f'{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f} {t:<45}{str(args):<30}') # print save.extend(x % (i + 4 if backbone else i) for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist layers.append(m_) if i == 0: ch = [] if isinstance(c2, list): ch.extend(c2) for _ in range(5 - len(ch)): ch.insert(0, 0) else: ch.append(c2) return nn.Sequential(*layers), sorted(save) def yaml_model_load(path): """Load a YOLOv8 model from a YAML file.""" path = Path(path) if path.stem in (f"yolov{d}{x}6" for x in "nsmlx" for d in (5, 8)): new_stem = re.sub(r"(\d+)([nslmx])6(.+)?$", r"\1\2-p6\3", path.stem) LOGGER.warning(f"WARNING ⚠️ Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.") path = path.with_name(new_stem + path.suffix) unified_path = re.sub(r"(\d+)([nslmx])(.+)?$", r"\1\3", str(path)) # i.e. yolov8x.yaml -> yolov8.yaml yaml_file = check_yaml(unified_path, hard=False) or check_yaml(path) d = yaml_load(yaml_file) # model dict d["scale"] = guess_model_scale(path) d["yaml_file"] = str(path) return d def guess_model_scale(model_path): """ Takes a path to a YOLO model's YAML file as input and extracts the size character of the model's scale. The function uses regular expression matching to find the pattern of the model scale in the YAML file name, which is denoted by n, s, m, l, or x. The function returns the size character of the model scale as a string. Args: model_path (str | Path): The path to the YOLO model's YAML file. Returns: (str): The size character of the model's scale, which can be n, s, m, l, or x. """ try: return re.search(r"yolo[v]?\d+([nslmx])", Path(model_path).stem).group(1) # noqa, returns n, s, m, l, or x except AttributeError: return "" def guess_model_task(model): """ Guess the task of a PyTorch model from its architecture or configuration. Args: model (nn.Module | dict): PyTorch model or model configuration in YAML format. Returns: (str): Task of the model ('detect', 'segment', 'classify', 'pose'). Raises: SyntaxError: If the task of the model could not be determined. """ def cfg2task(cfg): """Guess from YAML dictionary.""" m = cfg["head"][-1][-2].lower() # output module name if m in {"classify", "classifier", "cls", "fc"}: return "classify" if "detect" in m: return "detect" if m == "segment": return "segment" if m == "pose": return "pose" if m == "obb": return "obb" # Guess from model cfg if isinstance(model, dict): with contextlib.suppress(Exception): return cfg2task(model) # Guess from PyTorch model if isinstance(model, nn.Module): # PyTorch model for x in "model.args", "model.model.args", "model.model.model.args": with contextlib.suppress(Exception): return eval(x)["task"] for x in "model.yaml", "model.model.yaml", "model.model.model.yaml": with contextlib.suppress(Exception): return cfg2task(eval(x)) for m in model.modules(): if isinstance(m, Segment): return "segment" elif isinstance(m, Classify): return "classify" elif isinstance(m, Pose): return "pose" elif isinstance(m, OBB): return "obb" elif isinstance(m, (Detect, WorldDetect, v10Detect)): return "detect" # Guess from model filename if isinstance(model, (str, Path)): model = Path(model) if "-seg" in model.stem or "segment" in model.parts: return "segment" elif "-cls" in model.stem or "classify" in model.parts: return "classify" elif "-pose" in model.stem or "pose" in model.parts: return "pose" elif "-obb" in model.stem or "obb" in model.parts: return "obb" elif "detect" in model.parts: return "detect" # Unable to determine task from model LOGGER.warning( "WARNING ⚠️ Unable to automatically guess model task, assuming 'task=detect'. " "Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify','pose' or 'obb'." ) return "detect" # assume detect 这上面是task.py文件的内容。下面是LSKNet.py文件的内容。 import torch import torch.nn as nn from torch.nn.modules.utils import _pair as to_2tuple from timm.models.layers import DropPath, to_2tuple from functools import partial import warnings __all__ = ['LSKNET_T', 'LSKNET_S'] class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Conv2d(in_features, hidden_features, 1) self.dwconv = DWConv(hidden_features) self.act = act_layer() self.fc2 = nn.Conv2d(hidden_features, out_features, 1) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.dwconv(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class LSKblock(nn.Module): def __init__(self, dim): super().__init__() self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim) self.conv_spatial = nn.Conv2d(dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3) self.conv1 = nn.Conv2d(dim, dim // 2, 1) self.conv2 = nn.Conv2d(dim, dim // 2, 1) self.conv_squeeze = nn.Conv2d(2, 2, 7, padding=3) self.conv = nn.Conv2d(dim // 2, dim, 1) def forward(self, x): attn1 = self.conv0(x) attn2 = self.conv_spatial(attn1) attn1 = self.conv1(attn1) attn2 = self.conv2(attn2) attn = torch.cat([attn1, attn2], dim=1) avg_attn = torch.mean(attn, dim=1, keepdim=True) max_attn, _ = torch.max(attn, dim=1, keepdim=True) agg = torch.cat([avg_attn, max_attn], dim=1) sig = self.conv_squeeze(agg).sigmoid() attn = attn1 * sig[:, 0, :, :].unsqueeze(1) + attn2 * sig[:, 1, :, :].unsqueeze(1) attn = self.conv(attn) return x * attn class Attention(nn.Module): def __init__(self, d_model): super().__init__() self.proj_1 = nn.Conv2d(d_model, d_model, 1) self.activation = nn.GELU() self.spatial_gating_unit = LSKblock(d_model) self.proj_2 = nn.Conv2d(d_model, d_model, 1) def forward(self, x): shorcut = x.clone() x = self.proj_1(x) x = self.activation(x) x = self.spatial_gating_unit(x) x = self.proj_2(x) x = x + shorcut return x class Block(nn.Module): def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_cfg=None): super().__init__() if norm_cfg: self.norm1 = nn.BatchNorm2d(norm_cfg, dim) self.norm2 = nn.BatchNorm2d(norm_cfg, dim) else: self.norm1 = nn.BatchNorm2d(dim) self.norm2 = nn.BatchNorm2d(dim) self.attn = Attention(dim) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) layer_scale_init_value = 1e-2 self.layer_scale_1 = nn.Parameter( layer_scale_init_value * torch.ones((dim)), requires_grad=True) self.layer_scale_2 = nn.Parameter( layer_scale_init_value * torch.ones((dim)), requires_grad=True) def forward(self, x): x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.attn(self.norm1(x))) x = x + self.drop_path(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x))) return x class OverlapPatchEmbed(nn.Module): """ Image to Patch Embedding """ def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768, norm_cfg=None): super().__init__() patch_size = to_2tuple(patch_size) self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=(patch_size[0] // 2, patch_size[1] // 2)) if norm_cfg: self.norm = nn.BatchNorm2d(norm_cfg, embed_dim) else: self.norm = nn.BatchNorm2d(embed_dim) def forward(self, x): x = self.proj(x) _, _, H, W = x.shape x = self.norm(x) return x, H, W class LSKNet(nn.Module): def __init__(self, img_size=224, in_chans=3, dim=None, embed_dims=[64, 128, 256, 512], mlp_ratios=[8, 8, 4, 4], drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], num_stages=4, pretrained=None, init_cfg=None, norm_cfg=None): super().__init__() assert not (init_cfg and pretrained), \ 'init_cfg and pretrained cannot be set at the same time' if isinstance(pretrained, str): warnings.warn('DeprecationWarning: pretrained is deprecated, ' 'please use "init_cfg" instead') self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) elif pretrained is not None: raise TypeError('pretrained must be a str or None') self.depths = depths self.num_stages = num_stages dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule cur = 0 for i in range(num_stages): patch_embed = OverlapPatchEmbed(img_size=img_size if i == 0 else img_size // (2 ** (i + 1)), patch_size=7 if i == 0 else 3, stride=4 if i == 0 else 2, in_chans=in_chans if i == 0 else embed_dims[i - 1], embed_dim=embed_dims[i], norm_cfg=norm_cfg) block = nn.ModuleList([Block( dim=embed_dims[i], mlp_ratio=mlp_ratios[i], drop=drop_rate, drop_path=dpr[cur + j], norm_cfg=norm_cfg) for j in range(depths[i])]) norm = norm_layer(embed_dims[i]) cur += depths[i] setattr(self, f"patch_embed{i + 1}", patch_embed) setattr(self, f"block{i + 1}", block) setattr(self, f"norm{i + 1}", norm) self.width_list = [i.size(1) for i in self.forward(torch.randn(1, 3, 640, 640))] def freeze_patch_emb(self): self.patch_embed1.requires_grad = False @torch.jit.ignore def no_weight_decay(self): return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better def get_classifier(self): return self.head def reset_classifier(self, num_classes, global_pool=''): self.num_classes = num_classes self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): B = x.shape[0] outs = [] for i in range(self.num_stages): patch_embed = getattr(self, f"patch_embed{i + 1}") block = getattr(self, f"block{i + 1}") norm = getattr(self, f"norm{i + 1}") x, H, W = patch_embed(x) for blk in block: x = blk(x) x = x.flatten(2).transpose(1, 2) x = norm(x) x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() outs.append(x) return outs def forward(self, x): x = self.forward_features(x) # x = self.head(x) return x class DWConv(nn.Module): def __init__(self, dim=768): super(DWConv, self).__init__() self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) def forward(self, x): x = self.dwconv(x) return x def _conv_filter(state_dict, patch_size=16): """ convert patch embedding weight from manual patchify + linear proj to conv""" out_dict = {} for k, v in state_dict.items(): if 'patch_embed.proj.weight' in k: v = v.reshape((v.shape[0], 3, patch_size, patch_size)) out_dict[k] = v return out_dict def LSKNET_T(): model = LSKNet(depths=[2, 2, 2, 2]) return model def LSKNET_S(): model = LSKNet() return model if __name__ == '__main__': model = LSKNet() inputs = torch.randn((1, 3, 640, 640)) for i in model(inputs): print(i.size()) 最下面是BiFPN.py文件,请你结合这三个文件和上面刚刚的运行错误解决这个问题。 import torch.nn as nn import torch class swish(nn.Module): def forward(self, x): return x * torch.sigmoid(x) class BiFPN(nn.Module): def __init__(self, length): super().__init__() self.weight = nn.Parameter(torch.ones(length, dtype=torch.float32), requires_grad=True) self.swish = swish() self.epsilon = 0.0001 def forward(self, x): weights = self.weight / (torch.sum(self.swish(self.weight), dim=0) + self.epsilon) weighted_feature_maps = [weights[i] * x[i] for i in range(len(x))] stacked_feature_maps = torch.stack(weighted_feature_maps, dim=0) result = torch.sum(stacked_feature_maps, dim=0) return result
09-01
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值