Skip to content

Commit 4c07a77

Browse files
authored
Merge pull request #1876 from weaviate/introduce-data-ingest-for-ssb
Introduce `collection.data.ingest` for sync/async SSB usage
2 parents ab08410 + 687c7de commit 4c07a77

12 files changed

Lines changed: 174 additions & 55 deletions

File tree

.github/workflows/main.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ env:
2727
WEAVIATE_133: 1.33.11
2828
WEAVIATE_134: 1.34.8
2929
WEAVIATE_135: 1.35.2
30-
WEAVIATE_136: 1.36.0-dev-c8f578d
30+
WEAVIATE_136: 1.36.0-dev-0bbf31a
3131

3232
jobs:
3333
lint-and-format:

integration/test_batch_v4.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ def test_add_ten_thousand_data_objects(
433433
"""Test adding ten thousand data objects."""
434434
client, name = client_factory()
435435
if (
436-
request.node.callspec.id == "test_add_ten_thousand_data_objects_experimental"
436+
request.node.callspec.id == "test_add_ten_thousand_data_objects_stream"
437437
and client._connection._weaviate_version.is_lower_than(1, 36, 0)
438438
):
439439
pytest.skip("Server-side batching not supported in Weaviate < 1.36.0")
@@ -641,7 +641,7 @@ def test_add_one_object_and_a_self_reference(
641641
"""Test adding one object and a self reference."""
642642
client, name = client_factory()
643643
if (
644-
request.node.callspec.id == "test_add_one_object_and_a_self_reference_experimental"
644+
request.node.callspec.id == "test_add_one_object_and_a_self_reference_stream"
645645
and client._connection._weaviate_version.is_lower_than(1, 36, 0)
646646
):
647647
pytest.skip("Server-side batching not supported in Weaviate < 1.36.0")

integration/test_collection_batch.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def test_non_existant_collection(collection_factory_get: CollectionFactoryGet) -
271271

272272

273273
@pytest.mark.asyncio
274-
async def test_add_one_hundred_thousand_objects_async_collection(
274+
async def test_batch_one_hundred_thousand_objects_async_collection(
275275
batch_collection_async: BatchCollectionAsync,
276276
) -> None:
277277
"""Test adding one hundred thousand data objects."""
@@ -295,3 +295,46 @@ async def test_add_one_hundred_thousand_objects_async_collection(
295295
assert await col.length() == nr_objects
296296
assert col.batch.results.objs.has_errors is False
297297
assert len(col.batch.failed_objects) == 0, [obj.message for obj in col.batch.failed_objects]
298+
299+
300+
@pytest.mark.asyncio
301+
async def test_ingest_one_hundred_thousand_data_objects_async(
302+
batch_collection_async: BatchCollectionAsync,
303+
) -> None:
304+
col = await batch_collection_async()
305+
if col._connection._weaviate_version.is_lower_than(1, 36, 0):
306+
pytest.skip("Server-side batching not supported in Weaviate < 1.36.0")
307+
nr_objects = 100000
308+
import time
309+
310+
start = time.time()
311+
results = await col.data.ingest({"name": "test" + str(i)} for i in range(nr_objects))
312+
end = time.time()
313+
print(f"Time taken to add {nr_objects} objects: {end - start} seconds")
314+
assert len(results.errors) == 0
315+
assert len(results.all_responses) == nr_objects
316+
assert len(results.uuids) == nr_objects
317+
assert await col.length() == nr_objects
318+
assert results.has_errors is False
319+
assert len(results.errors) == 0, [obj.message for obj in results.errors.values()]
320+
321+
322+
def test_ingest_one_hundred_thousand_data_objects(
323+
batch_collection: BatchCollection,
324+
) -> None:
325+
col = batch_collection()
326+
if col._connection._weaviate_version.is_lower_than(1, 36, 0):
327+
pytest.skip("Server-side batching not supported in Weaviate < 1.36.0")
328+
nr_objects = 100000
329+
import time
330+
331+
start = time.time()
332+
results = col.data.ingest({"name": "test" + str(i)} for i in range(nr_objects))
333+
end = time.time()
334+
print(f"Time taken to add {nr_objects} objects: {end - start} seconds")
335+
assert len(results.errors) == 0
336+
assert len(results.all_responses) == nr_objects
337+
assert len(results.uuids) == nr_objects
338+
assert len(col) == nr_objects
339+
assert results.has_errors is False
340+
assert len(results.errors) == 0, [obj.message for obj in results.errors.values()]

integration/test_rbac.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -742,8 +742,8 @@ def test_server_side_batching_with_auth() -> None:
742742
with connect_to_local(
743743
port=RBAC_PORTS[0], grpc_port=RBAC_PORTS[1], auth_credentials=RBAC_AUTH_CREDS
744744
) as client:
745-
if client._connection._weaviate_version.is_lower_than(1, 34, 0):
746-
pytest.skip("Server-side batching not supported in Weaviate < 1.34.0")
745+
if client._connection._weaviate_version.is_lower_than(1, 36, 0):
746+
pytest.skip("Server-side batching not supported in Weaviate < 1.36.0")
747747
collection = client.collections.create(collection_name)
748748
with client.batch.stream() as batch:
749749
batch.add_object(collection_name)

weaviate/collections/batch/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,9 @@ def number_errors(self) -> int:
359359
def _start(self):
360360
pass
361361

362+
def _wait(self):
363+
pass
364+
362365
def _shutdown(self) -> None:
363366
"""Shutdown the current batch and wait for all requests to be finished."""
364367
self.flush()

weaviate/collections/batch/batch_wrapper.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
_ClusterBatch,
1111
_ClusterBatchAsync,
1212
_DynamicBatching,
13-
_ServerSideBatching,
1413
)
1514
from weaviate.collections.batch.sync import _BatchBaseSync
1615
from weaviate.collections.classes.batch import (
@@ -140,8 +139,6 @@ def __init__(
140139
self._connection = connection
141140
self._consistency_level = consistency_level
142141
self._current_batch: Optional[_BatchBaseAsync] = None
143-
# config options
144-
self._batch_mode: _BatchMode = _ServerSideBatching(1)
145142

146143
self._batch_data = _BatchDataWrapper()
147144
self._cluster = _ClusterBatchAsync(connection)
@@ -371,7 +368,7 @@ async def add_reference(
371368
"""
372369
...
373370

374-
async def flush(self) -> None:
371+
def flush(self) -> None:
375372
"""Flush the current batch.
376373
377374
This will send all the objects and references in the current batch to Weaviate.
@@ -505,19 +502,20 @@ def number_errors(self) -> int:
505502
Q = TypeVar("Q", bound=Union[BatchClientProtocolAsync, BatchCollectionProtocolAsync[Properties]])
506503

507504

508-
class _ContextManagerWrapper(Generic[T, P]):
505+
class _ContextManagerSync(Generic[T, P]):
509506
def __init__(self, current_batch: T):
510507
self.__current_batch: T = current_batch
511508

512509
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
513510
self.__current_batch._shutdown()
511+
self.__current_batch._wait()
514512

515513
def __enter__(self) -> P:
516514
self.__current_batch._start()
517515
return self.__current_batch # pyright: ignore[reportReturnType]
518516

519517

520-
class _ContextManagerWrapperAsync(Generic[Q]):
518+
class _ContextManagerAsync(Generic[Q]):
521519
def __init__(self, current_batch: _BatchBaseAsync):
522520
self.__current_batch = current_batch
523521

weaviate/collections/batch/client.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
_BatchMode,
2020
_BatchWrapper,
2121
_BatchWrapperAsync,
22-
_ContextManagerWrapper,
23-
_ContextManagerWrapperAsync,
22+
_ContextManagerAsync,
23+
_ContextManagerSync,
2424
)
2525
from weaviate.collections.batch.sync import _BatchBaseSync
2626
from weaviate.collections.classes.config import ConsistencyLevel, Vectorizers
@@ -146,10 +146,10 @@ async def add_reference(
146146
BatchClient = _BatchClient
147147
BatchClientSync = _BatchClientSync
148148
BatchClientAsync = _BatchClientAsync
149-
ClientBatchingContextManager = _ContextManagerWrapper[
149+
ClientBatchingContextManager = _ContextManagerSync[
150150
Union[BatchClient, BatchClientSync], BatchClientProtocol
151151
]
152-
AsyncClientBatchingContextManager = _ContextManagerWrapperAsync[BatchClientProtocolAsync]
152+
ClientBatchingContextManagerAsync = _ContextManagerAsync[BatchClientProtocolAsync]
153153

154154

155155
class _BatchClientWrapper(_BatchWrapper):
@@ -196,7 +196,7 @@ def __create_batch_and_reset(
196196

197197
self._batch_data = _BatchDataWrapper() # clear old data
198198

199-
return _ContextManagerWrapper(
199+
return _ContextManagerSync(
200200
batch_client(
201201
connection=self._connection,
202202
consistency_level=self._consistency_level,
@@ -310,7 +310,7 @@ def __init__(
310310

311311
def __create_batch_and_reset(self):
312312
self._batch_data = _BatchDataWrapper() # clear old data
313-
return _ContextManagerWrapperAsync(
313+
return _ContextManagerAsync(
314314
BatchClientAsync(
315315
connection=self._connection,
316316
consistency_level=self._consistency_level,
@@ -328,15 +328,15 @@ def experimental(
328328
*,
329329
concurrency: Optional[int] = None,
330330
consistency_level: Optional[ConsistencyLevel] = None,
331-
) -> AsyncClientBatchingContextManager:
331+
) -> ClientBatchingContextManagerAsync:
332332
return self.stream(concurrency=concurrency, consistency_level=consistency_level)
333333

334334
def stream(
335335
self,
336336
*,
337337
concurrency: Optional[int] = None,
338338
consistency_level: Optional[ConsistencyLevel] = None,
339-
) -> AsyncClientBatchingContextManager:
339+
) -> ClientBatchingContextManagerAsync:
340340
"""Configure the batching context manager to use batch streaming.
341341
342342
When you exit the context manager, the final batch will be sent automatically.
@@ -345,9 +345,9 @@ def stream(
345345
concurrency: The number of concurrent streams to use when sending batches. If not provided, the default will be one.
346346
consistency_level: The consistency level to be used when inserting data. If not provided, the default value is `None`.
347347
"""
348-
if self._connection._weaviate_version.is_lower_than(1, 34, 0):
348+
if self._connection._weaviate_version.is_lower_than(1, 36, 0):
349349
raise WeaviateUnsupportedFeatureError(
350-
"Server-side batching", str(self._connection._weaviate_version), "1.34.0"
350+
"Server-side batching", str(self._connection._weaviate_version), "1.36.0"
351351
)
352352
self._batch_mode = _ServerSideBatching(
353353
# concurrency=concurrency

weaviate/collections/batch/collection.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
BatchCollectionProtocolAsync,
2020
_BatchWrapper,
2121
_BatchWrapperAsync,
22-
_ContextManagerWrapper,
23-
_ContextManagerWrapperAsync,
22+
_ContextManagerAsync,
23+
_ContextManagerSync,
2424
)
2525
from weaviate.collections.batch.sync import _BatchBaseSync
2626
from weaviate.collections.classes.config import ConsistencyLevel, Vectorizers
@@ -88,14 +88,14 @@ def add_reference(
8888
class _BatchCollectionSync(Generic[Properties], _BatchBaseSync):
8989
def __init__(
9090
self,
91-
executor: ThreadPoolExecutor,
9291
connection: ConnectionSync,
9392
consistency_level: Optional[ConsistencyLevel],
9493
results: _BatchDataWrapper,
95-
batch_mode: _BatchMode,
9694
name: str,
9795
tenant: Optional[str],
98-
vectorizer_batching: bool,
96+
executor: Optional[ThreadPoolExecutor] = None,
97+
batch_mode: Optional[_BatchMode] = None,
98+
vectorizer_batching: bool = False,
9999
) -> None:
100100
super().__init__(
101101
connection=connection,
@@ -184,11 +184,11 @@ async def add_reference(
184184
BatchCollection = _BatchCollection
185185
BatchCollectionSync = _BatchCollectionSync
186186
BatchCollectionAsync = _BatchCollectionAsync
187-
CollectionBatchingContextManager = _ContextManagerWrapper[
187+
CollectionBatchingContextManager = _ContextManagerSync[
188188
Union[BatchCollection[Properties], BatchCollectionSync[Properties]],
189189
BatchCollectionProtocol[Properties],
190190
]
191-
CollectionBatchingContextManagerAsync = _ContextManagerWrapperAsync[
191+
CollectionBatchingContextManagerAsync = _ContextManagerAsync[
192192
BatchCollectionProtocolAsync[Properties]
193193
]
194194

@@ -239,7 +239,7 @@ def __create_batch_and_reset(
239239
self._vectorizer_batching = False
240240

241241
self._batch_data = _BatchDataWrapper() # clear old data
242-
return _ContextManagerWrapper(
242+
return _ContextManagerSync(
243243
batch_client(
244244
connection=self._connection,
245245
consistency_level=self._consistency_level,
@@ -311,9 +311,9 @@ def stream(
311311
concurrency: The number of concurrent requests when sending batches. This controls the number of concurrent requests
312312
made to Weaviate. If not provided, the default value is 1.
313313
"""
314-
if self._connection._weaviate_version.is_lower_than(1, 34, 0):
314+
if self._connection._weaviate_version.is_lower_than(1, 36, 0):
315315
raise WeaviateUnsupportedFeatureError(
316-
"Server-side batching", str(self._connection._weaviate_version), "1.34.0"
316+
"Server-side batching", str(self._connection._weaviate_version), "1.36.0"
317317
)
318318
self._batch_mode = _ServerSideBatching(
319319
# concurrency=concurrency
@@ -338,7 +338,7 @@ def __init__(
338338

339339
def __create_batch_and_reset(self):
340340
self._batch_data = _BatchDataWrapper() # clear old data
341-
return _ContextManagerWrapperAsync(
341+
return _ContextManagerAsync(
342342
BatchCollectionAsync(
343343
connection=self._connection,
344344
consistency_level=self._consistency_level,
@@ -371,9 +371,9 @@ def stream(
371371
concurrency: The number of concurrent requests when sending batches. This controls the number of concurrent requests
372372
made to Weaviate. If not provided, the default value is 1.
373373
"""
374-
if self._connection._weaviate_version.is_lower_than(1, 34, 0):
374+
if self._connection._weaviate_version.is_lower_than(1, 36, 0):
375375
raise WeaviateUnsupportedFeatureError(
376-
"Server-side batching", str(self._connection._weaviate_version), "1.34.0"
376+
"Server-side batching", str(self._connection._weaviate_version), "1.36.0"
377377
)
378378
self._batch_mode = _ServerSideBatching(
379379
# concurrency=concurrency

weaviate/collections/batch/sync.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
_BatchMode,
1616
_BgThreads,
1717
_ClusterBatch,
18-
_ServerSideBatching,
1918
)
2019
from weaviate.collections.batch.grpc_batch import _BatchGRPC
2120
from weaviate.collections.classes.batch import (
@@ -54,9 +53,9 @@ def __init__(
5453
connection: ConnectionSync,
5554
consistency_level: Optional[ConsistencyLevel],
5655
results: _BatchDataWrapper,
57-
batch_mode: _BatchMode,
58-
executor: ThreadPoolExecutor,
59-
vectorizer_batching: bool,
56+
batch_mode: Optional[_BatchMode] = None,
57+
executor: Optional[ThreadPoolExecutor] = None,
58+
vectorizer_batching: bool = False,
6059
objects: Optional[ObjectsBatchRequest[BatchObject]] = None,
6160
references: Optional[ReferencesBatchRequest[BatchReference]] = None,
6261
) -> None:
@@ -108,8 +107,6 @@ def __init__(
108107
self.__reqs: Queue[Optional[batch_pb2.BatchStreamRequest]] = Queue(maxsize=1)
109108
self.__stop = False
110109

111-
self.__batch_mode = batch_mode
112-
113110
@property
114111
def number_errors(self) -> int:
115112
"""Return the number of errors in the batch."""
@@ -123,12 +120,7 @@ def __all_threads_alive(self) -> bool:
123120
)
124121

125122
def _start(self) -> None:
126-
assert isinstance(self.__batch_mode, _ServerSideBatching), (
127-
"Only server-side batching is supported in this mode"
128-
)
129-
self.__bg_threads = [
130-
self.__start_bg_threads() for _ in range(self.__batch_mode.concurrency)
131-
]
123+
self.__bg_threads = [self.__start_bg_threads() for _ in range(1)]
132124
logger.info(
133125
f"Provisioned {len(self.__bg_threads)} stream(s) to the server for batch processing"
134126
)

weaviate/collections/data/async_.pyi

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import uuid as uuid_package
2-
from typing import Generic, List, Literal, Optional, Sequence, Union, overload
2+
from typing import Generic, Iterable, List, Literal, Optional, Sequence, Union, overload
33

4-
from weaviate.collections.batch.collection import _BatchCollectionWrapper
54
from weaviate.collections.batch.grpc_batch import _BatchGRPC
65
from weaviate.collections.batch.grpc_batch_delete import _BatchDeleteGRPC
76
from weaviate.collections.batch.rest import _BatchREST
@@ -30,7 +29,6 @@ class _DataCollectionAsync(
3029
__batch_delete: _BatchDeleteGRPC
3130
__batch_grpc: _BatchGRPC
3231
__batch_rest: _BatchREST
33-
__batch: _BatchCollectionWrapper[Properties]
3432

3533
async def insert(
3634
self,
@@ -81,3 +79,6 @@ class _DataCollectionAsync(
8179
async def delete_many(
8280
self, where: _Filters, *, verbose: bool = False, dry_run: bool = False
8381
) -> Union[DeleteManyReturn[List[DeleteManyObject]], DeleteManyReturn[None]]: ...
82+
async def ingest(
83+
self, objs: Iterable[Union[Properties, DataObject[Properties, Optional[ReferenceInputs]]]]
84+
) -> BatchObjectReturn: ...

0 commit comments

Comments
 (0)