diff --git a/mypy/checker.py b/mypy/checker.py index 0ccd6cd2b409..4cb284d18d9a 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -26,7 +26,7 @@ ARG_POS, ARG_STAR, LITERAL_TYPE, LDEF, MDEF, GDEF, CONTRAVARIANT, COVARIANT, INVARIANT, TypeVarExpr, AssignmentExpr, is_final_node, - ARG_NAMED) + ARG_NAMED, BinOp) from mypy import nodes from mypy import operators from mypy.literals import literal, literal_hash, Key @@ -2076,7 +2076,7 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None: self.fail(message_registry.DEPENDENT_FINAL_IN_CLASS_BODY, s) def check_type_alias_rvalue(self, s: AssignmentStmt) -> None: - if not (self.is_stub and isinstance(s.rvalue, OpExpr) and s.rvalue.op == '|'): + if not (self.is_stub and isinstance(s.rvalue, OpExpr) and s.rvalue.op == BinOp.BitOr): # We do this mostly for compatibility with old semantic analyzer. # TODO: should we get rid of this? alias_type = self.expr_checker.accept(s.rvalue) @@ -2086,7 +2086,7 @@ def check_type_alias_rvalue(self, s: AssignmentStmt) -> None: alias_type = AnyType(TypeOfAny.special_form) def accept_items(e: Expression) -> None: - if isinstance(e, OpExpr) and e.op == '|': + if isinstance(e, OpExpr) and e.op == BinOp.BitOr: accept_items(e.left) accept_items(e.right) else: @@ -4491,7 +4491,7 @@ def has_no_custom_eq_checks(t: Type) -> bool: (None if if_assignment_map is None or if_condition_map is None else if_map), (None if else_assignment_map is None or else_condition_map is None else else_map), ) - elif isinstance(node, OpExpr) and node.op == 'and': + elif isinstance(node, OpExpr) and node.op == BinOp.And: left_if_vars, left_else_vars = self.find_isinstance_check(node.left) right_if_vars, right_else_vars = self.find_isinstance_check(node.right) @@ -4499,7 +4499,7 @@ def has_no_custom_eq_checks(t: Type) -> bool: # and false if at least one of e1 and e2 is false. return (and_conditional_maps(left_if_vars, right_if_vars), or_conditional_maps(left_else_vars, right_else_vars)) - elif isinstance(node, OpExpr) and node.op == 'or': + elif isinstance(node, OpExpr) and node.op == BinOp.Or: left_if_vars, left_else_vars = self.find_isinstance_check(node.left) right_if_vars, right_else_vars = self.find_isinstance_check(node.right) @@ -5563,7 +5563,7 @@ def flatten_types(t: Type) -> List[Type]: def get_isinstance_type(expr: Expression, type_map: Dict[Expression, Type]) -> Optional[List[TypeRange]]: - if isinstance(expr, OpExpr) and expr.op == '|': + if isinstance(expr, OpExpr) and expr.op == BinOp.BitOr: left = get_isinstance_type(expr.left, type_map) right = get_isinstance_type(expr.right, type_map) if left is None or right is None: @@ -5762,7 +5762,7 @@ def is_same_arg_prefix(t: CallableType, s: CallableType) -> bool: ignore_pos_arg_names=True) -def infer_operator_assignment_method(typ: Type, operator: str) -> Tuple[bool, str]: +def infer_operator_assignment_method(typ: Type, operator: BinOp) -> Tuple[bool, str]: """Determine if operator assignment on given value type is in-place, and the method name. For example, if operator is '+', return (True, '__iadd__') or (False, '__add__') diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 280e9a35d537..cb71e8c2fd0d 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -32,7 +32,7 @@ DictionaryComprehension, ComplexExpr, EllipsisExpr, StarExpr, AwaitExpr, YieldExpr, YieldFromExpr, TypedDictExpr, PromoteExpr, NewTypeExpr, NamedTupleExpr, TypeVarExpr, TypeAliasExpr, BackquoteExpr, EnumCallExpr, TypeAlias, SymbolNode, PlaceholderNode, - ParamSpecExpr, + ParamSpecExpr, BinOp, ArgKind, ARG_POS, ARG_NAMED, ARG_STAR, ARG_STAR2, LITERAL_TYPE, REVEAL_TYPE, ) from mypy.literals import literal @@ -2156,12 +2156,12 @@ def visit_ellipsis(self, e: EllipsisExpr) -> Type: def visit_op_expr(self, e: OpExpr) -> Type: """Type check a binary operator expression.""" - if e.op == 'and' or e.op == 'or': + if e.op.is_boolean(): return self.check_boolean_op(e, e) - if e.op == '*' and isinstance(e.left, ListExpr): + if e.op == BinOp.Mul and isinstance(e.left, ListExpr): # Expressions of form [...] * e get special type inference. return self.check_list_multiply(e) - if e.op == '%': + if e.op == BinOp.Mod: pyversion = self.chk.options.python_version if pyversion[0] == 3: if isinstance(e.left, BytesExpr) and pyversion[1] >= 5: @@ -2174,7 +2174,7 @@ def visit_op_expr(self, e: OpExpr) -> Type: left_type = self.accept(e.left) proper_left_type = get_proper_type(left_type) - if isinstance(proper_left_type, TupleType) and e.op == '+': + if isinstance(proper_left_type, TupleType) and e.op == BinOp.Add: left_add_method = proper_left_type.partial_fallback.type.get('__add__') if left_add_method and left_add_method.fullname == 'builtins.tuple.__add__': proper_right_type = get_proper_type(self.accept(e.right)) @@ -2371,8 +2371,8 @@ def dangerous_comparison(self, left: Type, right: Type, return False return not is_overlapping_types(left, right, ignore_promotions=False) - def get_operator_method(self, op: str) -> str: - if op == '/' and self.chk.options.python_version[0] == 2: + def get_operator_method(self, op: BinOp) -> str: + if op == BinOp.Div and self.chk.options.python_version[0] == 2: # TODO also check for "from __future__ import division" return '__div__' else: @@ -2788,15 +2788,15 @@ def check_boolean_op(self, e: OpExpr, context: Context) -> Type: ctx = self.type_context[-1] left_type = self.accept(e.left, ctx) - assert e.op in ('and', 'or') # Checked by visit_op_expr + assert e.op.is_boolean() # Checked by visit_op_expr if e.right_always: left_map, right_map = None, {} # type: mypy.checker.TypeMap, mypy.checker.TypeMap elif e.right_unreachable: left_map, right_map = {}, None - elif e.op == 'and': + elif e.op == BinOp.And: right_map, left_map = self.chk.find_isinstance_check(e.left) - elif e.op == 'or': + elif e.op == BinOp.Or: left_map, right_map = self.chk.find_isinstance_check(e.left) # If left_map is None then we know mypy considers the left expression @@ -2832,10 +2832,10 @@ def check_boolean_op(self, e: OpExpr, context: Context) -> Type: assert right_map is not None # find_isinstance_check guarantees this return right_type - if e.op == 'and': + if e.op == BinOp.And: restricted_left_type = false_only(left_type) result_is_left = not left_type.can_be_true - elif e.op == 'or': + elif e.op == BinOp.Or: restricted_left_type = true_only(left_type) result_is_left = not left_type.can_be_false diff --git a/mypy/exprtotype.py b/mypy/exprtotype.py index 8f6f6c11f346..8ff576577f5f 100644 --- a/mypy/exprtotype.py +++ b/mypy/exprtotype.py @@ -13,6 +13,7 @@ RawExpressionType, ProperType, UnionType ) from mypy.options import Options +from mypy.operators import BinOp class TypeTranslationError(Exception): @@ -86,7 +87,7 @@ def expr_to_unanalyzed_type(expr: Expression, else: raise TypeTranslationError() elif (isinstance(expr, OpExpr) - and expr.op == '|' + and expr.op == BinOp.BitOr and ((options and options.python_version >= (3, 10)) or allow_new_syntax)): return UnionType([expr_to_unanalyzed_type(expr.left, options, allow_new_syntax), expr_to_unanalyzed_type(expr.right, options, allow_new_syntax)]) diff --git a/mypy/fastparse.py b/mypy/fastparse.py index 2f4122bf3bfa..aae22691457e 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -38,6 +38,7 @@ from mypy import message_registry, errorcodes as codes from mypy.errors import Errors from mypy.options import Options +from mypy.operators import BinOp from mypy.reachability import mark_block_unreachable try: @@ -402,12 +403,12 @@ def translate_type_comment(self, ast3.FloorDiv: '//' } - def from_operator(self, op: ast3.operator) -> str: + def from_operator(self, op: ast3.operator) -> BinOp: op_name = ASTConverter.op_map.get(type(op)) if op_name is None: raise RuntimeError('Unknown operator ' + str(type(op))) else: - return op_name + return BinOp(op_name) comp_op_map: Final[Dict[typing.Type[AST], str]] = { ast3.Gt: '>', @@ -422,12 +423,12 @@ def from_operator(self, op: ast3.operator) -> str: ast3.NotIn: 'not in' } - def from_comp_operator(self, op: ast3.cmpop) -> str: + def from_comp_operator(self, op: ast3.cmpop) -> BinOp: op_name = ASTConverter.comp_op_map.get(type(op)) if op_name is None: raise RuntimeError('Unknown comparison operator ' + str(type(op))) else: - return op_name + return BinOp(op_name) def as_block(self, stmts: List[ast3.stmt], lineno: int) -> Optional[Block]: b = None @@ -966,9 +967,9 @@ def visit_BoolOp(self, n: ast3.BoolOp) -> OpExpr: def group(self, op: str, vals: List[Expression], n: ast3.expr) -> OpExpr: if len(vals) == 2: - e = OpExpr(op, vals[0], vals[1]) + e = OpExpr(BinOp(op), vals[0], vals[1]) else: - e = OpExpr(op, vals[0], self.group(op, vals[1:], n)) + e = OpExpr(BinOp(op), vals[0], self.group(op, vals[1:], n)) return self.set_line(e, n) # BinOp(expr left, operator op, expr right) @@ -978,7 +979,7 @@ def visit_BinOp(self, n: ast3.BinOp) -> OpExpr: if op is None: raise RuntimeError('cannot translate BinOp ' + str(type(n.op))) - e = OpExpr(op, self.visit(n.left), self.visit(n.right)) + e = OpExpr(BinOp(op), self.visit(n.left), self.visit(n.right)) return self.set_line(e, n) # UnaryOp(unaryop op, expr operand) diff --git a/mypy/fastparse2.py b/mypy/fastparse2.py index 2d288bf158e5..35155a1bdcfb 100644 --- a/mypy/fastparse2.py +++ b/mypy/fastparse2.py @@ -51,6 +51,7 @@ TYPE_IGNORE_PATTERN, INVALID_TYPE_IGNORE ) from mypy.options import Options +from mypy.operators import BinOp from mypy.reachability import mark_block_unreachable try: @@ -258,14 +259,14 @@ def translate_type_comment(self, n: ast27.stmt, ast27.FloorDiv: '//' } - def from_operator(self, op: ast27.operator) -> str: + def from_operator(self, op: ast27.operator) -> BinOp: op_name = ASTConverter.op_map.get(type(op)) if op_name is None: raise RuntimeError('Unknown operator ' + str(type(op))) elif op_name == '@': raise RuntimeError('mypy does not support the MatMult operator') else: - return op_name + return BinOp(op_name) comp_op_map: Final[Dict[typing.Type[AST], str]] = { ast27.Gt: '>', @@ -280,12 +281,12 @@ def from_operator(self, op: ast27.operator) -> str: ast27.NotIn: 'not in' } - def from_comp_operator(self, op: ast27.cmpop) -> str: + def from_comp_operator(self, op: ast27.cmpop) -> BinOp: op_name = ASTConverter.comp_op_map.get(type(op)) if op_name is None: raise RuntimeError('Unknown comparison operator ' + str(type(op))) else: - return op_name + return BinOp(op_name) def as_block(self, stmts: List[ast27.stmt], lineno: int) -> Optional[Block]: b = None @@ -813,9 +814,9 @@ def visit_BoolOp(self, n: ast27.BoolOp) -> OpExpr: def group(self, vals: List[Expression], op: str) -> OpExpr: if len(vals) == 2: - return OpExpr(op, vals[0], vals[1]) + return OpExpr(BinOp(op), vals[0], vals[1]) else: - return OpExpr(op, vals[0], self.group(vals[1:], op)) + return OpExpr(BinOp(op), vals[0], self.group(vals[1:], op)) # BinOp(expr left, operator op, expr right) def visit_BinOp(self, n: ast27.BinOp) -> OpExpr: @@ -824,7 +825,7 @@ def visit_BinOp(self, n: ast27.BinOp) -> OpExpr: if op is None: raise RuntimeError('cannot translate BinOp ' + str(type(n.op))) - e = OpExpr(op, self.visit(n.left), self.visit(n.right)) + e = OpExpr(BinOp(op), self.visit(n.left), self.visit(n.right)) return self.set_line(e, n) # UnaryOp(unaryop op, expr operand) diff --git a/mypy/messages.py b/mypy/messages.py index 086f3d22aee1..6f9660cf8a09 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -33,7 +33,7 @@ ReturnStmt, NameExpr, Var, CONTRAVARIANT, COVARIANT, SymbolNode, CallExpr, IndexExpr, StrExpr, SymbolTable, TempNode, SYMBOL_FUNCBASE_TYPES ) -from mypy.operators import op_methods, op_methods_to_symbols +from mypy.operators import BinOp, op_methods, op_methods_to_symbols from mypy.subtypes import ( is_subtype, find_member, get_member_flags, IS_SETTABLE, IS_CLASSVAR, IS_CLASS_OR_STATIC, @@ -437,10 +437,10 @@ def incompatible_argument(self, base = extract_type(name) for method, op in op_methods_to_symbols.items(): - for variant in method, '__r' + method[2:]: + for variant in (method, '__r' + method[2:]): # FIX: do not rely on textual formatting if name.startswith('"{}" of'.format(variant)): - if op == 'in' or variant != method: + if op == BinOp.In or variant != method: # Reversed order of base/argument. self.unsupported_operand_types(op, arg_type, base, context, code=codes.OPERATOR) diff --git a/mypy/nodes.py b/mypy/nodes.py index 435ffa9293cb..1f926b987b69 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -14,7 +14,7 @@ import mypy.strconv from mypy.util import short_type from mypy.visitor import NodeVisitor, StatementVisitor, ExpressionVisitor - +from mypy.operators import BinOp from mypy.bogus_type import Bogus @@ -1134,11 +1134,11 @@ class OperatorAssignmentStmt(Statement): __slots__ = ('op', 'lvalue', 'rvalue') - op: str # TODO: Enum? + op: BinOp lvalue: Lvalue rvalue: Expression - def __init__(self, op: str, lvalue: Lvalue, rvalue: Expression) -> None: + def __init__(self, op: BinOp, lvalue: Lvalue, rvalue: Expression) -> None: super().__init__() self.op = op self.lvalue = lvalue @@ -1796,7 +1796,7 @@ class OpExpr(Expression): __slots__ = ('op', 'left', 'right', 'method_type', 'right_always', 'right_unreachable') - op: str # TODO: Enum? + op: BinOp left: Expression right: Expression # Inferred type for the operator method type (when relevant). @@ -1806,7 +1806,7 @@ class OpExpr(Expression): # Per static analysis only: Is the right side unreachable? right_unreachable: bool - def __init__(self, op: str, left: Expression, right: Expression) -> None: + def __init__(self, op: BinOp, left: Expression, right: Expression) -> None: super().__init__() self.op = op self.left = left @@ -1824,18 +1824,18 @@ class ComparisonExpr(Expression): __slots__ = ('operators', 'operands', 'method_types') - operators: List[str] + operators: List[BinOp] operands: List[Expression] # Inferred type for the operator methods (when relevant; None for 'is'). method_types: List[Optional["mypy.types.Type"]] - def __init__(self, operators: List[str], operands: List[Expression]) -> None: + def __init__(self, operators: List[BinOp], operands: List[Expression]) -> None: super().__init__() self.operators = operators self.operands = operands self.method_types = [] - def pairwise(self) -> Iterator[Tuple[str, Expression, Expression]]: + def pairwise(self) -> Iterator[Tuple[BinOp, Expression, Expression]]: """If this comparison expr is "a < b is c == d", yields the sequence ("<", a, b), ("is", b, c), ("==", c, d) """ diff --git a/mypy/operators.py b/mypy/operators.py index aa26cb2ec6e9..ef0cecde6098 100644 --- a/mypy/operators.py +++ b/mypy/operators.py @@ -1,34 +1,104 @@ """Information about Python operators""" +import enum from typing_extensions import Final +@enum.unique +class BinOp(str, enum.Enum): + """Represents all possible operators in Python. + + Copies the same names as ``ast`` module does. + + Note, that some operators cannot be used in some Python versions. + For example, ``@`` does not exist in Python2. + """ + + # boolops: + And = 'and' + Or = 'or' + + # operators: + Add = '+' + BitAnd = '&' + BitOr = '|' + BitXor = '^' + Div = '/' + DivMod = 'divmod' + FloorDiv = '//' + LShift = '<<' + Mod = '%' + Mul = '*' + MatMult = '@' + Pow = '**' + RShift = '>>' + Sub = '-' + + # cmpops: + Eq = '==' + Gt = '>' + GtE = '>=' + In = 'in' + Is = 'is' + IsNot = 'is not' + Lt = '<' + LtE = '<=' + NotEq = '!=' + NotIn = 'not in' + + def is_numeric_compare(self) -> bool: + return self in {BinOp.Eq, BinOp.NotEq, BinOp.LtE, BinOp.Lt, BinOp.GtE, BinOp.Gt} + + def is_boolean(self) -> bool: + return self in {BinOp.And, BinOp.Or} + + def is_equality(self) -> bool: + return self in {BinOp.Eq, BinOp.NotEq} + + def is_contains(self) -> bool: + return self in {BinOp.In, BinOp.NotIn} + + def is_identity(self) -> bool: + return self in {BinOp.Is, BinOp.IsNot} + + +# Map reverse binary numberic operators, if possible: +reverse_op: Final = { + BinOp.Eq: BinOp.Eq, + BinOp.NotEq: BinOp.NotEq, + BinOp.Lt: BinOp.Gt, + BinOp.Gt: BinOp.Lt, + BinOp.LtE: BinOp.GtE, + BinOp.GtE: BinOp.LtE, +} + + # Map from binary operator id to related method name (in Python 3). op_methods: Final = { - '+': '__add__', - '-': '__sub__', - '*': '__mul__', - '/': '__truediv__', - '%': '__mod__', - 'divmod': '__divmod__', - '//': '__floordiv__', - '**': '__pow__', - '@': '__matmul__', - '&': '__and__', - '|': '__or__', - '^': '__xor__', - '<<': '__lshift__', - '>>': '__rshift__', - '==': '__eq__', - '!=': '__ne__', - '<': '__lt__', - '>=': '__ge__', - '>': '__gt__', - '<=': '__le__', - 'in': '__contains__', + BinOp.Add: '__add__', + BinOp.Sub: '__sub__', + BinOp.Mul: '__mul__', + BinOp.Div: '__truediv__', + BinOp.Mod: '__mod__', + BinOp.DivMod: '__divmod__', + BinOp.FloorDiv: '__floordiv__', + BinOp.Pow: '__pow__', + BinOp.MatMult: '__matmul__', + BinOp.BitAnd: '__and__', + BinOp.BitOr: '__or__', + BinOp.BitXor: '__xor__', + BinOp.LShift: '__lshift__', + BinOp.RShift: '__rshift__', + BinOp.Eq: '__eq__', + BinOp.NotEq: '__ne__', + BinOp.Lt: '__lt__', + BinOp.LtE: '__le__', + BinOp.Gt: '__gt__', + BinOp.GtE: '__ge__', + BinOp.In: '__contains__', } -op_methods_to_symbols: Final = {v: k for (k, v) in op_methods.items()} +op_methods_to_symbols: Final = {v: k.value for (k, v) in op_methods.items()} op_methods_to_symbols['__div__'] = '/' comparison_fallback_method: Final = "__cmp__" @@ -36,19 +106,19 @@ ops_with_inplace_method: Final = { - "+", - "-", - "*", - "/", - "%", - "//", - "**", - "@", - "&", - "|", - "^", - "<<", - ">>", + BinOp.Add, + BinOp.Sub, + BinOp.Mul, + BinOp.Div, + BinOp.Mod, + BinOp.FloorDiv, + BinOp.Pow, + BinOp.MatMult, + BinOp.BitAnd, + BinOp.BitOr, + BinOp.BitXor, + BinOp.LShift, + BinOp.RShift, } inplace_operator_methods: Final = set("__i" + op_methods[op][2:] for op in ops_with_inplace_method) diff --git a/mypy/reachability.py b/mypy/reachability.py index 44a21b993cfc..c68f798b9274 100644 --- a/mypy/reachability.py +++ b/mypy/reachability.py @@ -9,6 +9,7 @@ ImportAll, LITERAL_YES ) from mypy.options import Options +from mypy.operators import BinOp, reverse_op from mypy.traverser import TraverserVisitor from mypy.literals import literal @@ -27,15 +28,6 @@ MYPY_FALSE: MYPY_TRUE, } -reverse_op: Final = { - "==": "==", - "!=": "!=", - "<": ">", - ">": "<", - "<=": ">=", - ">=": "<=", -} - def infer_reachability_of_if_statement(s: IfStmt, options: Options) -> None: for i in range(len(s.expr)): @@ -87,10 +79,10 @@ def infer_condition_value(expr: Expression, options: Options) -> int: name = expr.name elif isinstance(expr, MemberExpr): name = expr.name - elif isinstance(expr, OpExpr) and expr.op in ('and', 'or'): + elif isinstance(expr, OpExpr) and expr.op.is_boolean(): left = infer_condition_value(expr.left, options) - if ((left in (ALWAYS_TRUE, MYPY_TRUE) and expr.op == 'and') or - (left in (ALWAYS_FALSE, MYPY_FALSE) and expr.op == 'or')): + if ((left in (ALWAYS_TRUE, MYPY_TRUE) and expr.op == BinOp.And) or + (left in (ALWAYS_FALSE, MYPY_FALSE) and expr.op == BinOp.Or)): # Either `True and ` or `False or `: the result will # always be the right-hand-side. return infer_condition_value(expr.right, options) @@ -134,7 +126,7 @@ def consider_sys_version_info(expr: Expression, pyversion: Tuple[int, ...]) -> i if len(expr.operators) > 1: return TRUTH_VALUE_UNKNOWN op = expr.operators[0] - if op not in ('==', '!=', '<=', '>=', '<', '>'): + if not op.is_numeric_compare(): return TRUTH_VALUE_UNKNOWN index = contains_sys_version_info(expr.operands[0]) @@ -157,7 +149,7 @@ def consider_sys_version_info(expr: Expression, pyversion: Tuple[int, ...]) -> i hi = 2 if 0 <= lo < hi <= 2: val = pyversion[lo:hi] - if len(val) == len(thing) or len(val) > len(thing) and op not in ('==', '!='): + if len(val) == len(thing) or len(val) > len(thing) and not op.is_equality(): return fixed_comparison(val, op, thing) return TRUTH_VALUE_UNKNOWN @@ -176,7 +168,7 @@ def consider_sys_platform(expr: Expression, platform: str) -> int: if len(expr.operators) > 1: return TRUTH_VALUE_UNKNOWN op = expr.operators[0] - if op not in ('==', '!='): + if not op.is_equality(): return TRUTH_VALUE_UNKNOWN if not is_sys_attr(expr.operands[0], 'platform'): return TRUTH_VALUE_UNKNOWN @@ -204,19 +196,19 @@ def consider_sys_platform(expr: Expression, platform: str) -> int: Targ = TypeVar('Targ', int, str, Tuple[int, ...]) -def fixed_comparison(left: Targ, op: str, right: Targ) -> int: +def fixed_comparison(left: Targ, op: BinOp, right: Targ) -> int: rmap = {False: ALWAYS_FALSE, True: ALWAYS_TRUE} - if op == '==': + if op == BinOp.Eq: return rmap[left == right] - if op == '!=': + if op == BinOp.NotEq: return rmap[left != right] - if op == '<=': + if op == BinOp.LtE: return rmap[left <= right] - if op == '>=': + if op == BinOp.GtE: return rmap[left >= right] - if op == '<': + if op == BinOp.Lt: return rmap[left < right] - if op == '>': + if op == BinOp.Gt: return rmap[left > right] return TRUTH_VALUE_UNKNOWN diff --git a/mypy/semanal.py b/mypy/semanal.py index 49c24cde0447..ab2d85fd4747 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -103,6 +103,7 @@ ) from mypy.exprtotype import expr_to_unanalyzed_type, TypeTranslationError from mypy.options import Options +from mypy.operators import BinOp from mypy.plugin import ( Plugin, ClassDefContext, SemanticAnalyzerPluginInterface, DynamicClassDefContext @@ -2134,7 +2135,7 @@ def can_be_type_alias(self, rv: Expression, allow_none: bool = False) -> bool: if allow_none and isinstance(rv, NameExpr) and rv.fullname == 'builtins.None': return True if (isinstance(rv, OpExpr) - and rv.op == '|' + and rv.op == BinOp.BitOr and self.can_be_type_alias(rv.left, allow_none=True) and self.can_be_type_alias(rv.right, allow_none=True)): return True @@ -3791,7 +3792,7 @@ def visit_call_expr(self, expr: CallExpr) -> None: elif refers_to_fullname(expr.callee, 'builtins.divmod'): if not self.check_fixed_args(expr, 2, 'divmod'): return - expr.analyzed = OpExpr('divmod', expr.args[0], expr.args[1]) + expr.analyzed = OpExpr(BinOp.DivMod, expr.args[0], expr.args[1]) expr.analyzed.line = expr.line expr.analyzed.accept(self) else: @@ -3894,14 +3895,14 @@ def visit_member_expr(self, expr: MemberExpr) -> None: def visit_op_expr(self, expr: OpExpr) -> None: expr.left.accept(self) - if expr.op in ('and', 'or'): + if expr.op.is_boolean(): inferred = infer_condition_value(expr.left, self.options) - if ((inferred in (ALWAYS_FALSE, MYPY_FALSE) and expr.op == 'and') or - (inferred in (ALWAYS_TRUE, MYPY_TRUE) and expr.op == 'or')): + if ((inferred in (ALWAYS_FALSE, MYPY_FALSE) and expr.op == BinOp.And) or + (inferred in (ALWAYS_TRUE, MYPY_TRUE) and expr.op == BinOp.Or)): expr.right_unreachable = True return - elif ((inferred in (ALWAYS_TRUE, MYPY_TRUE) and expr.op == 'and') or - (inferred in (ALWAYS_FALSE, MYPY_FALSE) and expr.op == 'or')): + elif ((inferred in (ALWAYS_TRUE, MYPY_TRUE) and expr.op == BinOp.And) or + (inferred in (ALWAYS_FALSE, MYPY_FALSE) and expr.op == BinOp.Or)): expr.right_always = True expr.right.accept(self) diff --git a/mypy/server/deps.py b/mypy/server/deps.py index f80673fdb7d4..7fffe887fff7 100644 --- a/mypy/server/deps.py +++ b/mypy/server/deps.py @@ -92,7 +92,7 @@ class 'mod.Cls'. This can also refer to an attribute inherited from a LDEF, MDEF, GDEF, TypeAliasExpr, NewTypeExpr, ImportAll, EnumCallExpr, AwaitExpr ) from mypy.operators import ( - op_methods, reverse_op_methods, ops_with_inplace_method, unary_op_methods + BinOp, op_methods, reverse_op_methods, ops_with_inplace_method, unary_op_methods ) from mypy.traverser import TraverserVisitor from mypy.types import ( @@ -709,14 +709,14 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> None: left = e.operands[i] right = e.operands[i + 1] self.process_binary_op(op, left, right) - if self.python2 and op in ('==', '!=', '<', '<=', '>', '>='): + if self.python2 and op.is_numeric_compare(): self.add_operator_method_dependency(left, '__cmp__') self.add_operator_method_dependency(right, '__cmp__') - def process_binary_op(self, op: str, left: Expression, right: Expression) -> None: + def process_binary_op(self, op: BinOp, left: Expression, right: Expression) -> None: method = op_methods.get(op) if method: - if op == 'in': + if op == BinOp.In: self.add_operator_method_dependency(right, method) else: self.add_operator_method_dependency(left, method) diff --git a/mypy/strconv.py b/mypy/strconv.py index c63063af0776..4daba5dbf384 100644 --- a/mypy/strconv.py +++ b/mypy/strconv.py @@ -214,7 +214,7 @@ def visit_assignment_stmt(self, o: 'mypy.nodes.AssignmentStmt') -> str: return self.dump(a, o) def visit_operator_assignment_stmt(self, o: 'mypy.nodes.OperatorAssignmentStmt') -> str: - return self.dump([o.op, o.lvalue, o.rvalue], o) + return self.dump([o.op.value, o.lvalue, o.rvalue], o) def visit_while_stmt(self, o: 'mypy.nodes.WhileStmt') -> str: a: List[Any] = [o.expr, o.body] @@ -410,10 +410,10 @@ def visit_call_expr(self, o: 'mypy.nodes.CallExpr') -> str: return self.dump(a + extra, o) def visit_op_expr(self, o: 'mypy.nodes.OpExpr') -> str: - return self.dump([o.op, o.left, o.right], o) + return self.dump([o.op.value, o.left, o.right], o) def visit_comparison_expr(self, o: 'mypy.nodes.ComparisonExpr') -> str: - return self.dump([o.operators, o.operands], o) + return self.dump([[op.value for op in o.operators], o.operands], o) def visit_cast_expr(self, o: 'mypy.nodes.CastExpr') -> str: return self.dump([o.expr, o.type], o) diff --git a/mypyc/irbuild/builder.py b/mypyc/irbuild/builder.py index 57baa8dbf574..c21566cfebee 100644 --- a/mypyc/irbuild/builder.py +++ b/mypyc/irbuild/builder.py @@ -30,6 +30,7 @@ from mypy.maptype import map_instance_to_supertype from mypy.visitor import ExpressionVisitor, StatementVisitor from mypy.util import split_target +from mypy.operators import BinOp from mypyc.common import TEMP_ATTR_NAME, SELF_NAME from mypyc.irbuild.prebuildvisitor import PreBuildVisitor @@ -244,7 +245,7 @@ def new_set_op(self, values: List[Value], line: int) -> Value: def translate_is_op(self, lreg: Value, rreg: Value, - expr_op: str, + expr_op: BinOp, line: int) -> Value: return self.builder.translate_is_op(lreg, rreg, expr_op, line) @@ -283,10 +284,10 @@ def call_c(self, desc: CFunctionDescription, args: List[Value], line: int) -> Va def int_op(self, type: RType, lhs: Value, rhs: Value, op: int, line: int) -> Value: return self.builder.int_op(type, lhs, rhs, op, line) - def compare_tagged(self, lhs: Value, rhs: Value, op: str, line: int) -> Value: + def compare_tagged(self, lhs: Value, rhs: Value, op: BinOp, line: int) -> Value: return self.builder.compare_tagged(lhs, rhs, op, line) - def compare_tuples(self, lhs: Value, rhs: Value, op: str, line: int) -> Value: + def compare_tuples(self, lhs: Value, rhs: Value, op: BinOp, line: int) -> Value: return self.builder.compare_tuples(lhs, rhs, op, line) def builtin_len(self, val: Value, line: int) -> Value: @@ -339,7 +340,7 @@ def check_if_module_loaded(self, id: str, line: int, needs_import: the BasicBlock that is run if the module has not been loaded yet out: the BasicBlock that is run if the module has already been loaded""" first_load = self.load_module(id) - comparison = self.translate_is_op(first_load, self.none_object(), 'is not', line) + comparison = self.translate_is_op(first_load, self.none_object(), BinOp.IsNot, line) self.add_bool_branch(comparison, out, needs_import) def get_module(self, module: str, line: int) -> Value: @@ -643,7 +644,7 @@ def process_iterator_tuple_assignment(self, iter_list = self.call_c(to_list, [iterator], line) iter_list_len = self.builtin_len(iter_list, line) post_star_len = Integer(len(post_star_vals)) - condition = self.binary_op(post_star_len, iter_list_len, '<=', line) + condition = self.binary_op(post_star_len, iter_list_len, BinOp.LtE, line) error_block, ok_block = BasicBlock(), BasicBlock() self.add(Branch(condition, ok_block, error_block, Branch.BOOL)) @@ -901,8 +902,8 @@ def shortcircuit_expr(self, expr: OpExpr) -> Value: # Conditional expressions def process_conditional(self, e: Expression, true: BasicBlock, false: BasicBlock) -> None: - if isinstance(e, OpExpr) and e.op in ['and', 'or']: - if e.op == 'and': + if isinstance(e, OpExpr) and e.op.is_boolean(): + if e.op == BinOp.And: # Short circuit 'and' in a conditional context. new = BasicBlock() self.process_conditional(e.left, new, false) @@ -945,7 +946,7 @@ def maybe_process_conditional_comparison(self, if not is_tagged(ltype) or not is_tagged(rtype): return False op = e.operators[0] - if op not in ('==', '!=', '<', '<=', '>', '>='): + if not op.is_numeric_compare(): return False left = self.accept(e.operands[0]) right = self.accept(e.operands[1]) diff --git a/mypyc/irbuild/callable_class.py b/mypyc/irbuild/callable_class.py index 0261332800ae..e59d94e71ebd 100644 --- a/mypyc/irbuild/callable_class.py +++ b/mypyc/irbuild/callable_class.py @@ -6,6 +6,8 @@ from typing import List +from mypy.operators import BinOp + from mypyc.common import SELF_NAME, ENV_ATTR_NAME from mypyc.ir.ops import BasicBlock, Return, Call, SetAttr, Value, Register from mypyc.ir.rtypes import RInstance, object_rprimitive @@ -115,7 +117,7 @@ def add_get_to_callable_class(builder: IRBuilder, fn_info: FuncInfo) -> None: # instance method object. instance_block, class_block = BasicBlock(), BasicBlock() comparison = builder.translate_is_op( - builder.read(instance), builder.none_object(), 'is', line + builder.read(instance), builder.none_object(), BinOp.Is, line ) builder.add_bool_branch(comparison, class_block, instance_block) diff --git a/mypyc/irbuild/classdef.py b/mypyc/irbuild/classdef.py index 80cce0bbd35f..2326a9ecf807 100644 --- a/mypyc/irbuild/classdef.py +++ b/mypyc/irbuild/classdef.py @@ -7,6 +7,7 @@ ClassDef, FuncDef, OverloadedFuncDef, PassStmt, AssignmentStmt, NameExpr, StrExpr, ExpressionStmt, TempNode, Decorator, Lvalue, RefExpr, is_class_var ) +from mypy.operators import BinOp from mypyc.ir.ops import ( Value, Register, Call, LoadErrorValue, LoadStatic, InitStatic, TupleSet, SetAttr, Return, BasicBlock, Branch, MethodCall, NAMESPACE_TYPE, LoadAddress @@ -435,7 +436,7 @@ def gen_glue_ne_method(builder: IRBuilder, cls: ClassIR, line: int) -> None: not_implemented = builder.add(LoadAddress(not_implemented_op.type, not_implemented_op.src, line)) builder.add(Branch( - builder.translate_is_op(eqval, not_implemented, 'is', line), + builder.translate_is_op(eqval, not_implemented, BinOp.Is, line), not_implemented_block, regular_block, Branch.BOOL)) diff --git a/mypyc/irbuild/expression.py b/mypyc/irbuild/expression.py index 225c12eeea9b..4c4810099a3e 100644 --- a/mypyc/irbuild/expression.py +++ b/mypyc/irbuild/expression.py @@ -15,6 +15,7 @@ Var, RefExpr, MypyFile, TypeInfo, TypeApplication, LDEF, ARG_POS ) from mypy.types import TupleType, Instance, TypeType, ProperType, get_proper_type +from mypy.operators import BinOp from mypyc.common import MAX_SHORT_INT from mypyc.ir.ops import ( @@ -382,11 +383,12 @@ def transform_unary_expr(builder: IRBuilder, expr: UnaryExpr) -> Value: def transform_op_expr(builder: IRBuilder, expr: OpExpr) -> Value: - if expr.op in ('and', 'or'): + if expr.op.is_boolean(): return builder.shortcircuit_expr(expr) # Special case for string formatting - if expr.op == '%' and (isinstance(expr.left, StrExpr) or isinstance(expr.left, BytesExpr)): + if (expr.op == BinOp.Mod + and (isinstance(expr.left, StrExpr) or isinstance(expr.left, BytesExpr))): ret = translate_printf_style_formatting(builder, expr.left, expr.right) if ret is not None: return ret @@ -479,7 +481,7 @@ def transform_conditional_expr(builder: IRBuilder, expr: ConditionalExpr) -> Val def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value: # x in (...)/[...] # x not in (...)/[...] - if (e.operators[0] in ['in', 'not in'] + if (e.operators[0].is_contains() and len(e.operators) == 1 and isinstance(e.operands[1], (TupleExpr, ListExpr))): items = e.operands[1].items @@ -488,12 +490,12 @@ def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value: # x not in y -> x != y[0] and ... and x != y[n] # 16 is arbitrarily chosen to limit code size if 1 < n_items < 16: - if e.operators[0] == 'in': - bin_op = 'or' - cmp_op = '==' + if e.operators[0] == BinOp.In: + bin_op = BinOp.Or + cmp_op = BinOp.Eq else: - bin_op = 'and' - cmp_op = '!=' + bin_op = BinOp.And + cmp_op = BinOp.NotEq lhs = e.operands[0] mypy_file = builder.graph['builtins'].tree assert mypy_file is not None @@ -512,16 +514,16 @@ def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value: # x in [y]/(y) -> x == y # x not in [y]/(y) -> x != y elif n_items == 1: - if e.operators[0] == 'in': - cmp_op = '==' + if e.operators[0] == BinOp.In: + cmp_op = BinOp.Eq else: - cmp_op = '!=' + cmp_op = BinOp.NotEq e.operators = [cmp_op] e.operands[1] = items[0] # x in []/() -> False # x not in []/() -> True elif n_items == 0: - if e.operators[0] == 'in': + if e.operators[0] == BinOp.In: return builder.false() else: return builder.true() @@ -551,7 +553,7 @@ def go(i: int, prev: Value) -> Value: def transform_basic_comparison(builder: IRBuilder, - op: str, + op: BinOp, left: Value, right: Value, line: int) -> Value: @@ -559,10 +561,10 @@ def transform_basic_comparison(builder: IRBuilder, and op in int_comparison_op_mapping.keys()): return builder.compare_tagged(left, right, op, line) negate = False - if op == 'is not': - op, negate = 'is', True - elif op == 'not in': - op, negate = 'in', True + if op == BinOp.IsNot: + op, negate = BinOp.Is, True + elif op == BinOp.NotIn: + op, negate = BinOp.In, True target = builder.binary_op(left, right, op, line) diff --git a/mypyc/irbuild/for_helpers.py b/mypyc/irbuild/for_helpers.py index ae592ae91087..96cfe4c36f8f 100644 --- a/mypyc/irbuild/for_helpers.py +++ b/mypyc/irbuild/for_helpers.py @@ -11,6 +11,7 @@ from mypy.nodes import ( Lvalue, Expression, TupleExpr, CallExpr, RefExpr, GeneratorExpr, ARG_POS, MemberExpr, TypeAlias ) +from mypy.operators import BinOp from mypyc.ir.ops import ( Value, BasicBlock, Integer, Branch, Register, TupleGet, TupleSet, IntOp ) @@ -534,7 +535,7 @@ def init(self, expr_reg: Value, target_type: RType, reverse: bool) -> None: index_reg: Value = Integer(0) else: index_reg = builder.binary_op(self.load_len(self.expr_target), - Integer(1), '-', self.line) + Integer(1), BinOp.Sub, self.line) self.index_target = builder.maybe_spill_assignable(index_reg) self.target_type = target_type @@ -548,14 +549,15 @@ def gen_condition(self) -> None: # obviously we still need to check against the length, # since it could shrink out from under us. comparison = builder.binary_op(builder.read(self.index_target, line), - Integer(0), '>=', line) + Integer(0), BinOp.GtE, line) second_check = BasicBlock() builder.add_bool_branch(comparison, second_check, self.loop_exit) builder.activate_block(second_check) # For compatibility with python semantics we recalculate the length # at every iteration. len_reg = self.load_len(self.expr_target) - comparison = builder.binary_op(builder.read(self.index_target, line), len_reg, '<', line) + comparison = builder.binary_op(builder.read(self.index_target, line), + len_reg, BinOp.Lt, line) builder.add_bool_branch(comparison, self.body_block, self.loop_exit) def begin_body(self) -> None: @@ -743,7 +745,7 @@ def gen_condition(self) -> None: builder = self.builder line = self.line # Add loop condition check. - cmp = '<' if self.step > 0 else '>' + cmp = BinOp.Lt if self.step > 0 else BinOp.Gt comparison = builder.binary_op(builder.read(self.index_reg, line), builder.read(self.end_target, line), cmp, line) builder.add_bool_branch(comparison, self.body_block, self.loop_exit) @@ -762,7 +764,7 @@ def gen_step(self) -> None: else: new_val = builder.binary_op( - builder.read(self.index_reg, line), Integer(self.step), '+', line) + builder.read(self.index_reg, line), Integer(self.step), BinOp.Add, line) builder.assign(self.index_reg, new_val, line) builder.assign(self.index_target, new_val, line) diff --git a/mypyc/irbuild/function.py b/mypyc/irbuild/function.py index bdd4ed992f2f..292f1382788d 100644 --- a/mypyc/irbuild/function.py +++ b/mypyc/irbuild/function.py @@ -18,6 +18,7 @@ ClassDef, FuncDef, OverloadedFuncDef, Decorator, Var, YieldFromExpr, AwaitExpr, YieldExpr, FuncItem, LambdaExpr, SymbolNode, ArgKind, TypeInfo ) +from mypy.operators import BinOp from mypy.types import CallableType, get_proper_type from mypyc.ir.ops import ( @@ -842,7 +843,7 @@ def generate_singledispatch_dispatch_function( ) call_find_impl, use_cache, call_func = BasicBlock(), BasicBlock(), BasicBlock() get_result = builder.call_c(dict_get_method_with_none, [dispatch_cache, arg_type], line) - is_not_none = builder.translate_is_op(get_result, builder.none_object(), 'is not', line) + is_not_none = builder.translate_is_op(get_result, builder.none_object(), BinOp.IsNot, line) impl_to_use = Register(object_rprimitive) builder.add_bool_branch(is_not_none, use_cache, call_find_impl) @@ -897,7 +898,7 @@ def gen_native_func_call_and_return(fdef: FuncDef) -> None: builder.builder.compare_tagged_condition( passed_id, current_id, - '==', + BinOp.Eq, call_impl, next_impl, line, diff --git a/mypyc/irbuild/generator.py b/mypyc/irbuild/generator.py index 39d30cf74eeb..f62eb78c4773 100644 --- a/mypyc/irbuild/generator.py +++ b/mypyc/irbuild/generator.py @@ -11,6 +11,7 @@ from typing import List from mypy.nodes import Var, ARG_OPT +from mypy.operators import BinOp from mypyc.common import SELF_NAME, NEXT_LABEL_ATTR_NAME, ENV_ATTR_NAME from mypyc.ir.ops import ( @@ -86,7 +87,7 @@ def populate_switch_for_generator_class(builder: IRBuilder) -> None: for label, true_block in enumerate(cls.continuation_blocks): false_block = BasicBlock() comparison = builder.binary_op( - cls.next_label_reg, Integer(label), '==', line + cls.next_label_reg, Integer(label), BinOp.Eq, line ) builder.add_bool_branch(comparison, true_block, false_block) builder.activate_block(false_block) @@ -109,7 +110,7 @@ def add_raise_exception_blocks_to_generator_class(builder: IRBuilder, line: int) # Check to see if an exception was raised. error_block = BasicBlock() ok_block = BasicBlock() - comparison = builder.translate_is_op(exc_type, builder.none_object(), 'is not', line) + comparison = builder.translate_is_op(exc_type, builder.none_object(), BinOp.IsNot, line) builder.add_bool_branch(comparison, error_block, ok_block) builder.activate_block(error_block) diff --git a/mypyc/irbuild/ll_builder.py b/mypyc/irbuild/ll_builder.py index b6cf990d025d..38b1cbb2888c 100644 --- a/mypyc/irbuild/ll_builder.py +++ b/mypyc/irbuild/ll_builder.py @@ -15,7 +15,7 @@ from typing_extensions import Final from mypy.nodes import ArgKind, ARG_POS, ARG_STAR, ARG_STAR2 -from mypy.operators import op_methods +from mypy.operators import BinOp, op_methods from mypy.types import AnyType, TypeOfAny from mypy.checkexpr import map_actuals_to_formals @@ -845,35 +845,38 @@ def load_native_type_object(self, fullname: str) -> Value: # Other primitive operations def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value: + # We don't use `BinOp` methods here, because some operators might be + # containing extra `=` sign, like `^=`. ltype = lreg.type rtype = rreg.type + is_eq_op = op in (BinOp.Eq, BinOp.NotEq) # Special case tuple comparison here so that nested tuples can be supported - if isinstance(ltype, RTuple) and isinstance(rtype, RTuple) and op in ('==', '!='): - return self.compare_tuples(lreg, rreg, op, line) + if isinstance(ltype, RTuple) and isinstance(rtype, RTuple) and is_eq_op: + return self.compare_tuples(lreg, rreg, BinOp(op), line) # Special case == and != when we can resolve the method call statically - if op in ('==', '!='): - value = self.translate_eq_cmp(lreg, rreg, op, line) + if is_eq_op: + value = self.translate_eq_cmp(lreg, rreg, BinOp(op), line) if value is not None: return value # Special case various ops - if op in ('is', 'is not'): - return self.translate_is_op(lreg, rreg, op, line) + if op in (BinOp.Is, BinOp.IsNot): + return self.translate_is_op(lreg, rreg, BinOp(op), line) # TODO: modify 'str' to use same interface as 'compare_bytes' as it avoids # call to PyErr_Occurred() - if is_str_rprimitive(ltype) and is_str_rprimitive(rtype) and op in ('==', '!='): - return self.compare_strings(lreg, rreg, op, line) - if is_bytes_rprimitive(ltype) and is_bytes_rprimitive(rtype) and op in ('==', '!='): - return self.compare_bytes(lreg, rreg, op, line) + if is_str_rprimitive(ltype) and is_str_rprimitive(rtype) and is_eq_op: + return self.compare_strings(lreg, rreg, BinOp(op), line) + if is_bytes_rprimitive(ltype) and is_bytes_rprimitive(rtype) and is_eq_op: + return self.compare_bytes(lreg, rreg, BinOp(op), line) if is_tagged(ltype) and is_tagged(rtype) and op in int_comparison_op_mapping: - return self.compare_tagged(lreg, rreg, op, line) + return self.compare_tagged(lreg, rreg, BinOp(op), line) if is_bool_rprimitive(ltype) and is_bool_rprimitive(rtype) and op in ( '&', '&=', '|', '|=', '^', '^='): - return self.bool_bitwise_op(lreg, rreg, op[0], line) - if isinstance(rtype, RInstance) and op in ('in', 'not in'): - return self.translate_instance_contains(rreg, lreg, op, line) + return self.bool_bitwise_op(lreg, rreg, BinOp(op[0]), line) + if isinstance(rtype, RInstance) and op in (BinOp.In, BinOp.NotIn): + return self.translate_instance_contains(rreg, lreg, BinOp(op), line) call_c_ops_candidates = binary_ops.get(op, []) target = self.matching_call_c(call_c_ops_candidates, [lreg, rreg], line) @@ -892,7 +895,7 @@ def check_tagged_short_int(self, val: Value, line: int, negated: bool = False) - check = self.comparison_op(bitwise_and, zero, op, line) return check - def compare_tagged(self, lhs: Value, rhs: Value, op: str, line: int) -> Value: + def compare_tagged(self, lhs: Value, rhs: Value, op: BinOp, line: int) -> Value: """Compare two tagged integers using given operator (value context).""" # generate fast binary logic ops on short ints if is_short_int_rprimitive(lhs.type) and is_short_int_rprimitive(rhs.type): @@ -901,7 +904,7 @@ def compare_tagged(self, lhs: Value, rhs: Value, op: str, line: int) -> Value: result = Register(bool_rprimitive) short_int_block, int_block, out = BasicBlock(), BasicBlock(), BasicBlock() check_lhs = self.check_tagged_short_int(lhs, line) - if op in ("==", "!="): + if op.is_equality(): check = check_lhs else: # for non-equality logical ops (less/greater than, etc.), need to check both sides @@ -930,7 +933,7 @@ def compare_tagged(self, lhs: Value, rhs: Value, op: str, line: int) -> Value: def compare_tagged_condition(self, lhs: Value, rhs: Value, - op: str, + op: BinOp, true: BasicBlock, false: BasicBlock, line: int) -> None: @@ -945,7 +948,7 @@ def compare_tagged_condition(self, true: Branch target if comparison is true false: Branch target if comparison is false """ - is_eq = op in ("==", "!=") + is_eq = op.is_equality() if ((is_short_int_rprimitive(lhs.type) and is_short_int_rprimitive(rhs.type)) or (is_eq and (is_short_int_rprimitive(lhs.type) or is_short_int_rprimitive(rhs.type)))): @@ -981,7 +984,7 @@ def compare_tagged_condition(self, eq = self.comparison_op(lhs, rhs, op_type, line) self.add(Branch(eq, true, false, Branch.BOOL)) - def compare_strings(self, lhs: Value, rhs: Value, op: str, line: int) -> Value: + def compare_strings(self, lhs: Value, rhs: Value, op: BinOp, line: int) -> Value: """Compare two strings""" compare_result = self.call_c(unicode_compare, [lhs, rhs], line) error_constant = Integer(-1, c_int_rprimitive, line) @@ -1003,25 +1006,25 @@ def compare_strings(self, lhs: Value, rhs: Value, op: str, line: int) -> Value: self.call_c(keep_propagating_op, [], line) self.goto(final_compare) self.activate_block(final_compare) - op_type = ComparisonOp.EQ if op == '==' else ComparisonOp.NEQ + op_type = ComparisonOp.EQ if op == BinOp.Eq else ComparisonOp.NEQ return self.add(ComparisonOp(compare_result, Integer(0, c_int_rprimitive), op_type, line)) - def compare_bytes(self, lhs: Value, rhs: Value, op: str, line: int) -> Value: + def compare_bytes(self, lhs: Value, rhs: Value, op: BinOp, line: int) -> Value: compare_result = self.call_c(bytes_compare, [lhs, rhs], line) - op_type = ComparisonOp.EQ if op == '==' else ComparisonOp.NEQ + op_type = ComparisonOp.EQ if op == BinOp.Eq else ComparisonOp.NEQ return self.add(ComparisonOp(compare_result, Integer(1, c_int_rprimitive), op_type, line)) def compare_tuples(self, lhs: Value, rhs: Value, - op: str, + op: BinOp, line: int = -1) -> Value: """Compare two tuples item by item""" # type cast to pass mypy check assert isinstance(lhs.type, RTuple) and isinstance(rhs.type, RTuple) - equal = True if op == '==' else False + equal = op == BinOp.Eq result = Register(bool_rprimitive) # empty tuples if len(lhs.type.types) == 0 and len(rhs.type.types) == 0: @@ -1063,15 +1066,15 @@ def compare_tuples(self, self.goto_and_activate(out) return result - def translate_instance_contains(self, inst: Value, item: Value, op: str, line: int) -> Value: + def translate_instance_contains(self, inst: Value, item: Value, op: BinOp, line: int) -> Value: res = self.gen_method_call(inst, '__contains__', [item], None, line) if not is_bool_rprimitive(res.type): res = self.call_c(bool_op, [res], line) - if op == 'not in': - res = self.bool_bitwise_op(res, Integer(1, rtype=bool_rprimitive), '^', line) + if op == BinOp.NotIn: + res = self.bool_bitwise_op(res, Integer(1, rtype=bool_rprimitive), BinOp.BitXor, line) return res - def bool_bitwise_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value: + def bool_bitwise_op(self, lreg: Value, rreg: Value, op: BinOp, line: int) -> Value: if op == '&': code = IntOp.AND elif op == '|': @@ -1229,7 +1232,7 @@ def shortcircuit_helper(self, op: str, def add_bool_branch(self, value: Value, true: BasicBlock, false: BasicBlock) -> None: if is_runtime_subtype(value.type, int_rprimitive): zero = Integer(0, short_int_rprimitive) - self.compare_tagged_condition(value, zero, '!=', true, false, value.line) + self.compare_tagged_condition(value, zero, BinOp.NotEq, true, false, value.line) return elif is_same_type(value.type, str_rprimitive): value = self.call_c(str_check_if_true, [value], value.line) @@ -1237,7 +1240,7 @@ def add_bool_branch(self, value: Value, true: BasicBlock, false: BasicBlock) -> or is_same_type(value.type, dict_rprimitive)): length = self.builtin_len(value, value.line) zero = Integer(0) - value = self.binary_op(length, zero, '!=', value.line) + value = self.binary_op(length, zero, BinOp.NotEq, value.line) elif (isinstance(value.type, RInstance) and value.type.class_ir.is_ext_class and value.type.class_ir.has_method('__bool__')): # Directly call the __bool__ method on classes that have it. @@ -1245,7 +1248,7 @@ def add_bool_branch(self, value: Value, true: BasicBlock, false: BasicBlock) -> else: value_type = optional_value_type(value.type) if value_type is not None: - is_none = self.translate_is_op(value, self.none_object(), 'is not', value.line) + is_none = self.translate_is_op(value, self.none_object(), BinOp.IsNot, value.line) branch = Branch(is_none, true, false, Branch.BOOL) self.add(branch) always_truthy = False @@ -1391,7 +1394,7 @@ def builtin_len(self, val: Value, line: int, use_pyssize_t: bool = False) -> Val length = self.gen_method_call(val, '__len__', [], int_rprimitive, line) length = self.coerce(length, int_rprimitive, line) ok, fail = BasicBlock(), BasicBlock() - self.compare_tagged_condition(length, Integer(0), '>=', ok, fail, line) + self.compare_tagged_condition(length, Integer(0), BinOp.GtE, ok, fail, line) self.activate_block(fail) self.add(RaiseStandardError(RaiseStandardError.VALUE_ERROR, "__len__() should return >= 0", @@ -1509,7 +1512,7 @@ def translate_special_method_call(self, def translate_eq_cmp(self, lreg: Value, rreg: Value, - expr_op: str, + expr_op: BinOp, line: int) -> Optional[Value]: """Add a equality comparison operation. @@ -1539,7 +1542,7 @@ def translate_eq_cmp(self, if not class_ir.has_method('__eq__'): # There's no __eq__ defined, so just use object identity. - identity_ref_op = 'is' if expr_op == '==' else 'is not' + identity_ref_op = BinOp.Is if expr_op == BinOp.Eq else BinOp.IsNot return self.translate_is_op(lreg, rreg, identity_ref_op, line) return self.gen_method_call( @@ -1553,14 +1556,14 @@ def translate_eq_cmp(self, def translate_is_op(self, lreg: Value, rreg: Value, - expr_op: str, + expr_op: BinOp, line: int) -> Value: """Create equality comparison operation between object identities Args: expr_op: either 'is' or 'is not' """ - op = ComparisonOp.EQ if expr_op == 'is' else ComparisonOp.NEQ + op = ComparisonOp.EQ if expr_op == BinOp.Is else ComparisonOp.NEQ lhs = self.coerce(lreg, object_rprimitive, line) rhs = self.coerce(rreg, object_rprimitive, line) return self.add(ComparisonOp(lhs, rhs, op, line)) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index 6a744781ee50..4fc9a9c0dc39 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -120,7 +120,7 @@ def transform_operator_assignment_stmt(builder: IRBuilder, stmt: OperatorAssignm target_value = builder.read(target, stmt.line) rreg = builder.accept(stmt.rvalue) # the Python parser strips the '=' from operator assignment statements, so re-add it - op = stmt.op + '=' + op = stmt.op.value + '=' res = builder.binary_op(target_value, rreg, op, stmt.line) # usually operator assignments are done in-place # but when target doesn't support that we need to manually assign diff --git a/mypyc/test/test_emitfunc.py b/mypyc/test/test_emitfunc.py index 139923aa57c6..1cca9a59aa61 100644 --- a/mypyc/test/test_emitfunc.py +++ b/mypyc/test/test_emitfunc.py @@ -3,6 +3,7 @@ from typing import List, Optional from mypy.backports import OrderedDict +from mypy.operators import BinOp from mypy.test.helpers import assert_string_arrays_equal @@ -102,12 +103,12 @@ def test_assign_int(self) -> None: def test_int_add(self) -> None: self.assert_emit_binary_op( - '+', self.n, self.m, self.k, + BinOp.Add, self.n, self.m, self.k, "cpy_r_r0 = CPyTagged_Add(cpy_r_m, cpy_r_k);") def test_int_sub(self) -> None: self.assert_emit_binary_op( - '-', self.n, self.m, self.k, + BinOp.Sub, self.n, self.m, self.k, "cpy_r_r0 = CPyTagged_Subtract(cpy_r_m, cpy_r_k);") def test_int_neg(self) -> None: @@ -293,7 +294,7 @@ def test_new_dict(self) -> None: def test_dict_contains(self) -> None: self.assert_emit_binary_op( - 'in', self.b, self.o, self.d, + BinOp.In, self.b, self.o, self.d, """cpy_r_r0 = PyDict_Contains(cpy_r_d, cpy_r_o);""") def test_int_op(self) -> None: @@ -402,7 +403,7 @@ def assert_emit(self, op: Op, expected: str, next_block: Optional[BasicBlock] = msg='Generated code unexpected') def assert_emit_binary_op(self, - op: str, + op: BinOp, dest: Value, left: Value, right: Value,