Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions chebifier/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Note: The top-level package __init__.py runs only once,
# even if multiple subpackages are imported later.

from ._custom_cache import PerSmilesPerModelLRUCache

modelwise_smiles_lru_cache = PerSmilesPerModelLRUCache(max_size=100)
208 changes: 208 additions & 0 deletions chebifier/_custom_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
import os
import pickle
import threading
from collections import OrderedDict
from collections.abc import Iterable
from functools import wraps
from typing import Any, Callable


class PerSmilesPerModelLRUCache:
"""
A thread-safe, optionally persistent LRU cache for storing
(SMILES, model_name) → result mappings.
"""

def __init__(self, max_size: int = 100, persist_path: str | None = None):
"""
Initialize the cache.

Args:
max_size (int): Maximum number of items to keep in the cache.
persist_path (str | None): Optional path to persist cache using pickle.
"""
self._cache: OrderedDict[tuple[str, str], Any] = OrderedDict()
self._max_size = max_size
self._lock = threading.Lock()
self._persist_path = persist_path

self.hits = 0
self.misses = 0

if self._persist_path:
self._load_cache()

def get(self, smiles: str, model_name: str) -> Any | None:
"""
Retrieve value from cache if present, otherwise return None.

Args:
smiles (str): SMILES string key.
model_name (str): Model identifier.

Returns:
Any | None: Cached value or None.
"""
key = (smiles, model_name)
with self._lock:
if key in self._cache:
self._cache.move_to_end(key)
self.hits += 1
return self._cache[key]
else:
self.misses += 1
return None

def set(self, smiles: str, model_name: str, value: Any) -> None:
"""
Store value in cache under (smiles, model_name) key.

Args:
smiles (str): SMILES string key.
model_name (str): Model identifier.
value (Any): Value to cache.
"""
assert value is not None, "Value must not be None"
key = (smiles, model_name)
with self._lock:
if key in self._cache:
self._cache.move_to_end(key)
self._cache[key] = value
if len(self._cache) > self._max_size:
self._cache.popitem(last=False)

def clear(self) -> None:
"""
Clear the cache and remove the persistence file if present.
"""
self._save_cache()
with self._lock:
self._cache.clear()
self.hits = 0
self.misses = 0
if self._persist_path and os.path.exists(self._persist_path):
os.remove(self._persist_path)

def stats(self) -> dict[str, int]:
"""
Return cache hit/miss statistics.

Returns:
dict[str, int]: Dictionary with 'hits' and 'misses' keys.
"""
return {"hits": self.hits, "misses": self.misses}

def batch_decorator(self, func: Callable) -> Callable:
"""
Decorator for class methods that accept a batch of SMILES as a list,
and cache predictions per (smiles, model_name) key.

The instance is expected to have a `model_name` attribute.

Args:
func (Callable): The method to decorate.

Returns:
Callable: The wrapped method.
"""

@wraps(func)
def wrapper(instance, smiles_list: list[str]) -> list[Any]:
assert isinstance(smiles_list, list), "smiles_list must be a list."
model_name = getattr(instance, "model_name", None)
assert model_name is not None, "Instance must have a model_name attribute."

missing_smiles: list[str] = []
missing_indices: list[int] = []
ordered_results: list[Any] = [None] * len(smiles_list)

# First: try to fetch all from cache
for idx, smiles in enumerate(smiles_list):
prediction = self.get(smiles=smiles, model_name=model_name)
if prediction is not None:
# For debugging purposes, you can uncomment the print statement below
# print(
# f"[Cache Hit] Prediction for smiles: {smiles} and model: {model_name} are retrieved from cache."
# )
ordered_results[idx] = prediction
else:
missing_smiles.append(smiles)
missing_indices.append(idx)

# If some are missing, call original function
if missing_smiles:
new_results = func(instance, tuple(missing_smiles))
assert isinstance(
new_results, Iterable
), "Function must return an Iterable."

# Save to cache and append
for smiles, prediction, missing_idx in zip(
missing_smiles, new_results, missing_indices
):
if prediction is not None:
self.set(smiles, model_name, prediction)
ordered_results[missing_idx] = prediction

return ordered_results

return wrapper

def __len__(self) -> int:
"""
Return number of items in the cache.

Returns:
int: Number of entries in the cache.
"""
with self._lock:
return len(self._cache)

def __repr__(self) -> str:
"""
String representation of the underlying cache.

Returns:
str: String version of the OrderedDict.
"""
return self._cache.__repr__()

def save(self) -> None:
"""
Save the cache to disk, if persistence is enabled.
"""
self._save_cache()

def load(self) -> None:
"""
Load the cache from disk, if persistence is enabled.
"""
self._load_cache()

def _save_cache(self) -> None:
"""
Serialize the cache to disk using pickle.
"""
if self._persist_path:
try:
with open(self._persist_path, "wb") as f:
pickle.dump(self._cache, f)
except Exception as e:
print(f"[Cache Save Error] {e}")

def _load_cache(self) -> None:
"""
Load the cache from disk, if the file exists and is non-empty.
"""
if (
self._persist_path
and os.path.exists(self._persist_path)
and os.path.getsize(self._persist_path) > 0
):
try:
with open(self._persist_path, "rb") as f:
loaded = pickle.load(f)
if isinstance(loaded, OrderedDict):
self._cache = loaded
except Exception as e:
print(f"[Cache Load Error] {e}")
10 changes: 3 additions & 7 deletions chebifier/prediction_models/base_predictor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from abc import ABC

from functools import lru_cache
from chebifier import modelwise_smiles_lru_cache


class BasePredictor(ABC):
Expand All @@ -23,17 +23,13 @@ def __init__(

self._description = kwargs.get("description", None)

@modelwise_smiles_lru_cache.batch_decorator
def predict_smiles_list(self, smiles_list: list[str]) -> dict:
# list is not hashable, so we convert it to a tuple (useful for caching)
return self.predict_smiles_tuple(tuple(smiles_list))

@lru_cache(maxsize=100)
def predict_smiles_tuple(self, smiles_tuple: tuple[str]) -> dict:
raise NotImplementedError()

def predict_smiles(self, smiles: str) -> dict:
# by default, use list-based prediction
return self.predict_smiles_tuple((smiles,))[0]
return self.predict_smiles_list([smiles])[0]

@property
def info_text(self):
Expand Down
8 changes: 4 additions & 4 deletions chebifier/prediction_models/c3p_predictor.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from functools import lru_cache
from typing import Optional, List
from pathlib import Path
from typing import List, Optional

from c3p import classifier as c3p_classifier

from chebifier import modelwise_smiles_lru_cache
from chebifier.prediction_models import BasePredictor


Expand All @@ -24,8 +24,8 @@ def __init__(
self.chemical_classes = chemical_classes
self.chebi_graph = kwargs.get("chebi_graph", None)

@lru_cache(maxsize=100)
def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list:
@modelwise_smiles_lru_cache.batch_decorator
def predict_smiles_list(self, smiles_list: list[str]) -> list:
result_list = c3p_classifier.classify(
list(smiles_list),
self.program_directory,
Expand Down
17 changes: 9 additions & 8 deletions chebifier/prediction_models/chebi_lookup.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from functools import lru_cache
import json
import os
from typing import Optional

from chebifier.prediction_models import BasePredictor
import os
import networkx as nx
from rdkit import Chem
import json

from chebifier import modelwise_smiles_lru_cache
from chebifier.prediction_models import BasePredictor
from chebifier.utils import load_chebi_graph


class ChEBILookupPredictor(BasePredictor):

def __init__(
self,
model_name: str,
Expand Down Expand Up @@ -67,7 +67,6 @@ def build_smiles_lookup(self):
)
return smiles_lookup

@lru_cache(maxsize=100)
def predict_smiles(self, smiles: str) -> Optional[dict]:
if not smiles:
return None
Expand All @@ -94,7 +93,8 @@ def predict_smiles(self, smiles: str) -> Optional[dict]:
else:
return None

def predict_smiles_tuple(self, smiles_list: list[str]) -> list:
@modelwise_smiles_lru_cache.batch_decorator
def predict_smiles_list(self, smiles_list: list[str]) -> list:
predictions = []
for smiles in smiles_list:
predictions.append(self.predict_smiles(smiles))
Expand Down Expand Up @@ -145,7 +145,8 @@ def explain_smiles(self, smiles: str) -> dict:
# Example usage
smiles_list = [
"CCO",
"C1=CC=CC=C1" "*C(=O)OC[C@H](COP(=O)([O-])OCC[N+](C)(C)C)OC(*)=O",
"C1=CC=CC=C1",
"*C(=O)OC[C@H](COP(=O)([O-])OCC[N+](C)(C)C)OC(*)=O",
] # SMILES with 251 matches in ChEBI
predictions = predictor.predict_smiles_list(smiles_list)
print(predictions)
14 changes: 6 additions & 8 deletions chebifier/prediction_models/chemlog_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
)
from chemlog.cli import CLASSIFIERS, _smiles_to_mol, strategy_call
from chemlog_extra.alg_classification.by_element_classification import (
XMolecularEntityClassifier,
OrganoXCompoundClassifier,
XMolecularEntityClassifier,
)
from functools import lru_cache

from chebifier import modelwise_smiles_lru_cache

from .base_predictor import BasePredictor

Expand Down Expand Up @@ -47,15 +48,14 @@


class ChemlogExtraPredictor(BasePredictor):

CHEMLOG_CLASSIFIER = None

def __init__(self, model_name: str, **kwargs):
super().__init__(model_name, **kwargs)
self.chebi_graph = kwargs.get("chebi_graph", None)
self.classifier = self.CHEMLOG_CLASSIFIER()

def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list:
def predict_smiles_list(self, smiles_list: list[str]) -> list:
Comment thread
aditya0by0 marked this conversation as resolved.
mol_list = [_smiles_to_mol(smiles) for smiles in smiles_list]
res = self.classifier.classify(mol_list)
if self.chebi_graph is not None:
Expand All @@ -72,12 +72,10 @@ def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list:


class ChemlogXMolecularEntityPredictor(ChemlogExtraPredictor):

CHEMLOG_CLASSIFIER = XMolecularEntityClassifier


class ChemlogOrganoXCompoundPredictor(ChemlogExtraPredictor):

CHEMLOG_CLASSIFIER = OrganoXCompoundClassifier


Expand All @@ -97,7 +95,6 @@ def __init__(self, model_name: str, **kwargs):
# fmt: on
print(f"Initialised ChemLog model {self.model_name}")

@lru_cache(maxsize=100)
def predict_smiles(self, smiles: str) -> Optional[dict]:
mol = _smiles_to_mol(smiles)
if mol is None:
Expand All @@ -122,7 +119,8 @@ def predict_smiles(self, smiles: str) -> Optional[dict]:
for label in self.peptide_labels + pos_labels
}

def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list:
@modelwise_smiles_lru_cache.batch_decorator
def predict_smiles_list(self, smiles_list: list[str]) -> list:
results = []
for i, smiles in tqdm.tqdm(enumerate(smiles_list)):
results.append(self.predict_smiles(smiles))
Expand Down
Loading