**Python 元组** tf.shape() x.get_shape().as_list()

部署运行你感兴趣的模型镜像

Python 元组
Python的元组与列表类似,不同之处在于元组的元素不能修改。

元组使用小括号,列表使用方括号。

元组创建很简单,只需要在括号中添加元素,并使用逗号隔开即可。

创建空元组
tup1 = ()
元组中只包含一个元素时,需要在元素后面添加逗号
tup1 = (50,)
元组与字符串类似,下标索引从0开始,可以进行截取,组合等。

(1) tf.shape()
先说tf.shape()很显然这个是获取张量的大小的,用法无需多说,直接上例子吧!

import tensorflow as tf

import numpy as np

a_array=np.array([[1,2,3],[4,5,6]])
b_list=[[1,2,3],[3,4,5]]
c_tensor=tf.constant([[1,2,3],[4,5,6]])

with tf.Session() as sess:
print(sess.run(tf.shape(a_array)))
print(sess.run(tf.shape(b_list)))
print(sess.run(tf.shape(c_tensor)))
结果:

(2)x.get_shape().as_list()
这个简单说明一下,x.get_shape(),只有tensor才可以使用这种方法,返回的是一个元组。

import tensorflow as tf

import numpy as np

a_array=np.array([[1,2,3],[4,5,6]])
b_list=[[1,2,3],[3,4,5]]
c_tensor=tf.constant([[1,2,3],[4,5,6]])

print(c_tensor.get_shape())
print(c_tensor.get_shape().as_list())

with tf.Session() as sess:
print(sess.run(tf.shape(a_array)))
print(sess.run(tf.shape(b_list)))
print(sess.run(tf.shape(c_tensor)))
结果:可见只能用于tensor来返回shape,但是是一个元组,需要通过as_list()的操作转换成list.


作者:爱抠脚的coder
来源:优快云
原文:https://blog.youkuaiyun.com/m0_37393514/article/details/82226754
版权声明:本文为博主原创文章,转载请附上博文链接!

您可能感兴趣的与本文相关的镜像

TensorFlow-v2.15

TensorFlow-v2.15

TensorFlow

TensorFlow 是由Google Brain 团队开发的开源机器学习框架,广泛应用于深度学习研究和生产环境。 它提供了一个灵活的平台,用于构建和训练各种机器学习模型

def Data_load_SUMIMO_test_sparse_SNR(Config, mean_real, std_real, mean_imag, std_imag, do_predict=True): results = [] bit_map = {'4QAM': 2, '16QAM': 4, '64QAM': 6} for i, data_path in enumerate(Config.data_path_list): print(f"\n📌 处理路径: {data_path}") seg_num = Config.segment_num[i] try: # 1. 加载接收信号 x1 = load_feature_data(data_path, seg_num, 'group_freqdata', 'RxFreqData') print(f"RxFreqData shape: {x1.shape}") # 2. 加载导频估计 h_ls = load_feature_data(data_path, seg_num, 'group_lsresult', 'LsEstimation') print(f"LsEstimation shape: {h_ls.shape}") # 3. 插值 pilot_mask = [f for f in range(120) if f % 6 == 0] if h_ls.ndim == 5: B, S, F, Nr, Nt = h_ls.shape h_pilot = h_ls[:, :, pilot_mask, :, :] else: raise ValueError(f"Unexpected shape: {h_ls.shape}") h_interp = fast_batch_interpolation(h_pilot, pilot_indices=pilot_mask, F=120) # 4. 特征合并 h_flat = h_interp.reshape(*h_interp.shape[:3], -1) # (B,S,F,Nr*Nt) x_combined = np.concatenate([x1, h_flat], axis=-1) # (B,S,F,C_total) x_input = np.concatenate([x_combined.real, x_combined.imag], axis=-1) # (B,S,F,2*C) # 5. 标准化 eps = 1e-8 x_real_norm = (x_input[..., :x_input.shape[-1]//2] - mean_real) / (std_real + eps) x_imag_norm = (x_input[..., x_input.shape[-1]//2:] - mean_imag) / (std_imag + eps) x_normalized = x_real_norm + 1j * x_imag_norm # complex input # 6. 加载标签 y_true = load_feature_data(data_path, seg_num, 'group_txcwbits', 'Label').astype(np.float32) print(f"Label shape: {y_true.shape}") # 7. 推理 y_pred_logits = None raw_ber = None if do_predict: print("🚀 执行模型预测...") with tf.device('/GPU:0' if tf.config.list_physical_devices('GPU') else '/CPU:0'): preds = loaded_model.predict(x_normalized, batch_size=1, verbose=1) n_bits = bit_map.get(Config.Modulation, 2) y_pred_selected = preds[..., :n_bits] y_pred_hard = (y_pred_selected > 0.5).astype(np.float32) correct = np.sum(y_pred_hard == y_true) total = y_true.size raw_ber = 1 - correct / total print(f"✅ Raw BER: {raw_ber:.6f}") results.append({ 'path': data_path, 'x_normalized': x_normalized, 'y_true': y_true, 'y_pred_logits': preds if do_predict else None, 'ber': raw_ber, 'shape': x_normalized.shape }) except Exception as e: print(f"❌ 处理 {data_path} 失败: {str(e)}") continue return results 如何调用这个函数
最新发布
10-29
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值