|
33 | 33 | from huggingface_hub import snapshot_download |
34 | 34 | from PIL import Image |
35 | 35 | from pypdf import PdfReader as pdf2_read |
| 36 | +from sklearn.cluster import KMeans |
| 37 | +from sklearn.metrics import silhouette_score |
36 | 38 |
|
37 | 39 | from common.file_utils import get_project_base_directory |
38 | 40 | from common.misc_utils import pip_install_torch |
@@ -353,69 +355,87 @@ def _layouts_rec(self, ZM, drop=True): |
353 | 355 | def _assign_column(self, boxes, zoomin=3): |
354 | 356 | if not boxes: |
355 | 357 | return boxes |
356 | | - |
357 | 358 | if all("col_id" in b for b in boxes): |
358 | 359 | return boxes |
359 | 360 |
|
360 | 361 | by_page = defaultdict(list) |
361 | 362 | for b in boxes: |
362 | 363 | by_page[b["page_number"]].append(b) |
363 | 364 |
|
364 | | - page_info = {} # pg -> dict(page_w, left_edge, cand_cols) |
365 | | - counter = Counter() |
| 365 | + page_cols = {} |
366 | 366 |
|
367 | 367 | for pg, bxs in by_page.items(): |
368 | 368 | if not bxs: |
369 | | - page_info[pg] = {"page_w": 1.0, "left_edge": 0.0, "cand": 1} |
370 | | - counter[1] += 1 |
| 369 | + page_cols[pg] = 1 |
371 | 370 | continue |
372 | 371 |
|
373 | | - if hasattr(self, "page_images") and self.page_images and len(self.page_images) >= pg: |
374 | | - page_w = self.page_images[pg - 1].size[0] / max(1, zoomin) |
375 | | - left_edge = 0.0 |
376 | | - else: |
377 | | - xs0 = [box["x0"] for box in bxs] |
378 | | - xs1 = [box["x1"] for box in bxs] |
379 | | - left_edge = float(min(xs0)) |
380 | | - page_w = max(1.0, float(max(xs1) - left_edge)) |
381 | | - |
382 | | - widths = [max(1.0, (box["x1"] - box["x0"])) for box in bxs] |
383 | | - median_w = float(np.median(widths)) if widths else 1.0 |
| 372 | + x0s_raw = np.array([b["x0"] for b in bxs], dtype=float) |
384 | 373 |
|
385 | | - raw_cols = int(page_w / max(1.0, median_w)) |
| 374 | + min_x0 = np.min(x0s_raw) |
| 375 | + max_x1 = np.max([b["x1"] for b in bxs]) |
| 376 | + width = max_x1 - min_x0 |
386 | 377 |
|
387 | | - # cand = raw_cols if (raw_cols >= 2 and median_w < page_w / raw_cols * 0.8) else 1 |
388 | | - cand = raw_cols |
| 378 | + INDENT_TOL = width * 0.12 |
| 379 | + x0s = [] |
| 380 | + for x in x0s_raw: |
| 381 | + if abs(x - min_x0) < INDENT_TOL: |
| 382 | + x0s.append([min_x0]) |
| 383 | + else: |
| 384 | + x0s.append([x]) |
| 385 | + x0s = np.array(x0s, dtype=float) |
| 386 | + |
| 387 | + max_try = min(4, len(bxs)) |
| 388 | + if max_try < 2: |
| 389 | + max_try = 1 |
| 390 | + best_k = 1 |
| 391 | + best_score = -1 |
| 392 | + |
| 393 | + for k in range(1, max_try + 1): |
| 394 | + km = KMeans(n_clusters=k, n_init="auto") |
| 395 | + labels = km.fit_predict(x0s) |
| 396 | + |
| 397 | + centers = np.sort(km.cluster_centers_.flatten()) |
| 398 | + if len(centers) > 1: |
| 399 | + try: |
| 400 | + score = silhouette_score(x0s, labels) |
| 401 | + except ValueError: |
| 402 | + continue |
| 403 | + else: |
| 404 | + score = 0 |
| 405 | + print(f"{k=},{score=}",flush=True) |
| 406 | + if score > best_score: |
| 407 | + best_score = score |
| 408 | + best_k = k |
389 | 409 |
|
390 | | - page_info[pg] = {"page_w": page_w, "left_edge": left_edge, "cand": cand} |
391 | | - counter[cand] += 1 |
| 410 | + page_cols[pg] = best_k |
| 411 | + logging.info(f"[Page {pg}] best_score={best_score:.2f}, best_k={best_k}") |
392 | 412 |
|
393 | | - logging.info(f"[Page {pg}] median_w={median_w:.2f}, page_w={page_w:.2f}, raw_cols={raw_cols}, cand={cand}") |
394 | 413 |
|
395 | | - global_cols = counter.most_common(1)[0][0] |
| 414 | + global_cols = Counter(page_cols.values()).most_common(1)[0][0] |
396 | 415 | logging.info(f"Global column_num decided by majority: {global_cols}") |
397 | 416 |
|
| 417 | + |
398 | 418 | for pg, bxs in by_page.items(): |
399 | 419 | if not bxs: |
400 | 420 | continue |
401 | | - |
402 | | - page_w = page_info[pg]["page_w"] |
403 | | - left_edge = page_info[pg]["left_edge"] |
404 | | - |
405 | | - if global_cols == 1: |
406 | | - for box in bxs: |
407 | | - box["col_id"] = 0 |
408 | | - continue |
409 | | - |
410 | | - for box in bxs: |
411 | | - w = box["x1"] - box["x0"] |
412 | | - if w >= 0.8 * page_w: |
413 | | - box["col_id"] = 0 |
414 | | - continue |
415 | | - cx = 0.5 * (box["x0"] + box["x1"]) |
416 | | - norm_cx = (cx - left_edge) / page_w |
417 | | - norm_cx = max(0.0, min(norm_cx, 0.999999)) |
418 | | - box["col_id"] = int(min(global_cols - 1, norm_cx * global_cols)) |
| 421 | + k = page_cols[pg] |
| 422 | + if len(bxs) < k: |
| 423 | + k = 1 |
| 424 | + x0s = np.array([[b["x0"]] for b in bxs], dtype=float) |
| 425 | + km = KMeans(n_clusters=k, n_init="auto") |
| 426 | + labels = km.fit_predict(x0s) |
| 427 | + |
| 428 | + centers = km.cluster_centers_.flatten() |
| 429 | + order = np.argsort(centers) |
| 430 | + |
| 431 | + remap = {orig: new for new, orig in enumerate(order)} |
| 432 | + |
| 433 | + for b, lb in zip(bxs, labels): |
| 434 | + b["col_id"] = remap[lb] |
| 435 | + |
| 436 | + grouped = defaultdict(list) |
| 437 | + for b in bxs: |
| 438 | + grouped[b["col_id"]].append(b) |
419 | 439 |
|
420 | 440 | return boxes |
421 | 441 |
|
@@ -1303,7 +1323,10 @@ def crop(self, text, ZM=3, need_position=False): |
1303 | 1323 |
|
1304 | 1324 | positions = [] |
1305 | 1325 | for ii, (pns, left, right, top, bottom) in enumerate(poss): |
1306 | | - right = left + max_width |
| 1326 | + if 0 < ii < len(poss) - 1: |
| 1327 | + right = max(left + 10, right) |
| 1328 | + else: |
| 1329 | + right = left + max_width |
1307 | 1330 | bottom *= ZM |
1308 | 1331 | for pn in pns[1:]: |
1309 | 1332 | if 0 <= pn - 1 < page_count: |
|
0 commit comments