문제

$$f(x) = x anh(\log(1 + e^x))$$

이 기능(미시 활성화)은 상당한 정밀도 손실 없이 안정적인 log1pexp를 사용하여 쉽게 구현할 수 있습니다.불행하게도 이것은 계산적으로 무겁습니다.

더 빠르고 직접적으로 수치적으로 안정적인 구현을 작성하는 것이 가능합니까?

정확도는 다음과 같습니다. x * std::tanh(std::log1p(std::exp(x))) 좋은 것.엄격한 제약은 없지만 신경망에서 사용하려면 합리적으로 정확해야 합니다.

입력의 분포는 다음과 같습니다. $[-\infty, \infty]$.어디에서나 작동해야 합니다.

도움이 되었습니까?

해결책

OP는 특정을 가리킨다. 구현 ~의 mish 정확도 사양을 위한 활성화 함수이므로 먼저 이를 특성화해야 했습니다.해당 구현에서는 단정밀도(float), 양의 반면에서는 안정적이고 정확합니다.음의 반면에서는 다음을 사용하기 때문에 logf 대신에 log1pf, 상대 오류가 빠르게 증가합니다. $x o-\infty$.정확도 손실은 대략 시작됩니다. $-1$ 그리고 이미 $-16.6355324$ 구현이 잘못 반환됨 $0$, 왜냐하면 $\exp(-16.6355324) = 2^{-24}$.

다음을 제거하는 간단한 수학적 변환을 사용하여 동일한 정확도와 동작을 얻을 수 있습니다. $\mathrm{tahn}$, 그리고 GPU는 일반적으로 FMA(Fused Multiply-Add)와 빠른 역수(FMA)를 제공한다는 점을 고려하면 이를 활용할 수 있습니다.예시적인 CUDA 코드는 다음과 같습니다:

__device__ float my_mishf (float x)
{
    float r;
    float e = expf (x);
    r = 1.0f / fmaf (fmaf (-0.5f, e, -1.0f), e, -1.0f);
    r = fmaf (r, x, x);
    return r;
}

OP에서 지적한 참조 구현과 마찬가지로 이는 양의 절반 평면에서 탁월한 정확도를 가지며 음의 절반 평면 오류는 급격히 증가하므로 다음과 같습니다. $-16.6355324$ 구현이 잘못 반환됨 $0$.

이러한 정확성 문제를 해결하려는 경우 다음 관찰을 적용할 수 있습니다.충분히 작은 경우 $x$, $f(x) = x \exp(x)$ 부동 소수점 정확도 내에서.을 위한 float 이것이 유지되는 계산 $x < -15$.간격의 경우 $[-15,-1]$, 우리는 합리적인 근사치를 사용할 수 있습니다 $R(x)$ 계산하기 $f(x) := R(x)x\exp(x)$.예시적인 CUDA 코드는 다음과 같습니다:

__device__ float my_mishf (float x)
{
    float r;
    if (x >= -1.0f) {
        float e = expf (x);
        r = 1.0f / fmaf (fmaf (-0.5f, e, -1.0f), e, -1.0f);
        r = fmaf (r, x, x);
    } else {
        float eh = expf (0.5f * x);
        float p =        1.03628484e-3f;  //  0x1.0fa7e6p-10
        p = fmaf (p, x, -7.28869531e-3f); // -0x1.ddac04p-8
        p = fmaf (p, x,  3.47027816e-2f); //  0x1.1c4902p-5
        p = fmaf (p, x, -3.54762226e-1f); // -0x1.6b46cap-2
        p = fmaf (p, x,  8.58785570e-1f); //  0x1.b7b2bep-1
        p = fmaf (p, x, -1.38065982e+0f); // -0x1.6172ecp+0
        p = fmaf (p, x,  5.97694337e-1f); //  0x1.3204fep-1
        float q =        1.03527203e-3f;  //  0x1.0f63eep-10
        q = fmaf (q, x, -7.35638570e-3f); // -0x1.e21bacp-8
        q = fmaf (q, x,  3.28683928e-2f); //  0x1.0d4204p-5
        q = fmaf (q, x, -3.79927397e-1f); // -0x1.850bb0p-2 
        q = fmaf (q, x,  6.86127126e-1f); //  0x1.5f4c0ep-1
        q = fmaf (q, x, -1.81509292e+0f); // -0x1.d0a9eep+0
        q = fmaf (q, x,  1.00000000e+0f); //  0x1.000000p+0
        r = (1.0f / q) * p;
        if (x < -15.0f) r = 1.0f;
        r = r * x * eh * eh;
    }
    return r;
}

불행히도 이 정확한 솔루션은 성능이 크게 저하되는 대가로 달성됩니다.부드럽게 감소하는 왼쪽 꼬리를 달성하면서 감소된 정확도를 기꺼이 수용하려는 경우 다음 보간 방식은 다시 다음을 기반으로 합니다. $f(x) \대략 x\exp(x)$, 성능의 상당 부분을 복구합니다.

__device__ float my_mishf (float x)
{
    float r;
    float e = expf (x);
    if (x >= -6.0625f) {
        r = 1.0f / fmaf (fmaf (-0.5f, e, -1.0f), e, -1.0f);
        r = fmaf (r, x, x);
    } else {
        r = fmaf (-0.5f, e, 1.0f);
        r = r * x * e;
    }
    return r;
}

기계별 성능 향상으로, expf() 장치 내장으로 대체될 수 있음 __expf().

다른 팁

일부 대수 조작 (@ orlp의 답변에서 지적 된 것처럼), 우리는 다음을 추론 할 수 있습니다 :

$$f (x)= x \ tanh (\ log (1 + e ^ x)) \ 태그 {1} $$ $$= x \ fRAC {(1 + e ^ X) ^ 2 - 1} {(1 + e ^ X) ^ 2 + 1}= x \ frac {e ^ { 2x} + 2E ^ X} {e ^ {2x} + 2e ^ x + 2} \ 태그 {2} $$ $$= x - \ frac {2x} {(1 + e ^ X) ^ 2 + 1} \ 태그 {3} $$

표현식 $ (3) $ $ x $ 이 부정적인 일 때 정도. 표현식 $ (2) $ $ x $ 의 큰 값에 적합하지 않습니다. 분자와 분모에서 모두 폭파하십시오.

함수 $ (1) $ $ x \ to-\ infty $ ...에 이제 $ x $ 이 크기가 커지면 $ (3) $ 은 치명적인 취소로 고통받을 것입니다. : 서로 취소하는 두 가지 큰 용어가 서로 정말로 적은 수를 제공합니다. 표현식 $ (2) $ 이이 범위에 더 적합합니다.

이것은 $ -8 $ 및 엑스에 여러 가지 중요한 수치를 잃어 버릴 때까지 상당히 잘 작동합니다.

함수를 자세히 살펴보고 $ f (x) $ 대략 $ x \ - \ infty $ .

$$ f (x)= x \ frac {e ^ {2x} + 2e ^ x} {e ^ {2x} + 2e ^ x + 2} $$ < / span>

$ ^ {2x} $ e ^ {2x} $ $ e ^ x $ . $ e ^ x $ $ 1 $ 보다 작은 크기의 명령이 될 것입니다. 이 두 가지 사실을 사용하여 $ f (x) $ 을 다음과 같이 대략적으로 대략적으로 대략적으로 대략적으로 대략적으로 할 수 있습니다 :

$ f (x) \ 약 x \ frac {e ^ x} {e ^ x + 1} \ 약 xe ^ x $

결과 :

$ f (x) \ 약 \ 시작 {사례} xe ^ x, & \ text {$ x \ le -18 $} \\ x \ frac {e ^ {2x} + 2e ^ x} {e ^ {2x} + 2e ^ x + 2} & \ text {$ -18 \ lt x \ le -0.6 $} \\ x - \ frac {2x} {(1 + e ^ X) ^ 2 + 1}, & \ text {그렇지 않으면} \ end {사례} $

빠른 CUDA 구현 :

__device__ float mish(float x)
{
    auto e = __expf(x);
    if (x <= -18.0f)
        return x * e;    

    auto n = e * e + 2 * e;
    if (x <= -0.6f)
        return x * __fdividef(n, n + 2);

    return x - 2 * __fdividef(x, n + 2);
}
.

편집 :

더 빠르고 정확한 버전 :

$ f (x) \ 약 \ 시작 {사례} x \ frac {e ^ {2x} + 2e ^ x} {e ^ {2x} + 2e ^ x + 2} & \ text {$ x \ le -0.6 $} \\ x - \ frac {2x} {(1 + e ^ X) ^ 2 + 1}, & \ text {그렇지 않으면} \ end {사례} $

__device__ float mish(float x)
{
    auto e = __expf(value);
    auto n = e * e + 2 * e;
    if (value <= -0.6f)
        return value * __fdividef(n, n + 2);

    return value - 2 * __fdividef(value, n + 2);
}
.

코드 : https://gist.github.com/yashassamaga/8ad0cd3b30dbd0eb588c1f4c035db28c P>

$$ \ begin {ARRAY} {C | C | C | C |} & \ text {time (float)} & \ text {time (float4)} & \ text {l2 오류 벡터의 표준} \\ \ Hline \ 텍스트 {mish} & 1.49ms & 1.39ms & 2.4583E-05 \\ \ HLINE \ text {relu} & 1.47ms & 1.39ms & \ text {n / a} \\ \ hline \ end {array} $$

로그를 수행 할 필요가 없습니다. $ p= 1+ \ exp (x) $ 을 보내면 우리는 $ F (x)= x \ cdot \ dfrac {p ^ 2-1} {p ^ 2 + 1} $ 또는 $ f (x)= x - \ dfrac {2x} {p ^ 2 + 1} $ .

내 인상은 누군가가 0에서 1까지 원활하게 진행되는 함수 f (x)에 의해 x를 곱하고 싶었고,이를 위해 한 초등 함수를 사용하여 표현식을 발견 할 때까지 실험하는 것입니다. ...에

파라미터 T를 선택한 후 $ p_t (x)= 1/2 + (3 / 4t) x - x ^ 3 / (4t ^ 3) $ , $ P_T (0)= 1/2 $ , $ p_t (t)= 1 $ $ p_t (-t)= 0 $ $ p_t '(t)= p_t'(- t)= 0 $ . x <-t, x> +1, $ p_t (x) $ ≤ x ≤ + t 인 경우 g (x)= 0으로 가정하십시오. 이것은 0에서 1까지 원활하게 변경하는 함수입니다. 다른 매개 변수 S를 선택하고 f (x) 대신 x * g (x-s)를 계산합니다.

t= 3.0 및 s= -0.3 주어진 함수와 상당히 합리적으로 일치하며 끔찍한 훨씬 더 빨리 (중요한 것처럼 보이는) 계산됩니다. 물론 다릅니다. 이 기능이 몇 가지 문제가있는 도구로 사용되므로 원래 함수가 더 나은 인 수학적 이유를보고 싶습니다.

문맥 여기 신경망 훈련을 위한 컴퓨터 비전 및 활성화 함수입니다.

이 코드는 GPU에서 실행될 가능성이 있습니다.성능은 일반적인 입력의 분포에 따라 달라지지만, 일반적으로 GPU 코드에서 분기를 피하는 것이 중요합니다..워프 발산은 코드 성능을 크게 저하시킬 수 있습니다.예를 들어, CUDA 툴킷 문서 말한다:

참고: 우선순위가 높음:동일한 워프 내에서 다른 실행 경로를 피하십시오.흐름 제어 명령(if, switch, do, for, while)은 동일한 워프의 스레드를 분기시켜 명령 처리량에 큰 영향을 미칠 수 있습니다.즉, 다른 실행 경로를 따르는 것입니다.이런 일이 발생하면 서로 다른 실행 경로를 별도로 실행해야 합니다.이는 이 워프에 대해 실행되는 명령의 총 개수를 증가시킵니다....몇 가지 명령만 포함하는 분기의 경우 워프 발산으로 인해 일반적으로 약간의 성능 손실이 발생합니다.예를 들어, 컴파일러는 실제 분기를 피하기 위해 예측을 사용할 수 있습니다.대신 모든 명령어가 예약되지만 스레드별 조건 코드 또는 조건자는 명령어를 실행하는 스레드를 제어합니다.거짓 조건자가 있는 스레드는 결과를 쓰지 않으며 주소를 평가하거나 피연산자를 읽지도 않습니다.

두 가지 분기 없는 구현

OP의 답변 짧은 분기가 있으므로 일부 컴파일러에서는 분기 예측이 발생할 수 있습니다.내가 알아차린 또 다른 점은 호출당 한 번씩 지수를 계산하는 것이 허용 가능한 것으로 보인다는 것입니다.즉, 지수에 대한 한 번의 호출이 "비싸다"거나 "느리다"는 것이 아니라는 OP의 대답을 이해합니다.

이 경우 다음과 같은 간단한 코드를 제안합니다.

__device__ float mish(float x)
{
    float expx = __expf(x);
    return x / (1.0f + 2.0f / (expx * (2.0f + expx)));
}

가지가 없고, 하나의 지수, 하나의 곱셈, 두 개의 나눗셈이 있습니다.나눗셈은 곱셈보다 비용이 더 많이 드는 경우가 많으므로 다음 코드도 시도해 보았습니다.

__device__ float mish(float x)
{
    float expx = __expf(x);
    float psi = expx * (2.0f + expx);
    return x * (psi / (2.0f + psi));
}

여기에는 가지가 없고 하나의 지수, 두 개의 곱셈, 하나의 나눗셈이 있습니다.

상대 오류

나는 이 두 구현과 OP의 답변의 log10 상대 정확도를 계산했습니다.1/1024씩 증가하면서 간격(-100,100)에 대해 계산한 다음 51개 값에 대한 실행 최대값을 계산했습니다(시각적 혼란을 줄이면서 여전히 올바른 인상을 주기 위해).배정밀도로 첫 번째 구현을 계산하는 것은 참조용으로 충분합니다.지수는 하나의 ULP 내에서 정확하며 소수의 산술 연산만 있습니다.나머지 비트는 테이블 제작자의 딜레마를 거의 발생시키지 않을 만큼 충분합니다.따라서 올바르게 반올림된 단정밀도 참조 값을 계산할 수 있을 가능성이 매우 높습니다.

Log10 relative error

녹색:첫 번째 구현.빨간색:두 번째 구현.파란색:OP의 구현.파란색과 빨간색은 대부분의 범위(약 -20 왼쪽)에서 겹칩니다.

OP 참고 사항:전체 정밀도를 유지하려면 컷오프를 -5보다 크게 변경하는 것이 좋습니다.

성능

어느 것이 더 빠른지 확인하려면 이 두 가지 구현을 테스트해야 합니다.적어도 OP만큼 빨라야 하며 분기가 부족하기 때문에 훨씬 더 빠를 것이라고 생각합니다.그러나 속도가 충분히 빠르지 않다면 더 많은 조치를 취할 수 있습니다.

중요한 질문:

예상되는 일반적인 입력 값의 분포는 어떻습니까?함수가 효과적으로 계산 가능한 전체 범위에 걸쳐 값이 균일하게 분포됩니까?아니면 거의 항상 0 주위에 모여 있게 될까요?그렇다면 어떤 분산/확산이 있습니까?

점근법은 개선될 수 있습니다.

왼쪽에서 OP는 다음을 사용합니다. x * expx 컷오프는 -18입니다.이 컷오프는 정밀도 손실 없이 약 -15.5625까지 증가할 수 있습니다.한 번의 추가 곱셈 비용으로 다음을 사용할 수 있습니다. x * expx * (1.0f - 0.5f * expx) 컷오프는 약 -4.875입니다.메모:0.5를 곱하는 것은 지수에서 1을 빼는 것으로 최적화될 수 있으므로 여기서는 계산하지 않습니다.

오른쪽에는 또 다른 점근치를 도입할 수 있습니다.만약에 x > 8.75, 단순히 return x.비용을 조금 더 들여서 할 수 있습니다 x * (1.0f - 2.0f * __expf(-2.0f * x)) 언제 x > 6.0.

보간

범위(-4.875, 6.0)의 중앙 부분에 대해서는 보간 테이블을 사용할 수 있습니다.해당 범위의 간격이 동일한 경우 하나의 나누기를 사용하여 분기 없이 테이블에 대한 직접 인덱스를 계산할 수 있습니다.이러한 테이블을 계산하려면 약간의 노력이 필요하지만 필요에 따라 그만한 가치가 있을 수 있습니다.소수의 곱셈과 덧셈 ~할 것 같다 지수보다 비용이 저렴합니다.즉, 라이브러리의 지수 구현자는 아마도 정확하고 빠른 지수를 얻기 위해 많은 시간과 노력을 소비했을 것입니다.또한 "mish" 기능은 범위 축소의 기회를 제공하지 않습니다.

라이센스 : CC-BY-SA ~와 함께 속성
제휴하지 않습니다 cs.stackexchange
scroll top