Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions chebai_graph/models/dynamic_gni.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ def __init__(self, config: dict[str, Any], **kwargs: Any):
print("Using complete randomness: ", self.complete_randomness)

if not self.complete_randomness:
assert (
"pad_node_features" in config or "pad_edge_features" in config
), "Missing 'pad_node_features' or 'pad_edge_features' in config when complete_randomness is False"
assert "pad_node_features" in config or "pad_edge_features" in config, (
"Missing 'pad_node_features' or 'pad_edge_features' in config when complete_randomness is False"
)
self.pad_node_features = (
int(config["pad_node_features"])
if config.get("pad_node_features") is not None
Expand All @@ -112,9 +112,9 @@ def __init__(self, config: dict[str, Any], **kwargs: Any):
f"in each forward pass."
)

assert (
self.pad_node_features > 0 or self.pad_edge_features > 0
), "'pad_node_features' or 'pad_edge_features' must be positive integers"
assert self.pad_node_features > 0 or self.pad_edge_features > 0, (
"'pad_node_features' or 'pad_edge_features' must be positive integers"
)

self.resgated: BasicGNN = ResGatedModel(
in_channels=self.in_channels,
Expand Down Expand Up @@ -182,9 +182,9 @@ def forward(self, batch: dict[str, Any]) -> Tensor:
)
new_edge_attr = torch.cat((graph_data.edge_attr, pad_edge), dim=1)

assert (
new_x is not None and new_edge_attr is not None
), "Feature initialization failed"
assert new_x is not None and new_edge_attr is not None, (
"Feature initialization failed"
)
out = self.resgated(
x=new_x.float(),
edge_index=graph_data.edge_index.long(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
1
5
6
8
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ SINGLE
AROMATIC
TRIPLE
DOUBLE
UNSPECIFIED
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@
7
10
12
11
9
124 changes: 87 additions & 37 deletions chebai_graph/preprocessing/datasets/chebi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Optional

import pandas as pd
from chebai_graph.preprocessing.reader.augmented_reader import _AugmentorReader
import torch
import tqdm
from chebai.preprocessing.datasets.chebi import (
Expand All @@ -15,6 +16,7 @@
)
from lightning_utilities.core.rank_zero import rank_zero_info
from torch_geometric.data.data import Data as GeomData
from rdkit import Chem

from chebai_graph.preprocessing.properties import (
AllNodeTypeProperty,
Expand Down Expand Up @@ -126,31 +128,52 @@ def enc_if_not_none(encode, value):
else None
)

for property in self.properties:
if not os.path.isfile(self.get_property_path(property)):
rank_zero_info(f"Processing property {property.name}")
# read all property values first, then encode
rank_zero_info(f"\tReading property values of {property.name}...")
property_values = [
self.reader.read_property(feat, property)
for feat in tqdm.tqdm(features)
]
rank_zero_info(f"\tEncoding property values of {property.name}...")
property.encoder.on_start(property_values=property_values)
encoded_values = [
enc_if_not_none(property.encoder.encode, value)
for value in tqdm.tqdm(property_values)
if any(
not os.path.isfile(self.get_property_path(property))
for property in self.properties
):
# augment molecule graph if possible (this would also happen for the properties if needed, but this avoids redundancy)
if isinstance(self.reader, _AugmentorReader):
returned_results = []
for mol in features:
try:
r = self.reader._create_augmented_graph(mol)
except Exception as e:
r = None
returned_results.append(r)
mols = [
augmented_mol[1]
for augmented_mol in returned_results
if augmented_mol is not None
]

torch.save(
[
{property.name: torch.cat(feat), "ident": id}
for feat, id in zip(encoded_values, idents)
if feat is not None
],
self.get_property_path(property),
)
property.on_finish()
else:
mols = features

for property in self.properties:
if not os.path.isfile(self.get_property_path(property)):
rank_zero_info(f"Processing property {property.name}")
# read all property values first, then encode
rank_zero_info(f"\tReading property values of {property.name}...")
property_values = [
self.reader.read_property(mol, property)
for mol in tqdm.tqdm(mols)
]
rank_zero_info(f"\tEncoding property values of {property.name}...")
property.encoder.on_start(property_values=property_values)
encoded_values = [
enc_if_not_none(property.encoder.encode, value)
for value in tqdm.tqdm(property_values)
]

torch.save(
[
{property.name: torch.cat(feat), "ident": id}
for feat, id in zip(encoded_values, idents)
if feat is not None
],
self.get_property_path(property),
)
property.on_finish()

@property
def processed_properties_dir(self) -> str:
Expand Down Expand Up @@ -185,20 +208,23 @@ def _after_setup(self, **kwargs) -> None:
super()._after_setup(**kwargs)

def _preprocess_smiles_for_pred(
self, idx, smiles: str, model_hparams: Optional[dict] = None
) -> dict:
self, idx, raw_data: str | Chem.Mol, model_hparams: Optional[dict] = None
) -> Optional[dict]:
"""Preprocess prediction data."""
# Add dummy labels because the collate function requires them.
# Note: If labels are set to `None`, the collator will insert a `non_null_labels` entry into `loss_kwargs`,
# which later causes `_get_prediction_and_labels` method in the prediction pipeline to treat the data as empty.
result = self.reader.to_data(
{"id": f"smiles_{idx}", "features": smiles, "labels": [1, 2]}
{"id": f"smiles_{idx}", "features": raw_data, "labels": [1, 2]}
)
# _read_data can return an updated version of the input data (e.g. augmented molecule dict) along with the GeomData object
if isinstance(result["features"], tuple):
result["features"], raw_data = result["features"]
if result is None or result["features"] is None:
return None
for property in self.properties:
property.encoder.eval = True
property_value = self.reader.read_property(smiles, property)
property_value = self.reader.read_property(raw_data, property)
if property_value is None or len(property_value) == 0:
encoded_value = None
else:
Expand Down Expand Up @@ -250,7 +276,9 @@ def __init__(
assert (
distribution is not None
and distribution in RandomFeatureInitializationReader.DISTRIBUTIONS
), "When using padding for features, a valid distribution must be specified."
), (
"When using padding for features, a valid distribution must be specified."
)
self.distribution = distribution
if self.pad_node_features:
print(
Expand Down Expand Up @@ -278,7 +306,12 @@ def _merge_props_into_base(self, row: pd.Series | dict) -> GeomData:
Returns:
A GeomData object with merged features.
"""
geom_data = row["features"]
if isinstance(row["features"], tuple):
geom_data, _ = row[
"features"
] # ignore additional returned data from _read_data (e.g. augmented molecule dict)
else:
geom_data = row["features"]
assert isinstance(geom_data, GeomData)
edge_attr = geom_data.edge_attr
x = geom_data.x
Expand Down Expand Up @@ -538,6 +571,10 @@ def _merge_props_into_base(
geom_data = row["features"]
if geom_data is None:
return None
if isinstance(geom_data, tuple):
geom_data = geom_data[
0
] # ignore additional returned data from _read_data (e.g. augmented molecule dict)
assert isinstance(geom_data, GeomData)

is_atom_node = geom_data.is_atom_node
Expand All @@ -550,9 +587,9 @@ def _merge_props_into_base(
edge_attr = geom_data.edge_attr

# Initialize node feature matrix
assert (
max_len_node_properties is not None
), "Maximum len of node properties should not be None"
assert max_len_node_properties is not None, (
"Maximum len of node properties should not be None"
)
x = torch.zeros((num_nodes, max_len_node_properties))

# Track column offsets for each node type
Expand All @@ -573,7 +610,14 @@ def _merge_props_into_base(
enc_len = property_values.shape[1]
# -------------- Node properties ---------------
if isinstance(property, AllNodeTypeProperty):
x[:, atom_offset : atom_offset + enc_len] = property_values
try:
x[:, atom_offset : atom_offset + enc_len] = property_values
except Exception as e:
raise ValueError(
f"Error assigning property '{property.name}' values to node features: {e}\n"
f"Property values shape: {property_values.shape}, expected (num_nodes, {enc_len})\n"
f"Node feature matrix shape: {x.shape}"
)
atom_offset += enc_len
fg_offset += enc_len
graph_offset += enc_len
Expand Down Expand Up @@ -607,9 +651,9 @@ def _merge_props_into_base(
raise TypeError(f"Unsupported property type: {type(property).__name__}")

total_used_columns = max(atom_offset, fg_offset, graph_offset)
assert (
total_used_columns <= max_len_node_properties
), f"Used {total_used_columns} columns, but max allowed is {max_len_node_properties}"
assert total_used_columns <= max_len_node_properties, (
f"Used {total_used_columns} columns, but max allowed is {max_len_node_properties}"
)

return GeomData(
x=x,
Expand Down Expand Up @@ -805,3 +849,9 @@ class ChEBI50_WFGE_WGN_AsPerNodeType(GraphPropAsPerNodeType, ChEBIOver50):

class ChEBI100_WFGE_WGN_AsPerNodeType(GraphPropAsPerNodeType, ChEBIOver100):
READER = AtomFGReader_WithFGEdges_WithGraphNode


class ChEBI25_WFGE_WGN_AsPerNodeType(GraphPropAsPerNodeType, ChEBIOverX):
READER = AtomFGReader_WithFGEdges_WithGraphNode

THRESHOLD = 25
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from rdkit.Chem import AllChem
from rdkit.Chem import MolToSmiles as m2s

from chebi_utils.sdf_extractor import _sanitize_molecule

from .fg_constants import ELEMENTS, FLAG_NO_FG


Expand Down Expand Up @@ -1911,7 +1913,11 @@ def get_structure(mol):
structure[frag] = {"atom": atom_idx, "is_ring_fg": False}

# Convert fragment SMILES back to mol to match with fused ring atom indices
frag_mol = Chem.MolFromSmiles(frag)
frag_mol = Chem.MolFromSmiles(frag, sanitize=False)
try:
frag_mol = _sanitize_molecule(frag_mol)
except:
pass
frag_rings = frag_mol.GetRingInfo().AtomRings()
if len(frag_rings) >= 1:
structure[frag]["is_ring_fg"] = True
Expand Down
12 changes: 6 additions & 6 deletions chebai_graph/preprocessing/properties/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,9 @@ def get_property_value(self, augmented_mol: dict) -> list:
)
prop_list.append(self.get_atom_value(graph_node))

assert (
len(prop_list) == augmented_mol[self.MAIN_KEY]["num_nodes"]
), "Number of property values should be equal to number of nodes"
assert len(prop_list) == augmented_mol[self.MAIN_KEY]["num_nodes"], (
"Number of property values should be equal to number of nodes"
)
return prop_list

def _check_modify_atom_prop_value(
Expand Down Expand Up @@ -390,9 +390,9 @@ def get_property_value(self, augmented_mol: dict) -> list:
)

num_directed_edges = augmented_mol[self.MAIN_KEY][k.NUM_EDGES] // 2
assert (
len(prop_list) == num_directed_edges
), f"Number of property values ({len(prop_list)}) should be equal to number of half the number of undirected edges i.e. must be equal to {num_directed_edges} "
assert len(prop_list) == num_directed_edges, (
f"Number of property values ({len(prop_list)}) should be equal to number of half the number of undirected edges i.e. must be equal to {num_directed_edges} "
)

return prop_list

Expand Down
9 changes: 6 additions & 3 deletions chebai_graph/preprocessing/property_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ def encode(self, token: str | None) -> torch.Tensor:
return torch.tensor([self._unk_token_idx])

if str(token) not in self.cache:
# Ensure cache is a mutable dict (jsonargparse may convert it to mappingproxy)
if not isinstance(self.cache, dict):
self.cache = dict(self.cache)
self.cache[str(token)] = len(self.cache)
return torch.tensor([self.cache[str(token)] + self.offset])

Expand Down Expand Up @@ -258,9 +261,9 @@ def encode(self, token: float | int | None) -> torch.Tensor:
"""
if token is None:
return torch.zeros(1, self.get_encoding_length())
assert (
len(token) == self.get_encoding_length()
), "Length of token should be equal to encoding length"
assert len(token) == self.get_encoding_length(), (
"Length of token should be equal to encoding length"
)
# return torch.tensor([token]) # token is an ndarray, no need to create list of ndarray due to below warning
# UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow.
# Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor.
Expand Down
Loading
Loading