Skip to content

Commit 6a75d16

Browse files
authored
Merge pull request PaddlePaddle#50 from LokeZhou/autolable
autolable update blip2
2 parents 6b33637 + 7943a9a commit 6a75d16

1 file changed

Lines changed: 94 additions & 77 deletions

File tree

applications/Automatic_label/automatic_label.py

Lines changed: 94 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,12 @@
2525
import paddle.nn.functional as F
2626
from PIL import Image, ImageDraw, ImageFont
2727

28-
2928
from paddlevlp.processors.groundingdino_processing import GroudingDinoProcessor
3029
from paddlevlp.models.groundingdino.modeling import GroundingDinoModel
3130
from paddlevlp.models.sam.modeling import SamModel
3231
from paddlevlp.processors.sam_processing import SamProcessor
32+
from paddlenlp.transformers import AutoTokenizer
33+
from paddlevlp.processors.blip_processing import BlipImageProcessor, BlipTextProcessor
3334
from paddlevlp.models.blip2.modeling import Blip2ForConditionalGeneration
3435
from paddlevlp.processors.blip_processing import Blip2Processor
3536
import nltk
@@ -42,7 +43,7 @@ def show_mask(mask, ax, random_color=False):
4243
if random_color:
4344
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
4445
else:
45-
color = np.array([30/255, 144/255, 255/255, 0.6])
46+
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
4647
h, w = mask.shape[-2:]
4748
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
4849
ax.imshow(mask_image)
@@ -51,7 +52,9 @@ def show_mask(mask, ax, random_color=False):
5152
def show_box(box, ax, label):
5253
x0, y0 = box[0], box[1]
5354
w, h = box[2] - box[0], box[3] - box[1]
54-
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
55+
ax.add_patch(
56+
plt.Rectangle(
57+
(x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
5558
ax.text(x0, y0, label)
5659

5760

@@ -64,10 +67,12 @@ class DataArguments:
6467
the command line.
6568
"""
6669

67-
input_image: str = field(
68-
metadata={"help": "The name of input image."}
69-
)
70-
70+
input_image: str = field(metadata={"help": "The name of input image."})
71+
72+
prompt: str = field(
73+
default="describe the image",
74+
metadata={"help": "The prompt of the image to be generated."
75+
}) # "Question: how many cats are there? Answer:"
7176

7277

7378
@dataclass
@@ -76,152 +81,162 @@ class ModelArguments:
7681
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
7782
"""
7883
blip2_model_name_or_path: str = field(
79-
default="Salesforce/blip2-opt-2.7b",
80-
metadata={"help": "Path to pretrained model or model identifier"},
81-
)
84+
default="paddlemix/blip2-caption-opt2.7b",
85+
metadata={"help": "Path to pretrained model or model identifier"}, )
86+
text_model_name_or_path: str = field(
87+
default="facebook/opt-2.7b",
88+
metadata={"help": "The type of text model to use (OPT, T5)."}, )
8289
dino_model_name_or_path: str = field(
8390
default="GroundingDino/groundingdino-swint-ogc",
84-
metadata={"help": "Path to pretrained model or model identifier"},
85-
)
91+
metadata={"help": "Path to pretrained model or model identifier"}, )
8692
sam_model_name_or_path: str = field(
8793
default="Sam/SamVitH-1024",
88-
metadata={"help": "Path to pretrained model or model identifier"},
89-
)
94+
metadata={"help": "Path to pretrained model or model identifier"}, )
9095
box_threshold: float = field(
9196
default=0.3,
92-
metadata={
93-
"help": "box threshold."
94-
},
95-
)
97+
metadata={"help": "box threshold."}, )
9698
text_threshold: float = field(
9799
default=0.25,
98-
metadata={
99-
"help": "text threshold."
100-
},
101-
)
100+
metadata={"help": "text threshold."}, )
102101
output_dir: str = field(
103102
default="automatic_label",
104-
metadata={
105-
"help": "output directory."
106-
},
107-
)
103+
metadata={"help": "output directory."}, )
108104
visual: bool = field(
109105
default=True,
110-
metadata={
111-
"help": "save visual image."
112-
},
113-
)
106+
metadata={"help": "save visual image."}, )
107+
108+
109+
def generate_caption(raw_image, prompt, processor, blip2_model):
114110

115-
def generate_caption(raw_image, processor,blip2_model):
116-
117111
inputs = processor(
118112
images=raw_image,
119-
text=None,
113+
text=prompt,
120114
return_tensors="pd",
121115
return_attention_mask=True,
122-
mode="test",
123-
)
116+
mode="test", )
124117
generated_ids, scores = blip2_model.generate(**inputs)
125-
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[
126-
0
127-
].strip()
118+
generated_text = processor.batch_decode(
119+
generated_ids, skip_special_tokens=True)[0].strip()
128120
logger.info("Generate text: {}".format(generated_text))
129121

130122
return generated_text
131123

124+
132125
def generate_tags(caption):
133126
lemma = nltk.wordnet.WordNetLemmatizer()
134-
127+
135128
nltk.download(['punkt', 'averaged_perceptron_tagger', 'wordnet'])
136-
tags_list = [word for (word, pos) in nltk.pos_tag(nltk.word_tokenize(caption)) if pos[0] == 'N']
129+
tags_list = [
130+
word for (word, pos) in nltk.pos_tag(nltk.word_tokenize(caption))
131+
if pos[0] == 'N'
132+
]
137133
tags_lemma = [lemma.lemmatize(w) for w in tags_list]
138134
tags = ', '.join(map(str, tags_lemma))
139135

140136
return tags
141137

138+
142139
def main():
143140
parser = PdArgumentParser((ModelArguments, DataArguments))
144141
model_args, data_args = parser.parse_args_into_dataclasses()
145142
url = (data_args.input_image)
146143

147144
logger.info("blip2_model: {}".format(model_args.blip2_model_name_or_path))
148145
#bulid blip2 processor
149-
blip2_processor = Blip2Processor.from_pretrained(
150-
model_args.blip2_model_name_or_path
151-
) # "Salesforce/blip2-opt-2.7b"
152-
#bulid blip2 model
153-
blip2_model = Blip2ForConditionalGeneration.from_pretrained(model_args.blip2_model_name_or_path)
154-
146+
blip2_tokenizer_class = AutoTokenizer.from_pretrained(
147+
model_args.text_model_name_or_path, use_fast=False)
148+
blip2_image_processor = BlipImageProcessor.from_pretrained(
149+
os.path.join(model_args.blip2_model_name_or_path, "processor", "eval"))
150+
blip2_text_processor_class = BlipTextProcessor.from_pretrained(
151+
os.path.join(model_args.blip2_model_name_or_path, "processor", "eval"))
152+
blip2_processor = Blip2Processor(blip2_image_processor,
153+
blip2_text_processor_class,
154+
blip2_tokenizer_class)
155+
156+
# #bulid blip2 model
157+
blip2_model = Blip2ForConditionalGeneration.from_pretrained(
158+
model_args.blip2_model_name_or_path)
159+
paddle.device.cuda.empty_cache()
155160
blip2_model.eval()
156-
blip2_model.to("gpu")
161+
157162
logger.info("blip2_model build finish!")
158-
163+
159164
logger.info("dino_model: {}".format(model_args.dino_model_name_or_path))
160165
#bulid dino processor
161166
dino_processor = GroudingDinoProcessor.from_pretrained(
162-
model_args.dino_model_name_or_path
163-
)
167+
model_args.dino_model_name_or_path)
164168
#bulid dino model
165-
dino_model = GroundingDinoModel.from_pretrained(model_args.dino_model_name_or_path)
169+
dino_model = GroundingDinoModel.from_pretrained(
170+
model_args.dino_model_name_or_path)
166171
dino_model.eval()
167172
logger.info("dino_model build finish!")
168173

169174
#buidl sam processor
170175
sam_processor = SamProcessor.from_pretrained(
171-
model_args.sam_model_name_or_path
172-
)
176+
model_args.sam_model_name_or_path)
173177
#bulid model
174178
logger.info("SamModel: {}".format(model_args.sam_model_name_or_path))
175-
sam_model = SamModel.from_pretrained(model_args.sam_model_name_or_path,input_type="boxs")
179+
sam_model = SamModel.from_pretrained(
180+
model_args.sam_model_name_or_path, input_type="boxs")
176181
logger.info("SamModel build finish!")
177-
182+
178183
#read image
179184
if os.path.isfile(url):
180185
#read image
181186
image_pil = Image.open(data_args.input_image)
182187
else:
183188
image_pil = Image.open(requests.get(url, stream=True).raw)
184-
185-
caption = generate_caption(image_pil,processor=blip2_processor,blip2_model=blip2_model)
186-
prompt = generate_tags(caption)
187-
logger.info("prompt: {}".format(prompt))
189+
190+
caption = generate_caption(
191+
image_pil,
192+
prompt=data_args.prompt,
193+
processor=blip2_processor,
194+
blip2_model=blip2_model)
195+
196+
det_prompt = generate_tags(caption)
197+
logger.info("det prompt: {}".format(det_prompt))
188198

189199
image_pil = image_pil.convert("RGB")
190200

191201
#preprocess image text_prompt
192-
image_tensor,mask,tokenized_out = dino_processor(images=image_pil,text=prompt)
202+
image_tensor, mask, tokenized_out = dino_processor(
203+
images=image_pil, text=det_prompt)
193204

194205
with paddle.no_grad():
195-
outputs = dino_model(image_tensor,mask, input_ids=tokenized_out['input_ids'],
196-
attention_mask=tokenized_out['attention_mask'],text_self_attention_masks=tokenized_out['text_self_attention_masks'],
197-
position_ids=tokenized_out['position_ids'])
206+
outputs = dino_model(
207+
image_tensor,
208+
mask,
209+
input_ids=tokenized_out['input_ids'],
210+
attention_mask=tokenized_out['attention_mask'],
211+
text_self_attention_masks=tokenized_out[
212+
'text_self_attention_masks'],
213+
position_ids=tokenized_out['position_ids'])
198214

199215
logits = F.sigmoid(outputs["pred_logits"])[0] # (nq, 256)
200216
boxes = outputs["pred_boxes"][0] # (nq, 4)
201217

202-
# filter output
218+
# filter output
203219
logits_filt = logits.clone()
204220
boxes_filt = boxes.clone()
205221
filt_mask = logits_filt.max(axis=1) > model_args.box_threshold
206222
logits_filt = logits_filt[filt_mask] # num_filt, 256
207223
boxes_filt = boxes_filt[filt_mask] # num_filt, 4
208224

209-
# build pred
225+
# build pred
210226
pred_phrases = []
211227
for logit, box in zip(logits_filt, boxes_filt):
212228
pred_phrase = dino_processor.decode(logit > model_args.text_threshold)
213229
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
214230

215-
216231
size = image_pil.size
217232
pred_dict = {
218233
"boxes": boxes_filt,
219234
"size": [size[1], size[0]], # H,W
220235
"labels": pred_phrases,
221236
}
222237
logger.info("dino output{}".format(pred_dict))
223-
224-
H,W = size[1], size[0]
238+
239+
H, W = size[1], size[0]
225240
boxes = []
226241
for box in zip(boxes_filt):
227242
box = box[0] * paddle.to_tensor([W, H, W, H])
@@ -231,12 +246,13 @@ def main():
231246
x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
232247
boxes.append([x0, y0, x1, y1])
233248
boxes = np.array(boxes)
234-
image_seg,prompt = sam_processor(image_pil,input_type="boxs",box=boxes,point_coords=None)
235-
seg_masks = sam_model(img=image_seg,prompt=prompt)
249+
image_seg, prompt = sam_processor(
250+
image_pil, input_type="boxs", box=boxes, point_coords=None)
251+
seg_masks = sam_model(img=image_seg, prompt=prompt)
236252
seg_masks = sam_processor.postprocess_masks(seg_masks)
237253

238254
logger.info("Sam finish!")
239-
255+
240256
if model_args.visual:
241257
# make dir
242258
os.makedirs(model_args.output_dir, exist_ok=True)
@@ -247,16 +263,17 @@ def main():
247263
show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
248264
for box, label in zip(boxes, pred_phrases):
249265
show_box(box, plt.gca(), label)
250-
266+
251267
plt.title(caption)
252268
plt.axis('off')
253269
plt.savefig(
254-
os.path.join(model_args.output_dir, 'mask_pred.jpg'),
255-
bbox_inches="tight", dpi=300, pad_inches=0.0
256-
)
270+
os.path.join(model_args.output_dir, 'mask_pred.jpg'),
271+
bbox_inches="tight",
272+
dpi=300,
273+
pad_inches=0.0)
257274

258275
logger.info("finish!")
259-
276+
260277

261278
if __name__ == "__main__":
262-
main()
279+
main()

0 commit comments

Comments
 (0)