Bug 10190002: POOR PERFORMANCE WITH /*+ OPT_PARAM('_OPTIMIZER_PUSH_PRED_COST_BASED','TRUE')*/ _optim...

本文通过两个具体案例展示了如何通过调整隐藏参数_optimizer_extend_jppd_view_types来优化SQL执行计划,从而显著提高SQL查询效率。

Recently I’ve met many long running issues that due to hidden paraleter _optimizer_extend_jppd_view_types set to true.
I have open an SR (3-3225436011) to oracle, and oracle said that it’s a bug

Bug 10190002: POOR PERFORMANCE WITH /*+ OPT_PARAM('_OPTIMIZER_PUSH_PRED_COST_BASED','TRUE')*/

Here is an example for this bug:

SQL:

CREATE TABLE robinson AS

SELECT PRMTN.BUS_UNIT_SKID,

PRMTN.ACCT_PRMTN_SKID,

PRMTN.PRMTN_SKID,

CASE WHEN (PRMTN.PRMTN_STTUS_CODE <> 'Completed' AND NVL(ACTVY.ACTL_MDA_LOR_AMT, 0 ) < NVL(ACTVY.ESTMT_MDA_LOR_AMT, 0))

THEN DECODE( ( PRMTN.FIXED_COST_ESTMT_AMT + PRMTN.VAR_COST_ESTMT_AMT + PRMTN.BASE_COST_AMT + PRMTN.ESTMT_SPND_NUM ), 0, 0, ( NVL(BASLN.ACTL_TMP_GIV - BASLN.BASLN_TMP_GIV, 0 ) - NVL(ACTVY.ESTMT_MDA_LOR_AMT, 0 ) ) / ( PRMTN.FIXED_COST_ESTMT_AMT + PRMTN.VAR_COST_ESTMT_AMT + PRMTN.BASE_COST_AMT + PRMTN.ESTMT_SPND_NUM ))

ELSE DECODE( ( PRMTN.ACTL_COST_TOT + PRMTN.ACTL_SPND_NUM ), 0, 0, ( NVL(BASLN.ACTL_TMP_GIV - BASLN.BASLN_TMP_GIV, 0 ) - NVL(ACTVY.ACTL_MDA_LOR_AMT, 0) ) / ( PRMTN.ACTL_COST_TOT + PRMTN.ACTL_SPND_NUM ))

END AS ACTL_EVENT_ROI, -- Actual Event ROI

/*DECODE((PRMTN.FIXED_COST_ESTMT_AMT + PRMTN.VAR_COST_ESTMT_AMT + PRMTN.BASE_COST_AMT + PRMTN.ESTMT_SPND_NUM), 0, 0, ( NVL(PRMTN.ESTMT_INCRM_NR_AMT, 0) - NVL(ACTVY.ESTMT_ROI_MDA_LOR_AMT, 0) ) / (PRMTN.FIXED_COST_ESTMT_AMT + PRMTN.VAR_COST_ESTMT_AMT + PRMTN.BASE_COST_AMT + PRMTN.ESTMT_SPND_NUM)) AS ESTMT_EVENT_ROI, -- Estimated Event ROI*/

(NVL(PRMTN_PROD.ACTL_TMP_NOS, 0 ) * ( 100 - PRMTN.PCT_CNBLN_NUM) / 100 - CASE WHEN (PRMTN.PRMTN_STTUS_CODE <> 'Completed' AND NVL(ACTVY.ACTL_MDA_LOR_AMT, 0 ) < NVL(ACTVY.ESTMT_MDA_LOR_AMT, 0 ))

THEN NVL(ACTVY.ESTMT_MDA_LOR_AMT, 0 ) ELSE NVL(ACTVY.ACTL_MDA_LOR_AMT, 0 ) END) AS ACTL_NOS_AMT, -- Actual NOS

(NVL(BASLN.ACTL_TMP_GIV - BASLN.BASLN_TMP_GIV, 0 ) - CASE WHEN (PRMTN.PRMTN_STTUS_CODE <> 'Completed' AND NVL(ACTVY.ACTL_MDA_LOR_AMT, 0 ) < NVL(ACTVY.ESTMT_MDA_LOR_AMT, 0 ))

THEN NVL(ACTVY.ESTMT_MDA_LOR_AMT, 0 ) ELSE NVL(ACTVY.ACTL_MDA_LOR_AMT, 0 ) END ) AS ACTL_INCRM_NOS_AMT, -- Actual Incremental NOS

(NVL(PRMTN_PROD.ESTMT_TMP_NOS , 0 ) * ( 100 - PRMTN.PCT_CNBLN_NUM) / 100 - NVL(ACTVY.ESTMT_MDA_LOR_AMT , 0 )) AS ESTMT_NOS_AMT, -- Estimated NOS

(NVL(PRMTN_PROD.ESTMT_TMP_INC_NOS , 0 ) * ( 100 - PRMTN.PCT_CNBLN_NUM ) / 100 - NVL(ACTVY.ESTMT_MDA_LOR_AMT , 0 )) AS ESTMT_INCRM_NOS_AMT, -- Estimated Incremental NOS

NVL(ACTVY.ESTMT_MDA_LOR_AMT , 0 ) AS ESTMT_MDA_LOR_AMT, -- Estimated MDA LOR

NVL(ACTVY.ACTL_MDA_LOR_AMT , 0 ) AS ACTL_MDA_LOR_AMT, -- Actual MDA LOR,

NVL(EPOS_FCT.ACTL_EPOS_SU_AMT , 0 ) AS ACTL_EPOS_SU_AMT, -- EPOS SU Amount

CASE WHEN (PRMTN.PRMTN_STTUS_CODE <> 'Completed' AND NVL(ACTVY.ACTL_MDA_LOR_AMT , 0 ) < NVL(ACTVY.ESTMT_MDA_LOR_AMT , 0 ))

THEN NVL(EPOS_BASLN.EPOS_ACTL_TMP_GIV - EPOS_BASLN.EPOS_BASLN_TMP_GIV , 0) - NVL(ACTVY.ESTMT_MDA_LOR_AMT , 0)

ELSE NVL(EPOS_BASLN.EPOS_ACTL_TMP_GIV - EPOS_BASLN.EPOS_BASLN_TMP_GIV , 0) - NVL(ACTVY.ACTL_MDA_LOR_AMT , 0)

END AS ACTL_EPOS_INCRM_NOS_AMT, -- Actual ePOS Incremental NOS

CASE WHEN (PRMTN.PRMTN_STTUS_CODE <> 'Completed' AND NVL(ACTVY.ACTL_MDA_LOR_AMT , 0 ) < NVL(ACTVY.ESTMT_MDA_LOR_AMT , 0 ))

THEN DECODE( ( PRMTN.FIXED_COST_ESTMT_AMT + PRMTN.VAR_COST_ESTMT_AMT + PRMTN.BASE_COST_AMT + PRMTN.ESTMT_SPND_NUM ), 0, 0, ( NVL(EPOS_BASLN.EPOS_ACTL_TMP_GIV - EPOS_BASLN.EPOS_BASLN_TMP_GIV , 0) - NVL(ACTVY.ESTMT_MDA_LOR_AMT,0) ) / ( PRMTN.FIXED_COST_ESTMT_AMT + PRMTN.VAR_COST_ESTMT_AMT + PRMTN.BASE_COST_AMT + PRMTN.ESTMT_SPND_NUM ))

ELSE DECODE( ( PRMTN.ACTL_COST_TOT + PRMTN.ACTL_SPND_NUM ), 0, 0, ( NVL(EPOS_BASLN.EPOS_ACTL_TMP_GIV - EPOS_BASLN.EPOS_BASLN_TMP_GIV , 0) - NVL(ACTVY.ACTL_MDA_LOR_AMT, 0) ) / ( PRMTN.ACTL_COST_TOT + PRMTN.ACTL_SPND_NUM ))

END AS ACTL_EPOS_EVENT_ROI, -- Actual ePOS Event ROI

------ Optima90, B010, Bean, Begin

( PRMTN.ESTMT_INCRM_NIV_AX_AMT - (PRMTN.FIXED_COST_ESTMT_AMT + PRMTN.VAR_COST_ESTMT_AMT - PRMTN.ESTMT_VAR_COST_ON_INVC_NUM) ) AS ESTMT_INCRM_NOS_NIV_AMT,

---Estimated Incremental Net Outside Sales by Net Invoice Value Amount

(DECODE((PRMTN.FIXED_COST_ESTMT_AMT + PRMTN.VAR_COST_ESTMT_AMT + PRMTN.BASE_COST_AMT + PRMTN.ESTMT_SPND_NUM),NULL,0,0,0,

( PRMTN.ESTMT_INCRM_NIV_AX_AMT - (PRMTN.FIXED_COST_ESTMT_AMT + PRMTN.VAR_COST_ESTMT_AMT - PRMTN.ESTMT_VAR_COST_ON_INVC_NUM)) / (PRMTN.FIXED_COST_ESTMT_AMT + PRMTN.VAR_COST_ESTMT_AMT + PRMTN.BASE_COST_AMT + PRMTN.ESTMT_SPND_NUM))) AS ESTMT_EVENT_ROI_NIV_RATE, ---Estimated Event Return On Investment by Net Invoice Value

CASE WHEN (PRMTN.PRMTN_STTUS_CODE <> 'Completed' AND

((((PRMTN.CALC_INDEX_NUM + PRMTN.ACTL_VAR_COST_NUM) - PRMTN.ACTL_VAR_COST_ON_INVC_NUM)) ----ACTUAL INCREMENTAL NIV - ACTUAL MDA NOT PAID ON INVOICE

< (((PRMTN.FIXED_COST_ESTMT_AMT + PRMTN.VAR_COST_ESTMT_AMT) - PRMTN.ESTMT_VAR_COST_ON_INVC_NUM )))) ----- ESTIMATED INCREMENTAL NIV - ESTIMATED MDA NOT PAID ON INVOICE

THEN ( NVL(BASLN.ACTL_TMP_NIV - BASLN.BASLN_TMP_NIV, 0 ) - NVL(((PRMTN.FIXED_COST_ESTMT_AMT + PRMTN.VAR_COST_ESTMT_AMT) - PRMTN.ESTMT_VAR_COST_ON_INVC_NUM ),0))

ELSE ( NVL(BASLN.ACTL_TMP_NIV - BASLN.BASLN_TMP_NIV, 0 ) - NVL(((PRMTN.CALC_INDEX_NUM + PRMTN.ACTL_VAR_COST_NUM) - PRMTN.ACTL_VAR_COST_ON_INVC_NUM),0))

END AS ACTL_INCRM_NOS_NIV_AMT, ---Actual Incremental Net Outside Sales by Net Invoice Value Amount

-----Actual Event Return On Investment by Net Invoice Value

CASE WHEN (PRMTN.PRMTN_STTUS_CODE <> 'Completed' AND

((((PRMTN.CALC_INDEX_NUM + PRMTN.ACTL_VAR_COST_NUM) - PRMTN.ACTL_VAR_COST_ON_INVC_NUM)) ----ACTUAL INCREMENTAL NIV - ACTUAL MDA NOT PAID ON INVOICE

< ((PRMTN.FIXED_COST_ESTMT_AMT + PRMTN.VAR_COST_ESTMT_AMT) - PRMTN.ESTMT_VAR_COST_ON_INVC_NUM ))) ----- ESTIMATED INCREMENTAL NIV - ESTIMATED MDA NOT PAID ON INVOICE

THEN DECODE((PRMTN.FIXED_COST_ESTMT_AMT + PRMTN.VAR_COST_ESTMT_AMT + PRMTN.BASE_COST_AMT + PRMTN.ESTMT_SPND_NUM),null,0,0,0,

( NVL(BASLN.ACTL_TMP_NIV - BASLN.BASLN_TMP_NIV, 0 ) - NVL(((PRMTN.FIXED_COST_ESTMT_AMT + PRMTN.VAR_COST_ESTMT_AMT) - PRMTN.ESTMT_VAR_COST_ON_INVC_NUM ),0)) /

(PRMTN.FIXED_COST_ESTMT_AMT + PRMTN.VAR_COST_ESTMT_AMT + PRMTN.BASE_COST_AMT + PRMTN.ESTMT_SPND_NUM))

ELSE DECODE(PRMTN.ACTL_COST_TOT + PRMTN.ACTL_SPND_NUM, null, 0, 0, 0, (( BASLN.ACTL_TMP_NIV - BASLN.BASLN_TMP_NIV) - ((PRMTN.CALC_INDEX_NUM + PRMTN.ACTL_VAR_COST_NUM) - PRMTN.ACTL_VAR_COST_ON_INVC_NUM))/(PRMTN.ACTL_COST_TOT + PRMTN.ACTL_SPND_NUM))

END AS ACTL_EVENT_ROI_NIV_RATE, --- Actual Event Return On Investment by Net Invoice Value

BASLN.ACTL_TMP_NIV AS ACTL_SHPMT_NIV_AMT,

BASLN.BASLN_TMP_NIV AS ACTL_SHPMT_BASLN_NIV_AMT

FROM (

SELECT BFCT.PRMTN_SKID,BFCT.ACCT_SKID,

SUM( BFCT.ACTL_GIV_AMT * (100 - NVL(BRAND_BASLN_FCT.TRADE_TERM_PCT,0)) / 100 ) AS ACTL_TMP_GIV,

SUM( BFCT.BASLN_GIV_AMT * (100 - NVL(BRAND_BASLN_FCT.TRADE_TERM_PCT,0)) / 100 ) AS BASLN_TMP_GIV,

------ Optima90, B010, Bean, Begin

SUM( BFCT.ACTL_NIV_AMT ) AS ACTL_TMP_NIV,

SUM( BFCT.BASLN_NIV_AMT ) AS BASLN_TMP_NIV

------ Optima90, B010, Bean, End

FROM ( -- Agrregate Actual GIV and Baseline GIV from Week to Month Level

SELECT BFCT.BUS_UNIT_SKID, BFCT.PROD_SKID, BFCT.ACCT_SKID, BFCT.PRMTN_SKID, D.MTH_SKID,

SUM(BFCT.ACTL_GIV_AMT) AS ACTL_GIV_AMT,

SUM(BFCT.BASLN_GIV_AMT) AS BASLN_GIV_AMT,

------ Optima90, B010, Bean, Begin

SUM(BFCT.ACTL_NIV_AMT) AS ACTL_NIV_AMT,

SUM(BFCT.BASLN_NIV_AMT) AS BASLN_NIV_AMT

------ Optima90, B010, Bean, End

FROM OPT_BASLN_FCT BFCT, OPT_CAL_MASTR_MV01 D

WHERE BFCT.WK_SKID = D.CAL_MASTR_SKID

GROUP BY BFCT.BUS_UNIT_SKID, BFCT.PROD_SKID, BFCT.ACCT_SKID, BFCT.PRMTN_SKID, D.MTH_SKID

)BFCT, OPT_BRAND_BASLN_FCT BRAND_BASLN_FCT, (SELECT DISTINCT PFCT.BASE_PRMTN_SKID AS PRMTN_SKID

FROM OPT_PRMTN_FCT PFCT

/*WHERE PFCT.REGN_CODE = 'AP'*/) PRMTN_FLT

WHERE BFCT.BUS_UNIT_SKID = BRAND_BASLN_FCT.BUS_UNIT_SKID(+)

AND BFCT.PROD_SKID = BRAND_BASLN_FCT.PROD_SKID(+)

AND BFCT.ACCT_SKID = BRAND_BASLN_FCT.PRMTN_ACCT_SKID(+)

AND BFCT.MTH_SKID = BRAND_BASLN_FCT.DATE_SKID(+)

AND BFCT.PRMTN_SKID = PRMTN_FLT.PRMTN_SKID

GROUP BY BFCT.PRMTN_SKID,BFCT.ACCT_SKID

) BASLN,

(

SELECT PFCT.BUS_UNIT_SKID,

PFCT.ACCT_PRMTN_SKID,

PFCT.BASE_PRMTN_SKID AS PRMTN_SKID,

PFCT.ACTL_SPND_NUM AS ACTL_SPND_NUM,

PFCT.CALC_INDEX_NUM + PFCT.ACTL_VAR_COST_NUM AS ACTL_COST_TOT,

PDIM.PRMTN_STTUS_CODE,

PFCT.VAR_COST_ESTMT_AMT AS VAR_COST_ESTMT_AMT,

PFCT.FIXED_COST_ESTMT_AMT AS FIXED_COST_ESTMT_AMT,

PFCT.BASE_COST_AMT AS BASE_COST_AMT,

PFCT.ESTMT_SPND_NUM AS ESTMT_SPND_NUM,

PFCT.ACTL_VAR_COST_NUM AS ACTL_VAR_COST_NUM,

PFCT.ESTMT_INCRM_NR_AMT,

PFCT.PCT_CNBLN_NUM,

-- Optima90, B010, Bean, Begin

PFCT.ESTMT_VAR_COST_ON_INVC_NUM,

PFCT.ACTL_VAR_COST_ON_INVC_NUM,

PFCT.CALC_INDEX_NUM,

PFCT.ESTMT_INCRM_NIV_AX_AMT

-- Optima90, B010, Bean, End

FROM OPT_PRMTN_FCT PFCT, OPT_PRMTN_DIM PDIM, (SELECT DISTINCT PFCT.BASE_PRMTN_SKID AS PRMTN_SKID

FROM OPT_PRMTN_FCT PFCT

/*WHERE PFCT.REGN_CODE = 'AP'*/) PRMTN_FLT

WHERE PDIM.PRMTN_SKID = PFCT.BASE_PRMTN_SKID

AND PDIM.PRMTN_SKID<>0

AND PFCT.BASE_PRMTN_SKID = PRMTN_FLT.PRMTN_SKID

) PRMTN,

(

SELECT PRMTN_PROD_FCT.PRMTN_SKID,

SUM(PRMTN_PROD_FCT.ACTL_GIV_AMT * ( 100 - PRMTN_PROD_FCT.TRADE_TERM_PCT) / 100) AS ACTL_TMP_NOS,

SUM(PRMTN_PROD_FCT.TOT_IN_GIV_AMT * ( 100 - PRMTN_PROD_FCT.TRADE_TERM_PCT) / 100) AS ESTMT_TMP_NOS, -- ESTMT_SPND_AMT -> TOT_IN_GIV_AMT, Changed by joy on Oct. 21, 2009

SUM(PRMTN_PROD_FCT.INCRM_IN_GIV_AMT * ( 100 - PRMTN_PROD_FCT.TRADE_TERM_PCT) / 100) AS ESTMT_TMP_INC_NOS

FROM OPT_PRMTN_PROD_FCT PRMTN_PROD_FCT, (SELECT DISTINCT PFCT.BASE_PRMTN_SKID AS PRMTN_SKID

FROM OPT_PRMTN_FCT PFCT

/*WHERE PFCT.REGN_CODE = 'AP'*/) PRMTN_FLT

WHERE PRMTN_PROD_FCT.PRMTN_SKID = PRMTN_FLT.PRMTN_SKID

GROUP BY PRMTN_PROD_FCT.PRMTN_SKID

) PRMTN_PROD,

(

SELECT OPT_ACTVY_FCT.PRMTN_SKID,

-- optima90, B019, Bean, begin

-- SUM(DECODE(SUBSTR(OPT_ACTVY_FDIM.COST_ELEM_CODE,1,2),'30', OPT_ACTVY_FCT.PRDCT_FIXED_COST_AMT + OPT_ACTVY_FCT.VAR_COST_ESTMT_AMT)) AS ESTMT_MDA_LOR_AMT, -- Estimated MDA LOR

-- SUM(DECODE(SUBSTR(OPT_ACTVY_FDIM.COST_ELEM_CODE,1,2),'30', OPT_ACTVY_FCT.PRDCT_FIXED_COST_AMT * ( 100 + OPT_ACTVY_FDIM.ACTVY_PCT_SPLIT_NUM) / 100 + OPT_ACTVY_FCT.VAR_COST_ESTMT_AMT)) AS ESTMT_ROI_MDA_LOR_AMT,

-- Estimated MDA LOR for Estimated Event ROI

-- SUM(DECODE(SUBSTR(OPT_ACTVY_FDIM.COST_ELEM_CODE,1,2),'30', OPT_ACTVY_FCT.CALC_INDEX_NUM + OPT_ACTVY_FCT.ACTL_VAR_COST_NUM)) AS ACTL_MDA_LOR_AMT -- Actual MDA LOR

--Bean, optima90, B019, remarked end

-- optima90, B019, Bean, begin recalcuate Estimated MDA LOR, Actual MDA LOR to avoid cost double counting according to SRS B019

SUM(DECODE(SUBSTR(OPT_ACTVY_FDIM.COST_ELEM_CODE,1,2),'30',DECODE(OPT_ACTVY_FDIM.ACTVY_SPECL_PACK_IND , NULL, OPT_ACTVY_FCT.PRDCT_FIXED_COST_AMT + OPT_ACTVY_FCT.VAR_COST_ESTMT_AMT ,0))) AS ESTMT_MDA_LOR_AMT,

-- Estimated MDA LOR

-- optima90, B019, Bean, end

SUM(DECODE(SUBSTR(OPT_ACTVY_FDIM.COST_ELEM_CODE,1,2),'30', OPT_ACTVY_FCT.PRDCT_FIXED_COST_AMT * ( 100 + OPT_ACTVY_FDIM.ACTVY_PCT_SPLIT_NUM) / 100 + OPT_ACTVY_FCT.VAR_COST_ESTMT_AMT)) AS ESTMT_ROI_MDA_LOR_AMT,

-- Estimated MDA LOR for Estimated Event ROI

-- Optima90, B019, Bean Begin

SUM(DECODE(SUBSTR(OPT_ACTVY_FDIM.COST_ELEM_CODE,1,2),'30',DECODE(OPT_ACTVY_FDIM.ACTVY_SPECL_PACK_IND , NULL, OPT_ACTVY_FCT.CALC_INDEX_NUM + OPT_ACTVY_FCT.ACTL_VAR_COST_NUM, 0))) AS ACTL_MDA_LOR_AMT

-- Actual MDA LOR

-- Optima90, B019, Bean end

FROM OPT_ACTVY_FCT, OPT_ACTVY_FDIM, (SELECT DISTINCT PFCT.BASE_PRMTN_SKID AS PRMTN_SKID

FROM OPT_PRMTN_FCT PFCT

/*WHERE PFCT.REGN_CODE = 'AP'*/) PRMTN_FLT

WHERE OPT_ACTVY_FCT.ACTVY_SKID = OPT_ACTVY_FDIM.ACTVY_SKID

AND OPT_ACTVY_FCT.PRMTN_SKID = PRMTN_FLT.PRMTN_SKID

GROUP BY OPT_ACTVY_FCT.PRMTN_SKID

) ACTVY,

(

SELECT PRMTN_PROD_FCT.PRMTN_SKID AS PRMTN_SKID,

SUM(EPOS_FCT.VOL_SU) AS ACTL_EPOS_SU_AMT

FROM OPT_EPOS_FCT EPOS_FCT,

OPT_PRMTN_PROD_FCT PRMTN_PROD_FCT,

OPT_CAL_MASTR_MV01 CAL_MASTR,

OPT_PRMTN_FDIM PRMTN_FDIM,

/*

(SELECT ACCT_SKID,

CONNECT_BY_ROOT ACCT_SKID AS ROOT_ACCT_SKID,

ACCT_TYPE_DESC

FROM OPT_ACCT_FDIM

CONNECT BY PRIOR ACCT_ID = PARNT_ACCT_ID

) HIER,

*/

-- Type 2 Account denormalized dimension

OPT_ACCT_ASDN_TYPE2_DIM HIER, (SELECT DISTINCT PFCT.BASE_PRMTN_SKID AS PRMTN_SKID

FROM OPT_PRMTN_FCT PFCT

/*WHERE PFCT.REGN_CODE = 'AP'*/) PRMTN_FLT

WHERE EPOS_FCT.PROD_SKID = PRMTN_PROD_FCT.PROD_SKID

AND PRMTN_PROD_FCT.PRMTN_SKID = PRMTN_FDIM.PRMTN_SKID

AND EPOS_FCT.DATE_SKID = CAL_MASTR.CAL_MASTR_SKID

AND CAL_MASTR.DAY_DATE BETWEEN HIER.ASDN_EFF_START_DATE AND HIER.ASDN_EFF_END_DATE

AND CAL_MASTR.DAY_DATE BETWEEN PRMTN_FDIM.PGM_START_DATE AND PRMTN_FDIM.PGM_END_DATE

-- ACCT_SKID in EPOS_FCT is on lower level, so ACCT_SKID in OPT_ACCT_ASDN_TYPE2_DIM is used

AND EPOS_FCT.ACCT_SKID = HIER.ACCT_SKID

-- ACCT_SKID in PROMOTION dimension is on higher level, so ASSOC_ACCT_SKID in OPT_ACCT_ASDN_TYPE2_DIM is used

--AND HIER.ROOT_ACCT_SKID = PRMTN_FDIM.ACCT_SKID

AND HIER.ASSOC_ACCT_SKID = PRMTN_FDIM.ACCT_SKID

AND PRMTN_PROD_FCT.PRMTN_SKID = PRMTN_FLT.PRMTN_SKID

GROUP BY PRMTN_PROD_FCT.PRMTN_SKID

) EPOS_FCT,

(

SELECT EPOS_BFCT.PRMTN_SKID AS PRMTN_SKID,

SUM( EPOS_BFCT.ACTL_GIV_AMT * (100 - NVL(BRAND_BASLN_FCT.TRADE_TERM_PCT,0)) / 100 ) AS EPOS_ACTL_TMP_GIV,

SUM( EPOS_BFCT.BASLN_GIV_AMT * (100 - NVL(BRAND_BASLN_FCT.TRADE_TERM_PCT,0)) / 100 ) AS EPOS_BASLN_TMP_GIV

FROM ( -- Agrregate Actual GIV and Baseline GIV from Week to Month Level

SELECT EPOS_BFCT.BUS_UNIT_SKID, EPOS_BFCT.ACCT_SKID, EPOS_BFCT.PRMTN_SKID, EPOS_BFCT.PROD_SKID, D.MTH_SKID,

SUM(EPOS_BFCT.ACTL_GIV_AMT) AS ACTL_GIV_AMT,

SUM(EPOS_BFCT.BASLN_GIV_AMT) AS BASLN_GIV_AMT

FROM OPT_EPOS_POST_EVENT_BASLN_FCT EPOS_BFCT, OPT_CAL_MASTR_MV01 D

WHERE EPOS_BFCT.WK_SKID = D.CAL_MASTR_SKID

GROUP BY EPOS_BFCT.BUS_UNIT_SKID, EPOS_BFCT.ACCT_SKID, EPOS_BFCT.PRMTN_SKID, EPOS_BFCT.PROD_SKID, D.MTH_SKID

)EPOS_BFCT,

OPT_BRAND_BASLN_FCT BRAND_BASLN_FCT

WHERE EPOS_BFCT.BUS_UNIT_SKID = BRAND_BASLN_FCT.BUS_UNIT_SKID(+)

AND EPOS_BFCT.PROD_SKID = BRAND_BASLN_FCT.PROD_SKID(+)

AND EPOS_BFCT.ACCT_SKID = BRAND_BASLN_FCT.PRMTN_ACCT_SKID(+)

AND EPOS_BFCT.MTH_SKID = BRAND_BASLN_FCT.DATE_SKID(+)

-- AND EPOS_BFCT.PRMTN_SKID = PRMTN_FLT.PRMTN_SKID

AND EPOS_BFCT.PRMTN_SKID IN (SELECT DISTINCT PFCT.BASE_PRMTN_SKID AS PRMTN_SKID

FROM OPT_PRMTN_FCT PFCT

/*WHERE PFCT.REGN_CODE = 'AP'*/ )

GROUP BY EPOS_BFCT.PRMTN_SKID

) EPOS_BASLN

WHERE PRMTN.PRMTN_SKID = BASLN.PRMTN_SKID(+)

AND PRMTN.ACCT_PRMTN_SKID = BASLN.ACCT_SKID(+)

AND PRMTN.PRMTN_SKID = ACTVY.PRMTN_SKID(+)

AND PRMTN.PRMTN_SKID = PRMTN_PROD.PRMTN_SKID(+)

AND PRMTN.PRMTN_SKID = EPOS_FCT.PRMTN_SKID(+)

AND PRMTN.PRMTN_SKID = EPOS_BASLN.PRMTN_SKID(+);

Explain plan for the SQL , Execution plan as below:

SQL> select * from table(dbms_xplan.display);

PLAN_TABLE_OUTPUT

---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Plan hash value: 392063744

---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

| Id | Operation | Name | Rows | Bytes | Cost (%CPU)| Time | Pstart| Pstop | TQ |IN-OUT| PQ Distrib |

---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

| 0 | CREATE TABLE STATEMENT | | 1 | 313 | 257K (1)| 00:13:27 | | | | | |

| 1 | LOAD AS SELECT | ROBINSON | | | | | | | | | |

| 2 | NESTED LOOPS | | | | | | | | | | |

| 3 | NESTED LOOPS | | 1 | 313 | 257K (1)| 00:13:27 | | | | | |

| 4 | NESTED LOOPS OUTER | | 1 | 297 | 257K (1)| 00:13:27 | | | | | |

| 5 | NESTED LOOPS OUTER | | 1 | 258 | 257K (1)| 00:13:26 | | | | | |

| 6 | NESTED LOOPS OUTER | | 1 | 232 | 257K (1)| 00:13:25 | | | | | |

| 7 | NESTED LOOPS OUTER | | 1 | 219 | 256K (1)| 00:13:23 | | | | | |

| 8 | NESTED LOOPS SEMI | | 1 | 193 | 256K (1)| 00:13:22 | | | | | |

| 9 | NESTED LOOPS OUTER | | 168K| 26M| 86690 (1)| 00:04:32 | | | | | |

| 10 | PARTITION LIST ALL | | 168K| 18M| 1279 (7)| 00:00:05 | 1 | 16 | | | |

|* 11 | TABLE ACCESS FULL | OPT_PRMTN_FCT | 168K| 18M| 1279 (7)| 00:00:05 | 1 | 16 | | | |

| 12 | VIEW PUSHED PREDICATE | | 1 | 52 | | | | | | | |

|* 13 | FILTER | | | | | | | | | | |

| 14 | SORT AGGREGATE | | 1 | 206 | | | | | | | |

|* 15 | FILTER | | | | | | | | | | |

| 16 | NESTED LOOPS OUTER | | 1 | 206 | 605 (7)| 00:00:02 | | | | | |

| 17 | NESTED LOOPS | | 1 | 182 | 56 (4)| 00:00:01 | | | | | |

| 18 | VIEW | | 1 | 13 | 19 (6)| 00:00:01 | | | | | |

| 19 | SORT UNIQUE | | 1 | 6 | 18 (6)| 00:00:01 | | | | | |

|* 20 | FILTER | | | | | | | | | | |

| 21 | PARTITION LIST ALL | | 1 | 6 | 17 (0)| 00:00:01 | 1 | 16 | | | |

|* 22 | BITMAP INDEX SINGLE VALUE | OPT_PRMTN_FCT_BX2 | 1 | 6 | 17 (0)| 00:00:01 | 1 | 16 | | | |

|* 23 | VIEW | | 1 | 169 | 37 (3)| 00:00:01 | | | | | |

| 24 | SORT GROUP BY | | 1 | 63 | 36 (3)| 00:00:01 | | | | | |

|* 25 | FILTER | | | | | | | | | | |

| 26 | NESTED LOOPS | | | | | | | | | | |

| 27 | NESTED LOOPS | | 1 | 63 | 35 (0)| 00:00:01 | | | | | |

| 28 | PARTITION LIST ALL | | 1 | 53 | 34 (0)| 00:00:01 | 1 | 16 | | | |

| 29 | TABLE ACCESS BY LOCAL INDEX ROWID | OPT_BASLN_FCT | 1 | 53 | 34 (0)| 00:00:01 | 1 | 16 | | | |

| 30 | BITMAP CONVERSION TO ROWIDS | | | | | | | | | | |

| 31 | BITMAP AND | | | | | | | | | | |

|* 32 | BITMAP INDEX SINGLE VALUE | OPT_BASLN_FCT_NX4 | | | | | 1 | 16 | | | |

|* 33 | BITMAP INDEX SINGLE VALUE | OPT_BASLN_FCT_NX3 | | | | | 1 | 16 | | | |

|* 34 | INDEX UNIQUE SCAN | OPT_CAL_MASTR_MV01_PK | 1 | | 0 (0)| 00:00:01 | | | | | |

| 35 | TABLE ACCESS BY INDEX ROWID | OPT_CAL_MASTR_MV01 | 1 | 10 | 1 (0)| 00:00:01 | | | | | |

| 36 | PX COORDINATOR | | | | | | | | | | |

| 37 | PX SEND QC (RANDOM) | :TQ10000 | 1 | 24 | 549 (7)| 00:00:02 | | | Q1,00 | P->S | QC (RAND) |

| 38 | PX BLOCK ITERATOR | | 1 | 24 | 549 (7)| 00:00:02 | KEY | KEY | Q1,00 | PCWC | |

|* 39 | TABLE ACCESS FULL | OPT_BRAND_BASLN_FCT | 1 | 24 | 549 (7)| 00:00:02 | KEY | KEY | Q1,00 | PCWP | |

| 40 | VIEW PUSHED PREDICATE | | 1 | 26 | | | | | | | |

|* 41 | FILTER | | | | | | | | | | |

| 42 | PARTITION LIST ALL | | 1 | 6 | 17 (0)| 00:00:01 | 1 | 16 | | | |

| 43 | BITMAP CONVERSION TO ROWIDS | | 1 | 6 | 17 (0)| 00:00:01 | | | | | |

|* 44 | BITMAP INDEX SINGLE VALUE | OPT_PRMTN_FCT_BX2 | | | | | 1 | 16 | | | |

| 45 | VIEW PUSHED PREDICATE | | 1 | 26 | | | | | | | |

|* 46 | FILTER | | | | | | | | | | |

| 47 | SORT AGGREGATE | | 1 | 147 | | | | | | | |

| 48 | NESTED LOOPS SEMI | | 1 | 147 | 569 (8)| 00:00:02 | | | | | |

| 49 | NESTED LOOPS OUTER | | 1 | 141 | 560 (8)| 00:00:02 | | | | | |

| 50 | VIEW | | 1 | 117 | 3 (34)| 00:00:01 | | | | | |

| 51 | SORT GROUP BY | | 1 | 101 | 2 (50)| 00:00:01 | | | | | |

|* 52 | FILTER | | | | | | | | | | |

| 53 | NESTED LOOPS | | | | | | | | | | |

| 54 | NESTED LOOPS | | 1 | 101 | 1 (0)| 00:00:01 | | | | | |

| 55 | PARTITION LIST ALL | | 1 | 91 | 1 (0)| 00:00:01 | 1 | 16 | | | |

| 56 | TABLE ACCESS BY LOCAL INDEX ROWID | OPT_EPOS_POST_EVENT_BASLN_FCT | 1 | 91 | 1 (0)| 00:00:01 | 1 | 16 | | | |

| 57 | BITMAP CONVERSION TO ROWIDS | | | | | | | | | | |

|* 58 | BITMAP INDEX SINGLE VALUE | OPT_EPOS_PST_EVNT_BSLN_FCT_BX4 | | | | | 1 | 16 | | | |

|* 59 | INDEX UNIQUE SCAN | OPT_CAL_MASTR_MV01_PK | 1 | | 0 (0)| 00:00:01 | | | | | |

| 60 | TABLE ACCESS BY INDEX ROWID | OPT_CAL_MASTR_MV01 | 1 | 10 | 0 (0)| 00:00:01 | | | | | |

| 61 | PX COORDINATOR | | | | | | | | | | |

| 62 | PX SEND QC (RANDOM) | :TQ20000 | 1 | 24 | 557 (8)| 00:00:02 | | | Q2,00 | P->S | QC (RAND) |

| 63 | PX BLOCK ITERATOR | | 1 | 24 | 557 (8)| 00:00:02 | KEY | KEY | Q2,00 | PCWC | |

|* 64 | TABLE ACCESS FULL | OPT_BRAND_BASLN_FCT | 1 | 24 | 557 (8)| 00:00:02 | KEY | KEY | Q2,00 | PCWP | |

| 65 | PARTITION LIST ALL | | 1 | 6 | 569 (8)| 00:00:02 | 1 | 16 | | | |

| 66 | BITMAP CONVERSION TO ROWIDS | | 1 | 6 | 569 (8)| 00:00:02 | | | | | |

|* 67 | BITMAP INDEX SINGLE VALUE | OPT_PRMTN_FCT_BX2 | | | | | 1 | 16 | | | |

| 68 | VIEW PUSHED PREDICATE | | 1 | 13 | | | | | | | |

|* 69 | FILTER | | | | | | | | | | |

| 70 | SORT AGGREGATE | | 1 | 135 | | | | | | | |

|* 71 | PX COORDINATOR | | | | | | | | | | |

| 72 | PX SEND QC (RANDOM) | :TQ30002 | 1 | 135 | | | | | Q3,02 | P->S | QC (RAND) |

| 73 | SORT AGGREGATE | | 1 | 135 | | | | | Q3,02 | PCWP | |

|* 74 | FILTER | | | | | | | | Q3,02 | PCWC | |

| 75 | NESTED LOOPS SEMI | | 1 | 135 | 655 (1)| 00:00:03 | | | Q3,02 | PCWP | |

| 76 | NESTED LOOPS | | 1 | 109 | 636 (1)| 00:00:02 | | | Q3,02 | PCWP | |

| 77 | NESTED LOOPS | | 1 | 81 | 634 (1)| 00:00:02 | | | Q3,02 | PCWP | |

|* 78 | HASH JOIN | | 3 | 204 | 633 (1)| 00:00:02 | | | Q3,02 | PCWP | |

|* 79 | HASH JOIN | | 1 | 46 | 334 (1)| 00:00:02 | | | Q3,02 | PCWP | |

| 80 | BUFFER SORT | | | | | | | | Q3,02 | PCWC | |

| 81 | PX RECEIVE | | 1 | 28 | 3 (0)| 00:00:01 | | | Q3,02 | PCWP | |

| 82 | PX SEND BROADCAST | :TQ30000 | 1 | 28 | 3 (0)| 00:00:01 | | | | S->P | BROADCAST |

| 83 | TABLE ACCESS BY GLOBAL INDEX ROWID| OPT_PRMTN_FDIM | 1 | 28 | 3 (0)| 00:00:01 | ROWID | ROWID | | | |

|* 84 | INDEX RANGE SCAN | OPT_PRMTN_FDIM_PK | 1 | | 2 (0)| 00:00:01 | | | | | |

| 85 | PX PARTITION LIST ALL | | 912 | 16416 | 330 (1)| 00:00:02 | 1 | 16 | Q3,02 | PCWC | |

| 86 | TABLE ACCESS BY LOCAL INDEX ROWID | OPT_PRMTN_PROD_FCT | 912 | 16416 | 330 (1)| 00:00:02 | 1 | 16 | Q3,02 | PCWP | |

| 87 | BITMAP CONVERSION TO ROWIDS | | | | | | | | Q3,02 | PCWP | |

|* 88 | BITMAP INDEX SINGLE VALUE | OPT_PRMTN_PROD_FCT_BX4 | | | | | 1 | 16 | Q3,02 | PCWP | |

| 89 | BUFFER SORT | | | | | | | | Q3,02 | PCWC | |

| 90 | PX RECEIVE | | 3039 | 66858 | 298 (2)| 00:00:01 | | | Q3,02 | PCWP | |

| 91 | PX SEND BROADCAST | :TQ30001 | 3039 | 66858 | 298 (2)| 00:00:01 | | | | S->P | BROADCAST |

| 92 | PARTITION RANGE ALL | | 3039 | 66858 | 298 (2)| 00:00:01 | 1 | 21 | | | |

| 93 | PARTITION LIST ALL | | 3039 | 66858 | 298 (2)| 00:00:01 | 1 | 5 | | | |

| 94 | TABLE ACCESS FULL | OPT_EPOS_FCT | 3039 | 66858 | 298 (2)| 00:00:01 | 1 | 105 | | | |

|* 95 | TABLE ACCESS BY INDEX ROWID | OPT_CAL_MASTR_MV01 | 1 | 13 | 1 (0)| 00:00:01 | | | Q3,02 | PCWP | |

|* 96 | INDEX UNIQUE SCAN | OPT_CAL_MASTR_MV01_PK | 1 | | 0 (0)| 00:00:01 | | | Q3,02 | PCWP | |

|* 97 | TABLE ACCESS BY GLOBAL INDEX ROWID | OPT_ACCT_ASDN_TYPE2_DIM | 1 | 28 | 3 (0)| 00:00:01 | ROWID | ROWID | Q3,02 | PCWP | |

|* 98 | INDEX RANGE SCAN | OPT_ACCT_ASDN_TYPE2_DIM_PK | 1 | | 2 (0)| 00:00:01 | | | Q3,02 | PCWP | |

| 99 | VIEW PUSHED PREDICATE | | 1 | 26 | | | | | Q3,02 | PCWP | |

|*100 | FILTER | | | | | | | | Q3,02 | PCWP | |

| 101 | PARTITION LIST ALL | | 1 | 6 | 17 (0)| 00:00:01 | 1 | 16 | Q3,02 | PCWP | |

| 102 | BITMAP CONVERSION TO ROWIDS | | 1 | 6 | 17 (0)| 00:00:01 | | | Q3,02 | PCWP | |

|*103 | BITMAP INDEX SINGLE VALUE | OPT_PRMTN_FCT_BX2 | | | | | 1 | 16 | Q3,02 | PCWP | |

| 104 | VIEW PUSHED PREDICATE | | 1 | 26 | | | | | | | |

|*105 | FILTER | | | | | | | | | | |

| 106 | SORT AGGREGATE | | 1 | 58 | | | | | | | |

|*107 | FILTER | | | | | | | | | | |

| 108 | NESTED LOOPS | | | | | | | | | | |

| 109 | NESTED LOOPS | | 1 | 58 | 39 (6)| 00:00:01 | | | | | |

|*110 | HASH JOIN | | 1 | 39 | 38 (6)| 00:00:01 | | | | | |

| 111 | VIEW | | 1 | 13 | 19 (6)| 00:00:01 | | | | | |

| 112 | SORT UNIQUE | | 1 | 6 | 18 (6)| 00:00:01 | | | | | |

|*113 | FILTER | | | | | | | | | | |

| 114 | PARTITION LIST ALL | | 1 | 6 | 17 (0)| 00:00:01 | 1 | 16 | | | |

|*115 | BITMAP INDEX SINGLE VALUE | OPT_PRMTN_FCT_BX2 | 1 | 6 | 17 (0)| 00:00:01 | 1 | 16 | | | |

| 116 | PARTITION LIST ALL | | 3 | 78 | 18 (0)| 00:00:01 | 1 | 16 | | | |

| 117 | TABLE ACCESS BY LOCAL INDEX ROWID | OPT_ACTVY_FCT | 3 | 78 | 18 (0)| 00:00:01 | 1 | 16 | | | |

| 118 | BITMAP CONVERSION TO ROWIDS | | | | | | | | | | |

|*119 | BITMAP INDEX SINGLE VALUE | OPT_ACTVY_FCT_BX6 | | | | | 1 | 16 | | | |

|*120 | INDEX UNIQUE SCAN | OPT_ACTVY_FDIM_PK | 1 | | 0 (0)| 00:00:01 | | | | | |

| 121 | TABLE ACCESS BY GLOBAL INDEX ROWID | OPT_ACTVY_FDIM | 1 | 19 | 1 (0)| 00:00:01 | ROWID | ROWID | | | |

| 122 | VIEW PUSHED PREDICATE | | 1 | 39 | | | | | | | |

|*123 | FILTER | | | | | | | | | | |

| 124 | SORT AGGREGATE | | 1 | 30 | | | | | | | |

|*125 | PX COORDINATOR | | | | | | | | | | |

| 126 | PX SEND QC (RANDOM) | :TQ40001 | 1 | 30 | | | | | Q4,01 | P->S | QC (RAND) |

| 127 | SORT AGGREGATE | | 1 | 30 | | | | | Q4,01 | PCWP | |

|*128 | FILTER | | | | | | | | Q4,01 | PCWC | |

|*129 | HASH JOIN | | 1 | 30 | 350 (1)| 00:00:02 | | | Q4,01 | PCWP | |

| 130 | BUFFER SORT | | | | | | | | Q4,01 | PCWC | |

| 131 | PX RECEIVE | | 1 | 13 | 19 (6)| 00:00:01 | | | Q4,01 | PCWP | |

| 132 | PX SEND BROADCAST | :TQ40000 | 1 | 13 | 19 (6)| 00:00:01 | | | | S->P | BROADCAST |

| 133 | VIEW | | 1 | 13 | 19 (6)| 00:00:01 | | | | | |

| 134 | SORT UNIQUE | | 1 | 6 | 18 (6)| 00:00:01 | | | | | |

|*135 | FILTER | | | | | | | | | | |

| 136 | PARTITION LIST ALL | | 1 | 6 | 17 (0)| 00:00:01 | 1 | 16 | | | |

|*137 | BITMAP INDEX SINGLE VALUE | OPT_PRMTN_FCT_BX2 | 1 | 6 | 17 (0)| 00:00:01 | 1 | 16 | | | |

| 138 | PX PARTITION LIST ALL | | 912 | 15504 | 330 (1)| 00:00:02 | 1 | 16 | Q4,01 | PCWC | |

| 139 | TABLE ACCESS BY LOCAL INDEX ROWID | OPT_PRMTN_PROD_FCT | 912 | 15504 | 330 (1)| 00:00:02 | 1 | 16 | Q4,01 | PCWP | |

| 140 | BITMAP CONVERSION TO ROWIDS | | | | | | | | Q4,01 | PCWP | |

|*141 | BITMAP INDEX SINGLE VALUE | OPT_PRMTN_PROD_FCT_BX4 | | | | | 1 | 16 | Q4,01 | PCWP | |

|*142 | INDEX RANGE SCAN | OPT_PRMTN_DIM_PK | 1 | | 1 (0)| 00:00:01 | | | | | |

| 143 | TABLE ACCESS BY GLOBAL INDEX ROWID | OPT_PRMTN_DIM | 1 | 16 | 2 (0)| 00:00:01 | ROWID | ROWID | | | |

---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Predicate Information (identified by operation id):

---------------------------------------------------

11 - filter("PFCT"."BASE_PRMTN_SKID"<>0)

13 - filter(COUNT(*)>0)

15 - filter("PFCT"."BASE_PRMTN_SKID"<>0)

20 - filter(0<>"PFCT"."BASE_PRMTN_SKID")

22 - access("PFCT"."BASE_PRMTN_SKID"="PFCT"."BASE_PRMTN_SKID")

filter("PFCT"."BASE_PRMTN_SKID"<>0)

23 - filter("BFCT"."PRMTN_SKID"="PRMTN_FLT"."PRMTN_SKID")

25 - filter(0<>"PFCT"."BASE_PRMTN_SKID")

32 - access("BFCT"."PRMTN_SKID"="PFCT"."BASE_PRMTN_SKID")

filter("BFCT"."PRMTN_SKID"<>0)

33 - access("BFCT"."ACCT_SKID"="PFCT"."ACCT_PRMTN_SKID")

34 - access("BFCT"."WK_SKID"="D"."CAL_MASTR_SKID")

39 - filter("BRAND_BASLN_FCT"."PRMTN_ACCT_SKID"(+)="PFCT"."ACCT_PRMTN_SKID" AND "BFCT"."BUS_UNIT_SKID"="BRAND_BASLN_FCT"."BUS_UNIT_SKID"(+) AND

"BFCT"."PROD_SKID"="BRAND_BASLN_FCT"."PROD_SKID"(+) AND "BFCT"."ACCT_SKID"="BRAND_BASLN_FCT"."PRMTN_ACCT_SKID"(+) AND "BFCT"."MTH_SKID"="BRAND_BASLN_FCT"."DATE_SKID"(+))

41 - filter(0<>"PFCT"."BASE_PRMTN_SKID")

44 - access("PFCT"."BASE_PRMTN_SKID"="PFCT"."BASE_PRMTN_SKID")

filter("PFCT"."BASE_PRMTN_SKID"<>0)

46 - filter(COUNT(*)>0)

52 - filter(0<>"PFCT"."BASE_PRMTN_SKID")

58 - access("EPOS_BFCT"."PRMTN_SKID"="PFCT"."BASE_PRMTN_SKID")

filter("EPOS_BFCT"."PRMTN_SKID"<>0)

59 - access("EPOS_BFCT"."WK_SKID"="D"."CAL_MASTR_SKID")

64 - filter("EPOS_BFCT"."BUS_UNIT_SKID"="BRAND_BASLN_FCT"."BUS_UNIT_SKID"(+) AND "EPOS_BFCT"."PROD_SKID"="BRAND_BASLN_FCT"."PROD_SKID"(+) AND

"EPOS_BFCT"."ACCT_SKID"="BRAND_BASLN_FCT"."PRMTN_ACCT_SKID"(+) AND "EPOS_BFCT"."MTH_SKID"="BRAND_BASLN_FCT"."DATE_SKID"(+))

67 - access("EPOS_BFCT"."PRMTN_SKID"="PFCT"."BASE_PRMTN_SKID")

filter("PFCT"."BASE_PRMTN_SKID"="PFCT"."BASE_PRMTN_SKID")

69 - filter(COUNT(SYS_OP_CSR(SYS_OP_MSR(COUNT(*),SUM("EPOS_FCT"."VOL_SU")),0))>0)

71 - filter(0<>"PFCT"."BASE_PRMTN_SKID")

74 - filter(0<>"PFCT"."BASE_PRMTN_SKID")

78 - access("EPOS_FCT"."PROD_SKID"="PRMTN_PROD_FCT"."PROD_SKID")

79 - access("PRMTN_PROD_FCT"."PRMTN_SKID"="PRMTN_FDIM"."PRMTN_SKID")

84 - access("PRMTN_FDIM"."PRMTN_SKID"="PFCT"."BASE_PRMTN_SKID")

filter("PRMTN_FDIM"."PRMTN_SKID"<>0)

88 - access("PRMTN_PROD_FCT"."PRMTN_SKID"="PFCT"."BASE_PRMTN_SKID")

filter("PRMTN_PROD_FCT"."PRMTN_SKID"<>0)

95 - filter("CAL_MASTR"."DAY_DATE">="PRMTN_FDIM"."PGM_START_DATE" AND "CAL_MASTR"."DAY_DATE"<="PRMTN_FDIM"."PGM_END_DATE")

96 - access("EPOS_FCT"."DATE_SKID"="CAL_MASTR"."CAL_MASTR_SKID")

97 - filter("CAL_MASTR"."DAY_DATE"<="HIER"."ASDN_EFF_END_DATE")

98 - access("EPOS_FCT"."ACCT_SKID"="HIER"."ACCT_SKID" AND "HIER"."ASSOC_ACCT_SKID"="PRMTN_FDIM"."ACCT_SKID" AND "CAL_MASTR"."DAY_DATE">="HIER"."ASDN_EFF_START_DATE")

100 - filter(0<>"PRMTN_PROD_FCT"."PRMTN_SKID")

103 - access("PFCT"."BASE_PRMTN_SKID"="PRMTN_PROD_FCT"."PRMTN_SKID")

filter("PFCT"."BASE_PRMTN_SKID"<>0)

105 - filter(COUNT(*)>0)

107 - filter(0<>"PFCT"."BASE_PRMTN_SKID")

110 - access("OPT_ACTVY_FCT"."PRMTN_SKID"="PRMTN_FLT"."PRMTN_SKID")

113 - filter(0<>"PFCT"."BASE_PRMTN_SKID")

115 - access("PFCT"."BASE_PRMTN_SKID"="PFCT"."BASE_PRMTN_SKID")

filter("PFCT"."BASE_PRMTN_SKID"<>0)

119 - access("OPT_ACTVY_FCT"."PRMTN_SKID"="PFCT"."BASE_PRMTN_SKID")

filter("OPT_ACTVY_FCT"."PRMTN_SKID"<>0)

120 - access("OPT_ACTVY_FCT"."ACTVY_SKID"="OPT_ACTVY_FDIM"."ACTVY_SKID")

123 - filter(COUNT(SYS_OP_CSR(SYS_OP_MSR(COUNT(*),SUM("PRMTN_PROD_FCT"."INCRM_IN_GIV_AMT"*(100-"PRMTN_PROD_FCT"."TRADE_TERM_PCT")/100),SUM("PRMTN_PROD_FCT"."TOT_IN_GIV_AMT"*(

100-"PRMTN_PROD_FCT"."TRADE_TERM_PCT")/100),SUM("PRMTN_PROD_FCT"."ACTL_GIV_AMT"*(100-"PRMTN_PROD_FCT"."TRADE_TERM_PCT")/100)),0))>0)

125 - filter(0<>"PFCT"."BASE_PRMTN_SKID")

128 - filter(0<>"PFCT"."BASE_PRMTN_SKID")

129 - access("PRMTN_PROD_FCT"."PRMTN_SKID"="PRMTN_FLT"."PRMTN_SKID")

135 - filter(0<>"PFCT"."BASE_PRMTN_SKID")

137 - access("PFCT"."BASE_PRMTN_SKID"="PFCT"."BASE_PRMTN_SKID")

filter("PFCT"."BASE_PRMTN_SKID"<>0)

141 - access("PRMTN_PROD_FCT"."PRMTN_SKID"="PFCT"."BASE_PRMTN_SKID")

filter("PRMTN_PROD_FCT"."PRMTN_SKID"<>0)

142 - access("PDIM"."PRMTN_SKID"="PFCT"."BASE_PRMTN_SKID")

filter("PDIM"."PRMTN_SKID"<>0)

217 rows selected.

Elapsed: 00:00:08.32

This is the wrong execution plan, if using this execution plan ,SQL will run a long time(16hours)

please notice that CBO use VIEW PUSHED PREDICATE ,this is controlled by hidden parameter _optimizer_extend_jppd_view_types

And it’s a new CBO feature in 11g.

SQL> alter session set "_optimizer_extend_jppd_view_types"=false;

Session altered.

Execution plan the SQL, execution plan as below:

SQL> select * from table(dbms_xplan.display);

PLAN_TABLE_OUTPUT

---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Plan hash value: 85794673

-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

| Id | Operation | Name | Rows | Bytes |TempSpc| Cost (%CPU)| Time | Pstart| Pstop | TQ |IN-OUT| PQ Distrib |

-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

| 0 | CREATE TABLE STATEMENT | | 168K| 50M| | 274K (11)| 00:14:17 | | | | | |

| 1 | LOAD AS SELECT | ROBINSON | | | | | | | | | | |

| 2 | PX COORDINATOR | | | | | | | | | | | |

| 3 | PX SEND QC (RANDOM) | :TQ20023 | 168K| 50M| | 273K (11)| 00:14:16 | | | Q2,23 | P->S | QC (RAND) |

|* 4 | HASH JOIN | | 168K| 50M| | 273K (11)| 00:14:16 | | | Q2,23 | PCWP | |

| 5 | BUFFER SORT | | | | | | | | | Q2,23 | PCWC | |

| 6 | PX RECEIVE | | 173K| 2706K| | 4292 (2)| 00:00:14 | | | Q2,23 | PCWP | |

| 7 | PX SEND BROADCAST | :TQ20007 | 173K| 2706K| | 4292 (2)| 00:00:14 | | | | S->P | BROADCAST |

| 8 | PARTITION LIST ALL | | 173K| 2706K| | 4292 (2)| 00:00:14 | 1 | 16 | | | |

|* 9 | TABLE ACCESS FULL | OPT_PRMTN_DIM | 173K| 2706K| | 4292 (2)| 00:00:14 | 1 | 16 | | | |

|* 10 | HASH JOIN | | 168K| 47M| | 269K (11)| 00:14:02 | | | Q2,23 | PCWP | |

| 11 | BUFFER SORT | | | | | | | | | Q2,23 | PCWC | |

| 12 | PX RECEIVE | | 168K| 2145K| | 940 (6)| 00:00:03 | | | Q2,23 | PCWP | |

| 13 | PX SEND BROADCAST | :TQ20008 | 168K| 2145K| | 940 (6)| 00:00:03 | | | | S->P | BROADCAST |

| 14 | VIEW | | 168K| 2145K| | 940 (6)| 00:00:03 | | | | | |

| 15 | HASH UNIQUE | | 168K| 990K| 2000K| 891 (6)| 00:00:03 | | | | | |

| 16 | PARTITION LIST ALL | | 168K| 990K| | 570 (2)| 00:00:02 | 1 | 16 | | | |

| 17 | BITMAP CONVERSION TO ROWIDS | | 168K| 990K| | 570 (2)| 00:00:02 | | | | | |

|* 18 | BITMAP INDEX FAST FULL SCAN | OPT_PRMTN_FCT_BX2 | | | | | | 1 | 16 | | | |

|* 19 | HASH JOIN RIGHT OUTER | | 168K| 45M| | 268K (11)| 00:13:59 | | | Q2,23 | PCWP | |

| 20 | PX RECEIVE | | 718 | 56004 | | 22203 (8)| 00:01:10 | | | Q2,23 | PCWP | |

| 21 | PX SEND BROADCAST | :TQ20020 | 718 | 56004 | | 22203 (8)| 00:01:10 | | | Q2,20 | P->P | BROADCAST |

| 22 | VIEW | | 718 | 56004 | | 22203 (8)| 00:01:10 | | | Q2,20 | PCWP | |

| 23 | HASH GROUP BY | | 718 | 78262 | | 20827 (8)| 00:01:06 | | | Q2,20 | PCWP | |

| 24 | PX RECEIVE | | 718 | 78262 | | 20827 (8)| 00:01:06 | | | Q2,20 | PCWP | |

| 25 | PX SEND HASH | :TQ20017 | 718 | 78262 | | 20827 (8)| 00:01:06 | | | Q2,17 | P->P | HASH |

| 26 | HASH GROUP BY | | 718 | 78262 | | 20827 (8)| 00:01:06 | | | Q2,17 | PCWP | |

|* 27 | HASH JOIN OUTER | | 1226K| 127M| | 20646 (8)| 00:01:05 | | | Q2,17 | PCWP | |

| 28 | PX RECEIVE | | 1226K| 99M| | 11592 (6)| 00:00:37 | | | Q2,17 | PCWP | |

| 29 | PX SEND PARTITION (KEY) | :TQ20015 | 1226K| 99M| | 11592 (6)| 00:00:37 | | | Q2,15 | P->P | PART (KEY) |

|* 30 | HASH JOIN | | 1226K| 99M| | 11592 (6)| 00:00:37 | | | Q2,15 | PCWP | |

| 31 | BUFFER SORT | | | | | | | | | Q2,15 | PCWC | |

| 32 | PX RECEIVE | | 168K| 990K| | 940 (6)| 00:00:03 | | | Q2,15 | PCWP | |

| 33 | PX SEND BROADCAST | :TQ20004 | 168K| 990K| | 940 (6)| 00:00:03 | | | | S->P | BROADCAST |

| 34 | VIEW | | 168K| 990K| | 940 (6)| 00:00:03 | | | | | |

| 35 | HASH UNIQUE | | 168K| 990K| 2000K| 891 (6)| 00:00:03 | | | | | |

| 36 | PARTITION LIST ALL | | 168K| 990K| | 570 (2)| 00:00:02 | 1 | 16 | | | |

|* 37 | BITMAP INDEX FAST FULL SCAN | OPT_PRMTN_FCT_BX2 | 168K| 990K| | 570 (2)| 00:00:02 | 1 | 16 | | | |

| 38 | VIEW | | 1849K| 139M| | 10620 (6)| 00:00:34 | | | Q2,15 | PCWP | |

| 39 | HASH GROUP BY | | 1849K| 111M| 156M| 9400 (6)| 00:00:30 | | | Q2,15 | PCWP | |

| 40 | PX RECEIVE | | 1849K| 111M| | 1909 (11)| 00:00:06 | | | Q2,15 | PCWP | |

| 41 | PX SEND HASH | :TQ20013 | 1849K| 111M| | 1909 (11)| 00:00:06 | | | Q2,13 | P->P | HASH |

|* 42 | HASH JOIN | | 1849K| 111M| | 1909 (11)| 00:00:06 | | | Q2,13 | PCWP | |

| 43 | BUFFER SORT | | | | | | | | | Q2,13 | PCWC | |

| 44 | PX RECEIVE | | 37190 | 363K| | 39 (6)| 00:00:01 | | | Q2,13 | PCWP | |

| 45 | PX SEND BROADCAST | :TQ20002 | 37190 | 363K| | 39 (6)| 00:00:01 | | | | S->P | BROADCAST |

| 46 | TABLE ACCESS FULL | OPT_CAL_MASTR_MV01 | 37190 | 363K| | 39 (6)| 00:00:01 | | | | | |

| 47 | PX BLOCK ITERATOR | | 1849K| 93M| | 1840 (10)| 00:00:06 | 1 | 16 | Q2,13 | PCWC | |

|* 48 | TABLE ACCESS FULL | OPT_BASLN_FCT | 1849K| 93M| | 1840 (10)| 00:00:06 | 1 | 16 | Q2,13 | PCWP | |

| 49 | PX PARTITION LIST ALL | | 7927K| 181M| | 8906 (8)| 00:00:28 | 1 | 16 | Q2,17 | PCWC | |

| 50 | TABLE ACCESS FULL | OPT_BRAND_BASLN_FCT | 7927K| 181M| | 8906 (8)| 00:00:28 | 1 | 16 | Q2,17 | PCWP | |

|* 51 | HASH JOIN RIGHT OUTER | | 168K| 33M| | 246K (11)| 00:12:50 | | | Q2,23 | PCWP | |

| 52 | BUFFER SORT | | | | | | | | | Q2,23 | PCWC | |

| 53 | PX RECEIVE | | 1 | 39 | | 570 (8)| 00:00:02 | | | Q2,23 | PCWP | |

| 54 | PX SEND BROADCAST | :TQ20009 | 1 | 39 | | 570 (8)| 00:00:02 | | | | S->P | BROADCAST |

| 55 | VIEW | | 1 | 39 | | 570 (8)| 00:00:02 | | | | | |

| 56 | HASH GROUP BY | | 1 | 113 | | 570 (8)| 00:00:02 | | | | | |

| 57 | NESTED LOOPS SEMI | | 1 | 113 | | 570 (8)| 00:00:02 | | | | | |

| 58 | NESTED LOOPS OUTER | | 1 | 107 | | 561 (8)| 00:00:02 | | | | | |

| 59 | VIEW | | 1 | 83 | | 4 (25)| 00:00:01 | | | | | |

| 60 | HASH GROUP BY | | 1 | 101 | | 3 (34)| 00:00:01 | | | | | |

| 61 | NESTED LOOPS | | | | | | | | | | | |

| 62 | NESTED LOOPS | | 1 | 101 | | 2 (0)| 00:00:01 | | | | | |

| 63 | PARTITION LIST ALL | | 1 | 91 | | 2 (0)| 00:00:01 | 1 | 16 | | | |

|* 64 | TABLE ACCESS FULL | OPT_EPOS_POST_EVENT_BASLN_FCT | 1 | 91 | | 2 (0)| 00:00:01 | 1 | 16 | | | |

|* 65 | INDEX UNIQUE SCAN | OPT_CAL_MASTR_MV01_PK | 1 | | | 0 (0)| 00:00:01 | | | | | |

| 66 | TABLE ACCESS BY INDEX ROWID | OPT_CAL_MASTR_MV01 | 1 | 10 | | 0 (0)| 00:00:01 | | | | | |

| 67 | PX COORDINATOR | | | | | | | | | | | |

| 68 | PX SEND QC (RANDOM) | :TQ10000 | 1 | 24 | | 557 (8)| 00:00:02 | | | Q1,00 | P->S | QC (RAND) |

| 69 | PX BLOCK ITERATOR | | 1 | 24 | | 557 (8)| 00:00:02 | KEY | KEY | Q1,00 | PCWC | |

|* 70 | TABLE ACCESS FULL | OPT_BRAND_BASLN_FCT | 1 | 24 | | 557 (8)| 00:00:02 | KEY | KEY | Q1,00 | PCWP | |

| 71 | PARTITION LIST ALL | | 1 | 6 | | 570 (8)| 00:00:02 | 1 | 16 | | | |

| 72 | BITMAP CONVERSION TO ROWIDS | | 1 | 6 | | 570 (8)| 00:00:02 | | | | | |

|* 73 | BITMAP INDEX SINGLE VALUE | OPT_PRMTN_FCT_BX2 | | | | | | 1 | 16 | | | |

|* 74 | HASH JOIN RIGHT OUTER | | 168K| 26M| | 245K (11)| 00:12:48 | | | Q2,23 | PCWP | |

| 75 | PX RECEIVE | | 1 | 26 | | 110K (7)| 00:05:45 | | | Q2,23 | PCWP | |

| 76 | PX SEND BROADCAST | :TQ20021 | 1 | 26 | | 110K (7)| 00:05:45 | | | Q2,21 | P->P | BROADCAST |

| 77 | VIEW | | 1 | 26 | | 110K (7)| 00:05:45 | | | Q2,21 | PCWP | |

| 78 | HASH GROUP BY | | 1 | 19 | | 110K (7)| 00:05:45 | | | Q2,21 | PCWP | |

| 79 | PX RECEIVE | | 1 | 19 | | 110K (7)| 00:05:45 | | | Q2,21 | PCWP | |

| 80 | PX SEND HASH | :TQ20018 | 1 | 19 | | 110K (7)| 00:05:45 | | | Q2,18 | P->P | HASH |

| 81 | HASH GROUP BY | | 1 | 19 | | 110K (7)| 00:05:45 | | | Q2,18 | PCWP | |

| 82 | VIEW | VM_NWVW_1 | 614 | 11666 | | 110K (7)| 00:05:45 | | | Q2,18 | PCWP | |

| 83 | HASH UNIQUE | | 614 | 101K| | 110K (7)| 00:05:45 | | | Q2,18 | PCWP | |

| 84 | PX RECEIVE | | 614 | 101K| | 110K (7)| 00:05:45 | | | Q2,18 | PCWP | |

| 85 | PX SEND HASH | :TQ20016 | 614 | 101K| | 110K (7)| 00:05:45 | | | Q2,16 | P->P | HASH |

|* 86 | HASH JOIN | | 614 | 101K| | 110K (7)| 00:05:45 | | | Q2,16 | PCWP | |

|* 87 | HASH JOIN | | 614 | 97K| | 109K (7)| 00:05:43 | | | Q2,16 | PCWP | |

| 88 | PX RECEIVE | | 43328 | 5881K| | 9831 (3)| 00:00:31 | | | Q2,16 | PCWP | |

| 89 | PX SEND BROADCAST | :TQ20014 | 43328 | 5881K| | 9831 (3)| 00:00:31 | | | Q2,14 | P->P | BROADCAST |

|* 90 | HASH JOIN | | 43328 | 5881K| | 9831 (3)| 00:00:31 | | | Q2,14 | PCWP | |

| 91 | PX RECEIVE | | 6345 | 613K| | 5544 (3)| 00:00:18 | | | Q2,14 | PCWP | |

| 92 | PX SEND HASH | :TQ20012 | 6345 | 613K| | 5544 (3)| 00:00:18 | | | Q2,12 | P->P | HASH |

|* 93 | HASH JOIN | | 6345 | 613K| | 5544 (3)| 00:00:18 | | | Q2,12 | PCWP | |

|* 94 | HASH JOIN | | 21392 | 1545K| | 5502 (3)| 00:00:18 | | | Q2,12 | PCWP | |

| 95 | BUFFER SORT | | | | | | | | | Q2,12 | PCWC | |

| 96 | PX RECEIVE | | 3039 | 100K| | 298 (2)| 00:00:01 | | | Q2,12 | PCWP | |

| 97 | PX SEND BROADCAST | :TQ20000 | 3039 | 100K| | 298 (2)| 00:00:01 | | | | S->P | BROADCAST |

| 98 | PARTITION RANGE ALL | | 3039 | 100K| | 298 (2)| 00:00:01 | 1 | 21 | | | |

| 99 | PARTITION LIST ALL | | 3039 | 100K| | 298 (2)| 00:00:01 | 1 | 5 | | | |

| 100 | TABLE ACCESS FULL | OPT_EPOS_FCT | 3039 | 100K| | 298 (2)| 00:00:01 | 1 | 105 | | | |

| 101 | PX BLOCK ITERATOR | | 2043K| 77M| | 5173 (2)| 00:00:17 | 1 | 16 | Q2,12 | PCWC | |

| 102 | TABLE ACCESS FULL | OPT_ACCT_ASDN_TYPE2_DIM | 2043K| 77M| | 5173 (2)| 00:00:17 | 1 | 16 | Q2,12 | PCWP | |

| 103 | BUFFER SORT | | | | | | | | | Q2,12 | PCWC | |

| 104 | PX RECEIVE | | 37190 | 907K| | 40 (8)| 00:00:01 | | | Q2,12 | PCWP | |

| 105 | PX SEND BROADCAST | :TQ20001 | 37190 | 907K| | 40 (8)| 00:00:01 | | | | S->P | BROADCAST |

| 106 | TABLE ACCESS FULL | OPT_CAL_MASTR_MV01 | 37190 | 907K| | 40 (8)| 00:00:01 | | | | | |

| 107 | BUFFER SORT | | | | | | | | | Q2,14 | PCWC | |

| 108 | PX RECEIVE | | 173K| 6783K| | 4283 (2)| 00:00:14 | | | Q2,14 | PCWP | |

| 109 | PX SEND HASH | :TQ20003 | 173K| 6783K| | 4283 (2)| 00:00:14 | | | | S->P | HASH |

| 110 | PARTITION LIST ALL | | 173K| 6783K| | 4283 (2)| 00:00:14 | 1 | 16 | | | |

|*111 | TABLE ACCESS FULL | OPT_PRMTN_FDIM | 173K| 6783K| | 4283 (2)| 00:00:14 | 1 | 16 | | | |

| 112 | PX BLOCK ITERATOR | | 55M| 1258M| | 98856 (6)| 00:05:10 | 1 | 16 | Q2,16 | PCWC | |

|*113 | TABLE ACCESS FULL | OPT_PRMTN_PROD_FCT | 55M| 1258M| | 98856 (6)| 00:05:10 | 1 | 16 | Q2,16 | PCWP | |

| 114 | BUFFER SORT | | | | | | | | | Q2,16 | PCWC | |

| 115 | PX RECEIVE | | 168K| 990K| | 570 (2)| 00:00:02 | | | Q2,16 | PCWP | |

| 116 | PX SEND BROADCAST | :TQ20005 | 168K| 990K| | 570 (2)| 00:00:02 | | | | S->P | BROADCAST |

| 117 | PARTITION LIST ALL | | 168K| 990K| | 570 (2)| 00:00:02 | 1 | 16 | | | |

| 118 | BITMAP CONVERSION TO ROWIDS | | 168K| 990K| | 570 (2)| 00:00:02 | | | | | |

|*119 | BITMAP INDEX FAST FULL SCAN| OPT_PRMTN_FCT_BX2 | | | | | | 1 | 16 | | | |

|*120 | HASH JOIN RIGHT OUTER | | 168K| 22M| | 135K (15)| 00:07:04 | | | Q2,23 | PCWP | |

| 121 | BUFFER SORT | | | | | | | | | Q2,23 | PCWC | |

| 122 | PX RECEIVE | | 1 | 39 | | 6665 (6)| 00:00:21 | | | Q2,23 | PCWP | |

| 123 | PX SEND BROADCAST | :TQ20010 | 1 | 39 | | 6665 (6)| 00:00:21 | | | | S->P | BROADCAST |

| 124 | VIEW | | 1 | 39 | | 6665 (6)| 00:00:21 | | | | | |

| 125 | HASH GROUP BY | | 1 | 51 | | 6351 (7)| 00:00:20 | | | | | |

|*126 | HASH JOIN | | 291K| 14M| 8832K| 6278 (6)| 00:00:20 | | | | | |

| 127 | PARTITION LIST ALL | | 291K| 5405K| | 3641 (5)| 00:00:12 | 1 | 16 | | | |

| 128 | TABLE ACCESS FULL | OPT_ACTVY_FDIM | 291K| 5405K| | 3641 (5)| 00:00:12 | 1 | 16 | | | |

|*129 | HASH JOIN | | 291K| 9101K| 2976K| 2292 (6)| 00:00:08 | | | | | |

| 130 | VIEW | | 168K| 990K| | 940 (6)| 00:00:03 | | | | | |

| 131 | HASH UNIQUE | | 168K| 990K| 2000K| 891 (6)| 00:00:03 | | | | | |

| 132 | PARTITION LIST ALL | | 168K| 990K| | 570 (2)| 00:00:02 | 1 | 16 | | | |

|*133 | BITMAP INDEX FAST FULL SCAN | OPT_PRMTN_FCT_BX2 | 168K| 990K| | 570 (2)| 00:00:02 | 1 | 16 | | | |

| 134 | PARTITION LIST ALL | | 291K| 7395K| | 1124 (6)| 00:00:04 | 1 | 16 | | | |

|*135 | TABLE ACCESS FULL | OPT_ACTVY_FCT | 291K| 7395K| | 1124 (6)| 00:00:04 | 1 | 16 | | | |

|*136 | HASH JOIN RIGHT OUTER | | 168K| 16M| | 128K (16)| 00:06:43 | | | Q2,23 | PCWP | |

| 137 | PX RECEIVE | | 1 | 52 | | 127K (16)| 00:06:39 | | | Q2,23 | PCWP | |

| 138 | PX SEND HASH | :TQ20022 | 1 | 52 | | 127K (16)| 00:06:39 | | | Q2,22 | P->P | HASH |

| 139 | VIEW | | 1 | 52 | | 127K (16)| 00:06:39 | | | Q2,22 | PCWP | |

| 140 | HASH GROUP BY | | 1 | 23 | | 113K (18)| 00:05:54 | | | Q2,22 | PCWP | |

| 141 | PX RECEIVE | | 1 | 23 | | 113K (18)| 00:05:54 | | | Q2,22 | PCWP | |

| 142 | PX SEND HASH | :TQ20019 | 1 | 23 | | 113K (18)| 00:05:54 | | | Q2,19 | P->P | HASH |

| 143 | HASH GROUP BY | | 1 | 23 | | 113K (18)| 00:05:54 | | | Q2,19 | PCWP | |

|*144 | HASH JOIN | | 55M| 1206M| | 103K (9)| 00:05:23 | | | Q2,19 | PCWP | |

| 145 | BUFFER SORT | | | | | | | | | Q2,19 | PCWC | |

| 146 | PX RECEIVE | | 168K| 990K| | 940 (6)| 00:00:03 | | | Q2,19 | PCWP | |

| 147 | PX SEND BROADCAST | :TQ20006 | 168K| 990K| | 940 (6)| 00:00:03 | | | | S->P | BROADCAST |

| 148 | VIEW | | 168K| 990K| | 940 (6)| 00:00:03 | | | | | |

| 149 | HASH UNIQUE | | 168K| 990K| 2000K| 891 (6)| 00:00:03 | | | | | |

| 150 | PARTITION LIST ALL | | 168K| 990K| | 570 (2)| 00:00:02 | 1 | 16 | | | |

|*151 | BITMAP INDEX FAST FULL SCAN | OPT_PRMTN_FCT_BX2 | 168K| 990K| | 570 (2)| 00:00:02 | 1 | 16 | | | |

| 152 | PX BLOCK ITERATOR | | 55M| 891M| | 101K (9)| 00:05:17 | 1 | 16 | Q2,19 | PCWC | |

|*153 | TABLE ACCESS FULL | OPT_PRMTN_PROD_FCT | 55M| 891M| | 101K (9)| 00:05:17 | 1 | 16 | Q2,19 | PCWP | |

| 154 | BUFFER SORT | | | | | | | | | Q2,23 | PCWC | |

| 155 | PX RECEIVE | | 168K| 8085K| | 1279 (7)| 00:00:05 | | | Q2,23 | PCWP | |

| 156 | PX SEND HASH | :TQ20011 | 168K| 8085K| | 1279 (7)| 00:00:05 | | | | S->P | HASH |

| 157 | PARTITION LIST ALL | | 168K| 8085K| | 1279 (7)| 00:00:05 | 1 | 16 | | | |

|*158 | TABLE ACCESS FULL | OPT_PRMTN_FCT | 168K| 8085K| | 1279 (7)| 00:00:05 | 1 | 16 | | | |

-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Predicate Information (identified by operation id):

---------------------------------------------------

4 - access("PDIM"."PRMTN_SKID"="PFCT"."BASE_PRMTN_SKID")

9 - filter("PDIM"."PRMTN_SKID"<>0)

10 - access("PFCT"."BASE_PRMTN_SKID"="PRMTN_FLT"."PRMTN_SKID")

18 - filter("PFCT"."BASE_PRMTN_SKID"<>0)

19 - access("PFCT"."BASE_PRMTN_SKID"="BASLN"."PRMTN_SKID"(+) AND "PFCT"."ACCT_PRMTN_SKID"="BASLN"."ACCT_SKID"(+))

27 - access("BFCT"."BUS_UNIT_SKID"="BRAND_BASLN_FCT"."BUS_UNIT_SKID"(+) AND "BFCT"."PROD_SKID"="BRAND_BASLN_FCT"."PROD_SKID"(+) AND

"BFCT"."ACCT_SKID"="BRAND_BASLN_FCT"."PRMTN_ACCT_SKID"(+) AND "BFCT"."MTH_SKID"="BRAND_BASLN_FCT"."DATE_SKID"(+))

30 - access("BFCT"."PRMTN_SKID"="PRMTN_FLT"."PRMTN_SKID")

37 - filter("PFCT"."BASE_PRMTN_SKID"<>0)

42 - access("BFCT"."WK_SKID"="D"."CAL_MASTR_SKID")

48 - filter("BFCT"."PRMTN_SKID"<>0)

51 - access("PFCT"."BASE_PRMTN_SKID"="EPOS_BASLN"."PRMTN_SKID"(+))

64 - filter("EPOS_BFCT"."PRMTN_SKID"<>0)

65 - access("EPOS_BFCT"."WK_SKID"="D"."CAL_MASTR_SKID")

70 - filter("EPOS_BFCT"."BUS_UNIT_SKID"="BRAND_BASLN_FCT"."BUS_UNIT_SKID"(+) AND "EPOS_BFCT"."PROD_SKID"="BRAND_BASLN_FCT"."PROD_SKID"(+) AND

"EPOS_BFCT"."ACCT_SKID"="BRAND_BASLN_FCT"."PRMTN_ACCT_SKID"(+) AND "EPOS_BFCT"."MTH_SKID"="BRAND_BASLN_FCT"."DATE_SKID"(+))

73 - access("EPOS_BFCT"."PRMTN_SKID"="PFCT"."BASE_PRMTN_SKID")

filter("PFCT"."BASE_PRMTN_SKID"<>0)

74 - access("PFCT"."BASE_PRMTN_SKID"="EPOS_FCT"."PRMTN_SKID"(+))

86 - access("PRMTN_PROD_FCT"."PRMTN_SKID"="PFCT"."BASE_PRMTN_SKID")

87 - access("EPOS_FCT"."PROD_SKID"="PRMTN_PROD_FCT"."PROD_SKID" AND "PRMTN_PROD_FCT"."PRMTN_SKID"="PRMTN_FDIM"."PRMTN_SKID")

90 - access("HIER"."ASSOC_ACCT_SKID"="PRMTN_FDIM"."ACCT_SKID")

filter("CAL_MASTR"."DAY_DATE">="PRMTN_FDIM"."PGM_START_DATE" AND "CAL_MASTR"."DAY_DATE"<="PRMTN_FDIM"."PGM_END_DATE")

93 - access("EPOS_FCT"."DATE_SKID"="CAL_MASTR"."CAL_MASTR_SKID")

filter("CAL_MASTR"."DAY_DATE">="HIER"."ASDN_EFF_START_DATE" AND "CAL_MASTR"."DAY_DATE"<="HIER"."ASDN_EFF_END_DATE")

94 - access("EPOS_FCT"."ACCT_SKID"="HIER"."ACCT_SKID")

111 - filter("PRMTN_FDIM"."PRMTN_SKID"<>0)

113 - filter("PRMTN_PROD_FCT"."PRMTN_SKID"<>0 AND SYS_OP_BLOOM_FILTER(:BF0000,"PRMTN_PROD_FCT"."PROD_SKID","PRMTN_PROD_FCT"."PRMTN_SKID"))

119 - filter("PFCT"."BASE_PRMTN_SKID"<>0)

120 - access("PFCT"."BASE_PRMTN_SKID"="ACTVY"."PRMTN_SKID"(+))

126 - access("OPT_ACTVY_FCT"."ACTVY_SKID"="OPT_ACTVY_FDIM"."ACTVY_SKID")

129 - access("OPT_ACTVY_FCT"."PRMTN_SKID"="PRMTN_FLT"."PRMTN_SKID")

133 - filter("PFCT"."BASE_PRMTN_SKID"<>0)

135 - filter("OPT_ACTVY_FCT"."PRMTN_SKID"<>0)

136 - access("PFCT"."BASE_PRMTN_SKID"="PRMTN_PROD"."PRMTN_SKID"(+))

144 - access("PRMTN_PROD_FCT"."PRMTN_SKID"="PRMTN_FLT"."PRMTN_SKID")

151 - filter("PFCT"."BASE_PRMTN_SKID"<>0)

153 - filter("PRMTN_PROD_FCT"."PRMTN_SKID"<>0)

158 - filter("PFCT"."BASE_PRMTN_SKID"<>0)

208 rows selected.

Elapsed: 00:00:07.27

This is the right execution plan , SQL can be finished within 5minutes. I’ve met this issue many times, last week , during Optima APAC cut over,I also met this issue, so it’s necessary to point it out.

Here is another example for this bug:

SQL> alter session set "_optimizer_extend_jppd_view_types"=false;

Session altered.

SQL> explain plan for MERGE INTO (SELECT *

2 FROM opt_prmtn_fct

3 WHERE base_prmtn_skid <> 0

4 ) trgt

5 USING ( SELECT pd.prmtn_skid AS prmtn_skid

6 , SUM (pd.actl_tot_qty_amt) AS actl_tot_su_amt

7 , SUM ( (CASE

8 WHEN ( (NVL (p.su_factr, 0) * NVL (pd.actl_tot_qty_amt, 0)

9 - NVL (pd.estmt_basln_case_amt, 0)) < 0)

10 THEN

11 0

12 ELSE

13 ( (NVL (p.su_factr, 0) * NVL (pd.actl_tot_qty_amt, 0))

14 - NVL (pd.estmt_basln_case_amt, 0))

15 END

16 + CASE

17 WHEN (NVL (pd.post_prmtn_vol_adjmt_su_amt, 0) = 0)

18 THEN

19 0

20 ELSE

21 (NVL (pd.post_prmtn_vol_adjmt_su_amt, 0)

22 - NVL (pd.estmt_basln_case_amt, 0))

23 END)

24 * pd.giv_conv_factr_amt)

25 AS actl_incrm_niv_amt

26 FROM opt_prmtn_prod_fct pd

27 , opt_prod_dim p

28 WHERE pd.prod_skid = p.prod_skid

29 AND pd.prmtn_skid <> 0

30 GROUP BY pd.prmtn_skid) srce

31 ON (srce.prmtn_skid = trgt.base_prmtn_skid)

32 WHEN MATCHED

33 THEN

34 UPDATE SET trgt.actl_tot_su_amt = srce.actl_tot_su_amt

35 , trgt.actl_incrm_niv_amt = srce.actl_incrm_niv_amt;

Explained.

Elapsed: 00:00:01.65

SQL> select * from table(dbms_xplan.display);

PLAN_TABLE_OUTPUT

-------------------------------------------------------------------------------------------------------------------------------------------------------

Plan hash value: 4142491988

------------------------------------------------------------------------------------------------------------------------------------------------------

| Id | Operation | Name | Rows | Bytes | Cost (%CPU)| Time | Pstart| Pstop | TQ |IN-OUT| PQ Distrib |

------------------------------------------------------------------------------------------------------------------------------------------------------

| 0 | MERGE STATEMENT | | 1 | 38 | 113K (16)| 00:05:56 | | | | | |

| 1 | MERGE | OPT_PRMTN_FCT | | | | | | | | | |

| 2 | PX COORDINATOR | | | | | | | | | | |

| 3 | PX SEND QC (RANDOM) | :TQ10004 | 1 | 244 | 113K (16)| 00:05:56 | | | Q1,04 | P->S | QC (RAND) |

| 4 | VIEW | | | | | | | | Q1,04 | PCWP | |

|* 5 | HASH JOIN | | 1 | 244 | 113K (16)| 00:05:56 | | | Q1,04 | PCWP | |

| 6 | PX RECEIVE | | 1 | 39 | 112K (16)| 00:05:52 | | | Q1,04 | PCWP | |

| 7 | PX SEND HASH | :TQ10003 | 1 | 39 | 112K (16)| 00:05:52 | | | Q1,03 | P->P | HASH |

| 8 | VIEW | | 1 | 39 | 112K (16)| 00:05:52 | | | Q1,03 | PCWP | |

| 9 | SORT GROUP BY | | 1 | 33 | 112K (16)| 00:05:52 | | | Q1,03 | PCWP | |

| 10 | PX RECEIVE | | 1 | 33 | 112K (16)| 00:05:52 | | | Q1,03 | PCWP | |

| 11 | PX SEND HASH | :TQ10002 | 1 | 33 | 112K (16)| 00:05:52 | | | Q1,02 | P->P | HASH |

| 12 | SORT GROUP BY | | 1 | 33 | 112K (16)| 00:05:52 | | | Q1,02 | PCWP | |

|* 13 | HASH JOIN | | 55M| 1744M| 102K (8)| 00:05:20 | | | Q1,02 | PCWP | |

| 14 | BUFFER SORT | | | | | | | | Q1,02 | PCWC | |

| 15 | PX RECEIVE | | 182K| 1780K| 2159 (4)| 00:00:07 | | | Q1,02 | PCWP | |

| 16 | PX SEND BROADCAST | :TQ10000 | 182K| 1780K| 2159 (4)| 00:00:07 | | | | S->P | BROADCAST |

| 17 | PARTITION LIST ALL| | 182K| 1780K| 2159 (4)| 00:00:07 | 1 | 16 | | | |

| 18 | TABLE ACCESS FULL| OPT_PROD_DIM | 182K| 1780K| 2159 (4)| 00:00:07 | 1 | 16 | | | |

| 19 | PX BLOCK ITERATOR | | 54M| 1206M| 99315 (7)| 00:05:11 | 1 | 16 | Q1,02 | PCWC | |

|* 20 | TABLE ACCESS FULL | OPT_PRMTN_PROD_FCT | 54M| 1206M| 99315 (7)| 00:05:11 | 1 | 16 | Q1,02 | PCWP | |

| 21 | BUFFER SORT | | | | | | | | Q1,04 | PCWC | |

| 22 | PX RECEIVE | | 172K| 33M| 1286 (8)| 00:00:05 | | | Q1,04 | PCWP | |

| 23 | PX SEND HASH | :TQ10001 | 172K| 33M| 1286 (8)| 00:00:05 | | | | S->P | HASH |

| 24 | PARTITION LIST ALL | | 172K| 33M| 1286 (8)| 00:00:05 | 1 | 16 | | | |

|* 25 | TABLE ACCESS FULL | OPT_PRMTN_FCT | 172K| 33M| 1286 (8)| 00:00:05 | 1 | 16 | | | |

------------------------------------------------------------------------------------------------------------------------------------------------------

Predicate Information (identified by operation id):

---------------------------------------------------

5 - access("SRCE"."PRMTN_SKID"="OPT_PRMTN_FCT"."BASE_PRMTN_SKID")

13 - access("PD"."PROD_SKID"="P"."PROD_SKID")

20 - filter("PD"."PRMTN_SKID"<>0)

25 - filter("BASE_PRMTN_SKID"<>0)

40 rows selected.

SQL> alter session set "_optimizer_extend_jppd_view_types"=true;

Session altered.

Elapsed: 00:00:01.18

SQL> explain plan for MERGE INTO (SELECT *

2 FROM opt_prmtn_fct

3 WHERE base_prmtn_skid <> 0

4 ) trgt

5 USING ( SELECT pd.prmtn_skid AS prmtn_skid

6 , SUM (pd.actl_tot_qty_amt) AS actl_tot_su_amt

7 , SUM ( (CASE

8 WHEN ( (NVL (p.su_factr, 0) * NVL (pd.actl_tot_qty_amt, 0)

9 - NVL (pd.estmt_basln_case_amt, 0)) < 0)

10 THEN

11 0

12 ELSE

13 ( (NVL (p.su_factr, 0) * NVL (pd.actl_tot_qty_amt, 0))

14 - NVL (pd.estmt_basln_case_amt, 0))

15 END

16 + CASE

17 WHEN (NVL (pd.post_prmtn_vol_adjmt_su_amt, 0) = 0)

18 THEN

19 0

20 ELSE

21 (NVL (pd.post_prmtn_vol_adjmt_su_amt, 0)

22 - NVL (pd.estmt_basln_case_amt, 0))

23 END)

24 * pd.giv_conv_factr_amt)

25 AS actl_incrm_niv_amt

26 FROM opt_prmtn_prod_fct pd

27 , opt_prod_dim p

28 WHERE pd.prod_skid = p.prod_skid

29 AND pd.prmtn_skid <> 0

30 GROUP BY pd.prmtn_skid) srce

31 ON (srce.prmtn_skid = trgt.base_prmtn_skid)

32 WHEN MATCHED

33 THEN

34 UPDATE SET trgt.actl_tot_su_amt = srce.actl_tot_su_amt

35 , trgt.actl_incrm_niv_amt = srce.actl_incrm_niv_amt;

Explained.

Elapsed: 00:00:01.32

SQL> select * from table(dbms_xplan.display);

PLAN_TABLE_OUTPUT

-----------------------------------------------------------------------------------------------------------------------------------------

Plan hash value: 1086075059

-----------------------------------------------------------------------------------------------------------------------------------------

| Id | Operation | Name | Rows | Bytes | Cost (%CPU)| Time | Pstart| Pstop |

-----------------------------------------------------------------------------------------------------------------------------------------

| 0 | MERGE STATEMENT | | 1 | 38 | 89110 (1)| 00:04:39 | | |

| 1 | MERGE | OPT_PRMTN_FCT | | | | | | |

| 2 | VIEW | | | | | | | |

| 3 | NESTED LOOPS | | 1 | 231 | 89110 (1)| 00:04:39 | | |

| 4 | PARTITION LIST ALL | | 172K| 33M| 1286 (8)| 00:00:05 | 1 | 16 |

|* 5 | TABLE ACCESS FULL | OPT_PRMTN_FCT | 172K| 33M| 1286 (8)| 00:00:05 | 1 | 16 |

| 6 | VIEW PUSHED PREDICATE | | 1 | 26 | | | | |

|* 7 | FILTER | | | | | | | |

| 8 | SORT AGGREGATE | | 1 | 33 | | | | |

|* 9 | PX COORDINATOR | | | | | | | |

| 10 | PX SEND QC (RANDOM) | :TQ10000 | 1 | 33 | | | | |

| 11 | SORT AGGREGATE | | 1 | 33 | | | | |

|* 12 | FILTER | | | | | | | |

| 13 | NESTED LOOPS | | | | | | | |

| 14 | NESTED LOOPS | | 929 | 30657 | 1361 (1)| 00:00:05 | | |

| 15 | PX PARTITION LIST ALL | | 922 | 21206 | 333 (1)| 00:00:02 | 1 | 16 |

| 16 | TABLE ACCESS BY LOCAL INDEX ROWID| OPT_PRMTN_PROD_FCT | 922 | 21206 | 333 (1)| 00:00:02 | 1 | 16 |

| 17 | BITMAP CONVERSION TO ROWIDS | | | | | | | |

|* 18 | BITMAP INDEX SINGLE VALUE | OPT_PRMTN_PROD_FCT_BX4 | | | | | 1 | 16 |

|* 19 | INDEX RANGE SCAN | OPT_PROD_DIM_PKN | 1 | | 1 (0)| 00:00:01 | | |

| 20 | TABLE ACCESS BY GLOBAL INDEX ROWID | OPT_PROD_DIM | 1 | 10 | 2 (0)| 00:00:01 | ROWID | ROWID |

-----------------------------------------------------------------------------------------------------------------------------------------

Hidden parameter “_optimizer_extend_jppd_view_types” in 11g R1 and R2 is set to true, and there is no such parameter below 11g.

This hidden parameter is used to join pred pushdown on group-by ,distinct,semi-anti-joined view

Ok, let’s review our SQL

MERGE INTO (SELECT *

FROM opt_prmtn_fct

WHERE base_prmtn_skid <> 0

) trgt

USING ( SELECT pd.prmtn_skid AS prmtn_skid

, SUM (pd.actl_tot_qty_amt) AS actl_tot_su_amt

, SUM ( (CASE

WHEN ( (NVL (p.su_factr, 0) * NVL (pd.actl_tot_qty_amt, 0)

- NVL (pd.estmt_basln_case_amt, 0)) < 0)

THEN

0

ELSE

( (NVL (p.su_factr, 0) * NVL (pd.actl_tot_qty_amt, 0))

- NVL (pd.estmt_basln_case_amt, 0))

END

+ CASE

WHEN (NVL (pd.post_prmtn_vol_adjmt_su_amt, 0) = 0)

THEN

0

ELSE

(NVL (pd.post_prmtn_vol_adjmt_su_amt, 0)

- NVL (pd.estmt_basln_case_amt, 0))

END)

* pd.giv_conv_factr_amt)

AS actl_incrm_niv_amt

FROM opt_prmtn_prod_fct pd

, opt_prod_dim p

WHERE pd.prod_skid = p.prod_skid

AND pd.prmtn_skid <> 0

GROUP BY pd.prmtn_skid) srce ----there have group by and have a view srce

ON (srce.prmtn_skid = trgt.base_prmtn_skid)

WHEN MATCHED

THEN

UPDATE SET trgt.actl_tot_su_amt = srce.actl_tot_su_amt

, trgt.actl_incrm_niv_amt = srce.actl_incrm_niv_amt;

So set hidden parameter “_optimizer_extend_jppd_view_types” to true, CBO will pushdown GROUP BY

| 6 | VIEW PUSHED PREDICATE | | 1 | 26 | | | | |

VIEW PUSHED PREDICATE means CBO choose PUSHDOWN GROUP BY VIEW

这是优化后的代码,还有什么其他可以改进的?在不改变模型结构的前提下 import os import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader, random_split from torch.optim.lr_scheduler import OneCycleLR import numpy as np import pandas as pd from PIL import Image from tqdm import tqdm import matplotlib.pyplot as plt from sklearn.metrics import f1_score, hamming_loss, accuracy_score, cohen_kappa_score import scipy.io as sio import h5py import shutil from pathlib import Path # ---------------------- 工具函数:批量转换数据为HDF5格式 ---------------------- def convert_to_hdf5(root_path, inr_list, hdf5_save_path="cci31_hdf5"): """ 将所有.mat和.png文件批量转换为HDF5格式,按INR档位分文件存储 每个HDF5文件包含:tfi(数据集)、fd(数据集)、labels(数据集)、filenames(属性) """ os.makedirs(hdf5_save_path, exist_ok=True) print(f"开始将数据转换为HDF5格式,保存至:{hdf5_save_path}") for inr in tqdm(inr_list, desc="转换INR档位"): inr_str = f"INR_{inr}dB" inr_path = os.path.join(root_path, inr_str) if not os.path.exists(inr_path): print(f"警告:{inr_path} 不存在,跳过该INR档位") continue # 读取标签文件 label_path = os.path.join(inr_path, "labels.csv") if not os.path.exists(label_path): print(f"警告:{label_path} 不存在,跳过该INR档位") continue df = pd.read_csv(label_path) if df.empty: continue # 创建当前INR对应的HDF5文件 hdf5_file = os.path.join(hdf5_save_path, f"inr_{inr}dB.h5") with h5py.File(hdf5_file, "w") as f: # 预处理所有样本,获取统一形状 tfi_list = [] fd_list = [] labels_list = [] filenames = [] for _, row in df.iterrows(): filename = row["filename"] filenames.append(filename.encode("utf-8")) # HDF5存储字符串需编码 # 读取TFI图片(灰度图) tfi_path = os.path.join(inr_path, "tfi", f"{filename}.png") if not os.path.exists(tfi_path): continue with Image.open(tfi_path).convert("L") as img: tfi_np = np.array(img, dtype=np.float32) / 255.0 # 归一化到[0,1] tfi_list.append(tfi_np) # 读取FD数据(mat文件) fft_path = os.path.join(inr_path, "fft_seq", f"{filename}.mat") if not os.path.exists(fft_path): continue fft_data = sio.loadmat(fft_path)["fft_iq"] # 假设形状为(2, 32, 32)或类似 fd_np = np.array(fft_data, dtype=np.float32) fd_list.append(fd_np) # 读取标签 label_np = np.array([ row["label_ST"], row["label_LSF"], row["label_CSF"], row["label_NPM"], row["label_BPSK"] ], dtype=np.float32) labels_list.append(label_np) # 转换为numpy数组(批量存储) if len(tfi_list) == 0: continue tfi_arr = np.stack(tfi_list, axis=0) # (N, H, W) fd_arr = np.stack(fd_list, axis=0) # (N, 2, H, W) 或 (N, 1024) labels_arr = np.stack(labels_list, axis=0) # (N, 5) # 创建HDF5数据集(启用压缩减少存储) f.create_dataset("tfi", data=tfi_arr, compression="gzip", compression_opts=1) f.create_dataset("fd", data=fd_arr, compression="gzip", compression_opts=1) f.create_dataset("labels", data=labels_arr, compression="gzip", compression_opts=1) # 存储文件名(作为属性) f.attrs["filenames"] = filenames f.attrs["sample_count"] = len(tfi_arr) print(f" 已生成:{hdf5_file} | 样本数:{len(tfi_arr)}") print(f"\nHDF5转换完成!共生成 {len(os.listdir(hdf5_save_path))} 个HDF5文件") return hdf5_save_path # ---------------------- 1. 模型模块定义(保持不变) ---------------------- class ADWC(nn.Module): def __init__(self, in_channels, stride=1): super().__init__() self.in_channels = in_channels self.stride = stride self.dwconv_3x3 = nn.Conv2d( in_channels, in_channels, kernel_size=(3, 3), stride=stride, padding=1, groups=in_channels, bias=False ) self.dwconv_1x3 = nn.Conv2d( in_channels, in_channels, kernel_size=(1, 3), stride=stride, padding=(0, 1), groups=in_channels, bias=False ) self.dwconv_3x1 = nn.Conv2d( in_channels, in_channels, kernel_size=(3, 1), stride=stride, padding=(1, 0), groups=in_channels, bias=False ) self.relu = nn.ReLU(inplace=True) self.fused_conv = None self.fused = False def fuse_kernels(self): if self.fused: return k3x3 = self.dwconv_3x3.weight.data k1x3 = self.dwconv_1x3.weight.data k3x1 = self.dwconv_3x1.weight.data k1x3_padded = torch.zeros_like(k3x3) k1x3_padded[:, :, 1:2, :] = k1x3 k3x1_padded = torch.zeros_like(k3x3) k3x1_padded[:, :, :, 1:2] = k3x1 fused_kernel = k3x3 + k1x3_padded + k3x1_padded self.fused_conv = nn.Conv2d( self.in_channels, self.in_channels, kernel_size=(3, 3), stride=self.stride, padding=1, groups=self.in_channels, bias=False ) self.fused_conv.weight.data = fused_kernel self.fused = True def forward(self, x): if self.training: out3x3 = self.dwconv_3x3(x) out1x3 = self.dwconv_1x3(x) out3x1 = self.dwconv_3x1(x) out = out3x3 + out1x3 + out3x1 out = self.relu(out) else: if not self.fused: self.fuse_kernels() out = self.relu(self.fused_conv(x)) return out class EConv(nn.Module): def __init__(self, in_channels, expansion_ratio=2): super().__init__() mid_channels = in_channels * expansion_ratio self.expand = nn.Sequential( nn.Conv2d(in_channels, mid_channels, kernel_size=1, padding=0, bias=False), nn.BatchNorm2d(mid_channels), nn.ReLU(inplace=True) ) self.adwc = ADWC(mid_channels, stride=1) self.squeeze = nn.Sequential( nn.Conv2d(mid_channels, in_channels, kernel_size=1, padding=0, bias=False), nn.BatchNorm2d(in_channels) ) def forward(self, x): residual = x out = self.expand(x) out = self.adwc(out) out = self.squeeze(out) out += residual out = nn.functional.relu(out, inplace=True) return out class DownsamplingBlock(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.adwc = ADWC(in_channels, stride=2) self.conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) self.bn = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) def forward(self, x): x = self.adwc(x) x = self.conv1x1(x) x = self.bn(x) x = self.relu(x) return x class CBAM(nn.Module): def __init__(self, in_channels, reduction=16, kernel_size=7): super().__init__() self.ca = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_channels, in_channels // reduction, kernel_size=1, bias=False), nn.ReLU(inplace=True), nn.Conv2d(in_channels // reduction, in_channels, kernel_size=1, bias=False), nn.Sigmoid() ) self.sa = nn.Sequential( nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size // 2, bias=False), nn.Sigmoid() ) def forward(self, x): out = x * self.ca(x) spatial_avg = out.mean(dim=1, keepdim=True) spatial_max, _ = out.max(dim=1, keepdim=True) spatial_feat = torch.cat([spatial_avg, spatial_max], dim=1) out = out * self.sa(spatial_feat) return out class FAF(nn.Module): def __init__(self, in_channels, reduction=16): super().__init__() self.ln = nn.LayerNorm(in_channels) self.channel_conv = nn.Sequential( nn.Conv2d(in_channels, in_channels // 2, kernel_size=1, bias=False), nn.BatchNorm2d(in_channels // 2), nn.ReLU(inplace=True), ADWC(in_channels // 2), nn.Conv2d(in_channels // 2, in_channels, kernel_size=1, bias=False), nn.BatchNorm2d(in_channels) ) self.cbam = CBAM(in_channels, reduction=reduction) def forward(self, x): residual = x x_trans = x.permute(0, 2, 3, 1) x_norm = self.ln(x_trans) x = x_norm.permute(0, 3, 1, 2) x = self.channel_conv(x) x_freq = torch.fft.fftn(x, dim=(-2, -1)) x_freq_real = self.cbam(x_freq.real) x_freq_imag = self.cbam(x_freq.imag) x_freq = torch.complex(x_freq_real, x_freq_imag) x = torch.fft.ifftn(x_freq, dim=(-2, -1)).real x += residual return x class D_FAF(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.downsample = DownsamplingBlock(in_channels, out_channels) self.faf = FAF(out_channels) def forward(self, x): local_feat = self.downsample(x) global_feat = self.faf(local_feat) out = torch.cat([local_feat, global_feat], dim=1) return out class MIRNet(nn.Module): def __init__(self, num_classes=5): super().__init__() self.num_classes = num_classes self.tfinet = nn.Sequential( nn.Conv2d(1, 16, kernel_size=(3, 3), padding=1, bias=False), nn.BatchNorm2d(16), nn.ReLU(inplace=True), EConv(16, expansion_ratio=2), CBAM(16), DownsamplingBlock(16, 64), EConv(64, expansion_ratio=2), D_FAF(64, 96), D_FAF(192, 128), D_FAF(256, 160), nn.AdaptiveAvgPool2d((1, 1)) ) self.fdnet = nn.Sequential( nn.Conv2d(2, 16, kernel_size=(3, 3), padding=1, bias=False), nn.BatchNorm2d(16), nn.ReLU(inplace=True), EConv(16, expansion_ratio=2), CBAM(16), DownsamplingBlock(16, 64), D_FAF(64, 96), nn.AdaptiveAvgPool2d((1, 1)) ) self.fusion_fc = nn.Sequential( nn.Flatten(), nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(inplace=True), nn.Linear(256, num_classes), nn.Sigmoid() ) def forward(self, tfi_input, fd_input): tfi_feat = self.tfinet(tfi_input) fd_input_2d = fd_input.view(-1, 2, 32, 32) fd_feat = self.fdnet(fd_input_2d) fused_feat = torch.cat([tfi_feat, fd_feat], dim=1) pred_prob = self.fusion_fc(fused_feat) return pred_prob # ---------------------- 2. 优化后的HDF5数据集(按需只读访问) ---------------------- class CCI31HDF5Dataset(Dataset): def __init__(self, hdf5_root, inr_list): self.hdf5_root = hdf5_root self.inr_list = inr_list self.sample_info = [] # 存储每个样本的(HDF5文件路径, 样本索引, INR值) self.hdf5_files = {} # 缓存已打开的HDF5文件句柄(避免重复打开) # 预扫描所有HDF5文件,收集样本信息(仅记录索引,不加载数据) print("扫描HDF5文件并收集样本信息...") for inr in tqdm(inr_list, desc="扫描INR档位"): hdf5_file = os.path.join(hdf5_root, f"inr_{inr}dB.h5") if not os.path.exists(hdf5_file): print(f"警告:{hdf5_file} 不存在,跳过") continue # 只读模式打开HDF5文件,获取样本数 with h5py.File(hdf5_file, "r") as f: sample_count = f.attrs.get("sample_count", 0) if sample_count == 0: continue # 记录每个样本的位置信息 for idx in range(sample_count): self.sample_info.append((hdf5_file, idx, inr)) print(f"样本收集完成:共{len(self.sample_info)}个有效样本") def __len__(self): return len(self.sample_info) def __getitem__(self, idx): # 根据索引获取样本位置信息 hdf5_file, sample_idx, inr = self.sample_info[idx] # 缓存HDF5文件句柄(避免重复打开/关闭,提升速度) if hdf5_file not in self.hdf5_files: self.hdf5_files[hdf5_file] = h5py.File(hdf5_file, "r", libver="latest", swmr=True) f = self.hdf5_files[hdf5_file] # 按需读取单个样本(HDF5支持随机访问,无需加载整个数据集) tfi = torch.tensor(f["tfi"][sample_idx], dtype=torch.float32).unsqueeze(0) # (1, H, W) fd = torch.tensor(f["fd"][sample_idx], dtype=torch.float32) # (2, 32, 32) labels = torch.tensor(f["labels"][sample_idx], dtype=torch.float32) # (5,) inr_tensor = torch.tensor([inr], dtype=torch.int) return tfi, fd, labels, inr_tensor def __del__(self): # 关闭所有打开的HDF5文件句柄 for f in self.hdf5_files.values(): if not f.closed: f.close() print("所有HDF5文件句柄已关闭") # ---------------------- 3. 数据加载与划分(优化DataLoader配置) ---------------------- def prepare_data(hdf5_root, inr_list, batch_size=64, num_workers=8): # 创建HDF5数据集(仅扫描样本信息,不加载数据) dataset = CCI31HDF5Dataset(hdf5_root, inr_list) total_size = len(dataset) train_size = int(0.5 * total_size) val_size = int(0.1 * total_size) test_size = total_size - train_size - val_size generator = torch.Generator().manual_seed(42) train_dataset, val_dataset, test_dataset = random_split( dataset, [train_size, val_size, test_size], generator=generator ) # 优化DataLoader配置: # - pin_memory=True:将数据固定在GPU内存,加速传输 # - persistent_workers=True:训练期间保持worker进程,避免重复创建 # - num_workers=8:根据CPU核心数调整(建议等于CPU核心数) # - prefetch_factor=2:提前加载2个batch train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, persistent_workers=True, drop_last=True, prefetch_factor=2 ) val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=min(num_workers, 4), pin_memory=True, persistent_workers=True, prefetch_factor=2 ) test_loader = DataLoader( test_dataset, batch_size=batch_size, shuffle=False, num_workers=min(num_workers, 4), pin_memory=True, persistent_workers=True, prefetch_factor=2 ) print(f"数据划分完成:") print(f"- 训练集:{len(train_dataset)}样本") print(f"- 验证集:{len(val_dataset)}样本") print(f"- 测试集:{len(test_dataset)}样本") return train_loader, val_loader, test_loader # ---------------------- 4. 多阈值优化算法(保持不变) ---------------------- def optimize_multithreshold(model, train_loader, device, threshold_range=[0.01, 0.99], step=0.01): model.eval() num_classes = model.num_classes jam_names = ["ST", "LSF", "CSF", "NPM", "BPSK"] all_pred_prob = [] all_true_labels = [] with torch.no_grad(): for tfi, fd, labels, _ in tqdm(train_loader, desc="收集训练集预测结果"): tfi, fd = tfi.to(device, non_blocking=True), fd.to(device, non_blocking=True) # non_blocking加速 pred_prob = model(tfi, fd) all_pred_prob.append(pred_prob.cpu().numpy()) all_true_labels.append(labels.cpu().numpy()) all_pred_prob = np.vstack(all_pred_prob) all_true_labels = np.vstack(all_true_labels) opt_thresholds = np.ones(num_classes) * 0.5 for c in range(num_classes): best_f1 = 0.0 best_th = 0.5 thresholds = np.arange(threshold_range[0], threshold_range[1] + step, step) for th in thresholds: pred_label = (all_pred_prob[:, c] >= th).astype(int) tp = np.sum((all_true_labels[:, c] == 1) & (pred_label == 1)) fp = np.sum((all_true_labels[:, c] == 0) & (pred_label == 1)) fn = np.sum((all_true_labels[:, c] == 1) & (pred_label == 0)) precision = tp / (tp + fp + 1e-8) recall = tp / (tp + fn + 1e-8) f1 = 2 * precision * recall / (precision + recall + 1e-8) if f1 > best_f1: best_f1 = f1 best_th = th opt_thresholds[c] = best_th print(f"标签{c}({jam_names[c]}):最优阈值={best_th:.3f},最优F1={best_f1:.4f}") def apply_T_rule(pred_prob): pred_label = (pred_prob >= opt_thresholds).astype(int) all_zero_mask = (np.sum(pred_label, axis=1) == 0) if np.any(all_zero_mask): max_idx = np.argmax(pred_prob[all_zero_mask], axis=1) pred_label[all_zero_mask, max_idx] = 1 return pred_label np.save("optimal_thresholds.npy", opt_thresholds) print(f"\n最优阈值已保存至:optimal_thresholds.npy") return opt_thresholds, apply_T_rule # ---------------------- 5. 训练函数(优化数据传输) ---------------------- def train_model(model, train_loader, val_loader, device, epochs=30, lr=0.003, weight_decay=2e-4): model.to(device) criterion = nn.BCELoss() optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) scheduler = OneCycleLR( optimizer, max_lr=lr, steps_per_epoch=len(train_loader), epochs=epochs, anneal_strategy="cos" ) best_val_f1 = 0.0 train_losses = [] val_losses = [] val_f1s = [] for epoch in range(epochs): model.train() train_loss = 0.0 with tqdm(train_loader, desc=f"Epoch {epoch + 1:2d}/{epochs}") as pbar: for tfi, fd, labels, _ in pbar: # non_blocking=True 配合pin_memory=True加速GPU数据传输 tfi, fd, labels = tfi.to(device, non_blocking=True), fd.to(device, non_blocking=True), labels.to(device, non_blocking=True) optimizer.zero_grad() pred_prob = model(tfi, fd) loss = criterion(pred_prob, labels) loss.backward() optimizer.step() scheduler.step() train_loss += loss.item() * tfi.size(0) pbar.set_postfix({ "Train Loss": f"{loss.item():.4f}", "LR": f"{scheduler.get_last_lr()[0]:.6f}" }) avg_train_loss = train_loss / len(train_loader.dataset) train_losses.append(avg_train_loss) model.eval() val_loss = 0.0 val_pred_prob = [] val_true_labels = [] with torch.no_grad(): for tfi, fd, labels, _ in val_loader: tfi, fd, labels = tfi.to(device, non_blocking=True), fd.to(device, non_blocking=True), labels.to(device, non_blocking=True) pred_prob = model(tfi, fd) loss = criterion(pred_prob, labels) val_loss += loss.item() * tfi.size(0) val_pred_prob.append(pred_prob.cpu().numpy()) val_true_labels.append(labels.cpu().numpy()) avg_val_loss = val_loss / len(val_loader.dataset) val_losses.append(avg_val_loss) val_pred_prob = np.vstack(val_pred_prob) val_true_labels = np.vstack(val_true_labels) val_pred_label = (val_pred_prob >= 0.5).astype(int) val_f1 = f1_score(val_true_labels, val_pred_label, average="macro") val_f1s.append(val_f1) if val_f1 > best_val_f1: best_val_f1 = val_f1 torch.save(model.state_dict(), "best_mirnet.pth") print(f" → 保存最佳模型(Val F1: {best_val_f1:.4f})") print(f" 训练损失:{avg_train_loss:.4f} | 验证损失:{avg_val_loss:.4f} | 验证F1:{val_f1:.4f}") plt.figure(figsize=(12, 4)) plt.subplot(1, 2, 1) plt.plot(range(1, epochs + 1), train_losses, label="Train Loss") plt.plot(range(1, epochs + 1), val_losses, label="Val Loss") plt.xlabel("Epoch") plt.ylabel("Loss") plt.legend() plt.grid(alpha=0.3) plt.subplot(1, 2, 2) plt.plot(range(1, epochs + 1), val_f1s, label="Val Macro F1", color="orange") plt.xlabel("Epoch") plt.ylabel("Macro F1-Score") plt.legend() plt.grid(alpha=0.3) plt.tight_layout() plt.savefig("train_curve.png") plt.close() print(f"\n训练曲线已保存至:train_curve.png") model.load_state_dict(torch.load("best_mirnet.pth")) return model, train_losses, val_losses # ---------------------- 6. 多INR测试与指标输出(保持不变) ---------------------- def test_by_inr(model, test_loader, apply_T_rule, device, inr_list): model.eval() inr_metrics = {inr: {"true": [], "pred": []} for inr in inr_list} all_true, all_pred = [], [] with torch.no_grad(): for tfi, fd, labels, inr in tqdm(test_loader, desc="按INR测试"): tfi, fd, labels = tfi.to(device, non_blocking=True), fd.to(device, non_blocking=True), labels.to(device, non_blocking=True) pred_prob = model(tfi, fd) pred_label = apply_T_rule(pred_prob.cpu().numpy()) inr_key = int(inr[0].item()) inr_metrics[inr_key]["true"].append(labels.cpu().numpy()) inr_metrics[inr_key]["pred"].append(pred_label) all_true.append(labels.cpu().numpy()) all_pred.append(pred_label) all_true = np.vstack(all_true) all_pred = np.vstack(all_pred) oa = accuracy_score(all_true, all_pred) * 100 macro_f1 = f1_score(all_true, all_pred, average="macro") * 100 precision = f1_score(all_true, all_pred, average="macro", zero_division=0) * 100 recall = f1_score(all_true, all_pred, average="macro", zero_division=0) * 100 hamming_score = (1 - hamming_loss(all_true, all_pred)) * 100 params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6 inr_accuracies = [] for inr in inr_list: if inr not in inr_metrics: inr_accuracies.append(0.0) continue true = np.vstack(inr_metrics[inr]["true"]) pred = np.vstack(inr_metrics[inr]["pred"]) acc = accuracy_score(true, pred) * 100 inr_accuracies.append(acc) plt.figure(figsize=(8, 6)) plt.plot(inr_list, inr_accuracies, marker="o", color="blue", label="MIRNet") plt.xlabel("INR (dB)") plt.ylabel("Accuracy (%)") plt.xticks(inr_list) plt.grid(alpha=0.3) plt.legend() plt.title("MIRNet在不同INR下的全局准确率") plt.savefig("inr_accuracy_curve.png") plt.close() print("\n" + "=" * 60) print("论文表4.3 指标结果:") print(f"OA(%)\tF1(%)\tPrecision(%)\tRecall(%)\tHM_score(%)\tParms(M)") print(f"{oa:.2f}\t{macro_f1:.2f}\t{precision:.2f}\t{recall:.2f}\t{hamming_score:.2f}\t{params:.2f}") print("=" * 60) return inr_accuracies, { "OA(%)": oa, "F1(%)": macro_f1, "Precision(%)": precision, "Recall(%)": recall, "HM_score(%)": hamming_score, "Parms(M)": params } # ---------------------- 7. 主函数(新增HDF5转换步骤) ---------------------- if __name__ == "__main__": DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") ORIGINAL_ROOT = "D:\\MATLAB R2023b\\Project\\inter\\awgn" # 原始数据路径(.mat和.png) HDF5_ROOT = "cci31_hdf5" # HDF5文件保存路径 INR_LIST = list(range(-10, 31, 4)) BATCH_SIZE = 128 EPOCHS = 30 MAX_LR = 0.003 WEIGHT_DECAY = 2e-4 NUM_WORKERS = 8 # 根据CPU核心数调整(建议=CPU核心数,如16核设为16) print("=" * 60) print("MIRNet训练与多INR测试流程(HDF5优化版)") print(f"设备:{DEVICE} | 原始数据:{ORIGINAL_ROOT} | HDF5保存:{HDF5_ROOT}") print(f"INR档位:{INR_LIST} | 批量大小:{BATCH_SIZE} | 工作线程数:{NUM_WORKERS}") print("=" * 60) # 步骤0:检查是否需要转换HDF5(如果HDF5目录不存在或为空,则转换) if not os.path.exists(HDF5_ROOT) or len(os.listdir(HDF5_ROOT)) == 0: print("\n【步骤0:批量转换数据为HDF5格式】") HDF5_ROOT = convert_to_hdf5(ORIGINAL_ROOT, INR_LIST, HDF5_ROOT) else: print(f"\n【步骤0:HDF5文件已存在,跳过转换】") print("\n【步骤1:加载HDF5数据集】") train_loader, val_loader, test_loader = prepare_data( hdf5_root=HDF5_ROOT, inr_list=INR_LIST, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS ) print("\n【步骤2:初始化MIRNet模型】") model = MIRNet(num_classes=5) params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6 print(f"模型参数量:{params:.2f}M") print("\n【步骤3:开始训练】") model, train_losses, val_losses = train_model( model=model, train_loader=train_loader, val_loader=val_loader, device=DEVICE, epochs=EPOCHS, lr=MAX_LR, weight_decay=WEIGHT_DECAY ) print("\n【步骤4:多阈值优化】") opt_thresholds, apply_T_rule = optimize_multithreshold( model=model, train_loader=train_loader, device=DEVICE ) print("\n【步骤5:按INR测试并生成曲线】") inr_accuracies, table43 = test_by_inr( model=model, test_loader=test_loader, apply_T_rule=apply_T_rule, device=DEVICE, inr_list=INR_LIST ) print("\n" + "=" * 60) print("流程全部完成!") print(f"结果文件:") print(f"- INR-准确率曲线:inr_accuracy_curve.png") print(f"- 训练曲线:train_curve.png") print(f"- 最优阈值:optimal_thresholds.npy") print(f"- 最佳模型:best_mirnet.pth") print("=" * 60)
11-21
import os import cv2 import random import hashlib import warnings from tqdm import tqdm from PIL import Image import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from torchvision import transforms from torchvision.utils import save_image from torchvision.models import vgg16 import numpy as np from torch.cuda.amp import autocast, GradScaler from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts # ===================== 基础配置 & 工具函数 ===================== try: from torchmetrics import ( F1Score as TorchMetricsF1Score, JaccardIndex as TorchMetricsJaccardIndex, StructuralSimilarityIndexMeasure ) TORCHMETRICS_AVAILABLE = True except ImportError: TORCHMETRICS_AVAILABLE = False class SimpleF1Score: def __init__(self, task="binary", average="macro", device=None): self.eps = 1e-8 self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu') def reset(self): pass def __call__(self, pred, target): pred = pred.to(self.device) target = target.to(self.device) tp = (pred * target).sum() fp = (pred * (1 - target)).sum() fn = ((1 - pred) * target).sum() precision = tp / (tp + fp + self.eps) recall = tp / (tp + fn + self.eps) f1 = 2 * precision * recall / (precision + recall + self.eps) return f1.item() class SimpleJaccardIndex: def __init__(self, task="binary", device=None): self.eps = 1e-8 self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu') def reset(self): pass def __call__(self, pred, target): pred = pred.to(self.device) target = target.to(self.device) intersection = (pred * target).sum() union = pred.sum() + target.sum() - intersection return (intersection / (union + self.eps)).item() class StructuralSimilarityIndexMeasure: def __init__(self, data_range=1.0, device=None): self.data_range = data_range self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.gaussian_kernel = self._create_gaussian_kernel() def _create_gaussian_kernel(self, kernel_size=11, sigma=1.5): kernel = torch.tensor([ [1, 1, 2, 2, 2, 1, 1], [1, 2, 2, 4, 2, 2, 1], [2, 2, 4, 8, 4, 2, 2], [2, 4, 8, 16, 8, 4, 2], [2, 2, 4, 8, 4, 2, 2], [1, 2, 2, 4, 2, 2, 1], [1, 1, 2, 2, 2, 1, 1] ], device=self.device).float() / 144 return kernel.unsqueeze(0).unsqueeze(0) def reset(self): pass def __call__(self, pred, target): pred = pred.to(self.device) target = target.to(self.device) pred = (pred + 1) / 2 # [-1,1] → [0,1] target = (target + 1) / 2 pred = nn.functional.pad(pred, (3,3,3,3), mode='reflect') target = nn.functional.pad(target, (3,3,3,3), mode='reflect') pred_mean = nn.functional.conv2d(pred, self.gaussian_kernel.repeat(3,1,1,1), groups=3) target_mean = nn.functional.conv2d(target, self.gaussian_kernel.repeat(3,1,1,1), groups=3) pred_var = nn.functional.conv2d(pred**2, self.gaussian_kernel.repeat(3,1,1,1), groups=3) - pred_mean**2 target_var = nn.functional.conv2d(target**2, self.gaussian_kernel.repeat(3,1,1,1), groups=3) - target_mean**2 cov = nn.functional.conv2d(pred*target, self.gaussian_kernel.repeat(3,1,1,1), groups=3) - pred_mean*target_mean c1 = (0.01 * self.data_range)**2 c2 = (0.03 * self.data_range)**2 ssim = ((2*pred_mean*target_mean + c1) * (2*cov + c2)) / ((pred_mean**2 + target_mean**2 + c1) * (pred_var + target_var + c2)) return torch.mean(ssim).item() warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=FutureWarning) os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True,garbage_collection_threshold:0.6' import matplotlib.pyplot as plt import matplotlib matplotlib.use('Agg') matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei'] matplotlib.rcParams['axes.unicode_minus'] = False def get_file_hash(file_path): hasher = hashlib.sha256() try: with open(file_path, 'rb') as f: while chunk := f.read(8192): hasher.update(chunk) return hasher.hexdigest() except Exception: return hashlib.md5(file_path.encode()).hexdigest() def ensure_dir(path): if not path: return "" path = os.path.normpath(path) try: os.makedirs(path, exist_ok=True) except Exception as e: print(f"创建目录警告:{path} | {str(e)[:50]}") return path def tensor_to_np(tensor): with torch.no_grad(): if len(tensor.shape) == 4: tensor = tensor[0] img_np = (tensor.permute(1, 2, 0).cpu().numpy() + 1) / 2 * 255 img_np = np.clip(img_np, 0, 255).astype(np.uint8) if img_np.shape[-1] != 3: img_np = cv2.cvtColor(img_np, cv2.COLOR_GRAY2RGB) return img_np def check_mask_validity(mask_tensor): mask_tensor = mask_tensor.squeeze() if len(mask_tensor.shape) > 2 else mask_tensor if len(mask_tensor.shape) != 2: return 0.0, False valid_pixels = torch.sum(mask_tensor) total_pixels = mask_tensor.numel() valid_ratio = valid_pixels / (total_pixels + 1e-8) return valid_ratio.item(), valid_ratio >= 0.001 # ===================== 核心修复:损失函数(极度简化,先让生成器活下来) ===================== def calculate_vein_loss(gen_img, real_img): """暂时禁用脉络损失,先让生成器学基础结构""" return torch.tensor(0.0, device=gen_img.device) def calculate_texture_loss(gen_lesion, real_lesion): """大幅降低纹理损失权重""" device = gen_lesion.device kernel_gauss = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], device=device).float() / 16 kernel_laplacian = torch.tensor([[0, 1, 0], [1, -4, 1], [0, 1, 0]], device=device).float() kernel_gauss = kernel_gauss.unsqueeze(0).unsqueeze(0).repeat(3, 1, 1, 1) kernel_laplacian = kernel_laplacian.unsqueeze(0).unsqueeze(0).repeat(3, 1, 1, 1) gen_tex_gauss = nn.functional.conv2d(gen_lesion, kernel_gauss, padding=1, groups=3) real_tex_gauss = nn.functional.conv2d(real_lesion, kernel_gauss, padding=1, groups=3) gen_tex_laplacian = nn.functional.conv2d(gen_lesion, kernel_laplacian, padding=1, groups=3) real_tex_laplacian = nn.functional.conv2d(real_lesion, kernel_laplacian, padding=1, groups=3) return (nn.functional.l1_loss(gen_tex_gauss, real_tex_gauss) + nn.functional.l1_loss(gen_tex_laplacian, real_tex_laplacian)) * 0.1 # 0.6→0.1 def calculate_edge_loss(gen_img, real_img): """大幅降低边缘损失权重""" device = gen_img.device sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32, device=device) sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32, device=device) sobel_x = sobel_x.unsqueeze(0).unsqueeze(0).repeat(3, 1, 1, 1) sobel_y = sobel_y.unsqueeze(0).unsqueeze(0).repeat(3, 1, 1, 1) gen_edge_x = torch.abs(nn.functional.conv2d(gen_img, sobel_x, padding=1, groups=3)) real_edge_x = torch.abs(nn.functional.conv2d(real_img, sobel_x, padding=1, groups=3)) gen_edge_y = torch.abs(nn.functional.conv2d(gen_img, sobel_y, padding=1, groups=3)) real_edge_y = torch.abs(nn.functional.conv2d(real_img, sobel_y, padding=1, groups=3)) return (nn.functional.l1_loss(gen_edge_x, real_edge_x) + nn.functional.l1_loss(gen_edge_y, real_edge_y)) * 0.1 # 0.4→0.1 class PerceptualLoss(nn.Module): def __init__(self, device): super().__init__() self.device = device try: vgg = vgg16(pretrained=True).features.to(device) except Exception: vgg = vgg16(weights=None).features.to(device) self.layers = nn.Sequential( vgg[0], vgg[1], vgg[2], vgg[3], vgg[4] ).eval() for param in self.layers.parameters(): param.requires_grad = False self.l1 = nn.L1Loss() def forward(self, gen_img, real_img): """暂时禁用感知损失""" return torch.tensor(0.0, device=gen_img.device) def calculate_tv_loss(img): """适度降低TV损失,避免过度约束""" tv_h = torch.sum(torch.abs(img[:, :, 1:, :] - img[:, :, :-1, :])) tv_v = torch.sum(torch.abs(img[:, :, :, 1:] - img[:, :, :, :-1])) return (tv_h + tv_v) / (img.shape[0] * img.shape[2] * img.shape[3]) * 0.001 # 0.002→0.001 def calculate_color_hist_loss(gen_img, real_img, mask): """暂时禁用颜色损失""" return torch.tensor(0.0, device=gen_img.device) class LesionFocusLoss(nn.Module): def __init__(self, device): super().__init__() self.l1 = nn.L1Loss() self.device = device self.perceptual_loss = PerceptualLoss(device) if TORCHMETRICS_AVAILABLE: self.ssim_loss = StructuralSimilarityIndexMeasure(data_range=1.0).to(device) else: self.ssim_loss = StructuralSimilarityIndexMeasure(data_range=1.0, device=device) def forward(self, gen_img, mask, real_img): if torch.sum(mask) < 1e-4: mask = torch.ones_like(mask) lesion_area = mask.repeat(1, 3, 1, 1) real_lesion = real_img * lesion_area gen_lesion = gen_img * lesion_area # 极度简化内容损失:只保留核心L1,权重降到极低 lesion_l1 = self.l1(gen_lesion, real_lesion) * 0.05 # 0.3→0.05 bg_l1 = self.l1(gen_img * (1 - lesion_area), real_img * (1 - lesion_area)) * 0.01 # 0.1→0.01 texture_loss = calculate_texture_loss(gen_lesion, real_lesion) color_loss = calculate_color_hist_loss(gen_img, real_img, mask) edge_loss = calculate_edge_loss(gen_img, real_img) perceptual_loss = self.perceptual_loss(gen_img, real_img) tv_loss = calculate_tv_loss(gen_img) vein_loss = calculate_vein_loss(gen_img, real_img) real_global_edge = calculate_edge_loss(real_img, real_img) gen_global_edge = calculate_edge_loss(gen_img, gen_img) structure_loss = self.l1(gen_global_edge, real_global_edge) * 0.01 # 0.05→0.01 # 大幅降低SSIM损失权重,先让生成器出图 ssim_val = self.ssim_loss(gen_img, real_img) ssim_loss = (1 - ssim_val) * 0.1 # 0.5→0.1 total_loss = torch.clamp( lesion_l1 + bg_l1 + texture_loss + color_loss + edge_loss + perceptual_loss + tv_loss + structure_loss + ssim_loss + vein_loss, max=1.0 # 3.5→1.0,大幅降低内容损失上限 ) return total_loss # ===================== 核心修复:模型(极度简化生成器+阉割判别器) ===================== class LightweightUNet(nn.Module): def __init__(self, in_channels=3, out_channels=1): super().__init__() self.down1 = self._conv_block(in_channels, 8) self.down2 = self._conv_block(8, 16) self.down3 = self._conv_block(16, 32) self.pool = nn.MaxPool2d(2) self.bottleneck = self._conv_block(32, 64) self.up1 = nn.ConvTranspose2d(64, 32, 2, stride=2) self.conv_up1 = self._conv_block(64, 32) self.up2 = nn.ConvTranspose2d(32, 16, 2, stride=2) self.conv_up2 = self._conv_block(32, 16) self.up3 = nn.ConvTranspose2d(16, 8, 2, stride=2) self.conv_up3 = self._conv_block(16, 8) self.out = nn.Conv2d(8, out_channels, 1) self.sigmoid = nn.Sigmoid() def _conv_block(self, in_ch, out_ch): return nn.Sequential( nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True) ) def forward(self, x): d1 = self.down1(x) d2 = self.down2(self.pool(d1)) d3 = self.down3(self.pool(d2)) b = self.bottleneck(self.pool(d3)) u1 = self.up1(b) if u1.shape[2:] != d3.shape[2:]: u1 = nn.functional.interpolate(u1, size=d3.shape[2:], mode='bilinear') u1 = torch.cat([u1, d3], dim=1) u1 = self.conv_up1(u1) u2 = self.up2(u1) if u2.shape[2:] != d2.shape[2:]: u2 = nn.functional.interpolate(u2, size=d2.shape[2:], mode='bilinear') u2 = torch.cat([u2, d2], dim=1) u2 = self.conv_up2(u2) u3 = self.up3(u2) if u3.shape[2:] != d1.shape[2:]: u3 = nn.functional.interpolate(u3, size=d1.shape[2:], mode='bilinear') u3 = torch.cat([u3, d1], dim=1) u3 = self.conv_up3(u3) return self.sigmoid(self.out(u3)) class UNetLesionMaskGenerator: def __init__(self, device=None): self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu') if TORCHMETRICS_AVAILABLE: self.pixel_f1 = TorchMetricsF1Score(task="binary", average="macro").to(self.device) self.iou = TorchMetricsJaccardIndex(task="binary").to(self.device) self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(self.device) else: self.pixel_f1 = SimpleF1Score(task="binary", average="macro", device=self.device) self.iou = SimpleJaccardIndex(task="binary", device=self.device) self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0, device=self.device) self.seg_model = LightweightUNet().to(self.device) self.seg_transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) self.threshold = 0.5 def load_pretrained_weights(self, weight_path=None): if weight_path and os.path.exists(weight_path): try: self.seg_model.load_state_dict(torch.load(weight_path, map_location=self.device, weights_only=True)) print("加载预训练分割权重成功") except Exception as e: print(f"预训练权重加载失败:{str(e)[:50]}") self.seg_model.eval() @torch.no_grad() def segment_lesion(self, img_pil, return_cpu=True): img_tensor = self.seg_transform(img_pil).unsqueeze(0).to(self.device) pred = self.seg_model(img_tensor) mask_tensor = (pred > self.threshold).float().squeeze(0).squeeze(0) if return_cpu: mask_tensor = mask_tensor.cpu() mask_np = mask_tensor.cpu().numpy().astype(np.uint8) * 255 kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) mask_np = cv2.morphologyEx(mask_np, cv2.MORPH_OPEN, kernel, iterations=1) img_np = np.array(img_pil.resize((256, 256))) return mask_tensor, mask_np, img_np @torch.no_grad() def calculate_precision_metrics(self, real_imgs, gen_imgs): if TORCHMETRICS_AVAILABLE: self.pixel_f1.reset() self.iou.reset() self.ssim.reset() real_masks, gen_masks = [], [] try: if len(real_imgs) == 0 or len(gen_imgs) == 0: return {"pixel": {"f1": 0.0, "iou": 0.0, "ssim": 0.0}, "valid": False} real_img = real_imgs[0] gen_img = gen_imgs[0] real_np = tensor_to_np(real_img) real_mask, _, _ = self.segment_lesion(Image.fromarray(real_np), return_cpu=False) gen_np = tensor_to_np(gen_img) gen_mask, _, _ = self.segment_lesion(Image.fromarray(gen_np), return_cpu=False) real_ratio, _ = check_mask_validity(real_mask.unsqueeze(0)) gen_ratio, _ = check_mask_validity(gen_mask.unsqueeze(0)) if real_ratio >= 0.001 and gen_ratio >= 0.001: real_masks.append(real_mask.unsqueeze(0)) gen_masks.append(gen_mask.unsqueeze(0)) if not real_masks: return {"pixel": {"f1": 0.0, "iou": 0.0, "ssim": 0.0}, "valid": False} real_masks = torch.stack(real_masks).to(self.device) gen_masks = torch.stack(gen_masks).to(self.device) if TORCHMETRICS_AVAILABLE: f1 = self.pixel_f1(gen_masks.float(), real_masks.float()).item() iou = self.iou(gen_masks.float(), real_masks.float()).item() else: f1 = self.pixel_f1(gen_masks, real_masks) iou = self.iou(gen_masks, real_masks) real_imgs_norm = (real_imgs + 1) / 2 gen_imgs_norm = (gen_imgs + 1) / 2 if TORCHMETRICS_AVAILABLE: ssim = self.ssim(gen_imgs_norm.to(self.device), real_imgs_norm.to(self.device)).item() else: ssim = self.ssim(gen_imgs_norm, real_imgs_norm) return {"pixel": {"f1": f1, "iou": iou, "ssim": ssim}, "valid": True} except Exception as e: print(f"指标计算警告:{str(e)[:50]}") return {"pixel": {"f1": 0.0, "iou": 0.0, "ssim": 0.0}, "valid": False} class EnhancedResBlock(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() # 极度简化残差块:单卷积+BN+ReLU self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, 1, 1), nn.BatchNorm2d(out_channels, momentum=0.99), nn.ReLU(inplace=True), ) self.shortcut = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity() def forward(self, x, mask=None): residual = self.conv(x) return nn.ReLU(inplace=True)(x + self.shortcut(residual)) # 核心修复:极度简化上采样模块 class ConvTransposeUpsample(nn.Module): def __init__(self, in_channels, out_channels, scale_factor=2): super().__init__() self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1) self.conv = nn.Sequential( nn.Conv2d(out_channels, out_channels, 3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): x = self.up(x) x = self.conv(x) return x class LesionFocusGenerator(nn.Module): def __init__(self, latent_dim=100, img_size=256): super().__init__() self.latent_dim = latent_dim self.init_size = img_size // 8 # 简化FC层:减少输出通道 self.fc = nn.Sequential( nn.Linear(latent_dim * 2, 128 * self.init_size * self.init_size), # 256→128 nn.ReLU(inplace=True), nn.Dropout(0.05) # 0.1→0.05 ) # 极度简化残差块:仅1个基础块 self.res_blocks = nn.Sequential( EnhancedResBlock(128, 128) # 256→128 ) self.lesion_feat = nn.Sequential( nn.Conv2d(1, 8, 3, padding=1), # 16→8 nn.BatchNorm2d(8, momentum=0.99), nn.ReLU(inplace=True), EnhancedResBlock(8, 8) ) # 简化上采样 self.upsample = nn.Sequential( ConvTransposeUpsample(128 + 8, 64), # 256+16→128+8;128→64 ConvTransposeUpsample(64, 32), # 64→32 ConvTransposeUpsample(32, 16) # 32→16 ) # 简化输出层 self.vein_refine = nn.Sequential( nn.Conv2d(16, 8, 3, padding=1), nn.BatchNorm2d(8, momentum=0.99), nn.ReLU(inplace=True), nn.Conv2d(8, 3, 3, padding=1), nn.Tanh() ) def forward(self, noise, mask): if self.training: noise = noise + 0.01 * torch.randn_like(noise) # 0.02→0.01 mask_global = torch.mean(mask, dim=(2, 3)).expand(-1, self.latent_dim) noise = torch.cat([noise, mask_global], dim=-1) x = self.fc(noise).view(-1, 128, self.init_size, self.init_size) # 256→128 for block in self.res_blocks: x = block(x, mask) mask_feat = self.lesion_feat(mask) mask_feat = nn.functional.interpolate(mask_feat, size=x.shape[2:], mode='bilinear', align_corners=False) x = torch.cat([x, mask_feat], dim=1) x = self.upsample(x) return self.vein_refine(x) class BalancedDiscriminator(nn.Module): """阉割判别器:大幅减少通道+增加Dropout+降低层数""" def __init__(self): super().__init__() self.main = nn.Sequential( nn.Conv2d(3, 8, 4, 2, 1, bias=False), # 16→8 nn.LeakyReLU(0.2, inplace=True), nn.Dropout(0.3), # 新增Dropout nn.Conv2d(8, 16, 4, 2, 1, bias=False), # 32→16 nn.BatchNorm2d(16), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(0.3), # 新增Dropout nn.Conv2d(16, 32, 4, 2, 1, bias=False), # 64→32 nn.BatchNorm2d(32), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(0.3), # 新增Dropout nn.Conv2d(32, 1, 4, 1, 0, bias=False), # 去掉最后一层卷积 nn.AdaptiveAvgPool2d(1) ) def forward(self, img): output = self.main(img) return output.squeeze(-1).squeeze(-1) # ===================== 数据集 ===================== class GrapeBlackRotDataset(Dataset): def __init__(self, dataset_dirs, transform=None, cache_masks=True): self.images_dir = dataset_dirs["images"] self.masks_dir = dataset_dirs["masks"] self.transform = transform if transform else self._default_transform() self.cache_masks = cache_masks self.mask_cache_dir = ensure_dir(os.path.join(self.masks_dir, "cache_unet")) self.samples = self._load_all_samples() self._validate_masks() def _default_transform(self): return transforms.Compose([ transforms.Resize((256, 256)), transforms.RandomHorizontalFlip(p=0.2), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) def _load_all_samples(self): samples = [] valid_ext = ('.png', '.jpg', '.jpeg') try: if not os.path.exists(self.images_dir): raise RuntimeError(f"图片目录不存在:{self.images_dir}") for filename in os.listdir(self.images_dir): if filename.lower().endswith(valid_ext): img_path = os.path.join(self.images_dir, filename) samples.append(img_path) except Exception as e: raise RuntimeError(f"加载数据集失败:{str(e)}") if len(samples) == 0: raise RuntimeError("未找到有效图片样本") print(f"成功加载 {len(samples)} 个葡萄黑腐病样本") return samples def _validate_masks(self): print("\n正在用UNet校验所有样本的掩码有效性...") valid_count = 0 step = max(1, len(self.samples) // 10) if len(self.samples) > 10 else 1 sample_subset = self.samples[::step] for idx, img_path in enumerate(tqdm(sample_subset, desc="掩码校验", mininterval=1.0)): try: mask, mask_np, img_np = self._get_lesion_mask_with_vis(img_path) valid_ratio, is_valid = check_mask_validity(mask) if is_valid: valid_count += 1 except Exception as e: continue valid_count = min(valid_count * step, len(self.samples)) valid_ratio = valid_count / len(self.samples) print(f"\n掩码有效性统计:{valid_count}/{len(self.samples)} 样本有有效病斑(占比 {valid_ratio:.2f})") if valid_count == 0: raise RuntimeError("所有样本掩码无效!请检查数据集是否包含黑腐病病斑") def _get_lesion_mask_with_vis(self, img_path): img_hash = get_file_hash(img_path) cache_path = os.path.join(self.mask_cache_dir, f"{img_hash}.pt") cache_vis_path = os.path.join(self.mask_cache_dir, f"{img_hash}_vis.npz") if self.cache_masks and os.path.exists(cache_path) and os.path.exists(cache_vis_path): try: mask = torch.load(cache_path, map_location='cpu', weights_only=True) data = np.load(cache_vis_path) mask = self._unify_mask_dim(mask) return mask, data['mask_np'], data['img_np'] except Exception as e: print(f"加载缓存失败,重新生成:{str(e)[:30]}") with Image.open(img_path).convert('RGB') as img: mask_tensor, mask_np, img_np = global_mask_generator.segment_lesion(img, return_cpu=True) mask = self._unify_mask_dim(mask_tensor) try: torch.save(mask, cache_path) np.savez(cache_vis_path, mask_np=mask_np, img_np=img_np) except Exception as e: print(f"保存缓存警告:{str(e)[:30]}") return mask, mask_np, img_np def _unify_mask_dim(self, mask): mask = mask.squeeze() if len(mask.shape) == 2: mask = mask.unsqueeze(0) if mask.shape[1:] != (256, 256): mask = nn.functional.interpolate( mask.unsqueeze(0).cpu(), size=(256, 256), mode='bilinear', align_corners=False ).squeeze(0) return mask def __len__(self): return len(self.samples) def __getitem__(self, idx): if idx < 0 or idx >= len(self): raise IndexError(f"索引 {idx} 超出数据集范围 [0, {len(self)-1}]") img_path = self.samples[idx] try: with Image.open(img_path).convert('RGB') as img: img_tensor = self.transform(img).cpu() lesion_mask, _, _ = self._get_lesion_mask_with_vis(img_path) lesion_mask = self._unify_mask_dim(lesion_mask).cpu() return img_tensor, lesion_mask, img_path except Exception as e: print(f"加载样本 {img_path} 失败:{str(e)[:50]},使用备用样本") return self.__getitem__(random.randint(0, len(self)-1)) # ===================== 可视化 ===================== @torch.no_grad() def generate_lesion_comparison(real_imgs, gen_imgs, masks, epoch, save_dir): try: save_dir = ensure_dir(save_dir) num_samples = min(2, len(real_imgs), len(gen_imgs)) if num_samples == 0: return None real_imgs_norm = (real_imgs[:num_samples] + 1) / 2.0 gen_imgs_norm = (gen_imgs[:num_samples] + 1) / 2.0 for i in range(num_samples): gen_img = gen_imgs_norm[i] min_val = gen_img.min() max_val = gen_img.max() if max_val - min_val > 1e-4: gen_img = (gen_img - min_val) / (max_val - min_val) * 1.1 - 0.05 gen_imgs_norm[i] = torch.clamp(gen_img, 0, 1) comparison_list = [] for i in range(num_samples): blank = torch.ones_like(real_imgs_norm[i]) * 1.0 row = torch.cat([real_imgs_norm[i], blank[:, :, 0:20], gen_imgs_norm[i]], dim=2) comparison_list.append(row) comparison = torch.cat(comparison_list, dim=1) save_path = os.path.join(save_dir, f"epoch_{epoch:03d}_real_vs_gen.png") save_path = os.path.normpath(save_path) save_image(comparison, save_path, padding=30, normalize=True) print(f"对比图保存成功:{save_path}") return save_path except Exception as e: print(f"保存对比图失败:{str(e)} | 保存路径:{save_dir}") return None def plot_lesion_metrics(epoch, losses, metrics, save_dir): if epoch % 50 != 0: return try: save_dir = ensure_dir(save_dir) plt.figure(figsize=(15, 5)) plt.subplot(1, 3, 1) plt.plot(losses['D'], label='判别器损失', color='red', linewidth=1.5) plt.plot(losses['G'], label='生成器损失', color='blue', linewidth=1.5) plt.axhline(y=0.6, color='orange', linestyle='--', alpha=0.6, label='D_loss目标线') plt.axhline(y=2.5, color='green', linestyle='--', alpha=0.6, label='G_loss目标线') plt.legend(fontsize=10) plt.title(f'损失曲线 (Epoch {epoch})', fontsize=12) plt.xlabel('Epoch', fontsize=10) plt.ylabel('损失', fontsize=10) plt.grid(alpha=0.3) plt.subplot(1, 3, 2) plt.plot(metrics['f1'], label='病斑F1', color='green', linewidth=1.5) plt.plot(metrics['iou'], label='病斑IoU', color='purple', linewidth=1.5) plt.legend(fontsize=10) plt.title(f'生成质量指标', fontsize=12) plt.xlabel('Epoch', fontsize=10) plt.grid(alpha=0.3) plt.subplot(1, 3, 3) plt.plot(metrics['ssim'], label='SSIM(清晰度)', color='teal', linewidth=1.5, marker='.', markersize=4) plt.axhline(y=0.82, color='red', linestyle='--', alpha=0.6, label='脉络清晰阈值') plt.legend(fontsize=10) plt.title(f'脉络清晰度指标', fontsize=12) plt.xlabel('Epoch', fontsize=10) plt.grid(alpha=0.3) plt.tight_layout() save_path = os.path.join(save_dir, f"metrics_epoch_{epoch:03d}.png") save_path = os.path.normpath(save_path) plt.savefig(save_path, dpi=150) plt.close() except Exception as e: print(f"绘制指标图失败:{str(e)[:50]}") return # ===================== 全局配置 & 主函数(终极调参) ===================== torch.manual_seed(42) np.random.seed(42) device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') print(f"使用设备: {device}") if torch.cuda.is_available(): print(f"GPU名称: {torch.cuda.get_device_name(0)} | 显存: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f}GB") global_mask_generator = UNetLesionMaskGenerator(device=device) # 终极调参:彻底平衡D/G batch_size = 4 latent_dim = 100 epochs = 600 lr_G = 2e-4 # 大幅提高生成器学习率(1.2e-4→2e-4) lr_D = 5e-6 # 大幅降低判别器学习率(2e-5→5e-6) save_interval = 1 num_workers = 0 max_grad_norm = 1.0 # 完全放开梯度裁剪(0.6→1.0) root_output_dir = ensure_dir(os.path.join(os.getcwd(), "lesion_gan", "grape_black_rot_optimized")) output_dirs = { "images": ensure_dir(os.path.join(root_output_dir, "comparison")), "models": ensure_dir(os.path.join(root_output_dir, "models")), "metrics": ensure_dir(os.path.join(root_output_dir, "metrics")) } checkpoint_path = os.path.normpath(os.path.join(output_dirs["models"], "lesion_model_optimized.pth")) dataset_dirs = { "images": r"D:\code\pythonroad\root_dir\Grape___Black_rot", "masks": ensure_dir(r"D:\plant_disease\grape_black_rot_masks_v6") } def main(): try: dataset = GrapeBlackRotDataset(dataset_dirs) print(f"数据集总样本数: {len(dataset)}") dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True, pin_memory=True ) print(f"\nDataLoader配置:batch_size={batch_size}, num_workers={num_workers}") print(f"DataLoader批次数量: {len(dataloader)}") print(f"训练轮数:{epochs} 轮(预计总时长:{epochs * 1.0 / 60:.1f} 小时)") generator = LesionFocusGenerator(latent_dim=latent_dim, img_size=256).to(device) discriminator = BalancedDiscriminator().to(device) # 生成器用更高的学习率,判别器用极低学习率 optimizer_G = optim.Adam(generator.parameters(), lr=lr_G, betas=(0.5, 0.999)) optimizer_D = optim.Adam(discriminator.parameters(), lr=lr_D, betas=(0.5, 0.999)) # 判别器学习率快速衰减,生成器缓慢衰减 scheduler_G = CosineAnnealingWarmRestarts(optimizer_G, T_0=200, T_mult=2, eta_min=5e-6) scheduler_D = CosineAnnealingWarmRestarts(optimizer_D, T_0=50, T_mult=2, eta_min=1e-7) scaler = GradScaler(enabled=True, init_scale=2.**16) bce = nn.BCEWithLogitsLoss() loss_fn = LesionFocusLoss(device=device) start_epoch = 0 losses = {'D': [], 'G': []} metrics = {'f1': [], 'iou': [], 'ssim': []} # 跳过加载旧检查点(结构已改) print("跳过加载旧检查点,使用全新模型从头训练") for epoch in range(start_epoch, epochs): generator.train() discriminator.train() epoch_loss_D = 0.0 epoch_loss_G = 0.0 epoch_f1 = 0.0 epoch_iou = 0.0 epoch_ssim = 0.0 batch_count = 0 pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch}/{epochs}", mininterval=2.0, ncols=100) for batch_idx, (real_imgs, masks, _) in pbar: real_imgs = real_imgs.to(device, non_blocking=True) masks = masks.to(device, non_blocking=True) batch_size_curr = real_imgs.shape[0] batch_count += 1 # 超软标签:几乎不给判别器置信度 real_label = torch.full((batch_size_curr, 1), 0.95, device=device) # 0.8→0.95 fake_label = torch.full((batch_size_curr, 1), 0.05, device=device) # 0.2→0.05 # 训练判别器:极低频(每4个batch训练1次) if batch_idx % 4 == 0: optimizer_D.zero_grad(set_to_none=True) with torch.no_grad(): noise = torch.randn(batch_size_curr, latent_dim, device=device) fake_imgs = generator(noise, masks).detach() with autocast(): real_pred = discriminator(real_imgs) fake_pred = discriminator(fake_imgs) # 判别器损失减半 loss_D = (bce(real_pred, real_label) + bce(fake_pred, fake_label)) * 0.25 scaler.scale(loss_D).backward() scaler.unscale_(optimizer_D) torch.nn.utils.clip_grad_norm_(discriminator.parameters(), max_grad_norm) scaler.step(optimizer_D) scaler.update() epoch_loss_D += loss_D.item() * 4 # 补偿训练频率 # 训练生成器:每次batch训练2次(增强生成器) for _ in range(2): optimizer_G.zero_grad(set_to_none=True) with autocast(): noise = torch.randn(batch_size_curr, latent_dim, device=device) fake_imgs = generator(noise, masks) fake_pred = discriminator(fake_imgs) # 对抗损失权重拉满 adv_loss = bce(fake_pred, real_label) * 1.0 # 0.8→1.0 content_loss = loss_fn(fake_imgs, masks, real_imgs) # 生成器损失:优先对抗,次要内容 loss_G = adv_loss * 0.8 + content_loss * 0.2 scaler.scale(loss_G).backward() scaler.unscale_(optimizer_G) torch.nn.utils.clip_grad_norm_(generator.parameters(), max_grad_norm) scaler.step(optimizer_G) scaler.update() epoch_loss_G += loss_G.item() if batch_idx % 10 == 0: avg_loss_D = epoch_loss_D / max(batch_count//4, 1) avg_loss_G = epoch_loss_G / batch_count pbar.set_postfix({ 'D_loss': f"{avg_loss_D:.3f}", 'G_loss': f"{min(avg_loss_G, 10.0):.3f}", 'F1': f"{epoch_f1:.3f}", 'IoU': f"{epoch_iou:.3f}", 'SSIM': f"{epoch_ssim:.3f}" }, refresh=True) seg_metrics = global_mask_generator.calculate_precision_metrics(real_imgs, fake_imgs) if seg_metrics["valid"]: epoch_f1 = seg_metrics["pixel"]["f1"] epoch_iou = seg_metrics["pixel"]["iou"] epoch_ssim = seg_metrics["pixel"]["ssim"] avg_loss_D = epoch_loss_D / max(len(dataloader)//4, 1) avg_loss_G = epoch_loss_G / len(dataloader) losses['D'].append(min(avg_loss_D, 10.0)) losses['G'].append(min(avg_loss_G, 10.0)) metrics['f1'].append(epoch_f1) metrics['iou'].append(epoch_iou) metrics['ssim'].append(epoch_ssim) print(f"\nEpoch {epoch} | D_loss: {avg_loss_D:.3f} | G_loss: {avg_loss_G:.3f}") print(f"F1: {epoch_f1:.3f} | IoU: {epoch_iou:.3f} | SSIM: {epoch_ssim:.3f}") print(f"学习率:G={scheduler_G.get_last_lr()[0]:.6f}, D={scheduler_D.get_last_lr()[0]:.7f}") scheduler_G.step() scheduler_D.step() generator.eval() with torch.no_grad(): noise = torch.randn(batch_size_curr, latent_dim, device=device) fake_imgs = generator(noise, masks) generate_lesion_comparison(real_imgs, fake_imgs, masks, epoch, output_dirs["images"]) generator.train() if (epoch + 1) % 10 == 0: plot_lesion_metrics(epoch, losses, metrics, output_dirs["metrics"]) try: checkpoint = { 'epoch': epoch, 'gen': generator.state_dict(), 'dis': discriminator.state_dict(), 'opt_g': optimizer_G.state_dict(), 'opt_d': optimizer_D.state_dict(), 'losses': losses, 'metrics': metrics } torch.save(checkpoint, checkpoint_path) if epoch > 50 and epoch_f1 > 0.8 and epoch_iou > 0.7 and epoch_ssim > 0.5: best_path = os.path.normpath( os.path.join(output_dirs["models"], f"best_model_epoch_{epoch}_f1_{epoch_f1:.3f}_ssim_{epoch_ssim:.3f}.pth")) torch.save(checkpoint, best_path) print(f"保存最佳模型至:{best_path}") except Exception as e: print(f"保存模型警告:{str(e)[:50]}") torch.cuda.empty_cache() print("\n训练完成!") print(f"对比图路径:{output_dirs['images']}") print(f"最终模型路径:{checkpoint_path}") print(f"指标图路径:{output_dirs['metrics']}") except KeyboardInterrupt: print("\n用户中断训练,保存当前进度...") torch.save({ 'epoch': epoch, 'gen': generator.state_dict(), 'dis': discriminator.state_dict(), 'losses': losses, 'metrics': metrics }, checkpoint_path.replace(".pth", "_interrupt.pth")) except Exception as e: print(f"训练异常终止:{str(e)}") torch.cuda.empty_cache() raise if __name__ == "__main__": main()代码生成的图像有网格化,不清晰的毛病
最新发布
12-01
# Ultralytics YOLOv5 🚀, AGPL-3.0 license """ Train a YOLOv5 model on a custom dataset. Models and datasets download automatically from the latest YOLOv5 release. Usage - Single-GPU training: $ python train.py --data coco128.yaml --weights yolov5s.pt --img 640 # from pretrained (recommended) $ python train.py --data coco128.yaml --weights '' --cfg yolov5s.yaml --img 640 # from scratch Usage - Multi-GPU DDP training: $ python -m torch.distributed.run --nproc_per_node 4 --master_port 1 train.py --data coco128.yaml --weights yolov5s.pt --img 640 --device 0,1,2,3 Models: https://github.com/ultralytics/yolov5/tree/master/models Datasets: https://github.com/ultralytics/yolov5/tree/master/data Tutorial: https://docs.ultralytics.com/yolov5/tutorials/train_custom_data """ import argparse import math import os os.environ["GIT_PYTHON_REFRESH"] = "quiet" os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" import random import subprocess import sys import time from copy import deepcopy from datetime import datetime, timedelta from pathlib import Path try: import comet_ml # must be imported before torch (if installed) except ImportError: comet_ml = None import numpy as np import torch import torch.distributed as dist import torch.nn as nn import yaml from torch.optim import lr_scheduler from tqdm import tqdm FILE = Path(__file__).resolve() ROOT = FILE.parents[0] # YOLOv5 root directory if str(ROOT) not in sys.path: sys.path.append(str(ROOT)) # add ROOT to PATH ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative import val as validate # for end-of-epoch mAP from models.experimental import attempt_load from models.yolo import Model from utils.autoanchor import check_anchors from utils.autobatch import check_train_batch_size from utils.callbacks import Callbacks from utils.dataloaders import create_dataloader from utils.downloads import attempt_download, is_url from utils.general import ( LOGGER, TQDM_BAR_FORMAT, check_amp, check_dataset, check_file, check_git_info, check_git_status, check_img_size, check_requirements, check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods, one_cycle, print_args, print_mutation, strip_optimizer, yaml_save, ) from utils.loggers import LOGGERS, Loggers from utils.loggers.comet.comet_utils import check_comet_resume from utils.loss import ComputeLoss from utils.metrics import fitness from utils.plots import plot_evolve from utils.torch_utils import ( EarlyStopping, ModelEMA, de_parallel, select_device, smart_DDP, smart_optimizer, smart_resume, torch_distributed_zero_first, ) LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html RANK = int(os.getenv("RANK", -1)) WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1)) GIT_INFO = check_git_info() def train(hyp, opt, device, callbacks): """ Train a YOLOv5 model on a custom dataset using specified hyperparameters, options, and device, managing datasets, model architecture, loss computation, and optimizer steps. Args: hyp (str | dict): Path to the hyperparameters YAML file or a dictionary of hyperparameters. opt (argparse.Namespace): Parsed command-line arguments containing training options. device (torch.device): Device on which training occurs, e.g., 'cuda' or 'cpu'. callbacks (Callbacks): Callback functions for various training events. Returns: None Models and datasets download automatically from the latest YOLOv5 release. Example: Single-GPU training: ```bash $ python train.py --data coco128.yaml --weights yolov5s.pt --img 640 # from pretrained (recommended) $ python train.py --data coco128.yaml --weights '' --cfg yolov5s.yaml --img 640 # from scratch ``` Multi-GPU DDP training: ```bash $ python -m torch.distributed.run --nproc_per_node 4 --master_port 1 train.py --data coco128.yaml --weights yolov5s.pt --img 640 --device 0,1,2,3 ``` For more usage details, refer to: - Models: https://github.com/ultralytics/yolov5/tree/master/models - Datasets: https://github.com/ultralytics/yolov5/tree/master/data - Tutorial: https://docs.ultralytics.com/yolov5/tutorials/train_custom_data """ save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze = ( Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, opt.resume, opt.noval, opt.nosave, opt.workers, opt.freeze, ) callbacks.run("on_pretrain_routine_start") # Directories w = save_dir / "weights" # weights dir (w.parent if evolve else w).mkdir(parents=True, exist_ok=True) # make dir last, best = w / "last.pt", w / "best.pt" # Hyperparameters if isinstance(hyp, str): with open(hyp, errors="ignore") as f: hyp = yaml.safe_load(f) # load hyps dict LOGGER.info(colorstr("hyperparameters: ") + ", ".join(f"{k}={v}" for k, v in hyp.items())) opt.hyp = hyp.copy() # for saving hyps to checkpoints # Save run settings if not evolve: yaml_save(save_dir / "hyp.yaml", hyp) yaml_save(save_dir / "opt.yaml", vars(opt)) # Loggers data_dict = None if RANK in {-1, 0}: include_loggers = list(LOGGERS) if getattr(opt, "ndjson_console", False): include_loggers.append("ndjson_console") if getattr(opt, "ndjson_file", False): include_loggers.append("ndjson_file") loggers = Loggers( save_dir=save_dir, weights=weights, opt=opt, hyp=hyp, logger=LOGGER, include=tuple(include_loggers), ) # Register actions for k in methods(loggers): callbacks.register_action(k, callback=getattr(loggers, k)) # Process custom dataset artifact link data_dict = loggers.remote_dataset if resume: # If resuming runs from remote artifact weights, epochs, hyp, batch_size = opt.weights, opt.epochs, opt.hyp, opt.batch_size # Config plots = not evolve and not opt.noplots # create plots cuda = device.type != "cpu" init_seeds(opt.seed + 1 + RANK, deterministic=True) with torch_distributed_zero_first(LOCAL_RANK): data_dict = data_dict or check_dataset(data) # check if None train_path, val_path = data_dict["train"], data_dict["val"] nc = 1 if single_cls else int(data_dict["nc"]) # number of classes names = {0: "item"} if single_cls and len(data_dict["names"]) != 1 else data_dict["names"] # class names is_coco = isinstance(val_path, str) and val_path.endswith("coco/val2017.txt") # COCO dataset # Model check_suffix(weights, ".pt") # check weights pretrained = weights.endswith(".pt") if pretrained: with torch_distributed_zero_first(LOCAL_RANK): weights = attempt_download(weights) # download if not found locally ckpt = torch.load(weights, map_location="cpu") # load checkpoint to CPU to avoid CUDA memory leak model = Model(cfg or ckpt["model"].yaml, ch=3, nc=nc, anchors=hyp.get("anchors")).to(device) # create exclude = ["anchor"] if (cfg or hyp.get("anchors")) and not resume else [] # exclude keys csd = ckpt["model"].float().state_dict() # checkpoint state_dict as FP32 csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) # intersect model.load_state_dict(csd, strict=False) # load LOGGER.info(f"Transferred {len(csd)}/{len(model.state_dict())} items from {weights}") # report else: model = Model(cfg, ch=3, nc=nc, anchors=hyp.get("anchors")).to(device) # create amp = check_amp(model) # check AMP # Freeze freeze = [f"model.{x}." for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze for k, v in model.named_parameters(): v.requires_grad = True # train all layers # v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results) if any(x in k for x in freeze): LOGGER.info(f"freezing {k}") v.requires_grad = False # Image size gs = max(int(model.stride.max()), 32) # grid size (max stride) imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2) # verify imgsz is gs-multiple # Batch size if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size batch_size = check_train_batch_size(model, imgsz, amp) loggers.on_params_update({"batch_size": batch_size}) # Optimizer nbs = 64 # nominal batch size accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing hyp["weight_decay"] *= batch_size * accumulate / nbs # scale weight_decay optimizer = smart_optimizer(model, opt.optimizer, hyp["lr0"], hyp["momentum"], hyp["weight_decay"]) # Scheduler if opt.cos_lr: lf = one_cycle(1, hyp["lrf"], epochs) # cosine 1->hyp['lrf'] else: def lf(x): """Linear learning rate scheduler function with decay calculated by epoch proportion.""" return (1 - x / epochs) * (1.0 - hyp["lrf"]) + hyp["lrf"] # linear scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # plot_lr_scheduler(optimizer, scheduler, epochs) # EMA ema = ModelEMA(model) if RANK in {-1, 0} else None # Resume best_fitness, start_epoch = 0.0, 0 if pretrained: if resume: best_fitness, start_epoch, epochs = smart_resume(ckpt, optimizer, ema, weights, epochs, resume) del ckpt, csd # DP mode if cuda and RANK == -1 and torch.cuda.device_count() > 1: LOGGER.warning( "WARNING ⚠️ DP not recommended, use torch.distributed.run for best DDP Multi-GPU results.\n" "See Multi-GPU Tutorial at https://docs.ultralytics.com/yolov5/tutorials/multi_gpu_training to get started." ) model = torch.nn.DataParallel(model) # SyncBatchNorm if opt.sync_bn and cuda and RANK != -1: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device) LOGGER.info("Using SyncBatchNorm()") # Trainloader train_loader, dataset = create_dataloader( train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls, hyp=hyp, augment=True, cache=None if opt.cache == "val" else opt.cache, rect=opt.rect, rank=LOCAL_RANK, workers=workers, image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr("train: "), shuffle=True, seed=opt.seed, ) labels = np.concatenate(dataset.labels, 0) mlc = int(labels[:, 0].max()) # max label class assert mlc < nc, f"Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}" # Process 0 if RANK in {-1, 0}: val_loader = create_dataloader( val_path, imgsz, batch_size // WORLD_SIZE * 2, gs, single_cls, hyp=hyp, cache=None if noval else opt.cache, rect=True, rank=-1, workers=workers * 2, pad=0.5, prefix=colorstr("val: "), )[0] if not resume: if not opt.noautoanchor: check_anchors(dataset, model=model, thr=hyp["anchor_t"], imgsz=imgsz) # run AutoAnchor model.half().float() # pre-reduce anchor precision callbacks.run("on_pretrain_routine_end", labels, names) # DDP mode if cuda and RANK != -1: model = smart_DDP(model) # Model attributes nl = de_parallel(model).model[-1].nl # number of detection layers (to scale hyps) hyp["box"] *= 3 / nl # scale to layers hyp["cls"] *= nc / 80 * 3 / nl # scale to classes and layers hyp["obj"] *= (imgsz / 640) ** 2 * 3 / nl # scale to image size and layers hyp["label_smoothing"] = opt.label_smoothing model.nc = nc # attach number of classes to model model.hyp = hyp # attach hyperparameters to model model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights model.names = names # Start training t0 = time.time() nb = len(train_loader) # number of batches nw = max(round(hyp["warmup_epochs"] * nb), 100) # number of warmup iterations, max(3 epochs, 100 iterations) # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training last_opt_step = -1 maps = np.zeros(nc) # mAP per class results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls) scheduler.last_epoch = start_epoch - 1 # do not move scaler = torch.cuda.amp.GradScaler(enabled=amp) stopper, stop = EarlyStopping(patience=opt.patience), False compute_loss = ComputeLoss(model) # init loss class callbacks.run("on_train_start") LOGGER.info( f'Image sizes {imgsz} train, {imgsz} val\n' f'Using {train_loader.num_workers * WORLD_SIZE} dataloader workers\n' f"Logging results to {colorstr('bold', save_dir)}\n" f'Starting training for {epochs} epochs...' ) for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------ callbacks.run("on_train_epoch_start") model.train() # Update image weights (optional, single-GPU only) if opt.image_weights: cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx # Update mosaic border (optional) # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs) # dataset.mosaic_border = [b - imgsz, -b] # height, width borders mloss = torch.zeros(3, device=device) # mean losses if RANK != -1: train_loader.sampler.set_epoch(epoch) pbar = enumerate(train_loader) LOGGER.info(("\n" + "%11s" * 7) % ("Epoch", "GPU_mem", "box_loss", "obj_loss", "cls_loss", "Instances", "Size")) if RANK in {-1, 0}: pbar = tqdm(pbar, total=nb, bar_format=TQDM_BAR_FORMAT) # progress bar optimizer.zero_grad() for i, (imgs, targets, paths, _) in pbar: # batch ------------------------------------------------------------- callbacks.run("on_train_batch_start") ni = i + nb * epoch # number integrated batches (since train start) imgs = imgs.to(device, non_blocking=True).float() / 255 # uint8 to float32, 0-255 to 0.0-1.0 # Warmup if ni <= nw: xi = [0, nw] # x interp # compute_loss.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou) accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round()) for j, x in enumerate(optimizer.param_groups): # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0 x["lr"] = np.interp(ni, xi, [hyp["warmup_bias_lr"] if j == 0 else 0.0, x["initial_lr"] * lf(epoch)]) if "momentum" in x: x["momentum"] = np.interp(ni, xi, [hyp["warmup_momentum"], hyp["momentum"]]) # Multi-scale if opt.multi_scale: sz = random.randrange(int(imgsz * 0.5), int(imgsz * 1.5) + gs) // gs * gs # size sf = sz / max(imgs.shape[2:]) # scale factor if sf != 1: ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple) imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False) # Forward with torch.cuda.amp.autocast(amp): pred = model(imgs) # forward loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size if RANK != -1: loss *= WORLD_SIZE # gradient averaged between devices in DDP mode if opt.quad: loss *= 4.0 # Backward scaler.scale(loss).backward() # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html if ni - last_opt_step >= accumulate: scaler.unscale_(optimizer) # unscale gradients torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) # clip gradients scaler.step(optimizer) # optimizer.step scaler.update() optimizer.zero_grad() if ema: ema.update(model) last_opt_step = ni # Log if RANK in {-1, 0}: mloss = (mloss * i + loss_items) / (i + 1) # update mean losses mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G" # (GB) pbar.set_description( ("%11s" * 2 + "%11.4g" * 5) % (f"{epoch}/{epochs - 1}", mem, *mloss, targets.shape[0], imgs.shape[-1]) ) callbacks.run("on_train_batch_end", model, ni, imgs, targets, paths, list(mloss)) if callbacks.stop_training: return # end batch ------------------------------------------------------------------------------------------------ # Scheduler lr = [x["lr"] for x in optimizer.param_groups] # for loggers scheduler.step() if RANK in {-1, 0}: # mAP callbacks.run("on_train_epoch_end", epoch=epoch) ema.update_attr(model, include=["yaml", "nc", "hyp", "names", "stride", "class_weights"]) final_epoch = (epoch + 1 == epochs) or stopper.possible_stop if not noval or final_epoch: # Calculate mAP results, maps, _ = validate.run( data_dict, batch_size=batch_size // WORLD_SIZE * 2, imgsz=imgsz, half=amp, model=ema.ema, single_cls=single_cls, dataloader=val_loader, save_dir=save_dir, plots=False, callbacks=callbacks, compute_loss=compute_loss, ) # Update best mAP fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95] stop = stopper(epoch=epoch, fitness=fi) # early stop check if fi > best_fitness: best_fitness = fi log_vals = list(mloss) + list(results) + lr callbacks.run("on_fit_epoch_end", log_vals, epoch, best_fitness, fi) # Save model if (not nosave) or (final_epoch and not evolve): # if save ckpt = { "epoch": epoch, "best_fitness": best_fitness, "model": deepcopy(de_parallel(model)).half(), "ema": deepcopy(ema.ema).half(), "updates": ema.updates, "optimizer": optimizer.state_dict(), "opt": vars(opt), "git": GIT_INFO, # {remote, branch, commit} if a git repo "date": datetime.now().isoformat(), } # Save last, best and delete torch.save(ckpt, last) if best_fitness == fi: torch.save(ckpt, best) if opt.save_period > 0 and epoch % opt.save_period == 0: torch.save(ckpt, w / f"epoch{epoch}.pt") del ckpt callbacks.run("on_model_save", last, epoch, final_epoch, best_fitness, fi) # EarlyStopping if RANK != -1: # if DDP training broadcast_list = [stop if RANK == 0 else None] dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks if RANK != 0: stop = broadcast_list[0] if stop: break # must break all DDP ranks # end epoch ---------------------------------------------------------------------------------------------------- # end training ----------------------------------------------------------------------------------------------------- if RANK in {-1, 0}: LOGGER.info(f"\n{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.") for f in last, best: if f.exists(): strip_optimizer(f) # strip optimizers if f is best: LOGGER.info(f"\nValidating {f}...") results, _, _ = validate.run( data_dict, batch_size=batch_size // WORLD_SIZE * 2, imgsz=imgsz, model=attempt_load(f, device).half(), iou_thres=0.65 if is_coco else 0.60, # best pycocotools at iou 0.65 single_cls=single_cls, dataloader=val_loader, save_dir=save_dir, save_json=is_coco, verbose=True, plots=plots, callbacks=callbacks, compute_loss=compute_loss, ) # val best model with plots if is_coco: callbacks.run("on_fit_epoch_end", list(mloss) + list(results) + lr, epoch, best_fitness, fi) callbacks.run("on_train_end", last, best, epoch, results) torch.cuda.empty_cache() return results def parse_opt(known=False): """ Parse command-line arguments for YOLOv5 training, validation, and testing. Args: known (bool, optional): If True, parses known arguments, ignoring the unknown. Defaults to False. Returns: (argparse.Namespace): Parsed command-line arguments containing options for YOLOv5 execution. Example: ```python from ultralytics.yolo import parse_opt opt = parse_opt() print(opt) ``` Links: - Models: https://github.com/ultralytics/yolov5/tree/master/models - Datasets: https://github.com/ultralytics/yolov5/tree/master/data - Tutorial: https://docs.ultralytics.com/yolov5/tutorials/train_custom_data """ parser = argparse.ArgumentParser() parser.add_argument("--weights", type=str, default=ROOT / r"E:/yolov5-master/yolov5m.pt", help="initial weights path") parser.add_argument("--cfg", type=str, default=r"models/yolov5m.yaml", help="model.yaml path") parser.add_argument("--data", type=str, default=ROOT / r"E:/yolov5-master/data/data.yaml", help="dataset.yaml path") parser.add_argument("--hyp", type=str, default=ROOT / "data/hyps/hyp.scratch-low.yaml", help="hyperparameters path") parser.add_argument("--epochs", type=int, default=150, help="total training epochs") parser.add_argument("--batch-size", type=int, default=16, help="total batch size for all GPUs, -1 for autobatch") parser.add_argument("--imgsz", "--img", "--img-size", type=int, default=640, help="train, val image size (pixels)") parser.add_argument("--rect", action="store_true", help="rectangular training") parser.add_argument("--resume", nargs="?", const=True, default=False, help="resume most recent training") parser.add_argument("--nosave", action="store_true", help="only save final checkpoint") parser.add_argument("--noval", action="store_true", help="only validate final epoch") parser.add_argument("--noautoanchor", action="store_true", help="disable AutoAnchor") parser.add_argument("--noplots", action="store_true", help="save no plot files") parser.add_argument("--evolve", type=int, nargs="?", const=300, help="evolve hyperparameters for x generations") parser.add_argument( "--evolve_population", type=str, default=ROOT / "data/hyps", help="location for loading population" ) parser.add_argument("--resume_evolve", type=str, default=None, help="resume evolve from last generation") parser.add_argument("--bucket", type=str, default="", help="gsutil bucket") parser.add_argument("--cache", type=str, nargs="?", const="ram", help="image --cache ram/disk") parser.add_argument("--image-weights", action="store_true", help="use weighted image selection for training") parser.add_argument("--device", default="", help="cuda device, i.e. 0 or 0,1,2,3 or cpu") parser.add_argument("--multi-scale", action="store_true", help="vary img-size +/- 50%%") parser.add_argument("--single-cls", action="store_true", help="train multi-class data as single-class") parser.add_argument("--optimizer", type=str, choices=["SGD", "Adam", "AdamW"], default="SGD", help="optimizer") parser.add_argument("--sync-bn", action="store_true", help="use SyncBatchNorm, only available in DDP mode") parser.add_argument("--workers", type=int, default=8, help="max dataloader workers (per RANK in DDP mode)") parser.add_argument("--project", default=ROOT / "runs/train", help="save to project/name") parser.add_argument("--name", default="exp", help="save to project/name") parser.add_argument("--exist-ok", action="store_true", help="existing project/name ok, do not increment") parser.add_argument("--quad", action="store_true", help="quad dataloader") parser.add_argument("--cos-lr", action="store_true", help="cosine LR scheduler") parser.add_argument("--label-smoothing", type=float, default=0.0, help="Label smoothing epsilon") parser.add_argument("--patience", type=int, default=100, help="EarlyStopping patience (epochs without improvement)") parser.add_argument("--freeze", nargs="+", type=int, default=[0], help="Freeze layers: backbone=10, first3=0 1 2") parser.add_argument("--save-period", type=int, default=-1, help="Save checkpoint every x epochs (disabled if < 1)") parser.add_argument("--seed", type=int, default=0, help="Global training seed") parser.add_argument("--local_rank", type=int, default=-1, help="Automatic DDP Multi-GPU argument, do not modify") # Logger arguments parser.add_argument("--entity", default=None, help="Entity") parser.add_argument("--upload_dataset", nargs="?", const=True, default=False, help='Upload data, "val" option') parser.add_argument("--bbox_interval", type=int, default=-1, help="Set bounding-box image logging interval") parser.add_argument("--artifact_alias", type=str, default="latest", help="Version of dataset artifact to use") # NDJSON logging parser.add_argument("--ndjson-console", action="store_true", help="Log ndjson to console") parser.add_argument("--ndjson-file", action="store_true", help="Log ndjson to file") return parser.parse_known_args()[0] if known else parser.parse_args() def main(opt, callbacks=Callbacks()): """ Runs the main entry point for training or hyperparameter evolution with specified options and optional callbacks. Args: opt (argparse.Namespace): The command-line arguments parsed for YOLOv5 training and evolution. callbacks (ultralytics.utils.callbacks.Callbacks, optional): Callback functions for various training stages. Defaults to Callbacks(). Returns: None Note: For detailed usage, refer to: https://github.com/ultralytics/yolov5/tree/master/models """ if RANK in {-1, 0}: print_args(vars(opt)) check_git_status() check_requirements(ROOT / "requirements.txt") # Resume (from specified or most recent last.pt) if opt.resume and not check_comet_resume(opt) and not opt.evolve: last = Path(check_file(opt.resume) if isinstance(opt.resume, str) else get_latest_run()) opt_yaml = last.parent.parent / "opt.yaml" # train options yaml opt_data = opt.data # original dataset if opt_yaml.is_file(): with open(opt_yaml, errors="ignore") as f: d = yaml.safe_load(f) else: d = torch.load(last, map_location="cpu")["opt"] opt = argparse.Namespace(**d) # replace opt.cfg, opt.weights, opt.resume = "", str(last), True # reinstate if is_url(opt_data): opt.data = check_file(opt_data) # avoid HUB resume auth timeout else: opt.data, opt.cfg, opt.hyp, opt.weights, opt.project = ( check_file(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp), str(opt.weights), str(opt.project), ) # checks assert len(opt.cfg) or len(opt.weights), "either --cfg or --weights must be specified" if opt.evolve: if opt.project == str(ROOT / "runs/train"): # if default project name, rename to runs/evolve opt.project = str(ROOT / "runs/evolve") opt.exist_ok, opt.resume = opt.resume, False # pass resume to exist_ok and disable resume if opt.name == "cfg": opt.name = Path(opt.cfg).stem # use model.yaml as name opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # DDP mode device = select_device(opt.device, batch_size=opt.batch_size) if LOCAL_RANK != -1: msg = "is not compatible with YOLOv5 Multi-GPU DDP training" assert not opt.image_weights, f"--image-weights {msg}" assert not opt.evolve, f"--evolve {msg}" assert opt.batch_size != -1, f"AutoBatch with --batch-size -1 {msg}, please pass a valid --batch-size" assert opt.batch_size % WORLD_SIZE == 0, f"--batch-size {opt.batch_size} must be multiple of WORLD_SIZE" assert torch.cuda.device_count() > LOCAL_RANK, "insufficient CUDA devices for DDP command" torch.cuda.set_device(LOCAL_RANK) device = torch.device("cuda", LOCAL_RANK) dist.init_process_group( backend="nccl" if dist.is_nccl_available() else "gloo", timeout=timedelta(seconds=10800) ) # Train if not opt.evolve: train(opt.hyp, opt, device, callbacks) # Evolve hyperparameters (optional) else: # Hyperparameter evolution metadata (including this hyperparameter True-False, lower_limit, upper_limit) meta = { "lr0": (False, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3) "lrf": (False, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf) "momentum": (False, 0.6, 0.98), # SGD momentum/Adam beta1 "weight_decay": (False, 0.0, 0.001), # optimizer weight decay "warmup_epochs": (False, 0.0, 5.0), # warmup epochs (fractions ok) "warmup_momentum": (False, 0.0, 0.95), # warmup initial momentum "warmup_bias_lr": (False, 0.0, 0.2), # warmup initial bias lr "box": (False, 0.02, 0.2), # box loss gain "cls": (False, 0.2, 4.0), # cls loss gain "cls_pw": (False, 0.5, 2.0), # cls BCELoss positive_weight "obj": (False, 0.2, 4.0), # obj loss gain (scale with pixels) "obj_pw": (False, 0.5, 2.0), # obj BCELoss positive_weight "iou_t": (False, 0.1, 0.7), # IoU training threshold "anchor_t": (False, 2.0, 8.0), # anchor-multiple threshold "anchors": (False, 2.0, 10.0), # anchors per output grid (0 to ignore) "fl_gamma": (False, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5) "hsv_h": (True, 0.0, 0.1), # image HSV-Hue augmentation (fraction) "hsv_s": (True, 0.0, 0.9), # image HSV-Saturation augmentation (fraction) "hsv_v": (True, 0.0, 0.9), # image HSV-Value augmentation (fraction) "degrees": (True, 0.0, 45.0), # image rotation (+/- deg) "translate": (True, 0.0, 0.9), # image translation (+/- fraction) "scale": (True, 0.0, 0.9), # image scale (+/- gain) "shear": (True, 0.0, 10.0), # image shear (+/- deg) "perspective": (True, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001 "flipud": (True, 0.0, 1.0), # image flip up-down (probability) "fliplr": (True, 0.0, 1.0), # image flip left-right (probability) "mosaic": (True, 0.0, 1.0), # image mosaic (probability) "mixup": (True, 0.0, 1.0), # image mixup (probability) "copy_paste": (True, 0.0, 1.0), # segment copy-paste (probability) } # GA configs pop_size = 50 mutation_rate_min = 0.01 mutation_rate_max = 0.5 crossover_rate_min = 0.5 crossover_rate_max = 1 min_elite_size = 2 max_elite_size = 5 tournament_size_min = 2 tournament_size_max = 10 with open(opt.hyp, errors="ignore") as f: hyp = yaml.safe_load(f) # load hyps dict if "anchors" not in hyp: # anchors commented in hyp.yaml hyp["anchors"] = 3 if opt.noautoanchor: del hyp["anchors"], meta["anchors"] opt.noval, opt.nosave, save_dir = True, True, Path(opt.save_dir) # only val/save final epoch # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices evolve_yaml, evolve_csv = save_dir / "hyp_evolve.yaml", save_dir / "evolve.csv" if opt.bucket: # download evolve.csv if exists subprocess.run( [ "gsutil", "cp", f"gs://{opt.bucket}/evolve.csv", str(evolve_csv), ] ) # Delete the items in meta dictionary whose first value is False del_ = [item for item, value_ in meta.items() if value_[0] is False] hyp_GA = hyp.copy() # Make a copy of hyp dictionary for item in del_: del meta[item] # Remove the item from meta dictionary del hyp_GA[item] # Remove the item from hyp_GA dictionary # Set lower_limit and upper_limit arrays to hold the search space boundaries lower_limit = np.array([meta[k][1] for k in hyp_GA.keys()]) upper_limit = np.array([meta[k][2] for k in hyp_GA.keys()]) # Create gene_ranges list to hold the range of values for each gene in the population gene_ranges = [(lower_limit[i], upper_limit[i]) for i in range(len(upper_limit))] # Initialize the population with initial_values or random values initial_values = [] # If resuming evolution from a previous checkpoint if opt.resume_evolve is not None: assert os.path.isfile(ROOT / opt.resume_evolve), "evolve population path is wrong!" with open(ROOT / opt.resume_evolve, errors="ignore") as f: evolve_population = yaml.safe_load(f) for value in evolve_population.values(): value = np.array([value[k] for k in hyp_GA.keys()]) initial_values.append(list(value)) # If not resuming from a previous checkpoint, generate initial values from .yaml files in opt.evolve_population else: yaml_files = [f for f in os.listdir(opt.evolve_population) if f.endswith(".yaml")] for file_name in yaml_files: with open(os.path.join(opt.evolve_population, file_name)) as yaml_file: value = yaml.safe_load(yaml_file) value = np.array([value[k] for k in hyp_GA.keys()]) initial_values.append(list(value)) # Generate random values within the search space for the rest of the population if initial_values is None: population = [generate_individual(gene_ranges, len(hyp_GA)) for _ in range(pop_size)] elif pop_size > 1: population = [generate_individual(gene_ranges, len(hyp_GA)) for _ in range(pop_size - len(initial_values))] for initial_value in initial_values: population = [initial_value] + population # Run the genetic algorithm for a fixed number of generations list_keys = list(hyp_GA.keys()) for generation in range(opt.evolve): if generation >= 1: save_dict = {} for i in range(len(population)): little_dict = {list_keys[j]: float(population[i][j]) for j in range(len(population[i]))} save_dict[f"gen{str(generation)}number{str(i)}"] = little_dict with open(save_dir / "evolve_population.yaml", "w") as outfile: yaml.dump(save_dict, outfile, default_flow_style=False) # Adaptive elite size elite_size = min_elite_size + int((max_elite_size - min_elite_size) * (generation / opt.evolve)) # Evaluate the fitness of each individual in the population fitness_scores = [] for individual in population: for key, value in zip(hyp_GA.keys(), individual): hyp_GA[key] = value hyp.update(hyp_GA) results = train(hyp.copy(), opt, device, callbacks) callbacks = Callbacks() # Write mutation results keys = ( "metrics/precision", "metrics/recall", "metrics/mAP_0.5", "metrics/mAP_0.5:0.95", "val/box_loss", "val/obj_loss", "val/cls_loss", ) print_mutation(keys, results, hyp.copy(), save_dir, opt.bucket) fitness_scores.append(results[2]) # Select the fittest individuals for reproduction using adaptive tournament selection selected_indices = [] for _ in range(pop_size - elite_size): # Adaptive tournament size tournament_size = max( max(2, tournament_size_min), int(min(tournament_size_max, pop_size) - (generation / (opt.evolve / 10))), ) # Perform tournament selection to choose the best individual tournament_indices = random.sample(range(pop_size), tournament_size) tournament_fitness = [fitness_scores[j] for j in tournament_indices] winner_index = tournament_indices[tournament_fitness.index(max(tournament_fitness))] selected_indices.append(winner_index) # Add the elite individuals to the selected indices elite_indices = [i for i in range(pop_size) if fitness_scores[i] in sorted(fitness_scores)[-elite_size:]] selected_indices.extend(elite_indices) # Create the next generation through crossover and mutation next_generation = [] for _ in range(pop_size): parent1_index = selected_indices[random.randint(0, pop_size - 1)] parent2_index = selected_indices[random.randint(0, pop_size - 1)] # Adaptive crossover rate crossover_rate = max( crossover_rate_min, min(crossover_rate_max, crossover_rate_max - (generation / opt.evolve)) ) if random.uniform(0, 1) < crossover_rate: crossover_point = random.randint(1, len(hyp_GA) - 1) child = population[parent1_index][:crossover_point] + population[parent2_index][crossover_point:] else: child = population[parent1_index] # Adaptive mutation rate mutation_rate = max( mutation_rate_min, min(mutation_rate_max, mutation_rate_max - (generation / opt.evolve)) ) for j in range(len(hyp_GA)): if random.uniform(0, 1) < mutation_rate: child[j] += random.uniform(-0.1, 0.1) child[j] = min(max(child[j], gene_ranges[j][0]), gene_ranges[j][1]) next_generation.append(child) # Replace the old population with the new generation population = next_generation # Print the best solution found best_index = fitness_scores.index(max(fitness_scores)) best_individual = population[best_index] print("Best solution found:", best_individual) # Plot results plot_evolve(evolve_csv) LOGGER.info( f'Hyperparameter evolution finished {opt.evolve} generations\n' f"Results saved to {colorstr('bold', save_dir)}\n" f'Usage example: $ python train.py --hyp {evolve_yaml}' ) def generate_individual(input_ranges, individual_length): """ Generate an individual with random hyperparameters within specified ranges. Args: input_ranges (list[tuple[float, float]]): List of tuples where each tuple contains the lower and upper bounds for the corresponding gene (hyperparameter). individual_length (int): The number of genes (hyperparameters) in the individual. Returns: list[float]: A list representing a generated individual with random gene values within the specified ranges. Example: ```python input_ranges = [(0.01, 0.1), (0.1, 1.0), (0.9, 2.0)] individual_length = 3 individual = generate_individual(input_ranges, individual_length) print(individual) # Output: [0.035, 0.678, 1.456] (example output) ``` Note: The individual returned will have a length equal to `individual_length`, with each gene value being a floating-point number within its specified range in `input_ranges`. """ individual = [] for i in range(individual_length): lower_bound, upper_bound = input_ranges[i] individual.append(random.uniform(lower_bound, upper_bound)) return individual def run(**kwargs): """ Execute YOLOv5 training with specified options, allowing optional overrides through keyword arguments. Args: weights (str, optional): Path to initial weights. Defaults to ROOT / 'yolov5s.pt'. cfg (str, optional): Path to model YAML configuration. Defaults to an empty string. data (str, optional): Path to dataset YAML configuration. Defaults to ROOT / 'data/coco128.yaml'. hyp (str, optional): Path to hyperparameters YAML configuration. Defaults to ROOT / 'data/hyps/hyp.scratch-low.yaml'. epochs (int, optional): Total number of training epochs. Defaults to 100. batch_size (int, optional): Total batch size for all GPUs. Use -1 for automatic batch size determination. Defaults to 16. imgsz (int, optional): Image size (pixels) for training and validation. Defaults to 640. rect (bool, optional): Use rectangular training. Defaults to False. resume (bool | str, optional): Resume most recent training with an optional path. Defaults to False. nosave (bool, optional): Only save the final checkpoint. Defaults to False. noval (bool, optional): Only validate at the final epoch. Defaults to False. noautoanchor (bool, optional): Disable AutoAnchor. Defaults to False. noplots (bool, optional): Do not save plot files. Defaults to False. evolve (int, optional): Evolve hyperparameters for a specified number of generations. Use 300 if provided without a value. evolve_population (str, optional): Directory for loading population during evolution. Defaults to ROOT / 'data/ hyps'. resume_evolve (str, optional): Resume hyperparameter evolution from the last generation. Defaults to None. bucket (str, optional): gsutil bucket for saving checkpoints. Defaults to an empty string. cache (str, optional): Cache image data in 'ram' or 'disk'. Defaults to None. image_weights (bool, optional): Use weighted image selection for training. Defaults to False. device (str, optional): CUDA device identifier, e.g., '0', '0,1,2,3', or 'cpu'. Defaults to an empty string. multi_scale (bool, optional): Use multi-scale training, varying image size by ±50%. Defaults to False. single_cls (bool, optional): Train with multi-class data as single-class. Defaults to False. optimizer (str, optional): Optimizer type, choices are ['SGD', 'Adam', 'AdamW']. Defaults to 'SGD'. sync_bn (bool, optional): Use synchronized BatchNorm, only available in DDP mode. Defaults to False. workers (int, optional): Maximum dataloader workers per rank in DDP mode. Defaults to 8. project (str, optional): Directory for saving training runs. Defaults to ROOT / 'runs/train'. name (str, optional): Name for saving the training run. Defaults to 'exp'. exist_ok (bool, optional): Allow existing project/name without incrementing. Defaults to False. quad (bool, optional): Use quad dataloader. Defaults to False. cos_lr (bool, optional): Use cosine learning rate scheduler. Defaults to False. label_smoothing (float, optional): Label smoothing epsilon value. Defaults to 0.0. patience (int, optional): Patience for early stopping, measured in epochs without improvement. Defaults to 100. freeze (list, optional): Layers to freeze, e.g., backbone=10, first 3 layers = [0, 1, 2]. Defaults to [0]. save_period (int, optional): Frequency in epochs to save checkpoints. Disabled if < 1. Defaults to -1. seed (int, optional): Global training random seed. Defaults to 0. local_rank (int, optional): Automatic DDP Multi-GPU argument. Do not modify. Defaults to -1. Returns: None: The function initiates YOLOv5 training or hyperparameter evolution based on the provided options. Examples: ```python import train train.run(data='coco128.yaml', imgsz=320, weights='yolov5m.pt') ``` Notes: - Models: https://github.com/ultralytics/yolov5/tree/master/models - Datasets: https://github.com/ultralytics/yolov5/tree/master/data - Tutorial: https://docs.ultralytics.com/yolov5/tutorials/train_custom_data """ opt = parse_opt(True) for k, v in kwargs.items(): setattr(opt, k, v) main(opt) return opt if __name__ == "__main__": opt = parse_opt() main(opt) 这是源代码
06-21
import os os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' import torch import torch.nn as nn import torch.optim as optim import torch.utils.data as Data import numpy as np import pandas as pd from sklearn.preprocessing import StandardScaler from sklearn.model_selection import train_test_split import matplotlib.pyplot as plt from joblib import dump, load from time import time from mealpy.swarm_based import PSO from mealpy.evolutionary_based import GA from sko.SA import SA as SKO_SA from mealpy.swarm_based import ACOR from mealpy.swarm_based import WOA from mealpy.swarm_based import GWO # ==================== 1. 设备设置与随机种子 ==================== torch.manual_seed(100) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ==================== 2. 数据加载与预处理 ==================== data = pd.read_csv('D:/PycharmProjects/PythonProject3/liaotou4.csv') feature_columns = data.columns[1:-2] target_columns = data.columns[-2:] # 可多目标 X_all, y_all = data[feature_columns].values, data[target_columns].values # 特征归一化 scaler_x = StandardScaler() X_all = scaler_x.fit_transform(X_all) dump(scaler_x, 'scaler_x') # 目标归一化 scaler_y = StandardScaler() y_all = scaler_y.fit_transform(y_all) dump(scaler_y, 'scaler_y') # 构建序列数据 seq_len = 60 X_seq, y_seq = [], [] for i in range(len(X_all) - seq_len): X_seq.append(X_all[i:i+seq_len]) y_seq.append(y_all[i+seq_len]) X_seq = torch.tensor(np.array(X_seq), dtype=torch.float32) y_seq = torch.tensor(np.array(y_seq), dtype=torch.float32) # 数据集划分 # 数据集按时序划分 train_size = int(0.7 * len(X_seq)) train_x, test_x = X_seq[:train_size], X_seq[train_size:] train_y, test_y = y_seq[:train_size], y_seq[train_size:] batch_size = 64 train_loader = Data.DataLoader(Data.TensorDataset(train_x, train_y), batch_size=batch_size, shuffle=True, drop_last=True) test_loader = Data.DataLoader(Data.TensorDataset(test_x, test_y), batch_size=batch_size, drop_last=True) # ==================== 3. 定义模型 ==================== from torch.nn import TransformerEncoder, TransformerEncoderLayer class TransformerBiLSTM(nn.Module): def __init__(self, input_dim, hidden_layer_sizes, hidden_dim, num_layers, num_heads, output_dim, dropout_rate=0.5): super().__init__() self.transformer = TransformerEncoder( TransformerEncoderLayer(input_dim, num_heads, hidden_dim, dropout=dropout_rate, batch_first=True), num_layers ) self.num_layers = len(hidden_layer_sizes) self.bilstm_layers = nn.ModuleList() self.bilstm_layers.append(nn.LSTM(input_dim, hidden_layer_sizes[0], batch_first=True, bidirectional=True)) for i in range(1, self.num_layers): self.bilstm_layers.append(nn.LSTM(hidden_layer_sizes[i-1]*2, hidden_layer_sizes[i], batch_first=True, bidirectional=True)) self.linear = nn.Linear(hidden_layer_sizes[-1]*2, output_dim) def forward(self, input_seq): transformer_output = self.transformer(input_seq) bilstm_out = transformer_output for bilstm in self.bilstm_layers: bilstm_out, _ = bilstm(bilstm_out) predict = self.linear(bilstm_out[:, -1, :]) return predict # ==================== 4. VPPSO算法实现 ==================== def vppso( func, dim, bounds, N=12, N1=6, N2=6, T=16, a=0.3, c1=2.0, c2=2.0, b=1.0, verbose=True ): X = np.random.uniform([b[0] for b in bounds], [b[1] for b in bounds], (N, dim)) V = np.zeros((N, dim)) Pbest = X.copy() Pbest_f = np.array([func(x) for x in X]) gbest_idx = np.argmin(Pbest_f) Gbest = Pbest[gbest_idx].copy() Gbest_f = Pbest_f[gbest_idx] best_curve = [Gbest_f] for t in range(T): alpha_t = np.exp(-b * (t / T) ** b) for i in range(N): if i < N1: if np.random.rand() < a: V[i] = V[i] else: r1, r2, r3 = np.random.rand(3) V[i] = (V[i] * r1 * alpha_t + c1 * r2 * (Pbest[i] - X[i]) + c2 * r3 * (Gbest - X[i])) X[i] = X[i] + V[i] else: if np.random.rand() < 0.5: X[i] = Gbest + alpha_t * np.random.rand(dim) * np.abs(Gbest * alpha_t) else: X[i] = Gbest - alpha_t * np.random.rand(dim) * np.abs(Gbest * alpha_t) # 边界处理 for d in range(dim): if X[i, d] < bounds[d][0]: X[i, d] = bounds[d][0] if X[i, d] > bounds[d][1]: X[i, d] = bounds[d][1] # 适应度 F = np.array([func(x) for x in X]) for i in range(N): if i < N1: if F[i] < Pbest_f[i]: Pbest[i] = X[i].copy() Pbest_f[i] = F[i] if F[i] < Gbest_f: Gbest = X[i].copy() Gbest_f = F[i] else: if F[i] < Gbest_f: Gbest = X[i].copy() Gbest_f = F[i] best_curve.append(Gbest_f) if verbose and (t % 4 == 0 or t == T-1): print(f"Iter {t+1}/{T}, Best fitness: {Gbest_f}") return Gbest, Gbest_f, best_curve # ==================== 5. 定义超参数搜索空间与适应度函数 ==================== param_bounds = [ (32, 128), # hidden_layer_sizes[0] (32, 128), # hidden_layer_sizes[1] (64, 256), # hidden_dim (1, 4), # num_layers (1, 4), # num_heads (0.05, 0.5), # dropout_rate (0.00005, 0.005) # learning rate ] def eval_model_hyperparams(x): h1 = int(round(x[0])) h2 = int(round(x[1])) hidden_dim = int(round(x[2])) num_layers = int(round(x[3])) num_heads = int(round(x[4])) dropout = float(x[5]) lr = float(x[6]) try: model = TransformerBiLSTM( input_dim=X_seq.shape[2], hidden_layer_sizes=[h1, h2], hidden_dim=hidden_dim, num_layers=num_layers, num_heads=num_heads, output_dim=y_seq.shape[1], dropout_rate=dropout ).to(device) optimizer = optim.Adam(model.parameters(), lr) loss_function = nn.MSELoss(reduction='sum') best_mse = 1000. for epoch in range(4): # 搜索时只训练4个epoch model.train() for seq, labels in train_loader: seq, labels = seq.to(device), labels.to(device) optimizer.zero_grad() y_pred = model(seq) loss = loss_function(y_pred, labels) loss.backward() optimizer.step() model.eval() with torch.no_grad(): test_loss = 0. test_size = len(test_loader.dataset) for data, label in test_loader: data, label = data.to(device), label.to(device) pre = model(data) test_loss += loss_function(pre, label).item() test_av_mseloss = test_loss / test_size if test_av_mseloss < best_mse: best_mse = test_av_mseloss return best_mse except Exception as e: print("Exception in eval:", e) return 9999. def run_sa(): bounds = np.array(param_bounds) x0 = [(b[0] + b[1]) / 2 for b in param_bounds] sa = SKO_SA( func=lambda x: eval_model_hyperparams(np.clip(x, bounds[:, 0], bounds[:, 1])), x0=x0, T_max=50, T_min=1, L=30, max_stay_counter=20 ) best_param, best_loss = sa.run() return best_param, best_loss from mealpy import Problem problem = Problem( fit_func=eval_model_hyperparams, bounds=param_bounds, minmax="min" ) optimizer_dict = { 'GA': lambda: GA.BaseGA(problem, epoch=16, pop_size=12).solve().solution[:2], 'PSO': lambda: PSO.BasePSO(problem, epoch=16, pop_size=12).solve().solution[:2], 'ACO': lambda: ACOR.BaseACOR(problem, epoch=16, pop_size=12).solve().solution[:2], 'WOA': lambda: WOA.BaseWOA(problem, epoch=16, pop_size=12).solve().solution[:2], 'GWO': lambda: GWO.BaseGWO(problem, epoch=16, pop_size=12).solve().solution[:2], 'SA': run_sa, 'VPPSO': lambda: vppso(eval_model_hyperparams, len(param_bounds), param_bounds, N=12, N1=6, N2=6, T=16, a=0.3, c1=2.0, c2=2.0, b=1.0, verbose=False)[:2] } final_results = {} show_n = 100 # 展示前100个样本 alg_colors = { 'VPPSO': 'blue', 'GA': 'red', 'PSO': 'green', 'SA': 'purple', 'ACOR': 'orange', 'WOA': 'deepskyblue', 'GWO': 'brown' } for alg_name, alg_func in optimizer_dict.items(): print(f"\n------ 开始{alg_name}优化Transformer-BiLSTM超参数 ------") best_param, best_loss = alg_func() # 还原参数 h1 = int(round(best_param[0])) h2 = int(round(best_param[1])) hidden_dim = int(round(best_param[2])) num_layers = int(round(best_param[3])) num_heads = int(round(best_param[4])) dropout = float(best_param[5]) lr = float(best_param[6]) print(f'{alg_name}最优超参数: {best_param}, 验证loss: {best_loss}') # 训练模型 model = TransformerBiLSTM( input_dim=X_seq.shape[2], hidden_layer_sizes=[h1, h2], hidden_dim=hidden_dim, num_layers=num_layers, num_heads=num_heads, output_dim=y_seq.shape[1], dropout_rate=dropout ).to(device) optimizer = optim.Adam(model.parameters(), lr) loss_function = nn.MSELoss(reduction='sum') # 为加快总流程,只训练epochs=40,你可调整 def train_short(model, epochs): train_size = len(train_loader.dataset) test_size = len(test_loader.dataset) minimum_mse = 1000. best_model_wts = model.state_dict() for epoch in range(epochs): model.train() for seq, labels in train_loader: seq, labels = seq.to(device), labels.to(device) optimizer.zero_grad() y_pred = model(seq) loss = loss_function(y_pred, labels) loss.backward() optimizer.step() model.eval() with torch.no_grad(): test_loss = 0. for data, label in test_loader: data, label = data.to(device), label.to(device) pre = model(data) test_loss += loss_function(pre, label).item() test_av_mseloss = test_loss / test_size if test_av_mseloss < minimum_mse: minimum_mse = test_av_mseloss best_model_wts = model.state_dict() model.load_state_dict(best_model_wts) train_short(model, epochs=40) torch.save(model.state_dict(), f'best_model_{alg_name}.pt') # 预测前100个 model.eval() original_data = [] pre_data = [] with torch.no_grad(): for data, label in test_loader: data, label = data.to(device), label.to(device) y_pred = model(data) original_data.append(label.cpu().numpy()) pre_data.append(y_pred.cpu().numpy()) original_data = np.concatenate(original_data, axis=0) pre_data = np.concatenate(pre_data, axis=0) scaler_y = load('scaler_y') original_100 = scaler_y.inverse_transform(original_data)[:show_n, 0] pre_100 = scaler_y.inverse_transform(pre_data)[:show_n, 0] final_results[alg_name] = (original_100, pre_100) # ======================= 结果可视化对比 ====================== plt.figure(figsize=(14, 7)) plt.plot(final_results['VPPSO'][0], color='gray', label='真实值', linewidth=2, linestyle='--') for alg_name, (orig, pred) in final_results.items(): plt.plot(pred, color=alg_colors[alg_name], label=f'{alg_name}优化', alpha=0.85) plt.xlabel('样本编号') plt.ylabel('预测输出') plt.title('不同智能优化算法下Transformer-BiLSTM预测对比(前100样本)') plt.legend() plt.tight_layout() plt.show()
07-30
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值