复现论文常用函数(一)tf.one_hot,tf.train.batch,tf.train.shuffle_batch,数据读取机制,获取文件路径,Bunch等

本文介绍了TensorFlow中的一些常用函数,包括tf.one_hot用于one-hot编码,tf.train.batch和tf.train.shuffle_batch分别用于有序和无序数据批处理,详细解析了数据读取机制,包括文件队列和内存队列的使用。此外,还提到了获取文件路径的方法,Bunch类的使用,以及tf.Assert、tf.control_dependencies和tf.stack等函数的用途。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1.tf.one_hot(input, len)
该函数用于将输入input转化为one-hot形式的向量
第一个参数input表示输入
第二个参数len表示one-hot的长度
如:input = [0,1,3]
output = tf.one_hot(input, 4)
#output为:
#[[1,0,0,0],[0,1,0,0],[0,0,0,1]]

2.tf.train.batch([example, label], batch_size=batch_size, capacity=capacity)
该函数用于将数据组织成为一个batch,且是有序的
第一个和第二个参数的含义字面意思就清楚了,不赘述
capacity表示的是队列容量

3.tf.train.shuffle_batch(
tensor_list, batch_size, capacity, min_after_dequeue, num_threads=1, seed=None, enqueue_many=False, shapes=None, name=None)
该函数用于将数据组织成为一个batch,且是随机的,无序的
参数含义和上述函数相同。其中多的几个参数含义在此表达一下:
min_after_dequeue:一定要保证这参数小于capacity参数的值,否则会出错。这个代表队列中的元素大于它的时候就输出乱的顺序的batch
num_threads表示线程数量
seed是随机种子

4.tf数据读取机制:
tf.train.string_input_producer()
tf.train.start_queue_runners()
在这里插入图片描述
文件队列,通过tf.train.string_input_producer()函数来创建,文件名队列不包含文件的具体内容,只是在队列中记录所有的文件名,所以可以在这个函数中对文件设置多个epoch,并对其进行shuffle。这个函数只是创建一个文件队列,并指定入队的操作由几个线程同时完成。真正的读取文件名内容是从执行了tf.train.start_queue_runners()开始的,start_queue_runners返回一个op,一旦执行这个op,文件名队列就开始被填充了。
内存队列,这个队列不需要用户手动创建,有了文件名队列后,start_queue_runners之后,Tensorflow会自己维护内存队列并保证用户时时有数据可读。
详细内容请看这篇文章

5.获取文件路径
#获取文件当前路径
current_path = os.getcwd()
#获取文件当前路径的上一路径
path = os.path.dirname(os.getcwd())
#获取文件当前路径的上一路径中的某个文件test.py的路径
file_path = os.path.join(os.path.dirname(os.getcwd()),‘test.py’)

6.Bunch:
今天阅读源码的时候看到了一个关于bunch的知识点,但是网上搜索了很长时间都没有看到有用的信息。因此在这里记下:
bunch其实类似于字典dict,接下来看看和字典的区别:
字典只能通过[]访问:
在这里插入图片描述
bunch不仅能通过[]访问,还能通过.访问,也就是说dict所有的key都变成了bunch的属性:
在这里插入图片描述
用途:当在某个json文件中写了一些配置参数的时候,python将json文件读入,都是字典,在使用的时候不如.方便,特别是可能需要添加某个属性,那么用类就非常方便了。
注意上面的代码要引入Bunch:
from sklearn.datasets.base import Bunch

7.tf.summary.merge_all
merge_all 可以将所有summary全部保存到磁盘,以便tensorboard显示。如果没有特殊要求,一般用这一句就可一显示训练时的各种信息了。
格式:tf.summaries.merge_all(key=‘summaries’)

8.tf.Assert(condition,data)
该函数用于根据条件打印数据,条件为真时,不打印,条件为假时,打印,例如:
rank_assertion = tf.Assert(
tf.equal(tf.rank(image), 3),
[‘Rank of image must be equal to 3.’])
当image的维度是3时,不打印,若不是3,则打印后面这句话

9.tf.control_dependencies()
该函数用于实现某种依赖关系。通俗点说,就是如果事件b必须在事件a发生后再执行,那么可以使用tf.control_dependencies()来实现。如:

with g.control_dependencies([a, b]):
  # Ops constructed here run after `a` and `b`.
  with g.control_dependencies([c, d]):
    # Ops constructed here run after `a`, `b`, `c`, and `d`.

10.tf.stack(values,axis=0,name=‘stack’)
将秩为 R 的张量列表堆叠成一个秩为 (R+1) 的张量。
将 values 中的张量列表打包成一个张量,该张量比 values 中的每个张量都高一个秩,通过沿 axis 维度打包。给定一个形状为(A, B, C)的张量的长度 N 的列表;
如果 axis == 0,那么 output 张量将具有形状(N, A, B, C)。如果 axis == 1,那么 output 张量将具有形状(A, N, B, C)。

x = tf.constant([1, 4])
y = tf.constant([2, 5])
z = tf.constant([3, 6])
tf.stack([x, y, z])  # [[1, 4], [2, 5], [3, 6]] (Pack along first dim.)
tf.stack([x, y, z], axis=1)  # [[1, 2, 3], [4, 5, 6]]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值