
import numpy as np
import os

from sklearn.datasets import fetch_openml
from sklearn.linear_model import SGDClassifier

def main():
    print( '\n **Guten Morgen!** \n' )

    mnist = fetch_openml('mnist_784', version=1)

    print( '\n Type of MNIST data structure. \n' )

    print( type(mnist) )

    print( '\n Keys: \n' )

    for key in mnist.keys():
        print( '\t', key )

    X, y = mnist["data"], mnist["target"]

    print( '\n Shape of data structure that contains images. \n' )
    print( '\t', X.shape )

    print( '\n Shape of data structure that contains labels. \n' )
    print( '\t', y.shape )

    print( '\n Type of elements in y. \n' )
    print( y.dtype )

    # Change the type of the labels.
    y = y.astype(np.uint8)

    print( '\n Type of elements in y after conversion. \n' )
    print( y.dtype )

    # Collect the data for training and testing
    # in variables.
    #     X_train and X_test are images
    #     y_train and y_test are labels
    X_train, X_test, y_train, y_test = \
        X[:60000], X[60000:], y[:60000], y[60000:]

    # Create binary arrays.
    # The i-th element of each array is True
    # when the i-th image is a 5, and False otherwise.
    y_train_5 = (y_train == 5)
    y_test_5 = (y_test == 5)

    # Create an instance of SGDClassifier.
    # Assign values to the hyperparameters that will
    # be the default values in future versions of the
    # scikit-learn software.
    sgd_clf = SGDClassifier(max_iter=1000, tol=1e-3, random_state=42)

    # Fit the model.
    # The fit method here gets images and
    # labels that are either True (the image is a 5)
    # or False (the image is not a 5).
    sgd_clf.fit(X_train, y_train_5)

    # Use the first digit in the dataset.
    n = 24
    some_digits = X_train[0:n]

    predictions = sgd_clf.predict( some_digits )

    print( '\n The labels and predictions on the first n images. \n' )
    for i in range(n):
        print( f'{y_train[i]:2d} {predictions[i]}' )

    print( '\n **Guten Abend!** \n' )
# end of main()

if __name__ == '__main__':
    main()
