Source code for skdatasets.repositories.keras

"""
Keras datasets (https://keras.io/datasets).

@author: David Diaz Vico
@license: MIT
"""

from __future__ import annotations

from typing import Any, Literal, Tuple, overload

import numpy as np
from sklearn.utils import Bunch

from tensorflow.keras.datasets import (
    boston_housing,
    cifar10,
    cifar100,
    fashion_mnist,
    imdb,
    mnist,
    reuters,
)

DATASETS = {
    'boston_housing': boston_housing.load_data,
    'cifar10': cifar10.load_data,
    'cifar100': cifar100.load_data,
    'fashion_mnist': fashion_mnist.load_data,
    'imdb': imdb.load_data,
    'mnist': mnist.load_data,
    'reuters': reuters.load_data,
}


@overload
def fetch(
    name: str,
    *,
    return_X_y: Literal[False] = False,
    **kwargs: Any,
) -> Bunch:
    pass


@overload
def fetch(
    name: str,
    *,
    return_X_y: Literal[True],
    **kwargs: Any,
) -> Tuple[np.typing.NDArray[float], np.typing.NDArray[int]]:
    pass


[docs]def fetch( name: str, *, return_X_y: bool = False, **kwargs: Any, ) -> Bunch | Tuple[np.typing.NDArray[float], np.typing.NDArray[int]]: """ Fetch Keras dataset. Fetch a Keras dataset by name. More info at https://keras.io/datasets. Parameters ---------- name : string Dataset name. return_X_y : bool, default=False If True, returns ``(data, target)`` instead of a Bunch object. **kwargs : dict Optional key-value arguments. See https://keras.io/datasets. Returns ------- data : Bunch Dictionary-like object with all the data and metadata. (data, target) : tuple if ``return_X_y`` is True """ (X_train, y_train), (X_test, y_test) = DATASETS[name](**kwargs) if len(X_train.shape) > 2: name = name + ' ' + str(X_train.shape[1:]) + ' shaped' X_max = np.iinfo(X_train[0][0].dtype).max n_features = np.prod(X_train.shape[1:]) X_train = X_train.reshape([X_train.shape[0], n_features]) / X_max X_test = X_test.reshape([X_test.shape[0], n_features]) / X_max X = np.concatenate((X_train, X_test)) y = np.concatenate((y_train, y_test)) if return_X_y: return X, y return Bunch( data=X, target=y, train_indices=list(range(len(X_train))), validation_indices=[], test_indices=list(range(len(X_train), len(X))), inner_cv=None, outer_cv=None, DESCR=name, )