How to automatically disable register_hook when training model is in eval() phase in PyTorch?

datascience.stackexchange https://datascience.stackexchange.com/questions/80740

  •  13-12-2020
  •  | 
  •  

Question

I require to update grads of an intermediate tensor variable using the register_hook method. Since the variable isn't a leaf-variable, I require to add the retain_grad() method to it after which, I can use the register_hook method to alter the grads.

score.retain_grad()
h = score.register_hook(lambda grad: grad * torch.FloatTensor(...))

This works perfectly fine during the training (model.train()) phase. However, it gives an error during the evaluation phase (model.eval()).

The error:

File "/home/envs/darthvader/lib/python3.6/site-packages/torch/tensor.py", line 198, in register_hook
    raise RuntimeError("cannot register a hook on a tensor that "
RuntimeError: cannot register a hook on a tensor that doesn't require gradient

How could the model automatically disable the register_hook method when it in eval() phase?

Was it helpful?

Solution

Removing score.retain_grad() and guarding register_hook with if condition (if score.requires_grad) does the trick.

if score.requires_grad:
    h = score.register_hook(lambda grad: grad * torch.FloatTensor(...))

Originally answered by Alban D here.

Licensed under: CC-BY-SA with attribution
Not affiliated with datascience.stackexchange
scroll top