Pregunta

$$f(x) = x anh(\log(1 + e^x))$$

La función (activación Mish) se puede implementar fácilmente utilizando un log1pexp estable sin ninguna pérdida significativa de precisión.Desafortunadamente, esto es computacionalmente pesado.

¿Es posible escribir una implementación numéricamente estable más directa que sea más rápida?

Precisión tan buena como x * std::tanh(std::log1p(std::exp(x))) sería bueno.No existen restricciones estrictas, pero debe ser razonablemente preciso para su uso en redes neuronales.

La distribución de insumos es de $[-\infty, \infty]$.Debería funcionar en todas partes.

¿Fue útil?

Solución

OP señala un particular implementación del mish función de activación para especificaciones de precisión, así que tuve que caracterizar esto primero.Esa implementación utiliza precisión simple (float), y es estable y preciso en el semiplano positivo.En el semiplano negativo, porque utiliza logf en lugar de log1pf, el error relativo crece rápidamente $x\a-\infty$.La pérdida de precisión comienza alrededor $-1$ y ya en $-16.6355324$ la implementación regresa falsamente $0$, porque $\exp(-16.6355324) = 2^{-24}$.

Se puede lograr la misma precisión y comportamiento utilizando una transformación matemática simple que elimina $\mathrm{tahn}$, y considerando que las GPU generalmente ofrecen una suma múltiple fusionada (FMA), así como una recíproca rápida, cuál uno querría utilizar.El código CUDA de ejemplo tiene el siguiente aspecto:

__device__ float my_mishf (float x)
{
    float r;
    float e = expf (x);
    r = 1.0f / fmaf (fmaf (-0.5f, e, -1.0f), e, -1.0f);
    r = fmaf (r, x, x);
    return r;
}

Al igual que con la implementación de referencia señalada por OP, esto tiene una precisión excelente en el semiplano positivo, y en el semiplano negativo el error aumenta rápidamente, por lo que en $-16.6355324$ la implementación regresa falsamente $0$.

Si deseamos abordar estos problemas de precisión, podemos aplicar las siguientes observaciones.Para suficientemente pequeño $x$, $f(x) = x\exp(x)$ con una precisión de coma flotante.Para float cálculo esto es válido para $x<-15$.para el intervalo $[-15,-1]$, podemos usar una aproximación racional $R(x)$ computar $f(x) := R(x)x\exp(x)$.El código CUDA de ejemplo tiene el siguiente aspecto:

__device__ float my_mishf (float x)
{
    float r;
    if (x >= -1.0f) {
        float e = expf (x);
        r = 1.0f / fmaf (fmaf (-0.5f, e, -1.0f), e, -1.0f);
        r = fmaf (r, x, x);
    } else {
        float eh = expf (0.5f * x);
        float p =        1.03628484e-3f;  //  0x1.0fa7e6p-10
        p = fmaf (p, x, -7.28869531e-3f); // -0x1.ddac04p-8
        p = fmaf (p, x,  3.47027816e-2f); //  0x1.1c4902p-5
        p = fmaf (p, x, -3.54762226e-1f); // -0x1.6b46cap-2
        p = fmaf (p, x,  8.58785570e-1f); //  0x1.b7b2bep-1
        p = fmaf (p, x, -1.38065982e+0f); // -0x1.6172ecp+0
        p = fmaf (p, x,  5.97694337e-1f); //  0x1.3204fep-1
        float q =        1.03527203e-3f;  //  0x1.0f63eep-10
        q = fmaf (q, x, -7.35638570e-3f); // -0x1.e21bacp-8
        q = fmaf (q, x,  3.28683928e-2f); //  0x1.0d4204p-5
        q = fmaf (q, x, -3.79927397e-1f); // -0x1.850bb0p-2 
        q = fmaf (q, x,  6.86127126e-1f); //  0x1.5f4c0ep-1
        q = fmaf (q, x, -1.81509292e+0f); // -0x1.d0a9eep+0
        q = fmaf (q, x,  1.00000000e+0f); //  0x1.000000p+0
        r = (1.0f / q) * p;
        if (x < -15.0f) r = 1.0f;
        r = r * x * eh * eh;
    }
    return r;
}

Desafortunadamente, esta solución precisa se logra a costa de una caída significativa en el rendimiento.Si uno está dispuesto a aceptar una precisión reducida y al mismo tiempo lograr una cola izquierda que decae suavemente, el siguiente esquema de interpolación, nuevamente basado en $f(x) \aprox x\exp(x)$, recupera gran parte del rendimiento:

__device__ float my_mishf (float x)
{
    float r;
    float e = expf (x);
    if (x >= -6.0625f) {
        r = 1.0f / fmaf (fmaf (-0.5f, e, -1.0f), e, -1.0f);
        r = fmaf (r, x, x);
    } else {
        r = fmaf (-0.5f, e, 1.0f);
        r = r * x * e;
    }
    return r;
}

Como mejora del rendimiento específica de la máquina, expf() podría ser reemplazado por el dispositivo intrínseco __expf().

Otros consejos

Con alguna manipulación algebraica (como se señala en la respuesta de @orlp), podemos deducir lo siguiente:

$$f(x) = x anh(\log(1+e^x)) ag{1}$$ $$ = x\frac{(1+e^x)^2 - 1}{(1+e^x)^2 + 1} = x\frac{e^{2x} + 2e^x}{e^ {2x} + 2e^x + 2}\etiqueta{2}$$ $$ = x - \frac{2x}{(1 + e^x)^2 + 1} ag{3}$$

Expresión $(3)$ funciona muy bien cuando $x$ es negativo con muy poca pérdida de precisión.Expresión $(2)$ no es adecuado para grandes valores de $x$ ya que los términos van a explotar tanto en el numerador como en el denominador.

La función $(1)$ asintóticamente llega a cero como $x \a-\infty$.No fue $x$ aumenta en magnitud, la expresión $(3)$ sufrirá una cancelación catastrófica:dos términos grandes que se anulan entre sí para dar un número realmente pequeño.La expresion $(2)$ es más adecuado en este rango.

Esto funciona bastante bien hasta $-18$ y más allá del cual se pierden múltiples cifras significativas.

Echemos un vistazo más de cerca a la función e intentemos aproximarnos $f(x)$ como $x \a-\infty$.

$$f(x) = x \frac{e^{2x} + 2e^x}{e^{2x} + 2e^x + 2}$$

El $mi^{2x}$ será órdenes de magnitud menor que $e^x$. $e^x$ será órdenes de magnitud menor que $1$.Usando estos dos hechos, podemos aproximar $f(x)$ a:

$f(x) \aprox x\frac{e^x}{e^x+1}\aprox xe^x$

Resultado:

$ f (x) aprox begin {casos} xe^x, & text {if $ x le -18 $} x frac {e^{2x} + 2e^x} {e^{2x } + 2e^x + 2} & text {if $ -18 lt x le -0.6 $} x - frac {2x} {(1 + e^x)^2 + 1}, & texto {de lo contrario} end {casos} $

Implementación rápida de CUDA:

__device__ float mish(float x)
{
    auto e = __expf(x);
    if (x <= -18.0f)
        return x * e;    

    auto n = e * e + 2 * e;
    if (x <= -0.6f)
        return x * __fdividef(n, n + 2);

    return x - 2 * __fdividef(x, n + 2);
}

EDITAR:

Una versión aún más rápida y precisa:

$ f (x) aprox begin {casos} x frac {e^{2x} + 2e^x} {e^{2x} + 2e^x + 2} & text {$ x le -0.6 $ } x - frac {2x} {(1 + e^x)^2 + 1}, & text {de lo contrario} end {casos} $

__device__ float mish(float x)
{
    auto e = __expf(value);
    auto n = e * e + 2 * e;
    if (value <= -0.6f)
        return value * __fdividef(n, n + 2);

    return value - 2 * __fdividef(value, n + 2);
}

Código: https://gist.github.com/YashasSamaga/8ad0cd3b30dbd0eb588c1f4c035db28c

$$ begin {array} {c | c | c | c |} & text {time (float)} & text {time (float4)} & text {l2 norma de error vector} hline Texto {Mish} & 1.49ms & 1.39ms & 2.4583e-05 Hline text {relu} & 1.47ms & 1.39ms & text {n/a} hline end {matriz} $$

No es necesario realizar el logaritmo.Si tu dejas $p = 1+\exp(x)$ entonces nosotros tenemos$f(x) = x\cdot\dfrac{p^2-1}{p^2+1}$ o alternativamente $f(x) = x - \dfrac{2x}{p^2+1}$.

Mi impresión es que alguien quería multiplicar x por una función f(x) que va suavemente de 0 a 1, y experimentó hasta encontrar una expresión usando funciones elementales que hiciera esto, sin ninguna razón matemática detrás de la elección de las funciones.

Después de elegir un parámetro t, sea $p_t(x) = 1/2 + (3 / 4t)x - x^3 / (4t^3)$, entonces $p_t(0) = 1/2$, $p_t(t) = 1$, $p_t(-t) = 0$, y $p_t'(t) = p_t'(-t) = 0$.Sea g(x) = 0 si x < -t, 1 si x > +1, y $p_t(x)$ si -t ≤ x ≤ +t.Esta es una función que cambia suavemente de 0 a 1.Elija otro parámetro s y, en lugar de f(x), calcule x * g (x - s).

t = 3,0 y s = -0,3 coincide bastante razonablemente con la función dada y se calcula muchísimo más rápido (lo que parece importante).Es diferente por supuesto.Como esta función se usa como herramienta en algún problema, me gustaría ver una razón matemática por la cual la función original es mejor.

El contexto aquí es la visión por computadora y la función de activación para entrenar redes neuronales.

Lo más probable es que este código se ejecute en una GPU.Si bien el rendimiento dependerá de la distribución de los insumos típicos, En términos generales, es importante evitar bifurcaciones en el código de la GPU..La divergencia de deformación puede degradar significativamente el rendimiento de su código.Por ejemplo, el Documentación del kit de herramientas CUDA dice:

Nota: Alta prioridad:Evite diferentes rutas de ejecución dentro del mismo warp.Las instrucciones de control de flujo (if, switch, do, for, while) pueden afectar significativamente el rendimiento de las instrucciones al provocar que los hilos de la misma deformación diverjan;es decir, seguir diferentes caminos de ejecución.Si esto sucede, las diferentes rutas de ejecución deben ejecutarse por separado;esto aumenta el número total de instrucciones ejecutadas para esta deformación....Para ramas que incluyen sólo unas pocas instrucciones, la divergencia de deformación generalmente resulta en pérdidas marginales de rendimiento.Por ejemplo, el compilador puede utilizar predicación para evitar una rama real.En cambio, todas las instrucciones están programadas, pero un código de condición o predicado por subproceso controla qué subprocesos ejecutan las instrucciones.Los subprocesos con un predicado falso no escriben resultados y tampoco evalúan direcciones ni leen operandos.

Dos implementaciones sin sucursales

La respuesta del OP tiene ramas cortas, por lo que la predicción de ramas puede ocurrir con algunos compiladores.Otra cosa que noté es que parece aceptable calcular el exponencial una vez por llamada.Es decir, entiendo la respuesta de OP de decir que una llamada al exponencial no es "cara" ni "lenta".

En ese caso, sugeriría el siguiente código simple:

__device__ float mish(float x)
{
    float expx = __expf(x);
    return x / (1.0f + 2.0f / (expx * (2.0f + expx)));
}

No tiene ramas, una exponencial, una multiplicación y dos divisiones.Las divisiones suelen ser más caras que las multiplicaciones, así que también probé este código:

__device__ float mish(float x)
{
    float expx = __expf(x);
    float psi = expx * (2.0f + expx);
    return x * (psi / (2.0f + psi));
}

Este no tiene ramas, una exponencial, dos multiplicaciones y una división.

Error relativo

Calculé la precisión relativa log10 de estas dos implementaciones más la respuesta de OP.Calculé durante el intervalo (-100,100) con un incremento de 1/1024, luego calculé un máximo acumulado sobre 51 valores (para reducir el desorden visual pero aún así dar la impresión correcta).Calcular la primera implementación con doble precisión es suficiente como referencia.La exponencial tiene una precisión de un ULP y solo hay un puñado de operaciones aritméticas;el resto de los bits son más que suficientes para hacer que el dilema del fabricante de mesas sea muy improbable.Por lo tanto, es muy probable que podamos calcular valores de referencia de precisión simple redondeados correctamente.

Log10 relative error

Verde:primera implementación.Rojo:Segunda implementación.Azul:Implementación del OP.El azul y el rojo se superponen en la mayor parte de su rango (a la izquierda de aproximadamente -20).

Nota para el OP:querrás cambiar el límite a mayor que -5 si quieres mantener la precisión total.

Actuación

Tendrás que probar estas dos implementaciones para ver cuál es más rápida.Deberían ser al menos tan rápidos como los OP, y sospecho que serán mucho más rápidos debido a la falta de ramas.Sin embargo, si no son lo suficientemente rápidos para usted, puede hacer más.

Una pregunta importante:

¿Cuál es la distribución de valores de entrada típicos que espera ver?¿Los valores se distribuirán uniformemente en todo el rango en el que la función es efectivamente computable?¿O estarán agrupados alrededor de 0 casi todo el tiempo?Si es así, ¿con qué variación/diferencial?

Las asintóticas se pueden mejorar.

A la izquierda, OP usa x * expx con un límite de -18.Este límite se puede aumentar hasta aproximadamente -15,5625 sin pérdida de precisión.Con el costo de una multiplicación extra, podrías usar x * expx * (1.0f - 0.5f * expx) y un límite de aproximadamente -4,875.Nota:la multiplicación por 0,5 se puede optimizar para restar 1 del exponente, así que no lo cuento aquí.

A la derecha, puedes introducir otro asintótico.Si x > 8.75, simplemente return x.Con un poco más de costo, podrías hacer x * (1.0f - 2.0f * __expf(-2.0f * x)) cuando x > 6.0.

Interpolación

Para la parte central del rango (-4,875, 6,0), puede utilizar una tabla de interpoladores.Si sus rangos están igualmente espaciados, puede usar una división para calcular un índice directo en la tabla (sin ramificar).Calcular una tabla de este tipo requeriría algo de esfuerzo, pero dependiendo de tus necesidades puede valer la pena:un puñado de multiplicaciones y sumas podría ser menos costoso que el exponencial.Dicho esto, los implementadores del exponencial en la biblioteca probablemente hayan dedicado mucho tiempo y esfuerzo a lograr que el suyo sea correcto y rápido.Además, la función "mish" no presenta ninguna oportunidad de reducir el alcance.

Licenciado bajo: CC-BY-SA con atribución
No afiliado a cs.stackexchange
scroll top