{ "cells": [ { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "from sklearn.datasets import fetch_openml\n", "from sklearn.linear_model import SGDClassifier\n", "from sklearn.metrics import roc_curve\n", "from sklearn.metrics import roc_auc_score\n", "from sklearn.model_selection import cross_val_predict\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "mnist = fetch_openml('mnist_784', version=1, cache=True)\n", "mnist.target = mnist.target.astype(np.int8) # fetch_openml() returns targets as strings\n", "\n", "X, y = mnist[\"data\"], mnist[\"target\"]\n", "X_train, y_train = X[:60000], y[:60000]\n", "y_train_5 = (y_train == 5)\n", "\n", "sgd_clf = SGDClassifier(max_iter=5, tol=-np.infty, random_state=42)\n", "y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3, method=\"decision_function\")\n", "\n", "fpr, tpr, thresholds = roc_curve(y_train_5, y_scores)\n", "def plot_roc_curve(fpr, tpr, label=None):\n", " plt.plot(fpr, tpr, linewidth=2, label=label)\n", " plt.plot([0, 1], [0, 1], \"k--\")\n", " plt.axis([0, 1, 0, 1])\n", " plt.xlabel(\"False Positive Rate\")\n", " plt.ylabel(\"True Positive Rate\")\n", "\n", "plot_roc_curve(fpr, tpr)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.9536789698168869" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "roc_auc_score(y_train_5, y_scores)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 2 }