-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathstate_torch.py
More file actions
101 lines (80 loc) · 2.69 KB
/
state_torch.py
File metadata and controls
101 lines (80 loc) · 2.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""
GPU-accelerated state management for optimization.
Keeps all image data on GPU to avoid CPU<->GPU transfers.
"""
import torch
import numpy as np
from typing import TYPE_CHECKING
if TYPE_CHECKING:
pass
class StateTorch:
"""
Optimization state with GPU-accelerated distance computation.
Keeps target and current canvas as torch tensors on GPU.
"""
def __init__(self, target: torch.Tensor, current: torch.Tensor, device: torch.device):
"""
Initialize state.
Args:
target: Target image tensor (H, W, C) on GPU
current: Current canvas tensor (H, W, C) on GPU
device: torch device
"""
self.device = device
# Ensure tensors are on the correct device
self.target = target.to(device) if not target.is_cuda else target
self.current = current.to(device) if not current.is_cuda else current
# Compute initial distance
self.distance = self._compute_distance()
def _compute_distance(self) -> float:
"""
Compute RMS distance between current and target (GPU-accelerated).
Returns:
Distance metric
"""
# Only compare RGB channels
diff = self.current[:, :, :3].float() - self.target[:, :, :3].float()
squared_diff = diff * diff
sum_diff = torch.sum(squared_diff).item()
pixels = self.current.shape[0] * self.current.shape[1]
# RMS distance normalized
if pixels == 0:
return 0.0
# sqrt(sum / (3 * pixels)) / 255
distance = (sum_diff / (3 * pixels)) ** 0.5 / 255.0
return distance
def copy(self) -> 'StateTorch':
"""Create a copy of this state."""
return StateTorch(
self.target.clone(),
self.current.clone(),
self.device
)
def to_numpy(self) -> tuple:
"""
Convert state tensors to numpy arrays (for compatibility).
Returns:
(target_numpy, current_numpy)
"""
return (
self.target.cpu().numpy(),
self.current.cpu().numpy()
)
@staticmethod
def from_numpy(
target: np.ndarray,
current: np.ndarray,
device: torch.device
) -> 'StateTorch':
"""
Create StateTorch from numpy arrays.
Args:
target: Target image numpy array
current: Current canvas numpy array
device: torch device
Returns:
StateTorch instance
"""
target_tensor = torch.from_numpy(target).to(device)
current_tensor = torch.from_numpy(current).to(device)
return StateTorch(target_tensor, current_tensor, device)