diff --git a/pkg/azurestore/azureservice.go b/pkg/azurestore/azureservice.go index e410892ec..38c56bd42 100644 --- a/pkg/azurestore/azureservice.go +++ b/pkg/azurestore/azureservice.go @@ -21,7 +21,9 @@ import ( "errors" "fmt" "io" + "net/http" "sort" + "strconv" "strings" "github.com/Azure/azure-sdk-for-go/sdk/azcore" @@ -68,6 +70,8 @@ type AzBlob interface { Upload(ctx context.Context, body io.ReadSeeker) error // Download returns a readcloser to download the contents of the blob Download(ctx context.Context) (io.ReadCloser, error) + // Serves the contents of the blob directly handling special HTTP headers like Range, if set + ServeContent(ctx context.Context, w http.ResponseWriter, r *http.Request) error // Get the offset of the blob and its indexes GetOffset(ctx context.Context) (int64, error) // Commit the uploaded blocks to the BlockBlob @@ -199,6 +203,64 @@ func (blockBlob *BlockBlob) Download(ctx context.Context) (io.ReadCloser, error) return resp.Body, nil } +// Serve content respecting range header +func (blockBlob *BlockBlob) ServeContent(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + var downloadOptions, err = ParseDownloadOptions(r) + if err != nil { + return err + } + result, err := blockBlob.BlobClient.DownloadStream(ctx, downloadOptions) + if err != nil { + return err + } + defer result.Body.Close() + + statusCode := http.StatusOK + if result.ContentRange != nil { + // Use 206 Partial Content for range requests + statusCode = http.StatusPartialContent + } else if result.ContentLength != nil && *result.ContentLength == 0 { + statusCode = http.StatusNoContent + } + + // Add Accept-Ranges,Content-*, Cache-Control, ETag, Expires, Last-Modified headers if present in azure response + if result.AcceptRanges != nil { + w.Header().Set("Accept-Ranges", *result.AcceptRanges) + } + if result.ContentDisposition != nil { + w.Header().Set("Content-Disposition", *result.ContentDisposition) + } + if result.ContentEncoding != nil { + w.Header().Set("Content-Encoding", *result.ContentEncoding) + } + if result.ContentLanguage != nil { + w.Header().Set("Content-Language", *result.ContentLanguage) + } + if result.ContentLength != nil { + w.Header().Set("Content-Length", strconv.FormatInt(*result.ContentLength, 10)) + } + if result.ContentRange != nil { + w.Header().Set("Content-Range", *result.ContentRange) + } + if result.ContentType != nil { + w.Header().Set("Content-Type", *result.ContentType) + } + if result.CacheControl != nil { + w.Header().Set("Cache-Control", *result.CacheControl) + } + if result.ETag != nil && *result.ETag != "" { + w.Header().Set("ETag", string(*result.ETag)) + } + if result.LastModified != nil { + w.Header().Set("Last-Modified", result.LastModified.Format(http.TimeFormat)) + } + + w.WriteHeader(statusCode) + + _, err = io.Copy(w, result.Body) + return err +} + func (blockBlob *BlockBlob) GetOffset(ctx context.Context) (int64, error) { // Get the offset of the file from azure storage // For the blob, show each block (ID and size) that is a committed part of it. @@ -260,6 +322,11 @@ func (infoBlob *InfoBlob) Download(ctx context.Context) (io.ReadCloser, error) { return resp.Body, nil } +// ServeContent is not needed for infoBlob +func (infoBlob *InfoBlob) ServeContent(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + return errors.New("azurestore: ServeContent is not implemented for InfoBlob") +} + // infoBlob does not utilise offset, so just return 0, nil func (infoBlob *InfoBlob) GetOffset(ctx context.Context) (int64, error) { return 0, nil @@ -316,3 +383,47 @@ func checkForNotFoundError(err error) error { } return err } + +// parse the Range, If-Match, If-None-Match, If-Unmodified-Since, If-Modified-Since headers if present +func ParseDownloadOptions(r *http.Request) (*azblob.DownloadStreamOptions, error) { + input := azblob.DownloadStreamOptions{AccessConditions: &azblob.AccessConditions{}} + + if val := r.Header.Get("Range"); val != "" { + // zero value count indicates from the offset to the resource's end, suffix-length is not required + input.Range = azblob.HTTPRange{Offset: 0, Count: 0} + bytesEnd := 0 + if _, err := fmt.Sscanf(val, "bytes=%d-%d", &input.Range.Offset, &bytesEnd); err != nil { + if _, err := fmt.Sscanf(val, "bytes=%d-", &input.Range.Offset); err != nil { + return nil, err + } + } + if bytesEnd != 0 { + input.Range.Count = int64(bytesEnd) - input.Range.Offset + 1 + } + } + if val := r.Header.Get("If-Match"); val != "" { + etagIfMatch := azcore.ETag(val) + input.AccessConditions.ModifiedAccessConditions.IfMatch = &etagIfMatch + } + if val := r.Header.Get("If-None-Match"); val != "" { + etagIfNoneMatch := azcore.ETag(val) + input.AccessConditions.ModifiedAccessConditions.IfNoneMatch = &etagIfNoneMatch + } + if val := r.Header.Get("If-Modified-Since"); val != "" { + t, err := http.ParseTime(val) + if err != nil { + return nil, err + } + input.AccessConditions.ModifiedAccessConditions.IfModifiedSince = &t + + } + if val := r.Header.Get("If-Unmodified-Since"); val != "" { + t, err := http.ParseTime(val) + if err != nil { + return nil, err + } + input.AccessConditions.ModifiedAccessConditions.IfUnmodifiedSince = &t + } + + return &input, nil +} diff --git a/pkg/azurestore/azurestore.go b/pkg/azurestore/azurestore.go index 0768bb421..dc85d9f2a 100644 --- a/pkg/azurestore/azurestore.go +++ b/pkg/azurestore/azurestore.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "io/fs" + "net/http" "os" "strings" @@ -47,6 +48,7 @@ func (store AzureStore) UseIn(composer *handler.StoreComposer) { composer.UseCore(store) composer.UseTerminater(store) composer.UseLengthDeferrer(store) + composer.UseContentServer(store) } func (store AzureStore) NewUpload(ctx context.Context, info handler.FileInfo) (handler.Upload, error) { @@ -149,6 +151,10 @@ func (store AzureStore) AsLengthDeclarableUpload(upload handler.Upload) handler. return upload.(*AzUpload) } +func (store AzureStore) AsServableUpload(upload handler.Upload) handler.ServableUpload { + return upload.(*AzUpload) +} + func (upload *AzUpload) WriteChunk(ctx context.Context, offset int64, src io.Reader) (int64, error) { // Create a temporary file for holding the uploaded data file, err := os.CreateTemp(upload.tempDir, "tusd-az-tmp-") @@ -214,6 +220,11 @@ func (upload *AzUpload) GetReader(ctx context.Context) (io.ReadCloser, error) { return upload.BlockBlob.Download(ctx) } +// Serves the contents of the blob directly handling special HTTP headers like Range, if set +func (upload *AzUpload) ServeContent(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + return upload.BlockBlob.ServeContent(ctx, w, r) +} + // Finish the file upload and commit the block list func (upload *AzUpload) FinishUpload(ctx context.Context) error { return upload.BlockBlob.Commit(ctx) diff --git a/pkg/azurestore/azurestore_mock_test.go b/pkg/azurestore/azurestore_mock_test.go index 48000a9ca..6ae735165 100644 --- a/pkg/azurestore/azurestore_mock_test.go +++ b/pkg/azurestore/azurestore_mock_test.go @@ -7,6 +7,7 @@ package azurestore_test import ( context "context" io "io" + http "net/http" reflect "reflect" gomock "github.com/golang/mock/gomock" @@ -132,6 +133,20 @@ func (mr *MockAzBlobMockRecorder) GetOffset(arg0 interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOffset", reflect.TypeOf((*MockAzBlob)(nil).GetOffset), arg0) } +// ServeContent mocks base method. +func (m *MockAzBlob) ServeContent(arg0 context.Context, arg1 http.ResponseWriter, arg2 *http.Request) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ServeContent", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// ServeContent indicates an expected call of ServeContent. +func (mr *MockAzBlobMockRecorder) ServeContent(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ServeContent", reflect.TypeOf((*MockAzBlob)(nil).ServeContent), arg0, arg1, arg2) +} + // Upload mocks base method. func (m *MockAzBlob) Upload(arg0 context.Context, arg1 io.ReadSeeker) error { m.ctrl.T.Helper() diff --git a/pkg/azurestore/azurestore_test.go b/pkg/azurestore/azurestore_test.go index 5b64e7290..a866357a7 100644 --- a/pkg/azurestore/azurestore_test.go +++ b/pkg/azurestore/azurestore_test.go @@ -6,8 +6,11 @@ import ( "encoding/json" "errors" "io" + "net/http" + "net/http/httptest" "testing" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/bloberror" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" @@ -431,6 +434,130 @@ func TestDeclareLength(t *testing.T) { cancel() } +func TestAzureStoreAsServerDataStore(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + assert := assert.New(t) + + service := NewMockAzService(mockCtrl) + store := azurestore.New(service) + + mockUpload := &azurestore.AzUpload{} + servableUpload := store.AsServableUpload(mockUpload) + + assert.NotNil(servableUpload) + assert.IsType(&azurestore.AzUpload{}, servableUpload) +} + +func TestAZServableUploadServeContent(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + assert := assert.New(t) + ctx := context.Background() + + blockBlob := NewMockAzBlob(mockCtrl) + assert.NotNil(blockBlob) + + // Create a test HTTP request and response recorder + req := httptest.NewRequest("GET", "/", nil) + rec := httptest.NewRecorder() + + // Expected response headers and body + expectedHeaders := map[string]string{ + "Content-Type": "text/plain", + "Content-Length": "12", + "ETag": "bytes", + "CacheControl": "max-age=3600", + } + expectedBody := "test content" + + // Mock ServeContent call + blockBlob.EXPECT().ServeContent(ctx, gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + // Add headers to response + for key, value := range expectedHeaders { + w.Header().Set(key, value) + } + w.WriteHeader(http.StatusOK) + + // Write response body + _, err := w.Write([]byte(expectedBody)) + return err + }, + ).Times(1) + + err := blockBlob.ServeContent(ctx, rec, req) + + assert.Nil(err) + assert.Equal(http.StatusOK, rec.Code) + for key, value := range expectedHeaders { + assert.Equal(value, rec.Header().Get(key)) + } + assert.Equal(expectedBody, rec.Body.String()) +} + +func TestParseDownloadOptions(t *testing.T) { + tests := []struct { + name string + headers map[string]string + expected *azblob.DownloadStreamOptions + expectErr bool + }{ + { + name: "Valid Range header", + headers: map[string]string{ + "Range": "bytes=10-20", + }, + expected: &azblob.DownloadStreamOptions{ + Range: azblob.HTTPRange{ + Offset: 10, + Count: 11, + }, + }, + expectErr: false, + }, + { + name: "Valid Range header", + headers: map[string]string{ + "Range": "bytes=10-", + }, + expected: &azblob.DownloadStreamOptions{ + Range: azblob.HTTPRange{ + Offset: 10, + Count: 0, + }, + }, + expectErr: false, + }, + { + name: "Valid Range header", + headers: map[string]string{ + "Range": "bytes=zZ-", + }, + expected: &azblob.DownloadStreamOptions{}, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + for key, value := range tt.headers { + req.Header.Set(key, value) + } + + options, err := azurestore.ParseDownloadOptions(req) + if tt.expectErr { + assert.NotNil(t, err) + } else { + assert.Nil(t, err) + options.AccessConditions = nil + assert.Equal(t, tt.expected, options) + } + }) + } +} + func newReadCloser(b []byte) io.ReadCloser { return io.NopCloser(bytes.NewReader(b)) }