tf.split()函数解析

本文详细介绍了TensorFlow中tf.split函数的使用方法,包括通过整数或向量指定切割数量,以及如何沿特定轴进行切割。以三维张量为例,展示了不同切割方式的效果。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

API原型(TensorFlow 1.8.0):

tf.split(
    value,
    num_or_size_splits,
    axis=0,
    num=None,
    name='split'
)
这个函数是用来切割张量的。输入切割的张量和参数,返回切割的结果。 
value传入的就是需要切割的张量。 
这个函数有两种切割的方式:

以三个维度的张量为例,比如说一个20 * 30 * 40的张量my_tensor,就如同一个长20厘米宽30厘米高40厘米的蛋糕,每立方厘米都是一个分量。

有两种切割方式: 
1. 如果num_or_size_splits传入的是一个整数,这个整数代表这个张量最后会被切成几个小张量。此时,传入axis的数值就代表切割哪个维度(从0开始计数)。调用tf.split(my_tensor, 2,0)返回两个10 * 30 * 40的小张量。 
2. 如果num_or_size_splits传入的是一个向量,那么向量有几个分量就分成几份,切割的维度还是由axis决定。比如调用tf.split(my_tensor, [10, 5, 25], 2),则返回三个张量分别大小为 20 * 30 * 10、20 * 30 * 5、20 * 30 * 25。很显然,传入的这个向量各个分量加和必须等于axis所指示原张量维度的大小 (10 + 5 + 25 = 40)。
--------------------- 
作者:SangrealLilith 
来源:优快云 
原文:https://blog.youkuaiyun.com/SangrealLilith/article/details/80272346 
版权声明:本文为博主原创文章,转载请附上博文链接!

import os import matplotlib.pylab as plt import tensorflow as tf import pathlib # 在URL上下载数据集 flowers_root = tf.keras.utils.get_file( 'flower_photos', 'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz', untar=True) # untar=True表示自动解压下载的压缩包 flowers_root = pathlib.Path(flowers_root) print("看一下flowers_root的路径:", flowers_root) # 获取每个类下的文件数据 list_ds = tf.data.Dataset.list_files(str(flowers_root / '*/*')) def parse_image(filename): parts = tf.strings.split(filename, os.sep) # 分割数据,使用 tf.strings.split 函数将文件路径 filename 按照操作系统的路径分隔符 os.sep 进行分割 label = parts[-2] image = tf.io.read_file(filename) # 读取并输出输入文件名的全部内容 image = tf.image.decode_jpeg(image) # 编码解码处理 image = tf.image.convert_image_dtype(image, tf.float32) # 转换为float类型 image = tf.image.resize(image, [128, 128]) # 尺寸调整为128*128 return image, label file_path = next(iter(list_ds)) # 找到文件路径 image, label = parse_image(file_path) print('image的内容', image) print('***********************************************') print('label的内容', label) # 自定义函数绘制图像图形 plt.figure() plt.imshow(image) plt.title(label.numpy().decode('utf-8')) plt.axis('off') plt.show() images_ds = list_ds.map(parse_image) for image, label in images_ds.take(1): plt.figure() plt.imshow(image) plt.title(label.numpy().decode('utf-8')) plt.axis('off') plt.show() 检查错误
03-20
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值