diff --git a/crates/core/src/inline_snippets.rs b/crates/core/src/inline_snippets.rs index 0044c8d61..f231b342b 100644 --- a/crates/core/src/inline_snippets.rs +++ b/crates/core/src/inline_snippets.rs @@ -300,6 +300,27 @@ fn delete_hanging_comma( replacements: &mut [(EffectRange, String)], offset: usize, ) -> Result<(String, Vec)> { + // Handle the case when after applying replacements, a single comma + // gets left behind on a single line + let mut temp_code = code.to_string(); + for (range, snippet) in replacements.iter_mut() { + let adjusted_range = adjust_range(&range.effective_range(), offset, &temp_code)?; + if adjusted_range.start > temp_code.len() || adjusted_range.end > temp_code.len() { + bail!("Range {:?} is out of bounds for code:\n{}\n", adjusted_range, temp_code); + } + + temp_code.replace_range(adjusted_range.clone(), snippet); + + let line_start = temp_code[..adjusted_range.start].rfind('\n').map_or(0, |pos| pos + 1); + let line_end = temp_code[adjusted_range.start..].find('\n').map_or(temp_code.len(), |pos| adjusted_range.start + pos); + let line_content = temp_code[line_start..line_end].trim(); + + if line_content == "," { + // inclusion of "," index in replacements + range.range.end += 1; + } + } + let deletion_ranges = replacements .iter() .filter_map(|r| { diff --git a/crates/core/src/test.rs b/crates/core/src/test.rs index a76b46d2b..a47e47bb7 100644 --- a/crates/core/src/test.rs +++ b/crates/core/src/test.rs @@ -12014,6 +12014,105 @@ fn trailing_comma_import_from_python_with_alias() { .unwrap(); } +// refer https://github.com/getgrit/gritql/issues/416 +#[test] +fn trailing_comma_after_argument_removal() { + run_test_expected({ + TestArgExpected { + pattern: r#" + language python + `TaskMetadata($args)` where { + $args <: any { + contains `n_samples=$_` as $ns_kwarg where { + $ns_kwarg <: `n_samples = $ns_val` => . + }, + contains `avg_character_length=$_` as $avg_kwarg where { + $avg_kwarg <: `avg_character_length = $avg_val` => `stats=GeneralDescriptiveStats(n_samples=$ns_val, avg_character_length=$avg_val)` + }, + }, + } + "#.to_owned(), + source: r#" + from pydantic import BaseModel + + + class TaskMetadata(BaseModel): + n_samples: dict[str, int] + avg_character_length: dict[str, float] + + + if __name__ == "__main__": + TaskMetadata( + name="TbilisiCityHallBitextMining", + dataset={ + "path": "jupyterjazz/tbilisi-city-hall-titles", + "revision": "798bb599140565cca2dab8473035fa167e5ee602", + }, + description="Parallel news titles from the Tbilisi City Hall website (https://tbilisi.gov.ge/).", + type="BitextMining", + category="s2s", + eval_splits=[_EVAL_SPLIT], + eval_langs=_EVAL_LANGS, + main_score="f1", + domains=["News"], + text_creation="created", + n_samples={_EVAL_SPLIT: 1820}, + reference="https://huggingface.co/datasets/jupyterjazz/tbilisi-city-hall-titles", + date=("2024-05-02", "2024-05-03"), + form=["written"], + task_subtypes=[], + license="Not specified", + socioeconomic_status="mixed", + annotations_creators="derived", + dialect=[], + bibtex_citation="", + avg_character_length={_EVAL_SPLIT: 78}, + ) + "# + .to_owned(), + expected: r#" + from pydantic import BaseModel + + + class TaskMetadata(BaseModel): + n_samples: dict[str, int] + avg_character_length: dict[str, float] + + + if __name__ == "__main__": + TaskMetadata( + name="TbilisiCityHallBitextMining", + dataset={ + "path": "jupyterjazz/tbilisi-city-hall-titles", + "revision": "798bb599140565cca2dab8473035fa167e5ee602", + }, + description="Parallel news titles from the Tbilisi City Hall website (https://tbilisi.gov.ge/).", + type="BitextMining", + category="s2s", + eval_splits=[_EVAL_SPLIT], + eval_langs=_EVAL_LANGS, + main_score="f1", + domains=["News"], + text_creation="created", + + reference="https://huggingface.co/datasets/jupyterjazz/tbilisi-city-hall-titles", + date=("2024-05-02", "2024-05-03"), + form=["written"], + task_subtypes=[], + license="Not specified", + socioeconomic_status="mixed", + annotations_creators="derived", + dialect=[], + bibtex_citation="", + stats=GeneralDescriptiveStats(n_samples={_EVAL_SPLIT: 1820}, avg_character_length={_EVAL_SPLIT: 78}), + ) + "# + .to_owned(), + } + }) + .unwrap(); +} + #[test] fn python_orphaned_from_imports() { run_test_expected({