Source code for idp_engine.Simplify

# Copyright 2019-2023 Ingmar Dasseville, Pierre Carbonnelle
#
# This file is part of IDP-Z3.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

"""

Methods to simplify a logic expression.

This module monkey-patches the Expression class and sub-classes.


"""
from __future__ import annotations

from copy import deepcopy
import sys
from typing import List, Tuple, Optional, Generator

from .Expression import (Constructor, Expression, AIfExpr, IF,
                         AQuantification, Operator, AEquivalence, AImplication,
                         ADisjunction, AConjunction, AComparison, EQUALS,
                         ASumMinus, AMultDiv, APower, AUnary, AAggregate,
                         SymbolExpr, AppliedSymbol, UnappliedSymbol, Variable,
                         Number, Date, Brackets, TRUE, FALSE, NOT, AND, OR)
from .Parse import Symbol, Enumeration, TupleIDP
from .Assignments import Status as S, Assignment
from .utils import BOOL, INT, DATE, ABS


# class Expression  ###########################################################

def _change(self: Expression,
            sub_exprs: Optional[List[Expression]] = None,
            ops : Optional[List[str]] = None,
            simpler : Optional[Expression] = None,
            co_constraint : Optional[Expression] = None
            ) -> Expression:
    " change attributes of an expression, and resets derived attributes "

    if simpler is not None:
        simpler.original = self.original
        simpler.is_type_constraint_for = self.is_type_constraint_for
        if type(self) == AppliedSymbol:
            simpler.in_head = self.in_head
        return simpler

    if sub_exprs is not None:
        self.sub_exprs = sub_exprs
    if ops is not None:
        self.operator = ops
    if co_constraint is not None:
        self.co_constraint = co_constraint

    # reset derived attributes
    self.str = sys.intern(str(self))

    return self
Expression._change = _change


def update_exprs(self: Expression,
                 new_exprs: Generator[Expression, None, None]
                 ) -> Expression:
    """ change sub_exprs and simplify, while keeping relevant info. """
    #  default implementation, without simplification
    return self._change(sub_exprs=list(new_exprs))
Expression.update_exprs = update_exprs


def simplify1(self: Expression) -> Expression:
    return self.update_exprs(self.sub_exprs)
Expression.simplify1 = simplify1



# Class AIfExpr  ###############################################################

def update_exprs(self: AIfExpr,
                 new_exprs: Generator[Expression, None, None]
                 ) -> Expression:
    sub_exprs = list(new_exprs)
    if_, then_, else_ = sub_exprs[0], sub_exprs[1], sub_exprs[2]
    if if_.same_as(TRUE):
        return self._change(simpler=then_, sub_exprs=sub_exprs)
    elif if_.same_as(FALSE):
        return self._change(simpler=else_, sub_exprs=sub_exprs)
    else:
        if then_.same_as(else_):
            return self._change(simpler=then_, sub_exprs=sub_exprs)
        elif then_.same_as(TRUE):
            if else_.same_as(FALSE):
                return self._change(simpler=if_, sub_exprs=sub_exprs)
            else:
                return self._change(simpler=OR([if_, else_]), sub_exprs=sub_exprs)
        elif else_.same_as(TRUE):
            if then_.same_as(FALSE):
                return self._change(simpler=NOT(if_), sub_exprs=sub_exprs)
            else:
                return self._change(simpler=OR([NOT(if_), then_]), sub_exprs=sub_exprs)
    return self._change(sub_exprs=sub_exprs)
AIfExpr.update_exprs = update_exprs


# Class Quantee  #######################################################


# Class AQuantification  ######################################################

def update_exprs(self: AQuantification,
                 new_exprs: Generator[Expression, None, None]
                 ) -> Expression:
    if self.q == '∀':
        return AConjunction.update_exprs(self, new_exprs, replace=False)
    else:
        return ADisjunction.update_exprs(self, new_exprs, replace=False)
AQuantification.update_exprs = update_exprs


# Class AImplication  #######################################################

def update_exprs(self: AImplication,
                 new_exprs: Generator[Expression, None, None]
                 ) -> Expression:
    if type(new_exprs) == list:
        new_exprs = iter(new_exprs)
    exprs0 = next(new_exprs)
    simpler = None
    if exprs0.same_as(FALSE):  # (false => p) is true
        return TRUE
    elif exprs0.same_as(TRUE):  # (true => p) is p
        exprs1 = next(new_exprs)
        simpler = exprs1
    else:
        exprs1 = next(new_exprs)
        if exprs1.same_as(TRUE):  # (p => true) is true
            return TRUE
        elif exprs1.same_as(FALSE):  # (p => false) is ~p
            simpler = NOT(exprs0)
        elif exprs1.same_as(exprs0):  # (p => p) is true
            return TRUE
    return self._change(simpler=simpler,
                        sub_exprs=[exprs0, exprs1])
AImplication.update_exprs = update_exprs


# Class AEquivalence  #######################################################

def update_exprs(self: AEquivalence,
                 new_exprs: Generator[Expression, None, None]
                 ) -> Expression:
    exprs = list(new_exprs)
    if len(exprs) == 1:
        return self._change(simpler=exprs[1], sub_exprs=exprs)
    for e in exprs:
        if e.same_as(TRUE):  # they must all be true
            return self._change(simpler=AND(exprs),
                                sub_exprs=exprs)
        if e.same_as(FALSE):  # they must all be false
            return self._change(simpler=AND([NOT(e) for e in exprs]),
                                sub_exprs=exprs)
    return self._change(sub_exprs=exprs)
AEquivalence.update_exprs = update_exprs


# Class ADisjunction  #######################################################

def update_exprs(self: Expression, new_exprs: Expression, replace=True) -> Expression:
    exprs, other = [], []
    simpler = None
    for expr in new_exprs:
        if expr.same_as(TRUE):
            return TRUE
        exprs.append(expr)
        if not expr.same_as(FALSE):
            other.append(expr)

    if len(other) == 0:  # all disjuncts are False
        return FALSE
    if replace and len(other) == 1:
        simpler = other[0]
    return self._change(simpler=simpler, sub_exprs=exprs)
ADisjunction.update_exprs = update_exprs


# Class AConjunction  #######################################################

# same as ADisjunction, with TRUE and FALSE swapped
def update_exprs(self: Expression, new_exprs: Expression, replace=True) -> Expression:
    exprs, other = [], []
    simpler = None
    for expr in new_exprs:
        if expr.same_as(FALSE):
            return FALSE
        exprs.append(expr)
        if not expr.same_as(TRUE):
            other.append(expr)

    if len(other) == 0:  # all conjuncts are True
        return TRUE
    if replace and len(other) == 1:
        simpler = other[0]
    return self._change(simpler=simpler, sub_exprs=exprs)
AConjunction.update_exprs = update_exprs


# Class AComparison  #######################################################

def update_exprs(self: AComparison,
                 new_exprs: Generator[Expression, None, None]
                 ) -> Expression:
    operands = list(new_exprs)

    if len(operands) == 2 and self.operator == ["="]:
        # a = a
        if operands[0].same_as(operands[1]):
            return TRUE

        # (if c then a else b) = d  ->  (if c then a=d else b=d)
        if type(operands[0]) == AIfExpr:
            then = EQUALS([operands[0].sub_exprs[1], operands[1]]).simplify1()
            else_ = EQUALS([operands[0].sub_exprs[2], operands[1]]).simplify1()
            new = IF(operands[0].sub_exprs[0], then, else_).simplify1()
            return self._change(simpler=new, sub_exprs=operands)

    acc = operands[0]
    assert len(self.operator) == len(operands[1:]), "Internal error"
    for op, expr in zip(self.operator, operands[1:]):
        if acc.is_value() and expr.is_value():
            if op in ["<", ">"] and acc.same_as(expr):
                return FALSE
            if op == "=" and not acc.same_as(expr):
                return FALSE
            if op == "≠":  # issue #246
                if acc.same_as(expr):
                    return FALSE
            elif not (Operator.MAP[op]) (acc.py_value, expr.py_value):
                return FALSE
        acc = expr
    if all(e.is_value() for e in operands):
        return TRUE
    return self._change(sub_exprs=operands)
AComparison.update_exprs = update_exprs

def as_set_condition(self: AComparison) -> Tuple[Optional[AppliedSymbol], Optional[bool], Optional[Enumeration]]:
    return ((None, None, None) if not self.is_assignment() else
            (self.sub_exprs[0], True,
             Enumeration(tuples=[TupleIDP(args=[self.sub_exprs[1]])])))
AComparison.as_set_condition = as_set_condition

#############################################################

def update_arith(self: Expression, operands: List[Expression]) -> Expression:
    operands = list(operands)
    if all(e.is_value() for e in operands):
        self.check(all(hasattr(e, 'py_value') for e in operands),
                f"Incorrect numeric type in {self}")
        out = operands[0].py_value

        assert len(self.operator) == len(operands[1:]), "Internal error"
        for op, e in zip(self.operator, operands[1:]):
            function = Operator.MAP[op]

            if op == '/' and self.type == INT:  # integer division
                out //= e.py_value
            else:
                out = function(out, e.py_value)
        value = (Number(number=str(out)) if operands[0].type != DATE else
                 Date.make(out))
        return value
    return self._change(sub_exprs=operands)


# Class ASumMinus  #######################################################

def update_exprs(self: ASumMinus,
                 new_exprs: Generator[Expression, None, None]
                 ) -> Expression:
    return update_arith(self, new_exprs)
ASumMinus.update_exprs = update_exprs


# Class AMultDiv  #######################################################

def update_exprs(self: AMultDiv,
                 new_exprs: Generator[Expression, None, None]
                 ) -> Expression:
    operands = list(new_exprs)
    if any(op == '%' for op in self.operator):  # special case !
        if len(operands) == 2 and all(e.is_value() for e in operands):
            out = operands[0].py_value % operands[1].py_value
            return Number(number=str(out))
        else:
            return self._change(sub_exprs=operands)
    return update_arith(self, operands)
AMultDiv.update_exprs = update_exprs


# Class APower  #######################################################

def update_exprs(self: APower,
                 new_exprs: Generator[Expression, None, None]
                 ) -> Expression:
    operands = list(new_exprs)
    if len(operands) == 2 \
       and all(e.is_value() for e in operands):
        out = operands[0].py_value ** operands[1].py_value
        return Number(number=str(out))
    else:
        return self._change(sub_exprs=operands)
APower.update_exprs = update_exprs


# Class AUnary  #######################################################

def update_exprs(self: AUnary,
                 new_exprs: Generator[Expression, None, None]
                 ) -> Expression:
    operand = list(new_exprs)[0]
    if self.operator == '¬':
        if operand.same_as(TRUE):
            return FALSE
        if operand.same_as(FALSE):
            return TRUE
    else:  # '-'
        if operand.is_value() and type(operand) == Number:
            return Number(number=f"{-operand.py_value}")
    return self._change(sub_exprs=[operand])
AUnary.update_exprs = update_exprs

def as_set_condition(self: AUnary) -> Tuple[Optional[AppliedSymbol], Optional[bool], Optional[Enumeration]]:
    (x, y, z) = self.sub_exprs[0].as_set_condition()
    return ((None, None, None) if x is None else
            (x, not y, z))
AUnary.as_set_condition = as_set_condition


# Class AAggregate  #######################################################

def update_exprs(self: AAggregate,
                 new_exprs: Generator[Expression, None, None]
                 ) -> Expression:
    operands = list(new_exprs)
    if self.annotated and not self.quantees:
        if all(e.is_value() for e in operands):
            out = sum(e.py_value for e in operands)
            return Number(number=str(out))
    return self._change(sub_exprs=operands)
AAggregate.update_exprs = update_exprs


# Class AppliedSymbol  #######################################################

def update_exprs(self: AppliedSymbol,
                 new_exprs: Generator[Expression, None, None]
                 ) -> Expression:
    new_exprs = list(new_exprs)
    if not self.decl and type(self.symbol) == Symbol:
        self.decl = self.symbol.decl
    self.type = (BOOL if self.is_enumerated or self.in_enumeration else
            self.decl.type if self.decl else None)
    if self.decl and type(self.decl) == Constructor:
        if all(e.is_value() for e in new_exprs):
            return self._change(sub_exprs=new_exprs)

    # simplify abs()
    if (self.decl and self.decl.name == ABS and len(new_exprs) == 1
        and new_exprs[0].is_value()):
        return Number(number=str(abs(new_exprs[0].py_value)))

    # simplify x(pos(0,0)) to 0,  is_pos(pos(0,0)) to True
    if (len(new_exprs) == 1
        and hasattr(new_exprs[0], 'decl')
        and type(new_exprs[0].decl) == Constructor
        and new_exprs[0].decl.tester
        and self.decl):
        if self.decl.name in new_exprs[0].decl.parent.accessors:
            i = new_exprs[0].decl.parent.accessors[self.decl.name]
            self.check(i < len(new_exprs[0].sub_exprs),
                       f"Incorrect expression: {self}")
            return self._change(simpler=new_exprs[0].sub_exprs[i], sub_exprs=new_exprs)
        if self.decl.name == new_exprs[0].decl.tester.name:
            return TRUE

    return self._change(sub_exprs=new_exprs)
AppliedSymbol.update_exprs = update_exprs

def as_set_condition(self: AppliedSymbol) -> Tuple[Optional[AppliedSymbol], Optional[bool], Optional[Enumeration]]:
    # determine core after substitutions
    core = AppliedSymbol.make(self.symbol, deepcopy(self.sub_exprs))

    return ((None, None, None) if not self.in_enumeration else
            (core, 'not' not in self.is_enumeration, self.in_enumeration))
AppliedSymbol.as_set_condition = as_set_condition


# Class SymbolExpr  #######################################################

def update_exprs(self: SymbolExpr,
                 new_exprs: Generator[Expression, None, None]
                 ) -> Expression:
    symbol = list(new_exprs)[0]
    value = (symbol if self.eval == '' else
             symbol.decl.symbol if type(symbol) == UnappliedSymbol and symbol.decl else
             None)
    if value is not None:
        self.check(type(value) != Variable,
                f"Variable `{value}` cannot be applied to argument(s).")
        return value
    return self._change(sub_exprs=[symbol])
SymbolExpr.update_exprs = update_exprs


# Class Brackets  #######################################################

def update_exprs(self: Brackets,
                 new_exprs: Generator[Expression, None, None]
                 ) -> Expression:
    return list(new_exprs)[0]
Brackets.update_exprs = update_exprs


# set conditions  #######################################################

[docs]def join_set_conditions(assignments: List[Assignment]) -> List[Assignment]: """In a list of assignments, merge assignments that are set-conditions on the same term. An equality and a membership predicate (`in` operator) are both set-conditions. Args: assignments (List[Assignment]): the list of assignments to make more compact Returns: List[Assignment]: the compacted list of assignments """ # for i, c in enumerate(assignments): (x, belongs, y) = c.as_set_condition() if x: for j in range(i): (x1, belongs1, y1) = assignments[j].as_set_condition() if x1 and x.same_as(x1): if belongs and belongs1: new_tuples = (y.tuples & y1.tuples) # intersect elif belongs and not belongs1: new_tuples = (y.tuples ^ y1.tuples) # difference elif not belongs and belongs1: belongs = belongs1 new_tuples = (y1.tuples ^ y.tuples) else: new_tuples = y.tuples | y1.tuples # union # sort again new_tuples = list(new_tuples.values()) out = AppliedSymbol.make( symbol=x.symbol, args=x.sub_exprs, is_enumeration='in', in_enumeration=Enumeration(tuples=new_tuples) ) core = deepcopy(AppliedSymbol.make(out.symbol, out.sub_exprs)) out.as_disjunction = out.in_enumeration.contains([core], False) out = Assignment(out, TRUE if belongs else FALSE, S.UNKNOWN) assignments[j] = out # keep the first one assignments[i] = Assignment(TRUE, TRUE, S.UNKNOWN) return [c for c in assignments if c.sentence != TRUE]
Done = True