|
14 | 14 | """ |
15 | 15 |
|
16 | 16 |
|
17 | | -def solve(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: # pragma: no cover |
18 | | - """ |
19 | | - Like torch.linalg.solve, tries to return X |
20 | | - such that AX=B, with A square. |
21 | | - """ |
22 | | - if hasattr(torch, "linalg") and hasattr(torch.linalg, "solve"): |
23 | | - # PyTorch version >= 1.8.0 |
24 | | - return torch.linalg.solve(A, B) |
25 | | - |
26 | | - # pyre-fixme[16]: `Tuple` has no attribute `solution`. |
27 | | - return torch.solve(B, A).solution |
28 | | - |
29 | | - |
30 | | -def lstsq(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: # pragma: no cover |
31 | | - """ |
32 | | - Like torch.linalg.lstsq, tries to return X |
33 | | - such that AX=B. |
34 | | - """ |
35 | | - if hasattr(torch, "linalg") and hasattr(torch.linalg, "lstsq"): |
36 | | - # PyTorch version >= 1.9 |
37 | | - return torch.linalg.lstsq(A, B).solution |
38 | | - |
39 | | - solution = torch.lstsq(B, A).solution |
40 | | - if A.shape[1] < A.shape[0]: |
41 | | - return solution[: A.shape[1]] |
42 | | - return solution |
43 | | - |
44 | | - |
45 | | -def qr(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no cover |
46 | | - """ |
47 | | - Like torch.linalg.qr. |
48 | | - """ |
49 | | - if hasattr(torch, "linalg") and hasattr(torch.linalg, "qr"): |
50 | | - # PyTorch version >= 1.9 |
51 | | - return torch.linalg.qr(A) |
52 | | - return torch.qr(A) |
53 | | - |
54 | | - |
55 | | -def eigh(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no cover |
56 | | - """ |
57 | | - Like torch.linalg.eigh, assuming the argument is a symmetric real matrix. |
58 | | - """ |
59 | | - if hasattr(torch, "linalg") and hasattr(torch.linalg, "eigh"): |
60 | | - return torch.linalg.eigh(A) |
61 | | - return torch.symeig(A, eigenvectors=True) |
62 | | - |
63 | | - |
64 | 17 | def meshgrid_ij( |
65 | 18 | *A: Union[torch.Tensor, Sequence[torch.Tensor]] |
66 | 19 | ) -> Tuple[torch.Tensor, ...]: # pragma: no cover |
|
0 commit comments