Skip to content

Commit 84dddb6

Browse files
committed
chore: rename tree_just_flatten to tree_leaves
1 parent b1371b2 commit 84dddb6

28 files changed

Lines changed: 551 additions & 98 deletions

src/estimagic/bootstrap.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from optimagic.parameters.tree_registry import (
1616
leaf_names,
1717
tree_flatten,
18-
tree_just_flatten,
18+
tree_leaves,
1919
tree_unflatten,
2020
)
2121
from optimagic.typing import VALUE_NAMESPACE
@@ -108,8 +108,7 @@ def bootstrap(
108108
# ==================================================================================
109109

110110
flat_outcomes = [
111-
tree_just_flatten(_outcome, namespace=VALUE_NAMESPACE)
112-
for _outcome in all_outcomes
111+
tree_leaves(_outcome, namespace=VALUE_NAMESPACE) for _outcome in all_outcomes
113112
]
114113
internal_outcomes = np.array(flat_outcomes)
115114

src/estimagic/estimate_msm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
from optimagic.parameters.space_conversion import InternalParams
5353
from optimagic.parameters.tree_registry import (
5454
leaf_names,
55-
tree_just_flatten,
55+
tree_leaves,
5656
)
5757
from optimagic.shared.check_option_dicts import (
5858
check_optimization_options,
@@ -321,7 +321,7 @@ def func(x):
321321
sim_mom = simulate_moments(params, **simulate_moments_kwargs)
322322
if isinstance(sim_mom, dict) and "simulated_moments" in sim_mom:
323323
sim_mom = sim_mom["simulated_moments"]
324-
out = np.array(tree_just_flatten(sim_mom, namespace=VALUE_NAMESPACE))
324+
out = np.array(tree_leaves(sim_mom, namespace=VALUE_NAMESPACE))
325325
return out
326326

327327
int_jac = first_derivative(
@@ -420,7 +420,7 @@ def get_msm_optimization_functions(
420420

421421
chol_weights = np.linalg.cholesky(flat_weights)
422422

423-
flat_emp_mom = tree_just_flatten(empirical_moments, namespace=VALUE_NAMESPACE)
423+
flat_emp_mom = tree_leaves(empirical_moments, namespace=VALUE_NAMESPACE)
424424

425425
_simulate_moments = _partial_kwargs(simulate_moments, simulate_moments_kwargs)
426426
_jacobian = _partial_kwargs(jacobian, jacobian_kwargs)
@@ -455,7 +455,7 @@ def _msm_criterion(
455455
if isinstance(simulated, np.ndarray) and simulated.ndim == 1:
456456
simulated_flat = simulated
457457
else:
458-
simulated_flat = np.array(tree_just_flatten(simulated, namespace=namespace))
458+
simulated_flat = np.array(tree_leaves(simulated, namespace=namespace))
459459

460460
deviations = simulated_flat - flat_empirical_moments
461461
residuals = deviations @ chol_weights

src/estimagic/msm_weighting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from estimagic.bootstrap import bootstrap
88
from optimagic.parameters.block_trees import block_tree_to_matrix, matrix_to_block_tree
9-
from optimagic.parameters.tree_registry import tree_just_flatten
9+
from optimagic.parameters.tree_registry import tree_leaves
1010
from optimagic.typing import VALUE_NAMESPACE
1111
from optimagic.utilities import robust_inverse
1212

@@ -55,7 +55,7 @@ def get_moments_cov(
5555
def func(data, **kwargs):
5656
raw = calculate_moments(data, **kwargs)
5757
out = pd.Series(
58-
tree_just_flatten(raw, namespace=VALUE_NAMESPACE)
58+
tree_leaves(raw, namespace=VALUE_NAMESPACE)
5959
) # xxxx won't be necessary soon!
6060
return out
6161

src/estimagic/shared_covs.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from optimagic.parameters.block_trees import matrix_to_block_tree
88
from optimagic.parameters.tree_registry import (
9-
tree_just_flatten,
9+
tree_leaves,
1010
tree_unflatten,
1111
)
1212
from optimagic.typing import VALUE_NAMESPACE
@@ -150,7 +150,7 @@ def calculate_estimation_summary(
150150
# ==================================================================================
151151

152152
flat_data = {
153-
key: tree_just_flatten(val, namespace=VALUE_NAMESPACE)
153+
key: tree_leaves(val, namespace=VALUE_NAMESPACE)
154154
for key, val in summary_data.items()
155155
}
156156

@@ -171,8 +171,8 @@ def calculate_estimation_summary(
171171
# create tree with values corresponding to indices of df
172172
indices = tree_unflatten(summary_data["value"], names, namespace=VALUE_NAMESPACE)
173173

174-
estimates_flat = tree_just_flatten(summary_data["value"])
175-
indices_flat = tree_just_flatten(indices)
174+
estimates_flat = tree_leaves(summary_data["value"])
175+
indices_flat = tree_leaves(indices)
176176

177177
# use index chunks in indices_flat to access the corresponding sub data frame of df,
178178
# and use the index information stored in estimates_flat to form the correct (multi)
@@ -318,7 +318,7 @@ def calculate_free_estimates(estimates, internal_estimates):
318318
mask = internal_estimates.free_mask
319319
names = internal_estimates.names
320320

321-
external_flat = np.array(tree_just_flatten(estimates, namespace=VALUE_NAMESPACE))
321+
external_flat = np.array(tree_leaves(estimates, namespace=VALUE_NAMESPACE))
322322

323323
free_estimates = FreeParams(
324324
values=external_flat[mask],

src/optimagic/benchmarking/run_benchmark.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from optimagic import batch_evaluators
1414
from optimagic.algorithms import AVAILABLE_ALGORITHMS
1515
from optimagic.optimization.optimize import minimize
16-
from optimagic.parameters.tree_registry import tree_just_flatten
16+
from optimagic.parameters.tree_registry import tree_leaves
1717
from optimagic.typing import VALUE_NAMESPACE
1818

1919

@@ -190,15 +190,15 @@ def _process_one_result(optimize_result, problem):
190190

191191
# This will happen if the optimization raised an error
192192
if isinstance(optimize_result, str):
193-
params_history_flat = [tree_just_flatten(_start_x, namespace=VALUE_NAMESPACE)]
193+
params_history_flat = [tree_leaves(_start_x, namespace=VALUE_NAMESPACE)]
194194
criterion_history = [_start_crit_value]
195195
time_history = [np.inf]
196196
batches_history = [0]
197197
else:
198198
history = optimize_result.history
199199
params_history = history.params
200200
params_history_flat = [
201-
tree_just_flatten(p, namespace=VALUE_NAMESPACE) for p in params_history
201+
tree_leaves(p, namespace=VALUE_NAMESPACE) for p in params_history
202202
]
203203
if _is_noisy:
204204
criterion_history = np.array([_criterion(p) for p in params_history])

src/optimagic/differentiation/derivatives.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,9 @@
2222
from optimagic.parameters.bounds import Bounds, get_internal_bounds, pre_process_bounds
2323
from optimagic.parameters.tree_registry import (
2424
tree_flatten,
25-
tree_just_flatten,
25+
tree_leaves,
2626
tree_unflatten,
2727
)
28-
from optimagic.parameters.tree_registry import tree_just_flatten as tree_leaves
2928
from optimagic.typing import VALUE_NAMESPACE, BatchEvaluatorLiteral, PyTree
3029

3130

@@ -226,18 +225,14 @@ def first_derivative(
226225

227226
if scaling_factor is not None and not np.isscalar(scaling_factor):
228227
scaling_factor = np.array(
229-
tree_just_flatten(scaling_factor, namespace=VALUE_NAMESPACE)
228+
tree_leaves(scaling_factor, namespace=VALUE_NAMESPACE)
230229
)
231230

232231
if min_steps is not None and not np.isscalar(min_steps):
233-
min_steps = np.array(
234-
tree_just_flatten(min_steps, namespace=VALUE_NAMESPACE)
235-
)
232+
min_steps = np.array(tree_leaves(min_steps, namespace=VALUE_NAMESPACE))
236233

237234
if step_size is not None and not np.isscalar(step_size):
238-
step_size = np.array(
239-
tree_just_flatten(step_size, namespace=VALUE_NAMESPACE)
240-
)
235+
step_size = np.array(tree_leaves(step_size, namespace=VALUE_NAMESPACE))
241236
else:
242237
x = params.astype(np.float64)
243238

@@ -544,18 +539,14 @@ def second_derivative(
544539

545540
if scaling_factor is not None and not np.isscalar(scaling_factor):
546541
scaling_factor = np.array(
547-
tree_just_flatten(scaling_factor, namespace=VALUE_NAMESPACE)
542+
tree_leaves(scaling_factor, namespace=VALUE_NAMESPACE)
548543
)
549544

550545
if min_steps is not None and not np.isscalar(min_steps):
551-
min_steps = np.array(
552-
tree_just_flatten(min_steps, namespace=VALUE_NAMESPACE)
553-
)
546+
min_steps = np.array(tree_leaves(min_steps, namespace=VALUE_NAMESPACE))
554547

555548
if step_size is not None and not np.isscalar(step_size):
556-
step_size = np.array(
557-
tree_just_flatten(step_size, namespace=VALUE_NAMESPACE)
558-
)
549+
step_size = np.array(tree_leaves(step_size, namespace=VALUE_NAMESPACE))
559550
else:
560551
x = params.astype(np.float64)
561552

src/optimagic/examples/criterion_functions.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
)
1818
from optimagic.parameters.block_trees import matrix_to_block_tree
1919
from optimagic.parameters.tree_registry import (
20-
tree_just_flatten,
20+
tree_leaves,
2121
tree_unflatten,
2222
)
2323
from optimagic.typing import VALUE_NAMESPACE, PyTree
@@ -214,9 +214,7 @@ def _get_x(params: PyTree) -> NDArray[np.float64]:
214214
if isinstance(params, np.ndarray) and params.ndim == 1:
215215
x = params.astype(float)
216216
else:
217-
x = np.array(
218-
tree_just_flatten(params, namespace=VALUE_NAMESPACE), dtype=np.float64
219-
)
217+
x = np.array(tree_leaves(params, namespace=VALUE_NAMESPACE), dtype=np.float64)
220218
return x
221219

222220

src/optimagic/optimization/fun_value.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from numpy.typing import NDArray
88

99
from optimagic.exceptions import InvalidFunctionError
10-
from optimagic.parameters.tree_registry import tree_just_flatten
10+
from optimagic.parameters.tree_registry import tree_leaves
1111
from optimagic.typing import VALUE_NAMESPACE, AggregationLevel, PyTree, Scalar
1212
from optimagic.utilities import isscalar
1313

@@ -123,7 +123,7 @@ def _get_flat_value(value: PyTree) -> NDArray[np.float64]:
123123
elif isinstance(value, np.ndarray):
124124
flat = value.flatten()
125125
else:
126-
flat = tree_just_flatten(value, namespace=VALUE_NAMESPACE)
126+
flat = tree_leaves(value, namespace=VALUE_NAMESPACE)
127127

128128
flat_arr = np.asarray(flat, dtype=np.float64)
129129
return flat_arr

src/optimagic/optimization/history.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from optimagic.parameters.tree_registry import (
1111
leaf_names,
12-
tree_just_flatten,
12+
tree_leaves,
1313
)
1414
from optimagic.timing import CostModel
1515
from optimagic.typing import VALUE_NAMESPACE, Direction, EvalTask, PyTree
@@ -400,7 +400,7 @@ def _get_flat_params(params: list[PyTree]) -> list[list[float]]:
400400
if fast_path:
401401
flatten = lambda x: x.tolist()
402402
else:
403-
flatten = partial(tree_just_flatten, namespace=VALUE_NAMESPACE)
403+
flatten = partial(tree_leaves, namespace=VALUE_NAMESPACE)
404404

405405
return [flatten(p) for p in params]
406406

src/optimagic/parameters/block_trees.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55

66
from optimagic.parameters.tree_registry import (
77
tree_flatten,
8+
tree_leaves,
89
tree_unflatten,
910
)
10-
from optimagic.parameters.tree_registry import tree_just_flatten as tree_leaves
1111
from optimagic.typing import VALUE_NAMESPACE
1212

1313

0 commit comments

Comments
 (0)