Wide & deep Model:从Google到华为

本文深入探讨了Google提出的Wide & Deep Learning模型及其在推荐系统中的应用,强调了模型的wide部分用于记忆历史交互,而deep部分用于提高泛化能力。华为的DeepFM进一步创新,通过共享参数简化特征工程。实验表明,DeepFM在CTR预测上的表现优于其他模型。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

在之前的一篇博客基于深度学习的推荐系统(二)MLP based中,我简单地提到了wide&deep model。在这里,我将这一模型单独拿出来加以讲述,因为这个模型是很多工业界推荐系统的根基。从Google在2016年发表这篇文章开始,越来越多的公司开始使用wide&deep model和它的变种。其中,华为提出的DeepFM是一个较为典型的变种,之前也提到过一些,我也会在这篇文章中加以讲述。

Wide & Deep Learning for Recommender Systems

Wide & Deep Learning for Recommender Systems是Google在2016年发表的文章。这篇文章一共只有4页,非常非常短,但是其内容产生了巨大的影响。文章大意如下:

使用非线性特征变换的广义线性模型被广泛用于具有稀疏输入的大规模回归和分类问题。通过一系列特征转换,我们可以完成历史交互的memorization(记忆),用这些特征建立的广义线性模型是有效且可解释的。但如果要提升这类模型的泛化(generalization)性能,需要很多的特征工程工作。深度神经网络可以通过针对稀疏特征学习的低维密集嵌入更好地推广到看不见的特征组合,因此需要较少的特征工程。但是当交互信息较少时,它会overfit,学习到一些本来不存在的关联。我们把前者称为wide,后者称为deep,把这两者组合起来,就得到了wide&deep model。

其实我个人认为文中关于memorization和generalization的定义和我们常用的不太一样,所以为了方便起见,我把原文定义在这里写一遍:

One challenge in recommender systems, similar to the general search ranking problem, is to achieve both memorization and generalization. Memorization can be loosely defined as learning the frequent co-occurrence of items or features and exploiting the correlation available in the historical data. Generalization, on the other hand, is based on transitivity of correlation and explores new feature combinations that have never or rarely occurred in the past. Recommendations based on memorization are usually more topical and directly relevant to the items on which users have already performed actions. Compared with memorization, generalization tends to improve the diversity of the recommended items.

所谓wide model,是指logistic regression等使用人工特征的模型,这些模型尤其常用one-hot编码。这些模型简单可解释,但无法对未在训练集中出现的feature建模。所谓deep model,是指embedding-based models,包括FM和深度神经网络。它们能从稀疏数据中学习到稠密的feature embedding,但如果输入矩阵是稀疏但高秩的,它就可能学习出很多并不存在的关联。

wide部分的表达式为 y = W w i d e T { x , ϕ ( x ) } + b y = W^T_{wide}\{x, \phi(x)\} + b y=WwideT{ x,ϕ(x)}+b,其中 w w w是参数, x x x是特征工程引入的input feature, ϕ ( x ) \phi(x) ϕ(x)是这些feature被转换后的形式,最常见的转换为cross-product transformation,定义为 ϕ k ( x ) = ∏ i = 1 d x i c k i , c k i ∈ { 0 , 1 } \phi_k(x) = \prod_{i=1}^{d}x_i^{c_{ki}}, c_{ki} \in \{0, 1\} ϕk

# pygcbs: # app_name: 'APP1' # master: '192.168.0.123' # port: 6789 # level: 'DEBUG' # interval: 1 # checklist: [ "System","CPU", "GPU","Mem","NPU", ] # save_path: "./" # docker: # pygcbs_image: nvidia-pygcbs:v1.0 # worker_image: nvidia-mindspore1.8.1:v1.0 # python_path: /opt/miniconda/bin/python # workers: # - '192.168.0.123:1' # socket_ifname: # - enp4s0 # tasks: #--------------------wide_deep--------------------------------- # - application_domain: "推荐" # task_framework: "Mindspore" # task_type: "推理" # task_name: "wide_deep_infer" # scenario: "SingleStream" # is_run_infer: True # project_path: '/home/gcbs/infer/wide_deep_infer' # main_path: "main.py" # dataset_path: '/home/gcbs/Dataset/wide_deep_data/' # times: 1 # 重试次数 #distribute do_eval: True is_distributed: False is_mhost: False exp_value: 0.501 #model log name: "wide_deep" Metrics: "AUC" request_auc: 0.74 dataset_name: "Criteo 1TB Click Logs Dataset" application: "推荐" standard_time: 3600 python_version: 3.8 mindspore_version: 1.8.1 # Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing) enable_modelarts: False data_url: "" train_url: "" checkpoint_url: "" data_path: "./data" dataset_path: "/home/gcbs/Dataset/wide_deep_data/" output_path: "/cache/train" load_path: "/cache/checkpoint_path" device_target: GPU enable_profiling: False data_format: 1 total_size: 10000000 performance_count: 10 # argparse_init 'WideDeep' epochs: 15 full_batch: False batch_size: 16000 eval_batch_size: 16000 test_batch_size: 16000 field_size: 39 vocab_size: 200000 vocab_cache_size: 0 emb_dim: 80 deep_layer_dim: [1024, 512, 256, 128] deep_layer_act: 'relu' keep_prob: 1.0 dropout_flag: False ckpt_path: "./check_points" stra_ckpt: "./check_points" eval_file_name: "./output/eval.log" loss_file_name: "./output/loss.log" host_device_mix: 0 dataset_type: "mindrecord" parameter_server: 0 field_slice: False sparse: False use_sp: True deep_table_slice_mode: "column_slice" #star_logen config mlperf_conf: './test.conf' user_conf: './user.conf' output: '/tmp/code/' scenario: 'Offline' max_batchsize: 16000 threads: 4 model_path: "./check_points/widedeep_train-12_123328.ckpt" is_accuracy: False find_peak_performance: False duration: False target_qps: False count_queries: False samples_per_query_multistream: False max_latency: False samples_per_query_offline: 500 # WideDeepConfig #data_path: "./test_raw_data/" #vocab_cache_size: 100000 #stra_ckpt: './checkpoints/strategy.ckpt' weight_bias_init: ['normal', 'normal'] emb_init: 'normal' init_args: [-0.01, 0.01] l2_coef: 0.00008 # 8e-5 manual_shape: None # wide_and_deep export device_id: 1 ckpt_file: "./check_points/widedeep_train-12_123328.ckpt" file_name: "wide_and_deep" file_format: "MINDIR" # src/process_data.py "Get and Process datasets" raw_data_path: "./raw_data" # src/preprocess_data.py "Recommendation dataset" dense_dim: 13 slot_dim: 26 threshold: 100 train_line_count: 45840617 skip_id_convert: 0 # src/generate_synthetic_data.py 'Generate Synthetic Data' output_file: "./train.txt" label_dim: 2 number_examples: 4000000 vocabulary_size: 400000000 random_slot_values: 0 #get_score threads_count: 4 base_score: 1 accuracy: 0.72 baseline_performance: 1文件中的这些是什么?、
06-17
# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """wide and deep model""" import time import numpy as np from mindspore import nn, context from mindspore import Parameter, ParameterTuple import mindspore.common.dtype as mstype import mindspore.ops as ops from mindspore.nn import Dropout from mindspore.nn.optim import Adam, FTRL from mindspore.common.initializer import Uniform, initializer from mindspore.context import ParallelMode from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.communication.management import get_group_size np_type = np.float32 ms_type = mstype.float32 def init_method(method, shape, name, max_val=1.0): ''' parameter init method ''' if method in ['uniform']: params = Parameter(initializer( Uniform(max_val), shape, ms_type), name=name) elif method == "one": params = Parameter(initializer("ones", shape, ms_type), name=name) elif method == 'zero': params = Parameter(initializer("zeros", shape, ms_type), name=name) elif method == "normal": params = Parameter(initializer("normal", shape, ms_type), name=name) return params def init_var_dict(init_args, in_vars): ''' var init function ''' var_map = {} _, _max_val = init_args for _, item in enumerate(in_vars): key, shape, method = item if key not in var_map.keys(): if method in ['random', 'uniform']: var_map[key] = Parameter(initializer( Uniform(_max_val), shape, ms_type), name=key) elif method == "one": var_map[key] = Parameter(initializer( "ones", shape, ms_type), name=key) elif method == "zero": var_map[key] = Parameter(initializer( "zeros", shape, ms_type), name=key) elif method == 'normal': var_map[key] = Parameter(initializer( "normal", shape, ms_type), name=key) return var_map class DenseLayer(nn.Cell): """ Dense Layer for Deep Layer of WideDeep Model; Containing: activation, matmul, bias_add; Args: """ def __init__(self, input_dim, output_dim, weight_bias_init, act_str, keep_prob=0.5, use_activation=True, convert_dtype=True, drop_out=False): super(DenseLayer, self).__init__() weight_init, bias_init = weight_bias_init self.weight = init_method( weight_init, [input_dim, output_dim], name="weight") self.bias = init_method(bias_init, [output_dim], name="bias") self.act_func = self._init_activation(act_str) self.matmul = ops.MatMul(transpose_b=False) self.bias_add = ops.BiasAdd() self.cast = ops.Cast() self.dropout = Dropout(keep_prob=(1 - keep_prob)) self.use_activation = use_activation self.convert_dtype = convert_dtype self.drop_out = drop_out def _init_activation(self, act_str): act_str = act_str.lower() if act_str == "relu": act_func = ops.ReLU() elif act_str == "sigmoid": act_func = ops.Sigmoid() elif act_str == "tanh": act_func = ops.Tanh() return act_func def construct(self, x): ''' Construct Dense layer ''' if self.training and self.drop_out: x = self.dropout(x) if self.convert_dtype: x = self.cast(x, mstype.float16) weight = self.cast(self.weight, mstype.float16) bias = self.cast(self.bias, mstype.float16) wx = self.matmul(x, weight) wx = self.bias_add(wx, bias) if self.use_activation: wx = self.act_func(wx) wx = self.cast(wx, mstype.float32) else: wx = self.matmul(x, self.weight) wx = self.bias_add(wx, self.bias) if self.use_activation: wx = self.act_func(wx) return wx class WideDeepModel(nn.Cell): """ From paper: " Wide & Deep Learning for Recommender Systems" Args: config (Class): The default config of Wide&Deep """ def __init__(self, config): super(WideDeepModel, self).__init__() self.batch_size = config.batch_size host_device_mix = bool(config.host_device_mix) parameter_server = bool(config.parameter_server) parallel_mode = context.get_auto_parallel_context("parallel_mode") is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) if is_auto_parallel: self.batch_size = self.batch_size * get_group_size() sparse = config.sparse self.field_size = config.field_size self.emb_dim = config.emb_dim self.weight_init, self.bias_init = config.weight_bias_init self.deep_input_dims = self.field_size * self.emb_dim self.all_dim_list = [self.deep_input_dims] + config.deep_layer_dim + [1] init_acts = [('Wide_b', [1], config.emb_init)] var_map = init_var_dict(config.init_args, init_acts) self.wide_b = var_map["Wide_b"] self.dense_layer_1 = DenseLayer(self.all_dim_list[0], self.all_dim_list[1], config.weight_bias_init, config.deep_layer_act, convert_dtype=True, drop_out=config.dropout_flag) self.dense_layer_2 = DenseLayer(self.all_dim_list[1], self.all_dim_list[2], config.weight_bias_init, config.deep_layer_act, convert_dtype=True, drop_out=config.dropout_flag) self.dense_layer_3 = DenseLayer(self.all_dim_list[2], self.all_dim_list[3], config.weight_bias_init, config.deep_layer_act, convert_dtype=True, drop_out=config.dropout_flag) self.dense_layer_4 = DenseLayer(self.all_dim_list[3], self.all_dim_list[4], config.weight_bias_init, config.deep_layer_act, convert_dtype=True, drop_out=config.dropout_flag) self.dense_layer_5 = DenseLayer(self.all_dim_list[4], self.all_dim_list[5], config.weight_bias_init, config.deep_layer_act, use_activation=False, convert_dtype=True, drop_out=config.dropout_flag) self.wide_mul = ops.Mul() self.deep_mul = ops.Mul() self.reduce_sum = ops.ReduceSum(keep_dims=False) self.reshape = ops.Reshape() self.deep_reshape = ops.Reshape() self.square = ops.Square() self.concat = ops.Concat(axis=1) self.unique = ops.Unique().shard(((1,),)) self.wide_gatherv2 = ops.Gather() self.deep_gatherv2 = ops.Gather() if is_auto_parallel and sparse and not config.field_slice and not parameter_server: target = 'CPU' if host_device_mix else 'DEVICE' self.wide_embeddinglookup = nn.EmbeddingLookup(config.vocab_size, 1, target=target, slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE) if config.deep_table_slice_mode == "column_slice": self.deep_embeddinglookup = nn.EmbeddingLookup(config.vocab_size, self.emb_dim, target=target, slice_mode=nn.EmbeddingLookup.TABLE_COLUMN_SLICE) if config.use_sp: self.dense_layer_1.matmul.shard(((1, get_group_size()), (get_group_size(), 1))) self.dense_layer_1.bias_add.shard(((get_group_size(), 1), (1,))) self.deep_mul.shard(((1, 1, get_group_size()), (1, 1, 1))) else: self.dense_layer_1.dropout.dropout.shard(((1, get_group_size()),)) self.dense_layer_1.matmul.shard(((1, get_group_size()), (get_group_size(), 1))) self.deep_mul.shard(((1, 1, get_group_size()), (1, 1, 1))) self.dense_layer_1.matmul.add_prim_attr("field_size", self.field_size) self.deep_reshape.add_prim_attr("skip_redistribution", True) else: self.deep_embeddinglookup = nn.EmbeddingLookup(config.vocab_size, self.emb_dim, target=target, slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE) self.reduce_sum.add_prim_attr("cross_batch", True) self.embedding_table = self.deep_embeddinglookup.embedding_table elif is_auto_parallel and host_device_mix and config.field_slice and config.full_batch and config.manual_shape: manual_shapes = tuple((s[0] for s in config.manual_shape)) self.deep_embeddinglookup = nn.EmbeddingLookup(config.vocab_size, self.emb_dim, slice_mode=nn.EmbeddingLookup.FIELD_SLICE, manual_shapes=manual_shapes) self.wide_embeddinglookup = nn.EmbeddingLookup(config.vocab_size, 1, slice_mode=nn.EmbeddingLookup.FIELD_SLICE, manual_shapes=manual_shapes) self.deep_mul.shard(((1, get_group_size(), 1), (1, get_group_size(), 1))) self.wide_mul.shard(((1, get_group_size(), 1), (1, get_group_size(), 1))) self.reduce_sum.shard(((1, get_group_size(), 1),)) self.dense_layer_1.dropout.dropout.shard(((1, get_group_size()),)) self.dense_layer_1.matmul.shard(((1, get_group_size()), (get_group_size(), 1))) self.embedding_table = self.deep_embeddinglookup.embedding_table elif parameter_server: cache_enable = config.vocab_cache_size > 0 target = 'DEVICE' if cache_enable else 'CPU' if not cache_enable: sparse = True if is_auto_parallel and config.full_batch and cache_enable: self.deep_embeddinglookup = nn.EmbeddingLookup(config.vocab_size, self.emb_dim, target=target, slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE, sparse=sparse, vocab_cache_size=config.vocab_cache_size) self.wide_embeddinglookup = nn.EmbeddingLookup(config.vocab_size, 1, target=target, slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE, sparse=sparse, vocab_cache_size=config.vocab_cache_size) else: self.deep_embeddinglookup = nn.EmbeddingLookup(config.vocab_size, self.emb_dim, target=target, sparse=sparse, vocab_cache_size=config.vocab_cache_size) self.wide_embeddinglookup = nn.EmbeddingLookup(config.vocab_size, 1, target=target, sparse=sparse, vocab_cache_size=config.vocab_cache_size) self.embedding_table = self.deep_embeddinglookup.embedding_table self.deep_embeddinglookup.embedding_table.set_param_ps() self.wide_embeddinglookup.embedding_table.set_param_ps() else: self.deep_embeddinglookup = nn.EmbeddingLookup(config.vocab_size, self.emb_dim, target='DEVICE', sparse=sparse, vocab_cache_size=config.vocab_cache_size) self.wide_embeddinglookup = nn.EmbeddingLookup(config.vocab_size, 1, target='DEVICE', sparse=sparse, vocab_cache_size=config.vocab_cache_size) self.embedding_table = self.deep_embeddinglookup.embedding_table def construct(self, id_hldr, wt_hldr): """ Args: id_hldr: batch ids; wt_hldr: batch weights; """ # Wide layer wide_id_weight = self.wide_embeddinglookup(id_hldr) # Deep layer deep_id_embs = self.deep_embeddinglookup(id_hldr) mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1)) # Wide layer wx = self.wide_mul(wide_id_weight, mask) wide_out = self.reshape(self.reduce_sum(wx, 1) + self.wide_b, (-1, 1)) # Deep layer vx = self.deep_mul(deep_id_embs, mask) deep_in = self.deep_reshape(vx, (-1, self.field_size * self.emb_dim)) deep_in = self.dense_layer_1(deep_in) deep_in = self.dense_layer_2(deep_in) deep_in = self.dense_layer_3(deep_in) deep_in = self.dense_layer_4(deep_in) deep_out = self.dense_layer_5(deep_in) out = wide_out + deep_out return out, self.embedding_table class NetWithLossClass(nn.Cell): """" Provide WideDeep training loss through network. Args: network (Cell): The training network config (Class): WideDeep config """ def __init__(self, network, config): super(NetWithLossClass, self).__init__(auto_prefix=False) host_device_mix = bool(config.host_device_mix) parameter_server = bool(config.parameter_server) sparse = config.sparse parallel_mode = context.get_auto_parallel_context("parallel_mode") is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) self.no_l2loss = (is_auto_parallel if (host_device_mix or config.field_slice) else parameter_server) if sparse: self.no_l2loss = True self.network = network self.l2_coef = config.l2_coef self.loss = ops.SigmoidCrossEntropyWithLogits() self.square = ops.Square() self.reduceMean_false = ops.ReduceMean(keep_dims=False) if is_auto_parallel: self.reduceMean_false.add_prim_attr("cross_batch", True) self.reduceSum_false = ops.ReduceSum(keep_dims=False) def construct(self, batch_ids, batch_wts, label): ''' Construct NetWithLossClass ''' predict, embedding_table = self.network(batch_ids, batch_wts) log_loss = self.loss(predict, label) wide_loss = self.reduceMean_false(log_loss) if self.no_l2loss: deep_loss = wide_loss else: l2_loss_v = self.reduceSum_false(self.square(embedding_table)) / 2 deep_loss = self.reduceMean_false(log_loss) + self.l2_coef * l2_loss_v return wide_loss, deep_loss class IthOutputCell(nn.Cell): def __init__(self, network, output_index): super(IthOutputCell, self).__init__() self.network = network self.output_index = output_index def construct(self, x1, x2, x3): predict = self.network(x1, x2, x3)[self.output_index] return predict class TrainStepWrap(nn.Cell): """ Encapsulation class of WideDeep network training. Append Adam and FTRL optimizers to the training network after that construct function can be called to create the backward graph. Args: network (Cell): The training network. Note that loss function should have been added. sens (Number): The adjust parameter. Default: 1024.0 host_device_mix (Bool): Whether run in host and device mix mode. Default: False parameter_server (Bool): Whether run in parameter server mode. Default: False """ def __init__(self, network, sens=1024.0, host_device_mix=False, parameter_server=False, sparse=False, cache_enable=False): super(TrainStepWrap, self).__init__() parallel_mode = context.get_auto_parallel_context("parallel_mode") is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) self.network = network self.network.set_train() self.trainable_params = network.trainable_params() weights_w = [] weights_d = [] for params in self.trainable_params: if 'wide' in params.name: weights_w.append(params) else: weights_d.append(params) self.weights_w = ParameterTuple(weights_w) self.weights_d = ParameterTuple(weights_d) if (sparse and is_auto_parallel) or (sparse and parameter_server): self.optimizer_d = Adam( self.weights_d, learning_rate=5e-4, eps=1e-8, loss_scale=sens, use_lazy=True) self.optimizer_w = FTRL(learning_rate=1e-3, params=self.weights_w, l1=1e-8, l2=1e-8, initial_accum=1.0, loss_scale=sens) if host_device_mix or (parameter_server and not cache_enable): self.optimizer_w.target = "CPU" self.optimizer_d.target = "CPU" else: self.optimizer_d = Adam( self.weights_d, learning_rate=5e-4, eps=1e-8, loss_scale=sens) self.optimizer_w = FTRL(learning_rate=1e-3, params=self.weights_w, l1=1e-8, l2=1e-8, initial_accum=1.0, loss_scale=sens) self.hyper_map = ops.HyperMap() self.grad_w = ops.GradOperation(get_by_list=True, sens_param=True) self.grad_d = ops.GradOperation(get_by_list=True, sens_param=True) self.sens = sens self.loss_net_w = IthOutputCell(network, output_index=0) self.loss_net_d = IthOutputCell(network, output_index=1) self.loss_net_w.set_grad() self.loss_net_d.set_grad() self.reducer_flag = False self.grad_reducer_w = None self.grad_reducer_d = None self.reducer_flag = parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL) if self.reducer_flag: mean = context.get_auto_parallel_context("gradients_mean") degree = context.get_auto_parallel_context("device_num") self.grad_reducer_w = DistributedGradReducer(self.optimizer_w.parameters, mean, degree) self.grad_reducer_d = DistributedGradReducer(self.optimizer_d.parameters, mean, degree) def construct(self, batch_ids, batch_wts, label): ''' Construct wide and deep model ''' weights_w = self.weights_w weights_d = self.weights_d loss_w, loss_d = self.network(batch_ids, batch_wts, label) sens_w = ops.Fill()(ops.DType()(loss_w), ops.Shape()(loss_w), self.sens) sens_d = ops.Fill()(ops.DType()(loss_d), ops.Shape()(loss_d), self.sens) grads_w = self.grad_w(self.loss_net_w, weights_w)(batch_ids, batch_wts, label, sens_w) grads_d = self.grad_d(self.loss_net_d, weights_d)(batch_ids, batch_wts, label, sens_d) if self.reducer_flag: grads_w = self.grad_reducer_w(grads_w) grads_d = self.grad_reducer_d(grads_d) return ops.depend(loss_w, self.optimizer_w(grads_w)), ops.depend(loss_d, self.optimizer_d(grads_d)) class PredictWithSigmoid(nn.Cell): """ Predict definition """ def __init__(self, network): super(PredictWithSigmoid, self).__init__() self.network = network self.sigmoid = ops.Sigmoid() parallel_mode = context.get_auto_parallel_context("parallel_mode") full_batch = context.get_auto_parallel_context("full_batch") is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) if is_auto_parallel and full_batch: self.sigmoid.shard(((1, 1),)) def construct(self, batch_ids, batch_wts, labels): logits, _, = self.network(batch_ids, batch_wts) pred_probs = self.sigmoid(logits) return logits, pred_probs, labels # Pre processing def pre_process_criteo_wide_deep(x): return x class WideDeepPostProcess: def __init__(self): self.good = 0 self.total = 0 self.roc_auc = 0 self.results = [] self.labels = [] def __call__(self, results, expected=None, result_dict=None): processed_results = [] n = len(results) for idx in range(0, n): result = results['auc'] processed_results.append(result) self.good += 1 self.total += 1 return processed_results def add_results(self, labels, results): self.results.append(results) self.labels.append(labels) def start(self): self.good = 0 self.total = 0 self.roc_auc = 0 self.results = [] def finalize(self, result_dict, ds=False, output_dir=None): result_dict["good"] = self.good result_dict["total"] = self.total 这段代码包含模型部分和训练部分么?
最新发布
06-17
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值