diff --git a/CHANGELOG.md b/CHANGELOG.md index 3f90063..4fcca57 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## 1.6.0 /2025-01-27 +* Fix typo by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/258 +* Improve Disk Caching by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/227 + +**Full Changelog**: https://github.com/opentensor/async-substrate-interface/compare/v1.5.15...v1.6.0 + ## 1.5.15 /2025-12-22 * Modifies the CachedFetcher to not keep pending exceptions by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/253 diff --git a/README.md b/README.md index 17960c4..6d5ad4d 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ pip install async-substrate-interface ## Usage -Here are examples of how to use the sync and async inferfaces: +Here are examples of how to use the sync and async interfaces: ```python from async_substrate_interface import SubstrateInterface diff --git a/async_substrate_interface/async_substrate.py b/async_substrate_interface/async_substrate.py index dae91e5..91c4ca0 100644 --- a/async_substrate_interface/async_substrate.py +++ b/async_substrate_interface/async_substrate.py @@ -1181,7 +1181,10 @@ async def __aenter__(self): await self.initialize() return self - async def initialize(self): + async def initialize(self) -> None: + await self._initialize() + + async def _initialize(self) -> None: """ Initialize the connection to the chain. """ @@ -1206,7 +1209,7 @@ async def initialize(self): self._initializing = False async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.ws.shutdown() + await self.close() @property def metadata(self): @@ -2428,7 +2431,6 @@ async def get_block_metadata( "MetadataVersioned", data=ScaleBytes(result) ) metadata_decoder.decode() - return metadata_decoder else: return result @@ -4289,20 +4291,27 @@ async def _handler(block_data: dict[str, Any]): class DiskCachedAsyncSubstrateInterface(AsyncSubstrateInterface): """ - Experimental new class that uses disk-caching in addition to memory-caching for the cached methods + Uses disk-caching in addition to memory-caching for the cached methods + + Loads the cache from the disk at startup, where it is kept in-memory, and dumps to the disk + when the connection is closed. """ + async def initialize(self) -> None: + await self.runtime_cache.load_from_disk(self.url) + await self._initialize() + async def close(self): """ - Closes the substrate connection, and the websocket connection. + Closes the substrate connection and the websocket connection, dumps the runtime cache to disk """ try: + await self.runtime_cache.dump_to_disk(self.url) await self.ws.shutdown() except AttributeError: pass - db_conn = AsyncSqliteDB(self.url) - if db_conn._db is not None: - await db_conn._db.close() + db = AsyncSqliteDB(self.url) + await db.close() @async_sql_lru_cache(maxsize=SUBSTRATE_CACHE_METHOD_SIZE) async def get_parent_block_hash(self, block_hash): diff --git a/async_substrate_interface/types.py b/async_substrate_interface/types.py index b9fbfae..8878497 100644 --- a/async_substrate_interface/types.py +++ b/async_substrate_interface/types.py @@ -2,6 +2,7 @@ from abc import ABC from collections import defaultdict, deque from collections.abc import Iterable +from contextlib import suppress from dataclasses import dataclass from datetime import datetime from typing import Optional, Union, Any @@ -16,6 +17,7 @@ from .const import SS58_FORMAT from .utils import json +from .utils.cache import AsyncSqliteDB logger = logging.getLogger("async_substrate_interface") @@ -34,8 +36,8 @@ class RuntimeCache: is important you are utilizing the correct version. """ - blocks: dict[int, "Runtime"] - block_hashes: dict[str, "Runtime"] + blocks: dict[int, str] + block_hashes: dict[str, int] versions: dict[int, "Runtime"] last_used: Optional["Runtime"] @@ -56,10 +58,10 @@ def add_item( Adds a Runtime object to the cache mapped to its version, block number, and/or block hash. """ self.last_used = runtime - if block is not None: - self.blocks[block] = runtime - if block_hash is not None: - self.block_hashes[block_hash] = runtime + if block is not None and block_hash is not None: + self.blocks[block] = block_hash + if block_hash is not None and runtime_version is not None: + self.block_hashes[block_hash] = runtime_version if runtime_version is not None: self.versions[runtime_version] = runtime @@ -73,33 +75,52 @@ def retrieve( Retrieves a Runtime object from the cache, using the key of its block number, block hash, or runtime version. Retrieval happens in this order. If no Runtime is found mapped to any of your supplied keys, returns `None`. """ + runtime = None if block is not None: - runtime = self.blocks.get(block) - if runtime is not None: - if block_hash is not None: - # if lookup occurs for block_hash and block, but only block matches, also map to block_hash - self.add_item(runtime, block_hash=block_hash) + if block_hash is not None: + self.blocks[block] = block_hash + if runtime_version is not None: + self.block_hashes[block_hash] = runtime_version + with suppress(KeyError): + runtime = self.versions[self.block_hashes[self.blocks[block]]] self.last_used = runtime return runtime if block_hash is not None: - runtime = self.block_hashes.get(block_hash) - if runtime is not None: - if block is not None: - # if lookup occurs for block_hash and block, but only block_hash matches, also map to block - self.add_item(runtime, block=block) + if runtime_version is not None: + self.block_hashes[block_hash] = runtime_version + with suppress(KeyError): + runtime = self.versions[self.block_hashes[block_hash]] self.last_used = runtime return runtime if runtime_version is not None: - runtime = self.versions.get(runtime_version) - if runtime is not None: - # if runtime_version matches, also map to block and block_hash (if supplied) - if block is not None: - self.add_item(runtime, block=block) - if block_hash is not None: - self.add_item(runtime, block_hash=block_hash) + with suppress(KeyError): + runtime = self.versions[runtime_version] self.last_used = runtime return runtime - return None + return runtime + + async def load_from_disk(self, chain_endpoint: str): + db = AsyncSqliteDB(chain_endpoint=chain_endpoint) + ( + block_mapping, + block_hash_mapping, + runtime_version_mapping, + ) = await db.load_runtime_cache(chain_endpoint) + if not any([block_mapping, block_hash_mapping, runtime_version_mapping]): + logger.debug("No runtime mappings in disk cache") + else: + logger.debug("Found runtime mappings in disk cache") + self.blocks = block_mapping + self.block_hashes = block_hash_mapping + self.versions = { + x: Runtime.deserialize(y) for x, y in runtime_version_mapping.items() + } + + async def dump_to_disk(self, chain_endpoint: str): + db = AsyncSqliteDB(chain_endpoint=chain_endpoint) + await db.dump_runtime_cache( + chain_endpoint, self.blocks, self.block_hashes, self.versions + ) class Runtime: @@ -149,6 +170,45 @@ def __init__( if registry is not None: self.load_registry_type_map() + def serialize(self): + metadata_value = self.metadata.data.data + return { + "chain": self.chain, + "type_registry": self.type_registry, + "metadata_value": metadata_value, + "metadata_v15": self.metadata_v15.encode_to_metadata_option(), + "runtime_info": { + "specVersion": self.runtime_version, + "transactionVersion": self.transaction_version, + }, + "registry": self.registry.registry if self.registry is not None else None, + "ss58_format": self.ss58_format, + } + + @classmethod + def deserialize(cls, serialized: dict) -> "Runtime": + ss58_format = serialized["ss58_format"] + runtime_config = RuntimeConfigurationObject(ss58_format=ss58_format) + runtime_config.clear_type_registry() + runtime_config.update_type_registry(load_type_registry_preset(name="core")) + metadata = runtime_config.create_scale_object( + "MetadataVersioned", data=ScaleBytes(serialized["metadata_value"]) + ) + metadata.decode() + registry = PortableRegistry.from_json(serialized["registry"]) + return cls( + chain=serialized["chain"], + metadata=metadata, + type_registry=serialized["type_registry"], + runtime_config=runtime_config, + metadata_v15=MetadataV15.decode_from_metadata_option( + serialized["metadata_v15"] + ), + registry=registry, + ss58_format=ss58_format, + runtime_info=serialized["runtime_info"], + ) + def load_runtime(self): """ Initial loading of the runtime's type registry information. diff --git a/async_substrate_interface/utils/cache.py b/async_substrate_interface/utils/cache.py index d8b494a..24c609c 100644 --- a/async_substrate_interface/utils/cache.py +++ b/async_substrate_interface/utils/cache.py @@ -20,6 +20,7 @@ if USE_CACHE else ":memory:" ) +SUBSTRATE_CACHE_METHOD_SIZE = int(os.getenv("SUBSTRATE_CACHE_METHOD_SIZE", "512")) logger = logging.getLogger("async_substrate_interface") @@ -38,13 +39,13 @@ def __new__(cls, chain_endpoint: str): cls._instances[chain_endpoint] = instance return instance - async def __call__(self, chain, other_self, func, args, kwargs) -> Optional[Any]: + async def close(self): async with self._lock: - if not self._db: - _ensure_dir() - self._db = await aiosqlite.connect(CACHE_LOCATION) - table_name = _get_table_name(func) - key = None + if self._db: + await self._db.close() + self._db = None + + async def _create_if_not_exists(self, chain: str, table_name: str): if not (local_chain := _check_if_local(chain)) or not USE_CACHE: await self._db.execute( f""" @@ -54,7 +55,8 @@ async def __call__(self, chain, other_self, func, args, kwargs) -> Optional[Any] key BLOB, value BLOB, chain TEXT, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE(key, chain) ); """ ) @@ -66,25 +68,34 @@ async def __call__(self, chain, other_self, func, args, kwargs) -> Optional[Any] WHERE rowid IN ( SELECT rowid FROM {table_name} ORDER BY created_at DESC - LIMIT -1 OFFSET 500 + LIMIT -1 OFFSET {SUBSTRATE_CACHE_METHOD_SIZE} ); END; """ ) await self._db.commit() - key = pickle.dumps((args, kwargs or None)) - try: - cursor: aiosqlite.Cursor = await self._db.execute( - f"SELECT value FROM {table_name} WHERE key=? AND chain=?", - (key, chain), - ) - result = await cursor.fetchone() - await cursor.close() - if result is not None: - return pickle.loads(result[0]) - except (pickle.PickleError, sqlite3.Error) as e: - logger.exception("Cache error", exc_info=e) - pass + return local_chain + + async def __call__(self, chain, other_self, func, args, kwargs) -> Optional[Any]: + async with self._lock: + if not self._db: + _ensure_dir() + self._db = await aiosqlite.connect(CACHE_LOCATION) + table_name = _get_table_name(func) + local_chain = await self._create_if_not_exists(chain, table_name) + key = pickle.dumps((args, kwargs or None)) + try: + cursor: aiosqlite.Cursor = await self._db.execute( + f"SELECT value FROM {table_name} WHERE key=? AND chain=?", + (key, chain), + ) + result = await cursor.fetchone() + await cursor.close() + if result is not None: + return pickle.loads(result[0]) + except (pickle.PickleError, sqlite3.Error) as e: + logger.exception("Cache error", exc_info=e) + pass result = await func(other_self, *args, **kwargs) if not local_chain or not USE_CACHE: # TODO use a task here @@ -95,6 +106,85 @@ async def __call__(self, chain, other_self, func, args, kwargs) -> Optional[Any] await self._db.commit() return result + async def load_runtime_cache(self, chain: str) -> tuple[dict, dict, dict]: + async with self._lock: + if not self._db: + _ensure_dir() + self._db = await aiosqlite.connect(CACHE_LOCATION) + block_mapping = {} + block_hash_mapping = {} + version_mapping = {} + tables = { + "RuntimeCache_blocks": block_mapping, + "RuntimeCache_block_hashes": block_hash_mapping, + "RuntimeCache_versions": version_mapping, + } + for table in tables.keys(): + async with self._lock: + local_chain = await self._create_if_not_exists(chain, table) + if local_chain: + return {}, {}, {} + for table_name, mapping in tables.items(): + try: + async with self._lock: + cursor: aiosqlite.Cursor = await self._db.execute( + f"SELECT key, value FROM {table_name} WHERE chain=?", + (chain,), + ) + results = await cursor.fetchall() + await cursor.close() + if results is None: + continue + for row in results: + key, value = row + runtime = pickle.loads(value) + mapping[key] = runtime + except (pickle.PickleError, sqlite3.Error) as e: + logger.exception("Cache error", exc_info=e) + return {}, {}, {} + return block_mapping, block_hash_mapping, version_mapping + + async def dump_runtime_cache( + self, + chain: str, + block_mapping: dict, + block_hash_mapping: dict, + version_mapping: dict, + ) -> None: + async with self._lock: + if not self._db: + _ensure_dir() + self._db = await aiosqlite.connect(CACHE_LOCATION) + + tables = { + "RuntimeCache_blocks": block_mapping, + "RuntimeCache_block_hashes": block_hash_mapping, + "RuntimeCache_versions": version_mapping, + } + for table, mapping in tables.items(): + local_chain = await self._create_if_not_exists(chain, table) + if local_chain: + return None + serialized_mapping = {} + for key, value in mapping.items(): + if not isinstance(value, (str, int)): + serialized_value = pickle.dumps(value.serialize()) + else: + serialized_value = pickle.dumps(value) + serialized_mapping[key] = serialized_value + + await self._db.executemany( + f"INSERT OR REPLACE INTO {table} (key, value, chain) VALUES (?,?,?)", + [ + (key, serialized_value_, chain) + for key, serialized_value_ in serialized_mapping.items() + ], + ) + + await self._db.commit() + + return None + def _ensure_dir(): path = Path(CACHE_LOCATION).parent @@ -119,7 +209,8 @@ def _create_table(c, conn, table_name): key BLOB, value BLOB, chain TEXT, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE(key, chain) ); """ ) @@ -130,7 +221,7 @@ def _create_table(c, conn, table_name): WHERE rowid IN ( SELECT rowid FROM {table_name} ORDER BY created_at DESC - LIMIT -1 OFFSET 500 + LIMIT -1 OFFSET {SUBSTRATE_CACHE_METHOD_SIZE} ); END;""" ) @@ -205,7 +296,7 @@ def inner(self, *args, **kwargs): def async_sql_lru_cache(maxsize: Optional[int] = None): def decorator(func): - @cached_fetcher(max_size=maxsize) + @cached_fetcher(max_size=maxsize, cache_key_index=None) async def inner(self, *args, **kwargs): async_sql_db = AsyncSqliteDB(self.url) result = await async_sql_db(self.url, self, func, args, kwargs) @@ -300,7 +391,7 @@ def make_cache_key(self, args: tuple, kwargs: dict) -> Hashable: key_name = list(bound.arguments)[self._cache_key_index] return bound.arguments[key_name] - return (tuple(bound.arguments.items()),) + return pickle.dumps(dict(bound.arguments)) async def __call__(self, *args: Any, **kwargs: Any) -> Any: key = self.make_cache_key(args, kwargs) @@ -354,7 +445,7 @@ def __get__(self, instance, owner): return self._instances[instance] -def cached_fetcher(max_size: Optional[int] = None, cache_key_index: int = 0): +def cached_fetcher(max_size: Optional[int] = None, cache_key_index: Optional[int] = 0): """Wrapper for CachedFetcher. See example in CachedFetcher docstring.""" def wrapper(method): diff --git a/pyproject.toml b/pyproject.toml index 7eb8606..ee9db3b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "async-substrate-interface" -version = "1.5.15" +version = "1.6.0" description = "Asyncio library for interacting with substrate. Mostly API-compatible with py-substrate-interface" readme = "README.md" license = { file = "LICENSE" } @@ -8,11 +8,11 @@ keywords = ["substrate", "development", "bittensor"] dependencies = [ "wheel", + "aiosqlite>=0.21.0,<1.0.0", "bt-decode==v0.8.0", "scalecodec~=1.2.11", "websockets>=14.1", "xxhash", - "aiosqlite>=0.21.0,<1.0.0" ] requires-python = ">=3.9,<3.15" diff --git a/tests/integration_tests/test_disk_cache.py b/tests/integration_tests/test_disk_cache.py index b6cbf45..cdebcc6 100644 --- a/tests/integration_tests/test_disk_cache.py +++ b/tests/integration_tests/test_disk_cache.py @@ -1,5 +1,5 @@ import pytest - +import time from async_substrate_interface.async_substrate import ( DiskCachedAsyncSubstrateInterface, AsyncSubstrateInterface, @@ -41,6 +41,7 @@ async def test_disk_cache(): assert parent_block_hash == parent_block_hash_from_cache assert block_runtime_info == block_runtime_info_from_cache assert block_runtime_version_for == block_runtime_version_from_cache + # Verify data integrity with non-disk cached Async Substrate Interface async with AsyncSubstrateInterface(entrypoint) as non_cache_substrate: block_hash_non_cache = await non_cache_substrate.get_block_hash(current_block) parent_block_hash_non_cache = await non_cache_substrate.get_parent_block_hash( @@ -58,6 +59,7 @@ async def test_disk_cache(): assert parent_block_hash == parent_block_hash_non_cache assert block_runtime_info == block_runtime_info_non_cache assert block_runtime_version_for == block_runtime_version_for_non_cache + # Verify data integrity with sync Substrate Interface with SubstrateInterface(entrypoint) as sync_substrate: block_hash_sync = sync_substrate.get_block_hash(current_block) parent_block_hash_sync = sync_substrate.get_parent_block_hash( @@ -73,4 +75,55 @@ async def test_disk_cache(): assert parent_block_hash == parent_block_hash_sync assert block_runtime_info == block_runtime_info_sync assert block_runtime_version_for == block_runtime_version_for_sync - print("test_disk_cache succeeded") + # Verify data is pulling from disk cache + async with DiskCachedAsyncSubstrateInterface(entrypoint) as disk_cached_substrate: + start = time.monotonic() + new_block_hash = await disk_cached_substrate.get_block_hash(current_block) + new_time = time.monotonic() + assert new_time - start < 0.001 + + start = time.monotonic() + new_parent_block_hash = await disk_cached_substrate.get_parent_block_hash( + block_hash + ) + new_time = time.monotonic() + assert new_time - start < 0.001 + start = time.monotonic() + new_block_runtime_info = await disk_cached_substrate.get_block_runtime_info( + block_hash + ) + new_time = time.monotonic() + assert new_time - start < 0.001 + start = time.monotonic() + new_block_runtime_version_for = ( + await disk_cached_substrate.get_block_runtime_version_for(block_hash) + ) + new_time = time.monotonic() + assert new_time - start < 0.001 + start = time.monotonic() + new_block_hash_from_cache = await disk_cached_substrate.get_block_hash( + current_block + ) + new_time = time.monotonic() + assert new_time - start < 0.001 + start = time.monotonic() + new_parent_block_hash_from_cache = ( + await disk_cached_substrate.get_parent_block_hash(block_hash_from_cache) + ) + new_time = time.monotonic() + assert new_time - start < 0.001 + start = time.monotonic() + new_block_runtime_info_from_cache = ( + await disk_cached_substrate.get_block_runtime_info(block_hash_from_cache) + ) + new_time = time.monotonic() + assert new_time - start < 0.001 + start = time.monotonic() + new_block_runtime_version_from_cache = ( + await disk_cached_substrate.get_block_runtime_version_for( + block_hash_from_cache + ) + ) + new_time = time.monotonic() + assert new_time - start < 0.001 + print("Disk Cache tests passed") diff --git a/tests/unit_tests/test_types.py b/tests/unit_tests/test_types.py index 7292177..f2e13b4 100644 --- a/tests/unit_tests/test_types.py +++ b/tests/unit_tests/test_types.py @@ -1,4 +1,12 @@ from async_substrate_interface.types import ScaleObj, Runtime, RuntimeCache +from async_substrate_interface.async_substrate import DiskCachedAsyncSubstrateInterface +from async_substrate_interface.utils import cache + +import sqlite3 +import os +import pickle +import pytest +from unittest.mock import patch def test_scale_object(): @@ -72,13 +80,83 @@ def test_runtime_cache(): # cache does not yet know that new_fake_block has the same runtime assert runtime_cache.retrieve(new_fake_block) is None assert ( - runtime_cache.retrieve(new_fake_block, runtime_version=fake_version) is not None + runtime_cache.retrieve( + new_fake_block, new_fake_hash, runtime_version=fake_version + ) + is not None ) # after checking the runtime with the new block, it now knows this runtime should also map to this block assert runtime_cache.retrieve(new_fake_block) is not None assert runtime_cache.retrieve(newer_fake_block) is None assert runtime_cache.retrieve(newer_fake_block, fake_hash) is not None assert runtime_cache.retrieve(newer_fake_block) is not None - assert runtime_cache.retrieve(block_hash=new_fake_hash) is None assert runtime_cache.retrieve(fake_block, block_hash=new_fake_hash) is not None assert runtime_cache.retrieve(block_hash=new_fake_hash) is not None + + +@pytest.mark.asyncio +async def test_runtime_cache_from_disk(): + test_db_location = "/tmp/async-substrate-interface-test-cache" + fake_chain = "ws://fake.com" + fake_block = 1 + fake_hash = "0xignore" + new_fake_block = 2 + new_fake_hash = "0xnewfakehash" + + if os.path.exists(test_db_location): + os.remove(test_db_location) + with patch.object(cache, "CACHE_LOCATION", test_db_location): + substrate = DiskCachedAsyncSubstrateInterface(fake_chain, _mock=True) + # Needed to avoid trying to initialize on the network during `substrate.initialize()` + substrate.initialized = True + + # runtime cache should be completely empty + assert substrate.runtime_cache.block_hashes == {} + assert substrate.runtime_cache.blocks == {} + assert substrate.runtime_cache.versions == {} + await substrate.initialize() + + # after initialization, runtime cache should still be completely empty + assert substrate.runtime_cache.block_hashes == {} + assert substrate.runtime_cache.blocks == {} + assert substrate.runtime_cache.versions == {} + await substrate.close() + + # ensure we have created the SQLite DB during initialize() + assert os.path.exists(test_db_location) + + # insert some fake data into our DB + conn = sqlite3.connect(test_db_location) + conn.execute( + "INSERT INTO RuntimeCache_blocks (key, value, chain) VALUES (?, ?, ?)", + (fake_block, pickle.dumps(fake_hash), fake_chain), + ) + conn.commit() + conn.close() + + substrate.initialized = True + await substrate.initialize() + assert substrate.runtime_cache.blocks == {fake_block: fake_hash} + # add an item to the cache + substrate.runtime_cache.add_item( + runtime=None, block_hash=new_fake_hash, block=new_fake_block + ) + await substrate.close() + + # verify that our added item is now in the DB + conn = sqlite3.connect(test_db_location) + cursor = conn.cursor() + cursor.execute("SELECT key, value, chain FROM RuntimeCache_blocks") + query = cursor.fetchall() + cursor.close() + conn.close() + + first_row = query[0] + assert first_row[0] == fake_block + assert pickle.loads(first_row[1]) == fake_hash + assert first_row[2] == fake_chain + + second_row = query[1] + assert second_row[0] == new_fake_block + assert pickle.loads(second_row[1]) == new_fake_hash + assert second_row[2] == fake_chain