1- package com .github .mfl28 .boundingboxeditor .model .io ;
1+ package com .github .mfl28 .boundingboxeditor .model .io . restclients ;
22
3- import com .google .gson .*;
4- import com .google .gson .reflect .TypeToken ;
3+ import com .google .gson .JsonSyntaxException ;
54import org .glassfish .jersey .media .multipart .FormDataMultiPart ;
65import org .glassfish .jersey .media .multipart .MultiPart ;
76import org .glassfish .jersey .media .multipart .MultiPartFeature ;
1211import javax .ws .rs .client .ClientBuilder ;
1312import javax .ws .rs .client .Entity ;
1413import javax .ws .rs .client .WebTarget ;
14+ import javax .ws .rs .core .GenericType ;
1515import javax .ws .rs .core .MediaType ;
1616import javax .ws .rs .core .Response ;
1717import java .io .InputStream ;
1818import java .net .ConnectException ;
19- import java .util .HashMap ;
2019import java .util .List ;
21- import java .util .Map ;
2220
2321public class TorchServeRestClient implements BoundingBoxPredictorClient {
2422 private static final String MODELS_RESOURCE_NAME = "models" ;
@@ -28,142 +26,93 @@ public class TorchServeRestClient implements BoundingBoxPredictorClient {
2826 private static final String PREDICTIONS_RESOURCE_NAME = "predictions" ;
2927 private static final String SERVER_PREDICTION_POST_ERROR_MESSAGE =
3028 "Could not get prediction from supplied inference server" ;
31- private final Client client = ClientBuilder .newBuilder ().register (MultiPartFeature .class ).build ();
29+ private final Client client = ClientBuilder .newBuilder ()
30+ .register (MultiPartFeature .class )
31+ .register (GsonMessageBodyHandler .class )
32+ .build ();
3233 private final BoundingBoxPredictorClientConfig clientConfig ;
3334
3435 public TorchServeRestClient (BoundingBoxPredictorClientConfig clientConfig ) {
3536 this .clientConfig = clientConfig ;
36- client .register (MultiPartFeature .class );
3737 }
3838
39- public String ping () {
40- return client .target (clientConfig .getInferenceAddress ())
41- .path ("ping" )
42- .request (MediaType .APPLICATION_JSON )
43- .get ()
44- .readEntity (String .class );
45- }
46-
47- public List <ModelEntry > models () throws PredictionClientException {
48- WebTarget managementTarget ;
39+ @ Override
40+ public List <BoundingBoxPredictionEntry > predict (InputStream input ) throws PredictionClientException {
41+ final MultiPart multiPart = new FormDataMultiPart ().bodyPart (new StreamDataBodyPart (DATA_BODY_PART_NAME ,
42+ input ));
43+ WebTarget predictionTarget ;
4944
5045 try {
51- managementTarget = client .target (clientConfig .getManagementAddress ());
46+ predictionTarget = client .target (clientConfig .getInferenceAddress ());
5247 } catch (IllegalArgumentException | NullPointerException e ) {
53- throw new PredictionClientException ("Invalid torch serve management address and/ or port." );
48+ throw new PredictionClientException ("Invalid torch serve inference address or port." );
5449 }
5550
5651 Response response ;
5752
5853 try {
59- response = managementTarget
60- .path (MODELS_RESOURCE_NAME )
61- .request (MediaType .APPLICATION_JSON )
62- .get ();
54+ response = predictionTarget .path (PREDICTIONS_RESOURCE_NAME ).path (clientConfig .getInferenceModelName ())
55+ .request (MediaType .APPLICATION_JSON )
56+ .post (Entity .entity (multiPart , multiPart .getMediaType ()));
6357 } catch (ProcessingException | IllegalArgumentException | IllegalStateException e ) {
6458 if (e .getCause () instanceof ConnectException ) {
65- throw new PredictionClientException ("Could not connect to supplied management server." );
59+ throw new PredictionClientException ("Could not connect to supplied inference server." );
6660 } else {
67- throw new PredictionClientException (SERVER_MODELS_READ_ERROR_MESSAGE );
61+ throw new PredictionClientException (SERVER_PREDICTION_POST_ERROR_MESSAGE );
6862 }
6963 }
7064
7165 if (!response .getStatusInfo ().equals (Response .Status .OK )) {
72- throw new PredictionClientException (SERVER_MODELS_READ_ERROR_MESSAGE );
66+ throw new PredictionClientException (SERVER_PREDICTION_POST_ERROR_MESSAGE );
7367 }
7468
75- String modelsJson ;
76-
7769 try {
78- modelsJson = response .readEntity (String . class );
70+ return response .readEntity (new GenericType <>() {} );
7971 } catch (ProcessingException | IllegalStateException e ) {
80- throw new PredictionClientException (SERVER_MODELS_READ_ERROR_MESSAGE );
81- }
82-
83- List <ModelEntry > models ;
84-
85- try {
86- models = new Gson ().fromJson (modelsJson , ModelsWrapper .class ).getModels ();
72+ throw new PredictionClientException (SERVER_PREDICTION_POST_ERROR_MESSAGE );
8773 } catch (JsonSyntaxException e ) {
88- throw new PredictionClientException ("Invalid torch serve management server response format for \" "
89- + MODELS_RESOURCE_NAME + "\" resource." );
74+ throw new PredictionClientException ("Invalid torch serve inference server response format for \" " +
75+ PREDICTIONS_RESOURCE_NAME + "\" resource." );
9076 }
91-
92- return models ;
9377 }
9478
9579 @ Override
96- public List <BoundingBoxPredictionEntry > predict (InputStream input ) throws PredictionClientException {
97- final MultiPart multiPart = new FormDataMultiPart ().bodyPart (new StreamDataBodyPart (DATA_BODY_PART_NAME ,
98- input ));
99-
100- WebTarget predictionTarget ;
80+ public List <ModelEntry > models () throws PredictionClientException {
81+ WebTarget managementTarget ;
10182
10283 try {
103- predictionTarget = client .target (clientConfig .getInferenceAddress ());
84+ managementTarget = client .target (clientConfig .getManagementAddress ());
10485 } catch (IllegalArgumentException | NullPointerException e ) {
105- throw new PredictionClientException ("Invalid torch serve inference address and/ or port." );
86+ throw new PredictionClientException ("Invalid torch serve management address or port." );
10687 }
10788
10889 Response response ;
10990
11091 try {
111- response = predictionTarget .path (PREDICTIONS_RESOURCE_NAME ).path (clientConfig .getInferenceModelName ())
112- .request (MediaType .APPLICATION_JSON )
113- .post (Entity .entity (multiPart , multiPart .getMediaType ()));
92+ response = managementTarget
93+ .path (MODELS_RESOURCE_NAME )
94+ .request (MediaType .APPLICATION_JSON )
95+ .get ();
11496 } catch (ProcessingException | IllegalArgumentException | IllegalStateException e ) {
11597 if (e .getCause () instanceof ConnectException ) {
116- throw new PredictionClientException ("Could not connect to supplied inference server." );
98+ throw new PredictionClientException ("Could not connect to supplied management server." );
11799 } else {
118- throw new PredictionClientException (SERVER_PREDICTION_POST_ERROR_MESSAGE );
100+ throw new PredictionClientException (SERVER_MODELS_READ_ERROR_MESSAGE );
119101 }
120102 }
121103
122104 if (!response .getStatusInfo ().equals (Response .Status .OK )) {
123- throw new PredictionClientException (SERVER_PREDICTION_POST_ERROR_MESSAGE );
105+ throw new PredictionClientException (SERVER_MODELS_READ_ERROR_MESSAGE );
124106 }
125107
126- String predictionJson ;
127-
128108 try {
129- predictionJson = response .readEntity (String .class );
109+ return response .readEntity (ModelsWrapper .class ). getModels ( );
130110 } catch (ProcessingException | IllegalStateException e ) {
131- throw new PredictionClientException (SERVER_PREDICTION_POST_ERROR_MESSAGE );
132- }
133-
134- final Gson gson = new GsonBuilder ()
135- .registerTypeAdapter (BoundingBoxPredictionEntry .class ,
136- (JsonDeserializer <BoundingBoxPredictionEntry >) (json , type , context ) ->
137- {
138- final JsonObject jsonObject = json .getAsJsonObject ();
139- double score = jsonObject .get ("score" ).getAsDouble ();
140-
141- Map <String , List <Double >> categoryToBoundingBox = new HashMap <>();
142-
143- for (Map .Entry <String , JsonElement > entry : jsonObject .entrySet ()) {
144- if (!entry .getKey ().equals ("score" ) && entry .getValue ().isJsonArray ()) {
145- categoryToBoundingBox .put (entry .getKey (),
146- context .deserialize (entry .getValue (),
147- new TypeToken <List <Double >>() {}
148- .getType ()));
149- }
150- }
151-
152- return new BoundingBoxPredictionEntry (categoryToBoundingBox , score );
153-
154- }).create ();
155-
156- List <BoundingBoxPredictionEntry > boundingBoxPredictionEntries ;
157-
158- try {
159- boundingBoxPredictionEntries =
160- gson .fromJson (predictionJson , new TypeToken <List <BoundingBoxPredictionEntry >>() {}.getType ());
111+ throw new PredictionClientException (SERVER_MODELS_READ_ERROR_MESSAGE );
161112 } catch (JsonSyntaxException e ) {
162- throw new PredictionClientException ("Invalid torch serve inference server response format for \" " +
163- PREDICTIONS_RESOURCE_NAME + "\" resource." );
113+ throw new PredictionClientException ("Invalid torch serve management server response format for \" "
114+ + MODELS_RESOURCE_NAME + "\" resource." );
164115 }
165-
166- return boundingBoxPredictionEntries ;
167116 }
168117
169118 public static class ModelEntry {
0 commit comments