Skip to content

Commit e5e9815

Browse files
authored
[DynamoDB] Add multi-attribute composite key support for GSIs (#9710)
* [DynamoDB] Add multi-attribute composite key support for GSIs DynamoDB announced support for multi-attribute composite keys in GSIs (Nov 2025), allowing up to 4 attributes each for partition and sort keys. This adds support for creating and querying GSIs with multiple range keys, including validation of key ordering rules, composite key sorting, sparse GSI exclusion, pagination, and both KeyConditionExpression and legacy KeyConditions APIs. * Fix ruff formatting
1 parent c6bece7 commit e5e9815

7 files changed

Lines changed: 838 additions & 142 deletions

File tree

moto/dynamodb/models/__init__.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def get_item(
355355
def query(
356356
self,
357357
table_name: str,
358-
hash_key_dict: dict[str, Any],
358+
hash_key_dict: Optional[dict[str, Any]],
359359
range_comparison: Optional[str],
360360
range_value_dicts: list[dict[str, Any]],
361361
limit: int,
@@ -367,13 +367,26 @@ def query(
367367
expr_names: Optional[dict[str, str]] = None,
368368
expr_values: Optional[dict[str, dict[str, str]]] = None,
369369
filter_expression: Optional[str] = None,
370+
hash_key_conditions: Optional[list[tuple[str, dict[str, Any]]]] = None,
371+
range_key_conditions: Optional[
372+
list[tuple[str, str, list[dict[str, Any]]]]
373+
] = None,
370374
**filter_kwargs: Any,
371375
) -> tuple[list[Item], int, Optional[dict[str, Any]]]:
372376
table = self.get_table(table_name)
373377

374-
hash_key = DynamoType(hash_key_dict)
378+
hash_key = DynamoType(hash_key_dict) if hash_key_dict else None
375379
range_values = [DynamoType(range_value) for range_value in range_value_dicts]
376380

381+
# Convert key conditions to DynamoType
382+
hash_key_conditions_typed = [
383+
(name, DynamoType(value)) for name, value in (hash_key_conditions or [])
384+
]
385+
range_key_conditions_typed = [
386+
(name, comparison, [DynamoType(v) for v in values])
387+
for name, comparison, values in (range_key_conditions or [])
388+
]
389+
377390
filter_expression_op = get_filter_expression(
378391
filter_expression, expr_names, expr_values
379392
)
@@ -389,6 +402,8 @@ def query(
389402
index_name,
390403
consistent_read,
391404
filter_expression_op,
405+
hash_key_conditions=hash_key_conditions_typed,
406+
range_key_conditions=range_key_conditions_typed,
392407
**filter_kwargs,
393408
)
394409

moto/dynamodb/models/table.py

Lines changed: 125 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -748,7 +748,7 @@ def delete_item(
748748

749749
def query(
750750
self,
751-
hash_key: DynamoType,
751+
hash_key: Optional[DynamoType],
752752
range_comparison: Optional[str],
753753
range_objs: list[DynamoType],
754754
limit: int,
@@ -758,9 +758,19 @@ def query(
758758
index_name: Optional[str] = None,
759759
consistent_read: bool = False,
760760
filter_expression: Any = None,
761+
hash_key_conditions: Optional[list[tuple[str, DynamoType]]] = None,
762+
range_key_conditions: Optional[list[tuple[str, str, list[DynamoType]]]] = None,
761763
**filter_kwargs: Any,
762764
) -> tuple[list[Item], int, Optional[dict[str, Any]]]:
763765
# FIND POSSIBLE RESULTS
766+
# Initialize variables for range key handling
767+
index_range_key: Optional[dict[str, str]] = None
768+
last_range_key_name: Optional[str] = None
769+
770+
# Extract last_range_key_name from range_key_conditions if present
771+
if range_key_conditions:
772+
last_range_key_name = range_key_conditions[-1][0]
773+
764774
if index_name:
765775
all_indexes = self.all_indexes()
766776
indexes_by_name = {i.name: i for i in all_indexes}
@@ -777,20 +787,19 @@ def query(
777787
"Consistent reads are not supported on global secondary indexes"
778788
)
779789

780-
try:
781-
index_hash_key = [
782-
key for key in index.schema if key["KeyType"] == "HASH"
783-
][0]
784-
except IndexError:
790+
# Get ALL hash keys from schema (multi-attribute support)
791+
index_hash_keys = [key for key in index.schema if key["KeyType"] == "HASH"]
792+
if not index_hash_keys:
785793
raise MockValidationException(
786794
f"Missing Hash Key. KeySchema: {index.name}"
787795
)
788796

789-
try:
790-
index_range_key = [
791-
key for key in index.schema if key["KeyType"] == "RANGE"
792-
][0]
793-
except IndexError:
797+
# Get ALL range keys from schema (multi-attribute support)
798+
index_range_keys = [
799+
key for key in index.schema if key["KeyType"] == "RANGE"
800+
]
801+
802+
if not index_range_keys:
794803
if isinstance(index, GlobalSecondaryIndex) and self.range_key_attr:
795804
# If we're querying a GSI that does not have a range key, the main range key acts as a range key
796805
index_range_key = {"AttributeName": self.range_key_attr}
@@ -801,28 +810,81 @@ def query(
801810
raise ValueError(
802811
f"Range Key comparison but no range key found for index: {index_name}"
803812
)
813+
else:
814+
# For backward compatibility with single range key
815+
index_range_key = index_range_keys[0]
804816

805-
hash_attrs = [index_hash_key["AttributeName"], self.hash_key_attr]
806-
if index_range_key:
807-
range_attrs = [
808-
index_range_key["AttributeName"],
809-
self.range_key_attr,
817+
# Build hash_attrs for sorting: all index hash keys + table hash key
818+
hash_attrs = [k["AttributeName"] for k in index_hash_keys] + [
819+
self.hash_key_attr
820+
]
821+
# Build range_attrs for sorting: all index range keys + table range key
822+
# Note: For backward compatibility with _generate_attr_to_sort_by, we always
823+
# include table range key (even if None) when there's only one GSI range key
824+
if index_range_keys:
825+
range_attrs: list[Optional[str]] = [
826+
k["AttributeName"] for k in index_range_keys
810827
]
828+
# Always append table range key for backward compatibility with sorting
829+
range_attrs.append(self.range_key_attr)
830+
elif index_range_key:
831+
range_attrs = [index_range_key["AttributeName"], self.range_key_attr]
811832
else:
812833
range_attrs = [self.range_key_attr]
813834

835+
# Build a dict of all hash key conditions from new interface
836+
all_hash_conditions: dict[str, DynamoType] = {}
837+
if hash_key_conditions:
838+
for attr_name, value in hash_key_conditions:
839+
all_hash_conditions[attr_name] = value
840+
elif hash_key:
841+
# Backward compatibility: use legacy hash_key parameter
842+
all_hash_conditions[index_hash_keys[0]["AttributeName"]] = hash_key
843+
844+
# Build a dict of range key equalities (all but last in range_key_conditions)
845+
range_equality_conditions: dict[str, DynamoType] = {}
846+
if range_key_conditions and len(range_key_conditions) > 1:
847+
# All but the last range key condition are equalities
848+
for attr_name, _comparison, values in range_key_conditions[:-1]:
849+
range_equality_conditions[attr_name] = values[0]
850+
814851
possible_results = []
815852
for item in self.all_items():
816853
if not isinstance(item, Item):
817854
continue
818-
item_hash_key = item.attrs.get(hash_attrs[0])
819-
if len(range_attrs) == 1:
820-
if item_hash_key and item_hash_key == hash_key:
821-
possible_results.append(item)
822-
else:
823-
item_range_key = item.attrs.get(range_attrs[0]) # type: ignore
824-
if item_hash_key and item_hash_key == hash_key and item_range_key:
825-
possible_results.append(item)
855+
856+
# Check ALL hash key conditions
857+
hash_match = True
858+
for attr_name, expected_value in all_hash_conditions.items():
859+
item_value = item.attrs.get(attr_name)
860+
if not item_value or item_value != expected_value:
861+
hash_match = False
862+
break
863+
864+
if not hash_match:
865+
continue
866+
867+
# Check range key equality conditions (for multi-attribute range keys)
868+
range_equality_match = True
869+
for attr_name, expected_value in range_equality_conditions.items():
870+
item_value = item.attrs.get(attr_name)
871+
if not item_value or item_value != expected_value:
872+
range_equality_match = False
873+
break
874+
875+
if not range_equality_match:
876+
continue
877+
878+
# For GSI, ensure item has ALL range key attributes (DynamoDB only indexes
879+
# items that have all key attributes present)
880+
if index_range_keys:
881+
has_all_range_keys = all(
882+
item.attrs.get(key["AttributeName"]) for key in index_range_keys
883+
)
884+
if not has_all_range_keys:
885+
continue
886+
887+
possible_results.append(item)
826888
else:
827889
hash_attrs = [self.hash_key_attr]
828890
range_attrs = [self.range_key_attr]
@@ -882,18 +944,26 @@ def query(
882944
scanned_count += 1
883945

884946
if range_comparison:
885-
if (
886-
index_name
887-
and index_range_key
888-
and result.attrs.get(index_range_key["AttributeName"])
947+
# Determine which attribute to apply the range comparison to
948+
range_attr_for_comparison: Optional[str] = None
949+
if last_range_key_name:
950+
# Multi-attribute key: use the specific range key from the query
951+
range_attr_for_comparison = last_range_key_name
952+
elif index_name and index_range_key:
953+
# Single range key GSI: use the index range key
954+
range_attr_for_comparison = index_range_key["AttributeName"]
955+
956+
if range_attr_for_comparison and result.attrs.get(
957+
range_attr_for_comparison
889958
):
890-
if result.attrs.get(index_range_key["AttributeName"]).compare( # type: ignore
959+
if result.attrs.get(range_attr_for_comparison).compare( # type: ignore
891960
range_comparison, range_objs
892961
):
893962
results.append(result)
894963
result_size += result.size()
895964
scanned_count += 1
896-
else:
965+
elif not index_name:
966+
# Table query (not GSI): use the table's range key
897967
if result.range_key.compare(range_comparison, range_objs): # type: ignore[union-attr]
898968
results.append(result)
899969
result_size += result.size()
@@ -1137,21 +1207,31 @@ def sorted_items(
11371207
def _generate_attr_to_sort_by(
11381208
self, hash_key_attrs: list[str], range_key_attrs: list[Optional[str]]
11391209
) -> list[str]:
1140-
gsi_hash_key = hash_key_attrs[0] if len(hash_key_attrs) == 2 else None
1141-
table_hash_key = str(
1142-
hash_key_attrs[0] if gsi_hash_key is None else hash_key_attrs[1]
1143-
)
1144-
gsi_range_key = range_key_attrs[0] if len(range_key_attrs) == 2 else None
1145-
table_range_key = str(
1146-
range_key_attrs[0] if gsi_range_key is None else range_key_attrs[1]
1147-
)
1148-
# Gets the GSI and table hash and range keys in the order to try sorting by
1149-
attrs_to_sort_by = [
1150-
gsi_hash_key,
1151-
gsi_range_key,
1152-
table_hash_key,
1153-
table_range_key,
1154-
]
1210+
# For GSI queries, hash_key_attrs = [gsi_hash_keys..., table_hash_key]
1211+
# and range_key_attrs = [gsi_range_keys..., table_range_key]
1212+
# For table queries, hash_key_attrs = [table_hash_key]
1213+
# and range_key_attrs = [table_range_key]
1214+
1215+
# Extract GSI keys (all but last) and table keys (last)
1216+
if len(hash_key_attrs) > 1:
1217+
# GSI query
1218+
gsi_hash_keys = hash_key_attrs[:-1]
1219+
table_hash_key = hash_key_attrs[-1]
1220+
gsi_range_keys = [k for k in range_key_attrs[:-1] if k is not None]
1221+
table_range_key = range_key_attrs[-1]
1222+
else:
1223+
# Table query
1224+
gsi_hash_keys = []
1225+
table_hash_key = hash_key_attrs[0]
1226+
gsi_range_keys = []
1227+
table_range_key = range_key_attrs[0] if range_key_attrs else None
1228+
1229+
# Sort order: GSI hash keys, GSI range keys, table hash key, table range key
1230+
attrs_to_sort_by: list[Optional[str]] = []
1231+
attrs_to_sort_by.extend(gsi_hash_keys)
1232+
attrs_to_sort_by.extend(gsi_range_keys)
1233+
attrs_to_sort_by.append(table_hash_key)
1234+
attrs_to_sort_by.append(table_range_key)
11551235
return [
11561236
attr for attr in attrs_to_sort_by if attr is not None and attr != "None"
11571237
]

0 commit comments

Comments
 (0)