deepmd.dpmodel.utils.safe_gradient

deepmd.dpmodel.utils.safe_gradient#

Safe versions of some functions that have problematic gradients.

Check https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where for more information.

Functions#

safe_for_sqrt(→ Any)

Safe version of sqrt that has a gradient of 0 at x = 0.

safe_for_vector_norm(→ Any)

Safe version of sqrt that has a gradient of 0 at x = 0.

Module Contents#

deepmd.dpmodel.utils.safe_gradient.safe_for_sqrt(x: Any) Any[source]#

Safe version of sqrt that has a gradient of 0 at x = 0.

deepmd.dpmodel.utils.safe_gradient.safe_for_vector_norm(x: Any, /, *, axis: Any | None = None, keepdims: bool = False, ord: Any = 2) Any[source]#

Safe version of sqrt that has a gradient of 0 at x = 0.