|
1 | | -from typing import Optional |
| 1 | +from typing import Callable, Optional, Sequence |
2 | 2 |
|
3 | 3 | import torch as th |
| 4 | +from torch import nn |
4 | 5 |
|
5 | 6 |
|
6 | 7 | def quantile_huber_loss( |
@@ -67,3 +68,96 @@ def quantile_huber_loss( |
67 | 68 | else: |
68 | 69 | loss = loss.mean() |
69 | 70 | return loss |
| 71 | + |
| 72 | + |
| 73 | +def conjugate_gradient_solver( |
| 74 | + matrix_vector_dot_fn: Callable[[th.Tensor], th.Tensor], |
| 75 | + b, |
| 76 | + max_iter=10, |
| 77 | + residual_tol=1e-10, |
| 78 | +) -> th.Tensor: |
| 79 | + """ |
| 80 | + Finds an approximate solution to a set of linear equations Ax = b |
| 81 | +
|
| 82 | + Sources: |
| 83 | + - https://github.com/ajlangley/trpo-pytorch/blob/master/conjugate_gradient.py |
| 84 | + - https://github.com/joschu/modular_rl/blob/master/modular_rl/trpo.py#L122 |
| 85 | +
|
| 86 | + Reference: |
| 87 | + - https://epubs.siam.org/doi/abs/10.1137/1.9781611971446.ch6 |
| 88 | +
|
| 89 | + :param matrix_vector_dot_fn: |
| 90 | + a function that right multiplies a matrix A by a vector v |
| 91 | + :param b: |
| 92 | + the right hand term in the set of linear equations Ax = b |
| 93 | + :param max_iter: |
| 94 | + the maximum number of iterations (default is 10) |
| 95 | + :param residual_tol: |
| 96 | + residual tolerance for early stopping of the solving (default is 1e-10) |
| 97 | + :return x: |
| 98 | + the approximate solution to the system of equations defined by `matrix_vector_dot_fn` |
| 99 | + and b |
| 100 | + """ |
| 101 | + |
| 102 | + # The vector is not initialized at 0 because of the instability issues when the gradient becomes small. |
| 103 | + # A small random gaussian noise is used for the initialization. |
| 104 | + x = 1e-4 * th.randn_like(b) |
| 105 | + residual = b - matrix_vector_dot_fn(x) |
| 106 | + # Equivalent to th.linalg.norm(residual) ** 2 (L2 norm squared) |
| 107 | + residual_squared_norm = th.matmul(residual, residual) |
| 108 | + |
| 109 | + if residual_squared_norm < residual_tol: |
| 110 | + # If the gradient becomes extremely small |
| 111 | + # The denominator in alpha will become zero |
| 112 | + # Leading to a division by zero |
| 113 | + return x |
| 114 | + |
| 115 | + p = residual.clone() |
| 116 | + |
| 117 | + for i in range(max_iter): |
| 118 | + # A @ p (matrix vector multiplication) |
| 119 | + A_dot_p = matrix_vector_dot_fn(p) |
| 120 | + |
| 121 | + alpha = residual_squared_norm / p.dot(A_dot_p) |
| 122 | + x += alpha * p |
| 123 | + |
| 124 | + if i == max_iter - 1: |
| 125 | + return x |
| 126 | + |
| 127 | + residual -= alpha * A_dot_p |
| 128 | + new_residual_squared_norm = th.matmul(residual, residual) |
| 129 | + |
| 130 | + if new_residual_squared_norm < residual_tol: |
| 131 | + return x |
| 132 | + |
| 133 | + beta = new_residual_squared_norm / residual_squared_norm |
| 134 | + residual_squared_norm = new_residual_squared_norm |
| 135 | + p = residual + beta * p |
| 136 | + |
| 137 | + |
| 138 | +def flat_grad( |
| 139 | + output, |
| 140 | + parameters: Sequence[nn.parameter.Parameter], |
| 141 | + create_graph: bool = False, |
| 142 | + retain_graph: bool = False, |
| 143 | +) -> th.Tensor: |
| 144 | + """ |
| 145 | + Returns the gradients of the passed sequence of parameters into a flat gradient. |
| 146 | + Order of parameters is preserved. |
| 147 | +
|
| 148 | + :param output: functional output to compute the gradient for |
| 149 | + :param parameters: sequence of ``Parameter`` |
| 150 | + :param retain_graph: – If ``False``, the graph used to compute the grad will be freed. |
| 151 | + Defaults to the value of ``create_graph``. |
| 152 | + :param create_graph: – If ``True``, graph of the derivative will be constructed, |
| 153 | + allowing to compute higher order derivative products. Default: ``False``. |
| 154 | + :return: Tensor containing the flattened gradients |
| 155 | + """ |
| 156 | + grads = th.autograd.grad( |
| 157 | + output, |
| 158 | + parameters, |
| 159 | + create_graph=create_graph, |
| 160 | + retain_graph=retain_graph, |
| 161 | + allow_unused=True, |
| 162 | + ) |
| 163 | + return th.cat([th.ravel(grad) for grad in grads if grad is not None]) |
0 commit comments