Skip to content

Commit afb61f5

Browse files
authored
Merge pull request #41 from jvJUCA/eye-tracking
[GSOC-25] Update and Simplification of Gaze Tracking Prediction Pipeline
2 parents 50ffbb7 + 24e2de9 commit afb61f5

File tree

9 files changed

+284
-41
lines changed

9 files changed

+284
-41
lines changed

.dockerignore

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,37 @@
1-
Dockerfile
2-
README.md
1+
# Arquivos e pastas que não precisam ir pro container
2+
3+
# Git
4+
.git
5+
.gitignore
6+
7+
# Python cache
8+
__pycache__/
39
*.pyc
410
*.pyo
511
*.pyd
6-
__pycache__
7-
.pytest_cache
12+
*.pdb
13+
.pytest_cache/
14+
*.pytest_cache
15+
16+
# Virtualenvs
17+
env/
18+
venv/
19+
.venv/
20+
21+
# Build / distribuições
22+
build/
23+
dist/
24+
*.egg-info/
25+
*.egg
26+
27+
# Logs e DB locais
28+
*.log
29+
*.sqlite3
30+
*.db
31+
32+
# Node
33+
node_modules/
34+
35+
# Outros
36+
.DS_Store
37+
*.swp

Dockerfile

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,23 @@
11
FROM python:3.11-slim
22

3-
ENV PYTHONUNBUFFERED True
4-
ENV APP_HOME /app
5-
ENV PORT 5000
3+
ENV PYTHONDONTWRITEBYTECODE=1 \
4+
PYTHONUNBUFFERED=1 \
5+
APP_HOME=/app \
6+
PORT=8080
67

78
WORKDIR $APP_HOME
8-
COPY . ./
99

10-
RUN pip install --no-cache-dir -r requirements.txt
10+
COPY requirements.txt .
1111

12-
CMD exec gunicorn --bind :$PORT --workers 1 --threads 8 --timeout 0 wsgi:app
12+
RUN apt-get update && apt-get install -y --no-install-recommends \
13+
build-essential \
14+
&& pip install --no-cache-dir -r requirements.txt \
15+
&& apt-get purge -y --auto-remove build-essential \
16+
&& rm -rf /var/lib/apt/lists/*
17+
18+
COPY . .
19+
20+
EXPOSE 8080
21+
22+
# Usando JSON array no CMD (mais seguro)
23+
CMD ["gunicorn", "--bind", ":8080", "--workers", "1", "--threads", "8", "--timeout", "0", "wsgi:app"]

app/main.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ def calib_validation():
7070
"""
7171
if request.method == "POST":
7272
return session_route.calib_results()
73-
return Response(
74-
"Invalid request method for route", status=405, mimetype="application/json"
75-
)
73+
return Response('Invalid request method for route', status=405, mimetype='application/json')
74+
75+
@app.route('/api/session/batch_predict', methods=['POST'])
76+
def batch_predict():
77+
if request.method == 'POST':
78+
return session_route.batch_predict()
79+
return Response('Invalid request method for route', status=405, mimetype='application/json')

app/requirements.txt

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
blinker==1.9.0
2+
click==8.1.8
3+
Flask==3.1.0
4+
flask-cors==5.0.1
5+
itsdangerous==2.2.0
6+
Jinja2==3.1.6
7+
joblib==1.4.2
8+
MarkupSafe==3.0.2
9+
numpy==2.2.4
10+
pandas==2.2.3
11+
python-dateutil==2.9.0.post0
12+
pytz==2025.2
13+
scikit-learn==1.6.1
14+
scipy==1.15.2
15+
six==1.17.0
16+
threadpoolctl==3.6.0
17+
tzdata==2025.2
18+
Werkzeug==3.1.3
19+
gunicorn==23.0.0
20+
requests==2.31.0

app/routes/session.py

Lines changed: 112 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66
import csv
77

88
from pathlib import Path
9+
import os
10+
import pandas as pd
11+
import traceback
12+
import re
13+
import requests
914
from flask import Flask, request, Response, send_file
1015

1116
# Local imports from app
@@ -147,26 +152,15 @@
147152

148153

149154
def calib_results():
150-
"""
151-
Generate calibration results.
152-
153-
This function generates calibration results based on the provided form data.
154-
It saves the calibration points to a CSV file. Then, it uses the gaze_tracker module to predict the calibration results.
155-
156-
Returns:
157-
Response: A JSON response containing the calibration results.
158-
159-
Raises:
160-
IOError: If there is an error while writing to the CSV files.
161-
"""
162-
# Get form data from request
163-
file_name = json.loads(request.form["file_name"])
164-
fixed_points = json.loads(request.form["fixed_circle_iris_points"])
165-
calib_points = json.loads(request.form["calib_circle_iris_points"])
166-
screen_height = json.loads(request.form["screen_height"])
167-
screen_width = json.loads(request.form["screen_width"])
168-
k = json.loads(request.form["k"])
169-
model = json.loads(request.form["model"])
155+
from_ruxailab = json.loads(request.form['from_ruxailab'])
156+
file_name = json.loads(request.form['file_name'])
157+
fixed_points = json.loads(request.form['fixed_circle_iris_points'])
158+
calib_points = json.loads(request.form['calib_circle_iris_points'])
159+
screen_height = json.loads(request.form['screen_height'])
160+
screen_width = json.loads(request.form['screen_width'])
161+
model_X = json.loads(request.form.get('model', '"Linear Regression"'))
162+
model_Y = json.loads(request.form.get('model', '"Linear Regression"'))
163+
k = json.loads(request.form['k'])
170164

171165
# Generate csv dataset of calibration points
172166
os.makedirs(
@@ -219,14 +213,107 @@ def calib_results():
219213
except IOError:
220214
print("I/O error")
221215

222-
# data = gaze_tracker.train_to_validate_calib(calib_csv_file, predict_csv_file)
216+
# Run prediction
217+
data = gaze_tracker.predict(calib_csv_file, k, model_X, model_Y)
218+
219+
if from_ruxailab:
220+
try:
221+
payload = {
222+
"session_id": file_name,
223+
"model": data,
224+
"screen_height": screen_height,
225+
"screen_width": screen_width,
226+
"k": k
227+
}
228+
229+
RUXAILAB_WEBHOOK_URL = "https://receivecalibration-ffptzpxikq-uc.a.run.app"
223230

224-
# Predict calibration results
225-
data = gaze_tracker.predict(calib_csv_file, k, model_X=model, model_Y=model)
231+
print("file_name:", file_name)
226232

227-
# Return calibration results
228-
return Response(json.dumps(data), status=200, mimetype="application/json")
233+
resp = requests.post(RUXAILAB_WEBHOOK_URL, json=payload)
234+
print("Enviado para RuxaiLab:", resp.status_code, resp.text)
235+
except Exception as e:
236+
print("Erro ao enviar para RuxaiLab:", e)
229237

238+
return Response(json.dumps(data), status=200, mimetype='application/json')
239+
240+
def batch_predict():
241+
try:
242+
data = request.get_json()
243+
iris_data = data['iris_tracking_data']
244+
k = data.get('k', 3)
245+
screen_height = data.get('screen_height')
246+
screen_width = data.get('screen_width')
247+
model_X = data.get('model_X', 'Linear Regression')
248+
model_Y = data.get('model_Y', 'Linear Regression')
249+
calib_id = data.get('calib_id')
250+
if not calib_id:
251+
return Response("Missing 'calib_id' in request", status=400)
252+
253+
base_path = Path().absolute() / 'app/services/calib_validation/csv/data'
254+
calib_csv_path = base_path / f"{calib_id}_fixed_train_data.csv"
255+
predict_csv_path = base_path / 'temp_batch_predict.csv'
256+
257+
print(f"Calib CSV Path: {calib_csv_path}")
258+
print(f"Predict CSV Path: {predict_csv_path}")
259+
print(f"Iris data sample (até 3): {iris_data[:3]}")
260+
261+
# Gera CSV temporário com os dados de íris
262+
with open(predict_csv_path, 'w', newline='') as csvfile:
263+
writer = csv.DictWriter(csvfile, fieldnames=[
264+
'left_iris_x', 'left_iris_y', 'right_iris_x', 'right_iris_y'
265+
])
266+
writer.writeheader()
267+
for item in iris_data:
268+
writer.writerow({
269+
'left_iris_x': item['left_iris_x'],
270+
'left_iris_y': item['left_iris_y'],
271+
'right_iris_x': item['right_iris_x'],
272+
'right_iris_y': item['right_iris_y']
273+
})
274+
275+
# Chama a função de predição corretamente
276+
predictions_raw = gaze_tracker.predict_new_data(
277+
calib_csv_path,
278+
predict_csv_path,
279+
model_X,
280+
model_Y,
281+
k
282+
)
283+
284+
# Constrói uma resposta mais visual e direta
285+
result = []
286+
if isinstance(predictions_raw, dict):
287+
# Percorre o dicionário retornado e transforma em lista plana
288+
for true_x, inner_dict in predictions_raw.items():
289+
if true_x == "centroids":
290+
continue
291+
for true_y, info in inner_dict.items():
292+
pred_x_list = info.get("predicted_x", [])
293+
pred_y_list = info.get("predicted_y", [])
294+
precision = info.get("PrecisionSD")
295+
accuracy = info.get("Accuracy")
296+
297+
for i, (px, py) in enumerate(zip(pred_x_list, pred_y_list)):
298+
timestamp = iris_data[i].get("timestamp") if i < len(iris_data) else None
299+
result.append({
300+
"timestamp": timestamp,
301+
"predicted_x": px,
302+
"predicted_y": py,
303+
"precision": precision,
304+
"accuracy": accuracy,
305+
"screen_width": screen_width,
306+
"screen_height": screen_height
307+
})
308+
else:
309+
print("Retorno inesperado da função predict:", type(predictions_raw))
310+
311+
return Response(json.dumps(result), status=200, mimetype='application/json')
312+
313+
except Exception as e:
314+
print("Erro na batch_predict:", e)
315+
traceback.print_exc()
316+
return Response("Erro interno na predição", status=500)
230317

231318
# def session_results():
232319
# session_id = request.args.__getitem__('id')

app/services/gaze_tracker.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,90 @@ def predict(data, k, model_X, model_Y):
235235
# Return the data
236236
return data
237237

238+
def predict_new_data_simple(calib_csv_path, predict_csv_path, model_X, model_Y, k=3):
239+
"""
240+
Versão simplificada de predict_new_data.
241+
Treina modelos nos dados de calibração e prevê coordenadas nos novos dados.
242+
Retorna o mesmo formato que a função `predict`.
243+
"""
244+
# -------------------- SCALERS --------------------
245+
sc_x = StandardScaler()
246+
sc_y = StandardScaler()
247+
248+
# -------------------- TREINO --------------------
249+
df_train = pd.read_csv(calib_csv_path).drop(["screen_height", "screen_width"], axis=1)
250+
251+
X_train_x = df_train[["left_iris_x", "right_iris_x"]].values
252+
y_train_x = df_train["point_x"].values
253+
X_train_y = df_train[["left_iris_y", "right_iris_y"]].values
254+
y_train_y = df_train["point_y"].values
255+
256+
X_train_x_scaled = sc_x.fit_transform(X_train_x)
257+
X_train_y_scaled = sc_y.fit_transform(X_train_y)
258+
259+
# Modelos
260+
model_fit_x = models[model_X].fit(X_train_x_scaled, y_train_x)
261+
model_fit_y = models[model_Y].fit(X_train_y_scaled, y_train_y)
262+
263+
# -------------------- NOVOS DADOS --------------------
264+
df_predict = pd.read_csv(predict_csv_path)
265+
X_pred_x = sc_x.transform(df_predict[["left_iris_x", "right_iris_x"]].values)
266+
X_pred_y = sc_y.transform(df_predict[["left_iris_y", "right_iris_y"]].values)
267+
268+
y_pred_x = model_fit_x.predict(X_pred_x)
269+
y_pred_y = model_fit_y.predict(X_pred_y)
270+
271+
# Garantir valores não-negativos
272+
y_pred_x = np.clip(y_pred_x, 0, None)
273+
y_pred_y = np.clip(y_pred_y, 0, None)
274+
275+
# -------------------- KMEANS --------------------
276+
data_pred = np.array([y_pred_x, y_pred_y]).T
277+
kmeans_model = KMeans(n_clusters=k, n_init="auto", init="k-means++")
278+
y_kmeans = kmeans_model.fit_predict(data_pred)
279+
280+
# -------------------- FORMATA DADOS --------------------
281+
df_data = pd.DataFrame({
282+
"Predicted X": y_pred_x,
283+
"Predicted Y": y_pred_y,
284+
"True X": df_predict["point_x"] if "point_x" in df_predict else y_pred_x,
285+
"True Y": df_predict["point_y"] if "point_y" in df_predict else y_pred_y
286+
})
287+
288+
# Calcular métricas
289+
precision_x = df_data.groupby(["True X", "True Y"]).apply(func_precision_x)
290+
precision_y = df_data.groupby(["True X", "True Y"]).apply(func_presicion_y)
291+
precision_xy = (precision_x + precision_y) / 2
292+
precision_xy /= np.mean(precision_xy)
293+
294+
accuracy_x = df_data.groupby(["True X", "True Y"]).apply(func_accuracy_x)
295+
accuracy_y = df_data.groupby(["True X", "True Y"]).apply(func_accuracy_y)
296+
accuracy_xy = (accuracy_x + accuracy_y) / 2
297+
accuracy_xy /= np.mean(accuracy_xy)
298+
299+
# Estrutura final
300+
data = {}
301+
for index, row in df_data.iterrows():
302+
outer_key = str(int(row["True X"]))
303+
inner_key = str(int(row["True Y"]))
304+
if outer_key not in data:
305+
data[outer_key] = {}
306+
data[outer_key][inner_key] = {
307+
"predicted_x": df_data[
308+
(df_data["True X"] == row["True X"]) &
309+
(df_data["True Y"] == row["True Y"])
310+
]["Predicted X"].tolist(),
311+
"predicted_y": df_data[
312+
(df_data["True X"] == row["True X"]) &
313+
(df_data["True Y"] == row["True Y"])
314+
]["Predicted Y"].tolist(),
315+
"PrecisionSD": precision_xy[(row["True X"], row["True Y"])],
316+
"Accuracy": accuracy_xy[(row["True X"], row["True Y"])],
317+
}
318+
319+
data["centroids"] = kmeans_model.cluster_centers_.tolist()
320+
return data
321+
238322

239323
def train_to_validate_calib(calib_csv_file, predict_csv_file):
240324
dataset_train_path = calib_csv_file

package-lock.json

Lines changed: 6 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,6 @@ scipy==1.15.2
1515
six==1.17.0
1616
threadpoolctl==3.6.0
1717
tzdata==2025.2
18-
Werkzeug==3.1.3
18+
Werkzeug==3.1.3
19+
gunicorn==23.0.0
20+
requests==2.31.0

wsgi.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,5 @@
1313
import os
1414
from app.main import app
1515

16-
1716
if __name__ == "__main__":
18-
app.run(debug=True, host="0.0.0.0", port=int(os.environ.get("PORT", 5000)))
17+
app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 8080)))

0 commit comments

Comments
 (0)