From a97db92279303ee9a8d7e34a6ce4568cd24aee79 Mon Sep 17 00:00:00 2001 From: andreasjansson Date: Mon, 12 Apr 2021 09:13:56 -0700 Subject: [PATCH] Make content-addressable ID ignore timestamps Plus fix a bug where IDs were generated before resolving cached content Signed-off-by: andreasjansson --- pkg/server/build.go | 54 +++++++++++++++++++++++++++++++++++---------- pkg/zip/hash.go | 2 +- pkg/zip/read.go | 6 +++-- pkg/zip/zip_test.go | 8 +++---- 4 files changed, 51 insertions(+), 19 deletions(-) diff --git a/pkg/server/build.go b/pkg/server/build.go index 77a4afed21..78c7ad56fb 100644 --- a/pkg/server/build.go +++ b/pkg/server/build.go @@ -1,7 +1,10 @@ package server import ( + "bytes" "crypto/sha1" + "encoding/hex" + "encoding/json" "fmt" "io" "net/http" @@ -10,6 +13,8 @@ import ( "strings" "time" + "github.com/mholt/archiver/v3" + "github.com/replicate/cog/pkg/console" "github.com/replicate/cog/pkg/docker" "github.com/replicate/cog/pkg/global" @@ -17,7 +22,6 @@ import ( "github.com/replicate/cog/pkg/model" "github.com/replicate/cog/pkg/serving" "github.com/replicate/cog/pkg/zip" - "encoding/json" ) func (s *Server) ReceiveFile(w http.ResponseWriter, r *http.Request) { @@ -39,20 +43,14 @@ func (s *Server) ReceiveModel(r *http.Request, logWriter logger.Logger, user str if err := r.ParseMultipartForm(5 << 30); err != nil { return nil, fmt.Errorf("Failed to parse request: %w", err) } - file, header, err := r.FormFile("file") + inputFile, header, err := r.FormFile("file") if err != nil { return nil, fmt.Errorf("Failed to read input file: %w", err) } - defer file.Close() + defer inputFile.Close() logWriter.WriteStatus("Received model") - hasher := sha1.New() - if _, err := io.Copy(hasher, file); err != nil { - return nil, fmt.Errorf("Failed to compute hash: %w", err) - } - id := fmt.Sprintf("%x", hasher.Sum(nil)) - parentDir, err := os.MkdirTemp("/tmp", "unzip") if err != nil { return nil, fmt.Errorf("Failed to make tempdir: %w", err) @@ -66,9 +64,13 @@ func (s *Server) ReceiveModel(r *http.Request, logWriter logger.Logger, user str return nil, err } z := zip.NewCachingZip() - if err := z.ReaderUnarchive(file, header.Size, dir, zipCache); err != nil { + if err := z.ReaderUnarchive(inputFile, header.Size, dir, zipCache); err != nil { return nil, fmt.Errorf("Failed to unzip: %w", err) } + id, err := computeID(dir) + if err != nil { + return nil, err + } configRaw, err := os.ReadFile(filepath.Join(dir, global.ConfigFilename)) if err != nil { @@ -83,9 +85,13 @@ func (s *Server) ReceiveModel(r *http.Request, logWriter logger.Logger, user str return nil, err } - if _, err := file.Seek(0, io.SeekStart); err != nil { - return nil, fmt.Errorf("Failed to rewind file: %w", err) + // make zip file for storage + file := new(bytes.Buffer) + z2 := &archiver.Zip{ImplicitTopLevelFolder: false} + if err := z2.WriterArchive([]string{dir + "/"}, file); err != nil { + return nil, fmt.Errorf("Failed to zip directory: %w", err) } + if err := s.store.Upload(user, name, id, file); err != nil { return nil, fmt.Errorf("Failed to upload to storage: %w", err) } @@ -260,3 +266,27 @@ func validateServingExampleInput(help *serving.HelpResponse, input map[string]st } return nil } + +func computeID(dir string) (string, error) { + hasher := sha1.New() + err := filepath.WalkDir(dir, func(path string, d os.DirEntry, err error) error { + if err != nil { + return err + } + if !d.Type().IsRegular() { + return nil + } + file, err := os.Open(path) + if err != nil { + return fmt.Errorf("Failed to open %s: %w", path, err) + } + if _, err := io.Copy(hasher, file); err != nil { + return fmt.Errorf("Failed to read %s: %w", path, err) + } + return nil + }) + if err != nil { + return "", err + } + return hex.EncodeToString(hasher.Sum(nil)), nil +} diff --git a/pkg/zip/hash.go b/pkg/zip/hash.go index 0be2c545df..2dd40da544 100644 --- a/pkg/zip/hash.go +++ b/pkg/zip/hash.go @@ -3,9 +3,9 @@ package zip import ( "crypto/sha256" "encoding/hex" + "fmt" "io" "os" - "fmt" ) func getFileHash(path string) (string, error) { diff --git a/pkg/zip/read.go b/pkg/zip/read.go index 4a03bb69f0..58f48bcb0e 100644 --- a/pkg/zip/read.go +++ b/pkg/zip/read.go @@ -24,11 +24,13 @@ func (z *CachingZip) ReaderUnarchive(source io.Reader, size int64, destination s prefixBuf := make([]byte, len(cachePrefix)) file, err := os.Open(fpath) if err != nil { - return fmt.Errorf("Failed to open %s: %v", fpath, err) + return fmt.Errorf("Failed to open in zip %s: %v", fpath, err) } defer file.Close() if _, err := file.Read(prefixBuf); err != nil { - return fmt.Errorf("Failed to read %s: %v", fpath, err) + if err != io.EOF { + return fmt.Errorf("Failed to read in zip %s: %v", fpath, err) + } } if string(prefixBuf) == cachePrefix { hashBuf := make([]byte, hashLength) diff --git a/pkg/zip/zip_test.go b/pkg/zip/zip_test.go index 3b44f64052..ff3bbd7362 100644 --- a/pkg/zip/zip_test.go +++ b/pkg/zip/zip_test.go @@ -47,7 +47,7 @@ func TestCachingZip(t *testing.T) { out, err := os.Create(outPath) require.NoError(t, err) - err = z.WriterArchive(dataDir + "/", out, []string{}) + err = z.WriterArchive(dataDir+"/", out, []string{}) require.NoError(t, err) require.NoError(t, out.Close()) @@ -67,7 +67,7 @@ func TestCachingZip(t *testing.T) { require.NoError(t, err) stat, err := file.Stat() require.NoError(t, err) - err = z.ReaderUnarchive(file, stat.Size(), unzipDir2 + "/", fs) + err = z.ReaderUnarchive(file, stat.Size(), unzipDir2+"/", fs) require.NoError(t, err) requireUnzippedCorrectly(t, unzipDir2, "foo", "bar", "baz") @@ -87,7 +87,7 @@ func TestCachingZip(t *testing.T) { require.NoError(t, os.WriteFile(filepath.Join(dataDir, "anotherdir/baz.txt"), []byte("changed-baz"), 0644)) - err = z.WriterArchive(dataDir + "/", out2, hashes) + err = z.WriterArchive(dataDir+"/", out2, hashes) err = new(archiver.Zip).Unarchive(outPath2, unzipDir3) require.NoError(t, err) @@ -100,7 +100,7 @@ func TestCachingZip(t *testing.T) { require.NoError(t, err) stat, err = file.Stat() require.NoError(t, err) - err = z.ReaderUnarchive(file, stat.Size(), unzipDir4 + "/", fs) + err = z.ReaderUnarchive(file, stat.Size(), unzipDir4+"/", fs) require.NoError(t, err) requireUnzippedCorrectly(t, unzipDir4, "foo", "bar", "changed-baz")