Skip to content

Commit c8ab907

Browse files
authored
Fix:improve multi-column document detection (#11415)
### What problem does this PR solve? change: improve multi-column document detection ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
1 parent 0d5589b commit c8ab907

File tree

1 file changed

+65
-42
lines changed

1 file changed

+65
-42
lines changed

deepdoc/parser/pdf_parser.py

Lines changed: 65 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
from huggingface_hub import snapshot_download
3434
from PIL import Image
3535
from pypdf import PdfReader as pdf2_read
36+
from sklearn.cluster import KMeans
37+
from sklearn.metrics import silhouette_score
3638

3739
from common.file_utils import get_project_base_directory
3840
from common.misc_utils import pip_install_torch
@@ -353,69 +355,87 @@ def _layouts_rec(self, ZM, drop=True):
353355
def _assign_column(self, boxes, zoomin=3):
354356
if not boxes:
355357
return boxes
356-
357358
if all("col_id" in b for b in boxes):
358359
return boxes
359360

360361
by_page = defaultdict(list)
361362
for b in boxes:
362363
by_page[b["page_number"]].append(b)
363364

364-
page_info = {} # pg -> dict(page_w, left_edge, cand_cols)
365-
counter = Counter()
365+
page_cols = {}
366366

367367
for pg, bxs in by_page.items():
368368
if not bxs:
369-
page_info[pg] = {"page_w": 1.0, "left_edge": 0.0, "cand": 1}
370-
counter[1] += 1
369+
page_cols[pg] = 1
371370
continue
372371

373-
if hasattr(self, "page_images") and self.page_images and len(self.page_images) >= pg:
374-
page_w = self.page_images[pg - 1].size[0] / max(1, zoomin)
375-
left_edge = 0.0
376-
else:
377-
xs0 = [box["x0"] for box in bxs]
378-
xs1 = [box["x1"] for box in bxs]
379-
left_edge = float(min(xs0))
380-
page_w = max(1.0, float(max(xs1) - left_edge))
381-
382-
widths = [max(1.0, (box["x1"] - box["x0"])) for box in bxs]
383-
median_w = float(np.median(widths)) if widths else 1.0
372+
x0s_raw = np.array([b["x0"] for b in bxs], dtype=float)
384373

385-
raw_cols = int(page_w / max(1.0, median_w))
374+
min_x0 = np.min(x0s_raw)
375+
max_x1 = np.max([b["x1"] for b in bxs])
376+
width = max_x1 - min_x0
386377

387-
# cand = raw_cols if (raw_cols >= 2 and median_w < page_w / raw_cols * 0.8) else 1
388-
cand = raw_cols
378+
INDENT_TOL = width * 0.12
379+
x0s = []
380+
for x in x0s_raw:
381+
if abs(x - min_x0) < INDENT_TOL:
382+
x0s.append([min_x0])
383+
else:
384+
x0s.append([x])
385+
x0s = np.array(x0s, dtype=float)
386+
387+
max_try = min(4, len(bxs))
388+
if max_try < 2:
389+
max_try = 1
390+
best_k = 1
391+
best_score = -1
392+
393+
for k in range(1, max_try + 1):
394+
km = KMeans(n_clusters=k, n_init="auto")
395+
labels = km.fit_predict(x0s)
396+
397+
centers = np.sort(km.cluster_centers_.flatten())
398+
if len(centers) > 1:
399+
try:
400+
score = silhouette_score(x0s, labels)
401+
except ValueError:
402+
continue
403+
else:
404+
score = 0
405+
print(f"{k=},{score=}",flush=True)
406+
if score > best_score:
407+
best_score = score
408+
best_k = k
389409

390-
page_info[pg] = {"page_w": page_w, "left_edge": left_edge, "cand": cand}
391-
counter[cand] += 1
410+
page_cols[pg] = best_k
411+
logging.info(f"[Page {pg}] best_score={best_score:.2f}, best_k={best_k}")
392412

393-
logging.info(f"[Page {pg}] median_w={median_w:.2f}, page_w={page_w:.2f}, raw_cols={raw_cols}, cand={cand}")
394413

395-
global_cols = counter.most_common(1)[0][0]
414+
global_cols = Counter(page_cols.values()).most_common(1)[0][0]
396415
logging.info(f"Global column_num decided by majority: {global_cols}")
397416

417+
398418
for pg, bxs in by_page.items():
399419
if not bxs:
400420
continue
401-
402-
page_w = page_info[pg]["page_w"]
403-
left_edge = page_info[pg]["left_edge"]
404-
405-
if global_cols == 1:
406-
for box in bxs:
407-
box["col_id"] = 0
408-
continue
409-
410-
for box in bxs:
411-
w = box["x1"] - box["x0"]
412-
if w >= 0.8 * page_w:
413-
box["col_id"] = 0
414-
continue
415-
cx = 0.5 * (box["x0"] + box["x1"])
416-
norm_cx = (cx - left_edge) / page_w
417-
norm_cx = max(0.0, min(norm_cx, 0.999999))
418-
box["col_id"] = int(min(global_cols - 1, norm_cx * global_cols))
421+
k = page_cols[pg]
422+
if len(bxs) < k:
423+
k = 1
424+
x0s = np.array([[b["x0"]] for b in bxs], dtype=float)
425+
km = KMeans(n_clusters=k, n_init="auto")
426+
labels = km.fit_predict(x0s)
427+
428+
centers = km.cluster_centers_.flatten()
429+
order = np.argsort(centers)
430+
431+
remap = {orig: new for new, orig in enumerate(order)}
432+
433+
for b, lb in zip(bxs, labels):
434+
b["col_id"] = remap[lb]
435+
436+
grouped = defaultdict(list)
437+
for b in bxs:
438+
grouped[b["col_id"]].append(b)
419439

420440
return boxes
421441

@@ -1303,7 +1323,10 @@ def crop(self, text, ZM=3, need_position=False):
13031323

13041324
positions = []
13051325
for ii, (pns, left, right, top, bottom) in enumerate(poss):
1306-
right = left + max_width
1326+
if 0 < ii < len(poss) - 1:
1327+
right = max(left + 10, right)
1328+
else:
1329+
right = left + max_width
13071330
bottom *= ZM
13081331
for pn in pns[1:]:
13091332
if 0 <= pn - 1 < page_count:

0 commit comments

Comments
 (0)