Question

I have seen many examples of this syntax that is being used for the loss function specifically:

loss = nn.BCEWithLogitsLoss()(pred, y)

Can anyone explain me what does the (pred, y) do exactly, that it directly computes the loss, instead of calling loss as a function of these 2 arguments again?

Was it helpful?

Solution

This is an example of Python's builtin __call__ method, as described here. In short: BCEWithLogitsLoss is a class. The first set of parentheses (empty, in your case) provides any needed arguments to the class initializer. Then the second set of parentheses are passed to the call method. So, this is convenient syntax that allows you to instantiate the class and evaluate one of its methods in one line.

Confirming this in the source code is actually a bit difficult. You can see the source for BCEWithLogitsLoss here, which confirms it is a class. But its only method is forward; where is __call__? For that, we notice that the BCEWithLogitsLoss class inherits from the _WeightedLoss class, which inherits from Loss, which inherits from Module. We can then see that this base class implements the mapping between __call__ and forward.

Licensed under: CC-BY-SA with attribution
Not affiliated with datascience.stackexchange
scroll top