55import requests
66import aiohttp
77
8- from .providers import get_provider , convert_history , DEFAULT_MAX_TOKENS , DEFAULT_PROVIDER
8+ from .providers import get_provider , convert_history
9+ from .utils import ensure_image_uri
910
1011##
1112## printing
@@ -34,40 +35,44 @@ def prepare_model(prov, model_key, model=None):
3435 return {'model' : model } if model is not None else {}
3536
3637def prepare_auth (prov , api_key = None ):
37- if (auth_func := prov .get ( ' authorize' ) ) is not None :
38- if (api_key := os .environ .get (prov [ ' api_key_env' ] )) is None :
39- raise Exception ('Cannot find API key in {api_key_env}' )
38+ if (auth_func := prov .authorize ) is not None :
39+ if (api_key := os .environ .get (prov . api_key_env )) is None :
40+ raise Exception (f 'Cannot find API key in { prov . api_key_env } ' )
4041 headers_auth = auth_func (api_key )
4142 else :
4243 headers_auth = {}
4344 return headers_auth
4445
4546def prepare_request (
46- query , provider = DEFAULT_PROVIDER , system = None , image = None , prefill = None , prediction = None , history = None ,
47- base_url = None , path = None , api_key = None , model = None , max_tokens = DEFAULT_MAX_TOKENS , ** kwargs
47+ query , provider = None , system = None , image = None , prefill = None , prediction = None , history = None ,
48+ base_url = None , path = None , api_key = None , model = None , max_tokens = None , ** kwargs
4849):
4950 # external provider details
5051 prov = get_provider (provider )
51- max_tokens_name = prov .get ('max_tokens_name' , 'max_completion_tokens' )
5252 url = prepare_url (prov , 'chat_path' , base_url = base_url , path = path )
5353 payload_model = prepare_model (prov , 'chat_model' , model = model )
5454
5555 # convert history to provider format
56- history = convert_history (history , prov [ ' content' ] )
56+ history = convert_history (history , prov . content )
5757
5858 # get extra headers
5959 headers_auth = prepare_auth (prov , api_key = api_key )
6060 headers_extra = prov .get ('headers' , {})
6161
6262 # get message payload
63- content = prov ['content' ](query , image = image )
64- payload_message = prov ['payload' ](
63+ img_data = ensure_image_uri (image )
64+ content = prov .content (query , image = img_data )
65+ payload_message = prov .payload (
6566 content , system = system , prefill = prefill , prediction = prediction , history = history
6667 )
6768
6869 # compose request
6970 headers = {'Content-Type' : 'application/json' , ** headers_auth , ** headers_extra }
70- payload = {** payload_model , ** payload_message , max_tokens_name : max_tokens , ** kwargs }
71+ payload = {** payload_model , ** payload_message , ** kwargs }
72+
73+ # add in max tokens
74+ if max_tokens is not None :
75+ payload [prov .max_tokens_name ] = max_tokens
7176
7277 # return url, headers, payload
7378 return url , headers , payload
@@ -76,10 +81,9 @@ def prepare_request(
7681## requests
7782##
7883
79- def reply (query , provider = DEFAULT_PROVIDER , history = None , prefill = None , dryrun = False , ** kwargs ):
84+ def reply (query , provider = None , history = None , prefill = None , dryrun = False , ** kwargs ):
8085 # get provider
8186 prov = get_provider (provider )
82- extractor = prov ['response' ]
8387
8488 # prepare request
8589 url , headers , payload = prepare_request (
@@ -97,7 +101,7 @@ def reply(query, provider=DEFAULT_PROVIDER, history=None, prefill=None, dryrun=F
97101
98102 # extract text
99103 data = response .json ()
100- text = extractor (data )
104+ text = prov . response (data )
101105
102106 # add in prefill
103107 if prefill is not None :
@@ -106,10 +110,9 @@ def reply(query, provider=DEFAULT_PROVIDER, history=None, prefill=None, dryrun=F
106110 # return text
107111 return text
108112
109- async def reply_async (query , provider = DEFAULT_PROVIDER , history = None , prefill = None , ** kwargs ):
113+ async def reply_async (query , provider = None , history = None , prefill = None , ** kwargs ):
110114 # get provider
111115 prov = get_provider (provider )
112- extractor = prov ['response' ]
113116
114117 # prepare request
115118 url , headers , payload = prepare_request (
@@ -123,7 +126,7 @@ async def reply_async(query, provider=DEFAULT_PROVIDER, history=None, prefill=No
123126
124127 # extract text
125128 data = await response .json ()
126- text = extractor (data )
129+ text = prov . response (data )
127130
128131 # add in prefill
129132 if prefill is not None :
@@ -154,10 +157,9 @@ async def iter_lines(inputs):
154157 if len (buffer ) > 0 :
155158 yield buffer
156159
157- def stream (query , provider = DEFAULT_PROVIDER , history = None , prefill = None , ** kwargs ):
160+ def stream (query , provider = None , history = None , prefill = None , ** kwargs ):
158161 # get provider
159162 prov = get_provider (provider )
160- extractor = prov ['stream' ]
161163
162164 # prepare request
163165 url , headers , payload = prepare_request (
@@ -181,14 +183,13 @@ def stream(query, provider=DEFAULT_PROVIDER, history=None, prefill=None, **kwarg
181183 for line in response .iter_lines ():
182184 if (data := parse_sse (line )) is not None :
183185 parsed = json .loads (data )
184- text = extractor (parsed )
186+ text = prov . stream (parsed )
185187 if text is not None :
186188 yield text
187189
188- async def stream_async (query , provider = DEFAULT_PROVIDER , history = None , prefill = None , ** kwargs ):
190+ async def stream_async (query , provider = None , history = None , prefill = None , ** kwargs ):
189191 # get provider
190192 prov = get_provider (provider )
191- extractor = prov ['stream' ]
192193
193194 # prepare request
194195 url , headers , payload = prepare_request (
@@ -214,15 +215,15 @@ async def stream_async(query, provider=DEFAULT_PROVIDER, history=None, prefill=N
214215 async for line in iter_lines (chunks ):
215216 if (data := parse_sse (line )) is not None :
216217 parsed = json .loads (data )
217- text = extractor (parsed )
218+ text = prov . stream (parsed )
218219 if text is not None :
219220 yield text
220221
221222##
222223## embeddings
223224##
224225
225- def embed (text , provider = DEFAULT_PROVIDER , base_url = None , path = None , api_key = None , model = None , timeout = None , ** kwargs ):
226+ def embed (text , provider = None , base_url = None , path = None , api_key = None , model = None , timeout = None , ** kwargs ):
226227 # get provider details
227228 prov = get_provider (provider )
228229 url = prepare_url (prov , f'embed_path' , base_url = base_url , path = path )
@@ -233,7 +234,7 @@ def embed(text, provider=DEFAULT_PROVIDER, base_url=None, path=None, api_key=Non
233234
234235 # make payload
235236 payload_model = prepare_model (prov , 'embed_model' , model = model )
236- payload_message = prov [ ' embed_payload' ] (text )
237+ payload_message = prov . embed_payload (text )
237238
238239 # compose request
239240 headers = {'Content-Type' : 'application/json' , ** headers_auth , ** headers_extra }
@@ -245,12 +246,12 @@ def embed(text, provider=DEFAULT_PROVIDER, base_url=None, path=None, api_key=Non
245246
246247 # extract result
247248 data = response .json ()
248- result = prov [ ' embed_response' ] (data )
249+ result = prov . embed_response (data )
249250
250251 # return result
251252 return result
252253
253- def tokenize (text , provider = DEFAULT_PROVIDER , base_url = None , path = None , api_key = None , model = None , timeout = None , ** kwargs ):
254+ def tokenize (text , provider = None , base_url = None , path = None , api_key = None , model = None , timeout = None , ** kwargs ):
254255 # get provider details
255256 prov = get_provider (provider )
256257 url = prepare_url (prov , 'tokenize_path' , base_url = base_url , path = path )
@@ -261,7 +262,7 @@ def tokenize(text, provider=DEFAULT_PROVIDER, base_url=None, path=None, api_key=
261262
262263 # make payload
263264 payload_model = prepare_model (prov , 'embed_model' , model = model )
264- payload_message = prov [ ' tokenize_payload' ] (text )
265+ payload_message = prov . tokenize_payload (text )
265266
266267 # compose request
267268 headers = {'Content-Type' : 'application/json' , ** headers_auth , ** headers_extra }
@@ -273,7 +274,7 @@ def tokenize(text, provider=DEFAULT_PROVIDER, base_url=None, path=None, api_key=
273274
274275 # extract result
275276 data = response .json ()
276- result = prov [ ' tokenize_response' ] (data )
277+ result = prov . tokenize_response (data )
277278
278279 # return result
279280 return result
0 commit comments