Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
f6f2c9c
ag per paper
agobbifbk Dec 18, 2024
1df26a5
ag per paper
agobbifbk Dec 18, 2024
7a37cb9
ag update
agobbifbk Dec 18, 2024
f546404
ag update
agobbifbk Jan 14, 2025
8b428ac
ag update
agobbifbk Jan 14, 2025
a5ef131
ag update
agobbifbk Jan 14, 2025
c3d223f
ag debug
agobbifbk Feb 5, 2025
f3339a9
ag debug
agobbifbk Feb 5, 2025
16c80f2
ag debug
agobbifbk Feb 5, 2025
9cd5c40
ag debug
agobbifbk Feb 5, 2025
d8e8f2e
ag debug
agobbifbk Feb 5, 2025
9d0d0b6
ag debug
agobbifbk Feb 5, 2025
7a7abdb
ag debug
agobbifbk Feb 5, 2025
50adb1c
ag debug
agobbifbk Feb 5, 2025
3cec289
ag debug
agobbifbk Feb 5, 2025
1b28deb
ag debug
agobbifbk Feb 5, 2025
5389980
ag debug
agobbifbk Feb 5, 2025
84ae2ca
ag debug
agobbifbk Feb 5, 2025
a5b3120
ag debug
agobbifbk Feb 5, 2025
1730b8b
ag debug
agobbifbk Feb 5, 2025
6c0bd95
ag debug
agobbifbk Feb 5, 2025
88cbea5
ag debug
agobbifbk Feb 5, 2025
5fe1370
ag debug
agobbifbk Feb 5, 2025
840ad5f
ag debug
agobbifbk Feb 5, 2025
bb80a99
ag debug
agobbifbk Feb 5, 2025
4edc058
ag debug
agobbifbk Feb 5, 2025
098f13c
ag debug
agobbifbk Feb 5, 2025
b69e27d
ag debug
agobbifbk Feb 5, 2025
8f6945b
ag debug
agobbifbk Feb 5, 2025
ed50987
ag debug
agobbifbk Feb 5, 2025
d0a91b7
ag debug
agobbifbk Feb 5, 2025
32f3249
ag debug
agobbifbk Feb 5, 2025
be38929
ag debug
agobbifbk Feb 5, 2025
332dfaf
ag bug fixing sam!
agobbifbk Feb 5, 2025
dc61caa
ag bug fixing
agobbifbk Feb 6, 2025
b58040c
ag bug fixing
agobbifbk Feb 6, 2025
7aa0fa0
ag bug fixing
agobbifbk Feb 6, 2025
18a8f94
ag bug fixing
agobbifbk Feb 6, 2025
e5189fa
ag bug fixing
agobbifbk Feb 6, 2025
121b8e8
ag bug fixing
agobbifbk Feb 6, 2025
7155504
ag bug fixing
agobbifbk Feb 6, 2025
9fa119d
ag bug fixing
agobbifbk Feb 6, 2025
2d29585
ag bug fixing
agobbifbk Feb 6, 2025
c91f36b
ag bug fixing
agobbifbk Feb 6, 2025
397a1bd
ag bug fixing
agobbifbk Feb 6, 2025
4d4b344
ag bug fixing
agobbifbk Feb 6, 2025
35d77ff
ag bug fixing
agobbifbk Feb 6, 2025
f6c4bba
ag bug fixing
agobbifbk Feb 6, 2025
e27e7a5
ag bug fixing
agobbifbk Feb 6, 2025
5a6641b
ag bug fixing
agobbifbk Feb 6, 2025
58b4e56
ag bug fixing
agobbifbk Feb 6, 2025
4109f10
ag bug fixing
agobbifbk Feb 6, 2025
f8ded67
ag bug fixing
agobbifbk Feb 6, 2025
fb16bef
ag bug fixing
agobbifbk Feb 6, 2025
e7ce508
ag bug fixing
agobbifbk Feb 6, 2025
b9b33db
ag bug fixing
agobbifbk Feb 6, 2025
aa3e39a
ag bf
agobbifbk Feb 7, 2025
e76ccff
ag bf
agobbifbk Feb 7, 2025
a3cdaab
ag bf
agobbifbk Feb 7, 2025
154ce37
creating v.1.1.3 version
agobbifbk Feb 11, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion bash_examples/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def compare(conf:DictConfig)-> None:
pdb.set_trace()

for conf_tmp in files:

if conf_tmp.endswith('yaml') is False:
continue
beauty_string(f'Processing file: {conf_tmp}','block',VERBOSE)
conf_tmp = OmegaConf.load(conf_tmp)

Expand Down
1 change: 0 additions & 1 deletion bash_examples/config_etth1/config_xps.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ split_params:
past_steps: model_configs@past_steps
future_steps: model_configs@future_steps


train_config:
dirpath: "/home/agobbi/Projects/ExpTS/etth1"
num_workers: 0
Expand Down
6 changes: 3 additions & 3 deletions bash_examples/config_test/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ split_params:
scaler: 'StandardScaler()' ## or sklearn.preprocessing.StandardScaler()
train_config:
dirpath: null
num_workers: 4
num_workers: 0
auto_lr_find: false
devices: [0]
seed: 42
Expand All @@ -58,8 +58,8 @@ hydra:
launcher:
n_jobs: 2
verbose: 1
pre_dispatch: 1
batch_size: 1
pre_dispatch: 2
batch_size: 2

output_subdir: null
sweeper:
Expand Down
8 changes: 4 additions & 4 deletions dsipts/data_structure/data_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def create_data_loader(self,data:pd.DataFrame,
starting_point (Union[None,dict], optional): a dictionary indicating if a sample must be considered. It is checked for the first lag in the future (useful in the case your model has to predict only starting from hour 12). Defaults to None.
skip_step (int, optional): list of the categortial variables (same for past and future). Usual there is a skip of one between two saples but for debugging or training time purposes you can skip some samples. Defaults to 1.
Returns:
MyDataset: class thath extends torch.utils.data.Dataset (see utils)
MyDataset: class that extends torch.utils.data.Dataset (see utils)
keys of a batch:
y : the target variable(s)
x_num_past: the numerical past variables
Expand Down Expand Up @@ -788,10 +788,11 @@ def train_model(self,dirpath:str,
os.remove(os.path.join(os.path.join(dirpath,f)))
if isinstance(self.losses,dict):
self.losses = pd.DataFrame()

try:
self.model = self.model.load_from_checkpoint(self.checkpoint_file_last)
except Exception as _:
beauty_string(f'There is a problem loading the weights on file {self.checkpoint_file_last}','section',self.verbose)
beauty_string(f'There is a problem loading the weights on file MAYBE CHANGED HOW WEIGHTS ARE LOADED {self.checkpoint_file_last}','section',self.verbose)

try:
val_loss = self.losses.val_loss.values[-1]
Expand Down Expand Up @@ -1051,8 +1052,7 @@ def load(self,model:Base, filename:str,load_last:bool=True,dirpath:Union[str,Non
directory = self.dirpath
else:
directory = dirpath



if load_last:

try:
Expand Down
4 changes: 3 additions & 1 deletion dsipts/data_structure/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __init__(self, data:dict,t:np.array,groups:np.array,idx_target:Union[np.arra
idx_target: index of target features in the past array
t (np.array): the time array related to the target variables
idx_target (Union[np.array,None]): you can specify the index in the past data that represent the input features (for differntial analysis or detrending strategies)
idx_target (Union[np.array,None]): you can specify the index in the future data that represent the input features (for differntial analysis or detrending strategies)
idx_target_future (Union[np.array,None]): you can specify the index in the future data that represent the input features (for differntial analysis or detrending strategies)

Returns:
torch.utils.data.Dataset: a torch Dataset to be used in a Dataloader
Expand All @@ -135,6 +135,8 @@ def __init__(self, data:dict,t:np.array,groups:np.array,idx_target:Union[np.arra
self.idx_target = np.array(idx_target) if idx_target is not None else None
self.idx_target_future = np.array(idx_target_future) if idx_target_future is not None else None



def __len__(self):

return len(self.data['y'])
Expand Down
67 changes: 44 additions & 23 deletions dsipts/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,19 +102,21 @@ def configure_optimizers(self):

:meta private:
"""

self.has_sam_optim = False
if self.optim_config is None:
self.optim_config = {'lr': 5e-05}


if self.optim is None:
optimizer = optim.Adam(self.parameters(), **self.optim_config)
self.initialize = True

else:
if self.initialize is False:
if self.optim=='SAM':
self.has_sam_optim = True
self.automatic_optimization = False
self.my_step = 0

else:
self.optim = eval(self.optim)
Expand All @@ -141,20 +143,40 @@ def training_step(self, batch, batch_idx):

:meta private:
"""
y_hat = self(batch)
loss = self.compute_loss(batch,y_hat)
if self.has_sam_optim:

#loss = self.compute_loss(batch,y_hat)
#import pdb
#pdb.set_trace()

if self.has_sam_optim:

opt = self.optimizers()
self.manual_backward(loss)
opt.first_step(zero_grad=True)

def closure():
opt.zero_grad()
y_hat = self(batch)
loss = self.compute_loss(batch,y_hat)
self.manual_backward(loss)
return loss

opt.step(closure)
y_hat = self(batch)
loss = self.compute_loss(batch, y_hat)

self.manual_backward(loss,retain_graph=True)
opt.second_step(zero_grad=True)

loss = self.compute_loss(batch,y_hat)

#opt.first_step(zero_grad=True)

#y_hat = self(batch)
#loss = self.compute_loss(batch, y_hat)
#self.my_step+=1
#self.manual_backward(loss,retain_graph=True)
#opt.second_step(zero_grad=True)
#self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
#self.log("global_step", self.my_step, on_step=True) # Correct way to log


#self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.increment("optimizer")
else:
y_hat = self(batch)
loss = self.compute_loss(batch,y_hat)
return loss


Expand Down Expand Up @@ -243,7 +265,7 @@ def compute_loss(self,batch,y_hat):
#import pdb
#pdb.set_trace()
mda = (1-torch.mean( torch.sign(torch.diff(x,axis=1))*torch.sign(torch.diff(batch['y'],axis=1))))
loss = torch.mean( torch.abs(x-batch['y']).mean(axis=1).flatten()) + self.persistence_weight*mda#/10
loss = torch.mean( torch.abs(x-batch['y']).mean(axis=1).flatten()) + self.persistence_weight*mda/10



Expand All @@ -255,7 +277,6 @@ def compute_loss(self,batch,y_hat):
sinkhorn = SinkhornDistance(eps=0.1, max_iter=100, reduction='mean')
loss = sinkhorn.compute(x,batch['y'])


elif self.loss_type == 'additive_iv':
std = torch.sqrt(torch.var(batch['y'], dim=(1))+ 1e-8) ##--> BSxChannel
x_std = torch.sqrt(torch.var(x, dim=(1))+ 1e-8)
Expand All @@ -278,7 +299,7 @@ def compute_loss(self,batch,y_hat):

elif self.loss_type=='triplet':
loss_fn = torch.nn.TripletMarginLoss(margin=0.01, p=1.0,swap=False)
loss = initial_loss + self.persistence_weight*loss_fn(x, batch['y'], y_persistence)
loss = initial_loss + self.persistence_weight*loss_fn(x, batch['y'], y_persistence)

elif self.loss_type=='high_order':
loss = initial_loss
Expand All @@ -291,12 +312,12 @@ def compute_loss(self,batch,y_hat):

elif self.loss_type=='dilated':
#BxLxCxMUL
#if self.persistence_weight==0.1:
# alpha = 0.25
#if self.persistence_weight==1:
# alpha = 0.5
#else:
# alpha =0.75
if self.persistence_weight==0.1:
alpha = 0.25
if self.persistence_weight==1:
alpha = 0.5
else:
alpha =0.75
alpha = self.persistence_weight
gamma = 0.01
loss = 0
Expand All @@ -307,8 +328,8 @@ def compute_loss(self,batch,y_hat):
loss+= dilate_loss( batch['y'][:,:,i:i+1],x[:,:,i:i+1], alpha, gamma, y_hat.device)

elif self.loss_type=='huber':
#loss = torch.nn.HuberLoss(reduction='mean', delta=self.persistence_weight/10)
loss = torch.nn.HuberLoss(reduction='mean', delta=self.persistence_weight)
loss = torch.nn.HuberLoss(reduction='mean', delta=self.persistence_weight/10)
#loss = torch.nn.HuberLoss(reduction='mean', delta=self.persistence_weight)
if self.use_quantiles is False:
x = y_hat[:,:,:,0]
else:
Expand Down
67 changes: 27 additions & 40 deletions dsipts/models/samformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,7 @@ def _denormalize(self, x):
return x



class SAM(Optimizer):
"""
SAM: Sharpness-Aware Minimization for Efficiently Improving Generalization https://arxiv.org/abs/2010.01412
https://github.com/davda54/sam
"""

class SAM(torch.optim.Optimizer):
def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

Expand All @@ -100,6 +94,7 @@ def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):

self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
self.param_groups = self.base_optimizer.param_groups
self.defaults.update(self.base_optimizer.defaults)

@torch.no_grad()
def first_step(self, zero_grad=False):
Expand All @@ -110,13 +105,9 @@ def first_step(self, zero_grad=False):
for p in group["params"]:
if p.grad is None:
continue
e_w = (
(torch.pow(p, 2) if group["adaptive"] else 1.0)
* p.grad
* scale.to(p)
)
p.add_(e_w) # climb to the local maximum "w + e(w)"
self.state[p]["e_w"] = e_w
self.state[p]["old_p"] = p.data.clone()
e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
p.add_(e_w) # Perturb weights in the gradient direction

if zero_grad:
self.zero_grad()
Expand All @@ -127,41 +118,37 @@ def second_step(self, zero_grad=False):
for p in group["params"]:
if p.grad is None:
continue
p.sub_(self.state[p]["e_w"]) # get back to "w" from "w + e(w)"
p.data = self.state[p]["old_p"] # Restore original weights

self.base_optimizer.step() # do the actual "sharpness-aware" update
self.base_optimizer.step() # Apply the sharpness-aware update

if zero_grad:
self.zero_grad()

@torch.no_grad()
def step(self, closure=None):
assert (
closure is not None
), "Sharpness Aware Minimization requires closure, but it was not provided"
closure = torch.enable_grad()(
closure
) # the closure should do a full forward-backward pass
assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"

with torch.enable_grad():
closure() # First forward-backward pass

self.first_step(zero_grad=True)
closure()

with torch.enable_grad():
closure() # Second forward-backward pass

self.second_step()

def _grad_norm(self):
shared_device = self.param_groups[0]["params"][
0
].device # put everything on the same device, in case of model parallelism
norm = torch.norm(
torch.stack(
[
((torch.abs(p) if group["adaptive"] else 1.0) * p.grad)
.norm(p=2)
.to(shared_device)
for group in self.param_groups
for p in group["params"]
if p.grad is not None
]
),
p=2,
)
return norm
shared_device = self.param_groups[0]["params"][0].device
grads = [
((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
for group in self.param_groups for p in group["params"]
if p.grad is not None
]
return torch.norm(torch.stack(grads), p=2) if grads else torch.tensor(0.0, device=shared_device)

def load_state_dict(self, state_dict):
super().load_state_dict(state_dict)
if hasattr(self, "base_optimizer"): # Ensure base optimizer exists
self.base_optimizer.load_state_dict(state_dict["base_optimizer"])
4 changes: 2 additions & 2 deletions notebooks/1-monash_timeseries.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion notebooks/2-venice_dataset.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -685797,7 +685797,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
"version": "3.9.19"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

setup(
name="dsipts",
version="1.1.2",
version="1.1.3",
author="Andrea Gobbi",
author_email="[email protected]",
packages=find_packages(exclude=("tests",)),
Expand Down
2 changes: 1 addition & 1 deletion setup_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name="dsipts",
version="1.1.2",
version="1.1.3",
author="Andrea Gobbi",
author_email="[email protected]",
packages=find_packages(exclude=("tests",)),
Expand Down