论文复现:Active Learning with the Furthest NearestNeighbor Criterion for Facial Age Estimation

FNN_2DLDA类实现了一种基于离差(Descripency)的方法,用于多样性采样。该算法计算各类样本中心、全局样本中心以及相关协方差矩阵,并通过特征降维进行图像选择。在特征空间不断变化的情况下,该方法通过增量更新保持模型的准确性。在给定的预算内,按最远最近邻策略选择未标记样本进行标记。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Furthest Nearest Neighbor 方法就是其他文章中的Descripency方法,是一种diversity samplig方法。 

 

由于特征空间是不断变化的,在特征空间上使用Descripency方法违背了该准则的初衷。

import os
import torch
import numpy as np
from copy import deepcopy
from collections import OrderedDict
from PIL import Image
from sklearn.model_selection import StratifiedKFold


class FNN_2DLDA(object):
    def __init__(self, X_train, y_train, labeled, budget, X_test, y_test):
        self.X = X_train
        self.y = y_train
        self.X_test = X_test
        self.y_test = y_test
        self.nSample = X_train.shape[0]
        print("样本个数=",self.nSample)
        self.labeled = list(deepcopy(labeled))     # 已标记样本的索引
        self.unlabeled = self.init_unlabeled_index()
        self.labels = np.sort(np.unique(y_train))  # 标签列表
        self.nClass = len(self.labels)
        self.budget = deepcopy(budget)
        self.nRow, self.nCol = self.X[0].shape     # 图像样本的行数和列数
        self.K = 10
        self.class_mean, self.global_mean, self.class_count = self.get_init_mean()
        self.S_bl, self.S_br, self.S_wl, self.S_wr = self.get_init_Sbl_Sbr_Swl_Swr()
        self.Wl, self.Wr = self.get_Wl_Wr()
        self.X_feature = self.get_feature()
        self.batch_size = 5

    def init_unlabeled_index(self):
        # =============无标记样本索引===============
        unlabeled = [i for i in range(self.nSample)]
        for idx in self.labeled:
            unlabeled.remove(idx)
        return unlabeled

    def get_init_mean(self):
        class_mean = torch.zeros((self.nClass, self.nRow, self.nCol))
        class_count = torch.zeros(self.nClass)
        global_mean = torch.zeros((self.nRow, self.nCol))
        # ========计算各类样本中心========
        for i in range(self.nClass):
            # ==获取第i个类的样本的索引==
            ids = []
            for idx in self.labeled:
                if self.y[idx] == self.labels[i]:
                    ids.append(idx)
            class_count[i] = len(ids)
            class_mean[i] = torch.mean(self.X[ids], dim=0)
        # ==========计算全局样本中心============
        for i in range(self.nClass):
            global_mean += (class_count[i] / len(self.labeled)) * class_mean[i]
        return class_mean, global_mean, class_count

    def get_init_Sbl_Sbr_Swl_Swr(self):
        # =============计算Sbl和Sbr=================
        S_bl = torch.zeros((self.nCol, self.nCol))
        S_br = torch.zeros((self.nRow, self.nRow))
        for i in range(self.nClass):
            tmp = self.class_mean[i] - self.global_mean
            S_bl += self.class_count[i] * torch.mm(tmp.T,tmp)
            S_br += self.class_count[i] * torch.mm(tmp, tmp.T)

        # =============计算Swl和Swr=================
        S_wl = torch.zeros((self.nCol, self.nCol))
        S_wr = torch.zeros((self.nRow, self.nRow))
        for i in range(self.nClass):
            for idx in self.labeled:
                if self.y[idx] == self.labels[i]:
                    tmp = self.X[idx] - self.class_mean[i]
                    S_wl += torch.mm(tmp.T, tmp)
                    S_wr += torch.mm(tmp, tmp.T)
        return S_bl, S_br, S_wl, S_wr

    def get_Wl_Wr(self):
        Wl_eigen_val, Wl_eigen_vec = torch.linalg.eig(torch.mm(torch.linalg.pinv(self.S_wl), self.S_bl))
        Wr_eigen_val, Wr_eigen_vec = torch.linalg.eig(torch.mm(torch.linalg.pinv(self.S_wr), self.S_br))
        odx_Wl = np.flipud(np.argsort(Wl_eigen_val))
        odx_Wr = np.flipud(np.argsort(Wr_eigen_val))
        Wl = torch.ones((self.nCol, self.K))
        Wr = torch.ones((self.K,self.nRow))
        for i in range(self.K):
            Wr[i] = Wr_eigen_vec[odx_Wr[i]]
            Wl[:,i] = Wl_eigen_vec[odx_Wl[i]]
        return Wl, Wr

    def get_feature(self):
        X_featrue = torch.zeros((self.nSample, self.K, self.K))
        for idx in range(self.nSample):
            X_featrue[idx] = torch.mm(torch.mm(self.Wr,self.X[idx]),self.Wl)
        return X_featrue

    def incremental_update_X_feature(self, selected):
        # ========update self.class_mean============
        for i in range(self.nClass):
            tmp_count = 0
            tmp_mean = torch.zeros((self.nRow, self.nCol))
            for idx in selected:
                if self.y[idx] == self.labels[i]:
                    tmp_count += 1
                    tmp_mean += self.X[idx]
            self.class_mean[i] = (self.class_count[i] * self.class_mean[i] + tmp_count * tmp_mean) / (self.class_count[i] + tmp_count)
            self.class_count[i] = self.class_count[i] + tmp_count
        # =========updata self.global_mean===========
        for i in range(self.nClass):
            self.global_mean = torch.zeros((self.nRow, self.nCol))
            self.global_mean += (self.class_count[i] / len(self.labeled)) * self.class_mean[i]

        # =========updata S_bl & S_br ===========
        self.S_bl = torch.zeros((self.nCol, self.nCol))
        self.S_br = torch.zeros((self.nRow, self.nRow))
        for i in range(self.nClass):
            tmp = self.class_mean[i] - self.global_mean
            self.S_bl += self.class_count[i] * torch.mm(tmp.T,tmp)
            self.S_br += self.class_count[i] * torch.mm(tmp, tmp.T)

        # =============update Swl & Swr=================
        self.S_wl = torch.zeros((self.nCol, self.nCol))
        self.S_wr = torch.zeros((self.nRow, self.nRow))
        for i in range(self.nClass):
            for idx in self.labeled:
                if self.y[idx] == self.labels[i]:
                    tmp = self.X[idx] - self.class_mean[i]
                    self.S_wl += torch.mm(tmp.T, tmp)
                    self.S_wr += torch.mm(tmp, tmp.T)


        Wl_eigen_val, Wl_eigen_vec = torch.linalg.eig(torch.mm(torch.linalg.pinv(self.S_wl), self.S_bl))
        Wr_eigen_val, Wr_eigen_vec = torch.linalg.eig(torch.mm(torch.linalg.pinv(self.S_wr), self.S_br))
        odx_Wl = np.flipud(np.argsort(Wl_eigen_val))
        odx_Wr = np.flipud(np.argsort(Wr_eigen_val))
        self.Wl = torch.ones((self.nCol, self.K))
        self.Wr = torch.ones((self.K,self.nRow))
        for i in range(self.K):
            self.Wr[i] = Wr_eigen_vec[odx_Wr[i]]
            self.Wl[:,i] = Wl_eigen_vec[odx_Wl[i]]
        # =============更新特征===============
        self.X_featrue = torch.zeros((self.nSample, self.K, self.K))
        for idx in range(self.nSample):
            self.X_featrue[idx] = torch.mm(torch.mm(self.Wr,self.X[idx]),self.Wl)

    def image_select(self, batch_size):
        metric_dict = OrderedDict()
        for idx in self.labeled:
            min_dist = np.inf
            min_index = None
            for jdx in self.unlabeled:
                dist_tmp = torch.norm(self.X_feature[idx] - self.X_feature[jdx])
                if dist_tmp < min_dist:
                    min_dist = dist_tmp
                    min_index = jdx
            metric_dict[(idx,min_index)] = min_dist
        selected = []
        for i in range(batch_size):
            tar_tuple = max(metric_dict, key=metric_dict.get)
            selected.append(tar_tuple[1])
            self.labeled.append(tar_tuple[1])
            self.labeled.append(tar_tuple[1])
            self.unlabeled.remove(tar_tuple[1])
            del metric_dict[tar_tuple]
            for idx in [tar_tuple[0], tar_tuple[1]]:
                min_dist = np.inf
                min_index = None
                for jdx in self.unlabeled:
                    dist_tmp = torch.norm(self.X_feature[idx] - self.X_feature[jdx])
                    if dist_tmp < min_dist:
                        min_dist = dist_tmp
                        min_index = jdx
                metric_dict[(idx,min_index)] = min_dist
        return selected

    def start(self):
        while self.budget > 0:
            if self.budget > self.batch_size:
                selected = self.image_select(batch_size=self.batch_size)
                self.budget -= self.batch_size
            else:
                selected = self.image_select(batch_size=self.budget)
                self.budget = 0
            print("selected::",selected)
            # ==========如果标记预算还没用完,则还要更新模型============
            if self.budget > 0:
                self.incremental_update_X_feature(selected=selected)







if __name__ == '__main__':
    path_dir = r"E:\PycharmProjects\DataSets\FaceData\yalefaces"
    # ===============基础信息=================
    nSample = 165
    nClass = 11
    labels = [i for i in np.arange(1,nClass+1)]
    nRow = 243
    nCol = 320
    Budget = 30
    # ==============构造标签==================
    y = np.zeros(165)
    i = 0
    label = 1
    j = 1
    while i < 165:
        if i+1 <= label*11:
            y[i] = label
            i += 1
        else:
            label +=1
    # ============读取图片数据=================
    X = torch.zeros((nSample, nRow, nCol))
    index = 0
    for name in os.listdir(path_dir):
        if name.split(".")[0][:7] == "subject":
            img = np.array(Image.open(path_dir + "\\" + name))
            X[index] = torch.from_numpy(img)
            index += 1

    SKF = StratifiedKFold(n_splits=5, shuffle=True)
    for train_idx, test_idx in SKF.split(X=X,y=y):
        X_train = X[train_idx]
        y_train = y[train_idx]
        X_test = X[test_idx]
        y_test = y[test_idx]
        labeled = []
        label_dict = OrderedDict()
        for lab in np.unique(y_train):
            label_dict[lab] = []
        for idx in range(len(y_train)):
            label_dict[y_train[idx]].append(idx)
        for idxlist in label_dict.values():
            for jdx in np.random.choice(idxlist,size=2, replace=False):
                labeled.append(jdx)

        model = FNN_2DLDA(X_train=X_train,y_train=y_train,labeled=labeled,budget=Budget,X_test=X_test,y_test=y_test)
        model.start()
        break

评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

DeniuHe

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值