diff --git a/src/protosym/core/differentiate.py b/src/protosym/core/differentiate.py index 2d6dc72..cec95a5 100644 --- a/src/protosym/core/differentiate.py +++ b/src/protosym/core/differentiate.py @@ -8,9 +8,11 @@ from dataclasses import dataclass from dataclasses import field +from functools import reduce from typing import TYPE_CHECKING as _TYPE_CHECKING from protosym.core.tree import forward_graph +from protosym.core.tree import Tr __all__ = [ @@ -20,11 +22,239 @@ if _TYPE_CHECKING: + from typing import Callable, Sequence + from protosym.core.atom import AtomType from protosym.core.tree import Tree, SubsFunc _DiffRules = dict[tuple[Tree, int], SubsFunc] +@dataclass(frozen=True) +class RingOps: + """Collection of ring operations.""" + + Integer: AtomType[int] + iadd: Callable[[int, int], int] + imul: Callable[[int, int], int] + add: Tree + mul: Tree + pow: Tree + + def split_integers(self, args: Sequence[Tree]) -> tuple[list[int], list[Tree]]: + integers: list[int] = [] + new_args: list[Tree] = [] + for arg in args: + if not arg.children and (atom := arg.value).atom_type == self.Integer: + integers.append(atom.value) # type: ignore + else: + new_args.append(arg) + return integers, new_args + + def flatten_add(self, args: list[Tree]) -> Tree: + integers: list[int] = [] + new_args: list[Tree] = [] + + # Associativity of add + # (x + y) + z -> x + y + z + for arg in args: + if arg.children and arg.children[0] == self.add: + new_args.extend(arg.children[1:]) + else: + new_args.append(arg) + + # Collect integer part + # x + 1 + 2 -> 3 + x + new_args2: list[Tree] = [] + for arg in new_args: + if not arg.children and (atom := arg.value).atom_type == self.Integer: + integers.append(atom.value) # type:ignore + else: + new_args2.append(arg) + + # Process all muls, extracting their coefficients + # 2*x + 3*x -> 5*x + totals = {} + for arg in new_args2: + if arg.children and arg.children[0] == self.mul: + intfacs, factors = self.split_integers(arg.children[1:]) + if len(factors) == 1: + [fac] = factors + else: + fac = self.mul(*factors) + integer = reduce(self.imul, intfacs, 1) + if fac not in totals: + totals[fac] = 0 + totals[fac] += integer + else: + if arg not in totals: + totals[arg] = 0 + totals[arg] += 1 + + new_args3: list[Tree] = [] + for fac, c in totals.items(): + if c == 0: + continue + elif c == 1: + new_args3.append(fac) + elif fac.children and fac.children[0] == self.mul: + new_args3.append(self.mul(Tr(self.Integer(c)), *fac.children[1:])) + else: + new_args3.append(self.mul(Tr(self.Integer(c)), fac)) + + int_value = reduce(self.iadd, integers, 0) + + if int_value: + new_args3.insert(0, Tr(self.Integer(int_value))) + + if not new_args3: + expr = Tr(self.Integer(0)) + elif len(new_args3) == 1: + [expr] = new_args3 + else: + expr = self.add(*new_args3) + + return expr + + def flatten_mul(self, args: list[Tree]) -> Tree: + integers: list[int] = [] + new_args: list[Tree] = [] + + for arg in args: + if arg.children and arg.children[0] == self.mul: + new_args.extend(arg.children[1:]) + else: + new_args.append(arg) + + new_args2: list[Tree] = [] + for arg in new_args: + if not arg.children and (atom := arg.value).atom_type == self.Integer: + integers.append(atom.value) # type:ignore + else: + new_args2.append(arg) + + powers = {} + for arg in new_args2: + if arg.children and arg.children[0] == self.pow: + base, s_exp = arg.children[1:] + exp: int + if s_exp.value.atom_type == self.Integer: + exp = s_exp.value.value # type: ignore + else: + base, exp = arg, 1 + else: + base, exp = arg, 1 + + if base not in powers: + powers[base] = 0 + powers[base] += exp + + new_args3: list[Tree] = [] + for base, exp in powers.items(): + if exp == 0: + continue + elif exp == 1: + new_args3.append(base) + else: + new_args3.append(self.pow(base, Tr(self.Integer(exp)))) + + int_value = reduce(self.imul, integers, 1) + + if int_value == 0: + return Tr(self.Integer(int_value)) + elif int_value != 1: + new_args3.insert(0, Tr(self.Integer(int_value))) + + if not new_args3: + expr = Tr(self.Integer(1)) + elif len(new_args3) == 1: + [expr] = new_args3 + else: + expr = self.mul(*new_args3) + + return expr + + def flatten_pow(self, args: list[Tree]) -> Tree: + base, exponent = args + + # (x**y)**a -> x**(a*y) for integer a + if base.children and base.children[0] == self.pow: + base_base, base_exp = base.children[1:] + if not exponent.children and exponent.value.atom_type == self.Integer: + exponent = self.flatten_mul([base_exp, exponent]) + base = base_base + + if exponent == Tr(self.Integer(0)): + expr = Tr(self.Integer(1)) + elif exponent == Tr(self.Integer(1)): + expr = base + else: + expr = self.pow(base, exponent) + return expr + + def flatten(self, expr: Tree) -> Tree: + """Apply the standard ring simplification rules. + + Identity (addition): :math:`x + 0 = x` + Identity (multiplication): :math:`x * 1 = x` + Associativity (addition): :math:`(x + y) + z = x + (y + z)` + Associativity (multiplication): :math:`(x * y) * z = x * (y * z)` + Commutativity (addition): :math:`x + y = y + x` + Commutativity (multiplication): :math:`x * y = y * x` + Add to Mul: :math:`2*x + 3*x = 5*x` + Mul to Pow: :math:`x^2 * x^3 = x^5` + """ + graph = forward_graph(expr) + stack = list(graph.atoms) + for func, indices in graph.operations: + + args = [stack[i] for i in indices] + + if func == self.add: + expr = self.flatten_add(args) + elif func == self.mul: + expr = self.flatten_mul(args) + elif func == self.pow and len(args) == 2: + expr = self.flatten_pow(args) + else: + expr = func(*args) + + stack.append(expr) + + return stack[-1] + + def flatten(self, expr: Tree) -> Tree: + """Apply common ring simplification rules. + + Identity (addition): :math:`x + 0 = x` + Identity (multiplication): :math:`x * 1 = x` + Associativity (addition): :math:`(x + y) + z = x + (y + z)` + Associativity (multiplication): :math:`(x * y) * z = x * (y * z)` + Commutativity (addition): :math:`x + y = y + x` + Commutativity (multiplication): :math:`x * y = y * x` + Add to Mul: :math:`2*x + 3*x = 5*x` + Mul to Pow: :math:`x^2 * x^3 = x^5` + ... + """ + graph = forward_graph(expr) + stack = list(graph.atoms) + for func, indices in graph.operations: + + args = [stack[i] for i in indices] + + if func == self.add: + expr = self.flatten_add(args) + elif func == self.mul: + expr = self.flatten_mul(args) + elif func == self.pow and len(args) == 2: + expr = self.flatten_pow(args) + else: + expr = func(*args) + + stack.append(expr) + + return stack[-1] + + @dataclass(frozen=True) class DiffProperties: """Collection of properties needed for differentiation.""" diff --git a/src/protosym/core/sym.py b/src/protosym/core/sym.py index c10e546..4a34700 100644 --- a/src/protosym/core/sym.py +++ b/src/protosym/core/sym.py @@ -22,6 +22,7 @@ from protosym.core.atom import AtomType from protosym.core.differentiate import diff_forward from protosym.core.differentiate import DiffProperties +from protosym.core.differentiate import RingOps from protosym.core.evaluate import Evaluator from protosym.core.exceptions import BadRuleError from protosym.core.tree import SubsFunc @@ -617,6 +618,26 @@ def __call__( return self.evaluator(expr.rep, values_rep) +class SymRingOps(Generic[T_sym]): + """Representation of ring operations.""" + + def __init__( + self, + new_sym: Type[T_sym], + integer: SymAtomType[T_sym, int], + iadd: Callable[[int, int], int], + imul: Callable[[int, int], int], + add: Sym, + mul: Sym, + pow: Sym, + ): + self.new_sym = new_sym + self.ringops = RingOps(integer.atom_type, iadd, imul, add.rep, mul.rep, pow.rep) + + def __call__(self, expr: T_sym) -> T_sym: + return self.new_sym(self.ringops.flatten(expr.rep)) + + class SymDifferentiator(Generic[T_sym]): """Representation of differentiation rules. diff --git a/src/protosym/simplecas/expr.py b/src/protosym/simplecas/expr.py index ce38d1b..88d7310 100644 --- a/src/protosym/simplecas/expr.py +++ b/src/protosym/simplecas/expr.py @@ -3,6 +3,8 @@ from functools import reduce from functools import wraps +from operator import add +from operator import mul from typing import Any from typing import Callable from typing import Optional @@ -17,6 +19,7 @@ from protosym.core.sym import HeadRule from protosym.core.sym import Sym from protosym.core.sym import SymDifferentiator +from protosym.core.sym import SymRingOps from protosym.core.tree import SubsFunc from protosym.core.tree import topological_sort from protosym.simplecas.exceptions import ExpressifyError @@ -425,26 +428,50 @@ def count_ops_graph(self) -> int: """ return len(topological_sort(self.rep)) - def diff(self, sym: Expr, ntimes: int = 1) -> Expr: + def flatten(self) -> Expr: + """Apply the usual simplification rules for Add, Mul and Pow.""" + return ring_ops(self) + + def diff(self, sym: Expr, ntimes: int = 1, flatten: bool = True) -> Expr: """Differentiate ``expr`` wrt ``sym`` (``ntimes`` times). >>> from protosym.simplecas import x, sin >>> sin(x).diff(x) cos(x) - Currently no simplification is done which can lead to some strange - looking output: + Large expressions can be generated and differentiated efficiently: + + >>> expr = sin(sin(sin(sin(sin(x))))).diff(x, 10) + >>> expr.count_ops_graph() + 20427 + >>> expr.count_ops_tree() + 597557 + + By default ``flatten`` is called during the calculation of derivatives + but that can be disabled by passing ``flatten=False`` which gives + unsimplified derivative expressions: >>> sin(x).diff(x, 4) + sin(x) + >>> sin(x).diff(x, 4, flatten=False) (-1*(-1*sin(x))) - Large expressions can be generated and differentiated efficiently: + Although ``flatten`` makes simple expressions look simpler it can also + make the graph structure of large expressions more complicated: - >>> expr = sin(sin(sin(sin(sin(x))))).diff(x, 10) - >>> expr.count_ops_graph() + >>> expr = sin(sin(sin(sin(sin(x))))) + >>> expr.diff(x, 10, flatten=False).count_ops_graph() 1552 - >>> expr.count_ops_tree() + >>> expr.diff(x, 10, flatten=True).count_ops_graph() + 20427 + >>> expr.diff(x, 10, flatten=False).count_ops_tree() 893621974 + >>> expr.diff(x, 10, flatten=True).count_ops_tree() + 597557 + + Note that the smallest operation count here arises when not flattening + and when using the graph representation rather than the tree + representation. Differentiation rules for new functions can be added as needed: @@ -473,6 +500,8 @@ def diff(self, sym: Expr, ntimes: int = 1) -> Expr: deriv = self for _ in range(ntimes): deriv = diff(deriv, sym) + if flatten: + deriv = deriv.flatten() return deriv def bin_expand(self) -> Expr: @@ -582,6 +611,12 @@ def call(self, args: list[Tree]) -> Tree: count_ops_tree[AtomRule[a]] = one_func(a) count_ops_tree[HeadRule(a, b)] = sum_plus_one(a, b) +# +# Basic ring operations for simplifying Add, Mul, Pow. +# + +ring_ops = SymRingOps(Expr, Integer, iadd=add, imul=mul, add=Add, mul=Mul, pow=Pow) + # # Differentiation. # diff --git a/tests/test_simplecas.py b/tests/test_simplecas.py index 102f771..5f7368f 100644 --- a/tests/test_simplecas.py +++ b/tests/test_simplecas.py @@ -30,6 +30,7 @@ two = Integer(2) +z = Symbol("z") def test_simplecas_types() -> None: @@ -276,12 +277,57 @@ def test_simplecas_differentation() -> None: assert x.diff(x) == one assert sin(1).diff(x) == zero assert (2 * sin(x)).diff(x) == 2 * cos(x) - assert (x**3).diff(x) == 3 * x ** (Add(3, -1)) + assert (x**3).diff(x) == 3 * x**2 assert sin(x).diff(x) == cos(x) assert cos(x).diff(x) == -sin(x) assert (sin(x) + cos(x)).diff(x) == cos(x) + -1 * sin(x) - assert (sin(x) ** 2).diff(x) == 2 * sin(x) ** Add(2, -1) * cos(x) - assert (x * sin(x)).diff(x) == 1 * sin(x) + x * cos(x) + assert (sin(x) ** 2).diff(x) == Mul(2, sin(x), cos(x)) + assert (x * sin(x)).diff(x) == sin(x) + x * cos(x) + assert sin(x).diff(x, 4) == sin(x) + assert sin(x).diff(x, 4, flatten=False) == -1 * (-1 * sin(x)) + + +def test_simplecas_flatten() -> None: + """Test flattening simplecas expressions.""" + assert x**2 * x**3 != x**5 + assert (x**2 * x**3).flatten() == x**5 + assert x + x + x != 3 * x + assert (x + x + x).flatten() == 3 * x + assert x + y + z != Add(x, y, z) + assert (x + y + z).flatten() == Add(x, y, z) + assert x - x != zero + assert (x - x).flatten() == zero + assert x**0 != one + assert (x**0).flatten() == one + assert x**1 != x + assert (x**1).flatten() == x + assert 1 * x != x + assert (1 * x).flatten() == x + assert 0 * x != zero + assert (0 * x).flatten() == zero + assert 0 + x != x + assert (0 + x).flatten() == x + assert x + 1 + 2 != 3 + x + assert (x + 1 + 2).flatten() == 3 + x + assert x * 2 * 3 != 6 * x + assert (x * 2 * 3).flatten() == 6 * x + assert 2 * x + 3 * x != 5 * x + assert (2 * x + 3 * x).flatten() == 5 * x + assert 2 * x * y + 3 * x * y != Mul(5, x, y) + assert (2 * x * y + 3 * x * y).flatten() == Mul(5, x, y) + assert (x**2) ** 3 != x**6 + assert ((x**2) ** 3).flatten() == x**6 + # Don't flatten this: + assert ((x**2) ** y).flatten() == (x**2) ** y != x ** (2 * y) + # Do flatten this: + assert ((x**y) ** 2) != Pow(x, 2 * y) + assert ((x**y) ** 2).flatten() == x ** (2 * y) + # Don't flatten this: + assert ((x**y) * x**2).flatten() == x**y * x**2 + assert ((x**y) * x**2).flatten() != x ** (2 + y) + # This flatten is undesirable (x could be zero): + assert x * (1 / x) != one + assert (x * (1 / x)).flatten() == one def test_simplecas_differentiation_rules() -> None: