|
| 1 | +from flask import Flask, request, jsonify |
| 2 | +from flask_cors import CORS |
| 3 | +from tensorflow.keras.models import load_model |
| 4 | +from tensorflow.keras.preprocessing import image |
| 5 | +import numpy as np |
| 6 | +import os |
| 7 | +import uuid |
| 8 | + |
| 9 | +app = Flask(__name__) |
| 10 | +CORS(app) # allow React frontend to talk to Flask backend |
| 11 | + |
| 12 | +# ---- config ---- |
| 13 | +MODEL_PATH = "../models/model.keras" # your model file |
| 14 | +IMG_SIZE = (128, 128) # must match your training size |
| 15 | +UPLOAD_DIR = "../data/temp" # temporary save for incoming files |
| 16 | +os.makedirs(UPLOAD_DIR, exist_ok=True) |
| 17 | + |
| 18 | +# These are the 3 classes we want to serve in the app UI |
| 19 | +SERVE_LABELS = ["Healthy", "Early Blight", "Late Blight"] |
| 20 | + |
| 21 | +# ---- load model once ---- |
| 22 | +model = load_model(MODEL_PATH) |
| 23 | + |
| 24 | + |
| 25 | +@app.route("/predict", methods=["POST"]) |
| 26 | +def predict(): |
| 27 | + if "file" not in request.files or request.files["file"].filename == "": |
| 28 | + return jsonify({"error": "No file uploaded"}), 400 |
| 29 | + |
| 30 | + f = request.files["file"] |
| 31 | + |
| 32 | + # save with random name to avoid collisions on repeated tests |
| 33 | + ext = os.path.splitext(f.filename)[1].lower() |
| 34 | + fname = f"{uuid.uuid4().hex}{ext if ext else '.jpg'}" |
| 35 | + fpath = os.path.join(UPLOAD_DIR, fname) |
| 36 | + f.save(fpath) |
| 37 | + |
| 38 | + # --- preprocess exactly like training --- |
| 39 | + img = image.load_img(fpath, target_size=IMG_SIZE) |
| 40 | + arr = image.img_to_array(img) |
| 41 | + arr = np.expand_dims(arr, 0) / 255.0 |
| 42 | + |
| 43 | + # --- raw prediction (model outputs 10 classes in your case) --- |
| 44 | + preds = model.predict(arr) |
| 45 | + |
| 46 | + # ---- IMPORTANT: keep only the first 3 outputs (Healthy, Early, Late) |
| 47 | + # If your model’s class order differs, adjust the slice or indices here. |
| 48 | + trimmed = preds[0][:3].astype(np.float64) |
| 49 | + |
| 50 | + # renormalize the 3 scores so they sum to 1 |
| 51 | + total = np.sum(trimmed) |
| 52 | + if total > 0: |
| 53 | + trimmed /= total |
| 54 | + |
| 55 | + idx = int(np.argmax(trimmed)) |
| 56 | + conf = float(trimmed[idx]) |
| 57 | + |
| 58 | + # logging for debugging in console |
| 59 | + print("Raw preds:", preds) |
| 60 | + print("Trimmed (3-class) preds:", trimmed) |
| 61 | + print("Predicted:", SERVE_LABELS[idx], "Confidence:", conf) |
| 62 | + |
| 63 | + return jsonify({ |
| 64 | + "label": SERVE_LABELS[idx], |
| 65 | + "confidence": round(conf * 100, 2) |
| 66 | + }) |
| 67 | + |
| 68 | + |
| 69 | +@app.route("/", methods=["GET"]) |
| 70 | +def health(): |
| 71 | + # simple health check |
| 72 | + return jsonify({"ok": True, "model": MODEL_PATH}) |
| 73 | + |
| 74 | + |
| 75 | +if __name__ == "__main__": |
| 76 | + app.run(port=5000, debug=True) |
0 commit comments