diff --git a/java/lance-jni/src/blocking_dataset.rs b/java/lance-jni/src/blocking_dataset.rs index 55b108dd9cc..b7165088328 100644 --- a/java/lance-jni/src/blocking_dataset.rs +++ b/java/lance-jni/src/blocking_dataset.rs @@ -224,20 +224,6 @@ impl BlockingDataset { Ok(branches) } - pub fn create_branch( - &mut self, - branch: &str, - version: u64, - source_branch: Option<&str>, - ) -> Result { - let reference = match source_branch { - Some(b) => Ref::from((b, version)), - None => Ref::from(version), - }; - let inner = RT.block_on(self.inner.create_branch(branch, reference, None))?; - Ok(Self { inner }) - } - pub fn delete_branch(&mut self, branch: &str) -> Result<()> { RT.block_on(self.inner.delete_branch(branch))?; Ok(()) @@ -258,17 +244,8 @@ impl BlockingDataset { Ok(Self { inner }) } - pub fn create_tag( - &mut self, - tag: &str, - version_number: u64, - branch: Option<&str>, - ) -> Result<()> { - RT.block_on( - self.inner - .tags() - .create_on_branch(tag, version_number, branch), - )?; + pub fn create_tag(&mut self, tag: &str, reference: Ref) -> Result<()> { + RT.block_on(self.inner.tags().create(tag, reference))?; Ok(()) } @@ -277,8 +254,8 @@ impl BlockingDataset { Ok(()) } - pub fn update_tag(&mut self, tag: &str, version: u64, branch: Option<&str>) -> Result<()> { - RT.block_on(self.inner.tags().update_on_branch(tag, version, branch))?; + pub fn update_tag(&mut self, tag: &str, reference: Ref) -> Result<()> { + RT.block_on(self.inner.tags().update(tag, reference))?; Ok(()) } @@ -1357,50 +1334,20 @@ fn inner_shallow_clone<'local>( env: &mut JNIEnv<'local>, java_dataset: JObject, target_path: JString, - reference: JObject, + jref: JObject, storage_options: JObject, ) -> Result> { let target_path_str = target_path.extract(env)?; - let storage_options = env.get_optional(&storage_options, |env, map_obj| { - let jmap = JMap::from_env(env, &map_obj)?; - to_rust_map(env, &jmap) - })?; - - let reference = { - let version_number = env.get_optional_u64_from_method(&reference, "getVersionNumber")?; - let tag_name = env.get_optional_string_from_method(&reference, "getTagName")?; - let branch_name = env.get_optional_string_from_method(&reference, "getBranchName")?; - match (version_number, branch_name, tag_name) { - (Some(version_number), branch_name, None) => { - Ref::Version(branch_name, Some(version_number)) - } - (None, None, Some(tag_name)) => Ref::Tag(tag_name), - _ => { - return Err(Error::input_error( - "One of (optional branch, version_number) and tag must be specified" - .to_string(), - )) - } - } - }; - + let reference = transform_jref_to_ref(jref, env)?; + let storage_opts = transform_jstorage_options(storage_options, env)?; let new_ds = { let mut dataset_guard = unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }?; - RT.block_on( - dataset_guard.inner.shallow_clone( - &target_path_str, - reference, - storage_options - .map(|options| { - Some(ObjectStoreParams { - storage_options: Some(options), - ..Default::default() - }) - }) - .unwrap_or(None), - ), - )? + RT.block_on(dataset_guard.inner.shallow_clone( + target_path_str.as_str(), + reference, + storage_opts, + ))? }; BlockingDataset { inner: new_ds }.into_java(env) @@ -1904,11 +1851,17 @@ fn inner_list_tags<'local>( let array_list = env.new_object("java/util/ArrayList", "()V", &[])?; for (tag_name, tag_contents) in tag_map { + let branch_name: JObject = if let Some(branch_name) = tag_contents.branch.as_ref() { + env.new_string(branch_name)?.into() + } else { + JObject::null() + }; let java_tag = env.new_object( "org/lance/Tag", - "(Ljava/lang/String;JI)V", + "(Ljava/lang/String;Ljava/lang/String;JI)V", &[ JValue::Object(&env.new_string(tag_name)?.into()), + JValue::Object(&branch_name), JValue::Long(tag_contents.version as i64), JValue::Int(tag_contents.manifest_size as i32), ], @@ -1928,25 +1881,11 @@ pub extern "system" fn Java_org_lance_Dataset_nativeCreateTag( mut env: JNIEnv, java_dataset: JObject, jtag_name: JString, - jtag_version: jlong, -) { - ok_or_throw_without_return!( - env, - inner_create_tag(&mut env, java_dataset, jtag_name, jtag_version) - ) -} - -#[no_mangle] -pub extern "system" fn Java_org_lance_Dataset_nativeCreateTagOnBranch( - mut env: JNIEnv, - java_dataset: JObject, - jtag_name: JString, - jtag_version: jlong, - jbranch: JString, + jref: JObject, ) { ok_or_throw_without_return!( env, - inner_create_tag_on_branch(&mut env, java_dataset, jtag_name, jtag_version, jbranch) + inner_create_tag(&mut env, java_dataset, jtag_name, jref) ) } @@ -1954,27 +1893,13 @@ fn inner_create_tag( env: &mut JNIEnv, java_dataset: JObject, jtag_name: JString, - jtag_version: jlong, + jref: JObject, ) -> Result<()> { let tag = jtag_name.extract(env)?; + let reference = transform_jref_to_ref(jref, env)?; let mut dataset_guard = { unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }? }; - dataset_guard.create_tag(tag.as_str(), jtag_version as u64, None)?; - Ok(()) -} - -fn inner_create_tag_on_branch( - env: &mut JNIEnv, - java_dataset: JObject, - jtag_name: JString, - jtag_version: jlong, - jbranch: JString, -) -> Result<()> { - let tag = jtag_name.extract(env)?; - let branch = jbranch.extract(env)?; - let mut dataset_guard = - { unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }? }; - dataset_guard.create_tag(tag.as_str(), jtag_version as u64, Some(branch.as_str()))?; + dataset_guard.create_tag(tag.as_str(), reference)?; Ok(()) } @@ -1999,54 +1924,25 @@ pub extern "system" fn Java_org_lance_Dataset_nativeUpdateTag( mut env: JNIEnv, java_dataset: JObject, jtag_name: JString, - jtag_version: jlong, + jref: JObject, ) { ok_or_throw_without_return!( env, - inner_update_tag(&mut env, java_dataset, jtag_name, jtag_version) + inner_update_tag(&mut env, java_dataset, jtag_name, jref) ) } -#[no_mangle] -pub extern "system" fn Java_org_lance_Dataset_nativeUpdateTagOnBranch( - mut env: JNIEnv, - java_dataset: JObject, - jtag_name: JString, - jtag_version: jlong, - jbranch: JString, -) { - ok_or_throw_without_return!( - env, - inner_update_tag_on_branch(&mut env, java_dataset, jtag_name, jtag_version, jbranch) - ) -} - -fn inner_update_tag_on_branch( - env: &mut JNIEnv, - java_dataset: JObject, - jtag_name: JString, - jtag_version: jlong, - jbranch: JString, -) -> Result<()> { - let tag = jtag_name.extract(env)?; - let branch = jbranch.extract(env)?; - let mut dataset_guard = - { unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }? }; - dataset_guard.update_tag(tag.as_str(), jtag_version as u64, Some(branch.as_str()))?; - Ok(()) -} - fn inner_update_tag( env: &mut JNIEnv, java_dataset: JObject, jtag_name: JString, - jtag_version: jlong, + jref: JObject, ) -> Result<()> { let tag = jtag_name.extract(env)?; + let reference = transform_jref_to_ref(jref, env)?; let mut dataset_guard = { unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }? }; - dataset_guard.update_tag(tag.as_str(), jtag_version as u64, None)?; - Ok(()) + dataset_guard.update_tag(tag.as_str(), reference) } #[no_mangle] @@ -2128,12 +2024,12 @@ pub extern "system" fn Java_org_lance_Dataset_nativeCreateBranch<'local>( mut env: JNIEnv<'local>, java_dataset: JObject, jbranch: JString, - jversion: jlong, - source_branch_obj: JObject, // Optional + jref: JObject, + jstorage_options: JObject, // Optional ) -> JObject<'local> { ok_or_throw!( env, - inner_create_branch(&mut env, java_dataset, jbranch, jversion, source_branch_obj) + inner_create_branch(&mut env, java_dataset, jbranch, jref, jstorage_options) ) } @@ -2141,42 +2037,12 @@ fn inner_create_branch<'local>( env: &mut JNIEnv<'local>, java_dataset: JObject, jbranch: JString, - jversion: jlong, - source_branch_obj: JObject, // Optional + jref: JObject, + jstorage_options: JObject, // Optional ) -> Result> { let branch_name: String = jbranch.extract(env)?; - let version = jversion as u64; - let source_branch = env.get_string_opt(&source_branch_obj)?; - let new_dataset = { - let mut dataset_guard = - unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }?; - dataset_guard.create_branch(&branch_name, version, source_branch.as_deref())? - }; - new_dataset.into_java(env) -} - -#[no_mangle] -pub extern "system" fn Java_org_lance_Dataset_nativeCreateBranchOnTag<'local>( - mut env: JNIEnv<'local>, - java_dataset: JObject, - jbranch: JString, - jtag_name: JString, -) -> JObject<'local> { - ok_or_throw!( - env, - inner_create_branch_on_tag(&mut env, java_dataset, jbranch, jtag_name) - ) -} - -fn inner_create_branch_on_tag<'local>( - env: &mut JNIEnv<'local>, - java_dataset: JObject, - jbranch: JString, - jtag_name: JString, -) -> Result> { - let branch_name: String = jbranch.extract(env)?; - let tag_name: String = jtag_name.extract(env)?; - let reference = Ref::from(tag_name.as_str()); + let reference = transform_jref_to_ref(jref, env)?; + let storage_opts = transform_jstorage_options(jstorage_options, env)?; let new_blocking_dataset = { let mut dataset_guard = @@ -2184,13 +2050,42 @@ fn inner_create_branch_on_tag<'local>( let inner = RT.block_on(dataset_guard.inner.create_branch( branch_name.as_str(), reference, - None, + storage_opts, ))?; BlockingDataset { inner } }; new_blocking_dataset.into_java(env) } +fn transform_jref_to_ref(jref: JObject, env: &mut JNIEnv) -> Result { + let source_tag_name = env.get_optional_string_from_method(&jref, "getTagName")?; + let source_version_number = env.get_optional_u64_from_method(&jref, "getVersionNumber")?; + let source_branch = env.get_optional_string_from_method(&jref, "getBranchName")?; + if let Some(tag_name) = source_tag_name { + Ok(Ref::Tag(tag_name)) + } else { + Ok(Ref::Version(source_branch, source_version_number)) + } +} + +fn transform_jstorage_options( + jstorage_options: JObject, + env: &mut JNIEnv, +) -> Result> { + let storage_options = env.get_optional(&jstorage_options, |env, map_obj| { + let jmap = JMap::from_env(env, &map_obj)?; + to_rust_map(env, &jmap) + })?; + Ok(storage_options + .map(|options| { + Some(ObjectStoreParams { + storage_options: Some(options), + ..Default::default() + }) + }) + .unwrap_or(None)) +} + #[no_mangle] pub extern "system" fn Java_org_lance_Dataset_nativeDeleteBranch( mut env: JNIEnv, diff --git a/java/src/main/java/org/lance/Dataset.java b/java/src/main/java/org/lance/Dataset.java index 76e854afe64..7d0f2503d8c 100644 --- a/java/src/main/java/org/lance/Dataset.java +++ b/java/src/main/java/org/lance/Dataset.java @@ -1171,6 +1171,44 @@ public Branches branches() { return new Branches(); } + /** + * Create a branch at a specified version. The returned Dataset points to the created branch's + * initial version. + * + * @param branch the branch name to create + * @param ref the reference to create branch from + * @return a new Dataset of the branch + */ + public Dataset createBranch(String branch, Ref ref) { + Preconditions.checkArgument(branch != null && ref != null, "branch and ref cannot be null"); + return innerCreateBranch(branch, ref, Optional.empty()); + } + + /** + * Create a branch at a specified version. The returned Dataset points to the created branch's + * initial version. + * + * @param branch the branch name to create + * @param ref the reference to create branch from + * @param storageOptions the storage options to create branch with + * @return a new Dataset of the branch + */ + public Dataset createBranch(String branch, Ref ref, Map storageOptions) { + Preconditions.checkArgument(branch != null && ref != null, "branch and ref cannot be null"); + Preconditions.checkArgument( + storageOptions != null && !storageOptions.isEmpty(), "storageOptions cannot be null"); + return innerCreateBranch(branch, ref, Optional.of(storageOptions)); + } + + private Dataset innerCreateBranch( + String branch, Ref ref, Optional> storageOptions) { + Preconditions.checkArgument(branch != null, "Branch cannot be null"); + try (LockManager.WriteLock writeLock = lockManager.acquireWriteLock()) { + Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); + return nativeCreateBranch(branch, ref, storageOptions); + } + } + /** * Checkout using a unified {@link Ref} which can be a tag, the latest version on main/branch or a * specified (branch_name, version_number). @@ -1204,32 +1242,44 @@ public Map getTableMetadata() { public class Tags { /** - * Create a new tag on main branch. + * Create a new tag on main branch. This is left for compatibility. We should use {@link + * #create(String, Ref)} instead. * * @param tag the tag name * @param versionNumber the version number to tag */ public void create(String tag, long versionNumber) { - try (LockManager.WriteLock writeLock = lockManager.acquireWriteLock()) { - Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); - nativeCreateTag(tag, versionNumber); - } + Preconditions.checkArgument(versionNumber > 0, "versionNumber must be greater than 0"); + create(tag, Ref.ofMain(versionNumber)); } /** * Create a new tag on a specified branch. * * @param tag the tag name - * @param versionNumber the version number to tag + * @param ref the referenced version to tag */ - public void create(String tag, long versionNumber, String targetBranch) { - Preconditions.checkArgument(targetBranch != null, "Branch cannot be null"); - try (LockManager.WriteLock writeLock = lockManager.acquireWriteLock()) { + public void create(String tag, Ref ref) { + Preconditions.checkArgument(tag != null, "Tag name cannot be null"); + Preconditions.checkArgument(ref != null, "ref cannot be null"); + try (LockManager.WriteLock readLock = lockManager.acquireWriteLock()) { Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); - nativeCreateTagOnBranch(tag, versionNumber, targetBranch); + nativeCreateTag(tag, ref); } } + /** + * Creates a new tag on the specified branch. This method will be removed in version 2.0.0. Use + * {@link #create(String, Ref)} instead. + * + * @param tag the name of the tag to create + * @param versionNumber the version number (or commit reference) to associate with the tag + */ + @Deprecated + public void create(String tag, long versionNumber, String targetBranch) { + create(tag, Ref.ofBranch(targetBranch, versionNumber)); + } + /** * Delete a tag from this dataset. * @@ -1243,29 +1293,29 @@ public void delete(String tag) { } /** - * Update a tag to a new version on main branch. + * Update a tag to a new version_number on main. This is left for compatibility. We should use + * {@link #update(String, Ref)} instead. * * @param tag the tag name - * @param versionNumber the version number to tag + * @param versionNumber the versionNumber on main. */ public void update(String tag, long versionNumber) { - try (LockManager.WriteLock writeLock = lockManager.acquireWriteLock()) { - Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); - nativeUpdateTag(tag, versionNumber); - } + Preconditions.checkArgument(versionNumber > 0, "version_number must be greater than 0"); + nativeUpdateTag(tag, Ref.ofMain(versionNumber)); } /** - * Update a tag to a new version on a specified branch. + * Update a tag to a new reference. * * @param tag the tag name - * @param version the version to tag + * @param ref the referenced version to tag */ - public void update(String tag, long version, String targetBranch) { - Preconditions.checkArgument(targetBranch != null, "Branch cannot be null"); + public void update(String tag, Ref ref) { + Preconditions.checkArgument(tag != null, "tag cannot be null"); + Preconditions.checkArgument(ref != null, "ref cannot be null"); try (LockManager.WriteLock writeLock = lockManager.acquireWriteLock()) { Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); - nativeUpdateTagOnBranch(tag, version, targetBranch); + nativeUpdateTag(tag, ref); } } @@ -1297,51 +1347,6 @@ public long getVersion(String tag) { /** Branch operations of the dataset. */ public class Branches { - /** - * Create a branch at a specified version. The returned Dataset points to the created branch's - * initial version. - * - * @param branch the branch name to create - * @param versionNumber the version number to create branch from - * @return a new Dataset of the branch - */ - public Dataset create(String branch, long versionNumber) { - try (LockManager.WriteLock writeLock = lockManager.acquireWriteLock()) { - Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); - return nativeCreateBranch(branch, versionNumber, Optional.empty()); - } - } - - /** - * Create a branch from a specific source branch and version. - * - * @param branchName the branch name to create - * @param versionNumber the version number to create branch from - * @param sourceBranch the source branch name - * @return a new Dataset of the created branch - */ - public Dataset create(String branchName, long versionNumber, String sourceBranch) { - try (LockManager.WriteLock writeLock = lockManager.acquireWriteLock()) { - Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); - Preconditions.checkNotNull(sourceBranch); - return nativeCreateBranch(branchName, versionNumber, Optional.of(sourceBranch)); - } - } - - /** - * Create a branch from a tag reference. - * - * @param branchName the branch name to create - * @param sourceTag the tag name to create branch from - * @return a new Dataset of the created branch - */ - public Dataset create(String branchName, String sourceTag) { - try (LockManager.WriteLock writeLock = lockManager.acquireWriteLock()) { - Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); - Preconditions.checkNotNull(sourceTag); - return nativeCreateBranchOnTag(branchName, sourceTag); - } - } /** * Delete a branch and its metadata. @@ -1395,7 +1400,7 @@ public SqlQuery sql(String sql) { * @return MergeInsertResult containing the new merged Dataset. */ public MergeInsertResult mergeInsert(MergeInsertParams mergeInsert, ArrowArrayStream source) { - try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) { + try (LockManager.WriteLock writeLock = lockManager.acquireWriteLock()) { MergeInsertResult result = nativeMergeInsert(mergeInsert, source.memoryAddress()); Dataset newDataset = result.dataset(); @@ -1412,15 +1417,11 @@ public MergeInsertResult mergeInsert(MergeInsertParams mergeInsert, ArrowArraySt private native MergeInsertResult nativeMergeInsert( MergeInsertParams mergeInsert, long arrowStreamMemoryAddress); - private native void nativeCreateTag(String tag, long versionNumber); - - private native void nativeCreateTagOnBranch(String tag, long versionNumber, String branch); + private native void nativeCreateTag(String tag, Ref ref); private native void nativeDeleteTag(String tag); - private native void nativeUpdateTag(String tag, long versionNumber); - - private native void nativeUpdateTagOnBranch(String tag, long versionNumber, String branch); + private native void nativeUpdateTag(String tag, Ref ref); private native List nativeListTags(); @@ -1430,9 +1431,7 @@ private native MergeInsertResult nativeMergeInsert( private native Dataset nativeCheckout(Ref ref); private native Dataset nativeCreateBranch( - String branch, long versionNumber, Optional sourceBranch); - - private native Dataset nativeCreateBranchOnTag(String branch, String tagName); + String branch, Ref ref, Optional> storageOptions); private native void nativeDeleteBranch(String branch); diff --git a/java/src/main/java/org/lance/Ref.java b/java/src/main/java/org/lance/Ref.java index 61e282b22b5..111a1edd6d3 100644 --- a/java/src/main/java/org/lance/Ref.java +++ b/java/src/main/java/org/lance/Ref.java @@ -14,11 +14,11 @@ package org.lance; import com.google.common.base.MoreObjects; +import com.google.common.base.Preconditions; import java.util.Optional; public class Ref { - private final Optional versionNumber; private final Optional branchName; private final Optional tagName; @@ -42,6 +42,7 @@ public Optional getTagName() { } public static Ref ofMain(long versionNumber) { + Preconditions.checkArgument(versionNumber > 0, "versionNumber must be greater than 0"); return new Ref(Optional.of(versionNumber), Optional.empty(), Optional.empty()); } @@ -50,14 +51,20 @@ public static Ref ofMain() { } public static Ref ofBranch(String branchName) { + Preconditions.checkArgument( + branchName != null && !branchName.isEmpty(), "branchName must not be empty"); return new Ref(Optional.empty(), Optional.of(branchName), Optional.empty()); } public static Ref ofBranch(String branchName, long versionNumber) { + Preconditions.checkArgument( + branchName != null && !branchName.isEmpty(), "branchName must not be empty"); + Preconditions.checkArgument(versionNumber > 0, "versionNumber must be greater than 0"); return new Ref(Optional.of(versionNumber), Optional.of(branchName), Optional.empty()); } public static Ref ofTag(String tagName) { + Preconditions.checkArgument(tagName != null && !tagName.isEmpty(), "tagName must not be empty"); return new Ref(Optional.empty(), Optional.empty(), Optional.of(tagName)); } diff --git a/java/src/main/java/org/lance/Tag.java b/java/src/main/java/org/lance/Tag.java index a9c328bedbd..f7ce7be83cc 100644 --- a/java/src/main/java/org/lance/Tag.java +++ b/java/src/main/java/org/lance/Tag.java @@ -16,14 +16,17 @@ import com.google.common.base.MoreObjects; import java.util.Objects; +import java.util.Optional; public class Tag { private final String name; + private final Optional branch; private final long version; private final int manifestSize; - public Tag(String name, long version, int manifestSize) { + public Tag(String name, String branch, long version, int manifestSize) { this.name = name; + this.branch = Optional.ofNullable(branch); this.version = version; this.manifestSize = manifestSize; } @@ -32,6 +35,10 @@ public String getName() { return name; } + public Optional getBranch() { + return branch; + } + public long getVersion() { return version; } @@ -44,6 +51,7 @@ public int getManifestSize() { public String toString() { return MoreObjects.toStringHelper(this) .add("name", name) + .add("branch", branch) .add("version", version) .add("manifestSize", manifestSize) .toString(); @@ -59,12 +67,13 @@ public boolean equals(Object o) { } Tag tag = (Tag) o; return version == tag.version + && Objects.equals(branch, tag.branch) && manifestSize == tag.manifestSize && Objects.equals(name, tag.name); } @Override public int hashCode() { - return Objects.hash(name, version, manifestSize); + return Objects.hash(name, branch, version, manifestSize); } } diff --git a/java/src/test/java/org/lance/DatasetTest.java b/java/src/test/java/org/lance/DatasetTest.java index 4b5db975827..b29b3dcf39f 100644 --- a/java/src/test/java/org/lance/DatasetTest.java +++ b/java/src/test/java/org/lance/DatasetTest.java @@ -256,7 +256,7 @@ void testDatasetCheckoutVersion(@TempDir Path tempDir) { } @Test - void testDatasetTags(@TempDir Path tempDir) { + void testTags(@TempDir Path tempDir) { String datasetPath = tempDir.resolve("dataset_tags").toString(); try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { TestUtils.SimpleTestDataset testDataset = @@ -265,7 +265,7 @@ void testDatasetTags(@TempDir Path tempDir) { // version 1, empty dataset try (Dataset dataset = testDataset.createEmptyDataset()) { assertEquals(1, dataset.version()); - dataset.tags().create("tag1", 1); + dataset.tags().create("tag1", Ref.ofMain()); assertEquals(1, dataset.tags().list().size()); assertEquals(1, dataset.tags().list().get(0).getVersion()); assertEquals(1, dataset.tags().getVersion("tag1")); @@ -277,11 +277,11 @@ void testDatasetTags(@TempDir Path tempDir) { assertEquals(1, dataset2.tags().list().size()); assertEquals(1, dataset2.tags().list().get(0).getVersion()); assertEquals(1, dataset2.tags().getVersion("tag1")); - dataset2.tags().create("tag2", 2); + dataset2.tags().create("tag2", Ref.ofMain(2)); assertEquals(2, dataset2.tags().list().size()); assertEquals(1, dataset2.tags().getVersion("tag1")); assertEquals(2, dataset2.tags().getVersion("tag2")); - dataset2.tags().update("tag2", 1); + dataset2.tags().update("tag2", Ref.ofMain(1)); assertEquals(2, dataset2.tags().list().size()); assertEquals(1, dataset2.tags().list().get(0).getVersion()); assertEquals(1, dataset2.tags().list().get(1).getVersion()); @@ -302,6 +302,35 @@ void testDatasetTags(@TempDir Path tempDir) { assertEquals(1, checkoutV1.tags().list().get(0).getVersion()); assertEquals(1, checkoutV1.tags().getVersion("tag1")); } + + try (Dataset branch = dataset2.createBranch("branch", Ref.ofMain(2))) { + branch.tags().create("tag_on_branch", Ref.ofBranch("branch")); + assertEquals(2, dataset2.tags().getVersion("tag_on_branch")); + List tags = dataset2.tags().list(); + Optional tagOptional = + dataset2.tags().list().stream() + .filter(t -> t.getName().equals("tag_on_branch")) + .findFirst(); + assertEquals(2, tags.size()); + assertTrue(tagOptional.isPresent()); + assertEquals(2, tagOptional.get().getVersion()); + assertEquals(Optional.of("branch"), tagOptional.get().getBranch()); + + dataset2.tags().update("tag1", Ref.ofBranch("branch")); + tags = dataset2.tags().list(); + tagOptional = + dataset2.tags().list().stream() + .filter(t -> t.getName().equals("tag_on_branch")) + .findFirst(); + assertEquals(2, tags.size()); + assertTrue(tagOptional.isPresent()); + assertEquals(2, tagOptional.get().getVersion()); + assertEquals(Optional.of("branch"), tagOptional.get().getBranch()); + } + + assertEquals(2, dataset2.tags().list().size()); + dataset2.tags().delete("tag_on_branch"); + assertEquals(1, dataset2.tags().list().size()); } } } @@ -1519,7 +1548,7 @@ void testBranches(@TempDir Path tempDir) { assertEquals(5, mainV2.countRows()); // Step2. create branch2 based on main:2 - try (Dataset branch1V2 = mainV2.branches().create("branch1", 2)) { + try (Dataset branch1V2 = mainV2.createBranch("branch1", Ref.ofMain(2))) { assertEquals(2, branch1V2.version()); // Write batch B on branch1: 3 rows -> global@3 @@ -1531,15 +1560,16 @@ void testBranches(@TempDir Path tempDir) { assertEquals(8, branch1V3.countRows()); // A(5) + B(3) // Step 3. Create branch2 based on branch1's latest version (simulate tag 't1') - mainV1.tags().create("tag", 3, "branch1"); + mainV1.tags().create("tag", Ref.ofBranch("branch1", 3)); - try (Dataset branch2V3 = branch1V2.branches().create("branch2", "tag")) { + try (Dataset branch2V3 = branch1V2.createBranch("branch2", Ref.ofTag("tag"))) { assertEquals(3, branch2V3.version()); assertEquals(8, branch2V3.countRows()); // A(5) + B(3) // Step 4. Write batch C on branch2: 2 rows -> branch2:4 FragmentMetadata fragC = suite.createNewFragment(2); - Append appendC = Append.builder().fragments(Arrays.asList(fragC)).build(); + Append appendC = + Append.builder().fragments(Collections.singletonList(fragC)).build(); try (Dataset branch2V4 = branch2V3.newTransactionBuilder().operation(appendC).build().commit()) { assertEquals(4, branch2V4.version()); diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 99f44f9c6ce..38e845e6102 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -588,7 +588,7 @@ def branches(self) -> "Branches": def create_branch( self, branch: str, - reference: Optional[int | str | Tuple[str, int]] = None, + reference: Optional[int | str | Tuple[Optional[str], Optional[int]]] = None, storage_options: Optional[Dict[str, str]] = None, ) -> "LanceDataset": """Create a new branch from a version or tag. @@ -597,10 +597,11 @@ def create_branch( ---------- branch: str Name of the branch to create. - reference: Optional[int | str | Tuple[str, int]] - The reference which could be a version_number, a tag name or a tuple of - (branch_name, version_number) to create the branch from. - If None, the latest version of the current branch is used. + reference: Optional[int | str | Tuple[Optional[str], Optional[int]] + An integer specifies a version number in the current branch; a string + specifies a tag name; a Tuple[Optional[str], Optional[int]] specifies + a version number in a specified branch. (None, None) means the latest + version_number on the main branch. storage_options: Optional[Dict[str, str]] Storage options for the underlying object store. If not provided, the storage options from the current dataset will be used. @@ -621,28 +622,6 @@ def create_branch( ds._read_params = self._read_params return ds - def checkout_branch(self, branch: str) -> "LanceDataset": - """Check out the latest version of a branch. - - Parameters - ---------- - branch: str - The branch name to checkout. - - Returns - ------- - LanceDataset - A dataset instance at the latest version of the branch. - """ - inner = self._ds.checkout_branch(branch) - ds = LanceDataset.__new__(LanceDataset) - ds._ds = inner - ds._uri = inner.uri - ds._storage_options = self._storage_options - ds._default_scan_options = self._default_scan_options - ds._read_params = self._read_params - return ds - def checkout_latest(self): """Check out the latest version of the current branch.""" self._ds.checkout_latest() @@ -2223,7 +2202,9 @@ def latest_version(self) -> int: """ return self._ds.latest_version() - def checkout_version(self, version: int | str | Tuple[str, int]) -> "LanceDataset": + def checkout_version( + self, version: int | str | Tuple[Optional[str], Optional[int]] + ) -> "LanceDataset": """ Load the given version of the dataset. @@ -2233,9 +2214,11 @@ def checkout_version(self, version: int | str | Tuple[str, int]) -> "LanceDatase Parameters ---------- - version: int | str | Tuple[str, int], - The version to check out. A version number on main (`int`), a tag - (`str`) or a tuple of ('branch_name', 'version_number') can be provided. + version: int | str | Tuple[Optional[str], Optional[int]], + An integer specifies a version number in the current branch; a string + specifies a tag name; a Tuple[Optional[str], Optional[int]] specifies + a version number in a specified branch. (None, None) means the latest + version_number on the main branch. Returns ------- @@ -3457,8 +3440,8 @@ def validate(self): def shallow_clone( self, - target_path: Union[str, Path], - version: Union[int, str, Tuple[int, str]], + target_path: str | Path, + reference: int | str | Tuple[Optional[str], Optional[int]], storage_options: Optional[Dict[str, str]] = None, **kwargs, ) -> "LanceDataset": @@ -3472,10 +3455,11 @@ def shallow_clone( ---------- target_path : str or Path The URI or filesystem path to clone the dataset into. - version : int, str or Tuple[int, str] - The source version to clone. An integer specifies a version number in main; - a string specifies a tag name; a Tuple[int, str] specifies a version number - in a specified branch. + reference : int, str or Tuple[Optional[str], Optional[int]] + An integer specifies a version number in the current branch; a string + specifies a tag name; a Tuple[Optional[str], Optional[int]] specifies + a version number in a specified branch. (None, None) means the latest + version_number on the main branch. storage_options : dict, optional Object store configuration for the new dataset (e.g., credentials, endpoints). If not specified, the storage options of the source dataset @@ -3493,7 +3477,7 @@ def shallow_clone( if storage_options is None: storage_options = self._storage_options - self._ds.shallow_clone(target_uri, version, storage_options) + self._ds.shallow_clone(target_uri, reference, storage_options) # Open and return a fresh dataset at the target URI to avoid manual overrides return LanceDataset(target_uri, storage_options=storage_options, **kwargs) @@ -3920,6 +3904,7 @@ class Transaction: class Tag(TypedDict): + branch: Optional[str] version: int manifest_size: int @@ -5210,7 +5195,11 @@ def list_ordered(self, order: Optional[str] = None) -> list[str, Tag]: """ return self._ds.tags_ordered(order) - def create(self, tag: str, version: int, branch: Optional[str] = None) -> None: + def create( + self, + tag: str, + reference: Optional[int | str | Tuple[Optional[str], Optional[int]]] = None, + ) -> None: """ Create a tag for a given dataset version. @@ -5219,12 +5208,13 @@ def create(self, tag: str, version: int, branch: Optional[str] = None) -> None: tag: str, The name of the tag to create. This name must be unique among all tag names for the dataset. - version: int, - The dataset version to tag. - branch: Optional[str], - The specified branch to create the tag, None if the specified branch is main + reference : int, str or Tuple[Optional[str], Optional[int]] + An integer specifies a version number in the current branch; a string + specifies a tag name; a Tuple[Optional[str], Optional[int]] specifies + a version number in a specified branch. (None, None) means the latest + version_number on the main branch. """ - self._ds.create_tag(tag, version, branch) + self._ds.create_tag(tag, reference) def delete(self, tag: str) -> None: """ @@ -5238,7 +5228,11 @@ def delete(self, tag: str) -> None: """ self._ds.delete_tag(tag) - def update(self, tag: str, version: int, branch: Optional[str] = None) -> None: + def update( + self, + tag: str, + reference: Optional[int | str | Tuple[Optional[str], Optional[int]]] = None, + ) -> None: """ Update tag to a new version. @@ -5246,12 +5240,13 @@ def update(self, tag: str, version: int, branch: Optional[str] = None) -> None: ---------- tag: str, The name of the tag to update. - version: int, - The new dataset version to tag. - branch: Optional[str], - The specified branch to create the tag, None if the specified branch is main + reference : int, str or Tuple[Optional[str], Optional[int]] + An integer specifies a version number in the current branch; a string + specifies a tag name; a Tuple[Optional[str], Optional[int]] specifies + a version number in a specified branch. (None, None) means the latest + version_number on the main branch. """ - self._ds.update_tag(tag, version, branch) + self._ds.update_tag(tag, reference) class Branches: diff --git a/python/python/lance/lance/__init__.pyi b/python/python/lance/lance/__init__.pyi index ff8d5e4e4d5..506e4e267d5 100644 --- a/python/python/lance/lance/__init__.pyi +++ b/python/python/lance/lance/__init__.pyi @@ -292,13 +292,14 @@ class _Dataset: def versions(self) -> List[Version]: ... def version(self) -> int: ... def latest_version(self) -> int: ... - def checkout_version(self, version: int | str | Tuple[str, int]) -> _Dataset: ... - def checkout_branch(self, branch: str) -> _Dataset: ... + def checkout_version( + self, version: int | str | Tuple[Optional[str], Optional[int]] + ) -> _Dataset: ... def checkout_latest(self) -> _Dataset: ... def shallow_clone( self, target_path: str, - reference: Optional[int | str | Tuple[str, int]] = None, + reference: Optional[int | str | Tuple[Optional[str], Optional[int]]] = None, storage_options: Optional[Dict[str, str]] = None, ) -> _Dataset: ... def restore(self): ... @@ -313,17 +314,23 @@ class _Dataset: def tags(self) -> Dict[str, Tag]: ... def tags_ordered(self, order: Optional[str]) -> List[Tuple[str, Tag]]: ... def create_tag( - self, tag: str, version: int, branch: Optional[str] = None + self, + tag: str, + reference: Optional[int | str | Tuple[Optional[str], Optional[int]]] = None, ) -> Tag: ... def delete_tag(self, tag: str): ... - def update_tag(self, tag: str, version: int, branch: Optional[str] = None): ... + def update_tag( + self, + tag: str, + reference: Optional[int | str | Tuple[Optional[str], Optional[int]]] = None, + ): ... # Branch operations def branches(self) -> Dict[str, Branch]: ... def branches_ordered(self, order: Optional[str]) -> List[Tuple[str, Branch]]: ... def create_branch( self, branch: str, - reference: Optional[int | str | Tuple[str, int]] = None, + reference: Optional[int | str | Tuple[Optional[str], Optional[int]]] = None, storage_options: Optional[Dict[str, str]] = None, **kwargs, ) -> _Dataset: ... diff --git a/python/python/tests/test_dataset.py b/python/python/tests/test_dataset.py index 9ddb6db6881..7522c0efa8b 100644 --- a/python/python/tests/test_dataset.py +++ b/python/python/tests/test_dataset.py @@ -449,7 +449,7 @@ def test_tag(tmp_path: Path): ds.tags.delete("tag1") ds.tags.create("tag1", 1) - ds.tags.create("tag2", 1, None) + ds.tags.create("tag2", 1) assert len(ds.tags.list()) == 2 @@ -466,16 +466,16 @@ def test_tag(tmp_path: Path): # test tag update with pytest.raises( - ValueError, match="Version not found error: version 3 does not exist" + ValueError, match="Version not found error: version main:3 does not exist" ): ds.tags.update("tag1", 3) with pytest.raises( ValueError, match="Ref not found error: tag tag3 does not exist" ): - ds.tags.update("tag3", 1, None) + ds.tags.update("tag3", 1) - ds.tags.update("tag1", 2, None) + ds.tags.update("tag1", 2) ds = lance.dataset(base_dir, "tag1") assert ds.version == 2 @@ -486,6 +486,33 @@ def test_tag(tmp_path: Path): version = ds.tags.get_version("tag1") assert version == 1 + ds.create_branch("branch", "tag1") + ds.tags.create("tag3", ("branch", None)) + target_tag = ds.tags.list().get("tag3") + assert ds.tags.get_version("tag3") == 1 + assert len(ds.tags.list()) == 3 + assert target_tag is not None + assert target_tag["version"] == 1 + assert target_tag["branch"] == "branch" + + ds.tags.update("tag3", (None, 2)) + target_tag = ds.tags.list()["tag3"] + assert ds.tags.get_version("tag3") == 2 + assert target_tag is not None + assert target_tag["version"] == 2 + assert target_tag["branch"] is None + + ds.create_branch("branch2", 2) + ds.tags.update("tag3", ("branch2", 2)) + target_tag = ds.tags.list()["tag3"] + assert ds.tags.get_version("tag3") == 2 + assert target_tag is not None + assert target_tag["version"] == 2 + assert target_tag["branch"] == "branch2" + + ds.tags.delete("tag3") + assert len(ds.tags.list()) == 2 + def test_tag_order(tmp_path: Path): table = pa.Table.from_pydict({"colA": [1, 2, 3], "colB": [4, 5, 6]}) @@ -1127,8 +1154,8 @@ def test_cleanup_error_when_tagged_old_versions(tmp_path): lance.write_dataset(table, base_dir, mode="overwrite") dataset = lance.dataset(base_dir) - dataset.tags.create("old-tag", 1, None) - dataset.tags.create("another-old-tag", 2, None) + dataset.tags.create("old-tag", 1) + dataset.tags.create("another-old-tag", 2) with pytest.raises(OSError): dataset.cleanup_old_versions(older_than=(datetime.now() - moment)) @@ -1156,9 +1183,9 @@ def test_cleanup_around_tagged_old_versions(tmp_path): lance.write_dataset(table, base_dir, mode="overwrite") dataset = lance.dataset(base_dir) - dataset.tags.create("old-tag", 1, None) - dataset.tags.create("another-old-tag", 2, None) - dataset.tags.create("tag-latest", 3, None) + dataset.tags.create("old-tag", 1) + dataset.tags.create("another-old-tag", 2) + dataset.tags.create("tag-latest", 3) stats = dataset.cleanup_old_versions( older_than=(datetime.now() - moment), error_if_tagged_old_versions=False @@ -4778,20 +4805,28 @@ def test_shallow_clone(tmp_path: Path): ds = lance.write_dataset(table_v2, src_dir, mode="overwrite") # Create a tag pointing to version 1 - ds.tags.create("v1", 1, None) + ds.tags.create("v1", 1) # Clone by numeric version (v2) and assert equality clone_v2_dir = tmp_path / "clone_v2" - ds_clone_v2 = ds.shallow_clone(clone_v2_dir, version=2) + ds_clone_v2 = ds.shallow_clone(clone_v2_dir, 2) assert ds_clone_v2.to_table() == table_v2 assert lance.dataset(clone_v2_dir).to_table() == table_v2 # Clone by tag (v1) and assert equality clone_v1_tag_dir = tmp_path / "clone_v1_tag" - ds_clone_v1_tag = ds.shallow_clone(clone_v1_tag_dir, version="v1") + ds_clone_v1_tag = ds.shallow_clone(clone_v1_tag_dir, "v1") assert ds_clone_v1_tag.to_table() == table_v1 assert lance.dataset(clone_v1_tag_dir).to_table() == table_v1 + table_v3 = pa.table({"a": [7, 8, 9], "b": [40, 50, 60]}) + branch = ds.create_branch("branch", 2) + lance.write_dataset(table_v3, branch.uri, mode="overwrite") + clone_branch_v3 = tmp_path / "clone_branch_v3" + cloned_by_branch = branch.shallow_clone(clone_branch_v3, 3) + assert cloned_by_branch.to_table() == table_v3 + assert lance.dataset(clone_branch_v3).to_table() == table_v3 + def test_branches(tmp_path: Path): # Step 1: create branch1 from main → append to branch1 → create branch2 from tag @@ -4810,10 +4845,23 @@ def test_branches(tmp_path: Path): ) assert branch1.to_table().combine_chunks() == expected_branch1.combine_chunks() - # Step 2: tag latest of branch1 → create branch2 from that tag - tag_name = "branch1_latest" - branch1.tags.create(tag_name, branch1.latest_version, "branch1") - branch2 = branch1.create_branch("branch2", tag_name) + # Step 2: + # tag latest of branch1 → create branch2 from that tag + # test create tag on the main branch by different ways + # test create branch from the main branch by specifying "main" + branch1.tags.create("branch1_latest", ("branch1", None)) + branch1.tags.create("main_latest", (None, None)) + branch1.tags.create("main_latest2", ("main", None)) + branch1.create_branch("branch_from_main", ("main", None)) + assert branch1.tags.list()["branch1_latest"]["branch"] == "branch1" + assert branch1.tags.list()["main_latest"]["branch"] is None + assert branch1.tags.list()["main_latest2"]["branch"] is None + assert branch1.branches.list()["branch_from_main"]["parent_branch"] is None + assert branch1.branches.list()["branch_from_main"]["parent_version"] == 1 + assert branch1.checkout_version("main_latest").latest_version == 1 + assert branch1.checkout_version("main_latest2").latest_version == 1 + assert branch1.checkout_version(("branch_from_main", None)).latest_version == 1 + branch2 = branch1.create_branch("branch2", "branch1_latest") assert branch2.version == 2 # Step 3: append more data to branch2 → verify contains branch1 data + new @@ -4838,20 +4886,23 @@ def test_branches(tmp_path: Path): assert "create_at" in b1_meta try: - ds_main.branches.delete("branch1") + ds_main.checkout_version("branch_not_exists") + assert False, "Expected OSError was not raised" except OSError as e: - if "Not found" not in str(e): + if "does not exist" not in str(e): raise + + ds_main.branches.delete("branch2") branches_after = ds_main.branches.list() - assert "branch1" not in branches_after - assert "branch2" in branches_after + assert "branch2" not in branches_after + assert "branch1" in branches_after - branch2 = ds_main.checkout_branch("branch2") - assert branch2.version == 3 - assert branch2.to_table().combine_chunks() == expected_branch2.combine_chunks() - branch2 = ds_main.checkout_version(("branch2", 2)) - assert branch2.version == 2 - assert branch2.to_table().combine_chunks() == expected_branch1.combine_chunks() - branch2.checkout_latest() - assert branch2.version == 3 - assert branch2.to_table().combine_chunks() == expected_branch2.combine_chunks() + branch1 = ds_main.checkout_version(("branch1", None)) + assert branch1.version == 2 + assert branch1.to_table().combine_chunks() == expected_branch1.combine_chunks() + branch1 = ds_main.checkout_version(("branch1", 1)) + assert branch1.version == 1 + assert branch1.to_table().combine_chunks() == main_table.combine_chunks() + branch1.checkout_latest() + assert branch1.version == 2 + assert branch1.to_table().combine_chunks() == expected_branch1.combine_chunks() diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 68031cc3c70..1e3f99347ea 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -1569,9 +1569,10 @@ impl Dataset { let pytags = PyDict::new(py); for (k, v) in tags.iter() { let dict = PyDict::new(py); - dict.set_item("version", v.version).unwrap(); - dict.set_item("manifest_size", v.manifest_size).unwrap(); - pytags.set_item(k, dict.into_py_any(py)?).unwrap(); + dict.set_item("branch", v.branch.clone())?; + dict.set_item("version", v.version)?; + dict.set_item("manifest_size", v.manifest_size)?; + pytags.set_item(k, dict.into_py_any(py)?)?; } pytags.into_py_any(py) }) @@ -1591,13 +1592,11 @@ impl Dataset { }) } - fn create_tag(&mut self, tag: String, version: u64, branch: Option) -> PyResult<()> { + fn create_tag(&mut self, py: Python, tag: String, reference: Option) -> PyResult<()> { + let reference = self.transform_ref(py, reference)?; rt().block_on( None, - self.ds - .as_ref() - .tags() - .create_on_branch(tag.as_str(), version, branch.as_deref()), + self.ds.as_ref().tags().create(tag.as_str(), reference), )? .map_err(|err| match err { Error::NotFound { .. } => PyValueError::new_err(err.to_string()), @@ -1619,33 +1618,16 @@ impl Dataset { Ok(()) } - fn update_tag(&self, tag: String, version: u64, branch: Option) -> PyResult<()> { + fn update_tag(&self, py: Python, tag: String, reference: Option) -> PyResult<()> { + let reference = self.transform_ref(py, reference)?; rt().block_on( None, - self.ds - .as_ref() - .tags() - .update_on_branch(tag.as_str(), version, branch.as_deref()), + self.ds.as_ref().tags().update(tag.as_str(), reference), )? .infer_error()?; Ok(()) } - /// Check out the latest version of the given branch - fn checkout_branch(&self, branch: String) -> PyResult { - let ds = rt() - .block_on(None, self.ds.checkout_branch(branch.as_str()))? - .map_err(|err| match err { - Error::NotFound { .. } => PyValueError::new_err(err.to_string()), - _ => PyIOError::new_err(err.to_string()), - })?; - let uri_str = ds.uri().to_string(); - Ok(Self { - ds: Arc::new(ds), - uri: uri_str, - }) - } - /// Check out the latest version of the current branch fn checkout_latest(&mut self) -> PyResult<()> { let mut new_self = self.ds.as_ref().clone(); @@ -1668,7 +1650,6 @@ impl Dataset { storage_options: Option>, ) -> PyResult { let mut new_self = self.ds.as_ref().clone(); - // Build Ref from python object let reference = self.transform_ref(py, reference)?; let store_params = storage_options.map(|opts| ObjectStoreParams { storage_options: Some(opts), @@ -2855,29 +2836,18 @@ impl Dataset { if let Some(reference) = reference { if let Ok(i) = reference.downcast_bound::(py) { let version_number: u64 = i.extract()?; - Ok(Ref::from(version_number)) + Ok(version_number.into()) } else if let Ok(tag_name) = reference.downcast_bound::(py) { let tag: &str = &tag_name.to_string_lossy(); - Ok(Ref::from(tag)) + Ok(tag.into()) } else if let Ok(tuple) = reference.downcast_bound::(py) { - let len = tuple.len(); - if len == 1 { - let elem = tuple.get_item(0)?; - if let Ok(version_number) = elem.extract::() { - Ok(Ref::from(version_number)) - } else if let Ok(branch_name) = elem.extract::() { - Ok(Ref::Version(Some(branch_name), None)) - } else { - Err(PyValueError::new_err( - "Version tuple must contain integer or string", - )) - } - } else if len == 2 { - let (branch_name, version_number) = tuple.extract::<(String, u64)>()?; - Ok(Ref::Version(Some(branch_name), Some(version_number))) + if tuple.len() == 2 { + let (branch_name, version_number) = + tuple.extract::<(Option, Option)>()?; + Ok((branch_name.as_deref(), version_number).into()) } else { Err(PyValueError::new_err( - "Version tuple must have 1 or 2 elements", + "Version tuple should be Tuple[Optional[str], Optional[int]]", )) } } else { diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index 26d4b16faca..19372959a4e 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -449,14 +449,19 @@ impl Dataset { /// Check out a dataset version with a ref pub async fn checkout_version(&self, version: impl Into) -> Result { - let ref_: refs::Ref = version.into(); - match ref_ { + let reference: refs::Ref = version.into(); + match reference { refs::Ref::Version(branch, version_number) => { - self.checkout_by_ref(version_number, branch).await + self.checkout_by_ref(version_number, branch.as_deref()) + .await + } + refs::Ref::VersionNumber(version_number) => { + self.checkout_by_ref(Some(version_number), self.manifest.branch.as_deref()) + .await } refs::Ref::Tag(tag_name) => { let tag_contents = self.tags().get(tag_name.as_str()).await?; - self.checkout_by_ref(Some(tag_contents.version), tag_contents.branch) + self.checkout_by_ref(Some(tag_contents.version), tag_contents.branch.as_deref()) .await } } @@ -487,7 +492,7 @@ impl Dataset { /// Check out the latest version of the branch pub async fn checkout_branch(&self, branch: &str) -> Result { - self.checkout_by_ref(None, Some(branch.to_string())).await + self.checkout_by_ref(None, Some(branch)).await } /// This is a two-phase operation: @@ -550,14 +555,10 @@ impl Dataset { self.branches().list().await } - fn already_checked_out( - &self, - location: &ManifestLocation, - branch_name: Option, - ) -> bool { + fn already_checked_out(&self, location: &ManifestLocation, branch_name: Option<&str>) -> bool { // We check the e_tag here just in case it has been overwritten. This can // happen if the table has been dropped then re-created recently. - self.manifest.branch == branch_name + self.manifest.branch.as_deref() == branch_name && self.manifest.version == location.version && self.manifest_location.naming_scheme == location.naming_scheme && location.e_tag.as_ref().is_some_and(|e_tag| { @@ -571,17 +572,9 @@ impl Dataset { async fn checkout_by_ref( &self, version_number: Option, - branch: Option, + branch: Option<&str>, ) -> Result { - let new_location = if self.manifest.branch.as_ref() != branch.as_ref() { - if let Some(branch_name) = branch.as_deref() { - self.find_branch_location(branch_name)? - } else { - self.branch_location().find_main()? - } - } else { - self.branch_location() - }; + let new_location = self.branch_location().find_branch(branch)?; let manifest_location = if let Some(version_number) = version_number { self.commit_handler @@ -597,7 +590,7 @@ impl Dataset { .await? }; - if self.already_checked_out(&manifest_location, branch.clone()) { + if self.already_checked_out(&manifest_location, branch) { return Ok(self.clone()); } @@ -982,7 +975,7 @@ impl Dataset { uri: self.uri.clone(), branch: self.manifest.branch.clone(), }; - current_location.find_branch(Some(branch_name.to_string())) + current_location.find_branch(Some(branch_name)) } /// Get the full manifest of the dataset version. @@ -1037,7 +1030,7 @@ impl Dataset { return Ok((cached_manifest, location)); } - if self.already_checked_out(&location, self.manifest.branch.clone()) { + if self.already_checked_out(&location, self.manifest.branch.as_deref()) { return Ok((self.manifest.clone(), self.manifest_location.clone())); } let mut manifest = read_manifest(&self.object_store, &location.path, location.size).await?; @@ -2137,8 +2130,7 @@ impl Dataset { version: impl Into, store_params: Option, ) -> Result { - let ref_ = version.into(); - let (ref_name, version_number) = self.resolve_reference(ref_).await?; + let (ref_name, version_number) = self.resolve_reference(version.into()).await?; let clone_op = Operation::Clone { is_shallow: true, ref_name, @@ -2149,7 +2141,9 @@ impl Dataset { let transaction = Transaction::new(version_number, clone_op, None); let builder = CommitBuilder::new(WriteDestination::Uri(target_path)) - .with_store_params(store_params.unwrap_or_default()) + .with_store_params( + store_params.unwrap_or(self.store_params.as_deref().cloned().unwrap_or_default()), + ) .with_object_store(Arc::new(self.object_store().clone())) .with_commit_handler(self.commit_handler.clone()) .with_storage_format(self.manifest.data_storage_format.lance_file_version()?); @@ -2162,14 +2156,18 @@ impl Dataset { if let Some(version_number) = version_number { Ok((branch, version_number)) } else { + let branch_location = self.branch_location().find_branch(branch.as_deref())?; let version_number = self .commit_handler - .resolve_latest_location(&self.base, &self.object_store) + .resolve_latest_location(&branch_location.path, &self.object_store) .await? .version; Ok((branch, version_number)) } } + refs::Ref::VersionNumber(version_number) => { + Ok((self.manifest.branch.clone(), version_number)) + } refs::Ref::Tag(tag_name) => { let tag_contents = self.tags().get(tag_name.as_str()).await?; Ok((tag_contents.branch, tag_contents.version)) diff --git a/rust/lance/src/dataset/branch_location.rs b/rust/lance/src/dataset/branch_location.rs index d3bdc3ab7f1..b9c979c8920 100644 --- a/rust/lance/src/dataset/branch_location.rs +++ b/rust/lance/src/dataset/branch_location.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +use crate::dataset::refs::Branches; use lance_core::{Error, Result}; use object_store::path::Path; use snafu::location; @@ -17,7 +18,7 @@ pub struct BranchLocation { impl BranchLocation { /// Find the root location pub fn find_main(&self) -> Result { - if let Some(branch_name) = self.branch.as_ref() { + if let Some(branch_name) = self.branch.as_deref() { let root_path_str = Self::get_root_path(self.path.as_ref(), branch_name)?; let root_uri = Self::get_root_path(self.uri.as_str(), branch_name)?; Ok(Self { @@ -69,13 +70,17 @@ impl BranchLocation { } /// Find the target branch location - pub fn find_branch(&self, branch_name: Option) -> Result { - if branch_name == self.branch { + pub fn find_branch(&self, branch_name: Option<&str>) -> Result { + if branch_name == self.branch.as_deref() { return Ok(self.clone()); } let root_location = self.find_main()?; - if let Some(target_branch) = branch_name.as_ref() { + if Branches::is_main_branch(branch_name) { + return Ok(root_location); + } + + if let Some(target_branch) = branch_name { let (new_path, new_uri) = { // Handle empty segment if target_branch.is_empty() { @@ -94,7 +99,7 @@ impl BranchLocation { Ok(Self { path: new_path, uri: new_uri, - branch: Some(target_branch.clone()), + branch: Some(target_branch.to_string()), }) } else { Ok(root_location) @@ -164,7 +169,7 @@ mod tests { fn test_find_branch_from_same_branch() { let root_path = TempStdDir::default().to_owned(); let location = create_branch_location(root_path); - let target_branch = location.branch.clone(); + let target_branch = location.branch.as_deref(); let new_location = location.find_branch(target_branch).unwrap(); assert_eq!(new_location.path, location.path); @@ -190,9 +195,9 @@ mod tests { fn test_find_simple_branch() { let root_path = TempStdDir::default().to_owned(); let location = create_branch_location(root_path); - let new_branch = Some("featureA".to_string()); + let new_branch = Some("featureA"); let main_location = location.find_main().unwrap(); - let new_location = location.find_branch(new_branch.clone()).unwrap(); + let new_location = location.find_branch(new_branch).unwrap(); assert_eq!( new_location.path.as_ref(), @@ -202,7 +207,7 @@ mod tests { new_location.uri, format!("{}/tree/featureA", main_location.uri) ); - assert_eq!(new_location.branch, new_branch); + assert_eq!(new_location.branch.as_deref(), new_branch); assert!(fs::create_dir_all(std::path::Path::new(new_location.uri.as_str())).is_ok()); } @@ -210,7 +215,7 @@ mod tests { fn test_find_complex_branch() { let root_path = TempStdDir::default().to_owned(); let location = create_branch_location(root_path); - let new_branch = Some("bugfix/issue-123".to_string()); + let new_branch = Some("bugfix/issue-123"); let main_location = location.find_main().unwrap(); let new_location = location.find_branch(new_branch).unwrap(); @@ -229,12 +234,12 @@ mod tests { fn test_find_empty_branch() { let root_path = TempStdDir::default().to_owned(); let location = create_branch_location(root_path); - let new_branch = Some("".to_string()); - let new_location = location.find_branch(new_branch.clone()).unwrap(); + let new_branch = Some(""); + let new_location = location.find_branch(new_branch).unwrap(); assert_eq!(new_location.path, location.path); assert_eq!(new_location.uri, location.uri); - assert_eq!(new_location.branch, new_branch); + assert_eq!(new_location.branch.as_deref(), new_branch); } #[test] @@ -258,7 +263,7 @@ mod tests { assert_eq!(main_location.branch, None); let new_branch = branch_location - .find_branch(Some("feature/nathan/A".to_string())) + .find_branch(Some("feature/nathan/A")) .unwrap(); assert_eq!( new_branch.uri, @@ -270,6 +275,6 @@ mod tests { .unwrap() .as_ref() ); - assert_eq!(new_branch.branch, Some("feature/nathan/A".to_string())); + assert_eq!(new_branch.branch.as_deref(), Some("feature/nathan/A")); } } diff --git a/rust/lance/src/dataset/builder.rs b/rust/lance/src/dataset/builder.rs index 332ba504cf9..3d463ce6ca4 100644 --- a/rust/lance/src/dataset/builder.rs +++ b/rust/lance/src/dataset/builder.rs @@ -520,6 +520,9 @@ impl DatasetBuilder { } (branch, version_number) } + // We don't have a current branch context, just specify the branch as main + // But the real branch will be specified by uri + Some(Ref::VersionNumber(version_number)) => (None, Some(version_number)), // Here we assume the uri and path is the root. // If tag not found, we need to delay checkout after loading by uri Some(Ref::Tag(tag_name)) => { @@ -564,7 +567,9 @@ impl DatasetBuilder { } if branch.as_deref() != dataset.manifest.branch.as_deref() { - return dataset.checkout_version((branch, version_number)).await; + return dataset + .checkout_version((branch.as_deref(), version_number)) + .await; } } if let Some(version_number) = version_number { diff --git a/rust/lance/src/dataset/refs.rs b/rust/lance/src/dataset/refs.rs index 6af0edf3dfc..4044da9f60e 100644 --- a/rust/lance/src/dataset/refs.rs +++ b/rust/lance/src/dataset/refs.rs @@ -12,7 +12,7 @@ use serde::{Deserialize, Serialize}; use std::sync::Arc; use crate::dataset::branch_location::BranchLocation; -use crate::dataset::refs::Ref::{Tag, Version}; +use crate::dataset::refs::Ref::{Tag, Version, VersionNumber}; use crate::{Error, Result}; use serde::de::DeserializeOwned; use snafu::location; @@ -22,9 +22,13 @@ use std::fmt; use std::fmt::Formatter; use std::io::ErrorKind; +pub const MAIN_BRANCH: &str = "main"; + /// Lance Ref #[derive(Debug, Clone)] pub enum Ref { + // Version number points of the current branch + VersionNumber(u64), // This is a global version identifier present as (branch_name, version_number) // if branch_name is None, it points to the main branch // if version_number is None, it points to the latest version @@ -34,32 +38,32 @@ pub enum Ref { } impl From for Ref { - fn from(ref_: u64) -> Self { - Version(None, Some(ref_)) + fn from(reference: u64) -> Self { + VersionNumber(reference) } } impl From<&str> for Ref { - fn from(ref_: &str) -> Self { - Tag(ref_.to_string()) + fn from(reference: &str) -> Self { + Tag(reference.to_string()) } } impl From<(&str, u64)> for Ref { - fn from(_ref: (&str, u64)) -> Self { - Version(Some(_ref.0.to_string()), Some(_ref.1)) + fn from(reference: (&str, u64)) -> Self { + Version(standardize_branch(reference.0), Some(reference.1)) } } -impl From<(Option, Option)> for Ref { - fn from(_ref: (Option, Option)) -> Self { - Version(_ref.0, _ref.1) +impl From<(Option<&str>, Option)> for Ref { + fn from(reference: (Option<&str>, Option)) -> Self { + Version(reference.0.and_then(standardize_branch), reference.1) } } impl From<(&str, Option)> for Ref { - fn from(_ref: (&str, Option)) -> Self { - Version(Some(_ref.0.to_string()), _ref.1) + fn from(reference: (&str, Option)) -> Self { + Version(standardize_branch(reference.0), reference.1) } } @@ -67,12 +71,12 @@ impl fmt::Display for Ref { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match self { Version(branch, version_number) => { - let branch_name = branch.as_deref().unwrap_or("main"); let version_str = version_number .map(|v| v.to_string()) .unwrap_or_else(|| "latest".to_string()); - write!(f, "{}:{}", branch_name, version_str) + write!(f, "{}:{}", normalize_branch(branch.as_deref()), version_str) } + VersionNumber(version_number) => write!(f, "{}", version_number), Tag(tag_name) => write!(f, "{}", tag_name), } } @@ -204,24 +208,12 @@ impl Tags<'_> { } let tag_contents = TagContents::from_path(&tag_file, self.object_store()).await?; - Ok(tag_contents) } - pub async fn create(&self, tag: &str, version: u64) -> Result<()> { - self.create_on_branch(tag, version, None).await - } - - pub async fn create_on_branch( - &self, - tag: &str, - version_number: u64, - branch: Option<&str>, - ) -> Result<()> { + pub async fn create(&self, tag: &str, reference: impl Into) -> Result<()> { check_valid_tag(tag)?; - let root_location = self.refs.root()?; - let branch = branch.map(String::from); let tag_file = tag_path(&root_location.path, tag); if self.object_store().exists(&tag_file).await? { @@ -229,39 +221,7 @@ impl Tags<'_> { message: format!("tag {} already exists", tag), }); } - - let branch_location = self.refs.base_location.find_branch(branch.clone())?; - let manifest_file = self - .refs - .commit_handler - .resolve_version_location( - &branch_location.path, - version_number, - &self.refs.object_store.inner, - ) - .await?; - - if !self.object_store().exists(&manifest_file.path).await? { - return Err(Error::VersionNotFound { - message: format!( - "version {}::{} does not exist", - branch.unwrap_or("Main".to_string()), - version_number - ), - }); - } - - let manifest_size = if let Some(size) = manifest_file.size { - size as usize - } else { - self.object_store().size(&manifest_file.path).await? as usize - }; - - let tag_contents = TagContents { - branch, - version: version_number, - manifest_size, - }; + let tag_contents = self.build_tag_content_by_ref(reference).await?; self.object_store() .put( @@ -287,43 +247,60 @@ impl Tags<'_> { self.object_store().delete(&tag_file).await } - pub async fn update(&self, tag: &str, version: u64) -> Result<()> { - self.update_on_branch(tag, version, None).await - } - - /// Update a tag to a branch::version - pub async fn update_on_branch( - &self, - tag: &str, - version_number: u64, - branch: Option<&str>, - ) -> Result<()> { + pub async fn update(&self, tag: &str, reference: impl Into) -> Result<()> { check_valid_tag(tag)?; - let branch = branch.map(String::from); let root_location = self.refs.root()?; let tag_file = tag_path(&root_location.path, tag); - if !self.object_store().exists(&tag_file).await? { return Err(Error::RefNotFound { message: format!("tag {} does not exist", tag), }); } + let tag_contents = self.build_tag_content_by_ref(reference).await?; - let target_branch_location = self.refs.base_location.find_branch(branch.clone())?; - let manifest_file = self - .refs - .commit_handler - .resolve_version_location( - &target_branch_location.path, - version_number, - &self.refs.object_store.inner, + self.object_store() + .put( + &tag_file, + serde_json::to_string_pretty(&tag_contents)?.as_bytes(), ) - .await?; + .await + .map(|_| ()) + } + + async fn build_tag_content_by_ref(&self, reference: impl Into) -> Result { + let reference = reference.into(); + let (branch, version_number) = match reference { + Version(branch, version_number) => (branch, version_number), + VersionNumber(version_number) => { + (self.refs.base_location.branch.clone(), Some(version_number)) + } + Tag(tag_name) => { + let tag_content = self.get(tag_name.as_str()).await?; + (tag_content.branch, Some(tag_content.version)) + } + }; + + let branch_location = self.refs.base_location.find_branch(branch.as_deref())?; + let manifest_file = if let Some(version_number) = version_number { + self.refs + .commit_handler + .resolve_version_location( + &branch_location.path, + version_number, + &self.refs.object_store.inner, + ) + .await? + } else { + self.refs + .commit_handler + .resolve_latest_location(&branch_location.path, &self.refs.object_store) + .await? + }; if !self.object_store().exists(&manifest_file.path).await? { return Err(Error::VersionNotFound { - message: format!("version {} does not exist", version_number), + message: format!("version {} does not exist", Version(branch, version_number)), }); } @@ -335,21 +312,18 @@ impl Tags<'_> { let tag_contents = TagContents { branch, - version: version_number, + version: manifest_file.version, manifest_size, }; - - self.object_store() - .put( - &tag_file, - serde_json::to_string_pretty(&tag_contents)?.as_bytes(), - ) - .await - .map(|_| ()) + Ok(tag_contents) } } impl Branches<'_> { + pub(crate) fn is_main_branch(branch: Option<&str>) -> bool { + branch == Some(MAIN_BRANCH) + } + pub async fn fetch(&self) -> Result> { let root_location = self.refs.root()?; let base_path = base_branches_contents_path(&root_location.path); @@ -408,7 +382,8 @@ impl Branches<'_> { Ok(branch_contents) } - pub async fn create( + // Only create branch metadata + pub(crate) async fn create( &self, branch_name: &str, version_number: u64, @@ -416,7 +391,7 @@ impl Branches<'_> { ) -> Result<()> { check_valid_branch(branch_name)?; - let source_branch = source_branch.map(String::from); + let source_branch = source_branch.and_then(standardize_branch); let root_location = self.refs.root()?; let branch_file = branch_contents_path(&root_location.path, branch_name); if self.object_store().exists(&branch_file).await? { @@ -425,7 +400,10 @@ impl Branches<'_> { }); } - let branch_location = self.refs.base_location.find_branch(source_branch.clone())?; + let branch_location = self + .refs + .base_location + .find_branch(source_branch.as_deref())?; // Verify the source version exists let manifest_file = self .refs @@ -536,35 +514,97 @@ impl Branches<'_> { remaining_branches: &[&str], base_location: &BranchLocation, ) -> Result> { - let mut longest_used_length = 0; - for &candidate in remaining_branches { - let common_len = branch - .chars() - .zip(candidate.chars()) - .take_while(|(a, b)| a == b) - .count(); - - if common_len > longest_used_length { - longest_used_length = common_len; + let deleted_branch = BranchRelativePath::new(branch); + let mut related_branches = Vec::new(); + let mut relative_dir = branch.to_string(); + for branch in remaining_branches { + let branch = BranchRelativePath::new(branch); + if branch.is_parent(&deleted_branch) || branch.is_child(&deleted_branch) { + related_branches.push(branch); + } else if let Some(common_prefix) = deleted_branch.find_common_prefix(&branch) { + related_branches.push(common_prefix); } } - // Means this branch path is used as a prefix of other branches - if longest_used_length == branch.len() { - return Ok(None); + + related_branches.sort_by(|a, b| a.segments.len().cmp(&b.segments.len()).reverse()); + if let Some(branch) = related_branches.first() { + if branch.is_child(&deleted_branch) || branch == &deleted_branch { + // There are children of the deleted branch, we can't delete any directory for now + // Example: deleted_branch = "a/b/c", remaining_branches = ["a/b/c/d"], we need to delete nothing + return Ok(None); + } else { + // We pick the longest common directory between the deleted branch and the remaining branches + // Then delete the first child of this common directory + // Example: deleted_branch = "a/b/c", remaining_branches = ["a"], we need to delete "a/b" + relative_dir = format!( + "{}/{}", + branch.segments.join("/"), + deleted_branch.segments[branch.segments.len()] + ); + } + } else if !deleted_branch.segments.is_empty() { + // There are no common directories between the deleted branch and the remaining branches + // We need to delete the entire directory + // Example: deleted_branch = "a/b/c", remaining_branches = [], we need to delete "a" + relative_dir = deleted_branch.segments[0].to_string(); } - let mut used_relative_path = &branch[..longest_used_length]; - if let Some(last_slash_index) = used_relative_path.rfind('/') { - used_relative_path = &used_relative_path[..last_slash_index]; + let absolute_dir = base_location.find_branch(Some(relative_dir.as_str()))?; + Ok(Some(absolute_dir.path)) + } +} + +#[derive(Debug, PartialEq)] +struct BranchRelativePath<'a> { + segments: Vec<&'a str>, +} + +impl<'a> BranchRelativePath<'a> { + fn new(branch_name: &'a str) -> Self { + let segments = branch_name.split('/').collect_vec(); + Self { segments } + } + + fn find_common_prefix(&self, other: &Self) -> Option { + let mut common_segments = Vec::new(); + for (i, segment) in self.segments.iter().enumerate() { + if i >= other.segments.len() || other.segments[i] != *segment { + break; + } + common_segments.push(*segment); } - let unused_dir = &branch[used_relative_path.len()..].trim_start_matches('/'); - if let Some(sub_dir) = unused_dir.split('/').next() { - let relative_dir = format!("{}/{}", used_relative_path, sub_dir); - // Use base_location to generate the cleanup path - let absolute_dir = base_location.find_branch(Some(relative_dir))?; - Ok(Some(absolute_dir.path)) + if !common_segments.is_empty() { + Some(BranchRelativePath { + segments: common_segments, + }) + } else { + None + } + } + + fn is_parent(&self, other: &Self) -> bool { + if other.segments.len() <= self.segments.len() { + false } else { - Ok(None) + for (i, segment) in self.segments.iter().enumerate() { + if other.segments[i] != *segment { + return false; + } + } + true + } + } + + fn is_child(&self, other: &Self) -> bool { + if other.segments.len() >= self.segments.len() { + false + } else { + for (i, segment) in other.segments.iter().enumerate() { + if self.segments[i] != *segment { + return false; + } + } + true } } } @@ -603,6 +643,20 @@ pub fn branch_contents_path(base_path: &Path, branch: &str) -> Path { base_branches_contents_path(base_path).child(format!("{}.json", branch)) } +pub(crate) fn normalize_branch(branch: Option<&str>) -> String { + match branch { + None => MAIN_BRANCH.to_string(), + Some(name) => name.to_string(), + } +} + +pub(crate) fn standardize_branch(branch: &str) -> Option { + match branch { + MAIN_BRANCH => None, + name => Some(name.to_string()), + } +} + async fn from_path(path: &Path, object_store: &ObjectStore) -> Result where T: DeserializeOwned, @@ -859,9 +913,8 @@ mod tests { // Test From for Ref let version_ref: Ref = 42u64.into(); match version_ref { - Version(branch, v) => { - assert_eq!(v, Some(42)); - assert_eq!(branch, None) + VersionNumber(version_number) => { + assert_eq!(version_number, 42); } _ => panic!("Expected Version variant"), } @@ -930,21 +983,17 @@ mod tests { } #[rstest] - #[case("feature/auth", &["feature/login", "feature/signup"], Some("feature/auth"))] - #[case("feature/auth/module", &["feature/other"], Some("feature/auth"))] - #[case("a/b/c", &["a/b/d", "a/e"], Some("a/b/c"))] #[case("feature/auth", &["feature/auth/sub"], None)] #[case("feature", &["feature/sub1", "feature/sub2"], None)] - #[case("a/b", &["a/b/c", "a/b/d"], None)] + #[case("a/b", &["a/b/c", "b/c/d"], None)] #[case("main", &[], Some("main"))] #[case("a", &["a"], None)] - #[case("single", &["other"], Some("single"))] - #[case("feature/auth/login/oauth", &["feature/auth/login/basic", "feature/auth/signup"], Some("feature/auth/login/oauth"))] - #[case("feature/user-auth", &["feature/user-signup"], Some("feature/user-auth"))] - #[case("release/2024.01", &["release/2024.02"], Some("release/2024.01"))] - #[case("very/long/common/prefix/branch1", &["very/long/common/prefix/branch2"], Some("very/long/common/prefix/branch1"))] - #[case("feature", &["bugfix", "hotfix"], Some("feature"))] + #[case("feature/auth", &["feature/login", "feature/signup"], Some("feature/auth"))] #[case("feature/sub", &["feature", "other"], Some("feature/sub"))] + #[case("very/long/common/prefix/branch1", &["very/long/common/prefix/branch2"], Some("very/long/common/prefix/branch1"))] + #[case("feature/auth/module", &["feature/other"], Some("feature/auth"))] + #[case("feature/dev", &["bugfix", "hotfix"], Some("feature"))] + #[case("branch1", &["dev/branch2", "feature/nathan/branch3", "branch4"], Some("branch1"))] fn test_get_cleanup_path( #[case] branch_to_delete: &str, #[case] remaining_branches: &[&str], @@ -969,7 +1018,7 @@ mod tests { branch_to_delete ); let expected_full_path = base_location - .find_branch(Some(expected_relative.to_string())) + .find_branch(Some(expected_relative)) .unwrap() .path; assert_eq!(result.unwrap().as_ref(), expected_full_path.as_ref()); diff --git a/rust/lance/src/dataset/tests/dataset_versioning.rs b/rust/lance/src/dataset/tests/dataset_versioning.rs index cfefa23ab3b..2e2fcdf6601 100644 --- a/rust/lance/src/dataset/tests/dataset_versioning.rs +++ b/rust/lance/src/dataset/tests/dataset_versioning.rs @@ -267,7 +267,7 @@ async fn test_tag( let bad_tag_creation = dataset.tags().create("tag1", 3).await; assert_eq!( bad_tag_creation.err().unwrap().to_string(), - "Version not found error: version Main::3 does not exist" + "Version not found error: version main:3 does not exist" ); let bad_tag_deletion = dataset.tags().delete("tag1").await; @@ -354,7 +354,7 @@ async fn test_tag( let another_bad_tag_update = dataset.tags().update("tag1", 3).await; assert_eq!( another_bad_tag_update.err().unwrap().to_string(), - "Version not found error: version 3 does not exist" + "Version not found error: version main:3 does not exist" ); dataset.tags().update("tag1", 2).await.unwrap(); @@ -595,11 +595,7 @@ async fn test_branch() { // create branch3 based on that tag, write data batch 4 branch2_dataset .tags() - .create_on_branch( - "tag1", - branch2_dataset.version().version, - Some("dev/branch2"), - ) + .create("tag1", ("dev/branch2", branch2_dataset.version().version)) .await .unwrap();