Source code for retriever

"""FileRetriever classes"""

import logging
import os
import sys
import json
import asyncio
from abc import ABC, abstractmethod
import collections
from dataclasses import dataclass

from typing import AsyncGenerator, Callable

from merge_utils import config, io_utils
from merge_utils.merge_set import MergeSet, MergeFileError
from merge_utils.metacat_utils import MetaCatWrapper

logger = logging.getLogger(__name__)

[docs] @dataclass class InputBatch: """Class representing a batch of input file data, starting at a specific skip index.""" skip: int = -1 files: list = None def __post_init__(self): if self.files is None: self.files = [] def __bool__(self): """Return True if the batch contains any files.""" return len(self.files) > 0 def __len__(self): """Return the number of files in the batch.""" return len(self.files) def __iter__(self): """Iterate over the files in the batch.""" return iter(self.files)
[docs] def file_serializer(obj): """Custom JSON serializer for MergeFileError objects""" if isinstance(obj, MergeFileError): return obj.name raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")
[docs] class MetaRetriever(ABC): """Base class for retrieving metadata from a source""" name: str = "metadata" file_owner: bool = True def __init__(self): if self.file_owner: self._files = MergeSet() self.dir = os.path.join(str(config.job.dir), 'cache', self.name) os.makedirs(self.dir, exist_ok=True) self.client = MetaCatWrapper() @property def files(self) -> MergeSet: """Return the set of files from the source""" return self._files @property def namespace(self) -> str: """ Return the default namespace for files without an explicit namespace. Checks the config for input.namespace, then output.namespace, then defaults to 'usertests'. """ if config.input.namespace: return str(config.input.namespace) if config.output.namespace: return str(config.output.namespace) return 'usertests'
[docs] async def get_done(self) -> None: """Asynchronously query MetaCat for merged files with the same tag as this job""" if config.validation.handling.already_done == 'include': return if not config.input.tag: logger.critical("Already-done checking requires a job tag to be specified!") sys.exit(1) tag = str(config.input.tag) logger.info("Checking MetaCat for already merged files with tag '%s'", tag) query = f"files where merge.tag == '{tag}' and dune.output_status == confirmed" dids = [] skip = 0 while True: batch_query = query + f" skip {skip} limit {config.validation.batch_size}" files = await self.client.query(batch_query, metadata=False, provenance=False) self.files.children.update(f['fid'] for f in files) dids.extend(f"{f['namespace']}:{f['name']}" for f in files) if len(files) < config.validation.batch_size: break skip += config.validation.batch_size if not dids: logger.info("No already merged files found with tag '%s'", tag) return io_utils.log_list("Found {n} merged file{s} with tag '%s':" % tag, dids, logging.INFO)
[docs] async def connect(self) -> None: """Connect to the MetaCat web API""" await self.client.connect() await self.get_done()
[docs] async def disconnect(self) -> None: """Disconnect from the MetaCat web API""" await self.client.disconnect()
[docs] @abstractmethod async def get_metadata(self, batch: InputBatch, limit: int) -> list: """ Asynchronously retrieve metadata for a specific batch of files. :param batch: empty InputBatch object with the skip index set :param limit: maximum number of files to retrieve :return: list of file metadata dictionaries """
# retrieve specific batch
[docs] async def get_batch(self, getter: Callable, batch: InputBatch, **kwargs) -> InputBatch: """ Asynchronously retrieve a batch of input data, with caching. :param getter: function to call to retrieve inputs :param batch: InputBatch object to retrieve data for :param kwargs: additional arguments to pass to getter :return: list of file dictionaries """ skip = batch.skip cache = os.path.join(self.dir, f"batch_{skip}.json") if os.path.exists(cache): logger.debug("Loading cached %s input batch %d", self.name, skip) files = io_utils.read_config_file(cache).get('files', []) else: logger.debug("Retrieving new %s input batch %d", self.name, skip) files = await getter(batch=batch, **kwargs) with open(cache, 'w', encoding="utf-8") as f: f.write(json.dumps({'files': files}, indent=2, default=file_serializer)) return InputBatch(skip=skip, files=files)
[docs] async def check_existence(self, files: list) -> None: """ Check that MetaCat records exist for a batch of input files. :param files: list of file metadata dictionaries to check """ logger.debug("Checking MetaCat to verify file records") skip_fids = not config.validation.check_fids provenance = False # To check for already merged files, we need the children of all input files if config.validation.handling.already_done != 'include': skip_fids = False provenance = True # Build list of DIDs to check, skipping files with FIDs if we can indices = {} for idx, file in enumerate(files): if file.get('errors'): continue if skip_fids and 'fid' in file: continue indices[f"{file['namespace']}:{file['name']}"] = idx # Request file info from MetaCat, with provenance if necessary dids = [{'did': did} for did in indices] res = await self.client.files(dids, metadata=False, provenance=provenance) # Update file list with the returned records for file in res: did = f"{file['namespace']}:{file['name']}" idx = indices.pop(did) files[idx].update(file) # Mark any files that were missing from MetaCat with an error if indices: for idx in indices.values(): files[idx]['errors'] = MergeFileError.NO_METADATA io_utils.log_list("MetaCat missing {n} file record{s}:", indices, logging.ERROR) io_utils.log_print("Did you mean to enable the grandparents option?", logging.ERROR)
[docs] async def get_siblings(self, files: list) -> None: """ We check for already merged files by looking at the children of the input files. But in grandparents mode, we actually want the children of the input file parents instead. We don't need the original children in that case, so just replace them with the siblings. :param files: list of file metadata dictionaries to check """ # Collect list of all parent FIDs to check parent_fids = set() for file in files: for parent in file['parents']: parent_fids.add(parent['fid']) fid_list = [{'fid': fid} for fid in parent_fids] # Retrieve children of parents from MetaCat parents = await self.client.files(fid_list, metadata = False, provenance = True) children = collections.defaultdict(set) for parent in parents: child_fids = set(c['fid'] for c in parent.get('children', [])) children[parent['fid']].update(child_fids) # Override the children of the input files for file in files: siblings = set() for parent in file['parents']: siblings.update(children[parent['fid']]) siblings.discard(file.get('fid', None)) file['children'] = [{'fid': f} for f in siblings]
[docs] async def check_parents(self, files: list) -> None: """ Check that MetaCat records exist for the parents of a batch of input files. Also get the siblings if we need them for the already merged check. :param files: list of file metadata dictionaries to check """ logger.debug("Checking MetaCat to verify parent records") skip_fids = not config.validation.check_fids provenance = False # To check for already merged files, we need the children of all input files if config.validation.handling.already_done != 'include': skip_fids = False provenance = True # Collect list of all parent DIDs to check parent_dids = set() for file in [f for f in files if not f.get('errors')]: for parent in file.get('parents', []): if skip_fids and 'fid' in parent: continue parent_dids.add(f"{parent['namespace']}:{parent['name']}") # Request parent info from MetaCat, with provenance if necessary dids = [{'did': did} for did in parent_dids] res = await self.client.files(dids, metadata=False, provenance=provenance) parents = {f"{file['namespace']}:{file['name']}": file for file in res} # Check for missing parents, and build child list if needed for already merged check missing = set() for file in [f for f in files if not f.get('errors')]: children = set() for parent in file.get('parents', []): if skip_fids and 'fid' in parent: continue did = f"{parent['namespace']}:{parent['name']}" parent_info = parents.get(did) if not parent_info: file['errors'] = MergeFileError.UNDECLARED missing.add(f"did: {did}") continue parent['fid'] = parent_info.get('fid') if not provenance: continue child_fids = set(c['fid'] for c in parent.get('children', [])) children.update(child_fids) if provenance: children.discard(file.get('fid', None)) file['children'] = [{'fid': f} for f in children] # Log any missing parents if missing: io_utils.log_list("MetaCat missing {n} grandparent record{s}:", missing, logging.ERROR)
[docs] async def get_files(self, query: list) -> list: """ Asynchronously retrieve file metadata for a specific list of DIDs. Also gets the siblings if we need them for the already merged check. :param query: list of dictionaries with 'did' keys to retrieve :return: list of file metadata dictionaries """ # In grandparents mode, we need the parents of the input files parents = bool(config.output.grandparents) # To check for already merged files, we need the children of the input files children = (config.validation.handling.already_done != 'include') # Request files from MetaCat, with provenance if necessary provenance = parents or children files = await self.client.files(query, metadata = True, provenance = provenance) # In grandparents mode, we actually need the children of the parents if parents and children: await self.get_siblings(files) return files
[docs] async def input_batches(self) -> AsyncGenerator[InputBatch, None]: """ Asynchronously retrieve input file metadata in batches. :return: InputBatch object containing skip index and list of MergeFile objects """ skip0 = int(config.input.skip or 0) skip = skip0 step = int(config.validation.batch_size) task = None while True: # Determine file limit for next batch limit = step if config.input.limit: limit = min(limit, config.input.limit + skip0 - skip) # Get previous batch to process, if we have a request in flight batch = await task if task is not None else None task = None # Start request for next batch if limit > 0: req = InputBatch(skip=skip) task = asyncio.create_task(self.get_batch(self.get_metadata, req, limit=limit)) # Increment skip for next batch skip += step # Process previous batch while we wait, if we have one if batch is None: continue logger.info("Processing new %s input batch %d", self.name, batch.skip) # Add file to merge set, and yield if we added any added = await asyncio.to_thread(self.files.add, batch.skip, batch.files) if added: yield InputBatch(skip=batch.skip, files=added) # If there is no next task, we're done if task is None: break # If the last batch was a partial batch, we're done if len(batch) < step: # Need to wait for the the last task to finish await task break # Yield empty batch to signal completion yield InputBatch()
async def _loop(self) -> None: """Repeatedly get input_batches until all files are retrieved.""" # Connect to source await self.connect() # Loop over batches, checking for errors as we go async for _ in self.input_batches(): self.files.check_errors() # Close connections and do final error checking await self.client.disconnect() self.files.check_errors(final = True)
[docs] def run(self) -> None: """Retrieve metadata for all files.""" try: asyncio.run(self._loop()) except ValueError as err: logger.critical("%s", err) sys.exit(1)
[docs] class QueryRetriever(MetaRetriever): """Class for retrieving metadata from MetaCat using an MQL query.""" name = "metacat_query" def __init__(self, query: str): """ Initialize the QueryRetriever with an MQL query. :param query: MQL query to find files """ super().__init__() if 'skip' in query or 'limit' in query: logger.warning("Consider using command line options for 'skip' and 'limit'!") elif query.endswith(' ordered'): logger.info("Merge-Utils will append the 'ordered' keyword to queries automatically.") else: query += ' ordered' self.query = query
[docs] async def get_metadata(self, batch: InputBatch, limit: int) -> list: """ Asynchronously query MetaCat for a specific batch of files :param batch: InputBatch object with skip index set :param limit: maximum number of files to retrieve :return: list of file metadata dictionaries """ query_batch = self.query + f" skip {batch.skip} limit {limit}" # In grandparents mode, we need the parents of the input files parents = bool(config.output.grandparents) # To check for already merged files, we need the children of the input files children = (config.validation.handling.already_done != 'include') # Query MetaCat, with provenance if necessary provenance = parents or children files = await self.client.query(query_batch, metadata = True, provenance = provenance) # In grandparents mode, we actually need the children of the parents if parents and children: await self.get_siblings(files) return files
[docs] class DidRetriever(MetaRetriever): """Class for retrieving metadata from MetaCat using a list of DIDs.""" name: str = "metacat_dids" def __init__(self, dids: list, dupes: set = None): """ Initialize the DidRetriever with a list of DIDs. :param dids: list of file DIDs to find :param dupes: set of indices of duplicate DIDs """ super().__init__() self.dids = dids self.check_namespaces() self.dupes = dupes if dupes is not None else self.check_duplicates()
[docs] def check_namespaces(self) -> None: """ Check DID list for namespace issues. """ # Check namespaces namespaces = collections.defaultdict(int) for did in self.dids: parts = did.split(':', 1) if len(parts) == 2: namespaces[parts[0]] += 1 if len(namespaces) == 0: ns = self.namespace logger.info("DID list missing namespaces, using default namespace '%s'", ns) self.dids = [f"{ns}:{did}" for did in self.dids] elif len(namespaces) == 1: ns, count = next(iter(namespaces.items())) if count < len(self.dids): logger.warning("Some DIDs missing namespaces, assuming shared namespace '%s'", ns) self.dids = [f"{ns}:{did}" if ':' not in did else did for did in self.dids] elif config.validation.handling.inconsistent == 'quit': io_utils.log_list("DID list contains multiple namespaces:", namespaces, logging.CRITICAL) sys.exit(1) else: count = sum(namespaces.values()) ns = self.namespace if count < len(self.dids): logger.warning("Some DIDs missing namespaces, assuming default namespace '%s'", ns) self.dids = [f"{ns}:{did}" if ':' not in did else did for did in self.dids]
[docs] def check_duplicates(self) -> set: """ Check DID list for duplicate entries. :return: set of indices of duplicate DIDs """ seen = set() dupes = set() for idx, did in enumerate(self.dids): if did in seen: dupes.add(idx) seen.add(did) if dupes and config.validation.handling.duplicate == 'quit': io_utils.log_list("DID list contains {n} duplicate file{s}:", list(self.dupes), logging.CRITICAL) sys.exit(1) return dupes
[docs] async def get_metadata(self, batch: InputBatch, limit: int) -> list: """ Asynchronously request a batch of DIDs from MetaCat :param batch: InputBatch object with skip index set :param limit: maximum number of files to retrieve :return: list of file metadata dictionaries """ skip = batch.skip dids = self.dids[skip:skip+limit] if len(dids) == 0: logger.debug("No DIDs to request for skip=%d, limit=%d", skip, limit) return [] # Build query and list of placeholder files query = [] files = [] indices = {} for idx, did in enumerate(dids): namespace, name = did.split(':') placeholder = { 'namespace': namespace, 'name': name, 'errors': MergeFileError.NO_METADATA } if skip+idx in self.dupes: logger.debug("Skipping duplicate DID: %s", did) placeholder['errors'] = MergeFileError.DUPLICATE else: query.append({'did': did}) indices[did] = idx files.append(placeholder) if len(query) == 0: logger.debug("All DIDs in batch are duplicates, skipping MetaCat request") return files # Request files from MetaCat res = await self.get_files(query) # Add returned files to output list in correct order for file in res: files[indices[f"{file['namespace']}:{file['name']}"]] = file return files
[docs] class LocalMetaRetriever(MetaRetriever): """MetaRetriever for local files""" name = "local_meta" def __init__(self, paths: list): """ Initialize the LocalMetaRetriever with a list of json files. :param paths: list of metadata file paths """ super().__init__() self.paths = paths
[docs] async def get_metadata(self, batch: InputBatch, limit: int) -> list: """ Asynchronously retrieve metadata for a specific batch of files. :param batch: InputBatch object containing skip index and list of file names :param limit: maximum number of files to retrieve :return: list of file metadata dictionaries """ files = [] skip = batch.skip end = min(skip + limit, len(self.paths)) namespace = self.namespace # Read metadata from local files if possible missing = {} for path in self.paths[skip:end]: name = os.path.basename(path).rsplit('.', 1)[0] metadata = io_utils.read_json(path) # If file is missing or unreadable, create a placeholder if not metadata: metadata = { 'namespace': namespace, 'name': name, 'errors': MergeFileError.NO_METADATA } missing[name] = len(files) files.append(metadata) # Make sure files exist in MetaCat, for parent listing if config.output.grandparents: await self.check_parents(files) else: await self.check_existence(files) # If we were missing any local files, try to find them in MetaCat if missing: io_utils.log_list("Checking MetaCat for {n} missing metadata file{s}:", missing, logging.INFO) res = await self.get_files([{'did': f"{namespace}:{name}"} for name in missing]) for file in res: files[missing[file['name']]] = file return files
[docs] def get() -> MetaRetriever: """ Create and return a metadata retriever based on input mode: files: LocalMetaRetriever if any metadata files were provided, otherwise DidRetriever dids: DidRetriever query: QueryRetriever dataset: QueryRetriever with query for files in the specified dataset :return: MetaRetriever object for retrieving file metadata """ # Determine input mode and retrieve metadata inputs = [str(f) for f in config.input.inputs] if config.input.mode == 'files': # We need to sort the input files into data and metadata files # Start by getting the set of all metadata file names (without the .json suffix) seen = set(os.path.basename(f)[:-5] for f in inputs if os.path.splitext(f)[1] == '.json') # Then go through the input list in order json_files = [] for path in inputs: # If we have a JSON file, just add it to the list and continue if os.path.splitext(path)[1] == '.json': json_files.append(path) continue # Skip data files if we already have a metadata file with that name name = os.path.basename(path) if name in seen: continue seen.add(name) # Otherwise, try to find a matching metadata file in the same directory meta_path = path + '.json' if os.path.isfile(meta_path): json_files.append(meta_path) continue # Or in the provided search directories for search_dir in config.input.search_dirs: meta_path = os.path.join(search_dir, name + '.json') if os.path.isfile(meta_path): json_files.append(meta_path) break # If we found any metadata files, return a LocalMetaRetriever if len(json_files) > 0: return LocalMetaRetriever(paths=json_files) # Otherwise, return a DidRetriever based on the input file names return DidRetriever(dids=[os.path.basename(f) for f in inputs]) if config.input.mode == 'dids': return DidRetriever(dids=inputs) if config.input.mode == 'query': if len(inputs) != 1: logger.critical("Multiple query inputs detected, did you forget to quote the query?") sys.exit(1) query = str(inputs[0]) return QueryRetriever(query=query) if config.input.mode == 'dataset': if len(inputs) != 1: logger.critical("Dataset input mode currently only supports a single dataset name.") sys.exit(1) query = f"files from {inputs[0]} where dune.output_status=confirmed" return QueryRetriever(query=query) logger.critical("Unknown input mode: %s", config.input.mode) sys.exit(1)