"""
Mathematical Expression Generator
This module provides functionality for generating, manipulating, and evaluating mathematical expressions, including:
1. Expression tree nodes (Variable, Constant, Add, Subtract, Multiply, Divide, Power, Comparison)
2. Formula configuration and serialization/deserialization
3. Random expression generation with configurable complexity
4. Z3 integration for constraint solving
5. LaTeX output for mathematical notation
"""
import random
import z3
import math
import re
[docs]
class ExprNode:
[docs]
def to_z3(self, vars): raise NotImplementedError
[docs]
def evaluate(self, values): raise NotImplementedError
[docs]
def to_latex(self, var_names): raise NotImplementedError
[docs]
def serialize(self): raise NotImplementedError
[docs]
class Variable(ExprNode):
def __init__(self, index): self.index = index
[docs]
def to_z3(self, vars): return vars[self.index], [] # No additional constraints
[docs]
def evaluate(self, values): return values[self.index]
[docs]
def to_latex(self, var_names): return var_names[self.index]
[docs]
def serialize(self): return f"(Var {self.index})"
[docs]
class Constant(ExprNode):
def __init__(self, value): self.value = value
[docs]
def to_z3(self, _): return self.value, [] # No additional constraints
[docs]
def evaluate(self, _): return self.value
[docs]
def to_latex(self, _): return f"{self.value}"
[docs]
def serialize(self): return f"(Const {self.value})"
[docs]
class Add(ExprNode):
def __init__(self, left, right):
self.left = left
self.right = right
[docs]
def to_z3(self, vars):
left_expr, left_constraint = self.left.to_z3(vars)
right_expr, right_constraint = self.right.to_z3(vars)
return left_expr + right_expr, left_constraint + right_constraint
[docs]
def evaluate(self, values): return self.left.evaluate(values) + self.right.evaluate(values)
[docs]
def to_latex(self, var_names): return f"({self.left.to_latex(var_names)} + {self.right.to_latex(var_names)})"
[docs]
def serialize(self): return f"(Add {self.left.serialize()} {self.right.serialize()})"
[docs]
class Subtract(ExprNode):
def __init__(self, left, right):
self.left = left
self.right = right
[docs]
def to_z3(self, vars):
left_expr, left_constraint = self.left.to_z3(vars)
right_expr, right_constraint = self.right.to_z3(vars)
return left_expr - right_expr, left_constraint + right_constraint
[docs]
def evaluate(self, values): return self.left.evaluate(values) - self.right.evaluate(values)
[docs]
def to_latex(self, var_names): return f"({self.left.to_latex(var_names)} - {self.right.to_latex(var_names)})"
[docs]
def serialize(self): return f"(Sub {self.left.serialize()} {self.right.serialize()})"
[docs]
class Multiply(ExprNode):
def __init__(self, left, right):
self.left = left
self.right = right
[docs]
def to_z3(self, vars):
left_expr, left_constraint = self.left.to_z3(vars)
right_expr, right_constraint = self.right.to_z3(vars)
return left_expr * right_expr, left_constraint + right_constraint
[docs]
def evaluate(self, values): return self.left.evaluate(values) * self.right.evaluate(values)
[docs]
def to_latex(self, var_names): return f"{self.left.to_latex(var_names)} \\cdot {self.right.to_latex(var_names)}"
[docs]
def serialize(self): return f"(Mul {self.left.serialize()} {self.right.serialize()})"
[docs]
class Divide(ExprNode):
def __init__(self, num, den):
self.num = num
self.den = den
[docs]
def to_z3(self, vars):
num_expr, num_constraint = self.num.to_z3(vars)
den_expr, den_constraint = self.den.to_z3(vars)
# Convert to Z3 Real if not already a Z3 expression
if not isinstance(num_expr, z3.ExprRef):
num_expr = z3.RealVal(num_expr)
if not isinstance(den_expr, z3.ExprRef):
den_expr = z3.RealVal(den_expr)
# Only apply ToReal if the expression is an integer
if z3.is_int(num_expr):
num_expr = z3.ToReal(num_expr)
if z3.is_int(den_expr):
den_expr = z3.ToReal(den_expr)
return (
num_expr / den_expr,
num_constraint + den_constraint + [den_expr != 0]
)
[docs]
def evaluate(self, values):
den_val = self.den.evaluate(values)
if den_val == 0:
raise ValueError("Division by zero")
return self.num.evaluate(values) / den_val # Python uses floating point division by default
[docs]
def to_latex(self, var_names):
return f"\\frac{{{self.num.to_latex(var_names)}}}{{{self.den.to_latex(var_names)}}}"
[docs]
def serialize(self):
return f"(Div {self.num.serialize()} {self.den.serialize()})"
[docs]
class Power(ExprNode):
def __init__(self, var, exponent):
self.var = var # Must be a Variable
self.exponent = exponent # Must be a positive integer
[docs]
def to_z3(self, vars):
var_expr, var_constraints = self.var.to_z3(vars)
if not isinstance(var_expr, z3.ExprRef):
var_expr = z3.RealVal(var_expr)
# Ensure exponent is a positive integer
if not isinstance(self.exponent, int) or self.exponent < 1:
raise ValueError(f"Exponent must be a positive integer, got {self.exponent}")
return var_expr ** self.exponent, var_constraints
[docs]
def evaluate(self, values):
return self.var.evaluate(values) ** self.exponent
[docs]
def to_latex(self, var_names):
var_name = self.var.to_latex(var_names)
return f"{var_name}^{{{self.exponent}}}"
[docs]
def serialize(self):
return f"(Power {self.var.serialize()} {self.exponent})"
[docs]
class Comparison(ExprNode):
OPS = {
'<=': (lambda a,b: a <= b, r'\leq'),
'>=': (lambda a,b: a >= b, r'\geq'),
'==': (lambda a,b: a == b, '='),
'<': (lambda a,b: a < b, '<'),
'>': (lambda a,b: a > b, '>')
}
def __init__(self, left, op, right):
self.left = left
self.op = op
self.right = right
[docs]
def to_z3(self, vars):
left_expr, left_constraint = self.left.to_z3(vars)
right_expr, right_constraint = self.right.to_z3(vars)
op_func, _ = Comparison.OPS[self.op]
return (
op_func(left_expr, right_expr),
left_constraint + right_constraint
)
[docs]
def evaluate(self, values):
return Comparison.OPS[self.op][0](self.left.evaluate(values), self.right.evaluate(values))
[docs]
def to_latex(self, var_names):
return f"{self.left.to_latex(var_names)} {Comparison.OPS[self.op][1]} {self.right.to_latex(var_names)}"
[docs]
def serialize(self):
return f"(Cmp {self.op} {self.left.serialize()} {self.right.serialize()})"
[docs]
def to_value(s):
# First check if it can be converted to integer (no decimal point or scientific notation)
if '.' not in s and 'e' not in s.lower():
try:
return int(s)
except ValueError:
pass # If it can't be converted to integer, continue trying float
# Try to convert to float
try:
return float(s)
except ValueError:
raise ValueError(f"Input string '{s}' is not a valid numeric format")
[docs]
def generate_expr(vars_num, depth=0, max_depth=5, allow_const=True,
force_var=False, parent_op=None, allow_power=True, max_const=10):
"""
Generate expression tree.
Args:
vars_num (int): Number of variables, determines the variable index range (0 to vars_num-1)
depth (int): Current recursion depth, used to control expression complexity, default 0
max_depth (int): Maximum recursion depth, controls maximum expression complexity, default 5
allow_const (bool): Whether to allow generating constant terms, default True
force_var (bool): Whether to force generating variables (automatically effective when allow_const is False), default False
parent_op (str): Parent node operation type, used for special handling of certain operations, default None
allow_power (bool): Whether to allow generating power operations (x² or x³), default True
max_const (int): Maximum value for generated constants, default 10
Returns:
Expression: Generated expression tree object (can be Variable, Constant, Add, Subtract, Multiply or Power)
Notes:
1. When depth >= max_depth, will force generating leaf nodes (variables or constants)
2. Multiplication operations receive special handling to ensure no two constants are multiplied
3. Power operations only allow x² or x³ forms
4. Addition/subtraction prevents generating two constants being added/subtracted
"""
if depth >= max_depth:
if force_var or not allow_const:
return Variable(random.randint(0, vars_num - 1))
else:
return random.choice([
Variable(random.randint(0, vars_num - 1)),
Constant(random.randint(1, max_const))
])
# Adjust operation type probability distribution
op_weights = {
'Add': 4,
'Subtract': 3,
'Multiply': 2,
'Power': 1 if allow_power else 0 # Determined by parameter whether to allow power operations
}
op_types = [op for op, w in op_weights.items() if w > 0]
weights = [w for op, w in op_weights.items() if w > 0]
op_type = random.choices(op_types, weights=weights, k=1)[0]
if op_type == 'Power' and allow_power:
# Power operation base must be a single variable
var = Variable(random.randint(0, vars_num - 1))
exponent = random.choice([2, 3]) # Only allow x² or x³
return Power(var, exponent)
elif op_type == 'Multiply':
# Generate left operand (forced to be constant or simple variable)
left = generate_expr(
vars_num, depth+1, max_depth,
allow_const=True,
force_var=False,
parent_op='Multiply',
allow_power=allow_power
)
# Generate right operand (forced to be another type)
if isinstance(left, Constant):
right = generate_expr(
vars_num, depth+1, max_depth,
allow_const=False,
force_var=True,
parent_op='Multiply',
allow_power=allow_power
)
else:
right = Constant(random.randint(2, max_const))
return Multiply(left, right)
else: # Add/Subtract
left = generate_expr(
vars_num, depth+1, max_depth,
allow_const,
force_var=False,
parent_op=op_type,
allow_power=allow_power
)
right = generate_expr(
vars_num, depth+1, max_depth,
allow_const,
force_var=False,
parent_op=op_type,
allow_power=allow_power
)
# Prevent both sides from being constants
if isinstance(left, Constant) and isinstance(right, Constant):
if random.choice([True, False]):
left = Variable(random.randint(0, vars_num - 1))
else:
right = Variable(random.randint(0, vars_num - 1))
return Add(left, right) if op_type == 'Add' else Subtract(left, right)
[docs]
def build_system(configs, vars, var_names=None):
"""
Build equation system, generate Z3 expressions and LaTeX format strings.
Args:
configs (list[str]): List of serialized formula configurations
vars (list): Z3 variable list
var_names (list[str], optional): Variable name list, used for LaTeX output
Returns:
dict: Dictionary containing the following keys:
- 'z3_expr': List of Z3 expressions for all formulas
- 'z3_constraint': List of all constraint conditions
- 'latex_str': LaTeX format string representation of the system
Notes:
1. When there's only one formula, returns a single LaTeX expression
2. Multiple formulas are wrapped in cases environment
3. Automatically adds positive exponent constraints for power operations
"""
vars = list(vars)
vars_num = len(vars)
configs = [FormulaConfig.deserialize(config, vars_num=vars_num) for config in configs]
# print("configs: ", configs)
z3_vars = vars
str_vars = var_names if var_names else [v.decl().name() for v in vars]
# Collect all expressions and constraints
all_exprs = []
all_constraints = []
for config in configs:
expr, constraints = config.root.to_z3(z3_vars)
all_exprs.append(expr)
all_constraints.extend(constraints)
if isinstance(config.root, Power):
all_constraints.append(config.root.exponent > 0)
# print("all_exprs: ", all_exprs)
# print("all_constraints: ", all_constraints)
return {
'z3_expr': all_exprs,
'z3_constraint': all_constraints,
'latex_str': r'\begin{cases} ' + ' \\\\ '.join(
[config.get_latex(str_vars) for config in configs]
) + r' \end{cases}' if len(configs) > 1 else configs[0].get_latex(str_vars)
}
if __name__ == "__main__":
configs = generate_formulas(formula_num=2, vars_num=3, is_cond=True)
print("configs:", configs)
vars = [z3.Real('x'), z3.Real('y'), z3.Real('z')]
system = build_system(configs, vars)
print("Generated System:")
print(system['latex_str'])
solver = z3.Solver()
solver.add(system['z3_expr'])
print("\nZ3 Constraints:", solver)