Source code for idp_engine.Interpret

# cython: binding=True

# Copyright 2019-2023 Ingmar Dasseville, Pierre Carbonnelle
#
# This file is part of IDP-Z3.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

"""

Methods to ground / interpret a theory in a data structure

* expand quantifiers
* replace symbols interpreted in the structure by their interpretation
* instantiate definitions

This module monkey-patches the ASTNode class and sub-classes.

( see docs/zettlr/Substitute.md )

"""
from __future__ import annotations

from copy import copy, deepcopy
from itertools import product
from typing import List, Callable, Optional, Tuple

from .Assignments import Status as S
from .Parse import (Import, TypeDeclaration, SymbolDeclaration,
                    SymbolInterpretation, FunctionEnum, Enumeration, TupleIDP,
                    ConstructedFrom, Definition, Rule)
from .Expression import (catch_error, RecDef, Symbol, SYMBOL, AIfExpr, IF,
                         SymbolExpr, Expression, Constructor, AQuantification,
                         Type, FORALL, IMPLIES, AND, AAggregate, EQUIV, EQUALS,
                         OR, AppliedSymbol, UnappliedSymbol, Quantee, Variable,
                         VARIABLE, TRUE, FALSE, Number, ZERO, Extension)
from .Theory import Theory
from .utils import (BOOL, INT, RESERVED_SYMBOLS, CONCEPT, OrderedSet, DEFAULT,
                    GOAL_SYMBOL, EXPAND, CO_CONSTR_RECURSION_DEPTH, Semantics)


# class Import  ###########################################################

@catch_error
def interpret(self: Import, problem: Theory):
    pass
Import.interpret = interpret


# class TypeDeclaration  ###########################################################

@catch_error
def interpret(self: TypeDeclaration, problem: Theory):
    interpretation = problem.interpretations.get(self.name, None)
    if self.name in [BOOL, CONCEPT]:
        self.translate(problem)
        ranges = [c.interpret(problem).range for c in self.constructors]
        ext = ([[t] for r in ranges for t in r], None)
        problem.extensions[self.name] = ext
    else:
        self.check(interpretation is not None
                   and hasattr(interpretation, 'enumeration'),
                   f'Expected an interpretation for type {self.name}')

        enum = interpretation.enumeration.interpret(problem)
        self.interpretation = interpretation
        self.constructors = enum.constructors
        self.translate(problem)

        if self.constructors is not None:
            for c in self.constructors:
                c.interpret(problem)

        # update problem.extensions
        ext = enum.extensionE(problem.interpretations, problem.extensions)
        problem.extensions[self.name] = ext

        # needed ?
        # if (isinstance(self.interpretation.enumeration, Ranges)
        # and self.interpretation.enumeration.tuples):
        #     # add condition that the interpretation is total over the infinite domain
        #     # ! x in N: type(x) <=> enum.contains(x)
        #     t = TYPE(self.type)  # INT, REAL or DATE
        #     t.decl, t.type = self, self.type
        #     var = VARIABLE(f"${self.name}!0$",t)
        #     q_vars = { f"${self.name}!0$": var}
        #     quantees = [Quantee.make(var, subtype=t)]
        #     expr1 = AppliedSymbol.make(SYMBOL(self.name), [var])
        #     expr1.decl = self
        #     expr2 = enum.contains(list(q_vars.values()), True)
        #     expr = EQUALS([expr1, expr2])
        #     constraint = FORALL(quantees, expr)
        #     constraint.annotations['reading'] = f"Enumeration of {self.name} should cover its domain"
        #     problem.constraints.append(constraint)
TypeDeclaration.interpret = interpret


# class SymbolDeclaration  ###########################################################

@catch_error
def interpret(self: SymbolDeclaration, problem: Theory):
    assert all(isinstance(s, Type) for s in self.sorts), 'internal error'

    symbol = SYMBOL(self.name)
    symbol.decl = self
    symbol.type = symbol.decl.type

    # determine the extension, i.e., (superset, filter)
    extensions = [s.extension(problem.interpretations, problem.extensions)
                for s in self.sorts]
    if any(e[0] is None for e in extensions):
        superset = None
    else:
        superset = list(product(*([ee[0] for ee in e[0]] for e in extensions)))

    filters = [e[1] for e in extensions]
    def filter(args):
        out = AND([f([deepcopy(t)]) if f is not None else TRUE
                    for f, t in zip(filters, args)])
        if self.out.decl.name == BOOL:
            out = AND([out, deepcopy(AppliedSymbol.make(symbol, args))])
        return out

    if self.out.decl.name == BOOL:
        problem.extensions[self.name] = (superset, filter)

    (range, _) = self.out.extension(problem.interpretations, problem.extensions)
    if range is None:
        self.range = []
    else:
        self.range = [e[0] for e in range]

    # create instances + empty assignment
    if self.name not in RESERVED_SYMBOLS and superset is not None:
        self.instances = {}
        for args in superset:
            expr = AppliedSymbol.make(symbol, args)
            self.instances[expr.code] = expr
            problem.assignments.assert__(expr, None, S.UNKNOWN)

    # interpret the enumeration
    if self.name in problem.interpretations and self.name != GOAL_SYMBOL:
        problem.interpretations[self.name].interpret(problem)

    # create type constraints
    if type(self.instances) == dict and self.out.decl.name != BOOL:
        for expr in self.instances.values():
            # add type constraints to problem.constraints
            # ! (x,y) in domain: range(f(x,y))
            range_condition = self.out.has_element(deepcopy(expr),
                                problem.interpretations, problem.extensions)
            if range_condition.same_as(TRUE):
                break
            range_condition = range_condition.interpret(problem, {})
            constraint = IMPLIES([filter(expr.sub_exprs), range_condition])
            constraint.is_type_constraint_for = self.name
            constraint.annotations['reading'] = f"Possible values for {expr}"
            problem.constraints.append(constraint)
SymbolDeclaration.interpret = interpret


# class Definition  ###########################################################

@catch_error
def interpret(self: Definition, problem: Theory):
    """updates problem.def_constraints, by expanding the definitions

    Args:
        problem (Theory):
            containts the enumerations for the expansion; is updated with the expanded definitions
    """
    self.cache = {}  # reset the cache
    problem.def_constraints.update(self.get_def_constraints(problem))
Definition.interpret = interpret


# class SymbolInterpretation  ###########################################################

@catch_error
def interpret(self: SymbolInterpretation, problem: Theory):
    status = S.DEFAULT if self.block.name == DEFAULT else S.STRUCTURE
    assert not self.is_type_enumeration, "Internal error"
    if not self.name in [GOAL_SYMBOL, EXPAND]:
        decl = problem.declarations[self.name]
        assert isinstance(decl, SymbolDeclaration), "Internal error"
        # update problem.extensions
        if self.symbol.decl.out.decl.name == BOOL:  # predicate
            extension = [t.args for t in self.enumeration.tuples]
            problem.extensions[self.symbol.name] = (extension, None)

        enumeration = self.enumeration  # shorthand
        self.check(all(len(t.args) == self.symbol.decl.arity
                            + (1 if type(enumeration) == FunctionEnum else 0)
                        for t in enumeration.tuples),
            f"Incorrect arity of tuples in Enumeration of {self.symbol}.  Please check use of ',' and ';'.")

        lookup = {}
        if hasattr(decl, 'instances') and decl.instances and self.default:
            lookup = { ",".join(str(a) for a in applied.sub_exprs): self.default
                    for applied in decl.instances.values()}
        if type(enumeration) == FunctionEnum:
            lookup.update( (','.join(str(a) for a in t.args[:-1]), t.args[-1])
                        for t in enumeration.sorted_tuples)
        else:
            lookup.update( (t.code, TRUE)
                            for t in enumeration.sorted_tuples)
        enumeration.lookup = lookup

        # update problem.assignments with data from enumeration
        for t in enumeration.tuples:

            # check that the values are in the range
            if type(self.enumeration) == FunctionEnum:
                args, value = t.args[:-1], t.args[-1]
                condition = decl.has_in_range(value,
                            problem.interpretations, problem.extensions)
                self.check(not condition.same_as(FALSE),
                           f"{value} is not in the range of {self.symbol.name}")
                if not condition.same_as(TRUE):
                    problem.constraints.append(condition)
            else:
                args, value = t.args, TRUE

            # check that the arguments are in the domain
            a = (str(args) if 1<len(args) else
                    str(args[0]) if len(args)==1 else
                    "()")
            self.check(len(args) == decl.arity,
                       f"Incorrect arity of {a} for {self.name}")
            condition = decl.has_in_domain(args,
                            problem.interpretations, problem.extensions)
            self.check(not condition.same_as(FALSE),
                       f"{a} is not in the domain of {self.symbol.name}")
            if not condition.same_as(TRUE):
                problem.constraints.append(condition)

            # check duplicates
            expr = AppliedSymbol.make(self.symbol, args)
            self.check(expr.code not in problem.assignments
                or problem.assignments[expr.code].status == S.UNKNOWN,
                f"Duplicate entry in structure for '{self.name}': {str(expr)}")

            # add to problem.assignments
            e = problem.assignments.assert__(expr, value, status)
            if (status == S.DEFAULT  # for proper display in IC
                and type(self.enumeration) == FunctionEnum):
                problem.assignments.assert__(e.formula(), TRUE, status)

        if self.default is not None:
            if decl.instances is not None:
                # fill the default value in problem.assignments
                for code, expr in decl.instances.items():
                    if (code not in problem.assignments
                        or problem.assignments[code].status != status):
                        e = problem.assignments.assert__(expr, self.default, status)
                        if (status == S.DEFAULT  # for proper display in IC
                            and type(self.enumeration) == FunctionEnum
                            and self.default.type != BOOL):
                            problem.assignments.assert__(e.formula(), TRUE, status)

        elif self.sign == '≜':
            # add condition that the interpretation is total
            # over the domain specified by the type signature
            # ! x in domain(f): enum.contains(x)
            q_vars = { f"${sort.decl.name}!{str(i)}$":
                       VARIABLE(f"${sort.decl.name}!{str(i)}$", sort)
                       for i, sort in enumerate(decl.sorts)}
            quantees = [Quantee.make(v, sort=v.sort) for v in q_vars.values()]

            # is the domain of `self` enumerable ?
            constraint1 = FORALL(quantees, FALSE)
            get_supersets(constraint1, problem)
            if constraint1.sub_exprs[0] == FALSE:  # no filter added
                # the domain is enumerable => do the check immediately
                domain = set(str(flatten(d)) for d in product(*constraint1.supersets))
                if type(self.enumeration) == FunctionEnum:
                    enumeration = set(str(d.args[:-1]) for d in self.enumeration.tuples)
                else:
                    enumeration = set(str(d.args) for d in self.enumeration.tuples)
                self.check(domain == enumeration, f"Enumeration of {self.name} should cover its domain")
            else:  # add a constraint to the problem, to be solved by Z3
                # test case: tests/1240 FO{Core, Sugar, Int, PF)/LivingBeing.idp
                expr = self.enumeration.contains(list(q_vars.values()), True)
                constraint = FORALL(quantees, expr).interpret(problem, {})
                constraint.annotations['reading'] = f"Enumeration of {self.name} should cover its domain"
                problem.constraints.append(constraint)
SymbolInterpretation.interpret = interpret


# class Enumeration  ###########################################################

@catch_error
def interpret(self: Enumeration, problem: Theory) -> Enumeration:
    return self
Enumeration.interpret = interpret


# class ConstructedFrom  ###########################################################

@catch_error
def interpret(self: ConstructedFrom, problem: Theory) -> ConstructedFrom:
    self.tuples = OrderedSet()
    for c in self.constructors:
        c.interpret(problem)
        if c.range is None:
            self.tuples = None
            return self
        self.tuples.extend(TupleIDP(args=[e]) for e in c.range)
    return self
ConstructedFrom.interpret = interpret


# class Constructor  ###########################################################

@catch_error
def interpret(self: Constructor, problem: Theory) -> Constructor:
    # assert all(s.decl and isinstance(s.decl.out, Type) for s in self.sorts), 'Internal error'
    if not self.sorts:
        self.range = [UnappliedSymbol.construct(self)]
    elif any(s.type == self.type for s in self.sorts):  # recursive data type
        self.range = None
    else:
        # assert all(isinstance(s.decl, SymbolDeclaration) for s in self.sorts), "Internal error"
        extensions = [s.decl.out.extension(problem.interpretations, problem.extensions)
                      for s in self.sorts]
        if any(e[0] is None for e in extensions):
            self.range = None
        else:
            self.check(all(e[1] is None for e in extensions),  # no filter in the extension
                       f"Type signature of constructor {self.name} must have a given interpretation")
            self.range = [AppliedSymbol.construct(self, es)
                          for es in product(*[[ee[0] for ee in e[0]] for e in extensions])]
    return self
Constructor.interpret = interpret


# class Expression  ###########################################################

[docs]def interpret(self: Expression, problem: Optional[Theory], subs: dict[str, Expression] ) -> Expression: """expand quantifiers and replace symbols interpreted in the structure by their interpretation Args: self: the expression to be interpreted problem: the theory to be applied subs: a dictionary mapping variable names to their value Returns: Expression: the interpreted expression """ if self.is_type_constraint_for: return self _prepare_interpret(self, problem, subs) return self._interpret(problem, subs)
Expression.interpret = interpret def _prepare_interpret(self: Expression, problem: Optional[Theory], subs: dict[str, Expression] ): """Prepare the interpretation by transforming quantifications and aggregates """ for e in self.sub_exprs: _prepare_interpret(e, problem, subs) if isinstance(self, AQuantification) or isinstance(self, AAggregate): # type inference if 0 < len(self.sub_exprs): # in case it was simplified away inferred = self.sub_exprs[0].type_inference() for q in self.quantees: if not q.sub_exprs: assert len(q.vars) == 1 and q.arity == 1, \ f"Internal error: interpret {q}" var = q.vars[0][0] self.check(var.name in inferred, f"can't infer type of {var.name}") var.sort = inferred[var.name] q.sub_exprs = [inferred[var.name]] get_supersets(self, problem) def clone_when_necessary(func): @catch_error def inner_function(self, problem, subs): if self.is_value(): return self if subs: self = copy(self) # shallow copy ! self.annotations = copy(self.annotations) out = func(self, problem, subs) return out return inner_function @clone_when_necessary def _interpret(self: Expression, problem: Optional[Theory], subs: dict[str, Expression] ) -> Expression: """ uses information in the problem and its vocabulary to: - expand quantifiers in the expression - simplify the expression using known assignments and enumerations - instantiate definitions This method creates a copy when necessary. Args: problem (Theory): the Theory to apply subs: a dictionary holding the value of the free variables of self Returns: Expression: the resulting expression """ out = self.update_exprs(e._interpret(problem, subs) for e in self.sub_exprs) _finalize(out, subs) return out Expression._interpret = _interpret @catch_error def _finalize(self: Expression, subs: dict[str, Expression]): """update self.variables and reading""" if subs: self.code = str(self) self.annotations['reading'] = self.code return self # class Type ########################################################### @catch_error def extension(self, interpretations: dict[str, SymbolInterpretation], extensions: dict[str, Extension] ) -> Extension: """returns the extension of a Type, given some interpretations. Normally, the extension is already in `extensions`. However, for Concept[T->T], an additional filtering is applied. Args: interpretations (dict[str, SymbolInterpretation]): the known interpretations of types and symbols Returns: Extension: a superset of the extension of self, and a function that, given arguments, returns an Expression that says whether the arguments are in the extension of self """ if self.code not in extensions: self.check(self.name == CONCEPT, "internal error") assert (self.out and extensions is not None and extensions[CONCEPT] is not None), "internal error" # Concept[T->T] ext = extensions[CONCEPT][0] assert isinstance(ext, List) , "Internal error" out = [v for v in ext if type(v[0]) == UnappliedSymbol and v[0].decl.symbol.decl.arity == len(self.ins) and isinstance(v[0].decl.symbol.decl, SymbolDeclaration) and v[0].decl.symbol.decl.out == self.out and len(v[0].decl.symbol.decl.sorts) == len(self.ins) and all(s == q for s, q in zip(v[0].decl.symbol.decl.sorts, self.ins))] extensions[self.code] = (out, None) return extensions[self.code] Type.extension = extension # Class AQuantification ######################################################
[docs]def get_supersets(self: AQuantification | AAggregate, problem: Optional[Theory]): """determine the extent of the variables, if possible, and add a filter to the quantified expression if needed. This is used to ground quantification over unary predicates. Example: type T := {1,2,3} p : T -> Bool // p is a subset of T !x in p: q(x) The formula is equivalent to `!x in T: p(x) => q(x).` -> The superset of `p` is `{1,2,3}`, the filter is `p(x)`. The grounding is `(p(1)=>q(1)) & (p(2)=>q(2)) & (p(3)=>q(3))` If p is enumerated (`p:={1,2}`) in a structure, however, the superset is now {1,2} and there is no need for a filter. The grounding is `q(1) & q(2)` Result: `self.supersets` is updated to contain the supersets `self.sub_exprs` are updated with the appropriate filters """ self.new_quantees, self.vars1, self.supersets = [], [], [] for q in self.quantees: domain = q.sub_exprs[0] if problem: if isinstance(domain, Type): # quantification over type / Concepts (superset, filter) = domain.extension(problem.interpretations, problem.extensions) elif type(domain) == SymbolExpr: return elif type(domain) == Symbol and domain.decl: self.check(domain.decl.out.type == BOOL, f"{domain} is not a type or predicate") assert domain.decl.name in problem.extensions, "internal error" (superset, filter) = problem.extensions[domain.decl.name] else: self.check(False, f"Can't resolve the domain of {str(q.vars)}") else: (superset, filter) = None, None assert hasattr(domain, "decl"), "Internal error" arity = domain.decl.arity for vars in q.vars: self.check(len(vars) == arity, f"Incorrect arity for {domain}") if problem and filter: self.sub_exprs = [_add_filter(self.q, f, filter, vars, problem) for f in self.sub_exprs] self.vars1.extend(flatten(q.vars)) if superset is None: self.new_quantees.append(q) self.supersets.extend([q] for q in q.vars) # replace the variable by itself else: self.supersets.extend([superset]*len(q.vars))
def _add_filter(q: str, expr: Expression, filter: Callable, args: List[Variable], theory: Theory) -> Expression: """add `filter(args)` to `expr` quantified by `q` Example: `_add_filter('∀', TRUE, filter, [1], theory)` returns `filter([1]) => TRUE` Args: q: the type of quantification expr: the quantified expression filter: a function that returns an Expression for some arguments args:the arguments to be applied to filter Returns: Expression: `expr` extended with appropriate filter """ applied = filter(args) if q == '∀': out = IMPLIES([applied, expr]) elif q == '∃': out = AND([applied, expr]) else: # aggregate if isinstance(expr, AIfExpr): # cardinality # if a then b else 0 -> if (applied & a) then b else 0 arg1 = AND([applied, expr.sub_exprs[0]]) out = IF(arg1, expr.sub_exprs[1], expr.sub_exprs[2]) else: # sum out = IF(applied, expr, Number(number="0")) return out def flatten(a): # https://stackoverflow.com/questions/952914/how-do-i-make-a-flat-list-out-of-a-list-of-lists out = [] for sublist in a: out.extend(sublist) return out @clone_when_necessary def _interpret(self: AQuantification | AAggregate, problem: Optional[Theory], subs: dict[str, Expression] ) -> Expression: """apply information in the problem and its vocabulary Args: problem (Theory): the problem to be applied Returns: Expression: the expanded quantifier expression """ # This method is called by AAggregate._interpret ! if not self.quantees and not subs: # already expanded return Expression._interpret(self, problem, subs) if not self.supersets: # interpret quantees for q in self.quantees: # for !x in $(output_domain(s,1)) q.sub_exprs = [e._interpret(problem, subs) for e in q.sub_exprs] get_supersets(self, problem) assert self.new_quantees is not None and self.vars1 is not None, "Internal error" self.quantees = self.new_quantees # expand the formula by the cross-product of the supersets, and substitute per `subs` forms, subs1 = [], copy(subs) for f in self.sub_exprs: for vals in product(*self.supersets): vals1 = flatten(vals) subs1.update((var.code, val) for var, val in zip(self.vars1, vals1)) new_f2 = f._interpret(problem, subs1) forms.append(new_f2) out = self.update_exprs(f for f in forms) return out AQuantification._interpret = _interpret # Class AAggregate ###################################################### @clone_when_necessary def _interpret(self: AAggregate, problem: Optional[Theory], subs: dict[str, Expression] ) -> Expression: assert self.annotated, f"Internal error in interpret" return AQuantification._interpret(self, problem, subs) AAggregate._interpret = _interpret # Class AppliedSymbol ############################################## @clone_when_necessary def _interpret(self: AppliedSymbol, problem: Optional[Theory], subs: dict[str, Expression] ) -> Expression: # interpret the symbol expression, if any if type(self.symbol) == SymbolExpr and self.symbol.is_intentional(): # $(x)() self.symbol = self.symbol._interpret(problem, subs) if type(self.symbol) == Symbol: # found $(x) self.check(len(self.sub_exprs) == len(self.symbol.decl.sorts), f"Incorrect arity for {self.code}") kwargs = ({'is_enumerated': self.is_enumerated} if self.is_enumerated else {'in_enumeration': self.in_enumeration} if self.in_enumeration else {}) out = AppliedSymbol.make(self.symbol, self.sub_exprs, **kwargs) out.original = self self = out # interpret the arguments sub_exprs = [e._interpret(problem, subs) for e in self.sub_exprs] out = self.update_exprs(e for e in sub_exprs) _finalize(out, subs) if out.is_value(): return out # interpret the AppliedSymbol value, co_constraint = None, None if out.decl and problem: if out.is_enumerated: assert out.decl.type != BOOL, \ f"Can't use 'is enumerated' with predicate {out.decl.name}." if out.decl.name in problem.interpretations: interpretation = problem.interpretations[out.decl.name] if interpretation.default is not None: out.as_disjunction = TRUE else: out.as_disjunction = interpretation.enumeration.contains(sub_exprs, True, interpretations=problem.interpretations, extensions=problem.extensions) if out.as_disjunction.same_as(TRUE) or out.as_disjunction.same_as(FALSE): value = out.as_disjunction out.as_disjunction.annotations = out.annotations elif out.in_enumeration: # re-create original Applied Symbol core = deepcopy(AppliedSymbol.make(out.symbol, sub_exprs)) out.as_disjunction = out.in_enumeration.contains([core], False, interpretations=problem.interpretations, extensions=problem.extensions) if out.as_disjunction.same_as(TRUE) or out.as_disjunction.same_as(FALSE): value = out.as_disjunction out.as_disjunction.annotations = out.annotations elif out.decl.name in problem.interpretations: interpretation = problem.interpretations[out.decl.name] if interpretation.block.name != DEFAULT: f = interpretation.interpret_application value = f(0, out, sub_exprs) if not out.in_head: # instantiate definition (for relevance) inst = [defin.instantiate_definition(out.decl, sub_exprs, problem) for defin in problem.definitions] inst = [x for x in inst if x] if inst: co_constraint = AND(inst) elif self.co_constraint: co_constraint = self.co_constraint.interpret(problem, subs) out = (value if value else out._change(sub_exprs=sub_exprs, co_constraint=co_constraint)) return out AppliedSymbol._interpret = _interpret # Class Variable ####################################################### @catch_error def _interpret(self: Variable, problem: Optional[Theory], subs: dict[str, Expression] ) -> Expression: if self.sort: self.sort = self.sort._interpret(problem, subs) out = subs.get(self.code, self) return out Variable._interpret = _interpret Done = True