tf.cast()的用法

定义方式:

def cast(x, dtype, name=None)

解释:

修改张量数据的数据类型

x:待修改数据类型的数据

dtype:目标数据类型

name:可选参数,即定义操作的名称

cast在tensorflow的代码位置:tensorflow/python/ops/math_ops.py

具体的定义如下所示:

def cast(x, dtype, name=None):
  """Casts a tensor to a new type.

  The operation casts `x` (in case of `Tensor`) or `x.values`
  (in case of `SparseTensor` or `IndexedSlices`) to `dtype`.

  For example:

  ```python
  x = tf.constant([1.8, 2.2], dtype=tf.float32)
  tf.cast(x, tf.int32)  # [1, 2], dtype=tf.int32
  ```

  The operation supports data types (for `x` and `dtype`) of
  `uint8`, `uint16`, `uint32`, `uint64`, `int8`, `int16`, `int32`, `int64`,
  `float16`, `float32`, `float64`, `complex64`, `complex128`, `bfloat16`.
  In case of casting from complex types (`complex64`, `complex128`) to real
  types, only the real part of `x` is returned. In case of casting from real
  types to complex types (`complex64`, `complex128`), the imaginary part of the
  returned value is set to `0`. The handling of complex types here matches the
  behavior of numpy.

  Args:
    x: A `Tensor` or `SparseTensor` or `IndexedSlices` of numeric type. It could
      be `uint8`, `uint16`, `uint32`, `uint64`, `int8`, `int16`, `int32`,
      `int64`, `float16`, `float32`, `float64`, `complex64`, `complex128`,
      `bfloat16`.
    dtype: The destination type. The list of supported dtypes is the same as
      `x`.
    name: A name for the operation (optional).

  Returns:
    A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` and
      same type as `dtype`.

  Raises:
    TypeError: If `x` cannot be cast to the `dtype`.
  """
  base_type = dtypes.as_dtype(dtype).base_dtype
  if isinstance(x,
                (ops.Tensor, _resource_variable_type)) and base_type == x.dtype:
    return x
  with ops.name_scope(name, "Cast", [x]) as name:
    if isinstance(x, sparse_tensor.SparseTensor):
      values_cast = cast(x.values, base_type, name=name)
      x = sparse_tensor.SparseTensor(x.indices, values_cast, x.dense_shape)
    elif isinstance(x, ops.IndexedSlices):
      values_cast = cast(x.values, base_type, name=name)
      x = ops.IndexedSlices(values_cast, x.indices, x.dense_shape)
    else:
      # TODO(josh11b): If x is not already a Tensor, we could return
      # ops.convert_to_tensor(x, dtype=dtype, ...)  here, but that
      # allows some conversions that cast() can't do, e.g. casting numbers to
      # strings.
      x = ops.convert_to_tensor(x, name="x")
      if x.dtype.base_dtype != base_type:
        x = gen_math_ops.cast(x, base_type, name=name)
    if x.dtype.is_complex and base_type.is_floating:
      logging.warn("Casting complex to real discards imaginary part.")
    return x

例子:

import tensorflow as tf


def tf_cast_1():
    t1 = tf.Variable([1, 2, 3, 4, 5, 6])
    t2 = tf.cast(t1, dtype=tf.float32)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print('t1', t1, sess.run(t1))
        print('t2', t2, sess.run(t2))


tf_cast_1()

运行的结果如下所示:

 

 

class ClassSpecificPrecision(tf.keras.metrics.Metric): def __init__(self, class_id, num_classes=8, name="class_precision", **kwargs): """ 为指定类别计算精确率 参数: class_id: 要计算精确率的类别ID (0-7) num_classes: 总类别数 (默认为8) name: 指标名称 """ super().__init__(name=name, **kwargs) self.class_id = class_id self.num_classes = num_classes # 初始化真正例(TP)和预测正例(PP)计数器 self.true_positives = self.add_weight( name="tp", initializer="zeros", dtype=tf.float32) self.predicted_positives = self.add_weight( name="pp", initializer="zeros", dtype=tf.float32) def update_state(self, y_true, y_pred, sample_weight=None): # 将标签转换为整数类型 y_true = tf.cast(y_true, tf.int32) # 获取预测类别 (形状: [batch_size]) y_pred_class = tf.argmax(y_pred, axis=-1, output_type=tf.int32) # 计算真正例 (TP): 实际为class_id且预测为class_id true_positive = tf.logical_and( tf.equal(y_true, self.class_id), tf.equal(y_pred_class, self.class_id) ) # 计算预测正例 (PP): 预测为class_id (不论实际类别) predicted_positive = tf.equal(y_pred_class, self.class_id) # 应用样本权重 (如果提供) if sample_weight is not None: true_positive = tf.cast(true_positive, tf.float32) * sample_weight predicted_positive = tf.cast(predicted_positive, tf.float32) * sample_weight else: true_positive = tf.cast(true_positive, tf.float32) predicted_positive = tf.cast(predicted_positive, tf.float32) # 更新状态变量 self.true_positives.assign_add(tf.reduce_sum(true_positive)) self.predicted_positives.assign_add(tf.reduce_sum(predicted_positive)) def result(self): # 避免除以零错误 return tf.where( self.predicted_positives > 0, self.true_positives / self.predicted_positives, 0.0 ) def reset_state(self): # 重置计数器 self.true_positives.assign(0.0) self.predicted_positives.assign(0.0) 上面自定义的指标输出为什么会超过100.0
07-26
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小杨算法屋

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值