|
| 1 | +from typing import Dict, Text, Optional, Generator, Any, Tuple |
| 2 | +from .tools.system.id import UniqueID |
| 3 | +from .tools.typing.response import ModelResponse |
| 4 | +from libraries.colorama import init, Fore, Style |
| 5 | +import yaml |
| 6 | +import requests |
| 7 | +import fake_useragent |
| 8 | +import logging |
| 9 | +import os |
| 10 | + |
| 11 | +init() |
| 12 | + |
| 13 | +class OpenGPTError(Exception): |
| 14 | + @staticmethod |
| 15 | + def Print(context: Text, warn: Optional[bool] = False) -> None: |
| 16 | + if warn: |
| 17 | + print(Fore.YELLOW + "Warning: " + context + Style.RESET_ALL) |
| 18 | + else: |
| 19 | + print(Fore.RED + "Error: " + context + Style.RESET_ALL) |
| 20 | + sys.exit(1) |
| 21 | + |
| 22 | +class Model: |
| 23 | + @classmethod |
| 24 | + def __init__(self: type, style: Optional[Text] = "Hotpot Art 9") -> None: |
| 25 | + self._SETUP_LOGGER() |
| 26 | + self.__LoadStyles() |
| 27 | + self.__DIR: Text = os.getcwd() |
| 28 | + self.__session: requests.Session = requests.Session() |
| 29 | + self.__UNIQUE_ID: str = UniqueID(16) |
| 30 | + self.STYLE: Text = style |
| 31 | + self.__STYLE_ID = self.__GetStyleID(self.STYLE) |
| 32 | + self.__HEADERS: Dict[str, str] = { |
| 33 | + "Accept": "*/*", |
| 34 | + "Accept-Language": "pt-BR,pt;q=0.9,en-US;q=0.8,en;q=0.7", |
| 35 | + "Content-Type": f"multipart/form-data; boundary=----WebKitFormBoundary{self.__UNIQUE_ID}", |
| 36 | + "Authorization": "hotpot-temp9n88MmVw8uaDzmoBq", |
| 37 | + "Host": "api.hotpot.ai", |
| 38 | + "Origin": "https://hotpot.ai", |
| 39 | + "Referer": "https://hotpot.ai/", |
| 40 | + "Sec-Ch-Ua": "\"Chromium\";v=\"112\", \"Google Chrome\";v=\"112\", \"Not:A-Brand\";v=\"99\"", |
| 41 | + "Sec-Ch-Ua-mobile": "?0", |
| 42 | + "Sec-Ch-Ua-platform": "\"Windows\"", |
| 43 | + "Sec-Fetch-Dest": "empty", |
| 44 | + "Sec-Fetch-Mode": "cors", |
| 45 | + "Sec-Fetch-Site": "same-site", |
| 46 | + "User-Agent": fake_useragent.UserAgent().random |
| 47 | + } |
| 48 | + |
| 49 | + @classmethod |
| 50 | + def _SETUP_LOGGER(self: type) -> None: |
| 51 | + self.__logger: logging.getLogger = logging.getLogger(__name__) |
| 52 | + self.__logger.setLevel(logging.DEBUG) |
| 53 | + console_handler: logging.StreamHandler = logging.StreamHandler() |
| 54 | + console_handler.setLevel(logging.DEBUG) |
| 55 | + formatter: logging.Formatter = logging.Formatter("Model - %(levelname)s - %(message)s") |
| 56 | + console_handler.setFormatter(formatter) |
| 57 | + |
| 58 | + self.__logger.addHandler(console_handler) |
| 59 | + |
| 60 | + @classmethod |
| 61 | + def __GetStyleID(self: type, style: Text) -> int: |
| 62 | + if style in self.__DATA: |
| 63 | + return int(self.__DATA[style]) |
| 64 | + else: |
| 65 | + OpenGPTError.Print(context=f"The style \"{style}\" not found. Changing to \"Hotpot Art 9\"", warn=True) |
| 66 | + self.STYLE = "Hotpot Art 9" |
| 67 | + return int(140) |
| 68 | + |
| 69 | + @classmethod |
| 70 | + def UpdateStyle(self: type, style: Text) -> None: |
| 71 | + self.STYLE = style |
| 72 | + self.__STYLE_ID = self.__GetStyleID(self.STYLE) |
| 73 | + |
| 74 | + @classmethod |
| 75 | + def __LoadStyles(self: type) -> None: |
| 76 | + self.__DATA: Dict[Text, Text] = yaml.safe_load(open(self.__DIR + "/styles.yml").read()) |
| 77 | + |
| 78 | + @classmethod |
| 79 | + def __Fields(self: type, *args: Tuple[int, str], **kwargs: Dict[str, Any]) -> Text: |
| 80 | + return kwargs |
| 81 | + |
| 82 | + @classmethod |
| 83 | + def __AddField(self: type, field: Text, value: Any, end: Optional[bool] = False) -> Text: |
| 84 | + form: Text = '' |
| 85 | + |
| 86 | + form += f"\n\n------WebKitFormBoundary{self.__UNIQUE_ID}" |
| 87 | + form += f"\nContent-Disposition: form-data; name=\"{field}\"" |
| 88 | + form += f"\n\n{value}" |
| 89 | + |
| 90 | + if end: |
| 91 | + form += f"\n------WebKitFormBoundary{self.__UNIQUE_ID}--" |
| 92 | + return form |
| 93 | + |
| 94 | + @classmethod |
| 95 | + def Generate(self: type, prompt: Text, width: Optional[int] = 256, height: Optional[int] = 256) -> Generator[ModelResponse, None, None]: |
| 96 | + __DATA: Dict[str, str] = self.__Fields(seedValue=-1, inputText=prompt, width=width, height=height, styleId=self.__STYLE_ID, |
| 97 | + styleLabel=self.STYLE, isPrivate=False, requestId=f"8-{self.__UNIQUE_ID}", |
| 98 | + resultUrl=f"https://hotpotmedia.s3.us-east-2.amazonaws.com/8-{self.__UNIQUE_ID}.png") |
| 99 | + |
| 100 | + __FORM_DATA: Text = '' |
| 101 | + |
| 102 | + for field in __DATA: |
| 103 | + if field != "resultUrl": |
| 104 | + __FORM_DATA += self.__AddField(field, __DATA[field]) |
| 105 | + else: |
| 106 | + __FORM_DATA += self.__AddField(field, __DATA[field], end=True) |
| 107 | + |
| 108 | + self.__logger.debug("Generating image " + Fore.CYAN + f"\"{prompt}\"" + Style.RESET_ALL) |
| 109 | + url: Text = self.__session.post("https://api.hotpot.ai/art-premium-test1", headers=self.__HEADERS, data=__FORM_DATA).content |
| 110 | + return ModelResponse(**{ |
| 111 | + "id": __DATA["requestId"], |
| 112 | + "url": url.decode().replace("\"", ""), |
| 113 | + "style": self.STYLE, |
| 114 | + "width": __DATA["width"], |
| 115 | + "height": __DATA["height"] |
| 116 | + }) |
0 commit comments