Skip to content
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,23 @@ else if (codePoint == 0x03C2) {
}
}

private static final int COMBINED_LOWERCASE_I_DOT = 0x69 << 16 | 0x307;

private static int getLowercaseCodePoint(final int codePoint) {
if (codePoint == 0x0130) {
// Latin capital letter I with dot above is mapped to 2 lowercase characters.
return COMBINED_LOWERCASE_I_DOT;
}
else if (codePoint == 0x03C2) {
// Greek final and non-final capital letter sigma should be mapped the same.
return 0x03C3;
}
else {
// All other characters should follow context-unaware ICU single-code point case mapping.
return UCharacter.toLowerCase(codePoint);
}
}

/**
* Converts an entire string to lowercase using ICU rules, code point by code point, with
* special handling for one-to-many case mappings (i.e. characters that map to multiple
Expand Down Expand Up @@ -621,37 +638,69 @@ public static UTF8String lowercaseSubStringIndex(final UTF8String string,
}
}

public static Map<String, String> getCollationAwareDict(UTF8String string,
Map<String, String> dict, int collationId) {
String srcStr = string.toString();

private static Map<Integer, String> getLowercaseDict(final Map<String, String> dict) {
// Replace all the keys in the dict with lowercased code points.
Map<Integer, String> lowercaseDict = new HashMap<>();
for (Map.Entry<String, String> entry : dict.entrySet()) {
int codePoint = entry.getKey().codePointAt(0);
lowercaseDict.putIfAbsent(getLowercaseCodePoint(codePoint), entry.getValue());
}
return lowercaseDict;
}
private static Map<String, String> getCollationAwareDict(final Map<String, String> dict,
int collationId) {
// Replace all the keys in the dict with collation keys.
Map<String, String> collationAwareDict = new HashMap<>();
for (String key : dict.keySet()) {
StringSearch stringSearch =
CollationFactory.getStringSearch(string, UTF8String.fromString(key), collationId);

int pos = 0;
while ((pos = stringSearch.next()) != StringSearch.DONE) {
int codePoint = srcStr.codePointAt(pos);
int charCount = Character.charCount(codePoint);
String newKey = srcStr.substring(pos, pos + charCount);

boolean exists = false;
for (String existingKey : collationAwareDict.keySet()) {
if (stringSearch.getCollator().compare(existingKey, newKey) == 0) {
collationAwareDict.put(newKey, collationAwareDict.get(existingKey));
exists = true;
break;
}
}
for (Map.Entry<String, String> entry : dict.entrySet()) {
String collationKey = CollationFactory.getCollationKey(entry.getKey(), collationId);
collationAwareDict.putIfAbsent(collationKey, entry.getValue());
}
return collationAwareDict;
}

if (!exists) {
collationAwareDict.put(newKey, dict.get(key));
}
private static String lowercaseTranslate(final String input, final Map<Integer, String> dict) {
StringBuilder sb = new StringBuilder();
int charCount = 0;
for (int k = 0; k < input.length(); k += charCount) {
int codePoint = input.codePointAt(k);
charCount = Character.charCount(codePoint);
String translated = dict.get(getLowercaseCodePoint(codePoint));
if (null == translated) {
sb.appendCodePoint(codePoint);
} else if (!"\0".equals(translated)) {
sb.append(translated);
}
}
return sb.toString();
}
private static String translate(final String input, final Map<String, String> dict,
final int collationId) {
StringBuilder sb = new StringBuilder();
int charCount = 0;
for (int k = 0; k < input.length(); k += charCount) {
int codePoint = input.codePointAt(k);
charCount = Character.charCount(codePoint);
String subStr = input.substring(k, k + charCount);
String collationKey = CollationFactory.getCollationKey(subStr, collationId);
String translated = dict.get(collationKey);
if (null == translated) {
sb.append(subStr);
} else if (!"\0".equals(translated)) {
sb.append(translated);
}
}
return sb.toString();
}

return collationAwareDict;
public static UTF8String lowercaseTranslate(final UTF8String input,
final Map<String, String> dict) {
Map<Integer, String> lowercaseDict = getLowercaseDict(dict);
return UTF8String.fromString(lowercaseTranslate(input.toString(), lowercaseDict));
}
public static UTF8String translate(final UTF8String input, final Map<String, String> dict,
final int collationId) {
Map<String, String> collationAwareDict = getCollationAwareDict(dict,collationId);
return UTF8String.fromString(translate(input.toString(), collationAwareDict, collationId));
}

public static UTF8String lowercaseTrim(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -805,12 +805,24 @@ public static String[] getICULocaleNames() {
return Collation.CollationSpecICU.ICULocaleNames;
}

public static UTF8String getCollationKey(UTF8String input, int collationId) {
public static String getCollationKey(String input, int collationId) {
Collation collation = fetchCollation(collationId);
if (collation.supportsBinaryEquality) {
return input;
} else if (collation.supportsLowercaseEquality) {
return input.toLowerCase();
} else {
CollationKey collationKey = collation.collator.getCollationKey(input);
return Arrays.toString(collationKey.toByteArray());
}
}

public static UTF8String getCollationKey(UTF8String input, int collationId) {
Collation collation = fetchCollation(collationId);
if (collation.supportsBinaryEquality) {
return input;
} else if (collation.supportsLowercaseEquality) {
return CollationAwareUTF8String.toLowerCase(input);
} else {
CollationKey collationKey = collation.collator.getCollationKey(input.toString());
return UTF8String.fromBytes(collationKey.toByteArray());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -509,34 +509,19 @@ public static String genCode(final String source, final String dict, final int c
return String.format(expr + "Binary(%s, %s)", source, dict);
} else if (collation.supportsLowercaseEquality) {
return String.format(expr + "Lowercase(%s, %s)", source, dict);
} else {
} else {
return String.format(expr + "ICU(%s, %s, %d)", source, dict, collationId);
}
}
public static UTF8String execBinary(final UTF8String source, Map<String, String> dict) {
return source.translate(dict);
}
public static UTF8String execLowercase(final UTF8String source, Map<String, String> dict) {
String srcStr = source.toString();
StringBuilder sb = new StringBuilder();
int charCount = 0;
for (int k = 0; k < srcStr.length(); k += charCount) {
int codePoint = srcStr.codePointAt(k);
charCount = Character.charCount(codePoint);
String subStr = srcStr.substring(k, k + charCount);
String translated = dict.get(subStr.toLowerCase());
if (null == translated) {
sb.append(subStr);
} else if (!"\0".equals(translated)) {
sb.append(translated);
}
}
return UTF8String.fromString(sb.toString());
return CollationAwareUTF8String.lowercaseTranslate(source, dict);
}
public static UTF8String execICU(final UTF8String source, Map<String, String> dict,
final int collationId) {
return source.translate(CollationAwareUTF8String.getCollationAwareDict(
source, dict, collationId));
return CollationAwareUTF8String.translate(source, dict, collationId);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LOWER}
import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, CollationSupport, GenericArrayData, TypeUtils}
import org.apache.spark.sql.catalyst.util.{ArrayData, CollationSupport, GenericArrayData, TypeUtils}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation, StringTypeBinaryLcase}
Expand Down Expand Up @@ -859,13 +859,9 @@ case class Overlay(input: Expression, replace: Expression, pos: Expression, len:

object StringTranslate {

def buildDict(matchingString: UTF8String, replaceString: UTF8String, collationId: Int)
def buildDict(matchingString: UTF8String, replaceString: UTF8String)
: JMap[String, String] = {
val matching = if (CollationFactory.fetchCollation(collationId).supportsLowercaseEquality) {
matchingString.toString().toLowerCase()
} else {
matchingString.toString()
}
val matching = matchingString.toString()

val replace = replaceString.toString()
val dict = new HashMap[String, String]()
Expand Down Expand Up @@ -923,7 +919,7 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac
if (matchingEval != lastMatching || replaceEval != lastReplace) {
lastMatching = matchingEval.asInstanceOf[UTF8String].clone()
lastReplace = replaceEval.asInstanceOf[UTF8String].clone()
dict = StringTranslate.buildDict(lastMatching, lastReplace, collationId)
dict = StringTranslate.buildDict(lastMatching, lastReplace)
}

CollationSupport.StringTranslate.exec(srcEval.asInstanceOf[UTF8String], dict, collationId)
Expand All @@ -947,7 +943,7 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac
$termLastMatching = $matching.clone();
$termLastReplace = $replace.clone();
$termDict = org.apache.spark.sql.catalyst.expressions.StringTranslate
.buildDict($termLastMatching, $termLastReplace, $collationId);
.buildDict($termLastMatching, $termLastReplace);
}
${ev.value} = CollationSupport.StringTranslate.
exec($src, $termDict, $collationId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,8 @@ class CollationStringExpressionsSuite
}
assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
}
test("TRANSLATE check result on explicitly collated string") {

test("Support StringTranslate string expression with collation") {
// Supported collations
case class TranslateTestCase[R](input: String, matchExpression: String,
replaceExpression: String, collation: String, result: R)
Expand All @@ -260,33 +261,27 @@ class CollationStringExpressionsSuite
TranslateTestCase("TRanslate", "rnlt", "XxXx", "UTF8_BINARY_LCASE", "xXaxsXaxe"),
TranslateTestCase("TRanslater", "Rrnlt", "xXxXx", "UTF8_BINARY_LCASE", "xxaxsXaxex"),
TranslateTestCase("TRanslater", "Rrnlt", "XxxXx", "UTF8_BINARY_LCASE", "xXaxsXaxeX"),
// scalastyle:off
TranslateTestCase("test大千世界X大千世界", "界x", "AB", "UTF8_BINARY_LCASE", "test大千世AB大千世A"),
TranslateTestCase("大千世界test大千世界", "TEST", "abcd", "UTF8_BINARY_LCASE", "大千世界abca大千世界"),
TranslateTestCase("Test大千世界大千世界", "tT", "oO", "UTF8_BINARY_LCASE", "oeso大千世界大千世界"),
TranslateTestCase("大千世界大千世界tesT", "Tt", "Oo", "UTF8_BINARY_LCASE", "大千世界大千世界OesO"),
TranslateTestCase("大千世界大千世界tesT", "大千", "世世", "UTF8_BINARY_LCASE", "世世世界世世世界tesT"),
// scalastyle:on
TranslateTestCase("Translate", "Rnlt", "1234", "UNICODE", "Tra2s3a4e"),
TranslateTestCase("TRanslate", "rnlt", "XxXx", "UNICODE", "TRaxsXaxe"),
TranslateTestCase("TRanslater", "Rrnlt", "xXxXx", "UNICODE", "TxaxsXaxeX"),
TranslateTestCase("TRanslater", "Rrnlt", "XxxXx", "UNICODE", "TXaxsXaxex"),
// scalastyle:off
TranslateTestCase("test大千世界X大千世界", "界x", "AB", "UNICODE", "test大千世AX大千世A"),
TranslateTestCase("Test大千世界大千世界", "tT", "oO", "UNICODE", "Oeso大千世界大千世界"),
TranslateTestCase("大千世界大千世界tesT", "Tt", "Oo", "UNICODE", "大千世界大千世界oesO"),
// scalastyle:on
TranslateTestCase("Translate", "Rnlt", "1234", "UNICODE_CI", "41a2s3a4e"),
TranslateTestCase("TRanslate", "rnlt", "XxXx", "UNICODE_CI", "xXaxsXaxe"),
TranslateTestCase("TRanslater", "Rrnlt", "xXxXx", "UNICODE_CI", "xxaxsXaxex"),
TranslateTestCase("TRanslater", "Rrnlt", "XxxXx", "UNICODE_CI", "xXaxsXaxeX"),
// scalastyle:off
TranslateTestCase("test大千世界X大千世界", "界x", "AB", "UNICODE_CI", "test大千世AB大千世A"),
TranslateTestCase("大千世界test大千世界", "TEST", "abcd", "UNICODE_CI", "大千世界abca大千世界"),
TranslateTestCase("Test大千世界大千世界", "tT", "oO", "UNICODE_CI", "oeso大千世界大千世界"),
TranslateTestCase("大千世界大千世界tesT", "Tt", "Oo", "UNICODE_CI", "大千世界大千世界OesO"),
TranslateTestCase("大千世界大千世界tesT", "大千", "世世", "UNICODE_CI", "世世世界世世世界tesT"),
// scalastyle:on
TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", "UTF8_BINARY_LCASE", "14234e"),
TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", "UNICODE_CI", "14234e"),
TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", "UNICODE", "Tr4234e"),
Expand All @@ -298,7 +293,22 @@ class CollationStringExpressionsSuite
TranslateTestCase("abcdef", "abcde", "123", "UTF8_BINARY", "123f"),
TranslateTestCase("abcdef", "abcde", "123", "UTF8_BINARY_LCASE", "123f"),
TranslateTestCase("abcdef", "abcde", "123", "UNICODE", "123f"),
TranslateTestCase("abcdef", "abcde", "123", "UNICODE_CI", "123f")
TranslateTestCase("abcdef", "abcde", "123", "UNICODE_CI", "123f"),
// Case mapping edge cases
TranslateTestCase("İ", "i\u0307", "xy", "UTF8_BINARY", "İ"),
TranslateTestCase("İi\u0307", "İi\u0307", "123", "UTF8_BINARY", "123"),
TranslateTestCase("İi\u0307", "İyz", "123", "UTF8_BINARY", "1i\u0307"),
TranslateTestCase("İi\u0307", "xi\u0307", "123", "UTF8_BINARY", "İ23"),
TranslateTestCase("İ", "i\u0307", "xy", "UTF8_BINARY_LCASE", "İ"),
TranslateTestCase("İi\u0307", "İi\u0307", "123", "UTF8_BINARY_LCASE", "123"),
TranslateTestCase("İi\u0307", "İyz", "123", "UTF8_BINARY_LCASE", "1i\u0307"),
TranslateTestCase("İi\u0307", "xi\u0307", "123", "UTF8_BINARY_LCASE", "İ23"),
TranslateTestCase("İi\u0307", "İi\u0307", "123", "UNICODE", "123"),
TranslateTestCase("İi\u0307", "İyz", "123", "UNICODE", "1i\u0307"),
TranslateTestCase("İi\u0307", "xi\u0307", "123", "UNICODE", "İ23"),
TranslateTestCase("İi\u0307", "İi\u0307", "123", "UNICODE_CI", "123"),
TranslateTestCase("İi\u0307", "İyz", "123", "UNICODE_CI", "1i\u0307"),
TranslateTestCase("İi\u0307", "xi\u0307", "123", "UNICODE_CI", "İ23")
)

testCases.foreach(t => {
Expand Down