Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 94 additions & 2 deletions mlforecast/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,9 +862,99 @@ def load(path: Union[str, Path], protocol: Optional[str] = None) -> "TimeSeries"
with fsspec.open(path, "rb", protocol=protocol) as f:
ts = cloudpickle.load(f)
return ts

def _validate_new_df(self, df: DataFrame) -> None:
if isinstance(df, pl.DataFrame):
Comment thread
nasaul marked this conversation as resolved.
df = df.sort([self.id_col, self.time_col])
stats = (
df.group_by(self.id_col)
.agg(
pl.col(self.time_col).min().alias("_min"),
pl.col(self.time_col).max().alias("_max"),
pl.len().alias("_size"),
)
.sort(self.id_col)
)
last_dates_df = pl_DataFrame(
{self.id_col: self.uids, "_last": self.last_dates}
)
expected_start = ufp.offset_times(last_dates_df["_last"], self.freq, 1)
expected_df = last_dates_df.with_columns(
pl.Series(name="_expected_start", values=expected_start)
).select([self.id_col, "_expected_start"])
stats = stats.with_columns(pl.col(self.id_col).cast(pl.Utf8))
expected_df = expected_df.with_columns(pl.col(self.id_col).cast(pl.Utf8))
stats = stats.join(expected_df, on=self.id_col, how="left")
bad_starts = stats.filter(
pl.col("_expected_start").is_not_null()
& (pl.col("_min") != pl.col("_expected_start"))
)
if bad_starts.height:
bad_ids = bad_starts[self.id_col].to_list()
raise ValueError(
"Series have invalid start dates. "
f"Expected start at last_date + freq for: {bad_ids}."
)
expected_next = ufp.offset_times(df[self.time_col], self.freq, 1)
df_check = df.with_columns(
pl.Series(name="_expected_next", values=expected_next)
).with_columns(
pl.col(self.time_col).shift(-1).over(self.id_col).alias("_next")
)
gaps = df_check.filter(
pl.col("_next").is_not_null()
& (pl.col("_expected_next") != pl.col("_next"))
)
if gaps.height:
bad_ids = gaps[self.id_col].unique().to_list()
raise ValueError(
"Found gaps or duplicate timestamps in the update for: "
f"{bad_ids}."
)
return
df = df.sort_values([self.id_col, self.time_col])
stats = (
df.groupby(self.id_col, observed=True)[self.time_col]
.agg(["min", "max", "size"])
.rename(columns={"min": "_min", "max": "_max", "size": "_size"})
.reset_index()
)
last_dates_df = pd.DataFrame(
{self.id_col: self.uids, "_last": self.last_dates}
)
expected_start = ufp.offset_times(last_dates_df["_last"], self.freq, 1)
expected_df = pd.DataFrame(
{self.id_col: last_dates_df[self.id_col], "_expected_start": expected_start}
)
stats[self.id_col] = stats[self.id_col].astype(str)
expected_df[self.id_col] = expected_df[self.id_col].astype(str)
stats = stats.merge(expected_df, on=self.id_col, how="left")
start_mismatch = stats["_expected_start"].notna() & (
stats["_min"] != stats["_expected_start"]
)
if start_mismatch.any():
bad_ids = stats.loc[start_mismatch, self.id_col].tolist()
raise ValueError(
"Series have invalid start dates. "
f"Expected start at last_date + freq for: {bad_ids}."
)
expected_next = ufp.offset_times(df[self.time_col], self.freq, 1)
next_time = df.groupby(self.id_col, observed=True)[self.time_col].shift(-1)
gaps = next_time.notna() & (expected_next != next_time)
if gaps.any():
bad_ids = df.loc[gaps, self.id_col].unique().tolist()
raise ValueError(
"Found gaps or duplicate timestamps in the update for: "
f"{bad_ids}."
)

def update(self, df: DataFrame) -> None:
"""Update the values of the stored series."""
def update(self, df: DataFrame, validate_input: bool = False) -> None:
"""Update the values of the stored series.

Args:
df: New observations to append.
validate_input: If True, validate continuity, start dates, and frequency.
"""
validate_format(df, self.id_col, self.time_col, self.target_col)
uids = self.uids
if isinstance(uids, pd.Index):
Expand All @@ -875,6 +965,8 @@ def update(self, df: DataFrame) -> None:
df = ufp.sort(df, by=[self.id_col, self.time_col])
values = df[self.target_col].to_numpy()
values = values.astype(self.ga.data.dtype, copy=False)
if validate_input:
Comment thread
nasaul marked this conversation as resolved.
self._validate_new_df(df=df)
id_counts = ufp.counts_by_id(df, self.id_col)
try:
sizes = ufp.join(uids, id_counts, on=self.id_col, how="outer_coalesce")
Expand Down
5 changes: 3 additions & 2 deletions mlforecast/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,10 +961,11 @@ def load(path: Union[str, Path]) -> "MLForecast":
fcst._cs_df = intervals["scores"]
return fcst

def update(self, df: DataFrame) -> None:
def update(self, df: DataFrame, validate_input: bool = False) -> None:
"""Update the values of the stored series.

Args:
df (pandas or polars DataFrame): Dataframe with new observations.
validate_input (bool): If True, validate continuity, start dates, and frequency.
"""
self.ts.update(df)
self.ts.update(df, validate_input)
Loading