diff --git a/pyomo/core/expr/visitor.py b/pyomo/core/expr/visitor.py index 210c0ec32e7..7503b2fc67b 100644 --- a/pyomo/core/expr/visitor.py +++ b/pyomo/core/expr/visitor.py @@ -9,6 +9,7 @@ import inspect +import itertools import logging import sys from copy import deepcopy @@ -16,6 +17,7 @@ logger = logging.getLogger('pyomo.core') +from pyomo.common.collections import ComponentSet from pyomo.common.deprecation import deprecated, deprecation_warning from pyomo.common.errors import DeveloperError, TemplateExpressionError from pyomo.common.numeric_types import ( @@ -1395,32 +1397,75 @@ def identify_components(expr, component_types): class IdentifyVariableVisitor(StreamBasedExpressionVisitor): - def __init__(self, include_fixed=False, named_expression_cache=None): + def __init__( + self, include_fixed=False, named_expression_cache=None, var_cache=None + ): """Visitor that collects all unique variables participating in an expression - Args: - include_fixed (bool): Whether to include fixed variables - named_expression_cache (optional, dict): Dict mapping ids of named - expressions to a tuple of the list of all variables and the - set of all variable ids contained in the named expression. + :meth:`walk_expression` returns a generator of the unique + variables found in the expression. If `var_cache` was + specified, then only the *new* variables found in `expr` are + returned (the full list of all variables is maintained in the + `var_cache` dict). + + Parameters + ---------- + include_fixed : bool + If True, fixed variables will be reported + + named_expression_cache : dict + Cache of named expressions that have been visited by this + walker. The value includes the variables within the named + expression as well as information for detecting when the + named expression has changed (for cache invalidation). + + var_cache : ComponentSet + ComponentSet for recording all variables that have been + "seen" by this walker. If provided, this ComponentSet is + preserved between calls to :meth:`walk_expression` (so + repeated variables are not returned more than once). """ super().__init__() self._include_fixed = include_fixed + # cache of visited named expressions. This dict maps + # {eid: (seen, exprs)}. + # - eid is the id() of the named expression + # - seen is the processed result for the named expression + # (including any nested named expressions) + # - exprs is used for automatically invalidating the cache (see below). self._cache = named_expression_cache + # Stack of named expressions. This holds the tuple # (eid, _seen, _exprs) # where eid is the id() of the subexpression we are currently # processing, and _seen and _exprs are from the parent context. self._expr_stack = [] - # The following attributes will be added by initializeWalker: - # self._seen: dict(eid: obj) + + # cache of "seen" variables: dict(eid: VarData) + # + # Pyomo encourages the use of ComponentSet to store (ordered) + # sets of Pyomo components (and in particular, Pyomo Vars). + # However, to reduce overhead (this is about a 10-12% + # improvement), we will operate directly on the underlying dict + # data store. This is slightly evil (and definitely violates + # encapsulation), but we accept the risk as identify_variables() + # is a potentially expensive operation. + if isinstance(var_cache, ComponentSet): + var_cache = var_cache._data + self._seen = var_cache + + # The following attribute will be added by initializeWalker: # self._exprs: list of (e, e.expr) for any (nested) named expressions def initializeWalker(self, expr): assert not self._expr_stack - self._seen = {} + if self._seen is None: + self._seen = {} + self._expr_stack.append(None) + else: + self._expr_stack.append(len(self._seen)) self._exprs = None if not self.beforeChild(None, expr, 0)[0]: return False, self.finalizeResult(None) @@ -1452,8 +1497,17 @@ def exitNode(self, node, data): self._merge_obj_lists(_seen, _exprs) def finalizeResult(self, result): + seen = self._seen + initial_num_seen = self._expr_stack.pop() assert not self._expr_stack - return self._seen.values() + if initial_num_seen is None: + self._seen = None + return seen.values() + else: + # Only return the *new* variables found on this walk. This + # relies on dict iteration being in insertion order (which, + # since python 3.7, it is) + return itertools.islice(seen.values(), initial_num_seen, len(seen)) def _merge_obj_lists(self, _seen, _exprs): self._seen.update(_seen) diff --git a/pyomo/core/tests/unit/test_visitor.py b/pyomo/core/tests/unit/test_visitor.py index 6268afd45ad..b2a18fa8ceb 100644 --- a/pyomo/core/tests/unit/test_visitor.py +++ b/pyomo/core/tests/unit/test_visitor.py @@ -61,6 +61,7 @@ SimpleExpressionVisitor, StreamBasedExpressionVisitor, ExpressionReplacementVisitor, + IdentifyVariableVisitor, evaluate_expression, expression_to_string, replace_expressions, @@ -252,6 +253,26 @@ def test_identify_vars_linear_expression(self): expr = quicksum([m.x, m.x], linear=True) self.assertEqual(list(identify_variables(expr, include_fixed=False)), [m.x]) + def test_identify_vars_var_cache(self): + m = ConcreteModel() + m.x = Var() + m.y = Var() + m.z = Var() + + e1 = m.x + m.y + e2 = m.y + m.z + + v = IdentifyVariableVisitor() + self.assertEqual(list(v.walk_expression(e1)), [m.x, m.y]) + self.assertEqual(list(v.walk_expression(e2)), [m.y, m.z]) + + seen = {} + v = IdentifyVariableVisitor(var_cache=seen) + self.assertEqual(list(v.walk_expression(e2)), [m.y, m.z]) + self.assertEqual(list(seen.values()), [m.y, m.z]) + self.assertEqual(list(v.walk_expression(e1)), [m.x]) + self.assertEqual(list(seen.values()), [m.y, m.z, m.x]) + class TestIdentifyParams(unittest.TestCase): def test_identify_params_numeric(self): diff --git a/pyomo/util/tests/test_vars_from_expressions.py b/pyomo/util/tests/test_vars_from_expressions.py new file mode 100644 index 00000000000..bd60fb70f66 --- /dev/null +++ b/pyomo/util/tests/test_vars_from_expressions.py @@ -0,0 +1,59 @@ +# ____________________________________________________________________________________ +# +# Pyomo: Python Optimization Modeling Objects +# Copyright (c) 2008-2026 National Technology and Engineering Solutions of Sandia, LLC +# Under the terms of Contract DE-NA0003525 with National Technology and Engineering +# Solutions of Sandia, LLC, the U.S. Government retains certain rights in this +# software. This software is distributed under the 3-clause BSD License. +# ____________________________________________________________________________________ + +import pyomo.environ as pyo +from pyomo.common import unittest +from pyomo.util.vars_from_expressions import get_vars, get_vars_from_components + + +class TestVarsFromExpressions(unittest.TestCase): + def test_get_vars(self): + m = pyo.ConcreteModel() + m.x = pyo.Var(list(range(5))) + m.c1 = pyo.Constraint(expr=m.x[0] + m.x[1] == 0) + m.c2 = pyo.Constraint(expr=m.x[1] + m.x[2] == 0) + m.obj = pyo.Objective(expr=m.x[3] + m.x[4]) + + self.assertEqual(list(get_vars(m)), [m.x[0], m.x[1], m.x[2], m.x[3], m.x[4]]) + # verify the default values for active and include_fixed + m.x[0].fix(0) + m.c2.deactivate() + self.assertEqual(list(get_vars(m)), [m.x[1], m.x[3], m.x[4]]) + + def test_get_vars_from_components(self): + m = pyo.ConcreteModel() + m.x = pyo.Var(list(range(5))) + m.c1 = pyo.Constraint(expr=m.x[0] + m.x[1] == 0) + m.c2 = pyo.Constraint(expr=m.x[1] + m.x[2] == 0) + m.obj = pyo.Objective(expr=m.x[3] + m.x[4]) + + self.assertEqual( + list(get_vars_from_components(m, pyo.Constraint)), [m.x[0], m.x[1], m.x[2]] + ) + self.assertEqual( + list(get_vars_from_components(m, pyo.Objective)), [m.x[3], m.x[4]] + ) + self.assertEqual( + list(get_vars_from_components(m, (pyo.Constraint, pyo.Objective))), + [m.x[0], m.x[1], m.x[2], m.x[3], m.x[4]], + ) + + # verify the default values for active and include_fixed + m.x[0].fix(0) + m.c2.deactivate() + self.assertEqual( + list(get_vars_from_components(m, pyo.Constraint)), [m.x[0], m.x[1], m.x[2]] + ) + self.assertEqual( + list(get_vars_from_components(m, pyo.Objective)), [m.x[3], m.x[4]] + ) + self.assertEqual( + list(get_vars_from_components(m, (pyo.Constraint, pyo.Objective))), + [m.x[0], m.x[1], m.x[2], m.x[3], m.x[4]], + ) diff --git a/pyomo/util/vars_from_expressions.py b/pyomo/util/vars_from_expressions.py index 6ad2f80b56b..e2f6125c5c9 100644 --- a/pyomo/util/vars_from_expressions.py +++ b/pyomo/util/vars_from_expressions.py @@ -8,14 +8,15 @@ # ____________________________________________________________________________________ """ -This module contains a function to generate a list of the Vars which appear +This module contains functions to generate a list of the Vars appearing in expressions in the active tree. Note this is not the same as -component_data_objects(Var) because it does not look for Var objects which are +``component_data_objects(Var)`` because it does not look for Var objects which are not used in any expressions and it does not care if the Vars it finds are -actually in the subtree or not. +actually in the Block subtree or not. """ -from pyomo.core import Block +from pyomo.common.collections import ComponentSet +from pyomo.core import Block, Constraint, Objective from pyomo.core.expr.visitor import IdentifyVariableVisitor @@ -28,29 +29,97 @@ def get_vars_from_components( descend_into=Block, descent_order=None, ): - """Returns a generator of all the Var objects which are used in Constraint + """Returns a ComponentSet of all the Var objects that appear in expressions on the block. By default, this recurses into sub-blocks. - Args: - ctype: The type of component from which to get Vars, assumed to have - an expr attribute. - include_fixed: Whether or not to include fixed variables - active: Whether to find Vars that appear in Constraints accessible - via the active tree - sort: sort method for iterating through Constraint objects - descend_into: Ctypes to descend into when finding Constraints - descent_order: Traversal strategy for finding the objects of type ctype + Parameters + ---------- + include_fixed : bool + If True, both fixed and free variables will be returned + + ctype : type | tuple[type] + The "ctype" of component from which to get Vars. The components + must expose a ``expr`` attribute that will be walked looking for + variables. + + active : bool | None + If True, only variables accessible through the active component + tree will be returned. If None, all variables accessible + through either active or inactive components will be returned. + + sort: SortComponents | bool | None + sort method for iterating through Constraint objects + + descend_into : type | tuple[type] | None + "ctypes" to descend into when finding Constraints + + descent_order : TraversalStrategy | None + Traversal strategy for walking the block hierarchy + + Returns + ------- + ComponentSet : set of variables + """ - visitor = IdentifyVariableVisitor(include_fixed, {}) - seen = set() - for constraint in block.component_data_objects( + var_cache = ComponentSet() + visitor = IdentifyVariableVisitor(include_fixed, {}, var_cache=var_cache) + for component in block.component_data_objects( ctype, active=active, sort=sort, descend_into=descend_into, descent_order=descent_order, ): - for var in visitor.walk_expression(constraint.expr): - if id(var) not in seen: - seen.add(id(var)) - yield var + visitor.walk_expression(component.expr) + return var_cache + + +def get_vars( + block, + include_fixed=False, + active=True, + sort=False, + descend_into=Block, + descent_order=None, +): + """Return all vars referenced through expressions in the specified block. + + This is a simple wrapper around :func:`get_vars_from_components()` + that gathers all variables referenced by :class:`Constraint` and + :class:`Objective` objects within the specified block. Note that as + it is designed to return the "variables used in the current model," + it uses different defaults for `active` and `include_fixed`. + + Parameters + ---------- + include_fixed : bool + If True, both fixed and free variables will be returned + + active : bool | None + If True, only variables accessible through the active component + tree will be returned. If None, all variables accessible + through either active or inactive components will be returned. + + sort: SortComponents | bool | None + sort method for iterating through Constraint objects + + descend_into : type | tuple[type] | None + "ctypes" to descend into when finding Constraints + + descent_order : TraversalStrategy | None + Traversal strategy for walking the block hierarchy + + Returns + ------- + ComponentSet : set of variables + + """ + return get_vars_from_components( + block, + ctype=(Constraint, Objective), + include_fixed=include_fixed, + active=active, + sort=sort, + descend_into=descend_into, + descent_order=descent_order, + )