"""Utilites for reading and writing files and other I/O operations"""
import os
import sys
import logging
import logging.config
import json
import pathlib
import math
from datetime import datetime, timezone
from collections.abc import Iterable
# tomllib was added to the standard library in Python 3.10, need tomli for DUNE
try:
import tomllib # type: ignore
except ImportError:
import tomli as tomllib
import yaml
logger = logging.getLogger(__name__)
[docs]
def pkg_dir() -> str:
"""Get the base directory of the package"""
directory = os.environ.get('MERGE_UTILS_DIR')
if directory:
return directory
return os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
[docs]
def src_dir() -> str:
"""Get the source directory of the package"""
return os.path.join(pkg_dir(), 'src', 'merge_utils')
[docs]
def get_timestamp() -> str:
"""Get the current timestamp as a string"""
return datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S")
[docs]
def expand_path(path: str, base_dir: str = None) -> str:
"""
Expand environment variables and user home in a path.
If the path is relative and a base directory is provided,
make the path absolute using the base directory.
:param path: Path to expand
:param base_dir: Base directory for relative paths
:return: Expanded path
"""
path = os.path.expanduser(os.path.expandvars(path))
if not os.path.isabs(path) and base_dir is not None:
path = os.path.join(os.path.expanduser(os.path.expandvars(base_dir)), path)
return os.path.abspath(path)
[docs]
def find_file(name: str, dirs: list[str] = None, recursive: bool = False) -> str:
"""
Locate a file by name in a list of directories
:param name: File name or path to locate
:param dirs: List of directories to search
:param recursive: Check sub-directories recursively
:return: Full path to the located file
:raises FileNotFoundError: If the file does not exist
"""
path = os.path.expanduser(os.path.expandvars(name))
# First, check if the path exists as given
if os.path.exists(path):
return os.path.abspath(path)
# If the path is absolute, check if it exists
if os.path.isabs(path):
raise FileNotFoundError(f"Failed to read file {path}")
# Search the provided directories
if dirs is None:
dirs = []
for directory in dirs:
test_path = os.path.expanduser(os.path.expandvars(os.path.join(directory, name)))
if not os.path.isabs(test_path):
test_path = os.path.join(pkg_dir(), test_path)
if os.path.exists(test_path):
return os.path.abspath(test_path)
if recursive:
dirs.extend([entry.path for entry in os.scandir(directory) if entry.is_dir()])
# For FCL files, also check the FHICL_FILE_PATH environment variable
if name.endswith(".fcl"):
fcl_dirs = os.getenv("FHICL_FILE_PATH")
if fcl_dirs is None:
logger.warning("FHICL_FILE_PATH environment variable is not set")
else:
for directory in fcl_dirs.split(':'):
test_path = os.path.expanduser(os.path.expandvars(os.path.join(directory, name)))
if os.path.exists(test_path):
return os.path.abspath(test_path)
# If we reach this point, the file was not found
raise FileNotFoundError(f"Failed to read file {name}")
[docs]
def find_cfg(name: str) -> str:
"""
Find the full path to a configuration file
:param name: Name of the configuration file
:return: Full path to the configuration file
:raises FileNotFoundError: If the file does not exist
"""
return find_file(name, [os.path.join(pkg_dir(), "config")], recursive=True)
[docs]
def find_runner(name: str) -> str:
"""
Find the full path to a runner script
:param name: Name of the runner script
:return: Full path to the runner script
:raises FileNotFoundError: If the file does not exist
"""
return find_file(name, [os.path.join(pkg_dir(), "src", "runners")])
[docs]
def read_config_file(name: str = None) -> dict:
"""
Read a configuration file in JSON, TOML, or YAML format
:param name: Name of the configuration file
:return: Dictionary containing the configuration settings
:raises FileNotFoundError: If the file does not exist
:raises ValueError: If the file type is not supported
"""
if name is None:
return None
path = find_cfg(name)
suffix = pathlib.Path(path).suffix
if suffix in [".json"]:
logger.debug("Reading JSON file %s", path)
with open(path, encoding="utf-8") as f:
cfg = json.load(f)
elif suffix in [".toml"]:
logger.debug("Reading TOML file %s", path)
with open(path, mode="rb") as f:
cfg = tomllib.load(f)
elif suffix in [".yaml", ".yml"]:
logger.debug("Reading YAML file %s", path)
with open(path, encoding="utf-8") as f:
cfg = yaml.safe_load(f)
else:
logger.error("Unknown file type: %s", suffix)
raise ValueError(f"Unknown file type: {suffix}")
return cfg
[docs]
def setup_log(name: str, log_file: str = None, verbosity: int = 0) -> None:
"""Configure logging"""
logger_config = read_config_file("logging.json")
if log_file:
logger_config['handlers']['file']['filename'] = log_file
else:
log_file = logger_config['handlers']['file']['filename']
if not os.path.isabs(log_file):
log_file = os.path.join(pkg_dir(), "logs", log_file)
logger_config['handlers']['file']['filename'] = log_file
# If we're appending to an existing log file, add a newline before the new log
if os.path.exists(log_file):
with open(logger_config['handlers']['file']['filename'], 'a', encoding="utf-8") as logfile:
logfile.write("\n")
logging.config.dictConfig(logger_config)
logger.info("Starting script %s", os.path.basename(name))
set_log_level(verbosity)
[docs]
def set_log_level(level: int) -> None:
"""Override the logging level for the console"""
if level == 0:
level = "ERROR"
elif level == 1:
level = "WARNING"
elif level == 2:
level = "INFO"
elif level >= 3:
level = "DEBUG"
for handler in logging.getLogger().handlers:
if handler.get_name() == "console":
handler.setLevel(level)
handler.addFilter(lambda record:
not hasattr(record, 'block') or record.block != "console")
[docs]
def log_print(msg: str, level=logging.INFO) -> None:
"""Print a message and save it to the log file"""
logger.log(level, msg, stacklevel=2, extra={'block': 'console'})
print(msg)
[docs]
def log_nonzero(msg: str, value: int, level=logging.DEBUG) -> int:
"""Log a message if the value is non-zero"""
if value == 0:
return 0
if value == 1:
msg = msg.format(n=1, s="", es="")
else:
msg = msg.format(n=value, s="s", es="es")
logger.log(level, msg, stacklevel=2)
return value
[docs]
def log_list(msg: str, items: Iterable, level=logging.WARNING) -> int:
"""Log a message for a list of items"""
total = len(items)
if total == 0:
return 0
if total == 1:
msg = [msg.format(n=1, s="", es="")]
else:
msg = [msg.format(n=total, s="s", es="es")]
msg += [f"\n {item}" for item in sorted(items)]
logger.log(level, "".join(msg), stacklevel=2)
return total
[docs]
def log_dict(msg: str, items: dict, level=logging.WARNING) -> int:
"""Log a message for a dictionary of items with counts"""
total = sum(items.values())
if total == 0:
return 0
if total == 1:
msg = [msg.format(n=1, s="", es="")]
else:
msg = [msg.format(n=total, s="s", es="es")]
mult = max(items.values())
if mult == 1:
msg += [f"\n {item}" for item in sorted(items)]
else:
pad = int(math.log10(mult)+1)
msg += [f"\n ({count:{pad}}) {item}" for item, count in sorted(items.items())]
logger.log(level, "".join(msg), stacklevel=2)
return total