【无线电控制与数据链探测系统】书籍配套源码-python 版本 第二章至第六章

第二章代码

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Comprehensive executable Python examples for Chapter 2 (Sections 2.1 - 2.5)
- Target audience: graduate students
- Purpose: runnable, in-depth code samples (no theoretical exposition; comments only)

Dependencies:
  numpy, scipy, matplotlib
  install via: pip install numpy scipy matplotlib

Usage:
  python Chapter2_Sections_2.1-2.5_full_python_examples.py --all
  or choose sections: --s2_1 --s2_2 --s2_3 --s2_4 --s2_5

Each section exposes functions that can be imported individually.
"""

import argparse
import numpy as np
import matplotlib.pyplot as plt
from scipy.constants import c, k as k_B
from scipy.special import erfc
from scipy.signal import fftconvolve

# =========================
# 2.1 Frequency & Spectrum Utilities
# =========================

# Simple regulatory table (example, not exhaustive)
REG_TABLE = [
    # (name, f_min_Hz, f_max_Hz, category)
    ("HF", 3e6, 30e6, "licensed/managed"),
    ("VHF", 30e6, 300e6, "licensed/managed"),
    ("UHF", 300e6, 3e9, "licensed/managed"),
    ("ISM_2.4GHz", 2.400e9, 2.4835e9, "unlicensed"),
    ("ISM_5GHz", 5.150e9, 5.875e9, "unlicensed"),
    ("Cellular_low", 600e6, 960e6, "licensed"),
]


def classify_frequency(freq_hz):
    """Return matching band(s) and category for frequency or frequency array."""
    f = np.atleast_1d(freq_hz)
    out = []
    for fi in f:
        matches = [r for r in REG_TABLE if r[1] <= fi <= r[2]]
        out.append(matches if matches else None)
    return out if out.__len__() > 1 else out[0]


def simulate_spectrum_occupancy(center_hz=2.442e9, bw_hz=20e6, n_channels=20, seed=1):
    """Create a toy spectrum occupancy array for visualization (power spectral density)."""
    rng = np.random.default_rng(seed)
    freqs = np.linspace(center_hz - bw_hz, center_hz + bw_hz, n_channels*50)
    # baseline noise
    psd = np.ones_like(freqs) * (-120)  # dBm/Hz baseline
    # add some random active transmissions
    for _ in range(rng.integers(1, 6)):
        center = rng.uniform(freqs.min(), freqs.max())
        width = rng.uniform(bw_hz*0.01, bw_hz*0.2)
        amp = rng.uniform(-30, -60)
        psd += amp * np.exp(-0.5 * ((freqs - center)/ (width/2))**2)
    return freqs, psd

# =========================
# 2.2 Link Budget, Noise, SNR, BER
# =========================


def fspl_db(f_hz, d_m):
    """Free-space path loss in dB for frequency f_hz and distance d_m.
    Vectorized over d_m (can be scalar or array).
    """
    lambda_m = c / f_hz
    # FSPL (linear) = (4*pi*d / lambda)^2 -> dB form
    with np.errstate(divide='ignore'):
        pl_db = 20 * np.log10(4 * np.pi * d_m / lambda_m)
    return pl_db


def noise_floor_dbm(b_hz, t_k=290.0, nf_db=0.0):
    """Thermal noise floor (dBm) for bandwidth b_hz, temperature t_k and noise figure nf_db."""
    n_w = k_B * t_k * b_hz * (10**(nf_db/10))
    n_dbm = 10 * np.log10(n_w) + 30
    return n_dbm


def link_budget(Pt_dbm, Gt_dbi, Gr_dbi, f_hz, d_m, Ls_dB=0.0, noise_fig_dB=3.0, bandwidth_hz=1e6):
    """Compute Pr (dBm), SNR (linear and dB), and BER for BPSK.
    Returns dict with arrays matching d_m shape.
    """
    pl_db = fspl_db(f_hz, d_m)
    Pr_dbm = Pt_dbm + Gt_dbi + Gr_dbi - pl_db - Ls_dB
    nf_db = noise_fig_dB
    noise_dbm = noise_floor_dbm(bandwidth_hz, t_k=290.0, nf_db=nf_db)
    snr_db = Pr_dbm - noise_dbm
    snr_lin = 10**(snr_db/10)
    ber_bpsk = 0.5 * erfc(np.sqrt(snr_lin))
    return {
        'Pr_dbm': Pr_dbm,
        'SNR_db': snr_db,
        'SNR_lin': snr_lin,
        'BER_BPSK': ber_bpsk,
        'PL_db': pl_db,
        'Noise_dbm': noise_dbm
    }

# Additional useful helper: compute max distance for required sensitivity

def max_distance_for_sensitivity(Pt_dbm, Gt_dbi, Gr_dbi, f_hz, sensitivity_dbm, Ls_dB=0.0):
    """Solve for distance where received power equals sensitivity (dBm) using FSPL inversion.
    Works for scalar sensitivity. Returns distance in meters.
    """
    lambda_m = c / f_hz
    # Pr = Pt + Gt + Gr - 20*log10(4*pi*d / lambda) - Ls
    # rearrange for d
    exp = (Pt_dbm + Gt_dbi + Gr_dbi - Ls_dB - sensitivity_dbm) / 20.0
    d = (lambda_m / (4 * np.pi)) * (10**exp)
    return d

# =========================
# 2.3 Modulation, OFDM, FHSS, DSSS
# =========================

# Basic modulators/demodulators (waveform-level)

def bpsk_mod(bits, fc, fs, baud):
    """BPSK modulator producing real passband waveform.
    bits: {0,1} array
    fc: carrier freq
    fs: sampling freq
    baud: symbol rate
    """
    bits = np.asarray(bits).astype(int)
    sps = int(fs / baud)
    # map 0->-1, 1->+1
    symbols = 2*bits - 1
    # upsample
    base = np.repeat(symbols, sps)
    t = np.arange(base.size) / fs
    carrier = np.cos(2*np.pi*fc*t)
    return base * carrier, fs


def add_awgn(signal, snr_db, seed=None):
    rng = np.random.default_rng(seed)
    sig_pow = np.mean(np.abs(signal)**2)
    snr_lin = 10**(snr_db/10)
    noise_pow = sig_pow / snr_lin
    noise = np.sqrt(noise_pow/2) * rng.normal(size=signal.shape)
    return signal + noise

# QAM mapper/demapper

def qam16_map(bits):
    # expects length multiple of 4
    bits = np.asarray(bits).reshape(-1,4)
    # Gray map 16-QAM, normalised
    mapping = {
        (0,0,0,0): -3-3j,(0,0,0,1): -3-1j,(0,0,1,1): -3+1j,(0,0,1,0): -3+3j,
        (0,1,0,0): -1-3j,(0,1,0,1): -1-1j,(0,1,1,1): -1+1j,(0,1,1,0): -1+3j,
        (1,1,0,0):  1-3j,(1,1,0,1):  1-1j,(1,1,1,1): 1+1j,(1,1,1,0):  1+3j,
        (1,0,0,0):  3-3j,(1,0,0,1):  3-1j,(1,0,1,1): 3+1j,(1,0,1,0):  3+3j
    }
    symbols = np.array([mapping[tuple(b.tolist())] for b in bits])
    symbols = symbols / np.sqrt((np.abs(symbols)**2).mean())
    return symbols


def qam16_demod(symbols):
    # simple nearest-neighbor demap using same constellation
    # reconstruct constellation
    pts = np.array([-3-3j,-3-1j,-3+1j,-3+3j,-1-3j,-1-1j,-1+1j,-1+3j,
                    1-3j,1-1j,1+1j,1+3j,3-3j,3-1j,3+1j,3+3j], dtype=complex)
    bits_map = [(0,0,0,0),(0,0,0,1),(0,0,1,1),(0,0,1,0),(0,1,0,0),(0,1,0,1),(0,1,1,1),(0,1,1,0),
                (1,1,0,0),(1,1,0,1),(1,1,1,1),(1,1,1,0),(1,0,0,0),(1,0,0,1),(1,0,1,1),(1,0,1,0)]
    pts = pts / np.sqrt((np.abs(pts)**2).mean())
    out_bits = []
    for s in symbols:
        idx = np.argmin(np.abs(s - pts))
        out_bits.extend(bits_map[idx])
    return np.array(out_bits)

# OFDM transmitter/receiver (simple CP-based)

def ofdm_tx(bits, n_subcarriers=64, cp_len=16, qam_order=16):
    """Return complex baseband time-domain OFDM waveform and params.
    bits length must be multiple of log2(qam_order)*n_subcarriers.
    """
    k = int(np.log2(qam_order))
    assert len(bits) % (k * n_subcarriers) == 0, "bits length must fit OFDM frames"
    symbols = qam16_map(bits) if qam_order==16 else None
    # group into OFDM symbols
    ofdm_syms = symbols.reshape(-1, n_subcarriers)
    tx_time = []
    for sym in ofdm_syms:
        freq_domain = sym
        time_domain = np.fft.ifft(freq_domain)
        cp = time_domain[-cp_len:]
        tx_time.extend(np.concatenate([cp, time_domain]))
    tx = np.array(tx_time)
    return tx, {'n_subcarriers':n_subcarriers, 'cp_len':cp_len, 'q_order':qam_order}


def ofdm_rx(rx, params):
    n = params['n_subcarriers']; cp = params['cp_len']; q= params['q_order']
    sym_len = n + cp
    nsym = len(rx) // sym_len
    rx = rx[:nsym*sym_len]
    rx_mat = rx.reshape(nsym, sym_len)
    out_syms = []
    for r in rx_mat:
        td = r[cp:]
        fd = np.fft.fft(td)
        out_syms.append(fd)
    out = np.concatenate(out_syms)
    # demap
    bits = qam16_demod(out) if q==16 else None
    return bits

# FHSS simple simulator

def fhss_transmit(bits, carrier_list, hop_rate, fs, symbol_rate, seed=0):
    rng = np.random.default_rng(seed)
    sps = int(fs / symbol_rate)
    # simple BPSK baseband
    symbols = 2*np.array(bits) - 1
    base = np.repeat(symbols, sps)
    # hop pattern
    n_hops = int(np.ceil(len(base) / (sps*hop_rate)))
    hops = rng.choice(carrier_list, size=n_hops, replace=True)
    t = np.arange(len(base))/fs
    tx = np.zeros_like(base, dtype=float)
    idx=0
    for hi,hf in enumerate(hops):
        start = hi * hop_rate * sps
        end = min(len(base), start + hop_rate*sps)
        if start>=end: break
        carrier = np.cos(2*np.pi*hf*t[start:end])
        tx[start:end] = base[start:end] * carrier
    return tx

# DSSS simple: spread with m-sequence

def pn_sequence(length, seed=1):
    rng = np.random.default_rng(seed)
    return rng.choice([1,-1], size=length)


def dsss_tx(bits, pn, sps=8):
    # bits: 0/1 -> map 0->-1,1->1 ; spread
    chips = np.repeat(2*np.array(bits)-1, len(pn)) * np.tile(pn, len(bits))
    # upsample
    return np.repeat(chips, sps)


def dsss_rx(rx, pn, sps=8):
    # simple correlator
    chips = rx.reshape(-1, sps)
    chips = chips.mean(axis=1)  # collapse sps
    # regroup by pn length
    n_symbols = len(chips) // len(pn)
    chips = chips[:n_symbols*len(pn)].reshape(n_symbols, len(pn))
    out = np.sign(np.dot(chips, pn))
    bits = ((out + 1) // 2).astype(int)
    return bits

# =========================
# 2.4 Data Link: framing, CRC, ARQ, handshake
# =========================

# CRC-16-CCITT

def crc16_ccitt(data: bytes, poly=0x1021, init=0xFFFF):
    crc = init
    for b in data:
        crc ^= b << 8
        for _ in range(8):
            if crc & 0x8000:
                crc = ((crc << 1) ^ poly) & 0xFFFF
            else:
                crc = (crc << 1) & 0xFFFF
    return crc


def make_frame(payload: bytes, seq: int, ctrl: int=0):
    # simple frame: [SOF(0x7E)] [CTRL(1)] [SEQ(1)] [LEN(2)] [PAYLOAD] [CRC(2)] [EOF(0x7E)]
    sof = bytes([0x7E])
    eof = bytes([0x7E])
    ctrl_b = bytes([ctrl & 0xFF])
    seq_b = bytes([seq & 0xFF])
    length = len(payload).to_bytes(2,'big')
    core = ctrl_b + seq_b + length + payload
    crc = crc16_ccitt(core).to_bytes(2,'big')
    return sof + core + crc + eof


def parse_frame(frame: bytes):
    if frame[0]!=0x7E or frame[-1]!=0x7E:
        raise ValueError('Frame markers missing')
    core = frame[1:-3]
    crc_received = int.from_bytes(frame[-3:-1],'big')
    crc_calc = crc16_ccitt(core)
    if crc_received != crc_calc:
        raise ValueError('CRC mismatch')
    ctrl = core[0]
    seq = core[1]
    length = int.from_bytes(core[2:4],'big')
    payload = core[4:4+length]
    return {'ctrl':ctrl, 'seq':seq, 'payload':payload}

# Stop-and-wait ARQ simulator

def stop_and_wait_simulate(tx_payloads, p_loss=0.1, max_retries=5, seed=0):
    rng = np.random.default_rng(seed)
    acked = []
    events = []
    for seq,payload in enumerate(tx_payloads):
        retries = 0
        while retries<=max_retries:
            # frame transmitted
            frame = make_frame(payload, seq)
            lost = rng.uniform() < p_loss
            events.append(('tx', seq, retries, lost))
            if lost:
                retries += 1
                continue
            # receiver sends ACK (simulate possible loss)
            ack_lost = rng.uniform() < p_loss
            events.append(('ack', seq, retries, ack_lost))
            if ack_lost:
                retries += 1
                continue
            # success
            acked.append(seq)
            break
        if retries>max_retries:
            events.append(('fail', seq, retries, True))
    return acked, events

# =========================
# 2.5 Antenna & Front-End Utilities
# =========================


def dBi_to_linear(dbi):
    return 10**(dbi/10)


def linear_to_dBi(g):
    return 10*np.log10(g)

# Simple idealised isotropic pattern modifier: create pattern for given dBi and beamwidth

def antenna_pattern(dbi, beamwidth_deg=60, theta_res=361):
    # create a 2D-ish azimuth pattern (not full electromagnetic detail)
    theta = np.linspace(0, 360, theta_res)
    bw = beamwidth_deg
    # main-lobe modelled as raised cosine around 0 deg
    main = np.maximum(0, np.cos(np.deg2rad(theta)/ (bw/180.0) * np.pi/2))
    norm = main.max()
    main = main / (norm + 1e-12)
    # convert amplitude shape to dBi scaled to peak
    peak_lin = dBi_to_linear(dbi)
    pattern_db = linear_to_dBi(main * peak_lin + 1e-12)
    return theta, pattern_db

# Polarization loss (simplified): mismatch between two linear polarizations (angles in deg)

def polarization_loss_db(tx_angle_deg, rx_angle_deg):
    # loss = -20*log10(|cos(delta)|)
    delta = np.deg2rad(tx_angle_deg - rx_angle_deg)
    loss = -20*np.log10(np.abs(np.cos(delta)) + 1e-12)
    return loss

# Impedance match calculations: VSWR and return loss

def return_loss_db(z_load, z0=50.0):
    # reflection coeff
    gamma = (z_load - z0) / (z_load + z0)
    rl = -20*np.log10(np.abs(gamma) + 1e-12)
    return rl


def vswr(z_load, z0=50.0):
    gamma = np.abs((z_load - z0) / (z_load + z0))
    return (1 + gamma) / (1 - gamma)

# Simple LNA model (gain, noise figure, 1dB compression)
class LNA:
    def __init__(self, gain_dB=20.0, nf_dB=1.0, p1dB_dBm=-5.0):
        self.gain_dB = gain_dB
        self.nf_dB = nf_dB
        self.p1dB_dBm = p1dB_dBm
    def amplify_dbm(self, pin_dbm):
        # simulate compression: if pin near p1dB, compress
        # simple soft knee: above p1dB-10dB, apply slight compression
        delta = pin_dbm - (self.p1dB_dBm - 10)
        if delta <= 0:
            pout = pin_dbm + self.gain_dB
        else:
            # compress slope
            slope = max(0.2, 1.0 - delta/50.0)
            pout = pin_dbm + self.gain_dB * slope
        return pout

# =========================
# Demo / CLI
# =========================

def demo_2_1():
    print('=== Demo 2.1: Frequency classification & spectrum occupancy ===')
    freqs = [2.45e9, 900e6, 28e6]
    for f in freqs:
        print(f'{f/1e6:.2f} MHz ->', classify_frequency(f))
    ff, psd = simulate_spectrum_occupancy()
    plt.figure(figsize=(8,3))
    plt.plot((ff-ff.mean())/1e6, psd)
    plt.xlabel('Offset from center (MHz)')
    plt.ylabel('PSD (dBm/Hz)')
    plt.title('Toy spectrum occupancy')
    plt.grid(True)
    plt.tight_layout()
    plt.show()


def demo_2_2():
    print('=== Demo 2.2: Link budget sweep ===')
    Pt = 30.0; Gt=15.0; Gr=15.0; f=2.4e9; Ls=2.0
    dists = np.logspace(1,4,300)
    res = link_budget(Pt, Gt, Gr, f, dists, Ls_dB=Ls, noise_fig_dB=5.0, bandwidth_hz=1e6)
    plt.figure(figsize=(8,6))
    plt.subplot(2,1,1)
    plt.semilogx(dists, res['Pr_dbm'])
    plt.grid(True); plt.ylabel('Pr (dBm)')
    plt.subplot(2,1,2)
    plt.semilogx(dists, res['SNR_db'])
    plt.grid(True); plt.ylabel('SNR (dB)'); plt.xlabel('Distance (m)')
    plt.tight_layout(); plt.show()
    # find distance where BER > 1e-3
    idx = np.where(res['BER_BPSK']>1e-3)[0]
    if idx.size>0:
        print('BER exceed 1e-3 beyond approx distance (m):', dists[idx[0]])
    else:
        print('BER acceptable in range')


def demo_2_3():
    print('=== Demo 2.3: OFDM + FHSS + DSSS mini-simulations ===')
    # OFDM quick round-trip
    rng = np.random.default_rng(0)
    bits = rng.integers(0,2,size=64*4)
    tx, params = ofdm_tx(bits, n_subcarriers=64, cp_len=16, qam_order=16)
    # AWGN
    rx = add_awgn(tx, snr_db=20, seed=1)
    rx_bits = ofdm_rx(rx, params)
    ber = np.mean(bits != rx_bits)
    print('OFDM BER (AWGN 20 dB):', ber)

    # FHSS example
    bits_bpsk = rng.integers(0,2,size=256)
    carrier_list = [2.412e9, 2.437e9, 2.462e9]
    tx_fh = fhss_transmit(bits_bpsk, carrier_list, hop_rate=4, fs=1e6, symbol_rate=1e3)
    print('FHSS tx length samples:', len(tx_fh))

    # DSSS example
    pn = pn_sequence(16, seed=2)
    tx_dsss = dsss_tx(bits_bpsk[:32], pn, sps=4)
    # channel: add narrowband interferer in middle
    t = np.arange(len(tx_dsss))/1e4
    interferer = 0.5*np.sin(2*np.pi*500*t)
    rx_dsss = tx_dsss + interferer
    rx_bits = dsss_rx(rx_dsss, pn, sps=4)
    print('DSSS recovered symbols:', rx_bits.shape[0])


def demo_2_4():
    print('=== Demo 2.4: Framing, CRC and Stop-and-Wait ARQ simulation ===')
    payloads = [b'hello', b'world', b'this', b'is', b'test']
    acked, events = stop_and_wait_simulate(payloads, p_loss=0.2, seed=42)
    print('Acked sequences:', acked)
    print('Event sample (first 8):', events[:8])


def demo_2_5():
    print('=== Demo 2.5: Antenna patterns, polarization, LNA behavior ===')
    theta, pat = antenna_pattern(12.0, beamwidth_deg=45, theta_res=361)
    plt.figure(figsize=(6,3))
    plt.plot(theta, pat)
    plt.xlabel('Azimuth (deg)'); plt.ylabel('Gain (dBi)'); plt.title('Antenna pattern (toy)')
    plt.grid(True); plt.tight_layout(); plt.show()
    print('Polarization loss (0 vs 45 deg):', polarization_loss_db(0,45))
    print('Return loss for Z=30 Ohm:', return_loss_db(30.0))
    lna = LNA(gain_dB=20, nf_dB=1, p1dB_dBm=-2)
    print('LNA amplify -40 dBm ->', lna.amplify_dbm(-40), 'dBm')
    print('LNA amplify -5 dBm ->', lna.amplify_dbm(-5), 'dBm')


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--all', action='store_true')
    parser.add_argument('--s2_1', action='store_true')
    parser.add_argument('--s2_2', action='store_true')
    parser.add_argument('--s2_3', action='store_true')
    parser.add_argument('--s2_4', action='store_true')
    parser.add_argument('--s2_5', action='store_true')
    args = parser.parse_args()
    if args.all or not any([args.s2_1,args.s2_2,args.s2_3,args.s2_4,args.s2_5]):
        demo_2_1(); demo_2_2(); demo_2_3(); demo_2_4(); demo_2_5()
    else:
        if args.s2_1: demo_2_1()
        if args.s2_2: demo_2_2()
        if args.s2_3: demo_2_3()
        if args.s2_4: demo_2_4()
        if args.s2_5: demo_2_5()

if __name__=='__main__':
    main()

第三章代码

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Executable Python examples for Chapter 3: Classic Signal Detection Methods & Statistical Foundations
Sections: 3.1-3.5
- Target: graduate students
- Purely implementation code with detailed comments
- Dependencies: numpy, scipy, matplotlib, sklearn
"""

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import chi2, norm
from scipy.signal import correlate
from sklearn.metrics import roc_curve, auc

# =========================
# 3.1 Hypothesis testing & Neyman-Pearson
# =========================

def np_test_statistic(y, sigma2=1.0, mu0=0.0, mu1=1.0):
    """Likelihood ratio test statistic for Gaussian signal detection."""
    # L(y) = f(y|H1)/f(y|H0)
    return np.exp(-(y-mu1)**2/(2*sigma2)) / np.exp(-(y-mu0)**2/(2*sigma2))

def np_threshold(alpha, sigma2=1.0, mu0=0.0, mu1=1.0):
    # Solve for threshold from false alarm probability alpha
    # Under H0, y ~ N(mu0, sigma2)
    # threshold = (sigma2/2) * log((1-alpha)/alpha) + (mu0+mu1)/2 ?? simplified
    return mu0 + np.sqrt(sigma2) * norm.ppf(1-alpha)

# =========================
# 3.2 Energy detection & matched filter
# =========================

def energy_detector(y, threshold=None):
    """Energy detector: y can be vector or 2D (samples x trials)."""
    e = np.sum(np.abs(y)**2, axis=0)
    if threshold is not None:
        return e > threshold
    return e


def matched_filter(y, s, sigma2=1.0, threshold=None):
    """Matched filter: correlate received y with known template s"""
    mf = np.real(np.dot(s.conj(), y)) / sigma2
    if threshold is not None:
        return mf > threshold
    return mf

# Compute threshold for given Pfa

def threshold_energy(N, sigma2, Pfa):
    # chi-squared distribution with 2N DOF (real+imag)
    return sigma2 * chi2.ppf(1-Pfa, 2*N)

def threshold_matched_filter(sigma2, Pfa, s_norm2):
    # MF output under H0 ~ N(0, s_norm2*sigma2)
    return norm.ppf(1-Pfa) * np.sqrt(s_norm2*sigma2)

# =========================
# 3.3 Cyclostationarity detection (pseudo code / example)
# =========================

def cyclic_autocorr(x, alpha, maxlag=10):
    """Compute cyclic autocorrelation for cycle frequency alpha"""
    N = len(x)
    R_alpha = np.zeros(2*maxlag+1, dtype=complex)
    lags = np.arange(-maxlag, maxlag+1)
    for idx, tau in enumerate(lags):
        # cyclic autocorrelation: R_x(alpha, tau) = E[x[n]*x*[n-tau]*exp(-j2pi alpha n)]
        n_idx = np.arange(max(0, tau), min(N, N+tau))
        R_alpha[idx] = np.mean(x[n_idx] * np.conj(x[n_idx - tau]) * np.exp(-1j*2*np.pi*alpha*n_idx/N))
    return lags, R_alpha

# =========================
# 3.4 Multi-antenna covariance detection
# =========================

def sample_covariance_matrix(X):
    """X: samples x antennas"""
    return np.cov(X, rowvar=False, bias=True)

def eigenvalue_detection(X, threshold=None):
    """Use maximum eigenvalue as test statistic"""
    R = sample_covariance_matrix(X)
    eigvals = np.linalg.eigvalsh(R)
    T = np.max(eigvals)
    if threshold is not None:
        return T > threshold
    return T

# =========================
# 3.5 Performance evaluation: ROC, AUC, min detectable SNR
# =========================

def simulate_detection(detector_func, H0_samples, H1_samples, **kwargs):
    """Compute detection outputs for H0 and H1 samples"""
    y0 = np.array([detector_func(x, **kwargs) for x in H0_samples])
    y1 = np.array([detector_func(x, **kwargs) for x in H1_samples])
    labels = np.concatenate([np.zeros_like(y0), np.ones_like(y1)])
    scores = np.concatenate([y0, y1])
    fpr, tpr, _ = roc_curve(labels, scores)
    roc_auc = auc(fpr, tpr)
    return fpr, tpr, roc_auc

# Minimum detectable SNR for target Pd

def min_detectable_snr_energy(N, sigma2, Pfa, Pd_target):
    # threshold from Pfa
    th = threshold_energy(N, sigma2, Pfa)
    # Solve SNR such that P_D = Pd_target under non-central chi-squared
    # Approximation: use Gaussian approx
    mu_H1 = N*sigma2*(1 + 1.0) # assuming unit signal power
    var_H1 = N*sigma2**2*2
    # solve for snr: Pd = Q((th - mu_H1)/sqrt(var_H1))
    z = norm.ppf(Pd_target)
    snr = ( (th - N*sigma2) / np.sqrt(2*N*sigma2**2) - z )**2
    return snr

# =========================
# Demo for all sections
# =========================

def demo_3_1():
    print('=== Demo 3.1: Neyman-Pearson simple test ===')
    y = np.random.normal(0,1,1000)
    alpha = 0.05
    th = np.quantile(y, 1-alpha)
    decisions = y>th
    print('Decisions count (H1 detections):', np.sum(decisions))


def demo_3_2():
    print('=== Demo 3.2: Energy & matched filter detection ===')
    N = 64
    sigma2 = 1.0
    Pfa = 0.01
    # H0 and H1 signals
    H0 = np.random.normal(0, np.sqrt(sigma2), (N,1000))
    s = np.ones(N)
    H1 = H0 + s[:,None]
    th_energy = threshold_energy(N, sigma2, Pfa)
    det_H0 = energy_detector(H0, threshold=th_energy)
    det_H1 = energy_detector(H1, threshold=th_energy)
    print('Energy detector - H0 false alarms:', np.sum(det_H0)/det_H0.size)
    print('Energy detector - H1 detections:', np.sum(det_H1)/det_H1.size)


def demo_3_3():
    print('=== Demo 3.3: Cyclic autocorrelation ===')
    x = np.sin(2*np.pi*5*np.arange(100)/100) + np.random.randn(100)*0.1
    lags, R_alpha = cyclic_autocorr(x, alpha=5/100, maxlag=10)
    plt.figure()
    plt.stem(lags, np.abs(R_alpha))
    plt.xlabel('Lag'); plt.ylabel('|R_alpha|'); plt.title('Cyclic autocorrelation')
    plt.grid(True); plt.show()


def demo_3_4():
    print('=== Demo 3.4: Multi-antenna eigenvalue detection ===')
    X = np.random.randn(100,4) # 100 samples, 4 antennas
    stat = eigenvalue_detection(X)
    print('Max eigenvalue test statistic:', stat)


def demo_3_5():
    print('=== Demo 3.5: ROC and min detectable SNR ===')
    N = 32; sigma2=1.0; Pfa=0.05
    H0_samples = [np.random.randn(N) for _ in range(500)]
    H1_samples = [np.random.randn(N)+1.0 for _ in range(500)]
    fpr, tpr, roc_auc = simulate_detection(energy_detector, H0_samples, H1_samples)
    print('ROC AUC:', roc_auc)
    plt.figure()
    plt.plot(fpr, tpr, lw=2); plt.xlabel('FPR'); plt.ylabel('TPR'); plt.title('ROC Curve'); plt.grid(True)
    plt.show()
    # Minimum detectable SNR (example)
    snr_min = min_detectable_snr_energy(N, sigma2, Pfa, Pd_target=0.9)
    print('Approx. minimum detectable SNR (energy detector):', snr_min)

# =========================
# Main CLI
# =========================

def main():
    demo_3_1()
    demo_3_2()
    demo_3_3()
    demo_3_4()
    demo_3_5()

if __name__=='__main__':
    main()

第四章代码

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Enhanced Python examples for Chapter 4: Spectrum Sensing & Wideband Detection (Sections 4.1-4.5)
- Graduate-level detailed implementation
- Full visualizations, iterative fusion, performance statistics
- Dependencies: numpy, scipy, matplotlib, cvxpy
"""

import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import stft, get_window
import cvxpy as cp

# =========================
# 4.1 Narrowband & Wideband detection with adaptive thresholding
# =========================

def narrowband_detector(y, threshold=None):
    noise_power = np.var(y)
    energy = np.sum(np.abs(y)**2)
    if threshold is None:
        threshold = noise_power * len(y) * 1.5  # simple adaptive factor
    return energy > threshold, energy, threshold

def wideband_detector(y, band_indices, thresholds=None):
    results, energies = [], []
    if thresholds is None:
        thresholds = [np.var(y[start:end])*len(y[start:end])*1.5 for start,end in band_indices]
    for i, (start, end) in enumerate(band_indices):
        band = y[start:end]
        e = np.sum(np.abs(band)**2)
        energies.append(e)
        results.append(e > thresholds[i])
    return results, energies, thresholds

# =========================
# 4.2 FFT / STFT detection and window comparison
# =========================

def stft_analysis(y, fs=1e3, nperseg=128, windows=['hann','hamming','blackman']):
    plt.figure(figsize=(12,8))
    for i, w in enumerate(windows):
        f,t,Zxx = stft(y, fs=fs, nperseg=nperseg, window=w)
        E = np.abs(Zxx)**2
        plt.subplot(len(windows),1,i+1)
        plt.pcolormesh(t,f,10*np.log10(E+1e-12))
        plt.title(f'STFT Energy Spectrogram ({w} window)')
        plt.xlabel('Time'); plt.ylabel('Frequency'); plt.colorbar(label='dB')
    plt.tight_layout()
    plt.show()

# =========================
# 4.3 Compressive sensing & sparse reconstruction
# =========================

def cs_reconstruction_l1(Phi, y):
    N = Phi.shape[1]
    x = cp.Variable(N)
    obj = cp.Minimize(cp.norm1(x))
    constraints = [Phi @ x == y]
    prob = cp.Problem(obj, constraints)
    prob.solve(solver=cp.SCS)
    return x.value

def cs_reconstruction_omp(Phi, y, k):
    # Simple Orthogonal Matching Pursuit implementation
    M, N = Phi.shape
    residual = y.copy()
    idx_selected = []
    x_hat = np.zeros(N)
    for _ in range(k):
        correlations = np.abs(Phi.T @ residual)
        idx = np.argmax(correlations)
        idx_selected.append(idx)
        Phi_selected = Phi[:, idx_selected]
        x_ls = np.linalg.lstsq(Phi_selected, y, rcond=None)[0]
        residual = y - Phi_selected @ x_ls
    x_hat[idx_selected] = x_ls
    return x_hat

def generate_sparse_spectrum(N=1024, k=10, seed=None):
    rng = np.random.default_rng(seed)
    x = np.zeros(N)
    idx = rng.choice(N, k, replace=False)
    x[idx] = rng.normal(0,1,k)
    return x

# =========================
# 4.4 Cooperative sensing with iterative fusion
# =========================

def cooperative_fusion_iterative(results_matrix, iterations=3, method='majority'):
    fused = np.copy(results_matrix[0])
    history = [fused.copy()]
    for _ in range(iterations):
        if method=='majority':
            fused = np.mean(np.vstack([fused, results_matrix]), axis=0) >= 0.5
        elif method=='weighted':
            weights = np.linspace(0.5,1.0,results_matrix.shape[0])[:,None]
            fused = np.sum(np.vstack([fused, results_matrix])*weights, axis=0) > (0.5*weights.sum())
        elif method=='bayesian':
            probs = np.mean(np.vstack([fused, results_matrix]), axis=0)
            fused = probs > 0.5
        history.append(fused.copy())
    return fused, history

# =========================
# 4.5 Real-time system analysis: latency & energy
# =========================

def processing_load(samples_per_frame, fft_size, sensors, fs):
    ops_per_fft = fft_size * np.log2(fft_size)
    frames_per_sec = fs / samples_per_frame
    total_ops = ops_per_fft * frames_per_sec * sensors
    return total_ops

def energy_latency_tradeoff(frame_sizes, fft_size=128, sensors=4, fs=1e3, energy_per_op=1e-9):
    latencies = frame_sizes/fs
    loads = np.array([processing_load(f, fft_size, sensors, fs) for f in frame_sizes])
    energy = loads * energy_per_op
    return latencies, energy

# =========================
# Demonstrations
# =========================

def demo_4_1_4_2():
    print('=== Demo 4.1-4.2 ===')
    y = np.random.randn(1024)
    nb_res, nb_e, th_nb = narrowband_detector(y[:128])
    print('Narrowband detection:', nb_res, 'Threshold:', th_nb)
    bands = [(0,128),(128,256),(256,512)]
    wb_res, wb_e, th_wb = wideband_detector(y, bands)
    print('Wideband detection results:', wb_res, 'Thresholds:', th_wb)
    stft_analysis(y, nperseg=64, windows=['hann','hamming','blackman'])

def demo_4_3():
    print('=== Demo 4.3 Compressive Sensing ===')
    N, k, M = 512, 8, 128
    x_true = generate_sparse_spectrum(N=N, k=k)
    Phi = np.random.normal(0,1,(M,N))
    y = Phi @ x_true
    x_l1 = cs_reconstruction_l1(Phi, y)
    x_omp = cs_reconstruction_omp(Phi, y, k)
    plt.figure()
    plt.plot(x_true,'o',label='Original')
    plt.plot(x_l1,'x',label='L1 reconstruction')
    plt.plot(x_omp,'+',label='OMP reconstruction')
    plt.legend(); plt.show()

def demo_4_4():
    print('=== Demo 4.4 Cooperative Fusion ===')
    sensors, bands = 5, 3
    results_matrix = np.random.randint(0,2,(sensors,bands))
    fused, history = cooperative_fusion_iterative(results_matrix, iterations=5, method='majority')
    print('Results matrix:\n', results_matrix)
    print('Final fused result:', fused)
    for i, h in enumerate(history):
        print(f'Iteration {i}: {h}')

def demo_4_5():
    print('=== Demo 4.5 Real-time System Analysis ===')
    frame_sizes = np.array([64,128,256,512])
    lat, energy = energy_latency_tradeoff(frame_sizes)
    plt.figure()
    plt.plot(lat*1e3, energy*1e3,'-o')
    plt.xlabel('Latency (ms)'); plt.ylabel('Energy (mJ)'); plt.title('Energy vs Latency Tradeoff')
    plt.grid(True); plt.show()

# =========================
# Main
# =========================

def main():
    demo_4_1_4_2()
    demo_4_3()
    demo_4_4()
    demo_4_5()

if __name__=='__main__':
    main()

第五章代码

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
signal_suite.py
A single-file, runnable toolkit that demonstrates a complete pipeline for
Signal Classification & Protocol Recognition (from features to end-to-end).
Includes:
  - Synthetic IQ dataset generator (multiple modulations + simple "protocol" label)
  - Feature extraction: time-domain, frequency-domain, spectrogram (STFT),
    continuous wavelet transform (CWT), and higher-order statistics
  - Traditional ML pipeline: feature selection, scaling, PCA, cross-validated classifiers
  - Deep learning models: Spectrogram-CNN, Raw-IQ 1D-CNN, LSTM, Transformer
  - Evaluation: accuracy, confusion matrices, SNR curves, mixed-interference testing
  - Interpretability: t-SNE visualization and confusion matrix heatmaps
All code in one file; execute `python signal_suite.py`.
"""

import os
import sys
import math
import time
import random
import argparse
from functools import partial

import numpy as np
import scipy.signal as sps
import scipy.stats as stats
from sklearn.model_selection import train_test_split, GridSearchCV, StratifiedKFold
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.decomposition import PCA
from sklearn.feature_selection import SelectKBest, f_classif
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset

# ---------------------------
# Reproducibility / device
# ---------------------------
SEED = 2025
np.random.seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    torch.cuda.manual_seed(SEED)
else:
    DEVICE = torch.device("cpu")

# ---------------------------
# Utilities
# ---------------------------
def ensure_dir(d):
    if not os.path.exists(d):
        os.makedirs(d)

PLOT_DIR = "plots_signal_suite"
ensure_dir(PLOT_DIR)

# ---------------------------
# Synthetic signal generator
# ---------------------------
def awgn(signal, snr_db):
    """Add AWGN noise to complex signal to achieve desired SNR in dB"""
    sig_power = np.mean(np.abs(signal)**2)
    snr_linear = 10**(snr_db/10.0)
    noise_power = sig_power / snr_linear
    noise = np.sqrt(noise_power/2) * (np.random.randn(*signal.shape) + 1j*np.random.randn(*signal.shape))
    return signal + noise

def generate_baseband(mod_type, num_samples=1024, sps_rate=8, freq_offset=0.0, phase_offset=0.0):
    """
    Generate a baseband complex IQ burst for a given modulation type.
    Supported mod_type: 'BPSK','QPSK','8PSK','16QAM','AM','FM','GFSK'
    Returns complex numpy array length num_samples
    """
    t = np.arange(num_samples) / sps_rate
    # symbol rate chosen so that many symbols fit in window
    if mod_type == 'BPSK':
        # random bits mapped to ±1
        bits = np.random.choice([1,-1], size=num_samples)
        s = bits.astype(np.complex64) * np.exp(1j*phase_offset)
    elif mod_type == 'QPSK':
        bits = np.random.randint(0,4,size=num_samples)
        mapping = {0:1+1j,1:-1+1j,2:-1-1j,3:1-1j}
        s = np.array([mapping[b] for b in bits], dtype=np.complex64)
        s /= np.sqrt(2)
    elif mod_type == '8PSK':
        bits = np.random.randint(0,8,size=num_samples)
        s = np.exp(1j*(2*np.pi*bits/8.0 + phase_offset))
    elif mod_type == '16QAM':
        # square 16-QAM normalized
        re = np.random.choice([-3,-1,1,3], size=num_samples)
        im = np.random.choice([-3,-1,1,3], size=num_samples)
        s = (re + 1j*im).astype(np.complex64)
        s /= np.sqrt((10)) # normalization
    elif mod_type == 'AM':
        # amplitude modulation onto a complex carrier (real amplitude mod)
        carrier = np.exp(1j*(2*np.pi*0.1*t + phase_offset))
        base = 0.7 + 0.3*np.random.randn(num_samples)
        s = base * carrier
    elif mod_type == 'FM':
        # frequency-modulated baseband: integrate frequency deviations
        freq_dev = 2*np.pi*0.05*np.random.randn(num_samples)  # instantaneous frequency noise
        inst_phase = np.cumsum(freq_dev) + 2*np.pi*0.1*t + phase_offset
        s = np.exp(1j*inst_phase)
    elif mod_type == 'GFSK':
        # gaussian filtered FSK-like: generate bits and approximate GFSK by filtering
        bits = np.random.choice([0,1], size=num_samples)
        # map to frequency impulses +/- delta
        delta = 0.3
        freq = (2*bits - 1) * delta
        phase = np.cumsum(freq)
        s = np.exp(1j*(phase + phase_offset))
    else:
        # default noise
        s = np.exp(1j*(2*np.pi*0.1*t + phase_offset))
    # apply small frequency offset if requested
    if freq_offset != 0.0:
        s *= np.exp(1j*2*np.pi*freq_offset*t)
    # apply simple pulse shaping: convolve with small raised-cosine-ish kernel
    window = np.hanning(7)
    s_real = np.convolve(np.real(s), window, mode='same')/np.sum(window)
    s_imag = np.convolve(np.imag(s), window, mode='same')/np.sum(window)
    s = s_real + 1j*s_imag
    # normalize power to 1
    s /= np.sqrt(np.mean(np.abs(s)**2) + 1e-12)
    return s.astype(np.complex64)

# Simple protocol label generator given a modulation type -> simulate 'protocol' by framing patterns
def generate_protocol_label(mod_type):
    # map mod types to simple "protocol" classes (e.g., framing patterns)
    mapping = {
        'BPSK': 'CTRL',
        'QPSK': 'DATA',
        '8PSK': 'VOICE',
        '16QAM': 'VIDEO',
        'AM': 'ANALOG',
        'FM': 'ANALOG',
        'GFSK': 'CTRL',
    }
    return mapping.get(mod_type, 'UNKNOWN')

# ---------------------------
# Dataset builder
# ---------------------------
MODS = ['BPSK','QPSK','8PSK','16QAM','AM','FM','GFSK']
MOD_TO_IDX = {m:i for i,m in enumerate(MODS)}
PROTOCOLS = ['CTRL','DATA','VOICE','VIDEO','ANALOG','UNKNOWN']
PROTO_TO_IDX = {p:i for i,p in enumerate(PROTOCOLS)}

def create_dataset(n_per_class=200, num_samples=1024, sps=8, snr_db=20, snr_jitter=2, interference_prob=0.2):
    """
    Create dataset with equal number per modulation class.
    Returns:
        X_complex: shape (N, num_samples) complex64
        y_mod: shape (N,) int labels for modulation
        y_proto: shape (N,) int labels for protocol
    """
    X = []
    y_mod = []
    y_proto = []
    for m in MODS:
        for _ in range(n_per_class):
            s = generate_baseband(m, num_samples=num_samples, sps_rate=sps,
                                   freq_offset=np.random.uniform(-0.05,0.05),
                                   phase_offset=np.random.uniform(0,2*np.pi))
            # maybe add an interfering emitter with some prob
            if np.random.rand() < interference_prob:
                interferer = generate_baseband(np.random.choice(MODS),
                                               num_samples=num_samples, sps_rate=sps)
                # mix with random power
                alpha = np.random.uniform(0.1,0.8)
                s = s + alpha * interferer
            # AWGN with jittered SNR
            snr_sample = snr_db + np.random.uniform(-snr_jitter, snr_jitter)
            s_noisy = awgn(s, snr_sample)
            X.append(s_noisy)
            y_mod.append(MOD_TO_IDX[m])
            y_proto.append(PROTO_TO_IDX[generate_protocol_label(m)])
    X = np.stack(X, axis=0)
    y_mod = np.array(y_mod)
    y_proto = np.array(y_proto)
    return X, y_mod, y_proto

# ---------------------------
# Feature extraction
# ---------------------------
def extract_time_features(x_complex):
    """
    Compute time-domain features for a single complex burst (1D array).
    Returns 1D feature vector (real-valued)
    """
    x = x_complex
    i = np.real(x)
    q = np.imag(x)
    amp = np.abs(x)
    phase = np.angle(x)
    inst_freq = np.gradient(np.unwrap(phase))
    feats = []
    # basic stats
    for v in [i, q, amp, inst_freq]:
        feats += [np.mean(v), np.std(v), stats.skew(v), stats.kurtosis(v)]
    # envelope features
    feats += [np.max(amp), np.min(amp), np.median(amp)]
    # RMS, crest factor
    feats += [np.sqrt(np.mean(np.abs(x)**2)), np.max(amp)/ (np.sqrt(np.mean(amp**2))+1e-12)]
    return np.asarray(feats, dtype=np.float32)

def extract_freq_features(x_complex, fs=1.0):
    """
    Frequency domain features: spectral centroid, bandwidth, spectral rolloff, top peaks
    """
    x = x_complex
    N = len(x)
    # compute PSD via Welch
    f, Pxx = sps.welch(x, fs=fs, nperseg=min(256,N))
    P = np.abs(Pxx) + 1e-12
    P_norm = P / np.sum(P)
    centroid = np.sum(f * P_norm)
    bw = np.sqrt(np.sum(((f-centroid)**2) * P_norm))
    # spectral rolloff 0.85
    cumsum = np.cumsum(P_norm)
    rolloff_idx = np.searchsorted(cumsum, 0.85)
    rolloff_freq = f[min(rolloff_idx, len(f)-1)]
    # top-3 peaks
    peaks, _ = sps.find_peaks(P, height=np.max(P)*0.05)
    peak_vals = sorted(P[peaks], reverse=True)[:3] if len(peaks)>0 else [0,0,0]
    # pack
    feats = [centroid, bw, rolloff_freq] + list(peak_vals) + [np.sum(P**2)]
    return np.asarray(feats, dtype=np.float32)

def extract_spectrogram(x_complex, nfft=256, noverlap=128, fs=1.0):
    """
    Return spectrogram magnitude (log-scaled) shape (freq_bins, time_bins)
    """
    x = x_complex
    f, t, Sxx = sps.stft(x, fs=fs, nperseg=nfft, noverlap=noverlap, padded=False)
    S = np.abs(Sxx)
    S = np.log10(S + 1e-12)
    # normalize
    S = (S - S.mean()) / (S.std()+1e-12)
    return S.astype(np.float32)

def extract_cwt(x_complex, widths=np.arange(1,64)):
    """
    Continuous Wavelet Transform (Ricker wavelet) magnitude features summarized
    We compute CWT and then pool to create fixed-length summary
    """
    x = np.real(x_complex)  # use real part for CWT
    try:
        cwtmat = sps.cwt(x, sps.ricker, widths)
    except Exception:
        # fallback small widths
        widths2 = np.arange(1,32)
        cwtmat = sps.cwt(x, sps.ricker, widths2)
    # summarize by mean/std across time for each scale
    means = cwtmat.mean(axis=1)
    stds = cwtmat.std(axis=1)
    feats = np.concatenate([means, stds]).astype(np.float32)
    # reduce dimensionality if too big by simple downsampling
    if feats.size > 128:
        idx = np.linspace(0, feats.size-1, 128).astype(int)
        feats = feats[idx]
    return feats

def extract_higher_order(x_complex):
    """
    Compute simple higher-order statistics: moments / cumulants approximations
    """
    x = x_complex
    r = np.real(x); im = np.imag(x)
    moments = []
    for v in [r, im, np.abs(x)]:
        moments.append(np.mean(v**3))
        moments.append(np.mean(v**4))
    # cross moments
    moments.append(np.mean(r*im))
    moments.append(np.mean(r**2 * im))
    return np.asarray(moments, dtype=np.float32)

def extract_features_all(X_complex, fs=1.0):
    """
    Given X_complex shape (N, L), compute an aggregated feature matrix and
    spectrograms and cwt arrays.
    Returns:
      feat_matrix: (N, F)
      specs_list: list of spectrogram arrays (per-sample)
      cwt_list: list of cwt-feature arrays (per-sample)
    """
    N, L = X_complex.shape
    feat_list = []
    specs = []
    cwts = []
    for i in range(N):
        x = X_complex[i]
        ftime = extract_time_features(x)
        ffreq = extract_freq_features(x, fs=fs)
        fh = extract_higher_order(x)
        fcwt = extract_cwt(x)
        feat_vec = np.concatenate([ftime, ffreq, fh, fcwt])
        feat_list.append(feat_vec)
        specs.append(extract_spectrogram(x))
        cwts.append(fcwt)
    feat_matrix = np.vstack(feat_list).astype(np.float32)
    return feat_matrix, specs, cwts

# ---------------------------
# Traditional ML pipeline
# ---------------------------
def traditional_pipeline(X_feat, y, n_select=40, do_pca=True, pca_dim=10, cv_folds=5, random_state=SEED):
    """
    Standard pipeline: scaling -> feature selection -> classifier (RF + SVM gridsearch)
    Returns best model, scaler, selector, pca (may be None), and cv history.
    """
    scaler = StandardScaler()
    Xs = scaler.fit_transform(X_feat)

    selector = SelectKBest(score_func=f_classif, k=min(n_select, Xs.shape[1]))
    Xs2 = selector.fit_transform(Xs, y)

    pca = None
    if do_pca:
        pca = PCA(n_components=min(pca_dim, Xs2.shape[1]), random_state=random_state)
        Xs2 = pca.fit_transform(Xs2)

    # classifiers to try
    param_grid_svc = {'C':[0.1,1,10], 'kernel':['rbf'], 'gamma':['scale']}
    param_grid_rf = {'n_estimators':[50,100], 'max_depth':[10,20,None]}

    svc = GridSearchCV(SVC(), param_grid_svc, cv=cv_folds, scoring='accuracy', n_jobs=1)
    rf = GridSearchCV(RandomForestClassifier(random_state=random_state), param_grid_rf, cv=cv_folds, scoring='accuracy', n_jobs=1)

    # run both and pick best
    svc.fit(Xs2, y)
    rf.fit(Xs2, y)
    best = svc.best_estimator_ if svc.best_score_ >= rf.best_score_ else rf.best_estimator_
    history = {'svc':svc, 'rf':rf}
    return best, scaler, selector, pca, history

# ---------------------------
# PyTorch datasets & models
# ---------------------------
class IQDataset(Dataset):
    def __init__(self, X_complex, y, transform=None):
        self.X = X_complex
        self.y = y
        self.transform = transform

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        x = self.X[idx]
        # represent complex as 2 channels (I,Q)
        x_iq = np.stack([np.real(x), np.imag(x)], axis=0).astype(np.float32)  # (2, L)
        if self.transform:
            x_iq = self.transform(x_iq)
        return x_iq, int(self.y[idx])

class SpecDataset(Dataset):
    def __init__(self, specs, y, transform=None):
        # specs: list of 2D arrays (freq_bins, time_bins)
        self.specs = specs
        self.y = y
        self.transform = transform
    def __len__(self):
        return len(self.y)
    def __getitem__(self, idx):
        s = self.specs[idx].astype(np.float32)
        # add channel dim
        s = np.expand_dims(s, 0)
        return s, int(self.y[idx])

# Models
class RawIQ_1DCNN(nn.Module):
    def __init__(self, in_channels=2, seq_len=1024, n_classes=7):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels, 32, kernel_size=7, padding=3)
        self.bn1 = nn.BatchNorm1d(32)
        self.conv2 = nn.Conv1d(32,64,kernel_size=5,padding=2)
        self.bn2 = nn.BatchNorm1d(64)
        self.conv3 = nn.Conv1d(64,128,kernel_size=3,padding=1)
        self.bn3 = nn.BatchNorm1d(128)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.fc1 = nn.Linear(128, 128)
        self.fc2 = nn.Linear(128, n_classes)
    def forward(self, x):
        # x: (B,2,L)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.pool(x).squeeze(-1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class SpecCNN(nn.Module):
    def __init__(self, in_ch=1, n_classes=7):
        super().__init__()
        # small image CNN
        self.conv1 = nn.Conv2d(in_ch,16,kernel_size=3,padding=1)
        self.conv2 = nn.Conv2d(16,32,kernel_size=3,padding=1)
        self.conv3 = nn.Conv2d(32,64,kernel_size=3,padding=1)
        self.pool = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(64*8*8, 128)  # adapt depending on input size
        self.fc2 = nn.Linear(128, n_classes)
    def forward(self, x):
        # x: (B,1,H,W)
        x = F.relu(self.conv1(x))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class LSTMClassifier(nn.Module):
    def __init__(self, in_ch=2, hidden=128, n_layers=2, n_classes=7):
        super().__init__()
        self.rnn = nn.LSTM(input_size=in_ch, hidden_size=hidden, num_layers=n_layers, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden*2, n_classes)
    def forward(self, x):
        # x: (B,2,L) -> transpose to (B,L,2)
        x = x.permute(0,2,1)
        out, _ = self.rnn(x)
        # take last time step
        h = out[:,-1,:]  # (B, hidden*2)
        return self.fc(h)

class TransformerClassifier(nn.Module):
    def __init__(self, in_ch=2, d_model=64, n_heads=4, num_layers=3, n_classes=7):
        super().__init__()
        self.input_proj = nn.Linear(in_ch, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads, dim_feedforward=128)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(d_model, n_classes)
    def forward(self, x):
        # x: (B,2,L) -> (L,B,2)
        x = x.permute(2,0,1)
        x = self.input_proj(x)  # (L,B,d_model)
        x = self.transformer(x)
        x = x.mean(dim=0)  # (B,d_model)
        return self.fc(x)

# ---------------------------
# Training & evaluation helpers
# ---------------------------
def train_torch_model(model, train_loader, val_loader, n_epochs=20, lr=1e-3, weight_decay=1e-4, device=DEVICE, verbose=True):
    model = model.to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()
    best_val = 0.0
    best_state = None
    history = {'train_loss':[], 'val_acc':[]}
    for epoch in range(1, n_epochs+1):
        model.train()
        loss_accum = 0.0
        n_total = 0
        for xb, yb in train_loader:
            xb = xb.to(device)
            yb = yb.to(device)
            opt.zero_grad()
            logits = model(xb)
            loss = criterion(logits, yb)
            loss.backward()
            opt.step()
            loss_accum += loss.item()*xb.size(0)
            n_total += xb.size(0)
        train_loss = loss_accum / n_total
        # validation
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for xb, yb in val_loader:
                xb = xb.to(device); yb = yb.to(device)
                logits = model(xb)
                preds = torch.argmax(logits, dim=1)
                correct += (preds==yb).sum().item()
                total += yb.size(0)
        val_acc = correct/total if total>0 else 0.0
        history['train_loss'].append(train_loss)
        history['val_acc'].append(val_acc)
        if verbose:
            print(f"Epoch {epoch}/{n_epochs} - train_loss {train_loss:.4f} - val_acc {val_acc:.4f}")
        if val_acc > best_val:
            best_val = val_acc
            best_state = model.state_dict()
    if best_state is not None:
        model.load_state_dict(best_state)
    return model, history

def evaluate_model_torch(model, loader, device=DEVICE):
    model = model.to(device)
    model.eval()
    y_true = []
    y_pred = []
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(device)
            logits = model(xb)
            preds = torch.argmax(logits, dim=1).cpu().numpy()
            y_pred.extend(preds.tolist())
            y_true.extend(yb.numpy().tolist())
    return np.array(y_true), np.array(y_pred)

# ---------------------------
# Visualization helpers
# ---------------------------
def plot_confusion(y_true, y_pred, labels, title, fname):
    cm = confusion_matrix(y_true, y_pred, labels=range(len(labels)))
    cmn = cm.astype('float') / (cm.sum(axis=1)[:, np.newaxis] + 1e-12)
    plt.figure(figsize=(8,6))
    sns.heatmap(cmn, annot=True, fmt='.2f', xticklabels=labels, yticklabels=labels, cmap='Blues')
    plt.ylabel('True'); plt.xlabel('Pred'); plt.title(title)
    plt.tight_layout()
    plt.savefig(fname)
    plt.close()

def plot_snr_curve(snr_list, acc_list_dict, fname):
    plt.figure(figsize=(8,5))
    for k,v in acc_list_dict.items():
        plt.plot(snr_list, v, marker='o', label=k)
    plt.xlabel('SNR (dB)'); plt.ylabel('Accuracy'); plt.grid(True); plt.legend()
    plt.tight_layout()
    plt.savefig(fname)
    plt.close()

def plot_tsne(features, labels, label_names, fname, perplexity=30):
    tsne = TSNE(n_components=2, perplexity=perplexity, random_state=SEED)
    X2 = tsne.fit_transform(features)
    plt.figure(figsize=(8,6))
    for lab in np.unique(labels):
        mask = labels==lab
        plt.scatter(X2[mask,0], X2[mask,1], label=label_names[lab], alpha=0.6, s=20)
    plt.legend()
    plt.tight_layout()
    plt.savefig(fname)
    plt.close()

# ---------------------------
# End-to-end experiment orchestration
# ---------------------------
def run_full_experiment():
    """
    This function runs:
      - dataset creation (small sizes for runnable demo)
      - extract features + spectrograms
      - train traditional ML pipeline
      - train DL models (RawIQ CNN & Spec CNN & LSTM)
      - evaluate and produce SNR curves and visualizations
    """
    start = time.time()
    print("===== Creating dataset =====")
    n_per = 120  # per modulation - small so demo runs quickly
    num_samples = 1024
    snr_for_train = 10  # train at moderate SNR with jitter internally
    X, y_mod, y_proto = create_dataset(n_per_class=n_per, num_samples=num_samples, sps=8,
                                       snr_db=snr_for_train, snr_jitter=6, interference_prob=0.25)
    N = X.shape[0]
    print("Dataset shape:", X.shape, "Mod labels:", np.unique(y_mod).shape)

    # split train/test
    X_train, X_test, y_train, y_test = train_test_split(X, y_mod, test_size=0.3, stratify=y_mod, random_state=SEED)
    print("Train/test sizes:", X_train.shape[0], X_test.shape[0])

    print("===== Extracting features (may take a bit) =====")
    feat_train, specs_train, cwts_train = extract_features_all(X_train, fs=1.0)
    feat_test, specs_test, cwts_test = extract_features_all(X_test, fs=1.0)
    print("Feature dims:", feat_train.shape, feat_test.shape)

    # Traditional ML pipeline
    print("===== Training traditional ML pipeline (RandomForest/SVC) =====")
    best_model, scaler, selector, pca, history = traditional_pipeline(feat_train, y_train, n_select=40, do_pca=True, pca_dim=10, cv_folds=3)
    # prepare test features same transform
    Xs_test = scaler.transform(feat_test)
    Xs_test2 = selector.transform(Xs_test)
    if pca is not None:
        Xs_test2 = pca.transform(Xs_test2)
    y_pred_trad = best_model.predict(Xs_test2)
    acc_trad = accuracy_score(y_test, y_pred_trad)
    print("Traditional model test acc:", acc_trad)
    plot_confusion(y_test, y_pred_trad, MODS, "Trad ML Confusion", os.path.join(PLOT_DIR, "conf_trad.png"))

    # Deep Learning: prepare datasets
    print("===== Preparing PyTorch datasets =====")
    # Raw IQ datasets
    ds_train_iq = IQDataset(X_train, y_train)
    ds_test_iq = IQDataset(X_test, y_test)
    batch = 64
    dl_train_iq = DataLoader(ds_train_iq, batch_size=batch, shuffle=True, drop_last=True)
    dl_val_iq = DataLoader(ds_test_iq, batch_size=batch, shuffle=False)

    # Spectrogram datasets (need to resize spectrograms to fixed small size)
    # We'll pad/crop spectrograms to shape (freq_bins=64, time_bins=64) for SpecCNN
    def spec_preproc(spec):
        # center crop or pad to 64x64
        H,W = spec.shape
        Ht, Wt = 64,64
        out = np.zeros((Ht,Wt), dtype=np.float32)
        h0 = max(0,(Ht-H)//2)
        w0 = max(0,(Wt-W)//2)
        h1 = min(H, Ht); w1 = min(W, Wt)
        out[h0:h0+h1, w0:w0+w1] = spec[:h1, :w1]
        return out
    specs_train_proc = [spec_preproc(s) for s in specs_train]
    specs_test_proc = [spec_preproc(s) for s in specs_test]
    ds_train_spec = SpecDataset(specs_train_proc, y_train)
    ds_test_spec = SpecDataset(specs_test_proc, y_test)
    dl_train_spec = DataLoader(ds_train_spec, batch_size=64, shuffle=True, drop_last=True)
    dl_val_spec = DataLoader(ds_test_spec, batch_size=64, shuffle=False)

    # Define models
    print("===== Building models =====")
    raw_model = RawIQ_1DCNN(in_channels=2, seq_len=num_samples, n_classes=len(MODS))
    spec_model = SpecCNN(in_ch=1, n_classes=len(MODS))
    lstm_model = LSTMClassifier(in_ch=2, hidden=128, n_layers=2, n_classes=len(MODS))

    # Train lightweight for demo
    print("===== Training RawIQ CNN =====")
    raw_model, hist_raw = train_torch_model(raw_model, dl_train_iq, dl_val_iq, n_epochs=12, lr=1e-3, verbose=True)
    print("===== Training Spec CNN =====")
    spec_model, hist_spec = train_torch_model(spec_model, dl_train_spec, dl_val_spec, n_epochs=12, lr=1e-3, verbose=True)
    print("===== Training LSTM =====")
    lstm_model, hist_lstm = train_torch_model(lstm_model, dl_train_iq, dl_val_iq, n_epochs=12, lr=1e-3, verbose=True)

    print("===== Evaluating DL models =====")
    y_t_raw, y_p_raw = evaluate_model_torch(raw_model, dl_val_iq)
    print("RawIQ CNN acc:", accuracy_score(y_t_raw, y_p_raw))
    plot_confusion(y_t_raw, y_p_raw, MODS, "RawIQ CNN Confusion", os.path.join(PLOT_DIR, "conf_raw.png"))

    y_t_spec, y_p_spec = evaluate_model_torch(spec_model, dl_val_spec)
    print("Spec CNN acc:", accuracy_score(y_t_spec, y_p_spec))
    plot_confusion(y_t_spec, y_p_spec, MODS, "Spec CNN Confusion", os.path.join(PLOT_DIR, "conf_spec.png"))

    y_t_lstm, y_p_lstm = evaluate_model_torch(lstm_model, dl_val_iq)
    print("LSTM acc:", accuracy_score(y_t_lstm, y_p_lstm))
    plot_confusion(y_t_lstm, y_p_lstm, MODS, "LSTM Confusion", os.path.join(PLOT_DIR, "conf_lstm.png"))

    # -------------------------------
    # SNR curve: evaluate models across SNRs
    # -------------------------------
    print("===== Computing SNR curves (this may take a while) =====")
    snr_vals = list(range(-5, 21, 5))  # -5,0,5,...20 dB
    accs = {'Traditional':[], 'RawCNN':[], 'SpecCNN':[], 'LSTM':[]}
    for s in snr_vals:
        print("Evaluating SNR:", s)
        X_snr, y_snr, _ = create_dataset(n_per_class=30, num_samples=num_samples, sps=8, snr_db=s, snr_jitter=1, interference_prob=0.25)
        # traditional model eval (use feature pipeline)
        feat_s, specs_s, _ = extract_features_all(X_snr, fs=1.0)
        Xs_s = scaler.transform(feat_s)
        Xs_s2 = selector.transform(Xs_s)
        if pca is not None:
            Xs_s2 = pca.transform(Xs_s2)
        y_pred_t = best_model.predict(Xs_s2)
        accs['Traditional'].append(accuracy_score(y_snr, y_pred_t))
        # raw cnn eval
        ds_s_iq = IQDataset(X_snr, y_snr)
        dl_s_iq = DataLoader(ds_s_iq, batch_size=64, shuffle=False)
        _, y_p_raw_s = evaluate_model_torch(raw_model, dl_s_iq)
        accs['RawCNN'].append(accuracy_score(y_snr, y_p_raw_s))
        # spec cnn eval
        specs_s_proc = [spec_preproc(sx) for sx in specs_s]
        ds_s_spec = SpecDataset(specs_s_proc, y_snr)
        dl_s_spec = DataLoader(ds_s_spec, batch_size=64, shuffle=False)
        _, y_p_spec_s = evaluate_model_torch(spec_model, dl_s_spec)
        accs['SpecCNN'].append(accuracy_score(y_snr, y_p_spec_s))
        # lstm eval
        _, y_p_lstm_s = evaluate_model_torch(lstm_model, dl_s_iq)
        accs['LSTM'].append(accuracy_score(y_snr, y_p_lstm_s))

    plot_snr_curve(snr_vals, accs, os.path.join(PLOT_DIR, "snr_curves.png"))
    print("SNR curve saved to", os.path.join(PLOT_DIR, "snr_curves.png"))

    # -------------------------------
    # Interpretability: t-SNE on traditional features & CNN embeddings
    # -------------------------------
    print("===== t-SNE visualization =====")
    # traditional features t-SNE
    X_feat_all = np.vstack([feat_train, feat_test])
    y_all = np.concatenate([y_train, y_test])
    plot_tsne(X_feat_all[:1000], y_all[:1000], MODS, os.path.join(PLOT_DIR, "tsne_features.png"))
    print("t-SNE saved (features)")

    # get embeddings from raw_model (before final FC)
    def get_raw_embeddings(model, X_complex):
        model = model.to(DEVICE)
        model.eval()
        embs = []
        with torch.no_grad():
            for i in range(0, len(X_complex), 64):
                batch = X_complex[i:i+64]
                xb = []
                for x in batch:
                    xb.append(np.stack([np.real(x), np.imag(x)], axis=0))
                xb = np.array(xb).astype(np.float32)
                xb_t = torch.tensor(xb, device=DEVICE)
                # forward until pool layer
                xout = F.relu(model.bn1(model.conv1(xb_t)))
                xout = F.relu(model.bn2(model.conv2(xout)))
                xout = F.relu(model.bn3(model.conv3(xout)))
                xout = model.pool(xout).squeeze(-1).cpu().numpy()
                embs.append(xout)
        embs = np.vstack(embs)
        return embs
    emb_raw = get_raw_embeddings(raw_model, X_test[:1000])
    plot_tsne(emb_raw, y_test[:1000], MODS, os.path.join(PLOT_DIR, "tsne_rawcnn.png"))
    print("t-SNE saved (rawcnn)")

    # -------------------------------
    # Confusion matrix for final comparisons saved above
    # -------------------------------
    # Save sample spectrograms
    plt.figure(figsize=(10,6))
    for i in range(6):
        plt.subplot(2,3,i+1)
        plt.imshow(specs_train_proc[i], aspect='auto', origin='lower')
        plt.title(f"Spec: {MODS[y_train[i]]}")
        plt.axis('off')
    plt.tight_layout()
    plt.savefig(os.path.join(PLOT_DIR, "sample_specs.png"))
    plt.close()

    end = time.time()
    print("Experiment finished in {:.1f}s. Plots saved to {}".format(end-start, PLOT_DIR))

# ---------------------------
# CLI
# ---------------------------
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Signal Classification & Protocol Recognition Suite")
    parser.add_argument('--run', type=str, default='full', help='Which experiment to run: full (default)')
    args = parser.parse_args()
    if args.run == 'full':
        run_full_experiment()
    else:
        run_full_experiment()

第六章代码

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
rf_fingerprint.py
第 6 章 无线电指纹 (RF Fingerprinting) 与抗欺骗 实验代码示例。

包含:
 - 多发射器合成 IQ 信号(模拟硬件非理想性:瞬态、IQ 不平衡、噪声、相噪)
 - 指纹特征提取(稳态特征 + 瞬态特征 + 融合)
 - 深度学习分类网络(1D-CNN / LSTM)
 - 对抗样本攻击 (FGSM) + 对抗训练 (认证)
 - 基本性能评测、混淆矩阵

注意:仅为示范用途,可扩展到真实硬件数据。
"""

import os
import numpy as np
import random
import math
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from scipy.signal import butter, lfilter

# -----------------------
# 固定随机种子和设备
# -----------------------
SEED = 2025
np.random.seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    torch.cuda.manual_seed(SEED)
else:
    DEVICE = torch.device("cpu")

# -----------------------
# 模拟多个发射器 (devices) 的 IQ 信号生成
# -----------------------
def generate_device_signal(device_id, length=2048, fs=1.0):
    """
    生成一个设备发射器的IQ信号,模拟硬件非理想性。
    device_id: 整数,用于区分不同发射器特性
    length: 信号长度
    fs: 采样率 (单位化)
    返回:complex ndarray 长度 length
    """
    t = np.arange(length) / fs

    # 正常载波频率
    fc = 0.1  
    # 每个设备一个独特幅度失真、IQ 不平衡和相噪特性
    amp_error = 1.0 + 0.01 * (device_id - 2)  # 幅度误差
    phase_error = 0.01 * device_id            # 固定相偏差
    iq_imbalance = 0.005 * (device_id - 3)    # IQ 不平衡比例
    phase_noise_strength = 0.001 * device_id  # 相噪强度

    # 基础正弦载波
    carrier = amp_error * np.exp(1j * (2 * np.pi * fc * t + phase_error))

    # 模拟 IQ 不平衡:对 I 分量和 Q 分量施加不同增益
    i = np.real(carrier) * (1 + iq_imbalance)
    q = np.imag(carrier) * (1 - iq_imbalance)

    # 将非理想性叠加:加入相噪
    # 相噪用累积小扰动
    phase_noise = np.cumsum(phase_noise_strength * np.random.randn(length))
    signal = (i + 1j * q) * np.exp(1j * phase_noise)

    # 模拟瞬态部分 (transient): 前几部分注入不稳定增益
    transient_len = length // 10
    transient = np.linspace(1.5, 1.0, transient_len)
    signal[:transient_len] *= transient

    # 添加高斯噪声
    noise_power = 0.001
    noise = np.sqrt(noise_power / 2) * (np.random.randn(length) + 1j * np.random.randn(length))
    signal += noise

    # 归一化功率
    signal = signal / np.sqrt(np.mean(np.abs(signal)**2) + 1e-12)
    return signal.astype(np.complex64)

def generate_dataset(n_devices=5, samples_per_device=200, length=2048):
    """
    为多个设备生成数据集。
    返回:
      X: complex ndarray shape (n_devices * samples_per_device, length)
      y: int 标签 (device_id)
    """
    X = []
    y = []
    for dev in range(n_devices):
        for _ in range(samples_per_device):
            sig = generate_device_signal(dev, length=length)
            X.append(sig)
            y.append(dev)
    X = np.stack(X, axis=0)
    y = np.array(y, dtype=np.int64)
    return X, y

# -----------------------
# 特征提取:稳态 + 瞬态 + 融合
# -----------------------
def extract_steady_features(x, fs=1.0):
    """
    提取稳态特征 (steady-state),从最后的大部分样本区段提取统计量
    返回特征向量 (real)
    """
    # 取尾部 50% 作为稳态
    L = len(x)
    steady = x[L//2:]
    i = np.real(steady)
    q = np.imag(steady)
    amp = np.abs(steady)
    phase = np.angle(steady)
    features = [
        np.mean(i), np.std(i),
        np.mean(q), np.std(q),
        np.mean(amp), np.std(amp),
        np.mean(phase), np.std(phase)
    ]
    return np.array(features, dtype=np.float32)

def extract_transient_features(x, fs=1.0):
    """
    提取瞬态 (transient) 特征,使用前面一个小截断段
    """
    L = len(x)
    transient = x[:L//10]  # 前 10%
    i = np.real(transient)
    q = np.imag(transient)
    amp = np.abs(transient)
    phase = np.unwrap(np.angle(transient))
    # 计算瞬态斜率 / 动态变化
    amp_slope = np.polyfit(np.arange(len(amp)), amp, 1)[0]
    phase_slope = np.polyfit(np.arange(len(phase)), phase, 1)[0]
    features = [amp_slope, phase_slope,
                np.mean(amp), np.std(amp),
                np.mean(i), np.mean(q)]
    return np.array(features, dtype=np.float32)

def extract_features(X):
    """
    对所有信号 X(complex)提取特征融合
    返回特征矩阵 shape (N, D)
    """
    feats = []
    for x in X:
        fs = 1.0
        f1 = extract_steady_features(x, fs)
        f2 = extract_transient_features(x, fs)
        feats.append(np.concatenate([f1, f2]))
    feats = np.vstack(feats)
    return feats

# -----------------------
# PyTorch 数据集与模型
# -----------------------
class RFDataset(Dataset):
    def __init__(self, X, y):
        self.X = X  # complex
        self.y = y
    def __len__(self):
        return len(self.y)
    def __getitem__(self, idx):
        x = self.X[idx]
        # 构造 real 和 imag 两通道
        xi = np.stack([np.real(x), np.imag(x)], axis=0).astype(np.float32)
        return xi, self.y[idx]

class CNN1DFingerprint(nn.Module):
    def __init__(self, in_ch=2, seq_len=2048, n_classes=5):
        super().__init__()
        self.conv1 = nn.Conv1d(in_ch, 32, kernel_size=31, padding=15)
        self.bn1 = nn.BatchNorm1d(32)
        self.conv2 = nn.Conv1d(32, 64, kernel_size=15, padding=7)
        self.bn2 = nn.BatchNorm1d(64)
        self.conv3 = nn.Conv1d(64, 128, kernel_size=7, padding=3)
        self.bn3 = nn.BatchNorm1d(128)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.fc1 = nn.Linear(128, 64)
        self.fc2 = nn.Linear(64, n_classes)
    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))
        x = torch.relu(self.bn2(self.conv2(x)))
        x = torch.relu(self.bn3(self.conv3(x)))
        x = self.pool(x).squeeze(-1)
        x = torch.relu(self.fc1(x))
        out = self.fc2(x)
        return out

# -----------------------
# 对抗攻击 (FGSM) + 认证训练
# -----------------------
def fgsm_attack(x, grads, eps=0.01):
    """
    对输入 x 执行 FGSM 对抗扰动
    x: tensor (B, 2, L)
    grads: gradient of loss wrt x
    eps: 扰动强度
    """
    sign = grads.sign()
    x_adv = x + eps * sign
    return x_adv.detach()

# -----------------------
# 训练 & 对抗训练函数
# -----------------------
def train_classifier(model, loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    for xb, yb in loader:
        xb = xb.to(device)
        yb = yb.to(device)
        optimizer.zero_grad()
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * xb.size(0)
    return running_loss / len(loader.dataset)

def train_with_adversary(model, loader, optimizer, criterion, device, eps=0.01):
    """
    对抗训练:对每个 minibatch 生成 FGSM 样本并一起训练
    """
    model.train()
    running_loss = 0.0
    for xb, yb in loader:
        xb = xb.to(device)
        yb = yb.to(device)

        # 正常前向
        optimizer.zero_grad()
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()

        # 计算梯度对输入
        grads = xb.grad if xb.requires_grad else None
        if grads is None:
            xb_adv = fgsm_attack(xb, torch.autograd.grad(loss, xb, retain_graph=True)[0], eps)
        else:
            xb_adv = fgsm_attack(xb, grads, eps)

        # 对抗样本前向
        logits_adv = model(xb_adv)
        loss_adv = criterion(logits_adv, yb)

        # 合并损失
        loss_total = 0.5 * loss + 0.5 * loss_adv
        optimizer.zero_grad()
        loss_total.backward()
        optimizer.step()

        running_loss += loss_total.item() * xb.size(0)
    return running_loss / len(loader.dataset)

def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    all_pred = []
    all_true = []
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(device)
            yb = yb.to(device)
            logits = model(xb)
            preds = torch.argmax(logits, dim=1)
            correct += (preds == yb).sum().item()
            total += xb.size(0)
            all_pred.extend(preds.cpu().numpy().tolist())
            all_true.extend(yb.cpu().numpy().tolist())
    return correct / total, np.array(all_true), np.array(all_pred)

# -----------------------
# 主流程
# -----------------------
def main():
    print("== 生成指纹数据集 ==")
    X, y = generate_dataset(n_devices=5, samples_per_device=300, length=2048)
    print("数据集大小:", X.shape, y.shape)

    # 划分训练和测试
    idx = np.arange(len(y))
    np.random.shuffle(idx)
    split = int(0.7 * len(y))
    train_idx = idx[:split]
    test_idx = idx[split:]
    X_train, y_train = X[train_idx], y[train_idx]
    X_test, y_test = X[test_idx], y[test_idx]

    print("提取特征 (稳态 + 瞬态) 用于传统分类")
    feats_train = extract_features(X_train)
    feats_test = extract_features(X_test)

    print("传统分类 (随机森林) ")
    from sklearn.ensemble import RandomForestClassifier
    clf = RandomForestClassifier(n_estimators=100, random_state=SEED)
    clf.fit(feats_train, y_train)
    acc_trad = clf.score(feats_test, y_test)
    print("传统随机森林分类器准确率:", acc_trad)

    # 深度学习:构建 PyTorch 数据集
    ds_train = RFDataset(X_train, y_train)
    ds_test = RFDataset(X_test, y_test)
    loader_train = DataLoader(ds_train, batch_size=32, shuffle=True)
    loader_test = DataLoader(ds_test, batch_size=32, shuffle=False)

    model = CNN1DFingerprint(in_ch=2, seq_len=2048, n_classes=len(np.unique(y)))
    model = model.to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    print("训练指纹分类器 (普通训练) …")
    for epoch in range(10):
        loss = train_classifier(model, loader_train, optimizer, criterion, DEVICE)
        acc, _, _ = evaluate(model, loader_test, DEVICE)
        print(f"Epoch {epoch+1}/10, Loss {loss:.4f}, Test Acc {acc:.4f}")

    print("评估训练后模型 …")
    acc, y_true, y_pred = evaluate(model, loader_test, DEVICE)
    print("测试准确率 (干净样本):", acc)

    # 对抗训练
    print("对抗训练 (FGSM) …")
    model_adv = CNN1DFingerprint(in_ch=2, seq_len=2048, n_classes=len(np.unique(y)))
    model_adv = model_adv.to(DEVICE)
    optimizer_adv = optim.Adam(model_adv.parameters(), lr=1e-3)
    for epoch in range(10):
        loss_adv = train_with_adversary(model_adv, loader_train, optimizer_adv, criterion, DEVICE, eps=0.02)
        acc_adv, _, _ = evaluate(model_adv, loader_test, DEVICE)
        print(f"Adv Epoch {epoch+1}/10, Loss {loss_adv:.4f}, Test Acc after attack {acc_adv:.4f}")

    print("评估对抗训练模型 …")
    acc_clean, _, _ = evaluate(model_adv, loader_test, DEVICE)
    print("对抗训练模型测试准确率 (干净):", acc_clean)

    # 使用 FGSM 在测试集上攻击
    print("对抗测试:生成对抗样本并评估 …")
    # 简单地对测试批次生成 FGSM
    model_adv.eval()
    total = 0
    correct_adv = 0
    for xb, yb in loader_test:
        xb = xb.to(DEVICE)
        yb = yb.to(DEVICE)
        xb.requires_grad = True
        logits = model_adv(xb)
        loss = criterion(logits, yb)
        model_adv.zero_grad()
        loss.backward()
        xb_adv = fgsm_attack(xb, xb.grad, eps=0.02)
        preds_adv = torch.argmax(model_adv(xb_adv), dim=1)
        correct_adv += (preds_adv == yb).sum().item()
        total += yb.size(0)
    acc_adv_test = correct_adv / total
    print("对抗样本测试准确率:", acc_adv_test)

    # 混淆矩阵可视化
    from sklearn.metrics import confusion_matrix
    import seaborn as sns
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(6,5))
    sns.heatmap(cm, annot=True, fmt='d')
    plt.xlabel("预测标签"); plt.ylabel("真实标签"); plt.title("混淆矩阵 (干净样本)")
    plt.savefig("confusion_clean.png")
    print("混淆矩阵 (干净样本) 已保存到 confusion_clean.png")

    print("全部流程结束。")

if __name__ == '__main__':
    main()

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值