Skip to content

Commit d61e716

Browse files
committed
feat: enhance gaze prediction with new model and NaN handling
1 parent afb61f5 commit d61e716

File tree

2 files changed

+250
-408
lines changed

2 files changed

+250
-408
lines changed

app/routes/session.py

Lines changed: 57 additions & 202 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@
44
import time
55
import json
66
import csv
7+
import math
8+
import numpy as np
79

810
from pathlib import Path
911
import os
1012
import pandas as pd
1113
import traceback
1214
import re
1315
import requests
14-
from flask import Flask, request, Response, send_file
16+
from flask import Flask, request, Response, send_file, jsonify
1517

1618
# Local imports from app
1719
from app.services.storage import save_file_locally
@@ -29,126 +31,32 @@
2931
app = Flask(__name__)
3032

3133

32-
# def allowed_file(filename):
33-
# return '.' in filename and \
34-
# filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
34+
# Helper function to convert NaN values to None for JSON serialization
35+
def convert_nan_to_none(obj):
36+
"""
37+
Recursively converts NaN and Inf values to None for proper JSON serialization.
38+
39+
Args:
40+
obj: Python object (dict, list, float, etc.)
41+
42+
Returns:
43+
The object with NaN/Inf values converted to None
44+
"""
45+
if isinstance(obj, dict):
46+
return {k: convert_nan_to_none(v) for k, v in obj.items()}
47+
elif isinstance(obj, list):
48+
return [convert_nan_to_none(item) for item in obj]
49+
elif isinstance(obj, float):
50+
if math.isnan(obj) or math.isinf(obj):
51+
return None
52+
return obj
53+
elif isinstance(obj, (np.floating, np.integer)):
54+
if np.isnan(obj) or np.isinf(obj):
55+
return None
56+
return float(obj) if isinstance(obj, np.floating) else int(obj)
57+
return obj
3558

3659

37-
# def create_session():
38-
# # # Get files from request
39-
# if 'webcamfile' not in request.files or 'screenfile' not in request.files:
40-
# return Response('Error: Files not found on the request', status=400, mimetype='application/json')
41-
42-
# webcam_file = request.files['webcamfile']
43-
# screen_file = request.files['screenfile']
44-
# title = request.form['title']
45-
# description = request.form['description']
46-
# website_url = request.form['website_url']
47-
# user_id = request.form['user_id']
48-
# calib_points = json.loads(request.form['calib_points'])
49-
# iris_points = json.loads(request.form['iris_points'])
50-
# timestamp = time.time()
51-
# session_id = re.sub(r"\s+", "", f'{timestamp}{title}')
52-
53-
# # Check if extension is valid
54-
# if webcam_file and allowed_file(webcam_file.filename) and screen_file and allowed_file(screen_file.filename):
55-
# webcam_url = save_file_locally(webcam_file, f'/{session_id}')
56-
# screen_url = save_file_locally(screen_file, f'/{session_id}')
57-
# else:
58-
# return Response('Error: Files do not follow the extension guidelines', status=400, mimetype='application/json')
59-
60-
# # Save session on database
61-
# session = Session(
62-
# id=session_id,
63-
# title=title,
64-
# description=description,
65-
# user_id=user_id,
66-
# created_date=timestamp,
67-
# website_url=website_url,
68-
# screen_record_url=screen_url,
69-
# webcam_record_url=webcam_url,
70-
# heatmap_url='',
71-
# calib_points=calib_points,
72-
# iris_points=iris_points
73-
# )
74-
75-
# db.create_document(COLLECTION_NAME, session_id, session.to_dict())
76-
77-
# # Generate csv dataset of calibration points
78-
# os.makedirs(
79-
# f'{Path().absolute()}/public/training/{session_id}/', exist_ok=True)
80-
# csv_file = f'{Path().absolute()}/public/training/{session_id}/train_data.csv'
81-
# csv_columns = ['timestamp', 'left_iris_x', 'left_iris_y',
82-
# 'right_iris_x', 'right_iris_y', 'mouse_x', 'mouse_y']
83-
# try:
84-
# with open(csv_file, 'w') as csvfile:
85-
# writer = csv.DictWriter(csvfile, fieldnames=csv_columns)
86-
# writer.writeheader()
87-
# for data in calib_points:
88-
# writer.writerow(data)
89-
# except IOError:
90-
# print("I/O error")
91-
92-
# # Generate csv of iris points of session
93-
# os.makedirs(
94-
# f'{Path().absolute()}/public/sessions/{session_id}/', exist_ok=True)
95-
# csv_file = f'{Path().absolute()}/public/sessions/{session_id}/session_data.csv'
96-
# csv_columns = ['timestamp', 'left_iris_x', 'left_iris_y',
97-
# 'right_iris_x', 'right_iris_y']
98-
# try:
99-
# with open(csv_file, 'w') as csvfile:
100-
# writer = csv.DictWriter(csvfile, fieldnames=csv_columns)
101-
# writer.writeheader()
102-
# for data in iris_points:
103-
# writer.writerow(data)
104-
# except IOError:
105-
# print("I/O error")
106-
107-
# return Response('Session Created!', status=201, mimetype='application/json')
108-
109-
110-
# def get_user_sessions():
111-
# user_id = request.args.__getitem__('user_id')
112-
# field = u'user_id'
113-
# op = u'=='
114-
# docs = db.get_documents(COLLECTION_NAME, field, op, user_id)
115-
# sessions = []
116-
# for doc in docs:
117-
# sessions.append(
118-
# doc.to_dict()
119-
# )
120-
# return Response(json.dumps(sessions), status=200, mimetype='application/json')
121-
122-
123-
# def get_session_by_id():
124-
# session_id = request.args.__getitem__('id')
125-
# doc = db.get_document(COLLECTION_NAME, doc_id=session_id)
126-
127-
# if doc.exists:
128-
# session = doc.to_dict()
129-
# return Response(json.dumps(session), status=200, mimetype='application/json')
130-
# else:
131-
# return Response('Session does not exist', status=404, mimetype='application/json')
132-
133-
134-
# def delete_session_by_id():
135-
# session_id = request.args.__getitem__('id')
136-
# db.delete_document(COLLECTION_NAME, session_id)
137-
# return Response(f'Session deleted with id {session_id}', status=200, mimetype='application/json')
138-
139-
140-
# def update_session_by_id():
141-
# id = request.form['id']
142-
# title = request.form['title']
143-
# description = request.form['description']
144-
145-
# data = {
146-
# u'title': title,
147-
# u'description': description,
148-
# }
149-
150-
# db.update_document(COLLECTION_NAME, id, data)
151-
# return Response(f'Session updated with id {id}', status=200, mimetype='application/json')
15260

15361

15462
def calib_results():
@@ -235,107 +143,54 @@ def calib_results():
235143
except Exception as e:
236144
print("Erro ao enviar para RuxaiLab:", e)
237145

146+
# Convert NaN values to None before returning JSON
147+
data = convert_nan_to_none(data)
238148
return Response(json.dumps(data), status=200, mimetype='application/json')
239149

240150
def batch_predict():
241151
try:
242152
data = request.get_json()
243-
iris_data = data['iris_tracking_data']
244-
k = data.get('k', 3)
245-
screen_height = data.get('screen_height')
246-
screen_width = data.get('screen_width')
247-
model_X = data.get('model_X', 'Linear Regression')
248-
model_Y = data.get('model_Y', 'Linear Regression')
249-
calib_id = data.get('calib_id')
153+
iris_data = data["iris_tracking_data"]
154+
screen_width = data.get("screen_width")
155+
screen_height = data.get("screen_height")
156+
model_X = data.get("model_X", "Linear Regression")
157+
model_Y = data.get("model_Y", "Linear Regression")
158+
calib_id = data.get("calib_id")
159+
250160
if not calib_id:
251-
return Response("Missing 'calib_id' in request", status=400)
161+
return Response("Missing calib_id", status=400)
252162

253-
base_path = Path().absolute() / 'app/services/calib_validation/csv/data'
163+
base_path = Path().absolute() / "app/services/calib_validation/csv/data"
254164
calib_csv_path = base_path / f"{calib_id}_fixed_train_data.csv"
255-
predict_csv_path = base_path / 'temp_batch_predict.csv'
165+
predict_csv_path = base_path / "temp_batch_predict.csv"
256166

257-
print(f"Calib CSV Path: {calib_csv_path}")
258-
print(f"Predict CSV Path: {predict_csv_path}")
259-
print(f"Iris data sample (até 3): {iris_data[:3]}")
260-
261-
# Gera CSV temporário com os dados de íris
262-
with open(predict_csv_path, 'w', newline='') as csvfile:
167+
# CSV temporário
168+
with open(predict_csv_path, "w", newline="") as csvfile:
263169
writer = csv.DictWriter(csvfile, fieldnames=[
264-
'left_iris_x', 'left_iris_y', 'right_iris_x', 'right_iris_y'
170+
"left_iris_x", "left_iris_y", "right_iris_x", "right_iris_y"
265171
])
266172
writer.writeheader()
267173
for item in iris_data:
268174
writer.writerow({
269-
'left_iris_x': item['left_iris_x'],
270-
'left_iris_y': item['left_iris_y'],
271-
'right_iris_x': item['right_iris_x'],
272-
'right_iris_y': item['right_iris_y']
175+
"left_iris_x": item["left_iris_x"],
176+
"left_iris_y": item["left_iris_y"],
177+
"right_iris_x": item["right_iris_x"],
178+
"right_iris_y": item["right_iris_y"],
273179
})
274180

275-
# Chama a função de predição corretamente
276-
predictions_raw = gaze_tracker.predict_new_data(
277-
calib_csv_path,
278-
predict_csv_path,
279-
model_X,
280-
model_Y,
281-
k
181+
result = gaze_tracker.predict_new_data_simple(
182+
calib_csv_path=calib_csv_path,
183+
predict_csv_path=predict_csv_path,
184+
iris_data=iris_data,
185+
# model_X="Random Forest Regressor",
186+
# model_Y="Random Forest Regressor",
187+
screen_width=screen_width,
188+
screen_height=screen_height,
282189
)
283190

284-
# Constrói uma resposta mais visual e direta
285-
result = []
286-
if isinstance(predictions_raw, dict):
287-
# Percorre o dicionário retornado e transforma em lista plana
288-
for true_x, inner_dict in predictions_raw.items():
289-
if true_x == "centroids":
290-
continue
291-
for true_y, info in inner_dict.items():
292-
pred_x_list = info.get("predicted_x", [])
293-
pred_y_list = info.get("predicted_y", [])
294-
precision = info.get("PrecisionSD")
295-
accuracy = info.get("Accuracy")
296-
297-
for i, (px, py) in enumerate(zip(pred_x_list, pred_y_list)):
298-
timestamp = iris_data[i].get("timestamp") if i < len(iris_data) else None
299-
result.append({
300-
"timestamp": timestamp,
301-
"predicted_x": px,
302-
"predicted_y": py,
303-
"precision": precision,
304-
"accuracy": accuracy,
305-
"screen_width": screen_width,
306-
"screen_height": screen_height
307-
})
308-
else:
309-
print("Retorno inesperado da função predict:", type(predictions_raw))
310-
311-
return Response(json.dumps(result), status=200, mimetype='application/json')
191+
return jsonify(convert_nan_to_none(result))
312192

313193
except Exception as e:
314-
print("Erro na batch_predict:", e)
194+
print("Erro batch_predict:", e)
315195
traceback.print_exc()
316-
return Response("Erro interno na predição", status=500)
317-
318-
# def session_results():
319-
# session_id = request.args.__getitem__('id')
320-
321-
# # Train Model
322-
# data = gaze_tracker.train_model(session_id)
323-
324-
# # To do: return gaze x and y on response as json
325-
# gaze = []
326-
# for i in range(len(data['x'])):
327-
# gaze.append({
328-
# 'x': data['x'][i],
329-
# 'y': data['y'][i]
330-
# })
331-
332-
# return Response(json.dumps(gaze), status=200, mimetype='application/json')
333-
334-
335-
# def session_results_record():
336-
# session_id = request.args.__getitem__('id')
337-
# doc = db.get_document(COLLECTION_NAME, doc_id=session_id)
338-
# if doc.exists:
339-
# session = doc.to_dict()
340-
341-
# return send_file(f'{Path().absolute()}/public/videos/{session["screen_record_url"]}', mimetype='video/webm')
196+
return Response("Erro interno", status=500)

0 commit comments

Comments
 (0)