from __future__ import absolute_import, division, print_function, unicode_literals
 
from PIL import Image
import numpy as np
import os
 
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
 
 
class ElaboraImmagini:
 
    # acquisiamo i dati per le immagini, con i seguenti argomenti
    # orig: lista che contiene il percorso alle immagini originali da tagliare
    # dest: cartella di destinazione nella quale salvare le immagini
    # n: numero di immagini in colonna
    # w: larghezza delle singole immagini
    def __init__(s, orig = [], dest = None, n = 4, w = 100):
 
        # prendiamo il percorso originale della cartella nella quale ci troviamo
        # ovvero dove si trova lo script di Python
        s.__orig_dir = os.path.dirname(os.path.realpath(__file__))
        # impostiamo la variabile privata con la lista delle immagini da elaborare
        s.__orig = orig
        # creiamo la cartella di destinazione
        s.__dest = s.__orig_dir + "\\" + dest
 
        # se la cartella non esiste...
        if not os.path.exists(s.__dest):
            # ... allora la creiamo
            os.makedirs(s.__dest)
 
        # salviamo n e w in variabili private
        s.__n = n
        s.__w = w
 
        # chiamiamo l'elaborazione delle immagini
        s.elabora()
 
    # creiamo il metodo per elaborare le immagini
    def elabora(s):
 
        # per ogni file nella lista di origine
        for f in s.__orig:
            # creiamo il nome del file con il percorso assoluto
            nome_file = s.__orig_dir + "\\" + f
            # per ogni riga...
            for i in range(s.__n):
                # ... e per ogni colonna
                for j in range(s.__n):
                    # ritagliamo le immagini con l'apposito metodo
                    # al metodo verranno passate le coordinate come 4-tupla
                    # calcolata da righe e colonne, oltre al nome del file da creare
                    # il nome lo creiamo come: nome del file senza estensione + conteggio
                    # quindi se stiamo tagliando banane.png avremo banane_x.png
                    # in questo modo potremo riconoscere le immagini dai nomi dei file
                    s.ritaglia( (j*s.__w,i*s.__w,(j+1)*s.__w,(i+1)*s.__w) , \
                                nome_file , \
                                os.path.basename(nome_file)[:-4] + "_" + str( i * s.__n + j) + ".png")
 
    # ritaglia con i seguenti argomenti:
    # coor: posizione left, top, right, bottom, come se fossero x e y insomma
    # nome_file: nome del file corrente
    # nome_dest: nome del file da creare
    def ritaglia(s, coor, nome_file, nome_dest):
        # apriamo il file
        img = Image.open(nome_file)
        # ritagliamolo
        cropped = img.crop(coor)
        # creiamo la posizione di salvataggio
        file_salvato = s.__dest + "\\" + nome_dest
        # salviamo
        cropped.save(file_salvato)
 
 
 
ElaboraImmagini(["mele.png","pere.png","banane.png"],"TRAINING")
ElaboraImmagini(["frutta.png"],"TEST")
 
class CreaDBImmagini:
 
    # creiamo delle "costanti" giusto per codificare il tipo di output
    # che vogliamo ottenere passando il tipo al costruttore
    TEST = 0
    TRAINING = 1
 
    # qui codifichiamo la frutta
    MELE = 0
    PERE = 1
    BANANE = 2
 
    # passiamo al costruttore due argomenti:
    # path: il percorso da cui leggere le immagini
    # tipo: il tipo di elaborazione che vogliamo fare (TEST o TRAINING)
    def __init__(s, path, tipo = 0):
        s.__path = path
        # lista che conterrà le immagini in byte
        s.__immagini = []
        # lista contenente le descrizioni
        s.__descrizioni = []
        s.__tipo = tipo
 
        s.elabora()
 
    # elaboriamo la cartella con le immagini
    def elabora(s):
        # per ciascuna immagine nella cartella
        for f in os.listdir(s.__path):
            # creiamo il percorso del file
            percorso_file = s.__path + "\\" + f
            # se il file esiste (per come è costruito il metodo non c'è
            # motivo per cui il file non debba esistere, in realtà
            # verifichiamo che eventualmente non sia una cartella)
            if os.path.isfile(percorso_file):
                # leggiamo l'immagine usando la libreria Pillow
                img = Image.open(percorso_file)
                # trasformiamo i byte dell'immagine in un array di NumPy
                aimg = np.asarray(img)
                # aggiungiamo tale array alla lista delle immagini
                s.__immagini.append(aimg)
 
                # se elaboriamo le immagini per il training...
                if s.__tipo == s.TRAINING:
                    # preniamo il tipo del frutto dal nome
                    # ricordiamoci che i nomi sono del tipo mela_1.png
                    # quindi spezziamo il nome su _ e prendiamo la prima
                    # parte
                    frutto = f.split("_")[0]
                    # codifichiamo il nome con il numero
                    if frutto == "mele": desc = s.MELE
                    if frutto == "pere": desc = s.PERE
                    if frutto == "banane": desc = s.BANANE
                    # aggiungiamo tale numero alla lista delle descrizioni
                    s.__descrizioni.append(desc)
 
                # in caso di test ci accontentiamo dei meri nomi dei file
                if s.__tipo == s.TEST:
                    s.__descrizioni.append(f)
 
    # metodo per restituire la lista di immagini e descrizioni
    # codificate come array di NumPy
    def get(s):
        return (np.array(s.__immagini), np.array(s.__descrizioni))
 
# qui alleno il modello (ora disattivato con False)
if True:
    train_img, train_desc = CreaDBImmagini("TRAINING",CreaDBImmagini.TRAINING).get()
 
    modello = keras.Sequential([
            keras.layers.Flatten(input_shape=(100,100,3)),
            keras.layers.Dense(128, activation="relu"),
            keras.layers.Dense(3, activation="softmax")
        ])
 
    modello.compile(optimizer="adam",
            loss="sparse_categorical_crossentropy",
            metrics=["accuracy"])
 
    modello.fit(train_img, train_desc, epochs=20)
 
    modello.save("modello_frutta")
 
# qui uso il modello
if True:
    modello = keras.models.load_model("modello_frutta")
 
    test_img, test_desc = CreaDBImmagini("TEST",CreaDBImmagini.TEST).get()
 
    print("elenco di previsioni grezzo")
    previsione = modello.predict(test_img)
 
    print("elenco previsioni gestito")
    for i, img in enumerate(test_img):
        previsione = modello.predict(np.array([img]))
        print(previsione, test_desc[i])