$ $f(x)=x anh(\log(1+e^x))$ $

使用稳定的log1pexp可以轻松实现该功能(mish激活),而不会造成任何显着的精度损失。不幸的是,这在计算上很重。

是否有可能编写一个更直接的数字稳定的实现,速度更快?

精度一样好 x * std::tanh(std::log1p(std::exp(x))) 那就好了。没有严格的约束,但在神经网络中使用它应该是相当准确的。

输入的分布是从 $[-\infty,\infty]$.它应该在任何地方工作。

有帮助吗?

解决方案

OP指向特定 实施情况mish 精度规格的激活函数,所以我不得不先对此进行表征。该实现使用单精度(float),并且在正半平面内是稳定和准确的。在负半平面中,因为它使用 logf 而不是 log1pf, ,相对误差迅速增长 $x o-\infty$.准确性的丧失始于 $-1$ 而且已经在 $-16.6355324$ 执行错误地返回 $0$, ,因为 $\exp(-16.6355324) = 2^{-24}$.

相同的精度和行为可以通过使用一个简单的数学变换,消除 $\mathrm{tahn}$, ,并考虑到Gpu通常提供融合乘加(FMA)以及快速倒数,这是人们想要利用的。示例性CUDA代码如下所示:

__device__ float my_mishf (float x)
{
    float r;
    float e = expf (x);
    r = 1.0f / fmaf (fmaf (-0.5f, e, -1.0f), e, -1.0f);
    r = fmaf (r, x, x);
    return r;
}

与OP所指向的参考实现一样,这在正半平面中具有优异的精度,并且在负半平面中误差迅速增加因此在 $-16.6355324$ 执行错误地返回 $0$.

如果希望解决这些准确性问题,我们可以应用以下观察结果。足够小 $x$, $f(x)=x\exp(x)$ 到浮点精度以内。为 float 计算这适用于 $x < -15$.对于间隔 $[-15,-1]$, ,我们可以使用有理近似 <R(x)> 计算 $f(x):=R(x)x\exp(x)$.示例性CUDA代码如下所示:

__device__ float my_mishf (float x)
{
    float r;
    if (x >= -1.0f) {
        float e = expf (x);
        r = 1.0f / fmaf (fmaf (-0.5f, e, -1.0f), e, -1.0f);
        r = fmaf (r, x, x);
    } else {
        float eh = expf (0.5f * x);
        float p =        1.03628484e-3f;  //  0x1.0fa7e6p-10
        p = fmaf (p, x, -7.28869531e-3f); // -0x1.ddac04p-8
        p = fmaf (p, x,  3.47027816e-2f); //  0x1.1c4902p-5
        p = fmaf (p, x, -3.54762226e-1f); // -0x1.6b46cap-2
        p = fmaf (p, x,  8.58785570e-1f); //  0x1.b7b2bep-1
        p = fmaf (p, x, -1.38065982e+0f); // -0x1.6172ecp+0
        p = fmaf (p, x,  5.97694337e-1f); //  0x1.3204fep-1
        float q =        1.03527203e-3f;  //  0x1.0f63eep-10
        q = fmaf (q, x, -7.35638570e-3f); // -0x1.e21bacp-8
        q = fmaf (q, x,  3.28683928e-2f); //  0x1.0d4204p-5
        q = fmaf (q, x, -3.79927397e-1f); // -0x1.850bb0p-2 
        q = fmaf (q, x,  6.86127126e-1f); //  0x1.5f4c0ep-1
        q = fmaf (q, x, -1.81509292e+0f); // -0x1.d0a9eep+0
        q = fmaf (q, x,  1.00000000e+0f); //  0x1.000000p+0
        r = (1.0f / q) * p;
        if (x < -15.0f) r = 1.0f;
        r = r * x * eh * eh;
    }
    return r;
}

不幸的是,这种精确的解决方案是以性能显着下降为代价的。如果一个愿意接受降低的精度,同时仍然实现一个平稳腐烂的左尾,下面的插值方案,再次基于 $f(x)\约x\exp(x)$, ,恢复了大部分性能:

__device__ float my_mishf (float x)
{
    float r;
    float e = expf (x);
    if (x >= -6.0625f) {
        r = 1.0f / fmaf (fmaf (-0.5f, e, -1.0f), e, -1.0f);
        r = fmaf (r, x, x);
    } else {
        r = fmaf (-0.5f, e, 1.0f);
        r = r * x * e;
    }
    return r;
}

作为特定于机器的性能增强, expf() 可以由设备内部替换 __expf().

其他提示

具有一些代数操作(如@ orlp的答案所指出的),我们可以推断以下内容:

$$ f(x)= x \ tanh(\ log(1 + e ^ x))\ tag {1} $$ $$= x \ frac {(1 + e ^ x)^ 2 - 1} {(1 + e ^ x)^ 2 + 1}= x \ frac {e ^ { 2x} + 2e ^ x} {e ^ {2x} + 2e ^ x + 2} \ tag {2} $$ $$= x - \ frac {2x} {(1 + e ^ x)^ 2 + 1} \ tag {3} $$

表达式 $(3)$(3)$ $ x $ 是负的损失很少精确。表达式 $(2)$(2)$ 不适合 $ x $ 以来的大值在分子和分母中爆炸。

函数 $(1)$ 渐近地击中零作为 $ x \ to-\ idty $ 。现在作为 $ x $ 的幅度大,表达式 $(3)$ 将遭受灾难性的消除:两个大术语互相取消,给出一个非常少数的数字。表达式 $(2)$ 在此范围内更适合。

此工作相当良好,直到 $ - 18 $ 及以上,您丢失了多个重要人物。

让我们仔细看看函数,然后尝试近似 $ f(x)$ 作为 $ x \ to - \ idty $

$$ f(x)= x \ frac {e ^ {2x} + 2e ^ x} {e ^ {2x} + 2e ^ x + 2} $$ < / span>

$ e ^ {2x} $ 将是小于 $ e ^ x $ $ e ^ x $ 将数量级小于 $ 1 $ 。使用这两个事实,我们可以近似 $ f(x)$ to:

$ f(x)\ apply x \ frac {e ^ x} {e ^ x + 1} \ atthe xe ^ x $

结果:

$ f(x)\ intave \ begin {is} xe ^ x,&\ text {如果$ x \ le -18 $} \\ x \ frac {e ^ {2x} + 2e ^ x} {e ^ {2x} + 2e ^ x + 2}&\ text {if $ -18 \ lt x \ le -0.6 $} \\ x - \ frac {2x} {(1 + e ^ x)^ 2 + 1},&\ text {否则} \结束{案例} $

快速cuda实现:

__device__ float mish(float x)
{
    auto e = __expf(x);
    if (x <= -18.0f)
        return x * e;    

    auto n = e * e + 2 * e;
    if (x <= -0.6f)
        return x * __fdividef(n, n + 2);

    return x - 2 * __fdividef(x, n + 2);
}
.

编辑:

更快,更准确的版本:

$ f(x)\ intave \ begin {is} x \ frac {e ^ {2x} + 2e ^ x} {e ^ {2x} + 2e ^ x + 2}&\ text {$ x \ le -0.6 $} \\ x - \ frac {2x} {(1 + e ^ x)^ 2 + 1},&\ text {否则} \结束{案例} $

__device__ float mish(float x)
{
    auto e = __expf(value);
    auto n = e * e + 2 * e;
    if (value <= -0.6f)
        return value * __fdividef(n, n + 2);

    return value - 2 * __fdividef(value, n + 2);
}
.

代码: https://gist.github.com/yashassamaga/8aad0cd3b30dbd0eb588c1f4c035db28c

$$ \ begin {array} {c | c | c | c |} &\ text {time(float)}&\ text {time(float4)}&\ text {l2错误矢量} \\ \ hline \文本{MISH}&1.49ms&1.39ms&2.4583E-05 \\ \ hline \文本{Relu}&1.47ms&1.39ms&\ text {n / a} \\ \ hline \结束{array} $$

无需执行对数。如果你让 $ p= 1+ \ exp(x)$ 那么我们 $ f(x)= x \ cdot \ dfrac {p ^ 2-1} {p ^ 2 + 1} $ $ f(x)= x - \ dfrac {2x} {p ^ 2 + 1} $

我的印象是,某人想要通过0到1的函数f(x)乘以x乘以x,并尝试直到他们发现使用这一函数的表达式,没有函数选择背后的数学原因。

选择参数t后,让 $ p_t(x)= 1/2 +(3/4t)x - x ^ 3 /(4t ^ 3)$ ,然后 $ p_t(0)= 1/2 $ $ p_t(t)= 1 $ $ p_t(-t)= 0 $ $ p_t'(t)= p_t'( - t)= 0 $ 。设x <-t,1如果x> +1,以及 $ p_t(x)$ if-t≤x≤+ t。这是一个函数,它从0开始变为1.选择另一个参数s,而不是f(x)计算x * g(x - s)。

t= 3.0和s= -0.3相匹配给定的函数相当合理,并且计算得足够速度(似乎是重要的)。当然是不同的。由于此功能在某些问题中用作工具,我希望看到原始功能是更好的数学原因

上下文 这里 是计算机视觉和训练神经网络的激活函数.

此代码很可能会在GPU上执行。虽然性能将取决于典型输入的分布, 一般来说,避免GPU代码中的分支很重要.扭曲发散会显着降低代码的性能。例如, CUDA工具包文档 说:

注:高优先级:避免同一warp内的不同执行路径。流控制指令(if、switch、do、for、while)会导致同一warp的线程发散,从而显着影响指令吞吐量;即遵循不同的执行路径。如果发生这种情况,则必须分别执行不同的执行路径;这增加了为此warp执行的指令的总数。...对于仅包括几个指令的分支,翘曲发散通常会导致边际性能损失。例如,编译器可以使用预测来避免实际的分支。相反,所有指令都被调度,但每个线程条件代码或谓词控制哪些线程执行指令。具有false谓词的线程不会写入结果,也不会计算地址或读取操作数。

两个无分支实现

OP的回答 确实有短分支,所以一些编译器可能会发生分支预测。我注意到的另一件事是,每次调用一次计算指数似乎是可以接受的。也就是说,我理解OP的回答说对指数的一次调用不是"昂贵"或"缓慢"。

在这种情况下,我会建议以下简单的代码:

__device__ float mish(float x)
{
    float expx = __expf(x);
    return x / (1.0f + 2.0f / (expx * (2.0f + expx)));
}

它没有分支,一个指数,一个乘法和两个除法。除法通常比乘法更昂贵,所以我也尝试了这个代码:

__device__ float mish(float x)
{
    float expx = __expf(x);
    float psi = expx * (2.0f + expx);
    return x * (psi / (2.0f + psi));
}

这没有分支,一个指数,两个乘法和一个除法。

相对误差

我计算了这两个实现的log10相对精度加上OP的答案。我计算了间隔(-100,100),增量为1/1024,然后计算了超过51个值的运行最大值(以减少视觉混乱,但仍然给出正确的印象)。用双精度计算第一个实现就足够了作为参考。指数精确到一个ULP内,并且只有少数算术运算;其余的位足以使表制造商的困境非常不可能。因此,我们很可能能够计算正确舍入的单精度参考值。

Log10 relative error

绿色:首先实施。红色:第二次实施。蓝色:OP的实现。蓝色和红色通过它们的大部分范围(左约-20)重叠。

OP注意事项:如果要保持完全精度,则需要将截止值更改为大于-5。

工作表现

你必须测试这两个实现,看看哪个更快。他们应该至少和OP一样快,我怀疑他们会因为缺乏分支而快得多。但是,如果他们对你来说不够快,你还可以做更多。

一个重要的问题:

您期望看到的典型输入值的分布是什么?值是否会在函数有效可计算的整个范围内均匀分布?或者它们几乎总是聚集在0周围?如果是这样,用什么方差/传播?

可以改善渐近线。

在左边,OP使用 x * expx 截止值为-18。该截止值可以增加到-15.5625左右,而不会损失精度。用一个额外的乘法的成本,你可以使用 x * expx * (1.0f - 0.5f * expx) 和约-4.875的截止值。注:乘以0.5可以优化为从指数减去1,所以我不在这里计算。

在右边,你可以引入另一个渐近线。如果 x > 8.75, ,简单 return x.只要多花点钱,你就可以了 x * (1.0f - 2.0f * __expf(-2.0f * x)) 何时 x > 6.0.

插值法

对于范围的中心部分(-4.875,6.0),您可以使用插值表。如果它们的范围是等距的,则可以使用一个除法来计算表中的直接索引(没有分支)。计算这样的表格需要一些努力,但根据您的需求可能是值得的:少数乘法和加法 可能 比指数便宜。也就是说,库中指数的实现者可能已经花费了大量的时间和精力来正确和快速地获得他们的指数。此外,"mish"功能不提供任何范围缩小的机会。

许可以下: CC-BY-SA归因
不隶属于 cs.stackexchange
scroll top