Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 119 additions & 2 deletions mlforecast/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,9 +862,124 @@ def load(path: Union[str, Path], protocol: Optional[str] = None) -> "TimeSeries"
with fsspec.open(path, "rb", protocol=protocol) as f:
ts = cloudpickle.load(f)
return ts

def _validate_new_df(self, df: DataFrame) -> None:
if isinstance(df, pl.DataFrame):
Comment thread
nasaul marked this conversation as resolved.
stats = (
df.group_by(self.id_col)
.agg(
pl.col(self.time_col).min().alias("_min"),
pl.col(self.time_col).max().alias("_max"),
pl.len().alias("_size"),
)
.sort(self.id_col)
)
last_dates_df = pl_DataFrame(
{self.id_col: self.uids, "_last": self.last_dates}
)
expected_start = ufp.offset_times(last_dates_df["_last"], self.freq, 1)
expected_df = last_dates_df.with_columns(
pl.Series(name="_expected_start", values=expected_start)
).select([self.id_col, "_expected_start"])
stats = stats.with_columns(pl.col(self.id_col).cast(pl.Utf8))
expected_df = expected_df.with_columns(pl.col(self.id_col).cast(pl.Utf8))
stats = stats.join(expected_df, on=self.id_col, how="left")
bad_starts = stats.filter(
pl.col("_expected_start").is_not_null()
& (pl.col("_min") != pl.col("_expected_start"))
)
if bad_starts.height:
bad_ids = bad_starts[self.id_col].to_list()
raise ValueError(
"Series have invalid start dates. "
f"Expected start at last_date + freq for: {bad_ids}."
)
if isinstance(self.freq, int):
diffs = pl.col("_max") - pl.col("_min")
misaligned = stats.filter((diffs % self.freq) != 0)
if misaligned.height:
raise ValueError(
"Found timestamps not aligned to the configured frequency."
)
expected_count = diffs // self.freq + 1
else:
delta = pd.Timedelta(pd.tseries.frequencies.to_offset(self.freq))
delta_ns = delta.value
min_ns = pl.col("_min").dt.timestamp("ns")
max_ns = pl.col("_max").dt.timestamp("ns")
diffs_ns = max_ns - min_ns
misaligned = stats.filter((diffs_ns % delta_ns) != 0)
if misaligned.height:
raise ValueError(
"Found timestamps not aligned to the configured frequency."
)
expected_count = diffs_ns // delta_ns + 1
gaps = stats.filter(expected_count != pl.col("_size"))
if gaps.height:
bad_ids = gaps[self.id_col].to_list()
raise ValueError(
"Found gaps or duplicate timestamps in the update for: "
f"{bad_ids}."
)
return
stats = (
df.groupby(self.id_col, observed=True)[self.time_col]
.agg(["min", "max", "size"])
.rename(columns={"min": "_min", "max": "_max", "size": "_size"})
.reset_index()
)
last_dates_df = pd.DataFrame(
{self.id_col: self.uids, "_last": self.last_dates}
)
expected_start = ufp.offset_times(last_dates_df["_last"], self.freq, 1)
expected_df = pd.DataFrame(
{self.id_col: last_dates_df[self.id_col], "_expected_start": expected_start}
)
stats[self.id_col] = stats[self.id_col].astype(str)
expected_df[self.id_col] = expected_df[self.id_col].astype(str)
stats = stats.merge(expected_df, on=self.id_col, how="left")
start_mismatch = stats["_expected_start"].notna() & (
stats["_min"] != stats["_expected_start"]
)
if start_mismatch.any():
bad_ids = stats.loc[start_mismatch, self.id_col].tolist()
raise ValueError(
"Series have invalid start dates. "
f"Expected start at last_date + freq for: {bad_ids}."
)
diffs = stats["_max"] - stats["_min"]
if isinstance(self.freq, int):
delta = self.freq
remainder = diffs % delta
if (remainder != 0).any():
raise ValueError(
"Found timestamps not aligned to the configured frequency."
)
expected_count = diffs // delta + 1
else:
offset = pd.tseries.frequencies.to_offset(self.freq)
delta = pd.Timedelta(offset)
remainder = diffs % delta
if (remainder != pd.Timedelta(0)).any():
raise ValueError(
"Found timestamps not aligned to the configured frequency."
)
expected_count = diffs // delta + 1
gaps = expected_count != stats["_size"]
if gaps.any():
bad_ids = stats.loc[gaps, self.id_col].tolist()
raise ValueError(
"Found gaps or duplicate timestamps in the update for: "
f"{bad_ids}."
)

def update(self, df: DataFrame, validate_input: bool = False) -> None:
"""Update the values of the stored series.

def update(self, df: DataFrame) -> None:
"""Update the values of the stored series."""
Args:
df: New observations to append.
validate_input: If True, validate continuity, start dates, and frequency.
"""
validate_format(df, self.id_col, self.time_col, self.target_col)
uids = self.uids
if isinstance(uids, pd.Index):
Expand All @@ -875,6 +990,8 @@ def update(self, df: DataFrame) -> None:
df = ufp.sort(df, by=[self.id_col, self.time_col])
values = df[self.target_col].to_numpy()
values = values.astype(self.ga.data.dtype, copy=False)
if validate_input:
Comment thread
nasaul marked this conversation as resolved.
self._validate_new_df(df=df)
id_counts = ufp.counts_by_id(df, self.id_col)
try:
sizes = ufp.join(uids, id_counts, on=self.id_col, how="outer_coalesce")
Expand Down
5 changes: 3 additions & 2 deletions mlforecast/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,10 +961,11 @@ def load(path: Union[str, Path]) -> "MLForecast":
fcst._cs_df = intervals["scores"]
return fcst

def update(self, df: DataFrame) -> None:
def update(self, df: DataFrame, validate_input: bool = False) -> None:
"""Update the values of the stored series.

Args:
df (pandas or polars DataFrame): Dataframe with new observations.
validate_input (bool): If True, validate continuity, start dates, and frequency.
"""
self.ts.update(df)
self.ts.update(df, validate_input)
Loading
Loading