|
1 | | -from typing import Any, Dict, Optional |
| 1 | +import concurrent.futures |
| 2 | +from typing import Any, Dict, List, Optional, Union |
2 | 3 |
|
3 | 4 | import together |
4 | 5 | from together.utils import create_post_request, get_logger |
|
7 | 8 | logger = get_logger(str(__name__)) |
8 | 9 |
|
9 | 10 |
|
| 11 | +class DataItem: |
| 12 | + def __init__(self, embedding: List[float]): |
| 13 | + self.embedding = embedding |
| 14 | + |
| 15 | + |
| 16 | +class EmbeddingsOutput: |
| 17 | + def __init__(self, data: List[DataItem]): |
| 18 | + self.data = data |
| 19 | + |
| 20 | + |
10 | 21 | class Embeddings: |
11 | 22 | @classmethod |
12 | 23 | def create( |
13 | | - self, |
14 | | - input: str, |
| 24 | + cls, |
| 25 | + input: Union[str, List[str]], |
15 | 26 | model: Optional[str] = "", |
16 | | - ) -> Dict[str, Any]: |
| 27 | + ) -> EmbeddingsOutput: |
17 | 28 | if model == "": |
18 | 29 | model = together.default_embedding_model |
19 | 30 |
|
20 | | - parameter_payload = { |
21 | | - "input": input, |
22 | | - "model": model, |
23 | | - } |
| 31 | + if isinstance(input, str): |
| 32 | + parameter_payload = { |
| 33 | + "input": input, |
| 34 | + "model": model, |
| 35 | + } |
| 36 | + |
| 37 | + response = cls._process_input(parameter_payload) |
| 38 | + |
| 39 | + return EmbeddingsOutput([DataItem(response["data"][0]["embedding"])]) |
24 | 40 |
|
| 41 | + elif isinstance(input, list): |
| 42 | + # If input is a list, process each string concurrently |
| 43 | + with concurrent.futures.ThreadPoolExecutor() as executor: |
| 44 | + parameter_payloads = [{"input": item, "model": model} for item in input] |
| 45 | + results = list(executor.map(cls._process_input, parameter_payloads)) |
| 46 | + |
| 47 | + return EmbeddingsOutput( |
| 48 | + [DataItem(item["data"][0]["embedding"]) for item in results] |
| 49 | + ) |
| 50 | + |
| 51 | + @classmethod |
| 52 | + def _process_input(cls, parameter_payload: Dict[str, Any]) -> Dict[str, Any]: |
25 | 53 | # send request |
26 | 54 | response = create_post_request( |
27 | 55 | url=together.api_base_embeddings, json=parameter_payload |
28 | 56 | ) |
29 | 57 |
|
| 58 | + # return the json as a DotDict |
30 | 59 | try: |
31 | 60 | response_json = dict(response.json()) |
32 | | - |
33 | 61 | except Exception as e: |
34 | 62 | raise together.JSONError(e, http_status=response.status_code) |
| 63 | + |
35 | 64 | return response_json |
0 commit comments