Question

$$f(x) = x anh(\log(1 + e^x))$$

La fonction (mish activation) peut être facilement mis en œuvre à l'aide d'un stable log1pexp sans aucune perte de précision.Malheureusement, c'est mathématiquement lourde.

Est-il possible d'écrire un rapport plus direct numériquement stable de mise en œuvre qui est le plus rapide?

Une précision aussi bonne que x * std::tanh(std::log1p(std::exp(x))) serait bien.Il n'y a pas de strictes contraintes, mais il doit être suffisamment précis pour une utilisation dans les réseaux de neurones.

De la distribution des intrants est de $[-\infty, \infty]$.Il devrait fonctionner partout.

Était-ce utile?

La solution

OP points pour un particulier la mise en œuvre de la mish fonction d'activation pour les spécifications de précision, j'ai donc eu pour caractériser cette première.Que la mise en œuvre utilise simple précision (float), et est stable et précise dans le positif demi-plan.Dans la négative demi-plan, car il utilise logf au lieu de log1pf, l'erreur relative grandit vite un $x o-\infty$.La perte de précision commence autour de $-1$ et déjà à l' $-16.6355324$ la mise en œuvre faussement retourne $0$, parce que $\exp(-16.6355324) = 2^{-24}$.

La même précision et le comportement peut être obtenue en utilisant une simple transformation mathématique qui élimine $\mathrm{tahn}$, et considérant que les Gpu offrent généralement un fused multiply-add (FMA), ainsi qu'un rapide réciproque, que l'on voudrait utiliser.Exemplaire CUDA code se présente comme suit:

__device__ float my_mishf (float x)
{
    float r;
    float e = expf (x);
    r = 1.0f / fmaf (fmaf (-0.5f, e, -1.0f), e, -1.0f);
    r = fmaf (r, x, x);
    return r;
}

Comme avec l'implémentation de référence indiqué par l'OP, ce qui a une excellente précision dans le positif demi-plan, et par la négative à la moitié du plan d'erreur augmente rapidement de manière à $-16.6355324$ la mise en œuvre faussement retourne $0$.

Si il ya un désir de répondre à ces questions relatives à l'exactitude, nous pouvons appliquer les observations suivantes.Pour suffisamment petit $x$, $f(x) = x \exp(x)$ à l'intérieur de virgule flottante de précision.Pour float le calcul de ce secteur pour $x < -15$.Pour l'intervalle $[-15,-1]$, on peut utiliser une approximation rationnelle $R(x)$ pour calculer $f(x) := R(x)x\exp(x)$.Exemplaire CUDA code se présente comme suit:

__device__ float my_mishf (float x)
{
    float r;
    if (x >= -1.0f) {
        float e = expf (x);
        r = 1.0f / fmaf (fmaf (-0.5f, e, -1.0f), e, -1.0f);
        r = fmaf (r, x, x);
    } else {
        float eh = expf (0.5f * x);
        float p =        1.03628484e-3f;  //  0x1.0fa7e6p-10
        p = fmaf (p, x, -7.28869531e-3f); // -0x1.ddac04p-8
        p = fmaf (p, x,  3.47027816e-2f); //  0x1.1c4902p-5
        p = fmaf (p, x, -3.54762226e-1f); // -0x1.6b46cap-2
        p = fmaf (p, x,  8.58785570e-1f); //  0x1.b7b2bep-1
        p = fmaf (p, x, -1.38065982e+0f); // -0x1.6172ecp+0
        p = fmaf (p, x,  5.97694337e-1f); //  0x1.3204fep-1
        float q =        1.03527203e-3f;  //  0x1.0f63eep-10
        q = fmaf (q, x, -7.35638570e-3f); // -0x1.e21bacp-8
        q = fmaf (q, x,  3.28683928e-2f); //  0x1.0d4204p-5
        q = fmaf (q, x, -3.79927397e-1f); // -0x1.850bb0p-2 
        q = fmaf (q, x,  6.86127126e-1f); //  0x1.5f4c0ep-1
        q = fmaf (q, x, -1.81509292e+0f); // -0x1.d0a9eep+0
        q = fmaf (q, x,  1.00000000e+0f); //  0x1.000000p+0
        r = (1.0f / q) * p;
        if (x < -15.0f) r = 1.0f;
        r = r * x * eh * eh;
    }
    return r;
}

Malheureusement, cette solution exacte est obtenue au prix d'une baisse significative des performances.Si l'on est disposé à accepter une réduction de la précision tout de même d'atteindre une douceur décomposition gauche de la queue, le suivant schéma d'interpolation, sur la base $f(x) \approx x\exp(x)$, récupère une grande partie de la performance:

__device__ float my_mishf (float x)
{
    float r;
    float e = expf (x);
    if (x >= -6.0625f) {
        r = 1.0f / fmaf (fmaf (-0.5f, e, -1.0f), e, -1.0f);
        r = fmaf (r, x, x);
    } else {
        r = fmaf (-0.5f, e, 1.0f);
        r = r * x * e;
    }
    return r;
}

Comme une machine spécifique à l'amélioration de la performance, expf() pourrait être remplacé par le dispositif intrinsèque __expf().

Autres conseils

Avec une certaine manipulation algébrique (comme l'a souligné @orlp de réponse), on peut en déduire les éléments suivants:

$$f(x) = x anh(\log(1+e^x)) ag{1}$$ $$ = x\frac{(1+e^x)^2 - 1}{(1+e^x)^2 + 1} = x\frac{e^{2x} + 2e^x}{e^{2x} + 2e^x + 2} ag{2}$$ $$ = x - \frac{2x}{(1 + e^x)^2 + 1} ag{3}$$

Expression $(3)$ fonctionne très bien lorsque $x$ est négatif avec très peu de perte de précision.Expression $(2)$ n'est pas adapté pour les grandes valeurs de $x$ puisque les conditions sont sur le point d'exploser à la fois au numérateur et au dénominateur.

La fonction $(1)$ asymptotiquement atteint zéro, comme $x o-\infty$.Maintenant que $x$ devient de plus grande ampleur, l'expression $(3)$ à souffrir de catastrophique d'annulation:deux grandes modalités annuler les uns les autres pour donner un très petit nombre.L'expression $(2)$ est plus approprié dans cette gamme.

Cela fonctionne assez bien jusqu'à ce que $-18$ et au-delà de laquelle vous perdez plusieurs chiffres significatifs.

Prenons regarder de plus près la fonction et essayer de rapprocher $f(x)$ comme $x o-\infty$.

$$f(x) = x - \frac{e^{2x} + 2e^x}{e^{2x} + 2e^x + 2}$$

L' $e^{2x}$ sera ordres de grandeur plus petite que $e^x$. $e^x$ sera ordres de grandeur plus petite que $1$.À l'aide de ces deux faits, nous pouvons rapprocher $f(x)$ pour:

$f(x) \approx x\frac{e^x}{e^x+1}\approx xe^x$

Résultat:

$f(x) \approx \begin{cas} xe^x & ext{si $x \le -18$} \\ x\frac{e^{2x} + 2e^x}{e^{2x} + 2e^x + 2} & ext{si $-18 \lt x \le -0.6$} \\ x - \frac{2x}{(1 + e^x)^2 + 1}, & ext{sinon} \end{cas} $

Rapide CUDA mise en œuvre:

__device__ float mish(float x)
{
    auto e = __expf(x);
    if (x <= -18.0f)
        return x * e;    

    auto n = e * e + 2 * e;
    if (x <= -0.6f)
        return x * __fdividef(n, n + 2);

    return x - 2 * __fdividef(x, n + 2);
}

EDIT:

Encore plus rapide et précis version:

$f(x) \approx \begin{cas} x\frac{e^{2x} + 2e^x}{e^{2x} + 2e^x + 2} & ext{$x \le -0.6$} \\ x - \frac{2x}{(1 + e^x)^2 + 1}, & ext{sinon} \end{cas} $

__device__ float mish(float x)
{
    auto e = __expf(value);
    auto n = e * e + 2 * e;
    if (value <= -0.6f)
        return value * __fdividef(n, n + 2);

    return value - 2 * __fdividef(value, n + 2);
}

Code: https://gist.github.com/YashasSamaga/8ad0cd3b30dbd0eb588c1f4c035db28c

$$\begin{array}{c|c|c|c|} & ext{Temps (float)} & ext{Temps (float4)} & ext{L2 norme du vecteur d'erreur} \\ \hline ext{méli} & 1.49 ms & 1.39 ms & 2.4583 e-05 \\ \hline ext{relu} & 1.47 ms & 1.39 ms & ext{N/A} \\ \hline \end{array}$$

Il n'y a pas besoin d'effectuer le logarithme.Si vous laissez $p = 1+\exp(x)$ ensuite, nous avons $f(x) = x\cdot\dfrac{p^2-1}{p^2+1}$ ou sinon $f(x) = x - \dfrac{2x}{p^2+1}$.

Mon impression est que quelqu'un a voulu multiplier x par une fonction f(x) qui va en douceur, de 0 à 1, et expérimenté jusqu'à ce qu'ils trouvent une expression à l'aide de fonctions élémentaires qui l'a fait, sans raison mathématique derrière le choix des fonctions.

Après le choix d'un paramètre t, laissez $p_t(x) = 1/2 + (3 / 4t)x - x^3 / (4t^3)$, puis $p_t(0) = 1/2$, $p_t(t) = 1$, $p_t(-t) = 0$, et $p_t'(t) = p_t'(-t) = 0$.Soit g(x) = 0 si x < -t, 1 si x > +1, et $p_t(x)$ si -t ≤ x ≤ +t.C'est une fonction que douceur des changements de 0 à 1.Choisir un autre paramètre s, et au lieu de f(x) calculer x * g (x - s).

t = 3.0 et s = -0.3 correspond à la fonction donnée tout à fait raisonnable, et elle est calculée beaucoup plus rapide (ce qui semble important).C'est différent bien sûr.Comme cette fonction est utilisée comme un outil de problème, je veux voir un mathemtical raison pour laquelle la fonction d'origine est mieux.

Le contexte ici est la vision par ordinateur et de la fonction d'activation pour la formation des réseaux de neurones.

Les Chances sont ce code va être exécuté sur un GPU.Alors que la performance va dépendre de la répartition des entrées typiques, généralement parlant, il est important d'éviter les branches de code GPU.Warp divergence peut considérablement dégrader les performances de votre code.Par exemple, l' CUDA Toolkit Documentation dit:

Remarque:Haute Priorité:Éviter les différents chemins d'exécution dans la même courbure.Flux des instructions de contrôle (if, switch, n', for, while) peuvent affecter de manière significative l'instruction de débit en causant des threads d'un même warp à diverger;c'est-à-suivre les différents chemins d'exécution.Si cela se produit, les différents chemins d'exécution doit être exécutée séparément;cela augmente le nombre total d'instructions exécutées pour le compte de cette chaine....Pour les branches, y compris quelques instructions, de la chaîne de divergence entraîne généralement des pertes de rendement marginal.Par exemple, le compilateur peut utiliser la prédication pour éviter une branche.Au lieu de cela, toutes les instructions sont prévues, mais par thread condition de code ou de prédicat contrôles qui threads d'exécuter les instructions.Des discussions avec un faux prédicat ne pas écrire les résultats, et aussi ne pas évaluer adresses ou de lire des opérandes.

Deux de la branche d'implémentations libres

OP répondre n'ont de courtes branches afin de direction de la prévision peut se produire avec certains compilateurs.Une autre chose que j'ai remarqué c'est qu'il semble être acceptable pour calculer l'exponentielle d'une fois par appel.C'est, je comprends OP réponse-à-dire un appel à l'exponentielle n'est pas "cher" ou "lent".

Dans ce cas, je vous suggère le code simple suivant:

__device__ float mish(float x)
{
    float expx = __expf(x);
    return x / (1.0f + 2.0f / (expx * (2.0f + expx)));
}

Il n'a pas de branches, l'une exponentielle, une multiplication, et les deux divisions.Les Divisions sont souvent plus cher que les multiplications donc j'ai essayé ce code:

__device__ float mish(float x)
{
    float expx = __expf(x);
    float psi = expx * (2.0f + expx);
    return x * (psi / (2.0f + psi));
}

Cela n'a pas de branches, l'une exponentielle, deux multiplications et une division.

L'erreur Relative

J'ai calculé la log10 précision relative de ces deux implémentations plus OP de réponse.J'ai calculé l'intervalle (-100,100) avec un incrément de 1/1024, alors calculée à l'exécution d'un maximum de plus de 51 valeurs (afin de réduire l'encombrement visuel, mais encore de donner l'impression correcte).Le calcul de la première mise en œuvre avec la double précision suffit comme référence.L'exponentielle est précis à l'intérieur d'une pratique de travail déloyale, et il y a seulement une poignée d'opérations arithmétiques;le reste des bits sont plus que suffisants pour faire un dilemme du fabricant de tables très peu probable.Ainsi, nous sommes très susceptibles d'être en mesure de calculer correctement arrondie simple précision des valeurs de référence.

Log10 relative error

Vert:première mise en œuvre.Rouge:deuxième mise en oeuvre.Bleu:OP de mise en œuvre.Le bleu et le rouge se chevauchent plus grande partie de leur gamme (à gauche de -20).

Note à l'OP:vous aurez envie de changer la fréquence de coupure supérieure à -5 si vous souhaitez conserver l'intégralité de la précision.

Performance

Vous aurez à tester ces deux implémentations pour voir qui est plus rapide.Ils devraient être au moins aussi rapide que l'OP, et je suppose qu'ils vont être beaucoup plus rapide en raison de l'absence de branches.Toutefois, si elles ne sont pas assez rapide pour vous, il ya plus que vous pouvez faire.

Une question importante:

Quelle est la répartition typique des valeurs d'entrée vous vous attendez à voir?Sont des valeurs va être distribués de manière uniforme sur l'ensemble de la gamme, la fonction est effectivement calculable?Ou vont-ils être regroupés autour de 0 presque tout le temps?Si oui, avec quel écart/propagation?

L'asymptotique peut être amélioré.

Sur la gauche, OP utilise x * expx avec une coupure de -18.Cette coupure peut être porté à environ -15.5625 sans perte de précision.Avec le coût d'une multiplication, vous pouvez utiliser x * expx * (1.0f - 0.5f * expx) et une coupure d'environ -4.875.Note:la multiplication de 0,5 peuvent être optimisés pour une soustraction de 1 à partir de l'exposant donc je ne suis pas compter qu'ici.

Sur la droite, vous pouvez introduire un autre asymptotique.Si x > 8.75, tout simplement return x.Avec un peu plus de coût, vous pouvez le faire x * (1.0f - 2.0f * __expf(-2.0f * x)) lorsque x > 6.0.

Interpolation

Pour la partie centrale de la plage (-4.875, 6.0), vous pouvez utiliser une table de pour les interpolants.Si leurs gammes sont régulièrement espacés, vous pouvez utiliser l'une division pour calculer un indice direct dans la table (sans branchement).Le calcul d'un tel tableau serait prendre un certain effort, mais en fonction de vos besoins peuvent être en vaut la peine:une poignée de multiplie les et ajoute peut être moins cher que l'exponentielle.Cela dit, la mise en œuvre de l'exponentielle dans la bibliothèque probablement avez passé beaucoup de temps et d'efforts à obtenir leur correcte et rapide.Aussi, le "méli" la fonction n'est pas de présenter toutes les possibilités de réduction de la portée.

Licencié sous: CC-BY-SA avec attribution
Non affilié à cs.stackexchange
scroll top