argparse load data from file

gflags有个--flagfile选项,python的argparse中默认不支持,以下方法可以实现


class LoadFromFile (argparse.Action):

   def __call__ (self, parser, namespace, values, option_string = None):

     with values as f: parser.parse_args(f.read().split(), namespace)




parser = argparse.ArgumentParser()# other arguments

parser.add_argument('-flagfile', type=open, action=LoadFromFile)

args = parser.parse_args()

Traceback (most recent call last): File "/home/wmy/miniconda3/lib/python3.12/site-packages/conda/exception_handler.py", line 18, in __call__ return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/home/wmy/miniconda3/lib/python3.12/site-packages/conda/cli/main.py", line 44, in main_subshell context.__init__(argparse_args=pre_args) File "/home/wmy/miniconda3/lib/python3.12/site-packages/conda/base/context.py", line 518, in __init__ self._set_search_path( File "/home/wmy/miniconda3/lib/python3.12/site-packages/conda/common/configuration.py", line 1432, in _set_search_path self._set_raw_data(dict(self._load_search_path(self._search_path))) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/wmy/miniconda3/lib/python3.12/site-packages/conda/common/configuration.py", line 1421, in _load_search_path yield path, YamlRawParameter.make_raw_parameters_from_file(path) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/wmy/miniconda3/lib/python3.12/site-packages/conda/common/configuration.py", line 398, in make_raw_parameters_from_file yaml_obj = yaml_round_trip_load(fh) ^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/wmy/miniconda3/lib/python3.12/site-packages/conda/common/serialize.py", line 34, in yaml_round_trip_load return _yaml_round_trip().load(string) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/wmy/miniconda3/lib/python3.12/site-packages/ruamel/yaml/main.py", line 451, in load return constructor.get_single_data() ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/wmy/miniconda3/lib/python3.12/site-packages/ruamel/yaml/constructor.py", line 116, in get_single_data return self.construct_document(node) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/wmy/miniconda3/lib/python3.12/site-packages/ruamel/yaml/constructor.
最新发布
03-17
``` import argparse import collections import numpy as np import torch import torch.nn as nn from parse_config import ConfigParser from trainer import Trainer from utils.util import * from data_loader.data_loaders import * import model.loss as module_loss import model.metric as module_metric import model.model as module_arch # 固定随机种子以提高可重复性 SEED = 123 torch.manual_seed(SEED) torch.backends.cudnn.deterministic = False torch.backends.cudnn.benchmark = False np.random.seed(SEED) def weights_init_normal(m): if isinstance(m, (nn.Conv1d, nn.Conv2d)): nn.init.normal_(m.weight.data, 0.0, 0.02) elif isinstance(m, nn.BatchNorm1d): nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0.0) def main(config, fold_id): batch_size = config["data_loader"]["args"]["batch_size"] logger = config.get_logger('train') # 构建模型并初始化权重 model = config.init_obj('arch', module_arch) model.apply(weights_init_normal) logger.info(model) # 获取损失函数和评估指标 criterion = getattr(module_loss, config['loss']) metrics = [getattr(module_metric, met) for met in config['metrics']] # 构建优化器 trainable_params = filter(lambda p: p.requires_grad, model.parameters()) optimizer = config.init_obj('optimizer', torch.optim, trainable_params) # 加载数据 data_loader, valid_data_loader, data_count = data_generator_np( folds_data[fold_id][0], folds_data[fold_id][1], batch_size ) weights_for_each_class = calc_class_weight(data_count) # 初始化训练器并开始训练 trainer = Trainer( model=model, criterion=criterion, metrics=metrics, optimizer=optimizer, config=config, data_loader=data_loader, fold_id=fold_id, valid_data_loader=valid_data_loader, class_weights=weights_for_each_class ) trainer.train() if __name__ == '__main__': args = argparse.ArgumentParser(description='PyTorch Template') args.add_argument('-c', '--config', default="config.json", type=str, help='config file path (default: None)') args.add_argument('-r', '--resume', default=None, type=str, help='path to latest checkpoint (default: None)') args.add_argument('-d', '--device', default="0", type=str, help='indices of GPUs to enable (default: all)') args.add_argument('-f', '--fold_id', type=str, help='fold_id') args.add_argument('-da', '--np_data_dir', type=str, help='Directory containing numpy files') CustomArgs = collections.namedtuple('CustomArgs', 'flags type target') options = [] args2 = args.parse_args() fold_id = int(args2.fold_id) config = ConfigParser.from_args(args, fold_id, options) if "shhs" in args2.np_data_dir: folds_data = load_folds_data_shhs(args2.np_data_dir, config["data_loader"]["args"]["num_folds"]) else: folds_data = load_folds_data(args2.np_data_dir, config["data_loader"]["args"]["num_folds"]) main(config, fold_id)```请帮我逐行解释这段Python代码:特别关注核心算法逻辑、特定语法结构、函数方法的用途、潜在错误排查、代码优化建议,解释每行的基础功能,我是新手需要基础模式,还要将一些上面部分代码有没有调用其他部分代码的指令?如果没有就不需要讲解,如果有的话调用命令是哪个语句?调用了哪部分代码?这部分代码在项目中起到的作用是什么?
03-09
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值