2525import paddle .nn .functional as F
2626from PIL import Image , ImageDraw , ImageFont
2727
28-
2928from paddlevlp .processors .groundingdino_processing import GroudingDinoProcessor
3029from paddlevlp .models .groundingdino .modeling import GroundingDinoModel
3130from paddlevlp .models .sam .modeling import SamModel
3231from paddlevlp .processors .sam_processing import SamProcessor
32+ from paddlenlp .transformers import AutoTokenizer
33+ from paddlevlp .processors .blip_processing import BlipImageProcessor , BlipTextProcessor
3334from paddlevlp .models .blip2 .modeling import Blip2ForConditionalGeneration
3435from paddlevlp .processors .blip_processing import Blip2Processor
3536import nltk
@@ -42,7 +43,7 @@ def show_mask(mask, ax, random_color=False):
4243 if random_color :
4344 color = np .concatenate ([np .random .random (3 ), np .array ([0.6 ])], axis = 0 )
4445 else :
45- color = np .array ([30 / 255 , 144 / 255 , 255 / 255 , 0.6 ])
46+ color = np .array ([30 / 255 , 144 / 255 , 255 / 255 , 0.6 ])
4647 h , w = mask .shape [- 2 :]
4748 mask_image = mask .reshape (h , w , 1 ) * color .reshape (1 , 1 , - 1 )
4849 ax .imshow (mask_image )
@@ -51,7 +52,9 @@ def show_mask(mask, ax, random_color=False):
5152def show_box (box , ax , label ):
5253 x0 , y0 = box [0 ], box [1 ]
5354 w , h = box [2 ] - box [0 ], box [3 ] - box [1 ]
54- ax .add_patch (plt .Rectangle ((x0 , y0 ), w , h , edgecolor = 'green' , facecolor = (0 ,0 ,0 ,0 ), lw = 2 ))
55+ ax .add_patch (
56+ plt .Rectangle (
57+ (x0 , y0 ), w , h , edgecolor = 'green' , facecolor = (0 , 0 , 0 , 0 ), lw = 2 ))
5558 ax .text (x0 , y0 , label )
5659
5760
@@ -64,10 +67,12 @@ class DataArguments:
6467 the command line.
6568 """
6669
67- input_image : str = field (
68- metadata = {"help" : "The name of input image." }
69- )
70-
70+ input_image : str = field (metadata = {"help" : "The name of input image." })
71+
72+ prompt : str = field (
73+ default = "describe the image" ,
74+ metadata = {"help" : "The prompt of the image to be generated."
75+ }) # "Question: how many cats are there? Answer:"
7176
7277
7378@dataclass
@@ -76,152 +81,162 @@ class ModelArguments:
7681 Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
7782 """
7883 blip2_model_name_or_path : str = field (
79- default = "Salesforce/blip2-opt-2.7b" ,
80- metadata = {"help" : "Path to pretrained model or model identifier" },
81- )
84+ default = "paddlemix/blip2-caption-opt2.7b" ,
85+ metadata = {"help" : "Path to pretrained model or model identifier" }, )
86+ text_model_name_or_path : str = field (
87+ default = "facebook/opt-2.7b" ,
88+ metadata = {"help" : "The type of text model to use (OPT, T5)." }, )
8289 dino_model_name_or_path : str = field (
8390 default = "GroundingDino/groundingdino-swint-ogc" ,
84- metadata = {"help" : "Path to pretrained model or model identifier" },
85- )
91+ metadata = {"help" : "Path to pretrained model or model identifier" }, )
8692 sam_model_name_or_path : str = field (
8793 default = "Sam/SamVitH-1024" ,
88- metadata = {"help" : "Path to pretrained model or model identifier" },
89- )
94+ metadata = {"help" : "Path to pretrained model or model identifier" }, )
9095 box_threshold : float = field (
9196 default = 0.3 ,
92- metadata = {
93- "help" : "box threshold."
94- },
95- )
97+ metadata = {"help" : "box threshold." }, )
9698 text_threshold : float = field (
9799 default = 0.25 ,
98- metadata = {
99- "help" : "text threshold."
100- },
101- )
100+ metadata = {"help" : "text threshold." }, )
102101 output_dir : str = field (
103102 default = "automatic_label" ,
104- metadata = {
105- "help" : "output directory."
106- },
107- )
103+ metadata = {"help" : "output directory." }, )
108104 visual : bool = field (
109105 default = True ,
110- metadata = {
111- "help" : "save visual image."
112- },
113- )
106+ metadata = {"help" : "save visual image." }, )
107+
108+
109+ def generate_caption ( raw_image , prompt , processor , blip2_model ):
114110
115- def generate_caption (raw_image , processor ,blip2_model ):
116-
117111 inputs = processor (
118112 images = raw_image ,
119- text = None ,
113+ text = prompt ,
120114 return_tensors = "pd" ,
121115 return_attention_mask = True ,
122- mode = "test" ,
123- )
116+ mode = "test" , )
124117 generated_ids , scores = blip2_model .generate (** inputs )
125- generated_text = processor .batch_decode (generated_ids , skip_special_tokens = True )[
126- 0
127- ].strip ()
118+ generated_text = processor .batch_decode (
119+ generated_ids , skip_special_tokens = True )[0 ].strip ()
128120 logger .info ("Generate text: {}" .format (generated_text ))
129121
130122 return generated_text
131123
124+
132125def generate_tags (caption ):
133126 lemma = nltk .wordnet .WordNetLemmatizer ()
134-
127+
135128 nltk .download (['punkt' , 'averaged_perceptron_tagger' , 'wordnet' ])
136- tags_list = [word for (word , pos ) in nltk .pos_tag (nltk .word_tokenize (caption )) if pos [0 ] == 'N' ]
129+ tags_list = [
130+ word for (word , pos ) in nltk .pos_tag (nltk .word_tokenize (caption ))
131+ if pos [0 ] == 'N'
132+ ]
137133 tags_lemma = [lemma .lemmatize (w ) for w in tags_list ]
138134 tags = ', ' .join (map (str , tags_lemma ))
139135
140136 return tags
141137
138+
142139def main ():
143140 parser = PdArgumentParser ((ModelArguments , DataArguments ))
144141 model_args , data_args = parser .parse_args_into_dataclasses ()
145142 url = (data_args .input_image )
146143
147144 logger .info ("blip2_model: {}" .format (model_args .blip2_model_name_or_path ))
148145 #bulid blip2 processor
149- blip2_processor = Blip2Processor .from_pretrained (
150- model_args .blip2_model_name_or_path
151- ) # "Salesforce/blip2-opt-2.7b"
152- #bulid blip2 model
153- blip2_model = Blip2ForConditionalGeneration .from_pretrained (model_args .blip2_model_name_or_path )
154-
146+ blip2_tokenizer_class = AutoTokenizer .from_pretrained (
147+ model_args .text_model_name_or_path , use_fast = False )
148+ blip2_image_processor = BlipImageProcessor .from_pretrained (
149+ os .path .join (model_args .blip2_model_name_or_path , "processor" , "eval" ))
150+ blip2_text_processor_class = BlipTextProcessor .from_pretrained (
151+ os .path .join (model_args .blip2_model_name_or_path , "processor" , "eval" ))
152+ blip2_processor = Blip2Processor (blip2_image_processor ,
153+ blip2_text_processor_class ,
154+ blip2_tokenizer_class )
155+
156+ # #bulid blip2 model
157+ blip2_model = Blip2ForConditionalGeneration .from_pretrained (
158+ model_args .blip2_model_name_or_path )
159+ paddle .device .cuda .empty_cache ()
155160 blip2_model .eval ()
156- blip2_model . to ( "gpu" )
161+
157162 logger .info ("blip2_model build finish!" )
158-
163+
159164 logger .info ("dino_model: {}" .format (model_args .dino_model_name_or_path ))
160165 #bulid dino processor
161166 dino_processor = GroudingDinoProcessor .from_pretrained (
162- model_args .dino_model_name_or_path
163- )
167+ model_args .dino_model_name_or_path )
164168 #bulid dino model
165- dino_model = GroundingDinoModel .from_pretrained (model_args .dino_model_name_or_path )
169+ dino_model = GroundingDinoModel .from_pretrained (
170+ model_args .dino_model_name_or_path )
166171 dino_model .eval ()
167172 logger .info ("dino_model build finish!" )
168173
169174 #buidl sam processor
170175 sam_processor = SamProcessor .from_pretrained (
171- model_args .sam_model_name_or_path
172- )
176+ model_args .sam_model_name_or_path )
173177 #bulid model
174178 logger .info ("SamModel: {}" .format (model_args .sam_model_name_or_path ))
175- sam_model = SamModel .from_pretrained (model_args .sam_model_name_or_path ,input_type = "boxs" )
179+ sam_model = SamModel .from_pretrained (
180+ model_args .sam_model_name_or_path , input_type = "boxs" )
176181 logger .info ("SamModel build finish!" )
177-
182+
178183 #read image
179184 if os .path .isfile (url ):
180185 #read image
181186 image_pil = Image .open (data_args .input_image )
182187 else :
183188 image_pil = Image .open (requests .get (url , stream = True ).raw )
184-
185- caption = generate_caption (image_pil ,processor = blip2_processor ,blip2_model = blip2_model )
186- prompt = generate_tags (caption )
187- logger .info ("prompt: {}" .format (prompt ))
189+
190+ caption = generate_caption (
191+ image_pil ,
192+ prompt = data_args .prompt ,
193+ processor = blip2_processor ,
194+ blip2_model = blip2_model )
195+
196+ det_prompt = generate_tags (caption )
197+ logger .info ("det prompt: {}" .format (det_prompt ))
188198
189199 image_pil = image_pil .convert ("RGB" )
190200
191201 #preprocess image text_prompt
192- image_tensor ,mask ,tokenized_out = dino_processor (images = image_pil ,text = prompt )
202+ image_tensor , mask , tokenized_out = dino_processor (
203+ images = image_pil , text = det_prompt )
193204
194205 with paddle .no_grad ():
195- outputs = dino_model (image_tensor ,mask , input_ids = tokenized_out ['input_ids' ],
196- attention_mask = tokenized_out ['attention_mask' ],text_self_attention_masks = tokenized_out ['text_self_attention_masks' ],
197- position_ids = tokenized_out ['position_ids' ])
206+ outputs = dino_model (
207+ image_tensor ,
208+ mask ,
209+ input_ids = tokenized_out ['input_ids' ],
210+ attention_mask = tokenized_out ['attention_mask' ],
211+ text_self_attention_masks = tokenized_out [
212+ 'text_self_attention_masks' ],
213+ position_ids = tokenized_out ['position_ids' ])
198214
199215 logits = F .sigmoid (outputs ["pred_logits" ])[0 ] # (nq, 256)
200216 boxes = outputs ["pred_boxes" ][0 ] # (nq, 4)
201217
202- # filter output
218+ # filter output
203219 logits_filt = logits .clone ()
204220 boxes_filt = boxes .clone ()
205221 filt_mask = logits_filt .max (axis = 1 ) > model_args .box_threshold
206222 logits_filt = logits_filt [filt_mask ] # num_filt, 256
207223 boxes_filt = boxes_filt [filt_mask ] # num_filt, 4
208224
209- # build pred
225+ # build pred
210226 pred_phrases = []
211227 for logit , box in zip (logits_filt , boxes_filt ):
212228 pred_phrase = dino_processor .decode (logit > model_args .text_threshold )
213229 pred_phrases .append (pred_phrase + f"({ str (logit .max ().item ())[:4 ]} )" )
214230
215-
216231 size = image_pil .size
217232 pred_dict = {
218233 "boxes" : boxes_filt ,
219234 "size" : [size [1 ], size [0 ]], # H,W
220235 "labels" : pred_phrases ,
221236 }
222237 logger .info ("dino output{}" .format (pred_dict ))
223-
224- H ,W = size [1 ], size [0 ]
238+
239+ H , W = size [1 ], size [0 ]
225240 boxes = []
226241 for box in zip (boxes_filt ):
227242 box = box [0 ] * paddle .to_tensor ([W , H , W , H ])
@@ -231,12 +246,13 @@ def main():
231246 x0 , y0 , x1 , y1 = int (x0 ), int (y0 ), int (x1 ), int (y1 )
232247 boxes .append ([x0 , y0 , x1 , y1 ])
233248 boxes = np .array (boxes )
234- image_seg ,prompt = sam_processor (image_pil ,input_type = "boxs" ,box = boxes ,point_coords = None )
235- seg_masks = sam_model (img = image_seg ,prompt = prompt )
249+ image_seg , prompt = sam_processor (
250+ image_pil , input_type = "boxs" , box = boxes , point_coords = None )
251+ seg_masks = sam_model (img = image_seg , prompt = prompt )
236252 seg_masks = sam_processor .postprocess_masks (seg_masks )
237253
238254 logger .info ("Sam finish!" )
239-
255+
240256 if model_args .visual :
241257 # make dir
242258 os .makedirs (model_args .output_dir , exist_ok = True )
@@ -247,16 +263,17 @@ def main():
247263 show_mask (mask .cpu ().numpy (), plt .gca (), random_color = True )
248264 for box , label in zip (boxes , pred_phrases ):
249265 show_box (box , plt .gca (), label )
250-
266+
251267 plt .title (caption )
252268 plt .axis ('off' )
253269 plt .savefig (
254- os .path .join (model_args .output_dir , 'mask_pred.jpg' ),
255- bbox_inches = "tight" , dpi = 300 , pad_inches = 0.0
256- )
270+ os .path .join (model_args .output_dir , 'mask_pred.jpg' ),
271+ bbox_inches = "tight" ,
272+ dpi = 300 ,
273+ pad_inches = 0.0 )
257274
258275 logger .info ("finish!" )
259-
276+
260277
261278if __name__ == "__main__" :
262- main ()
279+ main ()
0 commit comments