What is weight decay?

Weight decay is a regularization technique by adding a small penalty, usually the L2 norm of the weights (all the weights of the model), to the loss function.

loss = loss + weight decay parameter * L2 norm of the weights

Some people prefer to only apply weight decay to the weights and not the bias. PyTorch applies weight decay to both weights and bias.

Why do we use weight decay?

  • To prevent overfitting.
  • To keep the weights small and avoid exploding gradient. Because the L2 norm of the weights are added to the loss, each iteration of your network will try to optimize/minimize the model weights in addition to the loss. This will help keep the weights as small as possible, preventing the weights to grow out of control, and thus avoid exploding gradient.

#pytorch #data-science #deep-learning #weight-decay

Deep learning basics — weight decay
2.35 GEEK