Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 157 additions & 43 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,36 @@
import base64
import boto3
import io
import os
import pixeltable as pxt
import PIL.Image
from botocore.client import Config
from datasets import load_dataset
from dotenv import load_dotenv
from flask import Flask, request, render_template, send_from_directory, redirect
from flask import Flask, request, render_template, send_from_directory, redirect, send_file
from flask_htmx import HTMX
from openai import OpenAI
from typing import cast
from urllib.parse import urlparse
from uuid_extensions import uuid7
from pixeltable.functions.huggingface import clip
from pixeltable.functions.huggingface import clip, image_to_image
from pixeltable.functions import reve, replicate


def pil_to_base64(img: PIL.Image.Image) -> str:
"""Convert a PIL Image to a base64 string."""
if img is None:
return ""
buffered = io.BytesIO()
img.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")

load_dotenv()
oai = OpenAI()

# OpenAI client is optional - only initialize if API key is available
oai = None
if os.environ.get("OPENAI_API_KEY"):
oai = OpenAI()


@pxt.udf
Expand All @@ -24,15 +39,17 @@ def gen_uuid() -> str:

def import_screenshots():
data_files = "s3://xe-zohar-copy/ds/screenshots_sharded/*.parquet"
storage_options={"profile": "tigris-dev"}
storage_options = {
"profile": "tigris-dev",
"endpoint_url": "https://t3.storage.dev",
}

dataset = load_dataset(
"parquet",
split="train",
data_files=data_files,
streaming=False,
storage_options=storage_options,
if_exists="ignore",
)
screenshots = pxt.create_table("screenshots", source=dataset, if_exists="ignore")
screenshots.add_embedding_index(
Expand All @@ -51,30 +68,16 @@ def import_screenshots():
screenshots = cast(pxt.Table, screenshots)


@pxt.query
def get_image(image_id: str) -> PIL.Image.Image:
return (
screenshots.where(screenshots.uuid == image_id)
.select(screenshots.image)
.limit(1)
)


def encode_image(file_path):
with open(file_path, "rb") as f:
base64_image = base64.b64encode(f.read()).decode("utf-8")
return base64_image


# @pxt.udf
# def image_edit(prompt: str, input: PIL.Image.Image) -> PIL.Image.Image:
# pass


generated_images = pxt.create_table(
"generated_images",
{
"input_image_id": pxt.String,
"input_image": pxt.Image,
"prompt": pxt.String,
},
if_exists="ignore",
Expand All @@ -84,12 +87,9 @@ def encode_image(file_path):
)
generated_images.add_computed_column(uuid=gen_uuid(), if_exists="ignore")
generated_images.add_computed_column(
input_image=get_image(generated_images.input_image_id),
gen_image=reve.edit(generated_images.input_image, generated_images.prompt),
if_exists="ignore",
)
# generated_images.add_computed_column(
# gen_image=image_generations(generated_images.prompt, model="gpt-image-1")
# )


def perform_search(screenshots: pxt.Table, query: str) -> pxt.ResultSet:
Expand All @@ -98,16 +98,16 @@ def perform_search(screenshots: pxt.Table, query: str) -> pxt.ResultSet:
screenshots.order_by(sim, asc=False)
.select(
uuid=screenshots.uuid,
url=screenshots.image.fileurl,
image=screenshots.image,
)
# .limit(6)
.limit(1)
.limit(6)
)
return results.collect()


app = Flask(__name__)
htmx = HTMX(app)
app.jinja_env.filters["b64encode"] = pil_to_base64
tigris = boto3.client(
"s3",
endpoint_url="https://t3.storage.dev",
Expand All @@ -132,6 +132,29 @@ def healthz():
return "OK"


@app.route("/api/image/<uuid>")
def api_image(uuid):
"""Serve an image from the screenshots table by UUID."""
result = (
screenshots.where(screenshots.uuid == uuid)
.select(screenshots.image)
.limit(1)
.collect()
)
if not result:
return "Image not found", 404

img = result[0]["image"]
if img is None:
return "Image not found", 404

# Convert PIL Image to bytes and serve
buffered = io.BytesIO()
img.save(buffered, format="PNG")
buffered.seek(0)
return send_file(buffered, mimetype="image/png")


@app.route("/")
def index():
if htmx:
Expand All @@ -154,26 +177,14 @@ def api_search():
results = []

for hit in hits:
print(hit)
uuid = hit["uuid"]
url = hit["url"]

# Parse S3 URL to extract bucket and key
parsed_url = urlparse(url)
bucket = parsed_url.netloc
key = parsed_url.path[1:]
print(bucket, key)

presigned_url = tigris.generate_presigned_url(
"get_object",
Params={"Bucket": bucket, "Key": key},
ExpiresIn=3600,
)

image = hit["image"]
# Convert image to base64 data URL to avoid concurrent request issues
b64 = pil_to_base64(image)
results.append(
{
"id": uuid,
"url": presigned_url,
"url": f"data:image/png;base64,{b64}",
}
)

Expand All @@ -182,3 +193,106 @@ def api_search():
query=query,
results=results,
)


def generate_with_reve(image: PIL.Image.Image, prompt: str) -> PIL.Image.Image:
"""Generate image using Reve edit."""
return reve.edit(image, prompt)


def generate_with_huggingface(image: PIL.Image.Image, prompt: str) -> PIL.Image.Image:
"""Generate image using HuggingFace image-to-image."""
return image_to_image(
image,
prompt,
model_id="timbrooks/instruct-pix2pix",
)


def generate_with_replicate(image: PIL.Image.Image, prompt: str) -> PIL.Image.Image:
"""Generate image using Replicate FLUX."""
return replicate.run(
"black-forest-labs/flux-1.1-pro",
{
"prompt": prompt,
"image": image,
"prompt_upsampling": True,
},
)


@app.route("/api/generate-image", methods=["POST"])
def api_generate_image():
if not htmx:
return redirect("/")

image_id = request.form.get("image_id", "")
prompt = request.form.get("prompt", "")
model = request.form.get("model", "reve")

if not image_id or not prompt:
return render_template(
"partials/api/generate_error.html",
error="Please select an image and enter a prompt.",
)

# Get the input image from screenshots table
input_image_result = (
screenshots.where(screenshots.uuid == image_id)
.select(screenshots.image)
.limit(1)
.collect()
)

if not input_image_result:
return render_template(
"partials/api/generate_error.html",
error="Could not find the selected image.",
)

input_image = input_image_result[0]["image"]

# Generate image using selected model
try:
if model == "reve":
# Use the computed column approach for Reve
generated_images.insert([{"input_image": input_image, "prompt": prompt}])
result = (
generated_images.where(generated_images.prompt == prompt)
.select(
input_image=generated_images.input_image,
gen_image=generated_images.gen_image,
prompt=generated_images.prompt,
)
.order_by(generated_images.uuid, asc=False)
.limit(1)
.collect()
)
if not result:
raise Exception("Failed to generate image")
generated_image = result[0]["gen_image"]
elif model == "huggingface":
generated_image = generate_with_huggingface(input_image, prompt)
elif model == "fal":
generated_image = generate_with_replicate(input_image, prompt)
else:
return render_template(
"partials/api/generate_error.html",
error=f"Unknown model: {model}",
)
except Exception as e:
return render_template(
"partials/api/generate_error.html",
error=f"Generation failed: {str(e)}",
)

return render_template(
"partials/api/generate_result.html",
generated_image=generated_image,
input_image=input_image,
prompt=prompt,
)


if __name__ == "__main__":
app.run(debug=True, host="0.0.0.0", port=5000)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ dependencies = [
"ipykernel>=7.1.0",
"ipython>=9.7.0",
"openai>=2.8.1",
"pixeltable==0.5.0",
"pixeltable>=0.5.1",
"pixeltable-yolox>=0.4.2",
"pydantic>=2.12.4",
"s3fs>=2025.10.0",
Expand Down
Loading
Loading