diff --git a/pkg/api/handlers/libpod/images_pull.go b/pkg/api/handlers/libpod/images_pull.go index b88b8442bf6..d34b6bd5643 100644 --- a/pkg/api/handlers/libpod/images_pull.go +++ b/pkg/api/handlers/libpod/images_pull.go @@ -23,6 +23,9 @@ import ( "go.podman.io/image/v5/types" ) +// The duration for which we are willing to wait before starting the stream, to be able to decide the HTTP status code more accurately. +const maximalStreamingDelay = 5 * time.Second + // ImagesPull is the v2 libpod endpoint for pulling images. Note that the // mandatory `reference` must be a reference to a registry (i.e., of docker // transport or be normalized to one). Other transports are rejected as they @@ -116,10 +119,12 @@ func ImagesPull(w http.ResponseWriter, r *http.Request) { // Let's keep thing simple when running in quiet mode and pull directly. if query.Quiet { images, err := runtime.LibimageRuntime().Pull(r.Context(), query.Reference, pullPolicy, pullOptions) - var report entities.ImagePullReport if err != nil { - report.Error = err.Error() + utils.Error(w, utils.HTTPStatusFromRegistryError(err), err) + return } + + var report entities.ImagePullReport for _, image := range images { report.Images = append(report.Images, image.ID()) // Pull last ID from list and publish in 'id' stanza. This maintains previous API contract @@ -138,6 +143,9 @@ func ImagesPull(w http.ResponseWriter, r *http.Request) { defer writer.Close() pullOptions.Writer = writer + progress := make(chan types.ProgressProperties) + pullOptions.Progress = progress + var pulledImages []*libimage.Image var pullError error runCtx, cancel := context.WithCancel(r.Context()) @@ -152,22 +160,58 @@ func ImagesPull(w http.ResponseWriter, r *http.Request) { } } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - flush() - enc := json.NewEncoder(w) enc.SetEscapeHTML(true) + + streamingStarted := false + var reportBuffer []entities.ImagePullReport + + // This timer ensures that streaming is not delayed indefinitely. + streamingDelayTimer := time.NewTimer(maximalStreamingDelay) + + // Streaming is delayed to choose the HTTP status code more accurately (e.g. if the image does not exist at all). + // Once a message is streamed, it is no longer possible to change the status code. + startStreaming := func() { + if !streamingStarted { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + + for _, report := range reportBuffer { + if err := enc.Encode(report); err != nil { + logrus.Warnf("Failed to encode json: %v", err) + } + flush() + } + streamingStarted = true + } + } + for { - var report entities.ImagePullReport select { + case <-progress: + startStreaming() // We are reporting progress working with the image contents, so it presumably exists and it is accessible. + case <-streamingDelayTimer.C: + startStreaming() // At some point, give up on the precise error code and let the caller show an “operation in progress, no data available yet” UI. case s := <-writer.Chan(): - report.Stream = string(s) - if err := enc.Encode(report); err != nil { - logrus.Warnf("Failed to encode json: %v", err) + report := entities.ImagePullReport{ + Stream: string(s), + } + if streamingStarted { + if err := enc.Encode(report); err != nil { + logrus.Warnf("Failed to encode json: %v", err) + } + flush() + } else { + reportBuffer = append(reportBuffer, report) } - flush() case <-runCtx.Done(): + if !streamingStarted && pullError != nil { + utils.Error(w, utils.HTTPStatusFromRegistryError(pullError), pullError) + return + } + + startStreaming() + var report entities.ImagePullReport for _, image := range pulledImages { report.Images = append(report.Images, image.ID()) // Pull last ID from list and publish in 'id' stanza. This maintains previous API contract diff --git a/pkg/api/handlers/swagger/errors.go b/pkg/api/handlers/swagger/errors.go index 8e7d443d04c..9956d666fb1 100644 --- a/pkg/api/handlers/swagger/errors.go +++ b/pkg/api/handlers/swagger/errors.go @@ -44,6 +44,13 @@ type artifactBadAuth struct { Body errorhandling.ErrorModel } +// Error from registry +// swagger:response +type errorFromRegistry struct { + // in:body + Body errorhandling.ErrorModel +} + // No such network // swagger:response type networkNotFound struct { diff --git a/pkg/api/handlers/swagger/responses.go b/pkg/api/handlers/swagger/responses.go index 997da00d599..525dc6a2c42 100644 --- a/pkg/api/handlers/swagger/responses.go +++ b/pkg/api/handlers/swagger/responses.go @@ -59,7 +59,7 @@ type imagesImportResponseLibpod struct { Body entities.ImageImportReport } -// Image Pull +// Image Pull. Errors may be detected later even if this returns HTTP status 200, and in that case, the error description will be in the `error` field. // swagger:response type imagesPullResponseLibpod struct { // in:body diff --git a/pkg/api/handlers/utils/images.go b/pkg/api/handlers/utils/images.go index 99a4d40aaf1..b1c750049b2 100644 --- a/pkg/api/handlers/utils/images.go +++ b/pkg/api/handlers/utils/images.go @@ -252,12 +252,7 @@ loop: // break out of for/select infinite loop case pullRes := <-pullResChan: err := pullRes.err if err != nil { - var errcd errcode.ErrorCoder - if errors.As(err, &errcd) { - writeStatusCode(errcd.ErrorCode().Descriptor().HTTPStatusCode) - } else { - writeStatusCode(http.StatusInternalServerError) - } + writeStatusCode(HTTPStatusFromRegistryError(err)) msg := err.Error() report.Error = &jsonstream.Error{ Message: msg, @@ -305,3 +300,14 @@ loop: // break out of for/select infinite loop } } } + +func HTTPStatusFromRegistryError(err error) int { + if err == nil { + return http.StatusOK + } + var errcd errcode.ErrorCoder + if errors.As(err, &errcd) { + return errcd.ErrorCode().Descriptor().HTTPStatusCode + } + return http.StatusInternalServerError +} diff --git a/pkg/api/server/register_images.go b/pkg/api/server/register_images.go index 0c4fa1257b6..5802457cd69 100644 --- a/pkg/api/server/register_images.go +++ b/pkg/api/server/register_images.go @@ -1103,7 +1103,7 @@ func (s *APIServer) registerImagesHandlers(r *mux.Router) error { // tags: // - images // summary: Pull images - // description: Pull one or more images from a container registry. + // description: Pull one or more images from a container registry. Error status codes can come either from the API or from the registry. Errors may be detected later even if the HTTP status 200 is returned, and in that case, the error description will be in the `error` field. // parameters: // - in: query // name: reference @@ -1157,6 +1157,8 @@ func (s *APIServer) registerImagesHandlers(r *mux.Router) error { // $ref: "#/responses/badParamError" // 500: // $ref: '#/responses/internalError' + // default: + // $ref: "#/responses/errorFromRegistry" r.Handle(VersionedPath("/libpod/images/pull"), s.APIHandler(libpod.ImagesPull)).Methods(http.MethodPost) // swagger:operation POST /libpod/images/prune libpod ImagePruneLibpod // --- diff --git a/test/apiv2/python/rest_api/test_v2_0_0_image.py b/test/apiv2/python/rest_api/test_v2_0_0_image.py index bd37d7a61fd..40f84324dc9 100644 --- a/test/apiv2/python/rest_api/test_v2_0_0_image.py +++ b/test/apiv2/python/rest_api/test_v2_0_0_image.py @@ -68,47 +68,87 @@ def test_delete(self): self.assertEqual(r.status_code, 409, r.text) def test_pull(self): - r = requests.post(self.uri("/images/pull?reference=alpine"), timeout=15) - self.assertEqual(r.status_code, 200, r.status_code) - text = r.text - keys = { - "error": False, - "id": False, - "images": False, - "stream": False, - } - # Read and record stanza's from pull - for line in str.splitlines(text): - obj = json.loads(line) - key_list = list(obj.keys()) - for k in key_list: - keys[k] = True - - self.assertFalse(keys["error"], "Expected no errors") - self.assertTrue(keys["id"], "Expected to find id stanza") - self.assertTrue(keys["images"], "Expected to find images stanza") - self.assertTrue(keys["stream"], "Expected to find stream progress stanza's") - - r = requests.post(self.uri("/images/pull?reference=alpine&quiet=true"), timeout=15) - self.assertEqual(r.status_code, 200, r.status_code) - text = r.text - keys = { - "error": False, - "id": False, - "images": False, - "stream": False, - } - # Read and record stanza's from pull - for line in str.splitlines(text): - obj = json.loads(line) - key_list = list(obj.keys()) - for k in key_list: - keys[k] = True - - self.assertFalse(keys["error"], "Expected no errors") - self.assertTrue(keys["id"], "Expected to find id stanza") - self.assertTrue(keys["images"], "Expected to find images stanza") - self.assertFalse(keys["stream"], "Expected to find stream progress stanza's") + def check_response_keys(r, keys_expected): + text = r.text + keys_found = set() + + # Read and record stanza's from pull + for line in str.splitlines(text): + obj = json.loads(line) + key_list = list(obj.keys()) + for k in key_list: + keys_found.add(k) + + for key, expected in keys_expected.items(): + if expected: + negation = "" + else: + negation = "not " + self.assertEqual( + key in keys_found, + expected, + f'Expected {negation}to find "{key}" stanza in response', + ) + + existing_reference = "alpine" + non_existing_reference = "quay.io/f4ee35641334/f6fda4bb" + cases = [ + dict( + quiet_postfix="&quiet=True", + reference=existing_reference, + timeout=15, + assert_function=self.assertEqual, + expected_keys={ + "error": False, + "id": True, + "images": True, + "stream": False, + }, + ), + dict( + quiet_postfix="", + reference=existing_reference, + timeout=15, + assert_function=self.assertEqual, + expected_keys={ + "error": False, + "id": True, + "images": True, + "stream": True, + }, + ), + dict( + quiet_postfix="&quiet=True", + reference=non_existing_reference, + timeout=None, + assert_function=self.assertNotEqual, + expected_keys={ + "cause": True, + "message": True, + "response": True, + }, + ), + dict( + quiet_postfix="", + reference=non_existing_reference, + timeout=None, + assert_function=self.assertNotEqual, + expected_keys={ + "cause": True, + "message": True, + "response": True, + }, + ), + ] + + for case in cases: + with self.subTest(case=case): + r = requests.post( + self.uri(f"/images/pull?reference={case['reference']}{case['quiet_postfix']}"), + timeout=case["timeout"], + ) + case["assert_function"](r.status_code, 200, r.status_code) + check_response_keys(r, case["expected_keys"]) def test_create(self): r = requests.post(