TensorFlow 和 Numpy 等常用函数小结

本文深入解析了TensorFlow中tf.cast、tf.variable_scope、tf.name_scope及tf.squeeze等关键函数的使用方法,同时介绍了numpy的flip函数和Python的__all__机制。通过实例展示了如何在TensorFlow中改变张量数据类型、变量命名空间以及张量维度压缩。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1.tf.cast用法

tf.cast:用于改变某个张量的数据类型

import tensorflow as tf;
import numpy as np;
 
A = tf.convert_to_tensor(np.array([[1,1,2,4], [3,4,8,5]]))
 
with tf.Session() as sess:
	print A.dtype
	b = tf.cast(A, tf.float32)
	print b.dtype

输出:

<dtype: 'int64'>
<dtype: 'float32'>

开始的时候定义A没有给出类型,采用默认类型,整形。利用tf.cast函数就改为float类型

2.tf.variable_scope和tf.name_scope的用法

tf.variable_scope可以让变量有相同的命名,包括tf.get_variable得到的变量,还有tf.Variable的变量

tf.name_scope可以让变量有相同的命名,只是限于tf.Variable的变量

import tensorflow as tf;  
import numpy as np;  
import matplotlib.pyplot as plt;  
 
with tf.variable_scope('V1'):
	a1 = tf.get_variable(name='a1', shape=[1], initializer=tf.constant_initializer(1))
	a2 = tf.Variable(tf.random_normal(shape=[2,3], mean=0, stddev=1), name='a2')

with tf.name_scope('V2'):
	a4 = tf.Variable(tf.random_normal(shape=[2,3], mean=0, stddev=1), name='a2')
  
with tf.Session() as sess:
	sess.run(tf.initialize_all_variables())
	print a1.name
	print a2.name
	print a4.name

输出:

V1/a1:0
V1/a2:0
V2/a2:0

3.tf.squeeze

从张量形状中移除大小为1的维度。

给定一个张量 input,该操作返回一个与已经移除的所有大小为1的维度具有相同类型的张量。如果您不想删除所有大小为1的维度,则可以通过指定 axis 来删除特定的大小为1的维度。

# 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
tf.shape(tf.squeeze(t))  # [2, 3]

或者,要删除特定的大小为1的维度:

# 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
tf.shape(tf.squeeze(t, [2, 4]))  # [1, 2, 3, 1]

1.np.flip

沿着指定的轴反转元素

>>> A = np.arange(8).reshape((2,2,2))
>>> A
array([[[0, 1],
        [2, 3]],
       [[4, 5],
        [6, 7]]])
        
>>> flip(A, 0)
array([[[4, 5],
        [6, 7]],
       [[0, 1],
        [2, 3]]])
        
>>> flip(A, 1)
array([[[2, 3],
        [0, 1]],
       [[6, 7],
        [4, 5]]])

>>>np.flip(A, 2)
array([[[1, 0],
        [3, 2]],
       [[5, 4],
        [7, 6]]])

2.os 路径操作

dir = '/home/123/result.png'
print os.path.dirname(dir)    # 获得文件夹的路径
print os.path.basename(dir)   # 获得文件名称
print os.path.splitext(os.path.basename(dir)) # 分离文件名和扩展名

输出:

/home/123/

result.png

result .png

3.python 中的 __all__

我越来越多的使用Python了,经常看到 __all__ 变量再各种 __init__.py 文件中,谁能解释为什么那么做呢?

它是一个string元素组成的list变量,定义了当你使用 from <module> import * 导入某个模块的时候能导出的符号(这里代表变量,函数,类等)。

举个例子,下面的代码在 foo.py 中,明确的导出了符号 barbaz

__all__ = ['bar', 'baz']

waz = 5
bar = 10
def baz(): return 'baz'

导入实现如下:

from foo import *

print bar
print baz

# The following will trigger an exception, as "waz" is not exported by the module
# 下面的代码就会抛出异常,因为 "waz"并没有从模块中导出,因为 __all__ 没有定义
print waz

如果把 foo.py 中 __all__ 给注释掉,那么上面的代码执行起来就不会有问题, import * 默认的行为是从给定的命名空间导出所有的符号(当然下划线开头的私有变量除外)。

注意

需要注意的是 __all__ 只影响到了 from <module> import * 这种导入方式,对于 from <module> import <member> 导入方式并没有影响,仍然可以从外部导入。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值