diff --git a/main.py b/main.py index fbd8548..a54a3b8 100644 --- a/main.py +++ b/main.py @@ -1,21 +1,36 @@ import base64 import boto3 +import io import os import pixeltable as pxt import PIL.Image from botocore.client import Config from datasets import load_dataset from dotenv import load_dotenv -from flask import Flask, request, render_template, send_from_directory, redirect +from flask import Flask, request, render_template, send_from_directory, redirect, send_file from flask_htmx import HTMX from openai import OpenAI from typing import cast from urllib.parse import urlparse from uuid_extensions import uuid7 -from pixeltable.functions.huggingface import clip +from pixeltable.functions.huggingface import clip, image_to_image +from pixeltable.functions import reve, replicate + + +def pil_to_base64(img: PIL.Image.Image) -> str: + """Convert a PIL Image to a base64 string.""" + if img is None: + return "" + buffered = io.BytesIO() + img.save(buffered, format="PNG") + return base64.b64encode(buffered.getvalue()).decode("utf-8") load_dotenv() -oai = OpenAI() + +# OpenAI client is optional - only initialize if API key is available +oai = None +if os.environ.get("OPENAI_API_KEY"): + oai = OpenAI() @pxt.udf @@ -24,7 +39,10 @@ def gen_uuid() -> str: def import_screenshots(): data_files = "s3://xe-zohar-copy/ds/screenshots_sharded/*.parquet" - storage_options={"profile": "tigris-dev"} + storage_options = { + "profile": "tigris-dev", + "endpoint_url": "https://t3.storage.dev", + } dataset = load_dataset( "parquet", @@ -32,7 +50,6 @@ def import_screenshots(): data_files=data_files, streaming=False, storage_options=storage_options, - if_exists="ignore", ) screenshots = pxt.create_table("screenshots", source=dataset, if_exists="ignore") screenshots.add_embedding_index( @@ -51,30 +68,16 @@ def import_screenshots(): screenshots = cast(pxt.Table, screenshots) -@pxt.query -def get_image(image_id: str) -> PIL.Image.Image: - return ( - screenshots.where(screenshots.uuid == image_id) - .select(screenshots.image) - .limit(1) - ) - - def encode_image(file_path): with open(file_path, "rb") as f: base64_image = base64.b64encode(f.read()).decode("utf-8") return base64_image -# @pxt.udf -# def image_edit(prompt: str, input: PIL.Image.Image) -> PIL.Image.Image: -# pass - - generated_images = pxt.create_table( "generated_images", { - "input_image_id": pxt.String, + "input_image": pxt.Image, "prompt": pxt.String, }, if_exists="ignore", @@ -84,12 +87,9 @@ def encode_image(file_path): ) generated_images.add_computed_column(uuid=gen_uuid(), if_exists="ignore") generated_images.add_computed_column( - input_image=get_image(generated_images.input_image_id), + gen_image=reve.edit(generated_images.input_image, generated_images.prompt), if_exists="ignore", ) -# generated_images.add_computed_column( -# gen_image=image_generations(generated_images.prompt, model="gpt-image-1") -# ) def perform_search(screenshots: pxt.Table, query: str) -> pxt.ResultSet: @@ -98,16 +98,16 @@ def perform_search(screenshots: pxt.Table, query: str) -> pxt.ResultSet: screenshots.order_by(sim, asc=False) .select( uuid=screenshots.uuid, - url=screenshots.image.fileurl, + image=screenshots.image, ) - # .limit(6) - .limit(1) + .limit(6) ) return results.collect() app = Flask(__name__) htmx = HTMX(app) +app.jinja_env.filters["b64encode"] = pil_to_base64 tigris = boto3.client( "s3", endpoint_url="https://t3.storage.dev", @@ -132,6 +132,29 @@ def healthz(): return "OK" +@app.route("/api/image/") +def api_image(uuid): + """Serve an image from the screenshots table by UUID.""" + result = ( + screenshots.where(screenshots.uuid == uuid) + .select(screenshots.image) + .limit(1) + .collect() + ) + if not result: + return "Image not found", 404 + + img = result[0]["image"] + if img is None: + return "Image not found", 404 + + # Convert PIL Image to bytes and serve + buffered = io.BytesIO() + img.save(buffered, format="PNG") + buffered.seek(0) + return send_file(buffered, mimetype="image/png") + + @app.route("/") def index(): if htmx: @@ -154,26 +177,14 @@ def api_search(): results = [] for hit in hits: - print(hit) uuid = hit["uuid"] - url = hit["url"] - - # Parse S3 URL to extract bucket and key - parsed_url = urlparse(url) - bucket = parsed_url.netloc - key = parsed_url.path[1:] - print(bucket, key) - - presigned_url = tigris.generate_presigned_url( - "get_object", - Params={"Bucket": bucket, "Key": key}, - ExpiresIn=3600, - ) - + image = hit["image"] + # Convert image to base64 data URL to avoid concurrent request issues + b64 = pil_to_base64(image) results.append( { "id": uuid, - "url": presigned_url, + "url": f"data:image/png;base64,{b64}", } ) @@ -182,3 +193,106 @@ def api_search(): query=query, results=results, ) + + +def generate_with_reve(image: PIL.Image.Image, prompt: str) -> PIL.Image.Image: + """Generate image using Reve edit.""" + return reve.edit(image, prompt) + + +def generate_with_huggingface(image: PIL.Image.Image, prompt: str) -> PIL.Image.Image: + """Generate image using HuggingFace image-to-image.""" + return image_to_image( + image, + prompt, + model_id="timbrooks/instruct-pix2pix", + ) + + +def generate_with_replicate(image: PIL.Image.Image, prompt: str) -> PIL.Image.Image: + """Generate image using Replicate FLUX.""" + return replicate.run( + "black-forest-labs/flux-1.1-pro", + { + "prompt": prompt, + "image": image, + "prompt_upsampling": True, + }, + ) + + +@app.route("/api/generate-image", methods=["POST"]) +def api_generate_image(): + if not htmx: + return redirect("/") + + image_id = request.form.get("image_id", "") + prompt = request.form.get("prompt", "") + model = request.form.get("model", "reve") + + if not image_id or not prompt: + return render_template( + "partials/api/generate_error.html", + error="Please select an image and enter a prompt.", + ) + + # Get the input image from screenshots table + input_image_result = ( + screenshots.where(screenshots.uuid == image_id) + .select(screenshots.image) + .limit(1) + .collect() + ) + + if not input_image_result: + return render_template( + "partials/api/generate_error.html", + error="Could not find the selected image.", + ) + + input_image = input_image_result[0]["image"] + + # Generate image using selected model + try: + if model == "reve": + # Use the computed column approach for Reve + generated_images.insert([{"input_image": input_image, "prompt": prompt}]) + result = ( + generated_images.where(generated_images.prompt == prompt) + .select( + input_image=generated_images.input_image, + gen_image=generated_images.gen_image, + prompt=generated_images.prompt, + ) + .order_by(generated_images.uuid, asc=False) + .limit(1) + .collect() + ) + if not result: + raise Exception("Failed to generate image") + generated_image = result[0]["gen_image"] + elif model == "huggingface": + generated_image = generate_with_huggingface(input_image, prompt) + elif model == "fal": + generated_image = generate_with_replicate(input_image, prompt) + else: + return render_template( + "partials/api/generate_error.html", + error=f"Unknown model: {model}", + ) + except Exception as e: + return render_template( + "partials/api/generate_error.html", + error=f"Generation failed: {str(e)}", + ) + + return render_template( + "partials/api/generate_result.html", + generated_image=generated_image, + input_image=input_image, + prompt=prompt, + ) + + +if __name__ == "__main__": + app.run(debug=True, host="0.0.0.0", port=5000) diff --git a/pyproject.toml b/pyproject.toml index 630aec0..65e23e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ dependencies = [ "ipykernel>=7.1.0", "ipython>=9.7.0", "openai>=2.8.1", - "pixeltable==0.5.0", + "pixeltable>=0.5.1", "pixeltable-yolox>=0.4.2", "pydantic>=2.12.4", "s3fs>=2025.10.0", diff --git a/static/css/main.css b/static/css/main.css index 6dd4819..b672b42 100644 --- a/static/css/main.css +++ b/static/css/main.css @@ -307,3 +307,121 @@ input.form-control.search::placeholder { h1 { margin-bottom: 0.5rem; } + +/* Form controls for generation panel */ +textarea.form-control, +select.form-control { + background-color: #282828; + border: 2px solid #665c54; + color: #f9f5d7; + border-radius: 0.5rem; + padding: 0.75rem 1rem; + font-family: "Iosevka Aile Iaso", sans-serif; + font-size: 1rem; + transition: all 0.3s ease; +} + +select.form-control { + cursor: pointer; + appearance: none; + background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='12' height='12' viewBox='0 0 12 12'%3E%3Cpath fill='%23bdae93' d='M6 8L1 3h10z'/%3E%3C/svg%3E"); + background-repeat: no-repeat; + background-position: right 1rem center; + padding-right: 2.5rem; +} + +textarea.form-control:focus { + outline: none; + border-color: #b16286; + box-shadow: 0 0 0 0.2rem rgba(177, 98, 134, 0.25); + background-color: #3c3836; +} + +textarea.form-control::placeholder { + color: #bdae93; + opacity: 0.7; +} + +/* Button styles */ +.btn { + background-color: #b16286; + color: #f9f5d7; + border: none; + border-radius: 0.5rem; + padding: 0.5rem 1.5rem; + font-family: "Iosevka Aile Iaso", sans-serif; + font-size: 1rem; + cursor: pointer; + transition: all 0.2s ease; +} + +.btn:hover { + background-color: #d3869b; + transform: translateY(-1px); +} + +.btn:active { + transform: translateY(0); +} + +.btn:disabled { + background-color: #665c54; + cursor: not-allowed; + transform: none; +} + +/* Generation panel styles */ +#generation-panel { + background-color: #282828; + border: 1px solid #3c3836; + border-radius: 0.5rem; + padding: 1.5rem; + margin-top: 2rem; +} + +#generation-panel h3 { + color: #f9f5d7; + margin-bottom: 1rem; +} + +/* Light mode styles */ +@media (prefers-color-scheme: light) { + textarea.form-control, + select.form-control { + background-color: #fbf1c7; + border-color: #928374; + color: #1d2021; + } + + select.form-control { + background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='12' height='12' viewBox='0 0 12 12'%3E%3Cpath fill='%23665c54' d='M6 8L1 3h10z'/%3E%3C/svg%3E"); + } + + textarea.form-control:focus, + select.form-control:focus { + border-color: #b16286; + background-color: #f9f5d7; + } + + textarea.form-control::placeholder { + color: #665c54; + } + + .btn { + background-color: #b16286; + color: #f9f5d7; + } + + .btn:hover { + background-color: #d3869b; + } + + #generation-panel { + background-color: #fbf1c7; + border-color: #ebdbb2; + } + + #generation-panel h3 { + color: #1d2021; + } +} diff --git a/templates/partials/api/generate_error.html b/templates/partials/api/generate_error.html new file mode 100644 index 0000000..43bfe9a --- /dev/null +++ b/templates/partials/api/generate_error.html @@ -0,0 +1,11 @@ +
+

Error

+

{{ error }}

+ +
diff --git a/templates/partials/api/generate_result.html b/templates/partials/api/generate_result.html new file mode 100644 index 0000000..9715513 --- /dev/null +++ b/templates/partials/api/generate_result.html @@ -0,0 +1,31 @@ +
+

Generated Image

+

Prompt: "{{ prompt }}"

+ +
+
+

Original

+ Original image +
+
+

Generated

+ Generated image +
+
+ + +
diff --git a/templates/partials/api/search.html b/templates/partials/api/search.html index 96c1a0a..669f869 100644 --- a/templates/partials/api/search.html +++ b/templates/partials/api/search.html @@ -1,21 +1,77 @@
{% for result in results %} -
-
- - - Search result -
+
+
{% endfor %} -
\ No newline at end of file + + +
+

Select an image above to generate a variation

+ +
+ + +
+ + +
+ +
+ +
+ +
+ +
+
+
+
+
+
+
+
+

Generating image... this may take a few seconds

+
+
+
+ +