Skip to content
Open

1 #319

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .gitignore

This file was deleted.

110 changes: 92 additions & 18 deletions create_heatmaps.py

Large diffs are not rendered by default.

37 changes: 24 additions & 13 deletions create_patches_fp.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# internal imports
from wsi_core.WholeSlideImage import WholeSlideImage
from wsi_core.wsi_utils import StitchCoords
from wsi_core.batch_process_utils import initialize_df
# ========== 内置依赖 ==========
from wsi_core.WholeSlideImage import WholeSlideImage # 封装 OpenSlide 的 WSI 操作类
from wsi_core.wsi_utils import StitchCoords # 把 patch 结果画回全图的小工具
from wsi_core.batch_process_utils import initialize_df # 生成/初始化 csv 清单
# other imports
# ========== 通用库 ==========
import os
import numpy as np
import time
Expand All @@ -11,13 +13,15 @@
import pandas as pd
from tqdm import tqdm

# ------------------ 1. 把 patch 结果“拼”成缩略热图 ------------------
def stitching(file_path, wsi_object, downscale = 64):
start = time.time()
heatmap = StitchCoords(file_path, wsi_object, downscale=downscale, bg_color=(0,0,0), alpha=-1, draw_grid=False)
total_time = time.time() - start

return heatmap, total_time

# ------------------ 2. 组织前景分割(得到组织轮廓) ------------------
def segment(WSI_object, seg_params = None, filter_params = None, mask_file = None):
### Start Seg Timer
start_time = time.time()
Expand All @@ -32,6 +36,7 @@ def segment(WSI_object, seg_params = None, filter_params = None, mask_file = Non
seg_time_elapsed = time.time() - start_time
return WSI_object, seg_time_elapsed

# ------------------ 3. 按轮廓切 patch ------------------
def patching(WSI_object, **kwargs):
### Start Patch Timer
start_time = time.time()
Expand All @@ -44,7 +49,7 @@ def patching(WSI_object, **kwargs):
patch_time_elapsed = time.time() - start_time
return file_path, patch_time_elapsed


# ------------------ 4. 主流程:分割 + 切 patch + 可选拼热图 ------------------
def seg_and_patch(source, save_dir, patch_save_dir, mask_save_dir, stitch_save_dir,
patch_size = 256, step_size = 256,
seg_params = {'seg_level': -1, 'sthresh': 8, 'mthresh': 7, 'close': 4, 'use_otsu': False,
Expand All @@ -58,8 +63,14 @@ def seg_and_patch(source, save_dir, patch_save_dir, mask_save_dir, stitch_save_d
stitch= False,
patch = False, auto_skip=True, process_list = None):



"""
source : 放 WSI 的目录
save_dir : 总输出目录
*save_dir 下会再分 patches/masks/stitches 三个子目录
seg/patch/stitch : bool,分别控制“只分割/只切图/只拼热图”
process_list : 可选 csv,指定每张片子的专属参数
"""
# --------------- 4.1 准备文件清单 ---------------
slides = sorted(os.listdir(source))
slides = [slide for slide in slides if os.path.isfile(os.path.join(source, slide))]
if process_list is None:
Expand All @@ -86,7 +97,7 @@ def seg_and_patch(source, save_dir, patch_save_dir, mask_save_dir, stitch_save_d
seg_times = 0.
patch_times = 0.
stitch_times = 0.

# --------------- 4.2 主循环 ---------------
for i in tqdm(range(total)):
df.to_csv(os.path.join(save_dir, 'process_list_autogen.csv'), index=False)
idx = process_stack.index[i]
Expand All @@ -102,10 +113,10 @@ def seg_and_patch(source, save_dir, patch_save_dir, mask_save_dir, stitch_save_d
df.loc[idx, 'status'] = 'already_exist'
continue

# Inialize WSI
# Inialize WSI # 载入 WSI
full_path = os.path.join(source, slide)
WSI_object = WholeSlideImage(full_path)

# --------------- 4.3 参数优先级:csv > default ---------------
if use_default_params:
current_vis_params = vis_params.copy()
current_filter_params = filter_params.copy()
Expand Down Expand Up @@ -175,10 +186,10 @@ def seg_and_patch(source, save_dir, patch_save_dir, mask_save_dir, stitch_save_d
current_seg_params['exclude_ids'] = []

w, h = WSI_object.level_dim[current_seg_params['seg_level']]
if w * h > 1e8:
print('level_dim {} x {} is likely too large for successful segmentation, aborting'.format(w, h))
df.loc[idx, 'status'] = 'failed_seg'
continue
#if w * h > 5e8:#1e8
# print('level_dim {} x {} is likely too large for successful segmentation, aborting'.format(w, h))
# df.loc[idx, 'status'] = 'failed_seg'
# continue

df.loc[idx, 'vis_level'] = current_vis_params['vis_level']
df.loc[idx, 'seg_level'] = current_seg_params['seg_level']
Expand Down
24 changes: 20 additions & 4 deletions create_splits_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import argparse
import numpy as np

# ---------- 1. 命令行参数 ----------
parser = argparse.ArgumentParser(description='Creating splits for whole slide classification')
parser.add_argument("--csv_path", type=str, required=True)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument('--label_frac', type=float, default= 1.0,
help='fraction of labels (default: 1)')
parser.add_argument('--seed', type=int, default=1,
Expand All @@ -20,9 +23,11 @@

args = parser.parse_args()

# ---------- 2. 根据任务选择数据集 ----------
if args.task == 'task_1_tumor_vs_normal':
args.n_classes=2
dataset = Generic_WSI_Classification_Dataset(csv_path = 'dataset_csv/tumor_vs_normal_dummy_clean.csv',
dataset = Generic_WSI_Classification_Dataset(csv_path = '/data/cuiping/NSCLC/labels_digital.csv',
#dataset = Generic_WSI_Classification_Dataset(csv_path = 'dataset_csv/tumor_vs_normal_dummy_clean.csv',
shuffle = False,
seed = args.seed,
print_info = True,
Expand All @@ -31,23 +36,30 @@
ignore=[])

elif args.task == 'task_2_tumor_subtyping':
args.n_classes=3
dataset = Generic_WSI_Classification_Dataset(csv_path = 'dataset_csv/tumor_subtyping_dummy_clean.csv',
args.n_classes=2
dataset = Generic_WSI_Classification_Dataset(csv_path = '/data/cuiping/NSCLC/labels_digital.csv',
#dataset = Generic_WSI_Classification_Dataset(csv_path = '/data/cuiping/RCC/labels_digital.csv',
shuffle = False,
seed = args.seed,
print_info = True,
label_dict = {'subtype_1':0, 'subtype_2':1, 'subtype_3':2},
label_dict = {'subtype_1':0, 'subtype_2':1},
#label_dict = {'subtype_1':0, 'subtype_2':1, 'subtype_3':2},
patient_strat= True,
patient_voting='maj',
ignore=[])

else:
raise NotImplementedError

# ---------- 3. 计算每类应分多少验证/测试 ----------
# 先统计每类有多少患者(patient_strat=True 时按患者算)
num_slides_cls = np.array([len(cls_ids) for cls_ids in dataset.patient_cls_ids])
# 按用户给的比例向下取整
val_num = np.round(num_slides_cls * args.val_frac).astype(int)
test_num = np.round(num_slides_cls * args.test_frac).astype(int)


# ---------- 4. 主流程:生成 k 折并保存 ----------
if __name__ == '__main__':
if args.label_frac > 0:
label_fracs = [args.label_frac]
Expand All @@ -62,6 +74,10 @@
dataset.set_splits()
descriptor_df = dataset.test_split_gen(return_descriptor=True)
splits = dataset.return_splits(from_id=True)
# 保存三种文件:
# ① splits_{i}.csv → 三列布尔矩阵
# ② splits_{i}_bool.csv → 同上,但纯 0/1
# ③ splits_{i}_descriptor.csv → 每折详细描述
save_splits(splits, ['train', 'val', 'test'], os.path.join(split_dir, 'splits_{}.csv'.format(i)))
save_splits(splits, ['train', 'val', 'test'], os.path.join(split_dir, 'splits_{}_bool.csv'.format(i)), boolean_style=True)
descriptor_df.to_csv(os.path.join(split_dir, 'splits_{}_descriptor.csv'.format(i)))
Expand Down
Loading