Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13,213 changes: 6,835 additions & 6,378 deletions pixi.lock

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ classifiers = [
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
"Programming Language :: Python :: Implementation :: CPython",
]
dependencies = [
Expand Down Expand Up @@ -90,7 +91,7 @@ exclude = "(\\w+/)*test_\\w+\\.py$|spaces/skops_model_card_creator|old"
ignore_missing_imports = true
no_implicit_optional = true

[tool.pixi.project]
[tool.pixi.workspace]
channels = ["conda-forge"]
platforms = ["linux-64", "osx-arm64", "osx-64", "win-64"]

Expand Down Expand Up @@ -164,7 +165,7 @@ python = "~=3.10.0"
scikit-learn = "~=1.4.0"
pandas = "~=2.1.0"
numpy = "~=1.26.0"
scipy = "~=1.11.0"
scipy = "~=1.13.0"
fairlearn = "~=0.10.0"
catboost = ">=1.0"
python = "~=3.11.0"
Expand Down
31 changes: 16 additions & 15 deletions skops/card/tests/test_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,8 @@ def test_default(self, model_card):
"Model description/Training Procedure/Model Plot"
).format()
# don't compare whole text, as it's quite long and non-deterministic
assert result.startswith("<style>#sk-")
assert "<style>" in result
# regex matches both old (#sk-) and new (.sk-) sklearn CSS selector formats
assert re.match(r"<style>[#.]sk-", result)
assert "MyRegressor()" in result

def test_no_overflow(self, model_card):
Expand Down Expand Up @@ -264,24 +264,25 @@ def test_model_diagram_str(self):
assert result == CONTENT_PLACEHOLDER

# now check that the actual model diagram is in the other section
# regex matches both old (#sk-) and new (.sk-) sklearn CSS selector formats
result = model_card.select(other_section_name).format()
assert result.startswith("<style>#sk-")
assert "<style>" in result
assert re.match(r"<style>[#.]sk-", result)
assert "MyRegressor()" in result

def test_other_section(self, model_card):
model_card.add_model_plot(section="Other section")
result = model_card.select("Other section").content
assert result.startswith("<style>#sk-")
assert "<style>" in result
# regex matches both old (#sk-) and new (.sk-) sklearn CSS selector formats
assert re.match(r"<style>[#.]sk-", result)
assert "MyRegressor()" in result

def test_with_description(self, model_card):
model_card.add_model_plot(description="Awesome diagram below")
result = model_card.select(
"Model description/Training Procedure/Model Plot"
).format()
assert result.startswith("Awesome diagram below\n\n<style>#sk-")
# regex matches both old (#sk-) and new (.sk-) sklearn CSS selector formats
assert re.match(r"Awesome diagram below\n\n<style>[#.]sk-", result)

@pytest.mark.parametrize("template", CUSTOM_TEMPLATES)
def test_custom_template_no_section_uses_default(self, template):
Expand All @@ -293,8 +294,8 @@ def test_custom_template_no_section_uses_default(self, template):
).format()

# don't compare whole text, as it's quite long and non-deterministic
assert result.startswith("<style>#sk-")
assert "<style>" in result
# regex matches both old (#sk-) and new (.sk-) sklearn CSS selector formats
assert re.match(r"<style>[#.]sk-", result)
assert "MyRegressor()" in result

@pytest.mark.parametrize("template", CUSTOM_TEMPLATES)
Expand All @@ -304,8 +305,8 @@ def test_custom_template_init_str_works(self, template):
model_card = Card(model, template=template, model_diagram=section_name)

result = model_card.select(section_name).format()
assert result.startswith("<style>#sk-")
assert "<style>" in result
# regex matches both old (#sk-) and new (.sk-) sklearn CSS selector formats
assert re.match(r"<style>[#.]sk-", result)
assert "MyRegressor()" in result

def test_default_template_and_model_diagram_true(self, model_card):
Expand All @@ -317,8 +318,8 @@ def test_default_template_and_model_diagram_true(self, model_card):
"Model description/Training Procedure/Model Plot"
).format()
# don't compare whole text, as it's quite long and non-deterministic
assert result.startswith("<style>#sk-")
assert "<style>" in result
# regex matches both old (#sk-) and new (.sk-) sklearn CSS selector formats
assert re.match(r"<style>[#.]sk-", result)
assert "MyRegressor()" in result

@pytest.mark.parametrize("template", CUSTOM_TEMPLATES)
Expand All @@ -333,8 +334,8 @@ def test_custom_template_and_model_diagram_true_uses_default(
"Model description/Training Procedure/Model Plot"
).format()
# don't compare whole text, as it's quite long and non-deterministic
assert result.startswith("<style>#sk-")
assert "<style>" in result
# regex matches both old (#sk-) and new (.sk-) sklearn CSS selector formats
assert re.match(r"<style>[#.]sk-", result)
assert "MyRegressor()" in result

def test_add_twice(self, model_card):
Expand Down
4 changes: 2 additions & 2 deletions skops/io/_quantile_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ._protocol import PROTOCOL
from ._sklearn import ReduceNode, reduce_get_state
from ._utils import LoadContext, SaveContext
from ._utils import LoadContext, SaveContext, get_module

try:
from quantile_forest._quantile_forest_fast import QuantileForest
Expand Down Expand Up @@ -40,7 +40,7 @@ def __init__(
super().__init__(
state,
load_context,
constructor=QuantileForest,
constructor=(get_module(QuantileForest), "QuantileForest"),
trusted=trusted,
)
self.trusted = self._get_trusted(trusted, [])
Expand Down
19 changes: 12 additions & 7 deletions skops/io/_sklearn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Optional, Sequence, Type
from typing import Any, Optional, Sequence

from sklearn.cluster import Birch
from sklearn.tree._tree import Tree
Expand Down Expand Up @@ -153,19 +153,19 @@ def __init__(
self,
state: dict[str, Any],
load_context: LoadContext,
constructor: Type[Any],
constructor: tuple[str, str],
trusted: Optional[Sequence[str]] = None,
) -> None:
super().__init__(state, load_context, trusted)
reduce = state["__reduce__"]
ctor_module, ctor_class = constructor
self.children = {
"attrs": get_tree(state["content"], load_context, trusted=trusted),
"args": get_tree(reduce["args"], load_context, trusted=trusted),
"constructor": TypeNode(
{
"__class__": constructor.__name__,
"__module__": get_module(constructor),
"__id__": id(constructor),
"__class__": ctor_class,
"__module__": ctor_module,
},
load_context,
trusted=trusted,
Expand Down Expand Up @@ -216,7 +216,12 @@ def __init__(
trusted: Optional[Sequence[str]] = None,
) -> None:
self.trusted = self._get_trusted(trusted, [get_module(Tree) + ".Tree"])
super().__init__(state, load_context, constructor=Tree, trusted=self.trusted)
super().__init__(
state,
load_context,
constructor=(get_module(Tree), "Tree"),
trusted=self.trusted,
)


def loss_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
Expand Down Expand Up @@ -255,7 +260,7 @@ def __init__(
super().__init__(
state,
load_context,
constructor=gettype(state["__module__"], state["__class__"]),
constructor=(state["__module__"], state["__class__"]),
trusted=self.trusted,
)

Expand Down
32 changes: 32 additions & 0 deletions skops/io/tests/test_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,3 +1169,35 @@ def test_custom_reduce():

loaded_obj = loads(dumps(obj), trusted=[CustomReduce])
assert obj.value == loaded_obj.value


def test_loss_node_does_not_import_before_audit(monkeypatch):
"""Regression test for https://github.com/skops-dev/skops/pull/506"""
try:
from sklearn._loss._loss import CyAbsoluteError
except ImportError:
pytest.skip("sklearn version does not have CyAbsoluteError")

dumped = dumps(CyAbsoluteError())
buffer = io.BytesIO()

with ZipFile(io.BytesIO(dumped), "r") as src, ZipFile(buffer, "w") as dst:
schema = json.loads(src.read("schema.json"))
schema["__module__"] = "malicious_mod"
schema["__class__"] = "Payload"

for info in src.infolist():
if info.filename == "schema.json":
dst.writestr("schema.json", json.dumps(schema))
else:
dst.writestr(info, src.read(info.filename))

dumped = buffer.getvalue()

def fail_gettype(*args, **kwargs):
raise AssertionError("gettype() should not be called before audit")

monkeypatch.setattr("skops.io._sklearn.gettype", fail_gettype)

with pytest.raises(UntrustedTypesFoundException, match="malicious_mod.Payload"):
loads(dumped)