Source code for daisy.data_sources.data_source

# Copyright (C) 2024-2025 DAI-Labor and others
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
"""A collection of the core interface and base classes for the first component of any
data handler (see the docstring of the data handler class), that provides the origin of
any data points being processed for further (ML) tasks. Supports generic generators,
but also remote communication endpoints that hand over generic data points in
streaming-manner, and any other implementations of the DataSource class. Note each
different kind of data may need its own implementation of the DataSource.

Author: Fabian Hofmann, Jonathan Ackerschewski
Modified: 04.11.24
"""

import csv
import logging
import os
from abc import ABC, abstractmethod
from typing import IO, Iterator

from natsort import natsorted

from daisy.communication import StreamEndpoint


[docs] class DataSource(ABC): """An abstract wrapper around a generator-like structure that has to yield data points as objects as they come for processing. That generator may be infinite or finite, as long as it is bounded on both sides by the following two methods that must be implemented: - open(): Enables the "generator" to provision data points. - close(): Closes the "generator". Note that as DataHandler wraps itself around given data sources to retrieve objects, open() and close() do not need to be implemented to be idempotent and arbitrarily permutable. Same can be assumed for __iter__() as it will only be called when the data source has been opened already. At the same time, __iter__() must be exhausted after close() has been called. """ _logger: logging.Logger def __init__(self, name: str = ""): """Creates a data source. Note that this should not enable the immediate generation of data points via __iter__() --- this behavior is implemented through open() (see the class documentation for more information). :param name: Name of data source for logging purposes. """ self._logger = logging.getLogger(name)
[docs] @abstractmethod def open(self): """Prepares the data source to be used for data point generation, setting up necessary environment variables, starting up background processes to read/generate data, etc. """ raise NotImplementedError
[docs] @abstractmethod def close(self): """Closes the data source after which data point generation is no longer available until opened again. """ raise NotImplementedError
@abstractmethod def __iter__(self) -> Iterator[object]: """After opened (see open()), returns a generator - either the object itself or creates a new one (e.g. through use of the yield statement). :return: Generator object for data points as objects. """ raise NotImplementedError
[docs] class SimpleDataSource(DataSource): """The simplest productive data source --- an actual wrapper around a generator that is always open and cannot be closed, yielding data points as objects as they are yielded. Can be infinite or finite; no matter, no control over the generator is natively supported. """ _generator: Iterator[object] def __init__(self, generator: Iterator[object], name: str = ""): """Creates a data source, simply wrapping it around the given generator. :param generator: Generator object from which data points are retrieved. :param name: Name of data source for logging purposes. """ super().__init__(name) self._generator = generator
[docs] def open(self): pass
[docs] def close(self): pass
def __iter__(self) -> Iterator[object]: """Returns the wrapped generator, requiring neither open() nor close(). :return: Generator object for data points as objects. """ return self._generator
[docs] class SimpleRemoteDataSource(DataSource): """The simple wrapper implementation to support and handle remote streaming endpoints of the Endpoint module as data sources. Considered infinite in nature, as it allows the generation of data point objects from a connected endpoint, until the client closes the data source. """ _endpoint: StreamEndpoint def __init__(self, endpoint: StreamEndpoint, name: str = ""): """Creates a new remote data source from a given stream endpoint. If no endpoint is provided, creates a new one instead with basic parameters. :param endpoint: Streaming endpoint from which data points are retrieved. :param name: Name of data source for logging purposes. """ super().__init__(name) self._logger.info("Initializing remote data source...") self._endpoint = endpoint self._logger.info("Remote data source initialized.")
[docs] def open(self): """Starts and opens/connects the endpoint of the data source.""" self._logger.info("Starting remote data source...") try: self._endpoint.start() except RuntimeError: pass self._logger.info("Remote data source started.")
[docs] def close(self): """Stops and closes the endpoint of the data source.""" self._logger.info("Stopping remote data source...") try: self._endpoint.stop() except RuntimeError: pass self._logger.info("Remote data source stopped.")
def __iter__(self) -> Iterator[object]: """Returns the wrapped endpoint generator, as it supports object retrieval directly. :return: Endpoint generator object for data points as objects. """ return self._endpoint.__iter__()
[docs] class CSVFileDataSource(DataSource): """This implementation of the DataSource reads one or multiple CSV files and yields their content. The output of this class are dictionaries containing the headers (first row) of the CSV files as the keys and the line as the values. Each CSV is, therefore, expected to have a header line as the first row. """ _files: list[str] _cur_index: int _cur_handle: IO | None _cur_csv: csv.reader _cur_headers: list[str] def __init__(self, files: str | list[str], name: str = ""): """Creates a new CSV file data source. Either a single file or a list of files are expected as the input. :param files: Either a single CSV file/directory or a list of CSV files/directories to read. """ super().__init__(name) self._logger.info("Initializing CSV file data source...") if isinstance(files, str): tmp_files = [files] elif isinstance(files, list): tmp_files = files else: raise TypeError( f"Expected either string or list of strings, but got {type(files)}" ) self._files = [] for path in tmp_files: if os.path.isdir(path): self._files += [ os.path.join(path, file) for file in natsorted(os.listdir(path)) ] else: self._files.append(path) self._cur_handle = None self._logger.info("CSV file data source initialized.")
[docs] def open(self): """Starts the CSV file data source by setting required parameters.""" self._logger.info("Opening CSV file data source...") self._cur_index = 0
[docs] def close(self): """Closes the CSV file data source.""" self._logger.info("Closing CSV file data source...") if self._cur_handle: self._cur_handle.close() self._cur_handle = None
def _open_next_file(self): """Opens the next CSV file to read. First, the last read file is closed. Afterward, the next CSV file is opened and the headers are extracted. """ self._logger.info("Opening next CSV file...") if self._cur_handle: self._cur_handle.close() self._cur_handle = None next_file = self._files[self._cur_index] self._cur_index += 1 self._cur_handle = open(next_file, "r") self._cur_csv = csv.reader(self._cur_handle) self._cur_headers = next(self._cur_csv) self._logger.info("Next CSV file opened and headers extracted.") def _line_to_dict(self, line, header) -> dict[str, object]: """Converts a line into a dictionary using the provided headers. :param line: The line to convert. :param header: The headers to use as the keys. :return: A dictionary containing the headers as keys and the line as values. """ cur_dict = {} if len(line) != len(header): raise ValueError( f"Malformed line detected. Line length does not match header length: {line}" ) for header_counter in range(len(header)): cur_dict[header[header_counter]] = line[header_counter] return cur_dict def __iter__(self) -> Iterator[dict[str, object]]: """Iterates through provided CSV files and yields each line as a dictionary.""" while self._cur_index < len(self._files): self._open_next_file() for line in self._cur_csv: cur_dict = self._line_to_dict(line, self._cur_headers) yield cur_dict self._logger.info("All CSV files exhausted.")