import tensorflow as tf
from tensorflow import keras

def main():
    # Keras comes with functions for fetching
    # some widely used datasets

    fashion_mnist = keras.datasets.fashion_mnist

    # fetch training and test sets

    # dataset is already divided into a training
    # set and a test set

    (X_train_full, y_train_full), (X_test, y_test) = \
        fashion_mnist.load_data()

    print( '\n **Size of MNIST Fashion training set** \n' )
    print( X_train_full.shape )

    print( '\n **Type of elements in MNIST Fashion dataset** \n' )
    print( X_train_full.dtype )

    # make validation set

    # scale data---instead of integers in
    # the interval [0, 255] represent pixel
    # brightness with floating point values
    # between 0.0 and 1.0

    # validate with first 5000 records

    X_validate = X_train_full[:5000]/255.0
    y_validate = y_train_full[:5000]

    # train with all remaining records
    # in the full training set

    X_train = X_train_full[5000:]/255.0
    y_train = y_train_full[5000:]

    # dataset contains images of 10 kinds of clothing

    types_of_clothing = [
        'T-shirt/top',
        'Trouser',
        'Pullover',
        'Dress',
        'Coat',
        'Sandal',
        'Shirt',
        'Sneaker',
        'Bag',
        'Ankle boot'
    ]

    model = keras.models.Sequential([
        keras.layers.Flatten( input_shape = [28, 28] ),
        keras.layers.Dense( 300, activation = 'relu' ),
        keras.layers.Dense( 100, activation = 'relu' ),
        keras.layers.Dense( 10, activation = 'softmax' )
    ])


    print( '\n **Model summary** \n' )
    print( model.summary() )

    model.compile(
        loss = 'sparse_categorical_crossentropy',
        optimizer = 'sgd',
        metrics = [ 'accuracy' ]
    )

    print( '\n **Fit the model** \n' )

    history = model.fit( 
        X_train, 
        y_train, 
        epochs = 30,
        validation_data = (X_validate, y_validate))

    print( '\n **Evaluate model** \n' )
    model.evaluate( X_test, y_test )

    X_new = X_test[:3]
    y_probabilities = model.predict( X_new )

    print( '\n **Prediction probabilities** \n' )
    print( y_probabilities.round(2) )

    y_predictions = model.predict_classes( X_new )

    print( '\n **Predictions** \n' )
    print( y_predictions )


    for i in y_predictions:
        print( types_of_clothing[i] )

# end of main()

if __name__ == '__main__':
    main()