Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
b0a7dde
chore: add optree package
abelaba Mar 31, 2026
d4d9d79
chore: add methods for working with optree
abelaba Mar 31, 2026
3bf5cd7
chore: replace pybaum tree methods with optree methods
abelaba Mar 31, 2026
5c3b563
fix: update implementation for setting data_col attribute for dataframes
abelaba Mar 31, 2026
51cc30f
chore: replace leaf_names with optree method
abelaba Apr 1, 2026
2edc2f2
chore: remove repeated OrderedDict check
abelaba Apr 1, 2026
1f07a02
chore: move namespace variable to typing.py
abelaba Apr 1, 2026
def8aa3
chore: remove unused method
abelaba Apr 1, 2026
109623f
chore: use optree context manager for ordering dict
abelaba Apr 1, 2026
e3eb382
chore: replace tree_equal method with optree impl
abelaba Apr 1, 2026
5aa2b08
chore: remove get_registry method and use namespace arugment
abelaba Apr 2, 2026
9d8c744
chore: use namespaces for passing data_col value for dataframes
abelaba Apr 2, 2026
94516da
chore: move namespaces list to typing.py
abelaba Apr 2, 2026
9e7996d
chore: register jax arrays
abelaba Apr 2, 2026
b932c30
chore: remove type hints
abelaba Apr 2, 2026
8276c6d
chore: rearrange method order
abelaba Apr 2, 2026
6f03b12
chore: remove pybaum dependency
abelaba Apr 2, 2026
d3edeed
chore: remove remaining pybaum string
abelaba Apr 2, 2026
ca240b5
chore: remove duplicate imports
abelaba Apr 2, 2026
1415d54
chore: add tests for tree methods and rearrange tree-registry methods
abelaba Apr 7, 2026
0e76b38
fix: move optree to pypi dependencies
abelaba Apr 7, 2026
c85fe08
docs: add documentation that tree_unflatten and tree_map respect inse…
abelaba Apr 7, 2026
81fd1ec
chore: raise warning for unregistered namespaces
abelaba Apr 8, 2026
8257c09
fix: update test for empty namespace
abelaba Apr 8, 2026
b1371b2
fix: add default namespace for dict insertion ordering
abelaba Apr 9, 2026
651fa5e
chore: rename tree_just_flatten to tree_leaves
abelaba Apr 9, 2026
bd29352
fix: use jax array type for registering a jax array
abelaba Apr 9, 2026
4f42b27
chore: change jaxlib registration class
abelaba Apr 9, 2026
b9e32d7
chore: add type hints
abelaba Apr 10, 2026
ad5cf22
chore: add tree_unflatten method to tests
abelaba Apr 10, 2026
35e7c36
Merge branch 'main' into migrate-pybaum-to-optree
abelaba Apr 10, 2026
19b4db3
test: update warning check for unregistered namespace
abelaba Apr 10, 2026
38db493
fix: use jax installation check from config file
abelaba Apr 10, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 7 additions & 45 deletions pixi.lock

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ dependencies = [
"numpy>=1.26",
"pandas>=2.1",
"plotly>=5.14",
"pybaum>=0.1.2",
"scipy>=1.11",
"sqlalchemy>=2.0",
"annotated-types>=0.4",
"typing-extensions>=4.5",
"optree>=0.19",
]
dynamic = ["version"]
keywords = [
Expand Down Expand Up @@ -349,7 +349,6 @@ ignore_errors = true

[[tool.mypy.overrides]]
module = [
"pybaum",
"scipy",
"scipy.linalg",
"scipy.linalg.lapack",
Expand Down
37 changes: 19 additions & 18 deletions src/estimagic/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,20 @@

import numpy as np
import pandas as pd
from pybaum import leaf_names, tree_flatten, tree_just_flatten, tree_unflatten

from estimagic.bootstrap_ci import calculate_ci
from estimagic.bootstrap_helpers import check_inputs
from estimagic.bootstrap_outcomes import get_bootstrap_outcomes
from estimagic.shared_covs import calculate_estimation_summary
from optimagic.batch_evaluators import joblib_batch_evaluator
from optimagic.parameters.block_trees import matrix_to_block_tree
from optimagic.parameters.tree_registry import get_registry
from optimagic.parameters.tree_registry import (
leaf_names,
tree_flatten,
tree_leaves,
tree_unflatten,
)
from optimagic.typing import VALUE_NAMESPACE
from optimagic.utilities import get_rng


Expand Down Expand Up @@ -102,9 +107,8 @@ def bootstrap(
# Process results
# ==================================================================================

registry = get_registry(extended=True)
flat_outcomes = [
tree_just_flatten(_outcome, registry=registry) for _outcome in all_outcomes
tree_leaves(_outcome, namespace=VALUE_NAMESPACE) for _outcome in all_outcomes
]
internal_outcomes = np.array(flat_outcomes)

Expand Down Expand Up @@ -162,11 +166,10 @@ def outcomes(self):
List[Any]: The boostrap outcomes as a list of pytrees.

"""
registry = get_registry(extended=True)
_, treedef = tree_flatten(self._base_outcome, registry=registry)
_, treedef = tree_flatten(self._base_outcome, namespace=VALUE_NAMESPACE)

outcomes = [
tree_unflatten(treedef, out, registry=registry)
tree_unflatten(treedef, out, namespace=VALUE_NAMESPACE)
for out in self._internal_outcomes
]
return outcomes
Expand All @@ -182,10 +185,9 @@ def se(self):
cov = self._internal_cov
se = np.sqrt(np.diagonal(cov))

registry = get_registry(extended=True)
_, treedef = tree_flatten(self._base_outcome, registry=registry)
_, treedef = tree_flatten(self._base_outcome, namespace=VALUE_NAMESPACE)

se = tree_unflatten(treedef, se, registry=registry)
se = tree_unflatten(treedef, se, namespace=VALUE_NAMESPACE)
return se

def cov(self, return_type="pytree"):
Expand All @@ -206,8 +208,7 @@ def cov(self, return_type="pytree"):
cov = self._internal_cov

if return_type == "dataframe":
registry = get_registry(extended=True)
names = np.array(leaf_names(self._base_outcome, registry=registry))
names = np.array(leaf_names(self._base_outcome, namespace=VALUE_NAMESPACE))
cov = pd.DataFrame(cov, columns=names, index=names)
elif return_type == "pytree":
cov = matrix_to_block_tree(cov, self._base_outcome, self._base_outcome)
Expand All @@ -234,15 +235,16 @@ def ci(self, ci_method="percentile", ci_level=0.95):
bounds of confidence intervals.

"""
registry = get_registry(extended=True)
base_outcome_flat, treedef = tree_flatten(self._base_outcome, registry=registry)
base_outcome_flat, treedef = tree_flatten(
self._base_outcome, namespace=VALUE_NAMESPACE
)

lower_flat, upper_flat = calculate_ci(
base_outcome_flat, self._internal_outcomes, ci_method, ci_level
)

lower = tree_unflatten(treedef, lower_flat, registry=registry)
upper = tree_unflatten(treedef, upper_flat, registry=registry)
lower = tree_unflatten(treedef, lower_flat, namespace=VALUE_NAMESPACE)
upper = tree_unflatten(treedef, upper_flat, namespace=VALUE_NAMESPACE)
return lower, upper

def p_values(self):
Expand Down Expand Up @@ -271,8 +273,7 @@ def summary(self, ci_method="percentile", ci_level=0.95):
Soon this will be a pytree.

"""
registry = get_registry(extended=True)
names = leaf_names(self.base_outcome, registry=registry)
names = leaf_names(self.base_outcome, namespace=VALUE_NAMESPACE)
summary_data = _calulcate_summary_data_bootstrap(
self, ci_method=ci_method, ci_level=ci_level
)
Expand Down
22 changes: 11 additions & 11 deletions src/estimagic/estimate_msm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import numpy as np
import pandas as pd
from pybaum import leaf_names, tree_just_flatten

from estimagic.msm_covs import cov_optimal, cov_robust
from estimagic.msm_sensitivity import (
Expand Down Expand Up @@ -51,10 +50,14 @@
from optimagic.parameters.bounds import Bounds, pre_process_bounds
from optimagic.parameters.conversion import Converter, get_converter
from optimagic.parameters.space_conversion import InternalParams
from optimagic.parameters.tree_registry import get_registry
from optimagic.parameters.tree_registry import (
leaf_names,
tree_leaves,
)
from optimagic.shared.check_option_dicts import (
check_optimization_options,
)
from optimagic.typing import VALUE_NAMESPACE
from optimagic.utilities import get_rng, to_pickle


Expand Down Expand Up @@ -318,8 +321,7 @@ def func(x):
sim_mom = simulate_moments(params, **simulate_moments_kwargs)
if isinstance(sim_mom, dict) and "simulated_moments" in sim_mom:
sim_mom = sim_mom["simulated_moments"]
registry = get_registry(extended=True)
out = np.array(tree_just_flatten(sim_mom, registry=registry))
out = np.array(tree_leaves(sim_mom, namespace=VALUE_NAMESPACE))
return out

int_jac = first_derivative(
Expand Down Expand Up @@ -418,8 +420,7 @@ def get_msm_optimization_functions(

chol_weights = np.linalg.cholesky(flat_weights)

registry = get_registry(extended=True)
flat_emp_mom = tree_just_flatten(empirical_moments, registry=registry)
flat_emp_mom = tree_leaves(empirical_moments, namespace=VALUE_NAMESPACE)

_simulate_moments = _partial_kwargs(simulate_moments, simulate_moments_kwargs)
_jacobian = _partial_kwargs(jacobian, jacobian_kwargs)
Expand All @@ -430,7 +431,7 @@ def get_msm_optimization_functions(
simulate_moments=_simulate_moments,
flat_empirical_moments=flat_emp_mom,
chol_weights=chol_weights,
registry=registry,
namespace=VALUE_NAMESPACE,
)
)

Expand All @@ -445,7 +446,7 @@ def get_msm_optimization_functions(


def _msm_criterion(
params, simulate_moments, flat_empirical_moments, chol_weights, registry
params, simulate_moments, flat_empirical_moments, chol_weights, namespace
):
"""Calculate msm criterion given parameters and building blocks."""
simulated = simulate_moments(params)
Expand All @@ -454,7 +455,7 @@ def _msm_criterion(
if isinstance(simulated, np.ndarray) and simulated.ndim == 1:
simulated_flat = simulated
else:
simulated_flat = np.array(tree_just_flatten(simulated, registry=registry))
simulated_flat = np.array(tree_leaves(simulated, namespace=namespace))

deviations = simulated_flat - flat_empirical_moments
residuals = deviations @ chol_weights
Expand Down Expand Up @@ -975,9 +976,8 @@ def sensitivity(
inner_tree=self._empirical_moments,
)
elif return_type == "dataframe":
registry = get_registry(extended=True)
row_names = self._internal_estimates.names
col_names = leaf_names(self._empirical_moments, registry=registry)
col_names = leaf_names(self._empirical_moments, namespace=VALUE_NAMESPACE)
out = pd.DataFrame(
data=raw,
index=row_names,
Expand Down
8 changes: 3 additions & 5 deletions src/estimagic/msm_weighting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

import numpy as np
import pandas as pd
from pybaum import tree_just_flatten
from scipy.linalg import block_diag

from estimagic.bootstrap import bootstrap
from optimagic.parameters.block_trees import block_tree_to_matrix, matrix_to_block_tree
from optimagic.parameters.tree_registry import get_registry
from optimagic.parameters.tree_registry import tree_leaves
from optimagic.typing import VALUE_NAMESPACE
from optimagic.utilities import robust_inverse


Expand Down Expand Up @@ -51,13 +51,11 @@ def get_moments_cov(

first_eval = calculate_moments(data, **moment_kwargs)

registry = get_registry(extended=True)

@functools.wraps(calculate_moments)
def func(data, **kwargs):
raw = calculate_moments(data, **kwargs)
out = pd.Series(
tree_just_flatten(raw, registry=registry)
tree_leaves(raw, namespace=VALUE_NAMESPACE)
) # xxxx won't be necessary soon!
return out

Expand Down
22 changes: 11 additions & 11 deletions src/estimagic/shared_covs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
import numpy as np
import pandas as pd
import scipy
from pybaum import tree_just_flatten, tree_unflatten

from optimagic.parameters.block_trees import matrix_to_block_tree
from optimagic.parameters.tree_registry import get_registry
from optimagic.parameters.tree_registry import (
tree_leaves,
tree_unflatten,
)
from optimagic.typing import VALUE_NAMESPACE


def transform_covariance(
Expand Down Expand Up @@ -146,9 +149,8 @@ def calculate_estimation_summary(
# Flatten summary and construct data frame for flat estimates
# ==================================================================================

registry = get_registry(extended=True)
flat_data = {
key: tree_just_flatten(val, registry=registry)
key: tree_leaves(val, namespace=VALUE_NAMESPACE)
for key, val in summary_data.items()
}

Expand All @@ -167,10 +169,10 @@ def calculate_estimation_summary(
# ==================================================================================

# create tree with values corresponding to indices of df
indices = tree_unflatten(summary_data["value"], names, registry=registry)
indices = tree_unflatten(summary_data["value"], names, namespace=VALUE_NAMESPACE)

estimates_flat = tree_just_flatten(summary_data["value"])
indices_flat = tree_just_flatten(indices)
estimates_flat = tree_leaves(summary_data["value"])
indices_flat = tree_leaves(indices)

# use index chunks in indices_flat to access the corresponding sub data frame of df,
# and use the index information stored in estimates_flat to form the correct (multi)
Expand Down Expand Up @@ -316,8 +318,7 @@ def calculate_free_estimates(estimates, internal_estimates):
mask = internal_estimates.free_mask
names = internal_estimates.names

registry = get_registry(extended=True)
external_flat = np.array(tree_just_flatten(estimates, registry=registry))
external_flat = np.array(tree_leaves(estimates, namespace=VALUE_NAMESPACE))

free_estimates = FreeParams(
values=external_flat[mask],
Expand Down Expand Up @@ -351,8 +352,7 @@ def transform_free_values_to_params_tree(values, free_params, params):
mask = free_params.free_mask
flat = np.full(len(mask), np.nan)
flat[np.ix_(mask)] = values
registry = get_registry(extended=True)
pytree = tree_unflatten(params, flat, registry=registry)
pytree = tree_unflatten(params, flat, namespace=VALUE_NAMESPACE)
return pytree


Expand Down
9 changes: 4 additions & 5 deletions src/optimagic/benchmarking/run_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
"""

import numpy as np
from pybaum import tree_just_flatten

from optimagic import batch_evaluators
from optimagic.algorithms import AVAILABLE_ALGORITHMS
from optimagic.optimization.optimize import minimize
from optimagic.parameters.tree_registry import get_registry
from optimagic.parameters.tree_registry import tree_leaves
from optimagic.typing import VALUE_NAMESPACE


def run_benchmark(
Expand Down Expand Up @@ -180,7 +180,6 @@ def _process_one_result(optimize_result, problem):
dict: Processed result.

"""
_registry = get_registry(extended=True)
_criterion = problem["noise_free_fun"]
_start_x = problem["inputs"]["params"]
_start_crit_value = _criterion(_start_x)
Expand All @@ -191,15 +190,15 @@ def _process_one_result(optimize_result, problem):

# This will happen if the optimization raised an error
if isinstance(optimize_result, str):
params_history_flat = [tree_just_flatten(_start_x, registry=_registry)]
params_history_flat = [tree_leaves(_start_x, namespace=VALUE_NAMESPACE)]
criterion_history = [_start_crit_value]
time_history = [np.inf]
batches_history = [0]
else:
history = optimize_result.history
params_history = history.params
params_history_flat = [
tree_just_flatten(p, registry=_registry) for p in params_history
tree_leaves(p, namespace=VALUE_NAMESPACE) for p in params_history
]
if _is_noisy:
criterion_history = np.array([_criterion(p) for p in params_history])
Expand Down
Loading
Loading