Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 51 additions & 10 deletions pyomo/core/expr/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@


import inspect
import itertools
import logging
import sys
from copy import deepcopy
Expand Down Expand Up @@ -1395,32 +1396,63 @@ def identify_components(expr, component_types):


class IdentifyVariableVisitor(StreamBasedExpressionVisitor):
def __init__(self, include_fixed=False, named_expression_cache=None):
def __init__(
self, include_fixed=False, named_expression_cache=None, var_cache=None
):
"""Visitor that collects all unique variables participating in an
expression

Args:
include_fixed (bool): Whether to include fixed variables
named_expression_cache (optional, dict): Dict mapping ids of named
expressions to a tuple of the list of all variables and the
set of all variable ids contained in the named expression.
:meth:`walk_expression` returns a generator of the unique
variables found in the expression. If `var_cache` was
specified, then only the *new* variables found in `expr` are
returned (the full list of all variables is maintained in the
`var_cache` dict).

Parameters
----------
include_fixed : bool
If True, fixed variables will be reported

named_expression_cache : dict
Cache of named expressions that have been visited by this
walker. The value includes the variables within the named
expression as well as information for detecting when the
named expression has changed (for cache invalidation).

var_cache : dict[int, VarData]
Dict mapping the :func:`id()` of variables to
:class:`VarData` for all variables that have been "seen" by
this walker. If provided, this dictionary is preserved
between calls to :meth:`walk_expression` (so repeated
variables are not returned more than once).

"""
super().__init__()
self._include_fixed = include_fixed
# cache of visited named expressions. This dict maps
# {eid: (seen, exprs)}.
# - eid is the id() of the named expression
# - seen is the processed result for the named expression
# (including any nested named expressions)
# - exprs is used for automatically invalidating the cache (see below).
self._cache = named_expression_cache
# Stack of named expressions. This holds the tuple
# (eid, _seen, _exprs)
# where eid is the id() of the subexpression we are currently
# processing, and _seen and _exprs are from the parent context.
self._expr_stack = []
# The following attributes will be added by initializeWalker:
# self._seen: dict(eid: obj)
# cache of "seen" variables: dict(eid: VarData)
self._seen = var_cache
# The following attribute will be added by initializeWalker:
# self._exprs: list of (e, e.expr) for any (nested) named expressions

def initializeWalker(self, expr):
assert not self._expr_stack
self._seen = {}
if self._seen is None:
self._seen = {}
self._expr_stack.append(None)
else:
self._expr_stack.append(len(self._seen))
self._exprs = None
if not self.beforeChild(None, expr, 0)[0]:
return False, self.finalizeResult(None)
Expand Down Expand Up @@ -1452,8 +1484,17 @@ def exitNode(self, node, data):
self._merge_obj_lists(_seen, _exprs)

def finalizeResult(self, result):
seen = self._seen
initial_num_seen = self._expr_stack.pop()
assert not self._expr_stack
return self._seen.values()
if initial_num_seen is None:
self._seen = None
return seen.values()
else:
# Only return the *new* variables found on this walk. This
# relies on dict iteration being in insertion order (which,
# since python 3.7, it is)
return itertools.islice(seen.values(), initial_num_seen, len(seen))

def _merge_obj_lists(self, _seen, _exprs):
self._seen.update(_seen)
Expand Down
21 changes: 21 additions & 0 deletions pyomo/core/tests/unit/test_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
SimpleExpressionVisitor,
StreamBasedExpressionVisitor,
ExpressionReplacementVisitor,
IdentifyVariableVisitor,
evaluate_expression,
expression_to_string,
replace_expressions,
Expand Down Expand Up @@ -252,6 +253,26 @@ def test_identify_vars_linear_expression(self):
expr = quicksum([m.x, m.x], linear=True)
self.assertEqual(list(identify_variables(expr, include_fixed=False)), [m.x])

def test_identify_vars_var_cache(self):
m = ConcreteModel()
m.x = Var()
m.y = Var()
m.z = Var()

e1 = m.x + m.y
e2 = m.y + m.z

v = IdentifyVariableVisitor()
self.assertEqual(list(v.walk_expression(e1)), [m.x, m.y])
self.assertEqual(list(v.walk_expression(e2)), [m.y, m.z])

seen = {}
v = IdentifyVariableVisitor(var_cache=seen)
self.assertEqual(list(v.walk_expression(e2)), [m.y, m.z])
self.assertEqual(list(seen.values()), [m.y, m.z])
self.assertEqual(list(v.walk_expression(e1)), [m.x])
self.assertEqual(list(seen.values()), [m.y, m.z, m.x])


class TestIdentifyParams(unittest.TestCase):
def test_identify_params_numeric(self):
Expand Down
59 changes: 59 additions & 0 deletions pyomo/util/tests/test_vars_from_expressions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# ____________________________________________________________________________________
#
# Pyomo: Python Optimization Modeling Objects
# Copyright (c) 2008-2026 National Technology and Engineering Solutions of Sandia, LLC
# Under the terms of Contract DE-NA0003525 with National Technology and Engineering
# Solutions of Sandia, LLC, the U.S. Government retains certain rights in this
# software. This software is distributed under the 3-clause BSD License.
# ____________________________________________________________________________________

import pyomo.environ as pyo
from pyomo.common import unittest
from pyomo.util.vars_from_expressions import get_vars, get_vars_from_components


class TestVarsFromExpressions(unittest.TestCase):
def test_get_vars(self):
m = pyo.ConcreteModel()
m.x = pyo.Var(list(range(5)))
m.c1 = pyo.Constraint(expr=m.x[0] + m.x[1] == 0)
m.c2 = pyo.Constraint(expr=m.x[1] + m.x[2] == 0)
m.obj = pyo.Objective(expr=m.x[3] + m.x[4])

self.assertEqual(list(get_vars(m)), [m.x[0], m.x[1], m.x[2], m.x[3], m.x[4]])
# verify the default values for active and include_fixed
m.x[0].fix(0)
m.c2.deactivate()
self.assertEqual(list(get_vars(m)), [m.x[1], m.x[3], m.x[4]])

def test_get_vars_from_components(self):
m = pyo.ConcreteModel()
m.x = pyo.Var(list(range(5)))
m.c1 = pyo.Constraint(expr=m.x[0] + m.x[1] == 0)
m.c2 = pyo.Constraint(expr=m.x[1] + m.x[2] == 0)
m.obj = pyo.Objective(expr=m.x[3] + m.x[4])

self.assertEqual(
list(get_vars_from_components(m, pyo.Constraint)), [m.x[0], m.x[1], m.x[2]]
)
self.assertEqual(
list(get_vars_from_components(m, pyo.Objective)), [m.x[3], m.x[4]]
)
self.assertEqual(
list(get_vars_from_components(m, (pyo.Constraint, pyo.Objective))),
[m.x[0], m.x[1], m.x[2], m.x[3], m.x[4]],
)

# verify the default values for active and include_fixed
m.x[0].fix(0)
m.c2.deactivate()
self.assertEqual(
list(get_vars_from_components(m, pyo.Constraint)), [m.x[0], m.x[1], m.x[2]]
)
self.assertEqual(
list(get_vars_from_components(m, pyo.Objective)), [m.x[3], m.x[4]]
)
self.assertEqual(
list(get_vars_from_components(m, (pyo.Constraint, pyo.Objective))),
[m.x[0], m.x[1], m.x[2], m.x[3], m.x[4]],
)
69 changes: 57 additions & 12 deletions pyomo/util/vars_from_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
# ____________________________________________________________________________________

"""
This module contains a function to generate a list of the Vars which appear
This module contains functions to generate a list of the Vars appearing
in expressions in the active tree. Note this is not the same as
component_data_objects(Var) because it does not look for Var objects which are
``component_data_objects(Var)`` because it does not look for Var objects which are
not used in any expressions and it does not care if the Vars it finds are
actually in the subtree or not.
actually in the Block subtree or not.
"""

from pyomo.core import Block
from pyomo.core import Block, Constraint, Objective
from pyomo.core.expr.visitor import IdentifyVariableVisitor


Expand All @@ -28,7 +28,7 @@ def get_vars_from_components(
descend_into=Block,
descent_order=None,
):
"""Returns a generator of all the Var objects which are used in Constraint
"""Returns a generator of all the Var objects which are used in
expressions on the block. By default, this recurses into sub-blocks.

Args:
Expand All @@ -40,17 +40,62 @@ def get_vars_from_components(
sort: sort method for iterating through Constraint objects
descend_into: Ctypes to descend into when finding Constraints
descent_order: Traversal strategy for finding the objects of type ctype

"""
visitor = IdentifyVariableVisitor(include_fixed, {})
seen = set()
for constraint in block.component_data_objects(
var_cache = {}
visitor = IdentifyVariableVisitor(include_fixed, {}, var_cache=var_cache)
for component in block.component_data_objects(
ctype,
active=active,
sort=sort,
descend_into=descend_into,
descent_order=descent_order,
):
for var in visitor.walk_expression(constraint.expr):
if id(var) not in seen:
seen.add(id(var))
yield var
visitor.walk_expression(component.expr)
return var_cache.values()


def get_vars(
block,
include_fixed=False,
active=True,
sort=False,
descend_into=Block,
descent_order=None,
):
"""Return all vars referenced through expressions in the specified block.

This is a simple wrapper around :func:`get_vars_from_components()`
that gathers all variables referenced by :class:`Constraint` and
:class:`Objective` objects within the specified block. Note that as
it is designed to return the "variables used in the current model,"
it uses different defaults for `active` and `include_fixed`.

Parameters
----------
include_fixed : bool
If True, both fixed and free variables will be returned

active : bool | None
If True, only variables accessible through the active component
tree will be returned

sort: SortOrder | None
sort method for iterating through Constraint objects

descend_into : None | type | tuple[type]
Ctypes to descend into when finding Constraints

descent_order : None | TraversalStrategy
Traversal strategy for walking the block hierarchy

"""
return get_vars_from_components(
block,
ctype=(Constraint, Objective),
include_fixed=include_fixed,
active=active,
sort=sort,
descend_into=descend_into,
descent_order=descent_order,
)
Loading