Skip to content

Commit 8c688a6

Browse files
committed
Add generic type bounds for ResultSet
1 parent 2c50d4d commit 8c688a6

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

src/codemodder/result.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import itertools
44
from dataclasses import dataclass, field
55
from pathlib import Path
6-
from typing import TYPE_CHECKING, Any, ClassVar, Sequence, Type
6+
from typing import TYPE_CHECKING, Any, ClassVar, Sequence, Type, TypeVar
77

88
import libcst as cst
99
from boltons.setutils import IndexedSet
@@ -225,8 +225,11 @@ def fuzzy_column_match(pos: CodeRange, location: Location) -> bool:
225225
)
226226

227227

228-
class ResultSet(dict[str, dict[Path, list[Result]]]):
229-
results_for_rule: dict[str, list[Result]]
228+
ResultType = TypeVar("ResultType", bound=Result)
229+
230+
231+
class ResultSet(dict[str, dict[Path, list[ResultType]]]):
232+
results_for_rule: dict[str, list[ResultType]]
230233
# stores SARIF runs.tool data
231234
tools: list[dict[str, dict]]
232235

@@ -235,7 +238,7 @@ def __init__(self, *args, **kwargs):
235238
self.results_for_rule = {}
236239
self.tools = []
237240

238-
def add_result(self, result: Result):
241+
def add_result(self, result: ResultType):
239242
self.results_for_rule.setdefault(result.rule_id, []).append(result)
240243
for loc in result.locations:
241244
self.setdefault(result.rule_id, {}).setdefault(loc.file, []).append(result)
@@ -246,7 +249,7 @@ def store_tool_data(self, tool_data: dict):
246249

247250
def results_for_rule_and_file(
248251
self, context: CodemodExecutionContext, rule_id: str, file: Path
249-
) -> list[Result]:
252+
) -> list[ResultType]:
250253
"""
251254
Return list of results for a given rule and file.
252255
@@ -258,7 +261,7 @@ def results_for_rule_and_file(
258261
"""
259262
return self.get(rule_id, {}).get(file.relative_to(context.directory), [])
260263

261-
def results_for_rules(self, rule_ids: list[str]) -> list[Result]:
264+
def results_for_rules(self, rule_ids: list[str]) -> list[ResultType]:
262265
"""
263266
Returns flat list of all results that match any of the given rule IDs.
264267
"""

0 commit comments

Comments
 (0)