tensorflow 一些tips

本文介绍了一种在TensorFlow中处理不同秩张量和矩阵相乘的方法,通过tf.scan()函数实现批量矩阵乘法。此外,还提供了一个使用sklearn库轻松划分训练集和测试集的函数示例。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1. tensorflow不支持不同秩的张量和矩阵相乘 

例如:shape(3,2,3)和 (3,4)的tensor不能相乘 ,以前可以用batch_matmul() 新的版本中已被删除

但是神经网络中有时需要这类的操作,因为第一个维度往往是样本数维度。以下方法可以实现上述操作:

利用tf.scan() 实现

        # A维度(batch_size,dim_1,dim_2) B维度 (dim_2,dim_3) ——>返回(batch_size,dim_1,dim_3)
        def batch_matmul(A, B):
            # self.units 和self.fre_dim 对应(dim_1,dim_3) 
            initializer = tf.ones([self.units,self.fre_dim], dtype=self.dtype)
            C = tf.scan(lambda a, x: tf.matmul(x, B), A, initializer)
            return C

2.一个方便划分训练和测试集的函数

from sklearn.model_selection import train_test_split
 train_X, test_X, train_y, test_y = train_test_split(dataX, dataY,
                                                        test_size=0.2,
                                                        random_state= np.random.seed(1000))

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值