type and test fixes

This commit is contained in:
Nick Sweeting
2026-03-15 20:12:27 -07:00
parent 3889eb4efa
commit bc21d4bfdb
52 changed files with 762 additions and 1317 deletions

View File

@@ -6,7 +6,7 @@ __package__ = 'archivebox.misc'
from io import StringIO
from pathlib import Path
from typing import List, Tuple
from typing import Any, List, Tuple
from archivebox.config import DATA_DIR
from archivebox.misc.util import enforce_types
@@ -48,8 +48,8 @@ def apply_migrations(out_dir: Path = DATA_DIR) -> List[str]:
@enforce_types
def get_admins(out_dir: Path = DATA_DIR) -> List:
def get_admins(out_dir: Path = DATA_DIR) -> List[Any]:
"""Get list of superuser accounts"""
from django.contrib.auth.models import User
return User.objects.filter(is_superuser=True).exclude(username='system')
return list(User.objects.filter(is_superuser=True).exclude(username='system'))

View File

@@ -14,7 +14,7 @@ from pathlib import Path
from datetime import datetime, timezone
from dataclasses import dataclass
from typing import Any, Optional, List, Dict, Union, Iterable, IO, TYPE_CHECKING
from typing import Any, Optional, List, Dict, Union, Iterable, IO, TYPE_CHECKING, cast
if TYPE_CHECKING:
from archivebox.core.models import Snapshot
@@ -397,7 +397,8 @@ def log_list_finished(snapshots):
from archivebox.core.models import Snapshot
print()
print('---------------------------------------------------------------------------------------------------')
print(Snapshot.objects.filter(pk__in=[s.pk for s in snapshots]).to_csv(cols=['timestamp', 'is_archived', 'num_outputs', 'url'], header=True, ljust=16, separator=' | '))
csv_queryset = cast(Any, Snapshot.objects.filter(pk__in=[s.pk for s in snapshots]))
print(csv_queryset.to_csv(cols=['timestamp', 'is_archived', 'num_outputs', 'url'], header=True, ljust=16, separator=' | '))
print('---------------------------------------------------------------------------------------------------')
print()

View File

@@ -13,7 +13,7 @@ django_stubs_ext.monkeypatch()
# monkey patch django timezone to add back utc (it was removed in Django 5.0)
timezone.utc = datetime.timezone.utc
setattr(timezone, 'utc', datetime.timezone.utc)
# monkey patch django-signals-webhooks to change how it shows up in Admin UI
# from signal_webhooks.apps import DjangoSignalWebhooksConfig

View File

@@ -13,12 +13,17 @@ class AccelleratedPaginator(Paginator):
@cached_property
def count(self):
if self.object_list._has_filters(): # type: ignore
has_filters = getattr(self.object_list, '_has_filters', None)
if callable(has_filters) and has_filters():
# fallback to normal count method on filtered queryset
return super().count
else:
# otherwise count total rows in a separate fast query
return self.object_list.model.objects.count()
model = getattr(self.object_list, 'model', None)
if model is None:
return super().count
# otherwise count total rows in a separate fast query
return model.objects.count()
# Alternative approach for PostgreSQL: fallback count takes > 200ms
# from django.db import connection, transaction, OperationalError

View File

@@ -17,7 +17,7 @@ from collections import deque
from pathlib import Path
from rich import box
from rich.console import Group
from rich.console import Group, RenderableType
from rich.layout import Layout
from rich.columns import Columns
from rich.panel import Panel
@@ -48,7 +48,7 @@ class CrawlQueuePanel:
self.max_crawl_workers = 8
self.crawl_id: Optional[str] = None
def __rich__(self) -> Panel:
def __rich__(self) -> RenderableType:
grid = Table.grid(expand=True)
grid.add_column(justify="left", ratio=1)
grid.add_column(justify="center", ratio=1)
@@ -104,7 +104,7 @@ class ProcessLogPanel:
self.compact = compact
self.bg_terminating = bg_terminating
def __rich__(self) -> Panel:
def __rich__(self) -> RenderableType:
completed_line = self._completed_output_line()
if completed_line:
style = "green" if self._completed_ok() else "yellow"

View File

@@ -111,7 +111,7 @@ def _render_markdown_fallback(text: str) -> str:
return _markdown.markdown(
text,
extensions=["extra", "toc", "sane_lists"],
output_format="html5",
output_format="html",
)
except Exception:
pass

View File

@@ -9,13 +9,14 @@ import sys
from json import dump
from pathlib import Path
from typing import Optional, Union, Tuple
from subprocess import _mswindows, PIPE, Popen, CalledProcessError, CompletedProcess, TimeoutExpired
from subprocess import PIPE, Popen, CalledProcessError, CompletedProcess, TimeoutExpired
from atomicwrites import atomic_write as lib_atomic_write
from archivebox.config.common import STORAGE_CONFIG
from archivebox.misc.util import enforce_types, ExtendedEncoder
IS_WINDOWS = os.name == 'nt'
def run(cmd, *args, input=None, capture_output=True, timeout=None, check=False, text=False, start_new_session=True, **kwargs):
"""Patched of subprocess.run to kill forked child subprocesses and fix blocking io making timeout=innefective
@@ -47,13 +48,15 @@ def run(cmd, *args, input=None, capture_output=True, timeout=None, check=False,
stdout, stderr = process.communicate(input, timeout=timeout)
except TimeoutExpired as exc:
process.kill()
if _mswindows:
if IS_WINDOWS:
# Windows accumulates the output in a single blocking
# read() call run on child threads, with the timeout
# being done in a join() on those threads. communicate()
# _after_ kill() is required to collect that and add it
# to the exception.
exc.stdout, exc.stderr = process.communicate()
timed_out_stdout, timed_out_stderr = process.communicate()
exc.stdout = timed_out_stdout.encode() if isinstance(timed_out_stdout, str) else timed_out_stdout
exc.stderr = timed_out_stderr.encode() if isinstance(timed_out_stderr, str) else timed_out_stderr
else:
# POSIX _communicate already populated the output so
# far into the TimeoutExpired exception.
@@ -71,11 +74,12 @@ def run(cmd, *args, input=None, capture_output=True, timeout=None, check=False,
finally:
# force kill any straggler subprocesses that were forked from the main proc
try:
os.killpg(pgid, signal.SIGINT)
if pgid is not None:
os.killpg(pgid, signal.SIGINT)
except Exception:
pass
return CompletedProcess(process.args, retcode, stdout, stderr)
return CompletedProcess(process.args, retcode or 0, stdout, stderr)
@enforce_types

View File

@@ -42,7 +42,7 @@ def convert(ini_str: str) -> str:
"""Convert a string of INI config into its TOML equivalent (warning: strips comments)"""
config = configparser.ConfigParser()
config.optionxform = str # capitalize key names
setattr(config, 'optionxform', str) # capitalize key names
config.read_string(ini_str)
# Initialize an empty dictionary to store the TOML representation
@@ -77,12 +77,12 @@ class JSONSchemaWithLambdas(GenerateJsonSchema):
Usage:
>>> json.dumps(value, encoder=JSONSchemaWithLambdas())
"""
def encode_default(self, default: Any) -> Any:
def encode_default(self, dft: Any) -> Any:
config = self._config
if isinstance(default, Callable):
return '{{lambda ' + inspect.getsource(default).split('=lambda ')[-1].strip()[:-1] + '}}'
if isinstance(dft, Callable):
return '{{lambda ' + inspect.getsource(dft).split('=lambda ')[-1].strip()[:-1] + '}}'
return to_jsonable_python(
default,
dft,
timedelta_mode=config.ser_json_timedelta,
bytes_mode=config.ser_json_bytes,
serialize_unknown=True

View File

@@ -56,9 +56,19 @@ urldecode = lambda s: s and unquote(s)
htmlencode = lambda s: s and escape(s, quote=True)
htmldecode = lambda s: s and unescape(s)
short_ts = lambda ts: str(parse_date(ts).timestamp()).split('.')[0]
ts_to_date_str = lambda ts: ts and parse_date(ts).strftime('%Y-%m-%d %H:%M')
ts_to_iso = lambda ts: ts and parse_date(ts).isoformat()
def short_ts(ts: Any) -> str | None:
parsed = parse_date(ts)
return None if parsed is None else str(parsed.timestamp()).split('.')[0]
def ts_to_date_str(ts: Any) -> str | None:
parsed = parse_date(ts)
return None if parsed is None else parsed.strftime('%Y-%m-%d %H:%M')
def ts_to_iso(ts: Any) -> str | None:
parsed = parse_date(ts)
return None if parsed is None else parsed.isoformat()
COLOR_REGEX = re.compile(r'\[(?P<arg_1>\d+)(;(?P<arg_2>\d+)(;(?P<arg_3>\d+))?)?m')
@@ -175,7 +185,7 @@ def docstring(text: Optional[str]):
@enforce_types
def str_between(string: str, start: str, end: str=None) -> str:
def str_between(string: str, start: str, end: str | None = None) -> str:
"""(<abc>12345</def>, <abc>, </def>) -> 12345"""
content = string.split(start, 1)[-1]
@@ -186,7 +196,7 @@ def str_between(string: str, start: str, end: str=None) -> str:
@enforce_types
def parse_date(date: Any) -> datetime:
def parse_date(date: Any) -> datetime | None:
"""Parse unix timestamps, iso format, and human-readable strings"""
if date is None:
@@ -196,20 +206,24 @@ def parse_date(date: Any) -> datetime:
if date.tzinfo is None:
return date.replace(tzinfo=timezone.utc)
assert date.tzinfo.utcoffset(datetime.now()).seconds == 0, 'Refusing to load a non-UTC date!'
offset = date.utcoffset()
assert offset == datetime.now(timezone.utc).utcoffset(), 'Refusing to load a non-UTC date!'
return date
if isinstance(date, (float, int)):
date = str(date)
if isinstance(date, str):
return dateparser(date, settings={'TIMEZONE': 'UTC'}).astimezone(timezone.utc)
parsed_date = dateparser(date, settings={'TIMEZONE': 'UTC'})
if parsed_date is None:
raise ValueError(f'Tried to parse invalid date string! {date}')
return parsed_date.astimezone(timezone.utc)
raise ValueError('Tried to parse invalid date! {}'.format(date))
@enforce_types
def download_url(url: str, timeout: int=None) -> str:
def download_url(url: str, timeout: int | None = None) -> str:
"""Download the contents of a remote url and return the text"""
from archivebox.config.common import ARCHIVING_CONFIG
@@ -221,7 +235,8 @@ def download_url(url: str, timeout: int=None) -> str:
cookie_jar = http.cookiejar.MozillaCookieJar(ARCHIVING_CONFIG.COOKIES_FILE)
cookie_jar.load(ignore_discard=True, ignore_expires=True)
for cookie in cookie_jar:
session.cookies.set(cookie.name, cookie.value, domain=cookie.domain, path=cookie.path)
if cookie.value is not None:
session.cookies.set(cookie.name, cookie.value, domain=cookie.domain, path=cookie.path)
response = session.get(
url,
@@ -331,47 +346,47 @@ class ExtendedEncoder(pyjson.JSONEncoder):
fields and objects
"""
def default(self, obj):
cls_name = obj.__class__.__name__
def default(self, o):
cls_name = o.__class__.__name__
if hasattr(obj, '_asdict'):
return obj._asdict()
if hasattr(o, '_asdict'):
return o._asdict()
elif isinstance(obj, bytes):
return obj.decode()
elif isinstance(o, bytes):
return o.decode()
elif isinstance(obj, datetime):
return obj.isoformat()
elif isinstance(o, datetime):
return o.isoformat()
elif isinstance(obj, Exception):
return '{}: {}'.format(obj.__class__.__name__, obj)
elif isinstance(o, Exception):
return '{}: {}'.format(o.__class__.__name__, o)
elif isinstance(obj, Path):
return str(obj)
elif isinstance(o, Path):
return str(o)
elif cls_name in ('dict_items', 'dict_keys', 'dict_values'):
return list(obj)
return list(o)
elif isinstance(obj, Callable):
return str(obj)
elif isinstance(o, Callable):
return str(o)
# Try dict/list conversion as fallback
try:
return dict(obj)
return dict(o)
except Exception:
pass
try:
return list(obj)
return list(o)
except Exception:
pass
try:
return str(obj)
return str(o)
except Exception:
pass
return pyjson.JSONEncoder.default(self, obj)
return pyjson.JSONEncoder.default(self, o)
@enforce_types