#!/usr/bin/env python3
"""Extract a 128-d face_recognition encoding from a base64 image payload."""

import base64
import io
import json
import sys

try:
    import face_recognition
except Exception as exc:
    print(json.dumps({"error": f"Python dependencies missing: {exc}"}))
    sys.exit(1)


def fail(message: str, code: int = 1):
    print(json.dumps({"error": message}))
    sys.exit(code)


def main():
    raw = sys.stdin.read().strip()
    if not raw:
        fail("Empty input")

    try:
        payload = json.loads(raw)
    except Exception:
        fail("Invalid JSON input")

    image_base64 = payload.get("image_base64")
    if not image_base64:
        fail("image_base64 is required")

    # Accept both data URL and raw base64.
    if "," in image_base64 and image_base64.startswith("data:"):
        image_base64 = image_base64.split(",", 1)[1]

    try:
        image_bytes = base64.b64decode(image_base64)
        # Use face_recognition's own loader via BytesIO — it handles PIL conversion
        # and produces a numpy array compatible with the underlying dlib C extension.
        np_image = face_recognition.load_image_file(io.BytesIO(image_bytes))
    except Exception:
        fail("Invalid image data")

    if np_image.ndim != 3 or np_image.shape[2] != 3:
        fail("Image could not be converted to RGB format")

    try:
        face_locations = face_recognition.face_locations(np_image, model="hog")
    except Exception as exc:
        fail(f"Face detection failed: {exc}")

    if len(face_locations) == 0:
        fail("No face detected. Please center your face and try again.")

    if len(face_locations) > 1:
        fail("Multiple faces detected. Please capture only one face.")

    try:
        encodings = face_recognition.face_encodings(np_image, known_face_locations=face_locations)
    except Exception as exc:
        fail(f"Face encoding failed: {exc}")

    if not encodings:
        fail("Unable to extract face encoding. Please try again.")

    encoding = [float(x) for x in encodings[0].tolist()]
    print(json.dumps({"encoding": encoding}))


if __name__ == "__main__":
    main()
