Skip to content
Merged

Dev #463

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
4fcc094
Fresh tpu start
VikParuchuri Aug 4, 2025
ab9eff4
Static cache impl
VikParuchuri Aug 4, 2025
dd7b127
Add in original codepath
VikParuchuri Aug 4, 2025
d9f6e4c
Fix issues with GPU codepaths
VikParuchuri Aug 4, 2025
a1aa155
Fix embedding with a static scatter
VikParuchuri Aug 5, 2025
2c60d24
Cleanup debug logs
VikParuchuri Aug 5, 2025
3b30120
Enable compile
VikParuchuri Aug 5, 2025
d4461c6
Fix mark steps
VikParuchuri Aug 6, 2025
768d8d5
Move layout
VikParuchuri Aug 6, 2025
523bd66
Merge branch 'vik/layout' into vik/tpu3
VikParuchuri Aug 6, 2025
0600fc5
Enable re-embedding bboxes
VikParuchuri Aug 6, 2025
8d1ef85
Merge remote-tracking branch 'origin/vik/tpu-layout' into vik/tpu-layout
VikParuchuri Aug 6, 2025
185b57a
Cleanup
VikParuchuri Aug 6, 2025
e1df24c
Cleanup embedding
VikParuchuri Aug 6, 2025
669ce48
Patch clamp issue
VikParuchuri Aug 7, 2025
d55f00a
Integrate table rec predictor
VikParuchuri Aug 8, 2025
f03b58b
Fix table rec
VikParuchuri Aug 8, 2025
eee29d4
Fix beacon issue
VikParuchuri Aug 11, 2025
8367a63
Accuracy fixes
VikParuchuri Aug 11, 2025
2748109
Fix encoder chunking
VikParuchuri Aug 11, 2025
de94700
Fix text lengths
VikParuchuri Aug 11, 2025
fc6657e
Use fix-length index
VikParuchuri Aug 12, 2025
9e5fa29
Wire in table structure
VikParuchuri Aug 12, 2025
a511a09
Pad image embeddings
VikParuchuri Aug 12, 2025
609caf4
Fix tensor creation
VikParuchuri Aug 12, 2025
f2eecf1
Properly pad
VikParuchuri Aug 12, 2025
14e7ee6
Avoid truncating layout and table
VikParuchuri Aug 15, 2025
cbe23fa
Tables can have a lot of cells
VikParuchuri Aug 15, 2025
a73eee6
Force bf16
VikParuchuri Aug 15, 2025
053f13c
Fix padding on tpu
VikParuchuri Aug 15, 2025
4613f45
Prefill fix
VikParuchuri Aug 18, 2025
a95b6ca
Fix layout and table rec image bbox
VikParuchuri Aug 19, 2025
e1aa09d
Set disable tqdm
VikParuchuri Aug 19, 2025
d6f3515
feat: new unified tokenizer
zanussbaum Aug 25, 2025
a8a0214
Merge pull request #442 from datalab-to/new-tokenizer
VikParuchuri Aug 30, 2025
3b1e8dc
Add bbox head
VikParuchuri Sep 8, 2025
0a4068b
Iterate on bbox head
VikParuchuri Sep 8, 2025
2d5dd9b
Move back to old table_rec model for now
tarun-menta Sep 17, 2025
9673ec6
Merge in recognition predictor changes from dev
tarun-menta Sep 17, 2025
74e790c
Tokenizer fix
tarun-menta Sep 17, 2025
20f9179
Pin sliding window for layout
tarun-menta Sep 18, 2025
9ab25b3
Merge branch 'dev' into layout-release
tarun-menta Sep 19, 2025
42e016f
Update tqdm desc string based on founation model mode
tarun-menta Sep 19, 2025
c1719e9
Update loaders with `dtype` instead of `torch_dtype` -- transformers
tarun-menta Sep 22, 2025
5811d07
Separate models for layout and OCR
tarun-menta Sep 23, 2025
4d7be66
Models moved to S3
tarun-menta Sep 23, 2025
1d09025
Bump foundation checkpoint
tarun-menta Sep 23, 2025
9bee27c
Fix tests
tarun-menta Sep 23, 2025
eb179cc
Update layout batch sizes
tarun-menta Sep 23, 2025
d3aecc0
Pick correct dtype on T4 GPUs
tarun-menta Sep 23, 2025
466aba7
Merge pull request #461 from datalab-to/layout-release
tarun-menta Sep 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,12 @@ Setting the `LAYOUT_BATCH_SIZE` env var properly will make a big difference when

```python
from PIL import Image
from surya.foundation import FoundationPredictor
from surya.layout import LayoutPredictor
from surya.settings import settings

image = Image.open(IMAGE_PATH)
layout_predictor = LayoutPredictor()
layout_predictor = LayoutPredictor(FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT))

# layout_predictions is a list of dicts, one per image
layout_predictions = layout_predictor([image])
Expand Down
84 changes: 65 additions & 19 deletions benchmark/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import click

from benchmark.utils.metrics import precision_recall
from surya.foundation import FoundationPredictor
from surya.layout import LayoutPredictor
from surya.input.processing import convert_if_not_rgb
from surya.debug.draw import draw_bboxes_on_image
Expand All @@ -16,15 +17,28 @@


@click.command(help="Benchmark surya layout model.")
@click.option("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "benchmark"))
@click.option("--max_rows", type=int, help="Maximum number of images to run benchmark on.", default=100)
@click.option(
"--results_dir",
type=str,
help="Path to JSON file with OCR results.",
default=os.path.join(settings.RESULT_DIR, "benchmark"),
)
@click.option(
"--max_rows",
type=int,
help="Maximum number of images to run benchmark on.",
default=100,
)
@click.option("--debug", is_flag=True, help="Run in debug mode.", default=False)
def main(results_dir: str, max_rows: int, debug: bool):
layout_predictor = LayoutPredictor()
foundation_predictor = FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
layout_predictor = LayoutPredictor(foundation_predictor)

pathname = "layout_bench"
# These have already been shuffled randomly, so sampling from the start is fine
dataset = datasets.load_dataset(settings.LAYOUT_BENCH_DATASET_NAME, split=f"train[:{max_rows}]")
dataset = datasets.load_dataset(
settings.LAYOUT_BENCH_DATASET_NAME, split=f"train[:{max_rows}]"
)
images = list(dataset["image"])
images = convert_if_not_rgb(images)

Expand All @@ -39,12 +53,23 @@ def main(results_dir: str, max_rows: int, debug: bool):
result_path = os.path.join(results_dir, folder_name)
os.makedirs(result_path, exist_ok=True)

label_alignment = { # First is publaynet, second is surya
label_alignment = { # First is publaynet, second is surya
"Image": [["Figure"], ["Picture", "Figure"]],
"Table": [["Table"], ["Table", "Form", "TableOfContents"]],
"Text": [["Text"], ["Text", "Formula", "Footnote", "Caption", "TextInlineMath", "Code", "Handwriting"]],
"Text": [
["Text"],
[
"Text",
"Formula",
"Footnote",
"Caption",
"TextInlineMath",
"Code",
"Handwriting",
],
],
"List": [["List"], ["ListItem"]],
"Title": [["Title"], ["SectionHeader", "Title"]]
"Title": [["Title"], ["SectionHeader", "Title"]],
}

page_metrics = collections.OrderedDict()
Expand All @@ -54,55 +79,76 @@ def main(results_dir: str, max_rows: int, debug: bool):
page_results = {}
for label_name in label_alignment:
correct_cats, surya_cats = label_alignment[label_name]
correct_bboxes = [b for b, l in zip(row["bboxes"], row["labels"]) if l in correct_cats]
correct_bboxes = [
b
for b, category in zip(row["bboxes"], row["labels"])
if category in correct_cats
]
all_correct_bboxes.extend(correct_bboxes)
pred_bboxes = [b.bbox for b in pred.bboxes if b.label in surya_cats]

metrics = precision_recall(pred_bboxes, correct_bboxes, penalize_double=False)
metrics = precision_recall(
pred_bboxes, correct_bboxes, penalize_double=False
)
weight = len(correct_bboxes)
metrics["weight"] = weight
page_results[label_name] = metrics

page_metrics[idx] = page_results

if debug:
bbox_image = draw_bboxes_on_image(all_correct_bboxes, copy.deepcopy(images[idx]))
bbox_image = draw_bboxes_on_image(
all_correct_bboxes, copy.deepcopy(images[idx])
)
bbox_image.save(os.path.join(result_path, f"{idx}_layout.png"))

mean_metrics = collections.defaultdict(dict)
layout_types = sorted(page_metrics[0].keys())
metric_types = sorted(page_metrics[0][layout_types[0]].keys())
metric_types.remove("weight")
for l in layout_types:
for label in layout_types:
for m in metric_types:
metric = []
total = 0
for page in page_metrics:
metric.append(page_metrics[page][l][m] * page_metrics[page][l]["weight"])
total += page_metrics[page][l]["weight"]
metric.append(
page_metrics[page][label][m] * page_metrics[page][label]["weight"]
)
total += page_metrics[page][label]["weight"]

value = sum(metric)
if value > 0:
value /= total
mean_metrics[l][m] = value
mean_metrics[label][m] = value

out_data = {
"time": surya_time,
"metrics": mean_metrics,
"page_metrics": page_metrics
"page_metrics": page_metrics,
}

with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f:
json.dump(out_data, f, indent=4)

table_headers = ["Layout Type", ] + metric_types
table_headers = [
"Layout Type",
] + metric_types
table_data = []
for layout_type in layout_types:
table_data.append([layout_type, ] + [f"{mean_metrics[layout_type][m]:.5f}" for m in metric_types])
table_data.append(
[
layout_type,
]
+ [f"{mean_metrics[layout_type][m]:.5f}" for m in metric_types]
)

print(tabulate(table_data, headers=table_headers, tablefmt="github"))
print(f"Took {surya_time / len(images):.5f} seconds per image, and {surya_time:.5f} seconds total.")
print("Precision and recall are over the mutual coverage of the detected boxes and the ground truth boxes at a .5 threshold.")
print(
f"Took {surya_time / len(images):.5f} seconds per image, and {surya_time:.5f} seconds total."
)
print(
"Precision and recall are over the mutual coverage of the detected boxes and the ground truth boxes at a .5 threshold."
)
print(f"Wrote results to {result_path}")


Expand Down
29 changes: 20 additions & 9 deletions benchmark/ordering.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import click

from surya.foundation import FoundationPredictor
from surya.input.processing import convert_if_not_rgb
from surya.layout import LayoutPredictor
from surya.common.polygon import PolygonBox
Expand All @@ -14,10 +15,21 @@


@click.command(help="Benchmark surya layout for reading order.")
@click.option("--results_dir", type=str, help="Path to JSON file with benchmark results.", default=os.path.join(settings.RESULT_DIR, "benchmark"))
@click.option("--max_rows", type=int, help="Maximum number of images to run benchmark on.", default=None)
@click.option(
"--results_dir",
type=str,
help="Path to JSON file with benchmark results.",
default=os.path.join(settings.RESULT_DIR, "benchmark"),
)
@click.option(
"--max_rows",
type=int,
help="Maximum number of images to run benchmark on.",
default=None,
)
def main(results_dir: str, max_rows: int):
layout_predictor = LayoutPredictor()
foundation_predictor = FoundationPredictor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
layout_predictor = LayoutPredictor(foundation_predictor)
pathname = "order_bench"
# These have already been shuffled randomly, so sampling from the start is fine
split = "train"
Expand Down Expand Up @@ -53,10 +65,7 @@ def main(results_dir: str, max_rows: int):
pred_positions.append(matching_idx)
accuracy = rank_accuracy(pred_positions, labels)
mean_accuracy += accuracy
page_results = {
"accuracy": accuracy,
"box_count": len(labels)
}
page_results = {"accuracy": accuracy, "box_count": len(labels)}

page_metrics[idx] = page_results

Expand All @@ -65,14 +74,16 @@ def main(results_dir: str, max_rows: int):
out_data = {
"time": surya_time,
"mean_accuracy": mean_accuracy,
"page_metrics": page_metrics
"page_metrics": page_metrics,
}

with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f:
json.dump(out_data, f, indent=4)

print(f"Mean accuracy is {mean_accuracy:.2f}.")
print(f"Took {surya_time / len(images):.2f} seconds per image, and {surya_time:.1f} seconds total.")
print(
f"Took {surya_time / len(images):.2f} seconds per image, and {surya_time:.1f} seconds total."
)
print("Mean accuracy is the % of correct ranking pairs.")
print(f"Wrote results to {result_path}")

Expand Down
22 changes: 22 additions & 0 deletions benchmark/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ def normalize_text(text: str) -> str:
help="Comma-separated list of languages to benchmark.",
default=None,
)
@click.option(
"--print_results",
is_flag=True,
)
def main(
results_dir: str,
max_rows: int,
Expand All @@ -112,6 +116,7 @@ def main(
tess_cpus: int,
textract_cpus: int,
languages: str | None,
print_results: bool,
):
foundation_predictor = FoundationPredictor()
rec_predictor = RecognitionPredictor(foundation_predictor)
Expand Down Expand Up @@ -352,6 +357,23 @@ def main(

print(f"Wrote results to {result_path}")

if print_results:
for idx, (pred, ref_text) in enumerate(zip(predictions_by_image, line_text)):
print(f"Image {idx}")
print("----")
for line_idx, (pred_line, ref_line) in enumerate(
zip(pred.text_lines, ref_text)
):
print(f"Sample {line_idx}")
print(f"Pred: {pred_line.text}")
print(f"Ref: {ref_line}")
print()

if settings.TORCH_DEVICE == "xla":
import torch_xla.debug.metrics as met

print(met.short_metrics_report())


if __name__ == "__main__":
main()
Loading
Loading