Pergunta

$$f(x) = x anh(\log(1 + e^x))$$

A função (ativação mish) pode ser facilmente implementada usando um log1pexp estável sem qualquer perda significativa de precisão.Infelizmente, isso é computacionalmente pesado.

É possível escrever uma implementação numericamente estável mais direta e mais rápida?

Precisão tão boa quanto x * std::tanh(std::log1p(std::exp(x))) seria bom.Não há restrições estritas, mas deve ser razoavelmente preciso para uso em redes neurais.

A distribuição de insumos é de $[-\infty, \infty]$.Deve funcionar em qualquer lugar.

Foi útil?

Solução

OP aponta para um determinado implementação do mish função de ativação para especificações de precisão, então tive que caracterizá-la primeiro.Essa implementação usa precisão única (float) e é estável e preciso no semiplano positivo.No semiplano negativo, porque usa logf em vez de log1pf, o erro relativo cresce rapidamente $x o-\infty$.A perda de precisão começa por volta $-1$ e já em $-16.6355324$ a implementação retorna falsamente $0$, porque $\exp(-16,6355324) = 2^{-24}$.

A mesma precisão e comportamento podem ser alcançados usando uma transformação matemática simples que elimina $\mathrm{tahn}$, e considerando que as GPUs normalmente oferecem uma adição múltipla fundida (FMA), bem como uma recíproca rápida, que seria desejável utilizar.O código CUDA exemplar é o seguinte:

__device__ float my_mishf (float x)
{
    float r;
    float e = expf (x);
    r = 1.0f / fmaf (fmaf (-0.5f, e, -1.0f), e, -1.0f);
    r = fmaf (r, x, x);
    return r;
}

Tal como acontece com a implementação de referência apontada pelo OP, esta tem excelente precisão no semiplano positivo e no semiplano negativo o erro aumenta rapidamente, de modo que em $-16.6355324$ a implementação retorna falsamente $0$.

Se desejarmos abordar essas questões de precisão, podemos aplicar as seguintes observações.Para suficientemente pequeno $x$, $f(x) = x\exp(x)$ com precisão de ponto flutuante.Para float cálculo isso vale para $ x < -15$.Para o intervalo $[-15,-1]$, podemos usar uma aproximação racional $R(x)$ calcular $f(x) := R(x)x\exp(x)$.O código CUDA exemplar é o seguinte:

__device__ float my_mishf (float x)
{
    float r;
    if (x >= -1.0f) {
        float e = expf (x);
        r = 1.0f / fmaf (fmaf (-0.5f, e, -1.0f), e, -1.0f);
        r = fmaf (r, x, x);
    } else {
        float eh = expf (0.5f * x);
        float p =        1.03628484e-3f;  //  0x1.0fa7e6p-10
        p = fmaf (p, x, -7.28869531e-3f); // -0x1.ddac04p-8
        p = fmaf (p, x,  3.47027816e-2f); //  0x1.1c4902p-5
        p = fmaf (p, x, -3.54762226e-1f); // -0x1.6b46cap-2
        p = fmaf (p, x,  8.58785570e-1f); //  0x1.b7b2bep-1
        p = fmaf (p, x, -1.38065982e+0f); // -0x1.6172ecp+0
        p = fmaf (p, x,  5.97694337e-1f); //  0x1.3204fep-1
        float q =        1.03527203e-3f;  //  0x1.0f63eep-10
        q = fmaf (q, x, -7.35638570e-3f); // -0x1.e21bacp-8
        q = fmaf (q, x,  3.28683928e-2f); //  0x1.0d4204p-5
        q = fmaf (q, x, -3.79927397e-1f); // -0x1.850bb0p-2 
        q = fmaf (q, x,  6.86127126e-1f); //  0x1.5f4c0ep-1
        q = fmaf (q, x, -1.81509292e+0f); // -0x1.d0a9eep+0
        q = fmaf (q, x,  1.00000000e+0f); //  0x1.000000p+0
        r = (1.0f / q) * p;
        if (x < -15.0f) r = 1.0f;
        r = r * x * eh * eh;
    }
    return r;
}

Infelizmente, esta solução precisa é alcançada ao custo de uma queda significativa no desempenho.Se alguém estiver disposto a aceitar uma precisão reduzida e ao mesmo tempo obter uma cauda esquerda com decaimento suave, o seguinte esquema de interpolação, novamente baseado em $f(x) \aprox x\exp(x)$, recupera grande parte do desempenho:

__device__ float my_mishf (float x)
{
    float r;
    float e = expf (x);
    if (x >= -6.0625f) {
        r = 1.0f / fmaf (fmaf (-0.5f, e, -1.0f), e, -1.0f);
        r = fmaf (r, x, x);
    } else {
        r = fmaf (-0.5f, e, 1.0f);
        r = r * x * e;
    }
    return r;
}

Como um aprimoramento de desempenho específico da máquina, expf() poderia ser substituído pelo dispositivo intrínseco __expf().

Outras dicas

Com alguma manipulação algébrica (conforme apontado na resposta do @orlp), podemos deduzir o seguinte:

$$f(x) = x anh(\log(1+e^x)) ag{1}$$ $$ = x\frac{(1+e^x)^2 - 1}{(1+e^x)^2 + 1} = x\frac{e^{2x} + 2e^x}{e^ {2x} + 2e^x + 2} ag{2}$$ $$ = x - \frac{2x}{(1 + e^x)^2 + 1} ag{3}$$

Expressão $(3)$ funciona muito bem quando $x$ é negativo com muito pouca perda de precisão.Expressão $(2)$ não é adequado para grandes valores de $x$ já que os termos vão explodir tanto no numerador quanto no denominador.

A função $(1)$ atinge assintoticamente zero como $x o-\infty$.Agora como $x$ torna-se maior em magnitude, a expressão $(3)$ sofrerá um cancelamento catastrófico:dois termos grandes que se cancelam para dar um número realmente pequeno.A expressão $(2)$ é mais adequado nesta faixa.

Isso funciona razoavelmente bem até $-18$ e além do qual você perde vários algarismos significativos.

Vamos dar uma olhada mais de perto na função e tentar aproximar $f(x)$ como $x o-\infty$.

$$f(x) = x \frac{e^{2x} + 2e^x}{e^{2x} + 2e^x + 2}$$

O $e^{2x}$ serão ordens de grandeza menores que $e^x$. $e^x$ serão ordens de grandeza menores que $1$.Usando esses dois fatos, podemos aproximar $f(x)$ para:

$f(x) \aprox x\frac{e^x}{e^x+1}\aprox xe^x$

Resultado:

$f(x) \approx \begin{cases} xe^x, & ext{if $x \le -18$} \\ x\frac{e^{2x} + 2e^x}{e^{2x} + 2e^x + 2} & ext{if $-18 \lt x \le -0.6$} \\ x - \frac{2x}{(1 + e^x)^2 + 1}, & ext{otherwise} \end{casos} $

Implementação rápida de CUDA:

__device__ float mish(float x)
{
    auto e = __expf(x);
    if (x <= -18.0f)
        return x * e;    

    auto n = e * e + 2 * e;
    if (x <= -0.6f)
        return x * __fdividef(n, n + 2);

    return x - 2 * __fdividef(x, n + 2);
}

EDITAR:

Uma versão ainda mais rápida e precisa:

$ f (x) aprox begin {casos} x frac {e^{2x} + 2e^x} {e^{2x} + 2e^x + 2} & text {$ x le -0.6 $ } x - frac {2x} {(1 + e^x)^2 + 1}, & text {Caso contrário} end {casos} $

__device__ float mish(float x)
{
    auto e = __expf(value);
    auto n = e * e + 2 * e;
    if (value <= -0.6f)
        return value * __fdividef(n, n + 2);

    return value - 2 * __fdividef(value, n + 2);
}

Código: https://gist.github.com/YashasSamaga/8ad0cd3b30dbd0eb588c1f4c035db28c

$$\begin{array}{c|c|c|c|} & ext{Time (float)} & ext{Time (float4)} & ext{L2 norma do vetor de erro} \\ \hline ext{mish} & 1.49ms & 1.39ms & 2.4583e-05 \\ \hline ext{relu} & 1.47ms & 1.39ms & ext{N/A} \\ \hline \end{matriz}$$

Não há necessidade de executar o logaritmo.Se você deixar $ p= 1+ \ exp (x) $ então temos $ f (x)= x \ cdot \ dfac {p ^ 2-1} {p ^ 2 + 1} $ ou alternativamente $ f (x)= x - \ dfac {2x} {p ^ 2 + 1} $ .

Minha impressão é que alguém queria multiplicar x por uma função f (x) que vai suavemente de 0 a 1, e experimentado até encontrar uma expressão usando funções elementares que fizeram isso, sem motivo matemático por trás da escolha de funções .

Após escolher um parâmetro t, deixe $ p_t (x)= 1/2 + (3 / 4T) x - x ^ 3 / (4T ^ 3) $ , então $ p_t (0)= 1/2 $ , $ p_t (t)= 1 $ , $ p_t (-t)= 0 $ e $ p_t '(t)= p_t' (- t)= 0 $ . Seja g (x)= 0 se x <-T, 1 se x> +1, e $ p_t (x) $ se -t ≤ x ≤ + t. Esta é uma função que muda suavemente de 0 para 1. Escolha outro parâmetro s, e em vez de f (x) calcular x * g (x - s).

t= 3,0 e s= -0.3 corresponde à função dada bastante razoavelmente e é calculada um terrível muito mais rápido (que parece importante). É diferente do curso. Como esta função é usada como uma ferramenta em algum problema, eu gostaria de ver uma razão matemática por que a função original é melhor .

O contexto aqui é a visão computacional e a função de ativação para treinar redes neurais.

Provavelmente, esse código será executado em uma GPU.Embora o desempenho dependa da distribuição de insumos típicos, de modo geral, é importante evitar ramificações no código GPU.A divergência de warp pode degradar significativamente o desempenho do seu código.Por exemplo, o Documentação do kit de ferramentas CUDA diz:

Nota: Alta Prioridade:Evite diferentes caminhos de execução dentro do mesmo warp.As instruções de controle de fluxo (if, switch, do, for, while) podem afetar significativamente o rendimento da instrução, fazendo com que threads do mesmo warp diverjam;isto é, seguir diferentes caminhos de execução.Se isso acontecer, os diferentes caminhos de execução deverão ser executados separadamente;isso aumenta o número total de instruções executadas para este warp....Para ramificações que incluem apenas algumas instruções, a divergência de warp geralmente resulta em perdas marginais de desempenho.Por exemplo, o compilador pode usar predicação para evitar uma ramificação real.Em vez disso, todas as instruções são agendadas, mas um código de condição ou predicado por thread controla quais threads executam as instruções.Threads com predicado falso não gravam resultados e também não avaliam endereços ou lêem operandos.

Duas implementações sem ramificação

Resposta do OP tem ramificações curtas, portanto a predicação de ramificação pode acontecer com alguns compiladores.Outra coisa que notei é que parece aceitável calcular o exponencial uma vez por chamada.Ou seja, entendo a resposta do OP de dizer que uma chamada para o exponencial não é "cara" ou "lenta".

Nesse caso, eu sugeriria o seguinte código simples:

__device__ float mish(float x)
{
    float expx = __expf(x);
    return x / (1.0f + 2.0f / (expx * (2.0f + expx)));
}

Não tem ramificações, uma exponencial, uma multiplicação e duas divisões.As divisões costumam ser mais caras que as multiplicações, então também experimentei este código:

__device__ float mish(float x)
{
    float expx = __expf(x);
    float psi = expx * (2.0f + expx);
    return x * (psi / (2.0f + psi));
}

Não tem ramificações, uma exponencial, duas multiplicações e uma divisão.

Erro relativo

Calculei a precisão relativa log10 dessas duas implementações mais a resposta do OP.Calculei o intervalo (-100.100) com um incremento de 1/1024 e, em seguida, calculei um máximo em execução acima de 51 valores (para reduzir a confusão visual, mas ainda dar a impressão correta).Calcular a primeira implementação com dupla precisão é suficiente como referência.O exponencial tem precisão de um ULP e há apenas algumas operações aritméticas;o restante dos bits é mais que suficiente para tornar muito improvável o dilema do fabricante de mesas.Portanto, é muito provável que consigamos calcular valores de referência de precisão simples arredondados corretamente.

Log10 relative error

Verde:primeira implementação.Vermelho:segunda implementação.Azul:Implementação do OP.O azul e o vermelho se sobrepõem na maior parte de seu alcance (à esquerda de cerca de -20).

Nota para OP:você desejará alterar o corte para maior que -5 se quiser manter a precisão total.

Desempenho

Você terá que testar essas duas implementações para ver qual é mais rápida.Eles deveriam ser pelo menos tão rápidos quanto os OPs, e suspeito que serão muito mais rápidos devido à falta de filiais.No entanto, se eles não forem rápidos o suficiente para você, há mais que você pode fazer.

Uma questão importante:

Qual é a distribuição dos valores de entrada típicos que você espera ver?Os valores serão distribuídos uniformemente por todo o intervalo em que a função é efetivamente computável?Ou eles ficarão agrupados em torno de 0 quase o tempo todo?Se sim, com que variação/spread?

A assintótica pode ser melhorada.

À esquerda, OP usa x * expx com um corte de -18.Este corte pode ser aumentado para cerca de -15,5625 sem perda de precisão.Com o custo de uma multiplicação extra, você poderia usar x * expx * (1.0f - 0.5f * expx) e um corte de cerca de -4,875.Observação:a multiplicação por 0,5 pode ser otimizada para uma subtração de 1 do expoente, então não estou contando isso aqui.

À direita, você pode introduzir outro assintótico.Se x > 8.75, simplesmente return x.Com um pouco mais de custo, você poderia fazer x * (1.0f - 2.0f * __expf(-2.0f * x)) quando x > 6.0.

Interpolação

Para a parte central do intervalo (-4,875, 6,0), você pode usar uma tabela de interpolantes.Se seus intervalos estiverem igualmente espaçados, você poderá usar uma divisão para calcular um índice direto na tabela (sem ramificação).Calcular tal tabela exigiria algum esforço, mas dependendo de suas necessidades pode valer a pena:um punhado de multiplicações e somas poder ser menos caro que o exponencial.Dito isto, os implementadores do exponencial na biblioteca provavelmente gastaram muito tempo e esforço para acertar e rapidamente.Além disso, a função “mish” não apresenta nenhuma oportunidade de redução de alcance.

Licenciado em: CC-BY-SA com atribuição
Não afiliado a cs.stackexchange
scroll top