解析:global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]

本文深入解析Python中的split()函数,包括其语法、参数说明及应用实例,对比os.path.split()和os.path.splitext(),帮助读者掌握字符串分割技巧。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

打印模型的路径及名称
print(ckpt.model_checkpoint_path)

输出如这样:MNIST_model/mnist_model-29001

split函数:拆分字符串。通过指定分隔符对字符串进行切片,并返回分割后的字符串列表(list)。
split函数返回值为:分割后的字符串列表。
list[n]:即表示选取第n个分片,n为-1即为末尾倒数第一个分片(分片即为在返回值列表中元素)

通过文件名得到模型保存时的迭代轮数
global_step = ckpt.model_checkpoint_path.split(’/’)[-1].split(’-’)[-1]
即对字符串"MNIST_model/mnist_model-29001"进行分割
分隔结果为:29001

知识点部分
python中的split函数

Python中有split()和os.path.split()两个函数,具体作用如下:

split():拆分字符串。通过指定分隔符对字符串进行切片,并返回分割后的字符串列表(list)

os.path.split():按照路径将文件名和路径分割开,返回的是元组

1. split()函数

语法:str.split(str="",num=string.count(str))[n]

参数说明:
str:表示为分隔符,默认为空格,但是不能为空(’’)。若字符串中没有分隔符,则把整个字符串作为列表的一个元素
num:表示分割次数。如果存在参数num,则仅分隔成 num+1 个子字符串,并且每一个子字符串可以赋给新的变量
[n]:表示选取第n个分片

注意:当没有参数的情况下,函数默认会以空格,回车符,空格符等作为分割条件。
str.split(str="",num=string.count(str)),返回值是列表,str.split(str="",num=string.count(str))[n],返回值是字符串

a = "My name is YoungFan"
b = "My\nname\tis\nYoungFan"  # \t 、\n是转义字符
c = "my\tname\tis YoungFan"

a = a.split()
b = b.split()
c = c.split(' ')  # 分隔符为一个空格

print(a)
print(b)
print(c)

输出:
['My', 'name', 'is', 'YoungFan']
['My', 'name', 'is', 'YoungFan']
['my\tname\tis', 'YoungFan']

string = "hello.world.python"
print(string.split('.'))  # 输出为:['hello', 'world', 'python']
print(string.split('.', 1))  # 输出为:['hello', 'world.python']
print(string.split('.', 1)[0])  # 输出为:hello
print(string.split('.', 1)[1])  # 输出为:world.python
string2 = "hello<python.world>and<c++>end"
print(string2.split("<", 2)[2].split(">")[0])  # 输出为:c++
2. os.path.split()

os.path.split() 函数将文件路径和文件名分开

x = os.path.split("E:/GitHub收藏项目/keras-yolo3-master/yolo.py")
print(x)

输出:('E:/GitHub收藏项目/keras-yolo3-master', 'yolo.py')

下面扩展一个知识点

3. os.path.splitext()

os.path.splitext() 函数将文件名和扩展名分开

y = os.path.splitext("E:/GitHub收藏项目/keras-yolo3-master/yolo.py")
print(y)  # 返回的是一个元组类型的  

输出:('E:/GitHub收藏项目/keras-yolo3-master/yolo', '.py')
分割完可以多重赋值
import cv2
major, minor = cv2.__version__.split('.')[:2]  # (字符串,列表,元组)切片:后边界不包括,例如[0,1,2][0:1]得[0]
print(major)  # 输出3
print(minor)  # 输出3
# 说明这是OpenCV3.3以上的版本

参考:
Python中的split()函数的用法
https://blog.youkuaiyun.com/meccaendless/article/details/78027012
python split(), os.path.split()和os.path.splitext()函数:
https://blog.youkuaiyun.com/zzc15806/article/details/81352742
python split(),os.path.split()和os.path.splitext()函数用法:
https://blog.youkuaiyun.com/T1243_3/article/details/80170006
python中split()函数讲解
https://blog.youkuaiyun.com/csdn15698845876/article/details/74012511

以下代码有什么错误,怎么修改: import tensorflow.compat.v1 as tf tf.disable_v2_behavior() from PIL import Image import matplotlib.pyplot as plt import input_data import model import numpy as np import xlsxwriter num_threads = 4 def evaluate_one_image(): workbook = xlsxwriter.Workbook('formatting.xlsx') worksheet = workbook.add_worksheet('My Worksheet') with tf.Graph().as_default(): BATCH_SIZE = 1 N_CLASSES = 4 image = tf.cast(image_array, tf.float32) image = tf.image.per_image_standardization(image) image = tf.reshape(image, [1, 208, 208, 3]) logit = model.cnn_inference(image, BATCH_SIZE, N_CLASSES) logit = tf.nn.softmax(logit) x = tf.placeholder(tf.float32, shape=[208, 208, 3]) logs_train_dir = 'log/' saver = tf.train.Saver() with tf.Session() as sess: print("从指定路径中加载模型...") ckpt = tf.train.get_checkpoint_state(logs_train_dir) if ckpt and ckpt.model_checkpoint_path: global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] saver.restore(sess, ckpt.model_checkpoint_path) print('模型加载成功, 训练的步数为: %s' % global_step) else: print('模型加载失败,checkpoint文件没找到!') prediction = sess.run(logit, feed_dict={x: image_array}) max_index = np.argmax(prediction) workbook.close() def evaluate_images(test_img): coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) for index,img in enumerate(test_img): image = Image.open(img) image = image.resize([208, 208]) image_array = np.array(image) tf.compat.v1.threading.Thread(target=evaluate_one_image, args=(image_array, index)).start() coord.request_stop() coord.join(threads) if __name__ == '__main__': test_dir = 'data/test/' import glob import xlwt test_img = glob.glob(test_dir + '*.jpg') evaluate_images(test_img)
07-08
import time import tensorflow.compat.v1 as tf tf.disable_v2_behavior() from tensorflow.examples.tutorials.mnist import input_data import mnist_inference import mnist_train tf.compat.v1.reset_default_graph() EVAL_INTERVAL_SECS = 10 def evaluate(mnist): with tf.Graph().as_default() as g: #定义输入与输出的格式 x = tf.compat.v1.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input') y_ = tf.compat.v1.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input') validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels} #直接调用封装好的函数来计算前向传播的结果 y = mnist_inference.inference(x, None) #计算正确率 correcgt_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correcgt_prediction, tf.float32)) #通过变量重命名的方式加载模型 variable_averages = tf.train.ExponentialMovingAverage(0.99) variable_to_restore = variable_averages.variables_to_restore() saver = tf.train.Saver(variable_to_restore) #每隔10秒调用一次计算正确率的过程以检测训练过程中正确率的变化 while True: with tf.compat.v1.Session() as sess: ckpt = tf.train.get_checkpoint_state(minist_train.MODEL_SAVE_PATH) if ckpt and ckpt.model_checkpoint_path: #load the model saver.restore(sess, ckpt.model_checkpoint_path) global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] accuracy_score = sess.run(accuracy, feed_dict=validate_feed) print("After %s training steps, validation accuracy = %g" % (global_step, accuracy_score)) else: print('No checkpoint file found') return time.sleep(EVAL_INTERVAL_SECS) def main(argv=None): mnist = input_data.read_data_sets(r"D:\Anaconda123\Lib\site-packages\tensorboard\mnist", one_hot=True) evaluate(mnist) if __name__ == '__main__': tf.compat.v1.app.run()对代码进行改进
05-26
IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] def get_args_parser(): parser = argparse.ArgumentParser() # dataset parser.add_argument('--checkpoint_dir', default='tmp', type=str, help='where to save the training log and models') parser.add_argument('--stage', default='sceneflow', type=str, help='training stage on different datasets') parser.add_argument('--val_dataset', default=['kitti15'], type=str, nargs='+') parser.add_argument('--max_disp', default=400, type=int, help='exclude very large disparity in the loss function') parser.add_argument('--img_height', default=288, type=int) parser.add_argument('--img_width', default=512, type=int) parser.add_argument('--padding_factor', default=16, type=int) # training parser.add_argument('--batch_size', default=64, type=int) parser.add_argument('--num_workers', default=8, type=int) parser.add_argument('--lr', default=1e-3, type=float) parser.add_argument('--weight_decay', default=1e-4, type=float) parser.add_argument('--seed', default=326, type=int) # resume pretrained model or resume training parser.add_argument('--resume', default=None, type=str, help='resume from pretrained model or resume from unexpectedly terminated training') parser.add_argument('--strict_resume', action='store_true', help='strict resume while loading pretrained weights') parser.add_argument('--no_resume_optimizer', action='store_true') parser.add_argument('--resume_exclude_upsampler', action='store_true') # model: learnable parameters parser.add_argument('--task', default='stereo', choices=['flow', 'stereo', 'depth'], type=str) parser.add_argument('--num_scales', default=1, type=int, help='feature scales: 1/8 or 1/8 + 1/4') parser.add_argument('--feature_channels', default=128, type=int) parser.add_argument('--upsample_factor', default=8, type=int) parser.add_argument('--num_head', default=1, type=int) parser.add_argument('--ffn_dim_expansion', default=4, type=int) parser.add_argument('--num_transformer_layers', default=6, type=int) parser.add_argument('--reg_refine', action='store_true', help='optional task-specific local regression refinement') # model: parameter-free parser.add_argument('--attn_type', default='self_swin2d_cross_1d', type=str, help='attention function') parser.add_argument('--attn_splits_list', default=[2], type=int, nargs='+', help='number of splits in attention') parser.add_argument('--corr_radius_list', default=[-1], type=int, nargs='+', help='correlation radius for matching, -1 indicates global matching') parser.add_argument('--prop_radius_list', default=[-1], type=int, nargs='+', help='self-attention radius for propagation, -1 indicates global attention') parser.add_argument('--num_reg_refine', default=1, type=int, help='number of additional local regression refinement') # evaluation parser.add_argument('--eval', action='store_true') parser.add_argument('--inference_size', default=None, type=int, nargs='+') parser.add_argument('--count_time', action='store_true') parser.add_argument('--save_vis_disp', action='store_true') parser.add_argument('--save_dir', default=None, type=str) parser.add_argument('--middlebury_resolution', default='F', choices=['Q', 'H', 'F']) # submission parser.add_argument('--submission', action='store_true') parser.add_argument('--eth_submission_mode', default='train', type=str, choices=['train', 'test']) parser.add_argument('--middlebury_submission_mode', default='training', type=str, choices=['training', 'test']) parser.add_argument('--output_path', default='output', type=str) # log parser.add_argument('--summary_freq', default=100, type=int, help='Summary frequency to tensorboard (iterations)') parser.add_argument('--save_ckpt_freq', default=1000, type=int, help='Save checkpoint frequency (steps)') parser.add_argument('--val_freq', default=1000, type=int, help='validation frequency in terms of training steps') parser.add_argument('--save_latest_ckpt_freq', default=1000, type=int) parser.add_argument('--num_steps', default=100000, type=int) # distributed training parser.add_argument('--distributed', action='store_true') parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--launcher', default='none', type=str) parser.add_argument('--gpu_ids', default=0, type=int, nargs='+') # inference parser.add_argument('--inference_dir', default=None, type=str) parser.add_argument('--inference_dir_left', default=None, type=str) parser.add_argument('--inference_dir_right', default=None, type=str) parser.add_argument('--pred_bidir_disp', action='store_true', help='predict both left and right disparities') parser.add_argument('--pred_right_disp', action='store_true', help='predict right disparity') parser.add_argument('--save_pfm_disp', action='store_true', help='save predicted disparity as .pfm format') parser.add_argument('--debug', action='store_true') return parser def main(args): print_info = not args.eval and not args.submission and args.inference_dir is None and \ args.inference_dir_left is None and args.inference_dir_right is None if print_info and args.local_rank == 0: print(args) misc.save_args(args) misc.check_path(args.checkpoint_dir) misc.save_command(args.checkpoint_dir) misc.check_path(args.output_path) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) np.random.seed(args.seed) torch.backends.cudnn.benchmark = True if args.launcher == 'none': args.distributed = False device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: args.distributed = True # adjust batch size for each gpu assert args.batch_size % torch.cuda.device_count() == 0 args.batch_size = args.batch_size // torch.cuda.device_count() dist_params = dict(backend='nccl') init_dist(args.launcher, **dist_params) # re-set gpu_ids with distributed training mode _, world_size = get_dist_info() args.gpu_ids = range(world_size) device = torch.device('cuda:{}'.format(args.local_rank)) setup_for_distributed(args.local_rank == 0) # model model = UniMatch(feature_channels=args.feature_channels, num_scales=args.num_scales, upsample_factor=args.upsample_factor, num_head=args.num_head, ffn_dim_expansion=args.ffn_dim_expansion, num_transformer_layers=args.num_transformer_layers, reg_refine=args.reg_refine, task=args.task).to(device) if print_info: print(model) if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model.to(device), device_ids=[args.local_rank], output_device=args.local_rank) model_without_ddp = model.module else: if torch.cuda.device_count() > 1: print('Use %d GPUs' % torch.cuda.device_count()) model = torch.nn.DataParallel(model) model_without_ddp = model.module else: model_without_ddp = model num_params = sum(p.numel() for p in model.parameters()) if print_info: print('=> Number of trainable parameters: %d' % num_params) if not args.eval and not args.submission and args.inference_dir is None: save_name = '%d_parameters' % num_params open(os.path.join(args.checkpoint_dir, save_name), 'a').close() optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) start_epoch = 0 start_step = 0 if args.resume: print("=> Load checkpoint: %s" % args.resume) loc = 'cuda:{}'.format(args.local_rank) if torch.cuda.is_available() else 'cpu' checkpoint = torch.load(args.resume, map_location=loc) model_without_ddp.load_state_dict(checkpoint['model'], strict=args.strict_resume) if 'optimizer' in checkpoint and 'step' in checkpoint and 'epoch' in checkpoint and not \ args.no_resume_optimizer: print('Load optimizer') optimizer.load_state_dict(checkpoint['optimizer']) start_step = checkpoint['step'] start_epoch = checkpoint['epoch'] if print_info: print('start_epoch: %d, start_step: %d' % (start_epoch, start_step)) if args.submission: if 'kitti15' in args.val_dataset or 'kitti12' in args.val_dataset: create_kitti_submission(model_without_ddp, output_path=args.output_path, padding_factor=args.padding_factor, attn_type=args.attn_type, attn_splits_list=args.attn_splits_list, corr_radius_list=args.corr_radius_list, prop_radius_list=args.prop_radius_list, num_reg_refine=args.num_reg_refine, inference_size=args.inference_size, ) if 'eth3d' in args.val_dataset: create_eth3d_submission(model_without_ddp, output_path=args.output_path, padding_factor=args.padding_factor, attn_type=args.attn_type, attn_splits_list=args.attn_splits_list, corr_radius_list=args.corr_radius_list, prop_radius_list=args.prop_radius_list, num_reg_refine=args.num_reg_refine, inference_size=args.inference_size, submission_mode=args.eth_submission_mode, save_vis_disp=args.save_vis_disp, ) if 'middlebury' in args.val_dataset: create_middlebury_submission(model_without_ddp, output_path=args.output_path, padding_factor=args.padding_factor, attn_type=args.attn_type, attn_splits_list=args.attn_splits_list, corr_radius_list=args.corr_radius_list, prop_radius_list=args.prop_radius_list, num_reg_refine=args.num_reg_refine, inference_size=args.inference_size, submission_mode=args.middlebury_submission_mode, save_vis_disp=args.save_vis_disp, ) return if args.eval: val_results = {} if 'things' in args.val_dataset: results_dict = validate_things(model_without_ddp, max_disp=args.max_disp, padding_factor=args.padding_factor, inference_size=args.inference_size, attn_type=args.attn_type, attn_splits_list=args.attn_splits_list, corr_radius_list=args.corr_radius_list, prop_radius_list=args.prop_radius_list, num_reg_refine=args.num_reg_refine, ) if args.local_rank == 0: val_results.update(results_dict) if 'kitti15' in args.val_dataset or 'kitti12' in args.val_dataset: results_dict = validate_kitti15(model_without_ddp, padding_factor=args.padding_factor, inference_size=args.inference_size, attn_type=args.attn_type, attn_splits_list=args.attn_splits_list, corr_radius_list=args.corr_radius_list, prop_radius_list=args.prop_radius_list, num_reg_refine=args.num_reg_refine, count_time=args.count_time, debug=args.debug, ) if args.local_rank == 0: val_results.update(results_dict) if 'eth3d' in args.val_dataset: results_dict = validate_eth3d(model_without_ddp, padding_factor=args.padding_factor, inference_size=args.inference_size, attn_type=args.attn_type, attn_splits_list=args.attn_splits_list, corr_radius_list=args.corr_radius_list, prop_radius_list=args.prop_radius_list, num_reg_refine=args.num_reg_refine, ) if args.local_rank == 0: val_results.update(results_dict) if 'middlebury' in args.val_dataset: results_dict = validate_middlebury(model_without_ddp, padding_factor=args.padding_factor, inference_size=args.inference_size, attn_type=args.attn_type, attn_splits_list=args.attn_splits_list, corr_radius_list=args.corr_radius_list, prop_radius_list=args.prop_radius_list, num_reg_refine=args.num_reg_refine, resolution=args.middlebury_resolution, ) if args.local_rank == 0: val_results.update(results_dict) return if args.inference_dir or (args.inference_dir_left and args.inference_dir_right): inference_stereo(model_without_ddp, inference_dir=args.inference_dir, inference_dir_left=args.inference_dir_left, inference_dir_right=args.inference_dir_right, output_path=args.output_path, padding_factor=args.padding_factor, inference_size=args.inference_size, attn_type=args.attn_type, attn_splits_list=args.attn_splits_list, corr_radius_list=args.corr_radius_list, prop_radius_list=args.prop_radius_list, num_reg_refine=args.num_reg_refine, pred_bidir_disp=args.pred_bidir_disp, pred_right_disp=args.pred_right_disp, save_pfm_disp=args.save_pfm_disp, ) return train_data = build_dataset(args) print('=> {} training samples found in the training set'.format(len(train_data))) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( train_data, num_replicas=torch.cuda.device_count(), rank=args.local_rank ) else: train_sampler = None train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=train_sampler is None, num_workers=args.num_workers, pin_memory=True, drop_last=True, sampler=train_sampler, ) last_epoch = start_step if args.resume and not args.no_resume_optimizer else -1 lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, args.lr, args.num_steps + 10, pct_start=0.05, cycle_momentum=False, anneal_strategy='cos', last_epoch=last_epoch, ) if args.local_rank == 0: summary_writer = SummaryWriter(args.checkpoint_dir) total_steps = start_step epoch = start_epoch print('=> Start training...') while total_steps < args.num_steps: model.train() # mannually change random seed for shuffling every epoch if args.distributed: train_sampler.set_epoch(epoch) if args.local_rank == 0: summary_writer.add_scalar('lr', lr_scheduler.get_last_lr()[0], total_steps + 1) for i, sample in enumerate(train_loader): left = sample['left'].to(device) # [B, 3, H, W] right = sample['right'].to(device) gt_disp = sample['disp'].to(device) # [B, H, W] mask = (gt_disp > 0) & (gt_disp < args.max_disp) if not mask.any(): continue pred_disps = model(left, right, attn_type=args.attn_type, attn_splits_list=args.attn_splits_list, corr_radius_list=args.corr_radius_list, prop_radius_list=args.prop_radius_list, num_reg_refine=args.num_reg_refine, task='stereo', )['flow_preds'] disp_loss = 0 all_loss = [] # loss weights loss_weights = [0.9 ** (len(pred_disps) - 1 - power) for power in range(len(pred_disps))] for k in range(len(pred_disps)): pred_disp = pred_disps[k] weight = loss_weights[k] curr_loss = F.smooth_l1_loss(pred_disp[mask], gt_disp[mask], reduction='mean') disp_loss += weight * curr_loss all_loss.append(curr_loss) total_loss = disp_loss # more efficient zero_grad for param in model.parameters(): param.grad = None total_loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() lr_scheduler.step() total_steps += 1 if total_steps % args.summary_freq == 0 and args.local_rank == 0: img_summary = dict() img_summary['left'] = left img_summary['right'] = right img_summary['gt_disp'] = gt_disp img_summary['pred_disp'] = pred_disps[-1] pred_disp = pred_disps[-1] img_summary['disp_error'] = disp_error_img(pred_disp, gt_disp) save_images(summary_writer, 'train', img_summary, total_steps) epe = F.l1_loss(gt_disp[mask], pred_disp[mask], reduction='mean') print('step: %06d \t epe: %.3f' % (total_steps, epe.item())) summary_writer.add_scalar('train/epe', epe.item(), total_steps) summary_writer.add_scalar('train/disp_loss', disp_loss.item(), total_steps) summary_writer.add_scalar('train/total_loss', total_loss.item(), total_steps) # save all losses for s in range(len(all_loss)): save_name = 'train/loss' + str(len(all_loss) - s - 1) save_value = all_loss[s] summary_writer.add_scalar(save_name, save_value, total_steps) d1 = d1_metric(pred_disp, gt_disp, mask) summary_writer.add_scalar('train/d1', d1.item(), total_steps) # always save the latest model for resuming training if args.local_rank == 0 and total_steps % args.save_latest_ckpt_freq == 0: # Save lastest checkpoint after each epoch checkpoint_path = os.path.join(args.checkpoint_dir, 'checkpoint_latest.pth') save_dict = { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'step': total_steps, 'epoch': epoch, } torch.save(save_dict, checkpoint_path) # save checkpoint of specific epoch if args.local_rank == 0 and total_steps % args.save_ckpt_freq == 0: print('Save checkpoint at step: %d' % total_steps) checkpoint_path = os.path.join(args.checkpoint_dir, 'step_%06d.pth' % total_steps) save_dict = { 'model': model_without_ddp.state_dict(), } torch.save(save_dict, checkpoint_path) # validation if total_steps % args.val_freq == 0: val_results = {} if 'things' in args.val_dataset: results_dict = validate_things(model_without_ddp, max_disp=args.max_disp, padding_factor=args.padding_factor, inference_size=args.inference_size, attn_type=args.attn_type, attn_splits_list=args.attn_splits_list, corr_radius_list=args.corr_radius_list, prop_radius_list=args.prop_radius_list, num_reg_refine=args.num_reg_refine, ) if args.local_rank == 0: val_results.update(results_dict) if 'kitti15' in args.val_dataset or 'kitti12' in args.val_dataset: results_dict = validate_kitti15(model_without_ddp, padding_factor=args.padding_factor, inference_size=args.inference_size, attn_type=args.attn_type, attn_splits_list=args.attn_splits_list, corr_radius_list=args.corr_radius_list, prop_radius_list=args.prop_radius_list, num_reg_refine=args.num_reg_refine, count_time=args.count_time, ) if args.local_rank == 0: val_results.update(results_dict) if 'eth3d' in args.val_dataset: results_dict = validate_eth3d(model_without_ddp, padding_factor=args.padding_factor, inference_size=args.inference_size, attn_type=args.attn_type, attn_splits_list=args.attn_splits_list, corr_radius_list=args.corr_radius_list, prop_radius_list=args.prop_radius_list, num_reg_refine=args.num_reg_refine, ) if args.local_rank == 0: val_results.update(results_dict) if 'middlebury' in args.val_dataset: results_dict = validate_middlebury(model_without_ddp, padding_factor=args.padding_factor, inference_size=args.inference_size, attn_type=args.attn_type, attn_splits_list=args.attn_splits_list, corr_radius_list=args.corr_radius_list, prop_radius_list=args.prop_radius_list, num_reg_refine=args.num_reg_refine, resolution=args.middlebury_resolution, ) if args.local_rank == 0: val_results.update(results_dict) if args.local_rank == 0: # save to tensorboard for key in val_results: tag = key.split('_')[0] tag = tag + '/' + key summary_writer.add_scalar(tag, val_results[key], total_steps) # save validation results to file val_file = os.path.join(args.checkpoint_dir, 'val_results.txt') with open(val_file, 'a') as f: f.write('step: %06d\n' % total_steps) # order of metrics metrics = ['things_epe', 'things_d1', 'kitti15_epe', 'kitti15_d1', 'kitti15_3px', 'eth3d_epe', 'eth3d_1px', 'middlebury_epe', 'middlebury_2px', ] eval_metrics = [] for metric in metrics: if metric in val_results.keys(): eval_metrics.append(metric) metrics_values = [val_results[metric] for metric in eval_metrics] num_metrics = len(eval_metrics) f.write(("| {:>20} " * num_metrics + '\n').format(*eval_metrics)) f.write(("| {:20.4f} " * num_metrics).format(*metrics_values)) f.write('\n\n') model.train() if total_steps >= args.num_steps: print('Training done') return epoch += 1 if __name__ == '__main__': parser = get_args_parser() args = parser.parse_args() if 'LOCAL_RANK' not in os.environ: os.environ['LOCAL_RANK'] = str(args.local_rank) main(args)分析代码
最新发布
07-15
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值