Question

$$f(x) = x \tanh(\log(1 + e^x))$$

The function (mish activation) can be easily implemented using a stable log1pexp without any significant loss of precision. Unfortunately, this is computationally heavy.

Is it possible to write a more direct numerically stable implementation which is faster?

Accuracy as good as x * std::tanh(std::log1p(std::exp(x))) would be nice. There is no strict constraints but it should be reasonably accurate for use in neural networks.

The distribution of inputs is from $[-\infty, \infty]$. It should work everywhere.

Was it helpful?

Solution

OP points to a particular implementation of the mish activation function for accuracy specifications, so I had to characterize this first. That implementation uses single precision (float), and is stable and accurate in the positive half-plane. In the negative half-plane, because it uses logf instead of log1pf, relative error quickly grows a $x\to-\infty$. Loss of accuracy starts around $-1$ and already at $-16.6355324$ the implementation falsely returns $0$, because $\exp(-16.6355324) = 2^{-24}$.

The same accuracy and behavior can be achieved by using a simple mathematical transformation that eliminates $\mathrm{tahn}$, and considering that GPUs typically offer a fused multiply-add (FMA) as well as a fast reciprocal, which one would want to utilize. Exemplary CUDA code looks as follows:

__device__ float my_mishf (float x)
{
    float r;
    float e = expf (x);
    r = 1.0f / fmaf (fmaf (-0.5f, e, -1.0f), e, -1.0f);
    r = fmaf (r, x, x);
    return r;
}

As with the reference implementation pointed to by OP, this has excellent accuracy in the positive half-plane, and in the negative half-plane error increases rapidly so at $-16.6355324$ the implementation falsely returns $0$.

If there is a desire to address these accuracy issues, we can apply the following observations. For sufficiently small $x$, $f(x) = x \exp(x)$ to within floating-point accuracy. For float computation this holds for $x < -15$. For the interval $[-15,-1]$, we can use a rational approximation $R(x)$ to compute $f(x) := R(x)x\exp(x)$. Exemplary CUDA code looks as follows:

__device__ float my_mishf (float x)
{
    float r;
    if (x >= -1.0f) {
        float e = expf (x);
        r = 1.0f / fmaf (fmaf (-0.5f, e, -1.0f), e, -1.0f);
        r = fmaf (r, x, x);
    } else {
        float eh = expf (0.5f * x);
        float p =        1.03628484e-3f;  //  0x1.0fa7e6p-10
        p = fmaf (p, x, -7.28869531e-3f); // -0x1.ddac04p-8
        p = fmaf (p, x,  3.47027816e-2f); //  0x1.1c4902p-5
        p = fmaf (p, x, -3.54762226e-1f); // -0x1.6b46cap-2
        p = fmaf (p, x,  8.58785570e-1f); //  0x1.b7b2bep-1
        p = fmaf (p, x, -1.38065982e+0f); // -0x1.6172ecp+0
        p = fmaf (p, x,  5.97694337e-1f); //  0x1.3204fep-1
        float q =        1.03527203e-3f;  //  0x1.0f63eep-10
        q = fmaf (q, x, -7.35638570e-3f); // -0x1.e21bacp-8
        q = fmaf (q, x,  3.28683928e-2f); //  0x1.0d4204p-5
        q = fmaf (q, x, -3.79927397e-1f); // -0x1.850bb0p-2 
        q = fmaf (q, x,  6.86127126e-1f); //  0x1.5f4c0ep-1
        q = fmaf (q, x, -1.81509292e+0f); // -0x1.d0a9eep+0
        q = fmaf (q, x,  1.00000000e+0f); //  0x1.000000p+0
        r = (1.0f / q) * p;
        if (x < -15.0f) r = 1.0f;
        r = r * x * eh * eh;
    }
    return r;
}

Unfortunately, this accurate solution is achieved at the cost of a significant drop in performance. If one is willing to accept reduced accuracy while still achieving a smoothly decaying left tail, the following interpolation scheme, again based on $f(x) \approx x\exp(x)$, recovers much of the performance:

__device__ float my_mishf (float x)
{
    float r;
    float e = expf (x);
    if (x >= -6.0625f) {
        r = 1.0f / fmaf (fmaf (-0.5f, e, -1.0f), e, -1.0f);
        r = fmaf (r, x, x);
    } else {
        r = fmaf (-0.5f, e, 1.0f);
        r = r * x * e;
    }
    return r;
}

As a machine-specific performance enhancement, expf() could be replaced by the device intrinsic __expf().

OTHER TIPS

With some algebraic manipulation (as pointed out in @orlp's answer), we can deduce the following:

$$f(x) = x \tanh(\log(1+e^x)) \tag{1}$$ $$ = x\frac{(1+e^x)^2 - 1}{(1+e^x)^2 + 1} = x\frac{e^{2x} + 2e^x}{e^{2x} + 2e^x + 2}\tag{2}$$ $$ = x - \frac{2x}{(1 + e^x)^2 + 1} \tag{3}$$

Expression $(3)$ works great when $x$ is negative with very little loss of precision. Expression $(2)$ is not suitable for large values of $x$ since the terms are going to blow up both in the numerator and denominator.

The function $(1)$ asymptotically hits zero as $x \to-\infty$. Now as $x$ becomes larger in magnitude, the expression $(3)$ will suffer from catastrophic cancellation: two large terms cancelling each other to give a really small number. The expression $(2)$ is more suitable in this range.

This works fairly well until $-18$ and beyond which you lose multiple significant figures.

Let's take a closer look at the function and try to approximate $f(x)$ as $x \to-\infty$.

$$f(x) = x \frac{e^{2x} + 2e^x}{e^{2x} + 2e^x + 2}$$

The $e^{2x}$ will be orders of magnitude smaller than $e^x$. $e^x$ will be orders of magnitude smaller than $1$. Using these two facts, we can approximate $f(x)$ to:

$f(x) \approx x\frac{e^x}{e^x+1}\approx xe^x$

Result:

$f(x) \approx \begin{cases} xe^x, & \text{if $x \le -18$} \\ x\frac{e^{2x} + 2e^x}{e^{2x} + 2e^x + 2} & \text{if $-18 \lt x \le -0.6$} \\ x - \frac{2x}{(1 + e^x)^2 + 1}, & \text{otherwise} \end{cases} $

Fast CUDA implementation:

__device__ float mish(float x)
{
    auto e = __expf(x);
    if (x <= -18.0f)
        return x * e;    

    auto n = e * e + 2 * e;
    if (x <= -0.6f)
        return x * __fdividef(n, n + 2);

    return x - 2 * __fdividef(x, n + 2);
}

EDIT:

An even more faster and accurate version:

$f(x) \approx \begin{cases} x\frac{e^{2x} + 2e^x}{e^{2x} + 2e^x + 2} & \text{$x \le -0.6$} \\ x - \frac{2x}{(1 + e^x)^2 + 1}, & \text{otherwise} \end{cases} $

__device__ float mish(float x)
{
    auto e = __expf(value);
    auto n = e * e + 2 * e;
    if (value <= -0.6f)
        return value * __fdividef(n, n + 2);

    return value - 2 * __fdividef(value, n + 2);
}

Code: https://gist.github.com/YashasSamaga/8ad0cd3b30dbd0eb588c1f4c035db28c

$$\begin{array}{c|c|c|c|} & \text{Time (float)} & \text{Time (float4)} & \text{L2 norm of error vector} \\ \hline \text{mish} & 1.49ms & 1.39ms & 2.4583e-05 \\ \hline \text{relu} & 1.47ms & 1.39ms & \text{N/A} \\ \hline \end{array}$$

There's no need to perform the logarithm. If you let $p = 1+\exp(x)$ then we have $f(x) = x\cdot\dfrac{p^2-1}{p^2+1}$ or alternatively $f(x) = x - \dfrac{2x}{p^2+1}$.

My impression is that someone wanted to multiply x by a function f(x) that goes smoothly from 0 to 1, and experimented until they found an expression using elementary functions that did this, with no mathematical reason behind the choice of functions.

After choosing a parameter t, let $p_t(x) = 1/2 + (3 / 4t)x - x^3 / (4t^3)$, then $p_t(0) = 1/2$, $p_t(t) = 1$, $p_t(-t) = 0$, and $p_t'(t) = p_t'(-t) = 0$. Let g(x) = 0 if x < -t, 1 if x > +1, and $p_t(x)$ if -t ≤ x ≤ +t. This is a function that smoothly changes from 0 to 1. Choose another parameter s, and instead of f(x) calculate x * g (x - s).

t = 3.0 and s = -0.3 matches the given function quite reasonably and is calculated an awful lot faster (which seems important). It's different of course. As this function is used as a tool in some problem, I'd want to see a mathemtical reason why the original function is better.

The context here is computer vision and the activation function for training neural nets.

Chances are this code is going to be executed on a GPU. While performance is going to depend on the distribution of typical inputs, generally speaking it is important to avoid branches in GPU code. Warp divergence can significantly degrade performance of your code. For example, the CUDA Toolkit Documentation says:

Note:High Priority: Avoid different execution paths within the same warp. Flow control instructions (if, switch, do, for, while) can significantly affect the instruction throughput by causing threads of the same warp to diverge; that is, to follow different execution paths. If this happens, the different execution paths must be executed separately; this increases the total number of instructions executed for this warp. ... For branches including just a few instructions, warp divergence generally results in marginal performance losses. For example, the compiler may use predication to avoid an actual branch. Instead, all instructions are scheduled, but a per-thread condition code or predicate controls which threads execute the instructions. Threads with a false predicate do not write results, and also do not evaluate addresses or read operands.

Two branch-free implementations

OP's answer does have short branches so branch predication may happen with some compilers. One other thing I noticed is that it appears to be acceptable to compute the exponential once per call. That is, I understand OP's answer to say one call to the exponential is not "expensive" or "slow".

In that case, I would suggest the following simple code:

__device__ float mish(float x)
{
    float expx = __expf(x);
    return x / (1.0f + 2.0f / (expx * (2.0f + expx)));
}

It has no branches, one exponential, one multiplication, and two divisions. Divisions are often more expensive than multiplications so I also tried out this code:

__device__ float mish(float x)
{
    float expx = __expf(x);
    float psi = expx * (2.0f + expx);
    return x * (psi / (2.0f + psi));
}

This has no branches, one exponential, two multiplications, and one division.

Relative error

I computed the log10 relative accuracy of these two implementations plus OP's answer. I computed over the interval (-100,100) with an increment of 1/1024, then computed a running maximum over 51 values (to reduce the visual clutter but still give the correct impression). Computing the first implementation with double precision suffices as a reference. The exponential is accurate to within one ULP, and there are only a handful of arithmetic operations; the rest of the bits are more than sufficient to make a table maker's dilemma very unlikely. Thus we're very likely to be able to compute correctly rounded single-precision reference values.

Log10 relative error

Green: first implementation. Red: second implemenation. Blue: OP's implementation. The blue and red overlap through most of their range (left of about -20).

Note to OP: you'll want to change the cutoff to greater than -5 if you want to maintain full precision.

Performance

You'll have to test these two implementations to see which is faster. They should be at least as fast as OP's, and I suspect they'll be much faster because of the lack of branches. However, if they are not fast enough for you, there's more you can do.

An important question:

What is the distribution of typical input values you expect to see? Are values going to be uniformly distributed over the entire range the function is effectively computable? Or are they going to be clustered around 0 almost all the time? If so, with what variance/spread?

The asymptotics can be improved.

On the left, OP uses x * expx with a cutoff of -18. This cutoff can be increased to about -15.5625 with no loss of precision. With the cost of one extra multiplication, you could use x * expx * (1.0f - 0.5f * expx) and a cutoff of about -4.875. Note: the multiplication by 0.5 can be optimized to a subtraction of 1 from the exponent so I'm not counting that here.

On the right, you can introduce another asymptotic. If x > 8.75, simply return x. With a little more cost, you could do x * (1.0f - 2.0f * __expf(-2.0f * x)) when x > 6.0.

Interpolation

For the center part of the range (-4.875, 6.0), you can use a table of interpolants. If their ranges are equally spaced, you can use one division to compute a direct index into the table (without branching). Calculating such a table would take some effort, but depending on your needs may be worth it: a handful of multiplies and adds might be less expensive than the exponential. That said, the implementers of the exponential in the library probably have spent a lot of time and effort getting theirs correct and fast. Also, the "mish" function doesn't present any opportunities for range reduction.

Licensed under: CC-BY-SA with attribution
Not affiliated with cs.stackexchange
scroll top