Source code for skdatasets.repositories.libsvm

"""
LIBSVM datasets (https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets).

@author: David Diaz Vico
@license: MIT
"""
from __future__ import annotations

import os
import sys
from typing import Final, Literal, Sequence, Tuple, overload

import numpy as np
import scipy as sp
from sklearn.datasets import load_svmlight_file, load_svmlight_files
from sklearn.model_selection import PredefinedSplit
from sklearn.utils import Bunch

from .base import DatasetNotFoundError, fetch_file

BASE_URL: Final = 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets'
COLLECTIONS: Final = frozenset((
    'binary',
    'multiclass',
    'regression',
    'string',
))


def _fetch_partition(
    collection: str,
    name: str,
    partition: str,
    data_home: str | None = None,
) -> str | None:
    """Fetch dataset partition."""
    subfolder = os.path.join('libsvm', collection)
    dataname = name.replace('/', '-')

    url = f"{BASE_URL}/{collection}/{name}{partition}"

    for data_url in (f"{url}.bz2", url):
        try:
            return os.fspath(
                fetch_file(
                    dataname,
                    urlname=data_url,
                    subfolder=subfolder,
                    data_home=data_home,
                ),
            )
        except DatasetNotFoundError:
            pass

    return None


def _load(
    collection: str,
    name: str,
    data_home: str | None = None,
) -> Tuple[
    np.typing.NDArray[float],
    np.typing.NDArray[int | float],
    Sequence[int],
    Sequence[int],
    Sequence[int],
    PredefinedSplit,
]:
    """Load dataset."""
    filename = _fetch_partition(collection, name, '', data_home)
    filename_tr = _fetch_partition(collection, name, '.tr', data_home)
    filename_val = _fetch_partition(collection, name, '.val', data_home)
    filename_t = _fetch_partition(collection, name, '.t', data_home)
    filename_r = _fetch_partition(collection, name, '.r', data_home)

    if (filename_tr is not None) and (filename_val is not None) and (filename_t is not None):

        _, _, X_tr, y_tr, X_val, y_val, X_test, y_test = load_svmlight_files([
            filename,
            filename_tr,
            filename_val,
            filename_t,
        ])

        cv = PredefinedSplit([-1] * X_tr.shape[0] + [0] * X_val.shape[0])

        X = sp.sparse.vstack((X_tr, X_val, X_test))
        y = np.hstack((y_tr, y_val, y_test))

        # Compute indices
        train_indices = list(range(X_tr.shape[0]))
        validation_indices = list(range(
            X_tr.shape[0],
            X_tr.shape[0] + X_val.shape[0],
        ))
        test_indices = list(range(X_tr.shape[0] + X_val.shape[0], X.shape[0]))

    elif (filename_tr is not None) and (filename_val is not None):

        _, _, X_tr, y_tr, X_val, y_val = load_svmlight_files([
            filename,
            filename_tr,
            filename_val,
        ])

        cv = PredefinedSplit([-1] * X_tr.shape[0] + [0] * X_val.shape[0])

        X = sp.sparse.vstack((X_tr, X_val))
        y = np.hstack((y_tr, y_val))

        # Compute indices
        train_indices = list(range(X_tr.shape[0]))
        validation_indices = list(range(X_tr.shape[0], X.shape[0]))
        test_indices = []

    elif (filename_t is not None) and (filename_r is not None):

        X_tr, y_tr, X_test, y_test, X_remaining, y_remaining = (
            load_svmlight_files([
                filename,
                filename_t,
                filename_r,
            ])
        )

        X = sp.sparse.vstack((X_tr, X_test, X_remaining))
        y = np.hstack((y_tr, y_test, y_remaining))

        # Compute indices
        train_indices = list(range(X_tr.shape[0]))
        validation_indices = []
        test_indices = list(
            range(
                X_tr.shape[0], X_tr.shape[0] + X_test.shape[0],
            ),
        )

        cv = None

    elif filename_t is not None:

        X_tr, y_tr, X_test, y_test = load_svmlight_files([
            filename,
            filename_t,
        ])

        X = sp.sparse.vstack((X_tr, X_test))
        y = np.hstack((y_tr, y_test))

        # Compute indices
        train_indices = list(range(X_tr.shape[0]))
        validation_indices = []
        test_indices = list(range(X_tr.shape[0], X.shape[0]))

        cv = None

    else:

        X, y = load_svmlight_file(filename)

        # Compute indices
        train_indices = []
        validation_indices = []
        test_indices = []

        cv = None

    return X, y, train_indices, validation_indices, test_indices, cv


@overload
def fetch(
    collection: str,
    name: str,
    *,
    data_home: str | None = None,
    return_X_y: Literal[False] = False,
) -> Bunch:
    pass


@overload
def fetch(
    collection: str,
    name: str,
    *,
    data_home: str | None = None,
    return_X_y: Literal[True],
) -> Tuple[np.typing.NDArray[float], np.typing.NDArray[int | float]]:
    pass


[docs]def fetch( collection: str, name: str, *, data_home: str | None = None, return_X_y: bool = False, ) -> Bunch | Tuple[np.typing.NDArray[float], np.typing.NDArray[int | float]]: """ Fetch LIBSVM dataset. Fetch a LIBSVM dataset by collection and name. More info at https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets. Parameters ---------- collection : string Collection name. name : string Dataset name. data_home : string or None, default None Specify another download and cache folder for the data sets. By default all scikit-learn data is stored in ‘~/scikit_learn_data’ subfolders. return_X_y : bool, default=False If True, returns ``(data, target)`` instead of a Bunch object. Returns ------- data : Bunch Dictionary-like object with all the data and metadata. (data, target) : tuple if ``return_X_y`` is True """ if collection not in COLLECTIONS: raise Exception('Avaliable collections are ' + str(list(COLLECTIONS))) X, y, train_indices, validation_indices, test_indices, cv = _load( collection, name, data_home=data_home, ) if return_X_y: return X, y return Bunch( data=X, target=y, train_indices=train_indices, validation_indices=validation_indices, test_indices=test_indices, inner_cv=cv, outer_cv=None, DESCR=name, )