# -*- coding: utf-8 -*-
# Copyright (c) 2025-present tandemdude
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from __future__ import annotations
__all__ = ["load", "loads", "parser_registry"]
import functools
import json
import os
import pathlib
import typing as t
from confspec import helpers
from confspec import interpolate
from confspec import parsers
if t.TYPE_CHECKING:
from collections.abc import Callable
import msgspec
import pydantic
StructT = t.TypeVar("StructT", bound=msgspec.Struct)
BaseModelT = t.TypeVar("BaseModelT", bound=pydantic.BaseModel)
parser_registry: dict[str, type[parsers.Parser]] = {
"json": parsers.JsonParser,
"toml": parsers.TomlParser,
"yaml": parsers.YamlParser,
"yml": parsers.YamlParser,
}
"""Dictionary mapping file extension to parser class used when parsing data of that format."""
KnownFormats: t.TypeAlias = t.Literal["json", "toml", "yaml", "yml"]
def _merge_dicts(d1: dict[str, t.Any], d2: dict[str, t.Any]) -> dict[str, t.Any]:
for key in d2:
if key in d1 and isinstance(d1[key], dict) and isinstance(d2[key], dict):
_merge_dicts(d1[key], d2[key])
continue
d1[key] = d2[key]
return d1
def _loads(
hierarchy: list[str | bytes],
fmt: KnownFormats | str,
/,
*,
cls: type[pydantic.BaseModel | msgspec.Struct] | None = None,
strict: bool = False,
dec_hook: Callable[[type[t.Any], t.Any], t.Any] | None = None,
) -> dict[str, t.Any] | pydantic.BaseModel | msgspec.Struct:
parser = parser_registry.get(fmt)
if parser is None:
raise NotImplementedError(f"no parser registered for format {fmt!r}")
mappings: list[dict[str, t.Any]] = []
for raw in hierarchy:
mappings.append(parser().read(raw.encode() if isinstance(raw, str) else raw))
parsed = functools.reduce(_merge_dicts, mappings)
interpolated = interpolate.InterpolationVisitor().visit(parsed)
if cls is None:
return interpolated
dumped = json.dumps(interpolated)
if helpers.is_pydantic(cls):
return cls.model_validate_json(dumped, strict=strict)
elif helpers.is_msgspec(cls):
import msgspec
return msgspec.json.decode(dumped, type=cls, strict=strict, dec_hook=dec_hook)
raise NotImplementedError(f"unknown class '{cls}' provided")
@t.overload
def loads(raw: str | bytes, fmt: KnownFormats | str, /) -> dict[str, t.Any]: ...
@t.overload
def loads(
raw: str | bytes,
fmt: KnownFormats | str,
/,
*,
cls: type[StructT],
strict: bool = False,
dec_hook: Callable[[type[t.Any], t.Any], t.Any] | None = None,
) -> StructT: ...
@t.overload
def loads(
raw: str | bytes, fmt: KnownFormats | str, /, *, cls: type[BaseModelT], strict: bool = False
) -> BaseModelT: ...
[docs]
def loads(
raw: str | bytes,
fmt: KnownFormats | str,
/,
*,
cls: type[pydantic.BaseModel | msgspec.Struct] | None = None,
strict: bool = False,
dec_hook: Callable[[type[t.Any], t.Any], t.Any] | None = None,
) -> dict[str, t.Any] | pydantic.BaseModel | msgspec.Struct:
"""
Like :meth:`~load`, but loads the configuration from the given string or bytes object instead. You must
pass a format when using this method so that the library knows which parser to use. All other arguments
have the same meaning as in :meth:`~load`.
"""
return _loads([raw], fmt, cls=cls, strict=strict, dec_hook=dec_hook)
@t.overload
def load(path: str | pathlib.Path, /, *, env: str | None = None) -> dict[str, t.Any]: ...
@t.overload
def load(
path: str | pathlib.Path,
/,
*,
cls: type[StructT],
env: str | None = None,
strict: bool = False,
dec_hook: Callable[[type[t.Any], t.Any], t.Any] | None = None,
) -> StructT: ...
@t.overload
def load(path: str | pathlib.Path, /, *, cls: type[BaseModelT], strict: bool = False) -> BaseModelT: ...
[docs]
def load(
path: str | pathlib.Path,
/,
*,
cls: type[pydantic.BaseModel | msgspec.Struct] | None = None,
env: str | None = None,
strict: bool = False,
dec_hook: Callable[[type[t.Any], t.Any], t.Any] | None = None,
) -> dict[str, t.Any] | pydantic.BaseModel | msgspec.Struct:
"""
Loads arbitrary configuration from the given path, performing environment variable substitutions, and
parses it into the given class (or to a dictionary if no class was provided).
Currently supported formats are: yaml, toml and JSON. Additional formats can be supported by creating your own
custom implementation of :obj:`~confspec.parsers.abc.Parser` and registering it with :obj:`~parser_registry`.
Args:
path: The path to the configuration file.
cls: The pydantic BaseModel, or msgspec Struct to parse the configuration into. If :obj:`None`, the
configuration will be parsed into a dictionary. Defaults to :obj:`None`.
env: The name of an additional environment configuration to load and merge with the base configuration.
When provided, an environment-specific file will be looked up using the same base name as the main
configuration file and the given environment as a suffix (e.g., ``config.prod.yaml`` for ``env="prod"``).
If not provided, the environment name will be read from the ``CONFSPEC_ENV`` environment variable if set.
If neither is specified, only the base configuration file will be loaded.
strict: Whether the parsing behaviour of pydantic/msgspec should be in strict mode. Defaults to :obj:`False`.
If :obj:`True`, then parsers will not perform type coercion (e.g. digit string to int).
dec_hook: Optional decode hook for msgspec to use when parsing to allow supporting additional types.
Returns:
The parsed configuration.
Raises:
:obj:`NotImplementedError`: If a file with an unrecognised format is specified.
:obj:`ValueError`: If the file cannot be parsed to a dictionary (e.g. the top level object is an array).
:obj:`ImportError`: If a required dependency is not installed.
"""
path = pathlib.Path(path) if not isinstance(path, pathlib.Path) else path
contents: list[str | bytes] = []
with open(path, "rb") as file:
contents.append(file.read().strip())
resolved_env = (env or os.getenv("CONFSPEC_ENV", "")).strip()
env_file_path = (path.parent / helpers.env_file_name(path, resolved_env)) if resolved_env else None
if env_file_path and env_file_path.is_file():
with open(env_file_path, "rb") as file:
contents.append(file.read().strip())
return _loads(contents, path.suffix[1:], cls=cls, strict=strict, dec_hook=dec_hook)