diff --git a/pkg/migrate/migrator_v1_v1fast.go b/pkg/migrate/migrator_v1_v1fast.go index 564663551d..3422a6def8 100644 --- a/pkg/migrate/migrator_v1_v1fast.go +++ b/pkg/migrate/migrator_v1_v1fast.go @@ -240,7 +240,13 @@ func (g *MigratorV1ToV1Fast) flushConfig(cfg *config.Config, dir string, configF } defer file.Close() console.Infof("Writing config changes to %s.", configFilepath) - _, err = file.WriteString(configStr) + + mergedCfgData, err := OverwrightYAML(data, content) + if err != nil { + return err + } + + _, err = file.WriteString(string(mergedCfgData)) if err != nil { return util.WrapError(err, "Failed to write config changes") } diff --git a/pkg/migrate/migrator_v1_v1fast_test.go b/pkg/migrate/migrator_v1_v1fast_test.go index 691f2dcb73..6e1f6ee685 100644 --- a/pkg/migrate/migrator_v1_v1fast_test.go +++ b/pkg/migrate/migrator_v1_v1fast_test.go @@ -66,10 +66,10 @@ class Predictor(BasePredictor): content, err := io.ReadAll(file) require.NoError(t, err) require.Equal(t, `build: - python_version: "3.11" - python_requirements: requirements.txt - fast: true -predict: predict.py:Predictor + python_version: "3.11" + python_requirements: requirements.txt + fast: true +predict: "predict.py:Predictor" `, string(content)) err = file.Close() require.NoError(t, err) @@ -155,11 +155,78 @@ class Predictor(BasePredictor): content, err := io.ReadAll(file) require.NoError(t, err) require.Equal(t, `build: - gpu: true + gpu: true + python_version: "3.11" + python_requirements: requirements.txt + fast: true +predict: "predict.py:Predictor" +`, string(content)) + err = file.Close() + require.NoError(t, err) +} + +func TestMigrateYAMLComments(t *testing.T) { + // Set our new working directory to a temp directory + originalDir, err := os.Getwd() + require.NoError(t, err) + defer func() { + err := os.Chdir(originalDir) + require.NoError(t, err) + }() + dir := t.TempDir() + err = os.Chdir(dir) + require.NoError(t, err) + + // Write our test configs/code + configFilepath := filepath.Join(dir, "cog.yaml") + file, err := os.Create(configFilepath) + require.NoError(t, err) + _, err = file.WriteString(`# Here we have a YAML comment +build: python_version: "3.11" - python_requirements: requirements.txt fast: true -predict: predict.py:Predictor + python_packages: + - "pillow" + run: + - command: curl -o /usr/local/bin/pget -L \"https://github.com/replicate/pget/releases/latest/download/pget_$(uname -s)_$(uname -m)\" && chmod +x /usr/local/bin/pget + gpu: true +predict: "predict.py:Predictor" +`) + require.NoError(t, err) + err = file.Close() + require.NoError(t, err) + pythonFilepath := filepath.Join(dir, "predict.py") + file, err = os.Create(pythonFilepath) + require.NoError(t, err) + _, err = file.WriteString(`from cog import BasePredictor, Input + + +class Predictor(BasePredictor): + def predict(self, s: str = Input(description="My Input Description", default=None)) -> str: + return "hello " + s +`) + require.NoError(t, err) + err = file.Close() + require.NoError(t, err) + + // Perform the migration + logCtx := coglog.NewMigrateLogContext(true) + migrator := NewMigratorV1ToV1Fast(false, logCtx) + err = migrator.Migrate(t.Context(), "cog.yaml") + require.NoError(t, err) + + // Check config output + file, err = os.Open(configFilepath) + require.NoError(t, err) + content, err := io.ReadAll(file) + require.NoError(t, err) + require.Equal(t, `# Here we have a YAML comment +build: + gpu: true + python_version: "3.11" + python_requirements: requirements.txt + fast: true +predict: "predict.py:Predictor" `, string(content)) err = file.Close() require.NoError(t, err) diff --git a/pkg/migrate/overwrite_yaml.go b/pkg/migrate/overwrite_yaml.go new file mode 100644 index 0000000000..7e7a1a89b2 --- /dev/null +++ b/pkg/migrate/overwrite_yaml.go @@ -0,0 +1,130 @@ +package migrate + +import ( + "fmt" + + "gopkg.in/yaml.v3" +) + +func OverwrightYAML(sourceYaml []byte, destinationYaml []byte) ([]byte, error) { + var sourceNode yaml.Node + err := yaml.Unmarshal(sourceYaml, &sourceNode) + if err != nil { + return nil, err + } + + var destinationNode yaml.Node + err = yaml.Unmarshal(destinationYaml, &destinationNode) + if err != nil { + return nil, err + } + + err = traverseAndCompare(sourceNode.Content[0], destinationNode.Content[0], "") + if err != nil { + return nil, err + } + + return yaml.Marshal(&destinationNode) +} + +func traverseAndCompare(sourceNode, destinationNode *yaml.Node, path string) error { + if sourceNode.Kind != destinationNode.Kind { + return fmt.Errorf("Type mismatch at %s: %s vs %s\n", path, nodeKindToString(sourceNode.Kind), nodeKindToString(destinationNode.Kind)) + } + + switch sourceNode.Kind { + case yaml.ScalarNode: + if sourceNode.Value != destinationNode.Value { + destinationNode.Value = sourceNode.Value + } + + case yaml.MappingNode: + map1 := mapNodeToMap(sourceNode) + map2 := mapNodeToMap(destinationNode) + + allKeys := getAllKeys(map1, map2) + + for _, key := range allKeys { + var childPath string + if path == "" { + childPath = key + } else { + childPath = path + "." + key + } + + sourceNodeChild, ok1 := map1[key] + destinationNodeChild, ok2 := map2[key] + + if !ok1 || !ok2 { + destinationNode.Content = sourceNode.Content + } else { + err := traverseAndCompare(sourceNodeChild, destinationNodeChild, childPath) + if err != nil { + return err + } + } + } + + case yaml.SequenceNode: + sourceLen := len(sourceNode.Content) + destinationLen := len(destinationNode.Content) + + maxLen := sourceLen + if destinationLen > maxLen { + maxLen = destinationLen + } + + for i := 0; i < maxLen; i++ { + childPath := fmt.Sprintf("%s[%d]", path, i) + + if i >= destinationLen { + destinationNode.Content = append(destinationNode.Content, sourceNode.Content[i]) + } else if i < sourceLen { + err := traverseAndCompare(sourceNode.Content[i], destinationNode.Content[i], childPath) + if err != nil { + return err + } + } + } + } + return nil +} + +func mapNodeToMap(node *yaml.Node) map[string]*yaml.Node { + result := make(map[string]*yaml.Node) + for i := 0; i < len(node.Content); i += 2 { + keyNode := node.Content[i] + valueNode := node.Content[i+1] + result[keyNode.Value] = valueNode + } + return result +} + +func getAllKeys(map1, map2 map[string]*yaml.Node) []string { + keys := make(map[string]struct{}) + for key := range map1 { + keys[key] = struct{}{} + } + for key := range map2 { + keys[key] = struct{}{} + } + + var keyList []string + for key := range keys { + keyList = append(keyList, key) + } + return keyList +} + +func nodeKindToString(kind yaml.Kind) string { + switch kind { + case yaml.ScalarNode: + return "Scalar" + case yaml.MappingNode: + return "Mapping" + case yaml.SequenceNode: + return "Sequence" + default: + return "Unknown" + } +} diff --git a/pkg/migrate/overwrite_yaml_test.go b/pkg/migrate/overwrite_yaml_test.go new file mode 100644 index 0000000000..280781c72b --- /dev/null +++ b/pkg/migrate/overwrite_yaml_test.go @@ -0,0 +1,64 @@ +package migrate + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestOverwrightYAML(t *testing.T) { + var yamlData1 = `build: + command: "build.sh" +image: "my-image" +predict: "predict.py" +train: "train.py" +concurrency: + max: 5 +environment: + - "VAR1=value1" + - "VAR2=value2" +` + + var yamlData2 = `build: + command: "build_new.sh" +image: "new-image" +predict: "new_predict.py" +concurrency: + max: 10 +environment: + - "VAR1=new_value1" + - "VAR3=value3" +` + content, err := OverwrightYAML([]byte(yamlData1), []byte(yamlData2)) + require.NoError(t, err) + require.Equal(t, yamlData1, string(content)) +} + +func TestOverwrightYAMLWithComments(t *testing.T) { + var yamlData1 = `# This here is a YAML Comment +build: + command: "build.sh" +image: "my-image" +predict: "predict.py" +train: "train.py" +concurrency: + max: 5 +environment: + - "VAR1=value1" + - "VAR2=value2" +` + + var yamlData2 = `build: + command: "build_new.sh" +image: "new-image" +predict: "new_predict.py" +concurrency: + max: 10 +environment: + - "VAR1=new_value1" + - "VAR3=value3" +` + content, err := OverwrightYAML([]byte(yamlData1), []byte(yamlData2)) + require.NoError(t, err) + require.Equal(t, yamlData1, string(content)) +}