1818
1919
2020def numpy2blob (tensor : np .ndarray ) -> tuple :
21- """ Convert the numpy input from user to `Tensor` """
21+ """Convert the numpy input from user to `Tensor`. """
2222 try :
2323 dtype = dtype_dict [str (tensor .dtype )]
2424 except KeyError :
@@ -29,7 +29,7 @@ def numpy2blob(tensor: np.ndarray) -> tuple:
2929
3030
3131def blob2numpy (value : ByteString , shape : Union [list , tuple ], dtype : str ) -> np .ndarray :
32- """ Convert `BLOB` result from RedisAI to `np.ndarray` """
32+ """Convert `BLOB` result from RedisAI to `np.ndarray`. """
3333 mm = {
3434 'FLOAT' : 'float32' ,
3535 'DOUBLE' : 'float64'
@@ -40,6 +40,7 @@ def blob2numpy(value: ByteString, shape: Union[list, tuple], dtype: str) -> np.n
4040
4141
4242def list2dict (lst ):
43+ """Convert the list from RedisAI to a dict."""
4344 if len (lst ) % 2 != 0 :
4445 raise RuntimeError ("Can't unpack the list: {}" .format (lst ))
4546 out = {}
@@ -55,10 +56,8 @@ def list2dict(lst):
5556def recursive_bytetransform (arr : List [AnyStr ], target : Callable ) -> list :
5657 """
5758 Recurse value, replacing each element of b'' with the appropriate element.
58- Function returns the same array after inplace operation which updates `arr`
5959
60- :param target: Type of tensor | array
61- :param arr: The array with b'' numbers or recursive array of b''
60+ Function returns the same array after inplace operation which updates `arr`
6261 """
6362 for ix in range (len (arr )):
6463 obj = arr [ix ]
@@ -70,10 +69,16 @@ def recursive_bytetransform(arr: List[AnyStr], target: Callable) -> list:
7069
7170
7271def listify (inp : Union [str , Sequence [str ]]) -> Sequence [str ]:
72+ """Wrap the ``inp`` with a list if it's not a list already."""
7373 return (inp ,) if not isinstance (inp , (list , tuple )) else inp
7474
7575
76- def tensorget_postprocessor (as_numpy , meta_only , rai_result ):
76+ def tensorget_postprocessor (rai_result , as_numpy , meta_only ):
77+ """Process the tensorget output.
78+
79+ If ``as_numpy`` is True, it'll be converted to a numpy array. The required
80+ information such as datatype and shape must be in ``rai_result`` itself.
81+ """
7782 rai_result = list2dict (rai_result )
7883 if meta_only :
7984 return rai_result
0 commit comments