(* DIFF_TERM = “TRUE“ *)

       在 Verilog 中,DIFF_TERM 是用于 差分终端(Differential Termination) 的属性,主要应用于差分信号对。

一、DIFF_TERM 概述

  DIFF_TERM 用于在 FPGA 的差分输入端口上启用或禁用内部终端电阻,这对高速差分信号(如 LVDS、TMDS 等)的完整性至关重要。

二、语法格式

(* DIFF_TERM = "TRUE" *)        // 启用差分终端
(* DIFF_TERM = "FALSE" *)       // 禁用差分终端
(* DIFF_TERM = "FALSE", TERM_100OHM = "TRUE" *)  // 特定终端值

三、实际应用示例

1. 基本用法

module diff_signals (
    // 启用差分终端的 LVDS 输入
    (* DIFF_TERM = "TRUE" *)
    input wire LVDS_CLK_P,
    input wire LVDS_CLK_N,
    
    (* DIFF_TERM = "TRUE" *)
    input wire LVDS_DATA_P,
    input wire LVDS_DATA_N,
    
    // 禁用差分终端的输出
    (* DIFF_TERM = "FALSE" *)
    output wire diff_out_P,
    output wire diff_out_N
);

    // 差分信号处理逻辑
    // ...
    
endmodule

2. 具体接口应用

module hdmi_interface (
    // HDMI TMDS 差分输入(启用终端)
    (* DIFF_TERM = "TRUE" *)
    input wire tmds_clk_p,
    input wire tmds_clk_n,
    
    (* DIFF_TERM = "TRUE" *)
    input wire [2:0] tmds_data_p,
    input wire [2:0] tmds_data_n,
    
    // 用户逻辑
    output reg [23:0] rgb_data,
    output reg video_valid
);

    always @(*) begin
        // 差分信号处理逻辑
        // 通常使用厂商原语进行差分转单端
    end

endmodule

四、在约束文件中的使用

1. Xilinx XDC 约束

# 启用差分终端
set_property DIFF_TERM TRUE [get_ports {LVDS_*_P}]
set_property DIFF_TERM TRUE [get_ports {LVDS_*_N}]

# 或者针对特定端口
set_property DIFF_TERM TRUE [get_ports clk200_p]
set_property DIFF_TERM TRUE [get_ports clk200_n]

2. Intel QSF 约束

# Intel Quartus 中的设置
set_instance_assignment -name DIFFERENTIAL_TERMINATION ON -to clk_in_p
set_instance_assignment -name DIFFERENTIAL_TERMINATION ON -to clk_in_n

五、典型应用场景

module high_speed_interface (
    // DDR 接口时钟(启用终端)
    (* DIFF_TERM = "TRUE" *)
    input wire ddr_clk_p,
    input wire ddr_clk_n,
    
    // SATA 接口(启用终端)
    (* DIFF_TERM = "TRUE" *)
    input wire sata_rx_p,
    input wire sata_rx_n,
    
    // PCI Express(启用终端)
    (* DIFF_TERM = "TRUE" *)
    input wire pcie_refclk_p,
    input wire pcie_refclk_n,
    
    // 普通单端信号(禁用终端)
    (* DIFF_TERM = "FALSE" *)
    input wire config_data
);

    // 高速接口逻辑
    // ...
    
endmodule

六、重要注意事项

  1. 输入信号专用DIFF_TERM 通常只用于输入差分信号

  2. 输出信号:差分输出通常不需要也不应该启用 DIFF_TERM

  3. 外部终端:如果已在 PCB 上放置外部终端电阻,应禁用 DIFF_TERM

  4. 功耗考虑:启用差分终端会增加功耗

  5. 厂商差异

    • Xilinx:支持 100Ω 差分终端

    • Intel:支持多种终端值(100Ω, 120Ω 等)

  6. Bank 约束:同一 I/O Bank 中的差分终端设置可能需要一致

七、最佳实践

// 推荐:为所有差分输入启用终端
(* DIFF_TERM = "TRUE" *)
input wire rx_clk_p, rx_clk_n;

(* DIFF_TERM = "TRUE" *)  
input wire rx_data_p, rx_data_n;

// 推荐:为差分输出禁用终端
(* DIFF_TERM = "FALSE" *)
output wire tx_clk_p, tx_clk_n;

(* DIFF_TERM = "FALSE" *)
output wire tx_data_p, tx_data_n;

  DIFF_TERM 是保证高速差分信号完整性的关键属性,正确使用可以显著改善信号质量,减少反射和振铃。

#include "BackgroundSubtractorSuBSENSE.h" #include "DistanceUtils.h" #include "RandUtils.h" #include <iostream> #include <opencv2/imgproc/imgproc.hpp> #include <opencv2/highgui/highgui.hpp> #include <iomanip> /* * * Intrinsic parameters for our method are defined here; tuning these for better * performance should not be required in most cases -- although improvements in * very specific scenarios are always possible. * */ //! defines the threshold value(s) used to detect long-term ghosting and trigger the fast edge-based absorption heuristic #define GHOSTDET_D_MAX (0.010f) // defines 'negligible' change here #define GHOSTDET_S_MIN (0.995f) // defines the required minimum local foreground saturation value //! parameter used to scale dynamic distance threshold adjustments ('R(x)') #define FEEDBACK_R_VAR (0.01f) //! parameters used to adjust the variation step size of 'v(x)' #define FEEDBACK_V_INCR (1.000f) #define FEEDBACK_V_DECR (0.100f) //! parameters used to scale dynamic learning rate adjustments ('T(x)') #define FEEDBACK_T_DECR (0.2500f) #define FEEDBACK_T_INCR (0.5000f) #define FEEDBACK_T_LOWER (2.0000f) #define FEEDBACK_T_UPPER (256.00f) //! parameters used to define 'unstable' regions, based on segm noise/bg dynamics and local dist threshold values #define UNSTABLE_REG_RATIO_MIN (0.100f) #define UNSTABLE_REG_RDIST_MIN (3.000f) //! parameters used to scale the relative LBSP intensity threshold used for internal comparisons #define LBSPDESC_NONZERO_RATIO_MIN (0.100f) #define LBSPDESC_NONZERO_RATIO_MAX (0.500f) //! parameters used to define model reset/learning rate boosts in our frame-level component #define FRAMELEVEL_MIN_COLOR_DIFF_THRESHOLD (m_nMinColorDistThreshold/2) #define FRAMELEVEL_ANALYSIS_DOWNSAMPLE_RATIO (8) // local define used to display debug information #define DISPLAY_SUBSENSE_DEBUG_INFO 0 // local define used to specify the default frame size (320x240 = QVGA) #define DEFAULT_FRAME_SIZE cv::Size(320,240) // local define used to specify the color dist threshold offset used for unstable regions #define STAB_COLOR_DIST_OFFSET (m_nMinColorDistThreshold/5) // local define used to specify the desc dist threshold offset used for unstable regions #define UNSTAB_DESC_DIST_OFFSET (m_nDescDistThresholdOffset) static const size_t s_nColorMaxDataRange_1ch = UCHAR_MAX; static const size_t s_nDescMaxDataRange_1ch = LBSP::DESC_SIZE*8; static const size_t s_nColorMaxDataRange_3ch = s_nColorMaxDataRange_1ch*3; static const size_t s_nDescMaxDataRange_3ch = s_nDescMaxDataRange_1ch*3; BackgroundSubtractorSuBSENSE::BackgroundSubtractorSuBSENSE( float fRelLBSPThreshold ,size_t nDescDistThresholdOffset ,size_t nMinColorDistThreshold ,size_t nBGSamples ,size_t nRequiredBGSamples ,size_t nSamplesForMovingAvgs) : BackgroundSubtractorLBSP(fRelLBSPThreshold) ,m_nMinColorDistThreshold(nMinColorDistThreshold) ,m_nDescDistThresholdOffset(nDescDistThresholdOffset) ,m_nBGSamples(nBGSamples) ,m_nRequiredBGSamples(nRequiredBGSamples) ,m_nSamplesForMovingAvgs(nSamplesForMovingAvgs) ,m_fLastNonZeroDescRatio(0.0f) ,m_bLearningRateScalingEnabled(true) ,m_fCurrLearningRateLowerCap(FEEDBACK_T_LOWER) ,m_fCurrLearningRateUpperCap(FEEDBACK_T_UPPER) ,m_nMedianBlurKernelSize(m_nDefaultMedianBlurKernelSize) ,m_bUse3x3Spread(true) ,m_defaultMorphologyKernel(cv::getStructuringElement(cv::MORPH_RECT, cv::Size(3, 3))) { CV_Assert(m_nBGSamples>0 && m_nRequiredBGSamples<=m_nBGSamples); CV_Assert(m_nMinColorDistThreshold>=STAB_COLOR_DIST_OFFSET); } BackgroundSubtractorSuBSENSE::~BackgroundSubtractorSuBSENSE() { if(m_aPxIdxLUT) delete[] m_aPxIdxLUT; if(m_aPxInfoLUT) delete[] m_aPxInfoLUT; } void BackgroundSubtractorSuBSENSE::initialize(const cv::Mat& oInitImg, const cv::Mat& oROI) { // == init CV_Assert(!oInitImg.empty() && oInitImg.cols>0 && oInitImg.rows>0); CV_Assert(oInitImg.isContinuous()); CV_Assert(oInitImg.type()==CV_8UC3 || oInitImg.type()==CV_8UC1); if(oInitImg.type()==CV_8UC3) { std::vector<cv::Mat> voInitImgChannels; cv::split(oInitImg,voInitImgChannels); if(!cv::countNonZero((voInitImgChannels[0]!=voInitImgChannels[1])|(voInitImgChannels[2]!=voInitImgChannels[1]))) std::cout << std::endl << "\tBackgroundSubtractorSuBSENSE : Warning, grayscale images should always be passed in CV_8UC1 format for optimal performance." << std::endl; } cv::Mat oNewBGROI; if(oROI.empty() && (m_oROI.empty() || oROI.size()!=oInitImg.size())) { oNewBGROI.create(oInitImg.size(),CV_8UC1); oNewBGROI = cv::Scalar_<uchar>(UCHAR_MAX); } else if(oROI.empty()) oNewBGROI = m_oROI; else { CV_Assert(oROI.size()==oInitImg.size() && oROI.type()==CV_8UC1); CV_Assert(cv::countNonZero((oROI<UCHAR_MAX)&(oROI>0))==0); oNewBGROI = oROI.clone(); cv::Mat oTempROI; cv::dilate(oNewBGROI,oTempROI,m_defaultMorphologyKernel,cv::Point(-1,-1),LBSP::PATCH_SIZE/2); cv::bitwise_or(oNewBGROI,oTempROI/2,oNewBGROI); } const size_t nOrigROIPxCount = (size_t)cv::countNonZero(oNewBGROI); CV_Assert(nOrigROIPxCount>0); LBSP::validateROI(oNewBGROI); const size_t nFinalROIPxCount = (size_t)cv::countNonZero(oNewBGROI); CV_Assert(nFinalROIPxCount>0); m_oROI = oNewBGROI; m_oImgSize = oInitImg.size(); m_nImgType = oInitImg.type(); m_nImgChannels = oInitImg.channels(); m_nTotPxCount = m_oImgSize.area(); m_nTotRelevantPxCount = nFinalROIPxCount; m_nFrameIndex = 0; m_nFramesSinceLastReset = 0; m_nModelResetCooldown = 0; m_fLastNonZeroDescRatio = 0.0f; const int nTotImgPixels = m_oImgSize.height*m_oImgSize.width; if(nOrigROIPxCount>=m_nTotPxCount/2 && (int)m_nTotPxCount>=DEFAULT_FRAME_SIZE.area()) { m_bLearningRateScalingEnabled = true; m_bAutoModelResetEnabled = true; m_bUse3x3Spread = !(nTotImgPixels>DEFAULT_FRAME_SIZE.area()*2); const int nRawMedianBlurKernelSize = std::min((int)floor((float)nTotImgPixels/DEFAULT_FRAME_SIZE.area()+0.5f)+m_nDefaultMedianBlurKernelSize,14); m_nMedianBlurKernelSize = (nRawMedianBlurKernelSize%2)?nRawMedianBlurKernelSize:nRawMedianBlurKernelSize-1; m_fCurrLearningRateLowerCap = FEEDBACK_T_LOWER; m_fCurrLearningRateUpperCap = FEEDBACK_T_UPPER; } else { m_bLearningRateScalingEnabled = false; m_bAutoModelResetEnabled = false; m_bUse3x3Spread = true; m_nMedianBlurKernelSize = m_nDefaultMedianBlurKernelSize; m_fCurrLearningRateLowerCap = FEEDBACK_T_LOWER*2; m_fCurrLearningRateUpperCap = FEEDBACK_T_UPPER*2; } m_oUpdateRateFrame.create(m_oImgSize,CV_32FC1); m_oUpdateRateFrame = cv::Scalar(m_fCurrLearningRateLowerCap); m_oDistThresholdFrame.create(m_oImgSize,CV_32FC1); m_oDistThresholdFrame = cv::Scalar(1.0f); m_oVariationModulatorFrame.create(m_oImgSize,CV_32FC1); m_oVariationModulatorFrame = cv::Scalar(10.0f); // should always be >= FEEDBACK_V_DECR m_oMeanLastDistFrame.create(m_oImgSize,CV_32FC1); m_oMeanLastDistFrame = cv::Scalar(0.0f); m_oMeanMinDistFrame_LT.create(m_oImgSize,CV_32FC1); m_oMeanMinDistFrame_LT = cv::Scalar(0.0f); m_oMeanMinDistFrame_ST.create(m_oImgSize,CV_32FC1); m_oMeanMinDistFrame_ST = cv::Scalar(0.0f); m_oDownSampledFrameSize = cv::Size(m_oImgSize.width/FRAMELEVEL_ANALYSIS_DOWNSAMPLE_RATIO,m_oImgSize.height/FRAMELEVEL_ANALYSIS_DOWNSAMPLE_RATIO); m_oMeanDownSampledLastDistFrame_LT.create(m_oDownSampledFrameSize,CV_32FC((int)m_nImgChannels)); m_oMeanDownSampledLastDistFrame_LT = cv::Scalar(0.0f); m_oMeanDownSampledLastDistFrame_ST.create(m_oDownSampledFrameSize,CV_32FC((int)m_nImgChannels)); m_oMeanDownSampledLastDistFrame_ST = cv::Scalar(0.0f); m_oMeanRawSegmResFrame_LT.create(m_oImgSize,CV_32FC1); m_oMeanRawSegmResFrame_LT = cv::Scalar(0.0f); m_oMeanRawSegmResFrame_ST.create(m_oImgSize,CV_32FC1); m_oMeanRawSegmResFrame_ST = cv::Scalar(0.0f); m_oMeanFinalSegmResFrame_LT.create(m_oImgSize,CV_32FC1); m_oMeanFinalSegmResFrame_LT = cv::Scalar(0.0f); m_oMeanFinalSegmResFrame_ST.create(m_oImgSize,CV_32FC1); m_oMeanFinalSegmResFrame_ST = cv::Scalar(0.0f); m_oUnstableRegionMask.create(m_oImgSize,CV_8UC1); m_oUnstableRegionMask = cv::Scalar_<uchar>(0); m_oBlinksFrame.create(m_oImgSize,CV_8UC1); m_oBlinksFrame = cv::Scalar_<uchar>(0); m_oDownSampledFrame_MotionAnalysis.create(m_oDownSampledFrameSize,CV_8UC((int)m_nImgChannels)); m_oDownSampledFrame_MotionAnalysis = cv::Scalar_<uchar>::all(0); m_oLastColorFrame.create(m_oImgSize,CV_8UC((int)m_nImgChannels)); m_oLastColorFrame = cv::Scalar_<uchar>::all(0); m_oLastDescFrame.create(m_oImgSize,CV_16UC((int)m_nImgChannels)); m_oLastDescFrame = cv::Scalar_<ushort>::all(0); m_oLastRawFGMask.create(m_oImgSize,CV_8UC1); m_oLastRawFGMask = cv::Scalar_<uchar>(0); m_oLastFGMask.create(m_oImgSize,CV_8UC1); m_oLastFGMask = cv::Scalar_<uchar>(0); m_oLastFGMask_dilated.create(m_oImgSize,CV_8UC1); m_oLastFGMask_dilated = cv::Scalar_<uchar>(0); m_oLastFGMask_dilated_inverted.create(m_oImgSize,CV_8UC1); m_oLastFGMask_dilated_inverted = cv::Scalar_<uchar>(0); m_oFGMask_FloodedHoles.create(m_oImgSize,CV_8UC1); m_oFGMask_FloodedHoles = cv::Scalar_<uchar>(0); m_oFGMask_PreFlood.create(m_oImgSize,CV_8UC1); m_oFGMask_PreFlood = cv::Scalar_<uchar>(0); m_oCurrRawFGBlinkMask.create(m_oImgSize,CV_8UC1); m_oCurrRawFGBlinkMask = cv::Scalar_<uchar>(0); m_oLastRawFGBlinkMask.create(m_oImgSize,CV_8UC1); m_oLastRawFGBlinkMask = cv::Scalar_<uchar>(0); m_voBGColorSamples.resize(m_nBGSamples); m_voBGDescSamples.resize(m_nBGSamples); for(size_t s=0; s<m_nBGSamples; ++s) { m_voBGColorSamples[s].create(m_oImgSize,CV_8UC((int)m_nImgChannels)); m_voBGColorSamples[s] = cv::Scalar_<uchar>::all(0); m_voBGDescSamples[s].create(m_oImgSize,CV_16UC((int)m_nImgChannels)); m_voBGDescSamples[s] = cv::Scalar_<ushort>::all(0); } if(m_aPxIdxLUT) delete[] m_aPxIdxLUT; if(m_aPxInfoLUT) delete[] m_aPxInfoLUT; m_aPxIdxLUT = new size_t[m_nTotRelevantPxCount]; m_aPxInfoLUT = new PxInfoBase[m_nTotPxCount]; if(m_nImgChannels==1) { CV_Assert(m_oLastColorFrame.step.p[0]==(size_t)m_oImgSize.width && m_oLastColorFrame.step.p[1]==1); CV_Assert(m_oLastDescFrame.step.p[0]==m_oLastColorFrame.step.p[0]*2 && m_oLastDescFrame.step.p[1]==m_oLastColorFrame.step.p[1]*2); for(size_t t=0; t<=UCHAR_MAX; ++t) m_anLBSPThreshold_8bitLUT[t] = cv::saturate_cast<uchar>((m_nLBSPThresholdOffset+t*m_fRelLBSPThreshold)/3); for(size_t nPxIter=0, nModelIter=0; nPxIter<m_nTotPxCount; ++nPxIter) { if(m_oROI.data[nPxIter]) { m_aPxIdxLUT[nModelIter] = nPxIter; m_aPxInfoLUT[nPxIter].nImgCoord_Y = (int)nPxIter/m_oImgSize.width; m_aPxInfoLUT[nPxIter].nImgCoord_X = (int)nPxIter%m_oImgSize.width; m_aPxInfoLUT[nPxIter].nModelIdx = nModelIter; m_oLastColorFrame.data[nPxIter] = oInitImg.data[nPxIter]; const size_t nDescIter = nPxIter*2; LBSP::computeGrayscaleDescriptor(oInitImg,oInitImg.data[nPxIter],m_aPxInfoLUT[nPxIter].nImgCoord_X,m_aPxInfoLUT[nPxIter].nImgCoord_Y,m_anLBSPThreshold_8bitLUT[oInitImg.data[nPxIter]],*((ushort*)(m_oLastDescFrame.data+nDescIter))); ++nModelIter; } } } else { //m_nImgChannels==3 CV_Assert(m_oLastColorFrame.step.p[0]==(size_t)m_oImgSize.width*3 && m_oLastColorFrame.step.p[1]==3); CV_Assert(m_oLastDescFrame.step.p[0]==m_oLastColorFrame.step.p[0]*2 && m_oLastDescFrame.step.p[1]==m_oLastColorFrame.step.p[1]*2); for(size_t t=0; t<=UCHAR_MAX; ++t) m_anLBSPThreshold_8bitLUT[t] = cv::saturate_cast<uchar>(m_nLBSPThresholdOffset+t*m_fRelLBSPThreshold); for(size_t nPxIter=0, nModelIter=0; nPxIter<m_nTotPxCount; ++nPxIter) { if(m_oROI.data[nPxIter]) { m_aPxIdxLUT[nModelIter] = nPxIter; m_aPxInfoLUT[nPxIter].nImgCoord_Y = (int)nPxIter/m_oImgSize.width; m_aPxInfoLUT[nPxIter].nImgCoord_X = (int)nPxIter%m_oImgSize.width; m_aPxInfoLUT[nPxIter].nModelIdx = nModelIter; const size_t nPxRGBIter = nPxIter*3; const size_t nDescRGBIter = nPxRGBIter*2; for(size_t c=0; c<3; ++c) { m_oLastColorFrame.data[nPxRGBIter+c] = oInitImg.data[nPxRGBIter+c]; LBSP::computeSingleRGBDescriptor(oInitImg,oInitImg.data[nPxRGBIter+c],m_aPxInfoLUT[nPxIter].nImgCoord_X,m_aPxInfoLUT[nPxIter].nImgCoord_Y,c,m_anLBSPThreshold_8bitLUT[oInitImg.data[nPxRGBIter+c]],((ushort*)(m_oLastDescFrame.data+nDescRGBIter))[c]); } ++nModelIter; } } } m_bInitialized = true; refreshModel(1.0f); } void BackgroundSubtractorSuBSENSE::refreshModel(float fSamplesRefreshFrac, bool bForceFGUpdate) { // == refresh CV_Assert(m_bInitialized); CV_Assert(fSamplesRefreshFrac>0.0f && fSamplesRefreshFrac<=1.0f); const size_t nModelsToRefresh = fSamplesRefreshFrac<1.0f?(size_t)(fSamplesRefreshFrac*m_nBGSamples):m_nBGSamples; const size_t nRefreshStartPos = fSamplesRefreshFrac<1.0f?rand()%m_nBGSamples:0; if(m_nImgChannels==1) { for(size_t nModelIter=0; nModelIter<m_nTotRelevantPxCount; ++nModelIter) { const size_t nPxIter = m_aPxIdxLUT[nModelIter]; if(bForceFGUpdate || !m_oLastFGMask.data[nPxIter]) { for(size_t nCurrModelIdx=nRefreshStartPos; nCurrModelIdx<nRefreshStartPos+nModelsToRefresh; ++nCurrModelIdx) { int nSampleImgCoord_Y, nSampleImgCoord_X; getRandSamplePosition(nSampleImgCoord_X,nSampleImgCoord_Y,m_aPxInfoLUT[nPxIter].nImgCoord_X,m_aPxInfoLUT[nPxIter].nImgCoord_Y,LBSP::PATCH_SIZE/2,m_oImgSize); const size_t nSamplePxIdx = m_oImgSize.width*nSampleImgCoord_Y + nSampleImgCoord_X; if(bForceFGUpdate || !m_oLastFGMask.data[nSamplePxIdx]) { const size_t nCurrRealModelIdx = nCurrModelIdx%m_nBGSamples; m_voBGColorSamples[nCurrRealModelIdx].data[nPxIter] = m_oLastColorFrame.data[nSamplePxIdx]; *((ushort*)(m_voBGDescSamples[nCurrRealModelIdx].data+nPxIter*2)) = *((ushort*)(m_oLastDescFrame.data+nSamplePxIdx*2)); } } } } } else { //m_nImgChannels==3 for(size_t nModelIter=0; nModelIter<m_nTotRelevantPxCount; ++nModelIter) { const size_t nPxIter = m_aPxIdxLUT[nModelIter]; if(bForceFGUpdate || !m_oLastFGMask.data[nPxIter]) { for(size_t nCurrModelIdx=nRefreshStartPos; nCurrModelIdx<nRefreshStartPos+nModelsToRefresh; ++nCurrModelIdx) { int nSampleImgCoord_Y, nSampleImgCoord_X; getRandSamplePosition(nSampleImgCoord_X,nSampleImgCoord_Y,m_aPxInfoLUT[nPxIter].nImgCoord_X,m_aPxInfoLUT[nPxIter].nImgCoord_Y,LBSP::PATCH_SIZE/2,m_oImgSize); const size_t nSamplePxIdx = m_oImgSize.width*nSampleImgCoord_Y + nSampleImgCoord_X; if(bForceFGUpdate || !m_oLastFGMask.data[nSamplePxIdx]) { const size_t nCurrRealModelIdx = nCurrModelIdx%m_nBGSamples; for(size_t c=0; c<3; ++c) { m_voBGColorSamples[nCurrRealModelIdx].data[nPxIter*3+c] = m_oLastColorFrame.data[nSamplePxIdx*3+c]; *((ushort*)(m_voBGDescSamples[nCurrRealModelIdx].data+(nPxIter*3+c)*2)) = *((ushort*)(m_oLastDescFrame.data+(nSamplePxIdx*3+c)*2)); } } } } } } } void BackgroundSubtractorSuBSENSE::operator()(cv::InputArray _image, cv::OutputArray _fgmask, double learningRateOverride) { // == process CV_Assert(m_bInitialized); cv::Mat oInputImg = _image.getMat(); CV_Assert(oInputImg.type()==m_nImgType && oInputImg.size()==m_oImgSize); CV_Assert(oInputImg.isContinuous()); _fgmask.create(m_oImgSize,CV_8UC1); cv::Mat oCurrFGMask = _fgmask.getMat(); memset(oCurrFGMask.data,0,oCurrFGMask.cols*oCurrFGMask.rows); size_t nNonZeroDescCount = 0; const float fRollAvgFactor_LT = 1.0f/std::min(++m_nFrameIndex,m_nSamplesForMovingAvgs); const float fRollAvgFactor_ST = 1.0f/std::min(m_nFrameIndex,m_nSamplesForMovingAvgs/4); if(m_nImgChannels==1) { for(size_t nModelIter=0; nModelIter<m_nTotRelevantPxCount; ++nModelIter) { const size_t nPxIter = m_aPxIdxLUT[nModelIter]; const size_t nDescIter = nPxIter*2; const size_t nFloatIter = nPxIter*4; const int nCurrImgCoord_X = m_aPxInfoLUT[nPxIter].nImgCoord_X; const int nCurrImgCoord_Y = m_aPxInfoLUT[nPxIter].nImgCoord_Y; const uchar nCurrColor = oInputImg.data[nPxIter]; size_t nMinDescDist = s_nDescMaxDataRange_1ch; size_t nMinSumDist = s_nColorMaxDataRange_1ch; float* pfCurrDistThresholdFactor = (float*)(m_oDistThresholdFrame.data+nFloatIter); float* pfCurrVariationFactor = (float*)(m_oVariationModulatorFrame.data+nFloatIter); float* pfCurrLearningRate = ((float*)(m_oUpdateRateFrame.data+nFloatIter)); float* pfCurrMeanLastDist = ((float*)(m_oMeanLastDistFrame.data+nFloatIter)); float* pfCurrMeanMinDist_LT = ((float*)(m_oMeanMinDistFrame_LT.data+nFloatIter)); float* pfCurrMeanMinDist_ST = ((float*)(m_oMeanMinDistFrame_ST.data+nFloatIter)); float* pfCurrMeanRawSegmRes_LT = ((float*)(m_oMeanRawSegmResFrame_LT.data+nFloatIter)); float* pfCurrMeanRawSegmRes_ST = ((float*)(m_oMeanRawSegmResFrame_ST.data+nFloatIter)); float* pfCurrMeanFinalSegmRes_LT = ((float*)(m_oMeanFinalSegmResFrame_LT.data+nFloatIter)); float* pfCurrMeanFinalSegmRes_ST = ((float*)(m_oMeanFinalSegmResFrame_ST.data+nFloatIter)); ushort& nLastIntraDesc = *((ushort*)(m_oLastDescFrame.data+nDescIter)); uchar& nLastColor = m_oLastColorFrame.data[nPxIter]; const size_t nCurrColorDistThreshold = (size_t)(((*pfCurrDistThresholdFactor)*m_nMinColorDistThreshold)-((!m_oUnstableRegionMask.data[nPxIter])*STAB_COLOR_DIST_OFFSET))/2; const size_t nCurrDescDistThreshold = ((size_t)1<<((size_t)floor(*pfCurrDistThresholdFactor+0.5f)))+m_nDescDistThresholdOffset+(m_oUnstableRegionMask.data[nPxIter]*UNSTAB_DESC_DIST_OFFSET); ushort nCurrInterDesc, nCurrIntraDesc; LBSP::computeGrayscaleDescriptor(oInputImg,nCurrColor,nCurrImgCoord_X,nCurrImgCoord_Y,m_anLBSPThreshold_8bitLUT[nCurrColor],nCurrIntraDesc); m_oUnstableRegionMask.data[nPxIter] = ((*pfCurrDistThresholdFactor)>UNSTABLE_REG_RDIST_MIN || (*pfCurrMeanRawSegmRes_LT-*pfCurrMeanFinalSegmRes_LT)>UNSTABLE_REG_RATIO_MIN || (*pfCurrMeanRawSegmRes_ST-*pfCurrMeanFinalSegmRes_ST)>UNSTABLE_REG_RATIO_MIN)?1:0; size_t nGoodSamplesCount=0, nSampleIdx=0; while(nGoodSamplesCount<m_nRequiredBGSamples && nSampleIdx<m_nBGSamples) { const uchar& nBGColor = m_voBGColorSamples[nSampleIdx].data[nPxIter]; { const size_t nColorDist = L1dist(nCurrColor,nBGColor); if(nColorDist>nCurrColorDistThreshold) goto failedcheck1ch; const ushort& nBGIntraDesc = *((ushort*)(m_voBGDescSamples[nSampleIdx].data+nDescIter)); const size_t nIntraDescDist = hdist(nCurrIntraDesc,nBGIntraDesc); LBSP::computeGrayscaleDescriptor(oInputImg,nBGColor,nCurrImgCoord_X,nCurrImgCoord_Y,m_anLBSPThreshold_8bitLUT[nBGColor],nCurrInterDesc); const size_t nInterDescDist = hdist(nCurrInterDesc,nBGIntraDesc); const size_t nDescDist = (nIntraDescDist+nInterDescDist)/2; if(nDescDist>nCurrDescDistThreshold) goto failedcheck1ch; const size_t nSumDist = std::min((nDescDist/4)*(s_nColorMaxDataRange_1ch/s_nDescMaxDataRange_1ch)+nColorDist,s_nColorMaxDataRange_1ch); if(nSumDist>nCurrColorDistThreshold) goto failedcheck1ch; if(nMinDescDist>nDescDist) nMinDescDist = nDescDist; if(nMinSumDist>nSumDist) nMinSumDist = nSumDist; nGoodSamplesCount++; } failedcheck1ch: nSampleIdx++; } const float fNormalizedLastDist = ((float)L1dist(nLastColor,nCurrColor)/s_nColorMaxDataRange_1ch+(float)hdist(nLastIntraDesc,nCurrIntraDesc)/s_nDescMaxDataRange_1ch)/2; *pfCurrMeanLastDist = (*pfCurrMeanLastDist)*(1.0f-fRollAvgFactor_ST) + fNormalizedLastDist*fRollAvgFactor_ST; if(nGoodSamplesCount<m_nRequiredBGSamples) { // == foreground const float fNormalizedMinDist = std::min(1.0f,((float)nMinSumDist/s_nColorMaxDataRange_1ch+(float)nMinDescDist/s_nDescMaxDataRange_1ch)/2 + (float)(m_nRequiredBGSamples-nGoodSamplesCount)/m_nRequiredBGSamples); *pfCurrMeanMinDist_LT = (*pfCurrMeanMinDist_LT)*(1.0f-fRollAvgFactor_LT) + fNormalizedMinDist*fRollAvgFactor_LT; *pfCurrMeanMinDist_ST = (*pfCurrMeanMinDist_ST)*(1.0f-fRollAvgFactor_ST) + fNormalizedMinDist*fRollAvgFactor_ST; *pfCurrMeanRawSegmRes_LT = (*pfCurrMeanRawSegmRes_LT)*(1.0f-fRollAvgFactor_LT) + fRollAvgFactor_LT; *pfCurrMeanRawSegmRes_ST = (*pfCurrMeanRawSegmRes_ST)*(1.0f-fRollAvgFactor_ST) + fRollAvgFactor_ST; oCurrFGMask.data[nPxIter] = UCHAR_MAX; if(m_nModelResetCooldown && (rand()%(size_t)FEEDBACK_T_LOWER)==0) { const size_t s_rand = rand()%m_nBGSamples; *((ushort*)(m_voBGDescSamples[s_rand].data+nDescIter)) = nCurrIntraDesc; m_voBGColorSamples[s_rand].data[nPxIter] = nCurrColor; } } else { // == background const float fNormalizedMinDist = ((float)nMinSumDist/s_nColorMaxDataRange_1ch+(float)nMinDescDist/s_nDescMaxDataRange_1ch)/2; *pfCurrMeanMinDist_LT = (*pfCurrMeanMinDist_LT)*(1.0f-fRollAvgFactor_LT) + fNormalizedMinDist*fRollAvgFactor_LT; *pfCurrMeanMinDist_ST = (*pfCurrMeanMinDist_ST)*(1.0f-fRollAvgFactor_ST) + fNormalizedMinDist*fRollAvgFactor_ST; *pfCurrMeanRawSegmRes_LT = (*pfCurrMeanRawSegmRes_LT)*(1.0f-fRollAvgFactor_LT); *pfCurrMeanRawSegmRes_ST = (*pfCurrMeanRawSegmRes_ST)*(1.0f-fRollAvgFactor_ST); const size_t nLearningRate = learningRateOverride>0?(size_t)ceil(learningRateOverride):(size_t)ceil(*pfCurrLearningRate); if((rand()%nLearningRate)==0) { const size_t s_rand = rand()%m_nBGSamples; *((ushort*)(m_voBGDescSamples[s_rand].data+nDescIter)) = nCurrIntraDesc; m_voBGColorSamples[s_rand].data[nPxIter] = nCurrColor; } int nSampleImgCoord_Y, nSampleImgCoord_X; const bool bCurrUsing3x3Spread = m_bUse3x3Spread && !m_oUnstableRegionMask.data[nPxIter]; if(bCurrUsing3x3Spread) getRandNeighborPosition_3x3(nSampleImgCoord_X,nSampleImgCoord_Y,nCurrImgCoord_X,nCurrImgCoord_Y,LBSP::PATCH_SIZE/2,m_oImgSize); else getRandNeighborPosition_5x5(nSampleImgCoord_X,nSampleImgCoord_Y,nCurrImgCoord_X,nCurrImgCoord_Y,LBSP::PATCH_SIZE/2,m_oImgSize); const size_t n_rand = rand(); const size_t idx_rand_uchar = m_oImgSize.width*nSampleImgCoord_Y + nSampleImgCoord_X; const size_t idx_rand_flt32 = idx_rand_uchar*4; const float fRandMeanLastDist = *((float*)(m_oMeanLastDistFrame.data+idx_rand_flt32)); const float fRandMeanRawSegmRes = *((float*)(m_oMeanRawSegmResFrame_ST.data+idx_rand_flt32)); if((n_rand%(bCurrUsing3x3Spread?nLearningRate:(nLearningRate/2+1)))==0 || (fRandMeanRawSegmRes>GHOSTDET_S_MIN && fRandMeanLastDist<GHOSTDET_D_MAX && (n_rand%((size_t)m_fCurrLearningRateLowerCap))==0)) { const size_t idx_rand_ushrt = idx_rand_uchar*2; const size_t s_rand = rand()%m_nBGSamples; *((ushort*)(m_voBGDescSamples[s_rand].data+idx_rand_ushrt)) = nCurrIntraDesc; m_voBGColorSamples[s_rand].data[idx_rand_uchar] = nCurrColor; } } if(m_oLastFGMask.data[nPxIter] || (std::min(*pfCurrMeanMinDist_LT,*pfCurrMeanMinDist_ST)<UNSTABLE_REG_RATIO_MIN && oCurrFGMask.data[nPxIter])) { if((*pfCurrLearningRate)<m_fCurrLearningRateUpperCap) *pfCurrLearningRate += FEEDBACK_T_INCR/(std::max(*pfCurrMeanMinDist_LT,*pfCurrMeanMinDist_ST)*(*pfCurrVariationFactor)); } else if((*pfCurrLearningRate)>m_fCurrLearningRateLowerCap) *pfCurrLearningRate -= FEEDBACK_T_DECR*(*pfCurrVariationFactor)/std::max(*pfCurrMeanMinDist_LT,*pfCurrMeanMinDist_ST); if((*pfCurrLearningRate)<m_fCurrLearningRateLowerCap) *pfCurrLearningRate = m_fCurrLearningRateLowerCap; else if((*pfCurrLearningRate)>m_fCurrLearningRateUpperCap) *pfCurrLearningRate = m_fCurrLearningRateUpperCap; if(std::max(*pfCurrMeanMinDist_LT,*pfCurrMeanMinDist_ST)>UNSTABLE_REG_RATIO_MIN && m_oBlinksFrame.data[nPxIter]) (*pfCurrVariationFactor) += FEEDBACK_V_INCR; else if((*pfCurrVariationFactor)>FEEDBACK_V_DECR) { (*pfCurrVariationFactor) -= m_oLastFGMask.data[nPxIter]?FEEDBACK_V_DECR/4:m_oUnstableRegionMask.data[nPxIter]?FEEDBACK_V_DECR/2:FEEDBACK_V_DECR; if((*pfCurrVariationFactor)<FEEDBACK_V_DECR) (*pfCurrVariationFactor) = FEEDBACK_V_DECR; } if((*pfCurrDistThresholdFactor)<std::pow(1.0f+std::min(*pfCurrMeanMinDist_LT,*pfCurrMeanMinDist_ST)*2,2)) (*pfCurrDistThresholdFactor) += FEEDBACK_R_VAR*(*pfCurrVariationFactor-FEEDBACK_V_DECR); else { (*pfCurrDistThresholdFactor) -= FEEDBACK_R_VAR/(*pfCurrVariationFactor); if((*pfCurrDistThresholdFactor)<1.0f) (*pfCurrDistThresholdFactor) = 1.0f; } if(popcount(nCurrIntraDesc)>=2) ++nNonZeroDescCount; nLastIntraDesc = nCurrIntraDesc; nLastColor = nCurrColor; } } else { //m_nImgChannels==3 for(size_t nModelIter=0; nModelIter<m_nTotRelevantPxCount; ++nModelIter) { const size_t nPxIter = m_aPxIdxLUT[nModelIter]; const int nCurrImgCoord_X = m_aPxInfoLUT[nPxIter].nImgCoord_X; const int nCurrImgCoord_Y = m_aPxInfoLUT[nPxIter].nImgCoord_Y; const size_t nPxIterRGB = nPxIter*3; const size_t nDescIterRGB = nPxIterRGB*2; const size_t nFloatIter = nPxIter*4; const uchar* const anCurrColor = oInputImg.data+nPxIterRGB; size_t nMinTotDescDist=s_nDescMaxDataRange_3ch; size_t nMinTotSumDist=s_nColorMaxDataRange_3ch; float* pfCurrDistThresholdFactor = (float*)(m_oDistThresholdFrame.data+nFloatIter); float* pfCurrVariationFactor = (float*)(m_oVariationModulatorFrame.data+nFloatIter); float* pfCurrLearningRate = ((float*)(m_oUpdateRateFrame.data+nFloatIter)); float* pfCurrMeanLastDist = ((float*)(m_oMeanLastDistFrame.data+nFloatIter)); float* pfCurrMeanMinDist_LT = ((float*)(m_oMeanMinDistFrame_LT.data+nFloatIter)); float* pfCurrMeanMinDist_ST = ((float*)(m_oMeanMinDistFrame_ST.data+nFloatIter)); float* pfCurrMeanRawSegmRes_LT = ((float*)(m_oMeanRawSegmResFrame_LT.data+nFloatIter)); float* pfCurrMeanRawSegmRes_ST = ((float*)(m_oMeanRawSegmResFrame_ST.data+nFloatIter)); float* pfCurrMeanFinalSegmRes_LT = ((float*)(m_oMeanFinalSegmResFrame_LT.data+nFloatIter)); float* pfCurrMeanFinalSegmRes_ST = ((float*)(m_oMeanFinalSegmResFrame_ST.data+nFloatIter)); ushort* anLastIntraDesc = ((ushort*)(m_oLastDescFrame.data+nDescIterRGB)); uchar* anLastColor = m_oLastColorFrame.data+nPxIterRGB; const size_t nCurrColorDistThreshold = (size_t)(((*pfCurrDistThresholdFactor)*m_nMinColorDistThreshold)-((!m_oUnstableRegionMask.data[nPxIter])*STAB_COLOR_DIST_OFFSET)); const size_t nCurrDescDistThreshold = ((size_t)1<<((size_t)floor(*pfCurrDistThresholdFactor+0.5f)))+m_nDescDistThresholdOffset+(m_oUnstableRegionMask.data[nPxIter]*UNSTAB_DESC_DIST_OFFSET); const size_t nCurrTotColorDistThreshold = nCurrColorDistThreshold*3; const size_t nCurrTotDescDistThreshold = nCurrDescDistThreshold*3; const size_t nCurrSCColorDistThreshold = nCurrTotColorDistThreshold/2; ushort anCurrInterDesc[3], anCurrIntraDesc[3]; const size_t anCurrIntraLBSPThresholds[3] = {m_anLBSPThreshold_8bitLUT[anCurrColor[0]],m_anLBSPThreshold_8bitLUT[anCurrColor[1]],m_anLBSPThreshold_8bitLUT[anCurrColor[2]]}; LBSP::computeRGBDescriptor(oInputImg,anCurrColor,nCurrImgCoord_X,nCurrImgCoord_Y,anCurrIntraLBSPThresholds,anCurrIntraDesc); m_oUnstableRegionMask.data[nPxIter] = ((*pfCurrDistThresholdFactor)>UNSTABLE_REG_RDIST_MIN || (*pfCurrMeanRawSegmRes_LT-*pfCurrMeanFinalSegmRes_LT)>UNSTABLE_REG_RATIO_MIN || (*pfCurrMeanRawSegmRes_ST-*pfCurrMeanFinalSegmRes_ST)>UNSTABLE_REG_RATIO_MIN)?1:0; size_t nGoodSamplesCount=0, nSampleIdx=0; while(nGoodSamplesCount<m_nRequiredBGSamples && nSampleIdx<m_nBGSamples) { const ushort* const anBGIntraDesc = (ushort*)(m_voBGDescSamples[nSampleIdx].data+nDescIterRGB); const uchar* const anBGColor = m_voBGColorSamples[nSampleIdx].data+nPxIterRGB; size_t nTotDescDist = 0; size_t nTotSumDist = 0; for(size_t c=0;c<3; ++c) { const size_t nColorDist = L1dist(anCurrColor[c],anBGColor[c]); if(nColorDist>nCurrSCColorDistThreshold) goto failedcheck3ch; const size_t nIntraDescDist = hdist(anCurrIntraDesc[c],anBGIntraDesc[c]); LBSP::computeSingleRGBDescriptor(oInputImg,anBGColor[c],nCurrImgCoord_X,nCurrImgCoord_Y,c,m_anLBSPThreshold_8bitLUT[anBGColor[c]],anCurrInterDesc[c]); const size_t nInterDescDist = hdist(anCurrInterDesc[c],anBGIntraDesc[c]); const size_t nDescDist = (nIntraDescDist+nInterDescDist)/2; const size_t nSumDist = std::min((nDescDist/2)*(s_nColorMaxDataRange_1ch/s_nDescMaxDataRange_1ch)+nColorDist,s_nColorMaxDataRange_1ch); if(nSumDist>nCurrSCColorDistThreshold) goto failedcheck3ch; nTotDescDist += nDescDist; nTotSumDist += nSumDist; } if(nTotDescDist>nCurrTotDescDistThreshold || nTotSumDist>nCurrTotColorDistThreshold) goto failedcheck3ch; if(nMinTotDescDist>nTotDescDist) nMinTotDescDist = nTotDescDist; if(nMinTotSumDist>nTotSumDist) nMinTotSumDist = nTotSumDist; nGoodSamplesCount++; failedcheck3ch: nSampleIdx++; } const float fNormalizedLastDist = ((float)L1dist<3>(anLastColor,anCurrColor)/s_nColorMaxDataRange_3ch+(float)hdist<3>(anLastIntraDesc,anCurrIntraDesc)/s_nDescMaxDataRange_3ch)/2; *pfCurrMeanLastDist = (*pfCurrMeanLastDist)*(1.0f-fRollAvgFactor_ST) + fNormalizedLastDist*fRollAvgFactor_ST; if(nGoodSamplesCount<m_nRequiredBGSamples) { // == foreground const float fNormalizedMinDist = std::min(1.0f,((float)nMinTotSumDist/s_nColorMaxDataRange_3ch+(float)nMinTotDescDist/s_nDescMaxDataRange_3ch)/2 + (float)(m_nRequiredBGSamples-nGoodSamplesCount)/m_nRequiredBGSamples); *pfCurrMeanMinDist_LT = (*pfCurrMeanMinDist_LT)*(1.0f-fRollAvgFactor_LT) + fNormalizedMinDist*fRollAvgFactor_LT; *pfCurrMeanMinDist_ST = (*pfCurrMeanMinDist_ST)*(1.0f-fRollAvgFactor_ST) + fNormalizedMinDist*fRollAvgFactor_ST; *pfCurrMeanRawSegmRes_LT = (*pfCurrMeanRawSegmRes_LT)*(1.0f-fRollAvgFactor_LT) + fRollAvgFactor_LT; *pfCurrMeanRawSegmRes_ST = (*pfCurrMeanRawSegmRes_ST)*(1.0f-fRollAvgFactor_ST) + fRollAvgFactor_ST; oCurrFGMask.data[nPxIter] = UCHAR_MAX; if(m_nModelResetCooldown && (rand()%(size_t)FEEDBACK_T_LOWER)==0) { const size_t s_rand = rand()%m_nBGSamples; for(size_t c=0; c<3; ++c) { *((ushort*)(m_voBGDescSamples[s_rand].data+nDescIterRGB+2*c)) = anCurrIntraDesc[c]; *(m_voBGColorSamples[s_rand].data+nPxIterRGB+c) = anCurrColor[c]; } } } else { // == background const float fNormalizedMinDist = ((float)nMinTotSumDist/s_nColorMaxDataRange_3ch+(float)nMinTotDescDist/s_nDescMaxDataRange_3ch)/2; *pfCurrMeanMinDist_LT = (*pfCurrMeanMinDist_LT)*(1.0f-fRollAvgFactor_LT) + fNormalizedMinDist*fRollAvgFactor_LT; *pfCurrMeanMinDist_ST = (*pfCurrMeanMinDist_ST)*(1.0f-fRollAvgFactor_ST) + fNormalizedMinDist*fRollAvgFactor_ST; *pfCurrMeanRawSegmRes_LT = (*pfCurrMeanRawSegmRes_LT)*(1.0f-fRollAvgFactor_LT); *pfCurrMeanRawSegmRes_ST = (*pfCurrMeanRawSegmRes_ST)*(1.0f-fRollAvgFactor_ST); const size_t nLearningRate = learningRateOverride>0?(size_t)ceil(learningRateOverride):(size_t)ceil(*pfCurrLearningRate); if((rand()%nLearningRate)==0) { const size_t s_rand = rand()%m_nBGSamples; for(size_t c=0; c<3; ++c) { *((ushort*)(m_voBGDescSamples[s_rand].data+nDescIterRGB+2*c)) = anCurrIntraDesc[c]; *(m_voBGColorSamples[s_rand].data+nPxIterRGB+c) = anCurrColor[c]; } } int nSampleImgCoord_Y, nSampleImgCoord_X; const bool bCurrUsing3x3Spread = m_bUse3x3Spread && !m_oUnstableRegionMask.data[nPxIter]; if(bCurrUsing3x3Spread) getRandNeighborPosition_3x3(nSampleImgCoord_X,nSampleImgCoord_Y,nCurrImgCoord_X,nCurrImgCoord_Y,LBSP::PATCH_SIZE/2,m_oImgSize); else getRandNeighborPosition_5x5(nSampleImgCoord_X,nSampleImgCoord_Y,nCurrImgCoord_X,nCurrImgCoord_Y,LBSP::PATCH_SIZE/2,m_oImgSize); const size_t n_rand = rand(); const size_t idx_rand_uchar = m_oImgSize.width*nSampleImgCoord_Y + nSampleImgCoord_X; const size_t idx_rand_flt32 = idx_rand_uchar*4; const float fRandMeanLastDist = *((float*)(m_oMeanLastDistFrame.data+idx_rand_flt32)); const float fRandMeanRawSegmRes = *((float*)(m_oMeanRawSegmResFrame_ST.data+idx_rand_flt32)); if((n_rand%(bCurrUsing3x3Spread?nLearningRate:(nLearningRate/2+1)))==0 || (fRandMeanRawSegmRes>GHOSTDET_S_MIN && fRandMeanLastDist<GHOSTDET_D_MAX && (n_rand%((size_t)m_fCurrLearningRateLowerCap))==0)) { const size_t idx_rand_uchar_rgb = idx_rand_uchar*3; const size_t idx_rand_ushrt_rgb = idx_rand_uchar_rgb*2; const size_t s_rand = rand()%m_nBGSamples; for(size_t c=0; c<3; ++c) { *((ushort*)(m_voBGDescSamples[s_rand].data+idx_rand_ushrt_rgb+2*c)) = anCurrIntraDesc[c]; *(m_voBGColorSamples[s_rand].data+idx_rand_uchar_rgb+c) = anCurrColor[c]; } } } if(m_oLastFGMask.data[nPxIter] || (std::min(*pfCurrMeanMinDist_LT,*pfCurrMeanMinDist_ST)<UNSTABLE_REG_RATIO_MIN && oCurrFGMask.data[nPxIter])) { if((*pfCurrLearningRate)<m_fCurrLearningRateUpperCap) *pfCurrLearningRate += FEEDBACK_T_INCR/(std::max(*pfCurrMeanMinDist_LT,*pfCurrMeanMinDist_ST)*(*pfCurrVariationFactor)); } else if((*pfCurrLearningRate)>m_fCurrLearningRateLowerCap) *pfCurrLearningRate -= FEEDBACK_T_DECR*(*pfCurrVariationFactor)/std::max(*pfCurrMeanMinDist_LT,*pfCurrMeanMinDist_ST); if((*pfCurrLearningRate)<m_fCurrLearningRateLowerCap) *pfCurrLearningRate = m_fCurrLearningRateLowerCap; else if((*pfCurrLearningRate)>m_fCurrLearningRateUpperCap) *pfCurrLearningRate = m_fCurrLearningRateUpperCap; if(std::max(*pfCurrMeanMinDist_LT,*pfCurrMeanMinDist_ST)>UNSTABLE_REG_RATIO_MIN && m_oBlinksFrame.data[nPxIter]) (*pfCurrVariationFactor) += FEEDBACK_V_INCR; else if((*pfCurrVariationFactor)>FEEDBACK_V_DECR) { (*pfCurrVariationFactor) -= m_oLastFGMask.data[nPxIter]?FEEDBACK_V_DECR/4:m_oUnstableRegionMask.data[nPxIter]?FEEDBACK_V_DECR/2:FEEDBACK_V_DECR; if((*pfCurrVariationFactor)<FEEDBACK_V_DECR) (*pfCurrVariationFactor) = FEEDBACK_V_DECR; } if((*pfCurrDistThresholdFactor)<std::pow(1.0f+std::min(*pfCurrMeanMinDist_LT,*pfCurrMeanMinDist_ST)*2,2)) (*pfCurrDistThresholdFactor) += FEEDBACK_R_VAR*(*pfCurrVariationFactor-FEEDBACK_V_DECR); else { (*pfCurrDistThresholdFactor) -= FEEDBACK_R_VAR/(*pfCurrVariationFactor); if((*pfCurrDistThresholdFactor)<1.0f) (*pfCurrDistThresholdFactor) = 1.0f; } if(popcount<3>(anCurrIntraDesc)>=4) ++nNonZeroDescCount; for(size_t c=0; c<3; ++c) { anLastIntraDesc[c] = anCurrIntraDesc[c]; anLastColor[c] = anCurrColor[c]; } } } #if DISPLAY_SUBSENSE_DEBUG_INFO std::cout << std::endl; cv::Point dbgpt(nDebugCoordX,nDebugCoordY); cv::Mat oMeanMinDistFrameNormalized; m_oMeanMinDistFrame_ST.copyTo(oMeanMinDistFrameNormalized); cv::circle(oMeanMinDistFrameNormalized,dbgpt,5,cv::Scalar(1.0f)); cv::resize(oMeanMinDistFrameNormalized,oMeanMinDistFrameNormalized,DEFAULT_FRAME_SIZE); cv::imshow("d_min(x)",oMeanMinDistFrameNormalized); std::cout << std::fixed << std::setprecision(5) << " d_min(" << dbgpt << ") = " << m_oMeanMinDistFrame_ST.at<float>(dbgpt) << std::endl; cv::Mat oMeanLastDistFrameNormalized; m_oMeanLastDistFrame.copyTo(oMeanLastDistFrameNormalized); cv::circle(oMeanLastDistFrameNormalized,dbgpt,5,cv::Scalar(1.0f)); cv::resize(oMeanLastDistFrameNormalized,oMeanLastDistFrameNormalized,DEFAULT_FRAME_SIZE); cv::imshow("d_last(x)",oMeanLastDistFrameNormalized); std::cout << std::fixed << std::setprecision(5) << " d_last(" << dbgpt << ") = " << m_oMeanLastDistFrame.at<float>(dbgpt) << std::endl; cv::Mat oMeanRawSegmResFrameNormalized; m_oMeanRawSegmResFrame_ST.copyTo(oMeanRawSegmResFrameNormalized); cv::circle(oMeanRawSegmResFrameNormalized,dbgpt,5,cv::Scalar(1.0f)); cv::resize(oMeanRawSegmResFrameNormalized,oMeanRawSegmResFrameNormalized,DEFAULT_FRAME_SIZE); cv::imshow("s_avg(x)",oMeanRawSegmResFrameNormalized); std::cout << std::fixed << std::setprecision(5) << " s_avg(" << dbgpt << ") = " << m_oMeanRawSegmResFrame_ST.at<float>(dbgpt) << std::endl; cv::Mat oMeanFinalSegmResFrameNormalized; m_oMeanFinalSegmResFrame_ST.copyTo(oMeanFinalSegmResFrameNormalized); cv::circle(oMeanFinalSegmResFrameNormalized,dbgpt,5,cv::Scalar(1.0f)); cv::resize(oMeanFinalSegmResFrameNormalized,oMeanFinalSegmResFrameNormalized,DEFAULT_FRAME_SIZE); cv::imshow("z_avg(x)",oMeanFinalSegmResFrameNormalized); std::cout << std::fixed << std::setprecision(5) << " z_avg(" << dbgpt << ") = " << m_oMeanFinalSegmResFrame_ST.at<float>(dbgpt) << std::endl; cv::Mat oDistThresholdFrameNormalized; m_oDistThresholdFrame.convertTo(oDistThresholdFrameNormalized,CV_32FC1,0.25f,-0.25f); cv::circle(oDistThresholdFrameNormalized,dbgpt,5,cv::Scalar(1.0f)); cv::resize(oDistThresholdFrameNormalized,oDistThresholdFrameNormalized,DEFAULT_FRAME_SIZE); cv::imshow("r(x)",oDistThresholdFrameNormalized); std::cout << std::fixed << std::setprecision(5) << " r(" << dbgpt << ") = " << m_oDistThresholdFrame.at<float>(dbgpt) << std::endl; cv::Mat oVariationModulatorFrameNormalized; cv::normalize(m_oVariationModulatorFrame,oVariationModulatorFrameNormalized,0,255,cv::NORM_MINMAX,CV_8UC1); cv::circle(oVariationModulatorFrameNormalized,dbgpt,5,cv::Scalar(255)); cv::resize(oVariationModulatorFrameNormalized,oVariationModulatorFrameNormalized,DEFAULT_FRAME_SIZE); cv::imshow("v(x)",oVariationModulatorFrameNormalized); std::cout << std::fixed << std::setprecision(5) << " v(" << dbgpt << ") = " << m_oVariationModulatorFrame.at<float>(dbgpt) << std::endl; cv::Mat oUpdateRateFrameNormalized; m_oUpdateRateFrame.convertTo(oUpdateRateFrameNormalized,CV_32FC1,1.0f/FEEDBACK_T_UPPER,-FEEDBACK_T_LOWER/FEEDBACK_T_UPPER); cv::circle(oUpdateRateFrameNormalized,dbgpt,5,cv::Scalar(1.0f)); cv::resize(oUpdateRateFrameNormalized,oUpdateRateFrameNormalized,DEFAULT_FRAME_SIZE); cv::imshow("t(x)",oUpdateRateFrameNormalized); std::cout << std::fixed << std::setprecision(5) << " t(" << dbgpt << ") = " << m_oUpdateRateFrame.at<float>(dbgpt) << std::endl; #endif //DISPLAY_SUBSENSE_DEBUG_INFO cv::bitwise_xor(oCurrFGMask,m_oLastRawFGMask,m_oCurrRawFGBlinkMask); cv::bitwise_or(m_oCurrRawFGBlinkMask,m_oLastRawFGBlinkMask,m_oBlinksFrame); m_oCurrRawFGBlinkMask.copyTo(m_oLastRawFGBlinkMask); oCurrFGMask.copyTo(m_oLastRawFGMask); cv::morphologyEx(oCurrFGMask,m_oFGMask_PreFlood,cv::MORPH_CLOSE, m_defaultMorphologyKernel); m_oFGMask_PreFlood.copyTo(m_oFGMask_FloodedHoles); cv::floodFill(m_oFGMask_FloodedHoles,cv::Point(0,0),UCHAR_MAX); cv::bitwise_not(m_oFGMask_FloodedHoles,m_oFGMask_FloodedHoles); cv::erode(m_oFGMask_PreFlood,m_oFGMask_PreFlood,m_defaultMorphologyKernel,cv::Point(-1,-1),3); cv::bitwise_or(oCurrFGMask,m_oFGMask_FloodedHoles,oCurrFGMask); cv::bitwise_or(oCurrFGMask,m_oFGMask_PreFlood,oCurrFGMask); cv::medianBlur(oCurrFGMask,m_oLastFGMask,m_nMedianBlurKernelSize); cv::dilate(m_oLastFGMask,m_oLastFGMask_dilated,m_defaultMorphologyKernel,cv::Point(-1,-1),3); cv::bitwise_and(m_oBlinksFrame,m_oLastFGMask_dilated_inverted,m_oBlinksFrame); cv::bitwise_not(m_oLastFGMask_dilated,m_oLastFGMask_dilated_inverted); cv::bitwise_and(m_oBlinksFrame,m_oLastFGMask_dilated_inverted,m_oBlinksFrame); m_oLastFGMask.copyTo(oCurrFGMask); cv::addWeighted(m_oMeanFinalSegmResFrame_LT,(1.0f-fRollAvgFactor_LT),m_oLastFGMask,(1.0/UCHAR_MAX)*fRollAvgFactor_LT,0,m_oMeanFinalSegmResFrame_LT,CV_32F); cv::addWeighted(m_oMeanFinalSegmResFrame_ST,(1.0f-fRollAvgFactor_ST),m_oLastFGMask,(1.0/UCHAR_MAX)*fRollAvgFactor_ST,0,m_oMeanFinalSegmResFrame_ST,CV_32F); const float fCurrNonZeroDescRatio = (float)nNonZeroDescCount/m_nTotRelevantPxCount; if(fCurrNonZeroDescRatio<LBSPDESC_NONZERO_RATIO_MIN && m_fLastNonZeroDescRatio<LBSPDESC_NONZERO_RATIO_MIN) { for(size_t t=0; t<=UCHAR_MAX; ++t) if(m_anLBSPThreshold_8bitLUT[t]>cv::saturate_cast<uchar>(m_nLBSPThresholdOffset+ceil(t*m_fRelLBSPThreshold/4))) --m_anLBSPThreshold_8bitLUT[t]; } else if(fCurrNonZeroDescRatio>LBSPDESC_NONZERO_RATIO_MAX && m_fLastNonZeroDescRatio>LBSPDESC_NONZERO_RATIO_MAX) { for(size_t t=0; t<=UCHAR_MAX; ++t) if(m_anLBSPThreshold_8bitLUT[t]<cv::saturate_cast<uchar>(m_nLBSPThresholdOffset+UCHAR_MAX*m_fRelLBSPThreshold)) ++m_anLBSPThreshold_8bitLUT[t]; } m_fLastNonZeroDescRatio = fCurrNonZeroDescRatio; if(m_bLearningRateScalingEnabled) { cv::resize(oInputImg,m_oDownSampledFrame_MotionAnalysis,m_oDownSampledFrameSize,0,0,cv::INTER_AREA); cv::accumulateWeighted(m_oDownSampledFrame_MotionAnalysis,m_oMeanDownSampledLastDistFrame_LT,fRollAvgFactor_LT); cv::accumulateWeighted(m_oDownSampledFrame_MotionAnalysis,m_oMeanDownSampledLastDistFrame_ST,fRollAvgFactor_ST); size_t nTotColorDiff = 0; for(int i=0; i<m_oMeanDownSampledLastDistFrame_ST.rows; ++i) { const size_t idx1 = m_oMeanDownSampledLastDistFrame_ST.step.p[0]*i; for(int j=0; j<m_oMeanDownSampledLastDistFrame_ST.cols; ++j) { const size_t idx2 = idx1+m_oMeanDownSampledLastDistFrame_ST.step.p[1]*j; nTotColorDiff += (m_nImgChannels==1)? (size_t)fabs((*(float*)(m_oMeanDownSampledLastDistFrame_ST.data+idx2))-(*(float*)(m_oMeanDownSampledLastDistFrame_LT.data+idx2)))/2 : //(m_nImgChannels==3) std::max((size_t)fabs((*(float*)(m_oMeanDownSampledLastDistFrame_ST.data+idx2))-(*(float*)(m_oMeanDownSampledLastDistFrame_LT.data+idx2))), std::max((size_t)fabs((*(float*)(m_oMeanDownSampledLastDistFrame_ST.data+idx2+4))-(*(float*)(m_oMeanDownSampledLastDistFrame_LT.data+idx2+4))), (size_t)fabs((*(float*)(m_oMeanDownSampledLastDistFrame_ST.data+idx2+8))-(*(float*)(m_oMeanDownSampledLastDistFrame_LT.data+idx2+8))))); } } const float fCurrColorDiffRatio = (float)nTotColorDiff/(m_oMeanDownSampledLastDistFrame_ST.rows*m_oMeanDownSampledLastDistFrame_ST.cols); if(m_bAutoModelResetEnabled) { if(m_nFramesSinceLastReset>1000) m_bAutoModelResetEnabled = false; else if(fCurrColorDiffRatio>=FRAMELEVEL_MIN_COLOR_DIFF_THRESHOLD && m_nModelResetCooldown==0) { m_nFramesSinceLastReset = 0; refreshModel(0.1f); // reset 10% of the bg model m_nModelResetCooldown = m_nSamplesForMovingAvgs/4; m_oUpdateRateFrame = cv::Scalar(1.0f); } else ++m_nFramesSinceLastReset; } else if(fCurrColorDiffRatio>=FRAMELEVEL_MIN_COLOR_DIFF_THRESHOLD*2) { m_nFramesSinceLastReset = 0; m_bAutoModelResetEnabled = true; } if(fCurrColorDiffRatio>=FRAMELEVEL_MIN_COLOR_DIFF_THRESHOLD/2) { m_fCurrLearningRateLowerCap = (float)std::max((int)FEEDBACK_T_LOWER>>(int)(fCurrColorDiffRatio/2),1); m_fCurrLearningRateUpperCap = (float)std::max((int)FEEDBACK_T_UPPER>>(int)(fCurrColorDiffRatio/2),1); } else { m_fCurrLearningRateLowerCap = FEEDBACK_T_LOWER; m_fCurrLearningRateUpperCap = FEEDBACK_T_UPPER; } if(m_nModelResetCooldown>0) --m_nModelResetCooldown; } } void BackgroundSubtractorSuBSENSE::getBackgroundImage(cv::OutputArray backgroundImage) const { CV_Assert(m_bInitialized); cv::Mat oAvgBGImg = cv::Mat::zeros(m_oImgSize,CV_32FC((int)m_nImgChannels)); for(size_t s=0; s<m_nBGSamples; ++s) { for(int y=0; y<m_oImgSize.height; ++y) { for(int x=0; x<m_oImgSize.width; ++x) { const size_t idx_nimg = m_voBGColorSamples[s].step.p[0]*y + m_voBGColorSamples[s].step.p[1]*x; const size_t nFloatIter = idx_nimg*4; float* oAvgBgImgPtr = (float*)(oAvgBGImg.data+nFloatIter); const uchar* const oBGImgPtr = m_voBGColorSamples[s].data+idx_nimg; for(size_t c=0; c<m_nImgChannels; ++c) oAvgBgImgPtr[c] += ((float)oBGImgPtr[c])/m_nBGSamples; } } } oAvgBGImg.convertTo(backgroundImage,CV_8U); } void BackgroundSubtractorSuBSENSE::getBackgroundDescriptorsImage(cv::OutputArray backgroundDescImage) const { CV_Assert(LBSP::DESC_SIZE==2); CV_Assert(m_bInitialized); cv::Mat oAvgBGDesc = cv::Mat::zeros(m_oImgSize,CV_32FC((int)m_nImgChannels)); for(size_t n=0; n<m_voBGDescSamples.size(); ++n) { for(int y=0; y<m_oImgSize.height; ++y) { for(int x=0; x<m_oImgSize.width; ++x) { const size_t idx_ndesc = m_voBGDescSamples[n].step.p[0]*y + m_voBGDescSamples[n].step.p[1]*x; const size_t nFloatIter = idx_ndesc*2; float* oAvgBgDescPtr = (float*)(oAvgBGDesc.data+nFloatIter); const ushort* const oBGDescPtr = (ushort*)(m_voBGDescSamples[n].data+idx_ndesc); for(size_t c=0; c<m_nImgChannels; ++c) oAvgBgDescPtr[c] += ((float)oBGDescPtr[c])/m_voBGDescSamples.size(); } } } oAvgBGDesc.convertTo(backgroundDescImage,CV_16U); } void BackgroundSubtractorSuBSENSE::apply(cv::InputArray image, cv::OutputArray fgmask, double learningRateOverride) { (*this)(image, fgmask, learningRateOverride); } 结合代码分析描述符阈值是如何动态计算的?
09-23
我想尝试你的优化思路1:强化特征一致性惩罚添加特征级语义差异损失,避免边界穿越显著物体,请根据我的原始损失函数代码进行修改import torch import torch.nn as nn import torch.nn.functional as F # def get_vgg19_FeatureMap(vgg_model, input_255, layer_index): # vgg_mean = torch.tensor([123.6800, 116.7790, 103.9390]).reshape((1,3,1,1)) # if torch.cuda.is_available(): # vgg_mean = vgg_mean.cuda() # vgg_input = input_255-vgg_mean # #x = vgg_model.features[0](vgg_input) # #FeatureMap_list.append(x) # for i in range(0,layer_index+1): # if i == 0: # x = vgg_model.features[0](vgg_input) # else: # x = vgg_model.features[i](x) # return x def l_num_loss(img1, img2, l_num=1): """计算L1损失。 参数: img1: 第一个输入图像tensor。 img2: 第二个输入图像tensor。 l_num: Lp范数的阶数,默认为1,即L1损失。 返回: L1损失的均值。 """ return torch.mean(torch.abs((img1 - img2)**l_num)) def boundary_extraction(mask): """提取图像边界。 参数: mask: 输入的二值掩码,0表示非边界,1表示边界。 返回: 返回一个显示边界的图像,非边界的位置为0,边界的位置为输入mask的值。 """ ones = torch.ones_like(mask) zeros = torch.zeros_like(mask) # 定义卷积核,3x3的全1核 in_channel = 1 out_channel = 1 kernel = [[1, 1, 1], [1, 1, 1], [1, 1, 1]] kernel = torch.FloatTensor(kernel).expand(out_channel,in_channel,3,3) # 用3×3卷积核检测边界 if torch.cuda.is_available(): kernel = kernel.cuda() ones = ones.cuda() zeros = zeros.cuda() weight = nn.Parameter(data=kernel, requires_grad=False) #dilation # 多次卷积加强边界,将小于1的像素设置为0,大于等于1的像素设置为1,强调边界位置 # 使用卷积核进行膨胀操作,多次卷积操作加强边界 x = F.conv2d(1-mask,weight,stride=1,padding=1) x = torch.where(x < 1, zeros, ones) x = F.conv2d(x,weight,stride=1,padding=1) x = torch.where(x < 1, zeros, ones) x = F.conv2d(x,weight,stride=1,padding=1) x = torch.where(x < 1, zeros, ones) x = F.conv2d(x,weight,stride=1,padding=1) x = torch.where(x < 1, zeros, ones) x = F.conv2d(x,weight,stride=1,padding=1) x = torch.where(x < 1, zeros, ones) x = F.conv2d(x,weight,stride=1,padding=1) x = torch.where(x < 1, zeros, ones) x = F.conv2d(x,weight,stride=1,padding=1) x = torch.where(x < 1, zeros, ones) # 返回结果像素值为1的是边界,像素值为0的非边界。 return x*mask # 边界损失由公式(10) def cal_boundary_term(inpu1_tesnor, inpu2_tesnor, mask1_tesnor, mask2_tesnor, stitched_image): """计算边界损失。 参数: inpu1_tesnor: 第一个输入图像tensor。 inpu2_tesnor: 第二个输入图像tensor。 mask1_tesnor: 第一个输入图像的掩码。 mask2_tesnor: 第二个输入图像的掩码。 stitched_image: 拼接后的图像。 返回: total_loss: 总的边界损失。 boundary_mask1: 第一个图像的边界掩码。 """ # 提取完的mask中的像素是0或1,乘完得到边界的mask,不为0的像素是边界 boundary_mask1 = mask1_tesnor * boundary_extraction(mask2_tesnor) boundary_mask2 = mask2_tesnor * boundary_extraction(mask1_tesnor) # 两个边界用L1范数计算损失 loss1 = l_num_loss(inpu1_tesnor*boundary_mask1, stitched_image*boundary_mask1, 1) loss2 = l_num_loss(inpu2_tesnor*boundary_mask2, stitched_image*boundary_mask2, 1) # 返回总的边界项损失和边界mask1,对应论文公式(10) return loss1+loss2, boundary_mask1 # 计算平滑项损失,惩罚拼接图中不平滑区域,使结果更自然 # 对应论文中公式(12) def cal_smooth_term_stitch(stitched_image, learned_mask1): """计算拼接图的平滑损失。 参数: stitched_image: 拼接后的图像。 learned_mask1: 学习得到的掩码。 返回: loss: 平滑损失的值。 """ delta = 1 # 计算掩码在垂直和水平方向的变化量 dh_mask = torch.abs(learned_mask1[:,:,0:-1*delta,:] - learned_mask1[:,:,delta:,:]) # 计算learned_mask1两个方向的变化量 dw_mask = torch.abs(learned_mask1[:,:,:,0:-1*delta] - learned_mask1[:,:,:,delta:]) # 计算拼接图像在两个方向的差异 dh_diff_img = torch.abs(stitched_image[:,:,0:-1*delta,:] - stitched_image[:,:,delta:,:]) # 计算stitched_image两个方向的差值 dw_diff_img = torch.abs(stitched_image[:,:,:,0:-1*delta] - stitched_image[:,:,:,delta:]) dh_pixel = dh_mask * dh_diff_img #相乘得到差异 dw_pixel = dw_mask * dw_diff_img loss = torch.mean(dh_pixel) + torch.mean(dw_pixel) return loss # 平滑项差异损失 def cal_smooth_term_diff(img1, img2, learned_mask1, overlap): """计算图像内容的平滑差异损失。 参数: img1: 第一个输入图像tensor。 img2: 第二个输入图像tensor。 learned_mask1: 学习到的掩码。 overlap: 重叠区域的掩码。 返回: loss: 平滑项差异损失的值。 """ diff_feature = torch.abs(img1-img2)**2 * overlap # 两个图像在重叠区域差异 delta = 1 # 在掩码中计算变化量 dh_mask = torch.abs(learned_mask1[:,:,0:-1*delta,:] - learned_mask1[:,:,delta:,:]) dw_mask = torch.abs(learned_mask1[:,:,:,0:-1*delta] - learned_mask1[:,:,:,delta:]) # 计算图像差异在两个方向上的变化 dh_diff_img = torch.abs(diff_feature[:,:,0:-1*delta,:] + diff_feature[:,:,delta:,:]) dw_diff_img = torch.abs(diff_feature[:,:,:,0:-1*delta] + diff_feature[:,:,:,delta:]) # 计算最终损失 dh_pixel = dh_mask * dh_diff_img dw_pixel = dw_mask * dw_diff_img loss = torch.mean(dh_pixel) + torch.mean(dw_pixel) return loss # dh_zeros = torch.zeros_like(dh_pixel) # dw_zeros = torch.zeros_like(dw_pixel) # if torch.cuda.is_available(): # dh_zeros = dh_zeros.cuda() # dw_zeros = dw_zeros.cuda() # loss = l_num_loss(dh_pixel, dh_zeros, 1) + l_num_loss(dw_pixel, dw_zeros, 1) # return loss, dh_pixel class FeatureConsistencyLoss(nn.Module): """VGG特征一致性损失(增强边界约束)""" def __init__(self, layer_indices=[3, 8, 15], alpha=0.5): super().__init__() # 加载预训练VGG(固定权重) self.vgg = self._load_vgg().eval() for param in self.vgg.parameters(): param.requires_grad = False self.layers = layer_indices self.alpha = alpha # 特征损失权重 # VGG标准化参数 self.mean = torch.tensor([123.68, 116.779, 103.939]).view(1, 3, 1, 1) self.std = torch.tensor([58.395, 57.12, 57.375]).view(1, 3, 1, 1) # 设备兼容 if torch.cuda.is_available(): self.mean = self.mean.cuda() self.std = self.std.cuda() def _load_vgg(self): """加载预训练VGG19模型""" from torchvision.models import vgg19 return vgg19(pretrained=True).features def forward(self, pred_img, target_img, boundary_mask): """计算特征一致性损失""" # 输入预处理 (0-1 -> VGG输入空间) pred_img = (pred_img + 1) * 127.5 # [-1,1] -> [0,255] target_img = (target_img + 1) * 127.5 # 标准化 pred_img = (pred_img - self.mean) / self.std target_img = (target_img - self.mean) / self.std # 提取多尺度特征 loss = 0 x_pred, x_target = pred_img, target_img for i, module in enumerate(self.vgg.children()): x_pred = module(x_pred) x_target = module(x_target) if i in self.layers: # 聚焦边界区域的差异 diff = F.mse_loss( x_pred * boundary_mask.expand_as(x_pred), x_target * boundary_mask.expand_as(x_target), reduction='none' ) loss += diff.mean() * self.alpha self.alpha *= 0.5 # 深层特征权重递减 return loss使函数接受的参数为源代码中的参数,并给出训练循环的使用代码
06-05
import argparse import torch from torch.utils.data import DataLoader import os import torch.optim as optim from torch.utils.tensorboard import SummaryWriter from network import build_model, Network from dataset import TrainDataset import glob from loss import cal_boundary_term, cal_smooth_term_stitch, cal_smooth_term_diff # path of project last_path = os.path.abspath(os.path.join(os.path.dirname("__file__"), os.path.pardir)) # path to save the summary files SUMMARY_DIR = os.path.join(last_path, 'summary') writer = SummaryWriter(log_dir=SUMMARY_DIR) # path to save the model files MODEL_DIR = os.path.join(last_path, 'model') # create folders if it dose not exist if not os.path.exists(MODEL_DIR): os.makedirs(MODEL_DIR) if not os.path.exists(SUMMARY_DIR): os.makedirs(SUMMARY_DIR) def train(args): os.environ['CUDA_DEVICES_ORDER'] = "PCI_BUS_ID" os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu # dataset train_data = TrainDataset(data_path=args.train_path) train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, num_workers=4, shuffle=True, drop_last=True) # define the network net = Network() if torch.cuda.is_available(): net = net.cuda() # define the optimizer and learning rate optimizer = optim.Adam(net.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-08) # default as 0.0001 scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.97) #load the existing models if it exists ckpt_list = glob.glob(MODEL_DIR + "/*.pth") ckpt_list.sort() if len(ckpt_list) != 0: model_path = ckpt_list[-1] checkpoint = torch.load(model_path) net.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch'] glob_iter = checkpoint['glob_iter'] scheduler.last_epoch = start_epoch print('load model from {}!'.format(model_path)) else: start_epoch = 0 glob_iter = 0 print('training from stratch!') print("##################start training#######################") score_print_fre = 300 for epoch in range(start_epoch, args.max_epoch): print("start epoch {}".format(epoch)) net.train() sigma_total_loss = 0. sigma_boundary_loss = 0. sigma_smooth1_loss = 0. sigma_smooth2_loss = 0. print(epoch, 'lr={:.6f}'.format(optimizer.state_dict()['param_groups'][0]['lr'])) for i, batch_value in enumerate(train_loader): warp1_tensor = batch_value[0].float() warp2_tensor = batch_value[1].float() mask1_tensor = batch_value[2].float() mask2_tensor = batch_value[3].float() if torch.cuda.is_available(): warp1_tensor = warp1_tensor.cuda() warp2_tensor = warp2_tensor.cuda() mask1_tensor = mask1_tensor.cuda() mask2_tensor = mask2_tensor.cuda() # forward, backward, update weights optimizer.zero_grad() batch_out = build_model(net, warp1_tensor, warp2_tensor, mask1_tensor, mask2_tensor) learned_mask1 = batch_out['learned_mask1'] learned_mask2 = batch_out['learned_mask2'] stitched_image = batch_out['stitched_image'] # boundary term # 论文公式(9)中的α = 10000,β=1000,前一项约束接缝的起点和终点,后一项约束路径。 boundary_loss, boundary_mask1 = cal_boundary_term( warp1_tensor, warp2_tensor, mask1_tensor, mask2_tensor, stitched_image) boundary_loss = 10000 * boundary_loss # smooth term # on stitched image smooth1_loss = cal_smooth_term_stitch(stitched_image, learned_mask1) smooth1_loss = 1000* smooth1_loss # on different image smooth2_loss = cal_smooth_term_diff( warp1_tensor, warp2_tensor, learned_mask1, mask1_tensor*mask2_tensor) smooth2_loss = 1000 * smooth2_loss total_loss = boundary_loss + smooth1_loss + smooth2_loss total_loss.backward() # clip the gradient torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=3, norm_type=2) optimizer.step() sigma_boundary_loss += boundary_loss.item() sigma_smooth1_loss += smooth1_loss.item() sigma_smooth2_loss += smooth2_loss.item() sigma_total_loss += total_loss.item() print(glob_iter) # print loss etc. if i % score_print_fre == 0 and i != 0: average_total_loss = sigma_total_loss / score_print_fre average_boundary_loss = sigma_boundary_loss/ score_print_fre average_smooth1_loss = sigma_smooth1_loss/ score_print_fre average_smooth2_loss = sigma_smooth2_loss/ score_print_fre sigma_total_loss = 0. sigma_boundary_loss = 0. sigma_smooth1_loss = 0. sigma_smooth2_loss = 0. print("Training: Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}]/[{:0>3}] Total Loss: {:.4f} boundary loss: {:.4f} smooth loss: {:.4f} diff loss: {:.4f} lr={:.8f}".format(epoch + 1, args.max_epoch, i + 1, len(train_loader), average_total_loss, average_boundary_loss, average_smooth1_loss, average_smooth2_loss, optimizer.state_dict()['param_groups'][0]['lr'])) # visualization writer.add_image("inpu1", (warp1_tensor[0]+1.)/2., glob_iter) writer.add_image("inpu2", (warp2_tensor[0]+1.)/2., glob_iter) writer.add_image("stitched_image", (stitched_image[0]+1.)/2., glob_iter) writer.add_image("learned_mask1", learned_mask1[0], glob_iter) writer.add_image("boundary_mask1", boundary_mask1[0], glob_iter) writer.add_scalar('lr', optimizer.state_dict()['param_groups'][0]['lr'], glob_iter) writer.add_scalar('total loss', average_total_loss, glob_iter) writer.add_scalar('average_boundary_loss', average_boundary_loss, glob_iter) writer.add_scalar('average_smooth1_loss', average_smooth1_loss, glob_iter) writer.add_scalar('average_smooth2_loss', average_smooth2_loss, glob_iter) glob_iter += 1 scheduler.step() # save model if ((epoch+1) % 10 == 0 or (epoch+1)==args.max_epoch): filename ='epoch' + str(epoch+1).zfill(3) + '_model.pth' model_save_path = os.path.join(MODEL_DIR, filename) state = {'model': net.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch+1, "glob_iter": glob_iter} torch.save(state, model_save_path) print("##################end training#######################") if __name__=="__main__": print('<==================== setting arguments ===================>\n') #nl: create the argument parser parser = argparse.ArgumentParser() #nl: add arguments parser.add_argument('--gpu', type=str, default='0') parser.add_argument('--batch_size', type=int, default=1) parser.add_argument('--max_epoch', type=int, default=50) parser.add_argument('--train_path', type=str, default=r'D:\keti\UDIS2-main\date\UDIS-D\composition2') #nl: parse the arguments args = parser.parse_args() print(args) print('<==================== jump into training function ===================>\n') #nl: rain train(args) 以上是我的训练代码,我已经在损失函数代码中增加了class FeatureConsistencyLoss(nn.Module): """VGG特征一致性损失(增强边界约束)""" def __init__(self, layer_indices=[3, 8, 15], alpha=0.5): super().__init__() # 加载预训练VGG(固定权重) self.vgg = self._load_vgg().eval() for param in self.vgg.parameters(): param.requires_grad = False self.layers = layer_indices self.alpha = alpha # 特征损失权重 # VGG标准化参数 self.mean = torch.tensor([123.68, 116.779, 103.939]).view(1, 3, 1, 1) self.std = torch.tensor([58.395, 57.12, 57.375]).view(1, 3, 1, 1) # 设备兼容 if torch.cuda.is_available(): self.mean = self.mean.cuda() self.std = self.std.cuda() def forward(self, pred_img, target_img, boundary_mask): """计算特征一致性损失 pred_img: 预测的拼接图像 [-1,1] target_img: 目标图像 [-1,1] boundary_mask: 边界掩码 [0,1] """ # 输入预处理 ([-1,1] -> VGG输入空间) pred_img = (pred_img + 1) * 127.5 # [-1,1] -> [0,255] target_img = (target_img + 1) * 127.5 # 标准化 pred_img = (pred_img - self.mean) / self.std target_img = (target_img - self.mean) / self.std # 确保边界掩码与特征图尺寸兼容 boundary_mask = boundary_mask.expand(-1, 3, -1, -1) # 扩展到3通道 # 提取多尺度特征 loss = 0 current_alpha = self.alpha feats_pred = [] feats_target = [] # 逐层处理VGG for i, layer in enumerate(self.vgg.children()): pred_img = layer(pred_img) target_img = layer(target_img) if i in self.layers: # 调整边界掩码尺寸 downsampled_mask = F.interpolate( boundary_mask, size=pred_img.shape[2:], mode='nearest' ) # 计算该层特征差异 diff = F.mse_loss( pred_img * downsampled_mask, target_img * downsampled_mask, reduction='none' ) # 按边界区域加权 weighted_diff = diff * downsampled_mask.mean(dim=1, keepdim=True) loss += weighted_diff.mean() * current_alpha # 深层特征权重递减 current_alpha *= 0.5 return loss,请帮我修改训练代码以调用我们增加的VGG特征一致性损失(增强边界约束)
06-05
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值