您的位置:首页 > 运维架构

加权交叉熵损失函数:tf.nn.weighted_cross_entropy_with_logits

2020-06-21 05:23 1951 查看

ECharts5.0版本即将上线,来说说我与ECharts的那些事吧!>>>

tf.nn.weighted_cross_entropy_with_logits函数

tf.nn.weighted_cross_entropy_with_logits(
    targets,
    logits,
    pos_weight,
    name=None
)

定义在:tensorflow/python/ops/nn_impl.py。

计算加权交叉熵。

类似于sigmoid_cross_entropy_with_logits(),除了pos_weight,允许人们通过向上或向下加权相对于负误差的正误差的成本来权衡召回率和精确度。

通常的交叉熵成本定义为:

targets * -log(sigmoid(logits)) +
    (1 - targets) * -log(1 - sigmoid(logits))

值pos_weights > 1减少了假阴性计数,从而增加了召回率。相反设置pos_weights < 1会减少假阳性计数并提高精度。从一下内容可以看出pos_weight是作为损失表达式中的正目标项的乘法系数引入的:

targets * -log(sigmoid(logits)) * pos_weight +
    (1 - targets) * -log(1 - sigmoid(logits))

为了简便起见,让x = logits,z = targets,q = pos_weight。损失是:

  qz * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
= qz * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
= qz * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
= qz * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
= (1 - z) * x + (qz +  1 - z) * log(1 + exp(-x))
= (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x))

设置l = (1 + (q - 1) * z),确保稳定性并避免溢出,使用一下内容来实现:

(1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0))

logits和targets必须具有相同的类型和形状。

参数:

  • targets:一个Tensor,与logits具有相同的类型和形状。
  • logits:一个Tensor,类型为float32或float64。
  • pos_weight:正样本中使用的系数。
  • name:操作的名称(可选)。

返回:

与具有分量加权逻辑损失的logits具有相同形状的Tensor。

可能引发的异常:

  • ValueError:如果logits和targets没有相同的形状。
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  entropy python tensorflow