Tighten CLI and admin typing

This commit is contained in:
Nick Sweeting
2026-03-15 19:33:15 -07:00
parent 5381f7584c
commit 49436af869
14 changed files with 317 additions and 97 deletions

View File

@@ -1,5 +1,7 @@
__package__ = 'archivebox.api'
from django.contrib import admin
from django.http import HttpRequest
from signal_webhooks.admin import WebhookAdmin
from signal_webhooks.utils import get_webhook_model
@@ -62,7 +64,11 @@ class CustomWebhookAdmin(WebhookAdmin, BaseModelAdmin):
}),
)
def lookup_allowed(self, lookup: str, value: str, request: HttpRequest | None = None) -> bool:
"""Preserve WebhookAdmin's auth token filter with Django's current admin signature."""
return not lookup.startswith("auth_token") and admin.ModelAdmin.lookup_allowed(self, lookup, value, request)
def register_admin(admin_site):
def register_admin(admin_site: admin.AdminSite) -> None:
admin_site.register(APIToken, APITokenAdmin)
admin_site.register(get_webhook_model(), CustomWebhookAdmin)

View File

@@ -3,20 +3,32 @@
__package__ = 'archivebox.base_models'
import json
from collections.abc import Mapping
from typing import TypedDict
from django import forms
from django.contrib import admin
from django.utils.html import mark_safe
from django.db import models
from django.forms.renderers import BaseRenderer
from django.http import HttpRequest, QueryDict
from django.utils.safestring import SafeString, mark_safe
from django_object_actions import DjangoObjectActions
class ConfigOption(TypedDict):
plugin: str
type: str
default: object
description: str
class KeyValueWidget(forms.Widget):
"""
A widget that renders JSON dict as editable key-value input fields
with + and - buttons to add/remove rows.
Includes autocomplete for available config keys from the plugin system.
"""
template_name = None # We render manually
template_name = "" # We render manually
class Media:
css = {
@@ -24,12 +36,12 @@ class KeyValueWidget(forms.Widget):
}
js = []
def _get_config_options(self):
def _get_config_options(self) -> dict[str, ConfigOption]:
"""Get available config options from plugins."""
try:
from archivebox.hooks import discover_plugin_configs
plugin_configs = discover_plugin_configs()
options = {}
options: dict[str, ConfigOption] = {}
for plugin_name, schema in plugin_configs.items():
for key, prop in schema.get('properties', {}).items():
options[key] = {
@@ -42,19 +54,28 @@ class KeyValueWidget(forms.Widget):
except Exception:
return {}
def render(self, name, value, attrs=None, renderer=None):
def _parse_value(self, value: object) -> dict[str, object]:
# Parse JSON value to dict
if value is None:
data = {}
elif isinstance(value, str):
return {}
if isinstance(value, str):
try:
data = json.loads(value) if value else {}
parsed = json.loads(value) if value else {}
except json.JSONDecodeError:
data = {}
elif isinstance(value, dict):
data = value
else:
data = {}
return {}
return parsed if isinstance(parsed, dict) else {}
if isinstance(value, Mapping):
return {str(key): item for key, item in value.items()}
return {}
def render(
self,
name: str,
value: object,
attrs: Mapping[str, str] | None = None,
renderer: BaseRenderer | None = None,
) -> SafeString:
data = self._parse_value(value)
widget_id = attrs.get('id', name) if attrs else name
config_options = self._get_config_options()
@@ -185,7 +206,7 @@ class KeyValueWidget(forms.Widget):
'''
return mark_safe(html)
def _render_row(self, widget_id, idx, key, value):
def _render_row(self, widget_id: str, idx: int, key: str, value: str) -> str:
return f'''
<div class="key-value-row" style="display: flex; gap: 8px; margin-bottom: 6px; align-items: center;">
<input type="text" class="kv-key" value="{self._escape(key)}" placeholder="KEY" list="{widget_id}_keys"
@@ -199,25 +220,35 @@ class KeyValueWidget(forms.Widget):
</div>
'''
def _escape(self, s):
def _escape(self, s: object) -> str:
"""Escape HTML special chars in attribute values."""
if not s:
return ''
return str(s).replace('&', '&amp;').replace('<', '&lt;').replace('>', '&gt;').replace('"', '&quot;')
def value_from_datadict(self, data, files, name):
def value_from_datadict(
self,
data: QueryDict | Mapping[str, object],
files: object,
name: str,
) -> str:
value = data.get(name, '{}')
return value
return value if isinstance(value, str) else '{}'
class ConfigEditorMixin:
class ConfigEditorMixin(admin.ModelAdmin):
"""
Mixin for admin classes with a config JSON field.
Provides a key-value editor widget with autocomplete for available config keys.
"""
def formfield_for_dbfield(self, db_field, request, **kwargs):
def formfield_for_dbfield(
self,
db_field: models.Field[object, object],
request: HttpRequest,
**kwargs: object,
) -> forms.Field | None:
"""Use KeyValueWidget for the config JSON field."""
if db_field.name == 'config':
kwargs['widget'] = KeyValueWidget()
@@ -228,8 +259,14 @@ class BaseModelAdmin(DjangoObjectActions, admin.ModelAdmin):
list_display = ('id', 'created_at', 'created_by')
readonly_fields = ('id', 'created_at', 'modified_at')
def get_form(self, request, obj=None, **kwargs):
form = super().get_form(request, obj, **kwargs)
def get_form(
self,
request: HttpRequest,
obj: models.Model | None = None,
change: bool = False,
**kwargs: object,
):
form = super().get_form(request, obj, change=change, **kwargs)
if 'created_by' in form.base_fields:
form.base_fields['created_by'].initial = request.user
return form

View File

@@ -48,7 +48,7 @@ class ModelWithUUID(models.Model):
class Meta(TypedModelMeta):
abstract = True
def __str__(self):
def __str__(self) -> str:
return f'[{self.id}] {self.__class__.__name__}'
@property
@@ -57,7 +57,7 @@ class ModelWithUUID(models.Model):
@property
def api_url(self) -> str:
return reverse_lazy('api-1:get_any', args=[self.id])
return str(reverse_lazy('api-1:get_any', args=[self.id]))
@property
def api_docs_url(self) -> str:
@@ -101,7 +101,7 @@ class ModelWithConfig(models.Model):
class ModelWithOutputDir(ModelWithUUID):
class Meta:
class Meta(ModelWithUUID.Meta):
abstract = True
def save(self, *args, **kwargs):

View File

@@ -123,7 +123,9 @@ class ArchiveBoxGroup(click.Group):
@classmethod
def _lazy_load(cls, cmd_name_or_path):
import_path = cls.all_subcommands.get(cmd_name_or_path, cmd_name_or_path)
import_path = cls.all_subcommands.get(cmd_name_or_path)
if import_path is None:
import_path = cmd_name_or_path
modname, funcname = import_path.rsplit('.', 1)
# print(f'LAZY LOADING {import_path}')

View File

@@ -254,9 +254,15 @@ def main(plugins: str, wait: bool, args: tuple):
if all_are_archiveresult_ids:
# Process existing ArchiveResults by ID
from rich import print as rprint
exit_code = 0
for record in records:
archiveresult_id = record.get('id') or record.get('url')
if not isinstance(archiveresult_id, str):
rprint(f'[red]Invalid ArchiveResult input: {record}[/red]', file=sys.stderr)
exit_code = 1
continue
result = process_archiveresult_by_id(archiveresult_id)
if result != 0:
exit_code = result

View File

@@ -5,6 +5,7 @@ __package__ = 'archivebox.cli'
import os
import sys
from pathlib import Path
from typing import Mapping
from rich import print
import rich_click as click
@@ -12,6 +13,19 @@ import rich_click as click
from archivebox.misc.util import docstring, enforce_types
def _normalize_snapshot_record(link_dict: Mapping[str, object]) -> tuple[str, dict[str, object]] | None:
url = link_dict.get('url')
if not isinstance(url, str) or not url:
return None
record: dict[str, object] = {'url': url}
for key in ('timestamp', 'title', 'tags', 'sources'):
value = link_dict.get(key)
if value is not None:
record[key] = value
return url, record
@enforce_types
def init(force: bool=False, quick: bool=False, install: bool=False) -> None:
"""Initialize a new ArchiveBox collection in the current directory"""
@@ -96,7 +110,7 @@ def init(force: bool=False, quick: bool=False, install: bool=False) -> None:
from archivebox.core.models import Snapshot
all_links = Snapshot.objects.none()
pending_links: dict[str, SnapshotDict] = {}
pending_links: dict[str, dict[str, object]] = {}
if existing_index:
all_links = Snapshot.objects.all()
@@ -107,20 +121,26 @@ def init(force: bool=False, quick: bool=False, install: bool=False) -> None:
else:
try:
# Import orphaned links from legacy JSON indexes
orphaned_json_links = {
link_dict['url']: link_dict
for link_dict in parse_json_main_index(DATA_DIR)
if not all_links.filter(url=link_dict['url']).exists()
}
orphaned_json_links: dict[str, dict[str, object]] = {}
for link_dict in parse_json_main_index(DATA_DIR):
normalized = _normalize_snapshot_record(link_dict)
if normalized is None:
continue
url, record = normalized
if not all_links.filter(url=url).exists():
orphaned_json_links[url] = record
if orphaned_json_links:
pending_links.update(orphaned_json_links)
print(f' [yellow]√ Added {len(orphaned_json_links)} orphaned links from existing JSON index...[/yellow]')
orphaned_data_dir_links = {
link_dict['url']: link_dict
for link_dict in parse_json_links_details(DATA_DIR)
if not all_links.filter(url=link_dict['url']).exists()
}
orphaned_data_dir_links: dict[str, dict[str, object]] = {}
for link_dict in parse_json_links_details(DATA_DIR):
normalized = _normalize_snapshot_record(link_dict)
if normalized is None:
continue
url, record = normalized
if not all_links.filter(url=url).exists():
orphaned_data_dir_links[url] = record
if orphaned_data_dir_links:
pending_links.update(orphaned_data_dir_links)
print(f' [yellow]√ Added {len(orphaned_data_dir_links)} orphaned links from existing archive directories.[/yellow]')

View File

@@ -464,11 +464,10 @@ def create_personas(
else:
rprint(f'[dim]Persona already exists: {name}[/dim]', file=sys.stderr)
# Import browser profile if requested
if import_from and source_profile_dir:
cookies_file = Path(persona.path) / 'cookies.txt'
cookies_file = Path(persona.path) / 'cookies.txt'
if import_from in CHROMIUM_BROWSERS:
# Import browser profile if requested
if import_from in CHROMIUM_BROWSERS and source_profile_dir is not None:
persona_chrome_dir = Path(persona.CHROME_USER_DATA_DIR)
# Copy the browser profile

View File

@@ -41,12 +41,14 @@ def remove(filter_patterns: Iterable[str]=(),
from archivebox.cli.archivebox_search import get_snapshots
log_list_started(filter_patterns, filter_type)
pattern_list = list(filter_patterns)
log_list_started(pattern_list or None, filter_type)
timer = TimedProgress(360, prefix=' ')
try:
snapshots = get_snapshots(
snapshots=snapshots,
filter_patterns=list(filter_patterns) if filter_patterns else None,
filter_patterns=pattern_list or None,
filter_type=filter_type,
after=after,
before=before,

View File

@@ -3,42 +3,147 @@
__package__ = 'archivebox.cli'
__command__ = 'archivebox search'
import sys
from pathlib import Path
from typing import Optional, List
from typing import TYPE_CHECKING, Callable
import rich_click as click
from rich import print
from django.db.models import QuerySet
from django.db.models import Q, QuerySet
from archivebox.config import DATA_DIR
from archivebox.misc.logging import stderr
from archivebox.misc.util import enforce_types, docstring
if TYPE_CHECKING:
from archivebox.core.models import Snapshot
# Filter types for URL matching
LINK_FILTERS = {
'exact': lambda pattern: {'url': pattern},
'substring': lambda pattern: {'url__icontains': pattern},
'regex': lambda pattern: {'url__iregex': pattern},
'domain': lambda pattern: {'url__istartswith': f'http://{pattern}'},
'tag': lambda pattern: {'tags__name': pattern},
'timestamp': lambda pattern: {'timestamp': pattern},
LINK_FILTERS: dict[str, Callable[[str], Q]] = {
'exact': lambda pattern: Q(url=pattern),
'substring': lambda pattern: Q(url__icontains=pattern),
'regex': lambda pattern: Q(url__iregex=pattern),
'domain': lambda pattern: (
Q(url__istartswith=f'http://{pattern}')
| Q(url__istartswith=f'https://{pattern}')
| Q(url__istartswith=f'ftp://{pattern}')
),
'tag': lambda pattern: Q(tags__name=pattern),
'timestamp': lambda pattern: Q(timestamp=pattern),
}
STATUS_CHOICES = ['indexed', 'archived', 'unarchived']
def _apply_pattern_filters(
snapshots: QuerySet['Snapshot', 'Snapshot'],
filter_patterns: list[str],
filter_type: str,
) -> QuerySet['Snapshot', 'Snapshot']:
filter_builder = LINK_FILTERS.get(filter_type)
if filter_builder is None:
stderr()
stderr(f'[X] Got invalid pattern for --filter-type={filter_type}', color='red')
raise SystemExit(2)
def get_snapshots(snapshots: Optional[QuerySet]=None,
filter_patterns: Optional[List[str]]=None,
query = Q()
for pattern in filter_patterns:
query |= filter_builder(pattern)
return snapshots.filter(query)
def _snapshots_to_json(
snapshots: QuerySet['Snapshot', 'Snapshot'],
*,
with_headers: bool,
) -> str:
from datetime import datetime, timezone as tz
from archivebox.config import VERSION
from archivebox.config.common import SERVER_CONFIG
from archivebox.misc.util import to_json
main_index_header = {
'info': 'This is an index of site data archived by ArchiveBox: The self-hosted web archive.',
'schema': 'archivebox.index.json',
'copyright_info': SERVER_CONFIG.FOOTER_INFO,
'meta': {
'project': 'ArchiveBox',
'version': VERSION,
'git_sha': VERSION,
'website': 'https://ArchiveBox.io',
'docs': 'https://github.com/ArchiveBox/ArchiveBox/wiki',
'source': 'https://github.com/ArchiveBox/ArchiveBox',
'issues': 'https://github.com/ArchiveBox/ArchiveBox/issues',
'dependencies': {},
},
} if with_headers else {}
snapshot_dicts = [snapshot.to_dict(extended=True) for snapshot in snapshots.iterator(chunk_size=500)]
output: dict[str, object] | list[dict[str, object]]
if with_headers:
output = {
**main_index_header,
'num_links': len(snapshot_dicts),
'updated': datetime.now(tz.utc),
'last_run_cmd': sys.argv,
'links': snapshot_dicts,
}
else:
output = snapshot_dicts
return to_json(output, indent=4, sort_keys=True)
def _snapshots_to_csv(
snapshots: QuerySet['Snapshot', 'Snapshot'],
*,
cols: list[str],
with_headers: bool,
) -> str:
header = ','.join(cols) if with_headers else ''
rows = [snapshot.to_csv(cols=cols, separator=',') for snapshot in snapshots.iterator(chunk_size=500)]
return '\n'.join((header, *rows))
def _snapshots_to_html(
snapshots: QuerySet['Snapshot', 'Snapshot'],
*,
with_headers: bool,
) -> str:
from datetime import datetime, timezone as tz
from django.template.loader import render_to_string
from archivebox.config import VERSION
from archivebox.config.common import SERVER_CONFIG
from archivebox.config.version import get_COMMIT_HASH
template = 'static_index.html' if with_headers else 'minimal_index.html'
snapshot_list = list(snapshots.iterator(chunk_size=500))
return render_to_string(template, {
'version': VERSION,
'git_sha': get_COMMIT_HASH() or VERSION,
'num_links': str(len(snapshot_list)),
'date_updated': datetime.now(tz.utc).strftime('%Y-%m-%d'),
'time_updated': datetime.now(tz.utc).strftime('%Y-%m-%d %H:%M'),
'links': snapshot_list,
'FOOTER_INFO': SERVER_CONFIG.FOOTER_INFO,
})
def get_snapshots(snapshots: QuerySet['Snapshot', 'Snapshot'] | None=None,
filter_patterns: list[str] | None=None,
filter_type: str='substring',
after: Optional[float]=None,
before: Optional[float]=None,
out_dir: Path=DATA_DIR) -> QuerySet:
after: float | None=None,
before: float | None=None,
out_dir: Path=DATA_DIR) -> QuerySet['Snapshot', 'Snapshot']:
"""Filter and return Snapshots matching the given criteria."""
from archivebox.core.models import Snapshot
if snapshots:
if snapshots is not None:
result = snapshots
else:
result = Snapshot.objects.all()
@@ -48,12 +153,12 @@ def get_snapshots(snapshots: Optional[QuerySet]=None,
if before is not None:
result = result.filter(timestamp__lt=before)
if filter_patterns:
result = Snapshot.objects.filter_by_patterns(filter_patterns, filter_type)
result = _apply_pattern_filters(result, filter_patterns, filter_type)
# Prefetch crawl relationship to avoid N+1 queries when accessing output_dir
result = result.select_related('crawl', 'crawl__created_by')
if not result:
if not result.exists():
stderr('[!] No Snapshots matched your filters:', filter_patterns, f'({filter_type})', color='lightyellow')
return result
@@ -96,15 +201,15 @@ def search(filter_patterns: list[str] | None=None,
# Export to requested format
if json:
output = snapshots.to_json(with_headers=with_headers)
output = _snapshots_to_json(snapshots, with_headers=with_headers)
elif html:
output = snapshots.to_html(with_headers=with_headers)
output = _snapshots_to_html(snapshots, with_headers=with_headers)
elif csv:
output = snapshots.to_csv(cols=csv.split(','), header=with_headers)
output = _snapshots_to_csv(snapshots, cols=csv.split(','), with_headers=with_headers)
else:
from archivebox.misc.logging_util import printable_folders
# Convert to dict for printable_folders
folders = {s.output_dir: s for s in snapshots}
folders: dict[str, Snapshot | None] = {snapshot.output_dir: snapshot for snapshot in snapshots}
output = printable_folders(folders, with_headers)
print(output)

View File

@@ -20,7 +20,6 @@ def status(out_dir: Path=DATA_DIR) -> None:
"""Print out some info and statistics about the archive collection"""
from django.contrib.auth import get_user_model
from archivebox.misc.db import get_admins
from archivebox.core.models import Snapshot
User = get_user_model()
@@ -102,11 +101,12 @@ def status(out_dir: Path=DATA_DIR) -> None:
print()
print('[green]\\[*] Scanning recent archive changes and user logins:[/green]')
print(f'[yellow] {CONSTANTS.LOGS_DIR}/*[/yellow]')
users = get_admins().values_list('username', flat=True)
admin_users = User.objects.filter(is_superuser=True).exclude(username='system')
users = [user.get_username() for user in admin_users]
print(f' UI users {len(users)}: {", ".join(users)}')
last_login = User.objects.order_by('last_login').last()
last_login = admin_users.order_by('last_login').last()
if last_login:
print(f' Last UI login: {last_login.username} @ {str(last_login.last_login)[:16]}')
print(f' Last UI login: {last_login.get_username()} @ {str(last_login.last_login)[:16]}')
last_downloaded = Snapshot.objects.order_by('downloaded_at').last()
if last_downloaded:
print(f' Last changes: {str(last_downloaded.downloaded_at)[:16]}')

View File

@@ -4,13 +4,56 @@ __package__ = 'archivebox.cli'
import os
import time
import rich_click as click
from typing import Iterable
from typing import TYPE_CHECKING, Callable, Iterable
from pathlib import Path
import rich_click as click
from django.core.exceptions import ObjectDoesNotExist
from django.db.models import Q, QuerySet
from archivebox.misc.util import enforce_types, docstring
if TYPE_CHECKING:
from archivebox.core.models import Snapshot
from archivebox.crawls.models import Crawl
LINK_FILTERS: dict[str, Callable[[str], Q]] = {
'exact': lambda pattern: Q(url=pattern),
'substring': lambda pattern: Q(url__icontains=pattern),
'regex': lambda pattern: Q(url__iregex=pattern),
'domain': lambda pattern: (
Q(url__istartswith=f'http://{pattern}')
| Q(url__istartswith=f'https://{pattern}')
| Q(url__istartswith=f'ftp://{pattern}')
),
'tag': lambda pattern: Q(tags__name=pattern),
'timestamp': lambda pattern: Q(timestamp=pattern),
}
def _apply_pattern_filters(
snapshots: QuerySet['Snapshot', 'Snapshot'],
filter_patterns: list[str],
filter_type: str,
) -> QuerySet['Snapshot', 'Snapshot']:
filter_builder = LINK_FILTERS.get(filter_type)
if filter_builder is None:
raise SystemExit(2)
query = Q()
for pattern in filter_patterns:
query |= filter_builder(pattern)
return snapshots.filter(query)
def _get_snapshot_crawl(snapshot: 'Snapshot') -> 'Crawl | None':
try:
return snapshot.crawl
except ObjectDoesNotExist:
return None
@enforce_types
def update(filter_patterns: Iterable[str] = (),
@@ -84,7 +127,7 @@ def update(filter_patterns: Iterable[str] = (),
resume = None
def drain_old_archive_dirs(resume_from: str = None, batch_size: int = 100) -> dict:
def drain_old_archive_dirs(resume_from: str | None = None, batch_size: int = 100) -> dict[str, int]:
"""
Drain old archive/ directories (0.8.x → 0.9.x migration).
@@ -153,21 +196,17 @@ def drain_old_archive_dirs(resume_from: str = None, batch_size: int = 100) -> di
continue
# Ensure snapshot has a valid crawl (migration 0024 may have failed)
from archivebox.crawls.models import Crawl
has_valid_crawl = False
if snapshot.crawl_id:
# Check if the crawl actually exists
has_valid_crawl = Crawl.objects.filter(id=snapshot.crawl_id).exists()
has_valid_crawl = _get_snapshot_crawl(snapshot) is not None
if not has_valid_crawl:
# Create a new crawl (created_by will default to system user)
from archivebox.crawls.models import Crawl
crawl = Crawl.objects.create(urls=snapshot.url)
# Use queryset update to avoid triggering save() hooks
from archivebox.core.models import Snapshot as SnapshotModel
SnapshotModel.objects.filter(pk=snapshot.pk).update(crawl=crawl)
# Refresh the instance
snapshot.crawl = crawl
snapshot.crawl_id = crawl.id
print(f"[DEBUG Phase1] Created missing crawl for snapshot {str(snapshot.id)[:8]}")
# Check if needs migration (0.8.x → 0.9.x)
@@ -221,7 +260,7 @@ def drain_old_archive_dirs(resume_from: str = None, batch_size: int = 100) -> di
return stats
def process_all_db_snapshots(batch_size: int = 100) -> dict:
def process_all_db_snapshots(batch_size: int = 100) -> dict[str, int]:
"""
O(n) scan over entire DB from most recent to least recent.
@@ -246,7 +285,7 @@ def process_all_db_snapshots(batch_size: int = 100) -> dict:
stats['processed'] += 1
# Skip snapshots with missing crawl references (orphaned by migration errors)
if not snapshot.crawl_id:
if _get_snapshot_crawl(snapshot) is None:
continue
try:
@@ -303,7 +342,7 @@ def process_filtered_snapshots(
before: float | None,
after: float | None,
batch_size: int
) -> dict:
) -> dict[str, int]:
"""Process snapshots matching filters (DB query only)."""
from archivebox.core.models import Snapshot
from django.db import transaction
@@ -315,7 +354,7 @@ def process_filtered_snapshots(
snapshots = Snapshot.objects.all()
if filter_patterns:
snapshots = Snapshot.objects.filter_by_patterns(list(filter_patterns), filter_type)
snapshots = _apply_pattern_filters(snapshots, list(filter_patterns), filter_type)
if before:
snapshots = snapshots.filter(bookmarked_at__lt=datetime.fromtimestamp(before))
@@ -329,7 +368,7 @@ def process_filtered_snapshots(
stats['processed'] += 1
# Skip snapshots with missing crawl references
if not snapshot.crawl_id:
if _get_snapshot_crawl(snapshot) is None:
continue
try:

View File

@@ -15,6 +15,11 @@ from archivebox.config.constants import CONSTANTS
from archivebox.misc.logging import stderr
class CaseConfigParser(ConfigParser):
def optionxform(self, optionstr: str) -> str:
return optionstr
def get_real_name(key: str) -> str:
"""get the up-to-date canonical name for a given old alias or current key"""
# Config aliases are no longer used with the simplified config system
@@ -59,6 +64,8 @@ def load_config_val(key: str,
return default(config)
return default
assert isinstance(val, str)
# calculate value based on expected type
BOOL_TRUEIES = ('true', 'yes', '1')
BOOL_FALSEIES = ('false', 'no', '0')
@@ -95,8 +102,7 @@ def load_config_file() -> Optional[benedict]:
config_path = CONSTANTS.CONFIG_FILE
if os.access(config_path, os.R_OK):
config_file = ConfigParser()
config_file.optionxform = str
config_file = CaseConfigParser()
config_file.read(config_path)
# flatten into one namespace
config_file_vars = benedict({
@@ -108,8 +114,6 @@ def load_config_file() -> Optional[benedict]:
# print(config_file_vars)
return config_file_vars
return None
class PluginConfigSection:
"""Pseudo-section for all plugin config keys written to [PLUGINS] section in ArchiveBox.conf"""
toml_section_header = "PLUGINS"
@@ -181,8 +185,7 @@ def write_config_file(config: Dict[str, str]) -> benedict:
if not os.access(config_path, os.F_OK):
atomic_write(config_path, CONFIG_HEADER)
config_file = ConfigParser()
config_file.optionxform = str
config_file = CaseConfigParser()
config_file.read(config_path)
with open(config_path, 'r', encoding='utf-8') as old:
@@ -288,4 +291,3 @@ def load_all_config():
flat_config.update(dict(config_section))
return flat_config

View File

@@ -14,8 +14,12 @@ from pathlib import Path
from typing import Any, Dict, Optional, Type, Tuple
from configparser import ConfigParser
from pydantic import ConfigDict
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict
class CaseConfigParser(ConfigParser):
def optionxform(self, optionstr: str) -> str:
return optionstr
class IniConfigSettingsSource(PydanticBaseSettingsSource):
@@ -42,8 +46,7 @@ class IniConfigSettingsSource(PydanticBaseSettingsSource):
if not config_path.exists():
return {}
parser = ConfigParser()
parser.optionxform = lambda x: x # preserve case
parser = CaseConfigParser()
parser.read(config_path)
# Flatten all sections into single namespace (ignore section headers)
@@ -66,7 +69,7 @@ class BaseConfigSet(BaseSettings):
USE_COLOR: bool = Field(default=True)
"""
model_config = ConfigDict(
model_config = SettingsConfigDict(
env_prefix="",
extra="ignore",
validate_default=True,
@@ -98,8 +101,7 @@ class BaseConfigSet(BaseSettings):
if not config_path.exists():
return {}
parser = ConfigParser()
parser.optionxform = lambda x: x # preserve case
parser = CaseConfigParser()
parser.read(config_path)
# Flatten all sections into single namespace