mirror of
https://github.com/ArchiveBox/ArchiveBox.git
synced 2026-04-06 07:47:53 +10:00
type and test fixes
This commit is contained in:
@@ -6,7 +6,7 @@ __package__ = 'archivebox.misc'
|
||||
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
from archivebox.config import DATA_DIR
|
||||
from archivebox.misc.util import enforce_types
|
||||
@@ -48,8 +48,8 @@ def apply_migrations(out_dir: Path = DATA_DIR) -> List[str]:
|
||||
|
||||
|
||||
@enforce_types
|
||||
def get_admins(out_dir: Path = DATA_DIR) -> List:
|
||||
def get_admins(out_dir: Path = DATA_DIR) -> List[Any]:
|
||||
"""Get list of superuser accounts"""
|
||||
from django.contrib.auth.models import User
|
||||
|
||||
return User.objects.filter(is_superuser=True).exclude(username='system')
|
||||
return list(User.objects.filter(is_superuser=True).exclude(username='system'))
|
||||
|
||||
@@ -14,7 +14,7 @@ from pathlib import Path
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, List, Dict, Union, Iterable, IO, TYPE_CHECKING
|
||||
from typing import Any, Optional, List, Dict, Union, Iterable, IO, TYPE_CHECKING, cast
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from archivebox.core.models import Snapshot
|
||||
@@ -397,7 +397,8 @@ def log_list_finished(snapshots):
|
||||
from archivebox.core.models import Snapshot
|
||||
print()
|
||||
print('---------------------------------------------------------------------------------------------------')
|
||||
print(Snapshot.objects.filter(pk__in=[s.pk for s in snapshots]).to_csv(cols=['timestamp', 'is_archived', 'num_outputs', 'url'], header=True, ljust=16, separator=' | '))
|
||||
csv_queryset = cast(Any, Snapshot.objects.filter(pk__in=[s.pk for s in snapshots]))
|
||||
print(csv_queryset.to_csv(cols=['timestamp', 'is_archived', 'num_outputs', 'url'], header=True, ljust=16, separator=' | '))
|
||||
print('---------------------------------------------------------------------------------------------------')
|
||||
print()
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ django_stubs_ext.monkeypatch()
|
||||
|
||||
|
||||
# monkey patch django timezone to add back utc (it was removed in Django 5.0)
|
||||
timezone.utc = datetime.timezone.utc
|
||||
setattr(timezone, 'utc', datetime.timezone.utc)
|
||||
|
||||
# monkey patch django-signals-webhooks to change how it shows up in Admin UI
|
||||
# from signal_webhooks.apps import DjangoSignalWebhooksConfig
|
||||
|
||||
@@ -13,12 +13,17 @@ class AccelleratedPaginator(Paginator):
|
||||
|
||||
@cached_property
|
||||
def count(self):
|
||||
if self.object_list._has_filters(): # type: ignore
|
||||
has_filters = getattr(self.object_list, '_has_filters', None)
|
||||
if callable(has_filters) and has_filters():
|
||||
# fallback to normal count method on filtered queryset
|
||||
return super().count
|
||||
else:
|
||||
# otherwise count total rows in a separate fast query
|
||||
return self.object_list.model.objects.count()
|
||||
|
||||
model = getattr(self.object_list, 'model', None)
|
||||
if model is None:
|
||||
return super().count
|
||||
|
||||
# otherwise count total rows in a separate fast query
|
||||
return model.objects.count()
|
||||
|
||||
# Alternative approach for PostgreSQL: fallback count takes > 200ms
|
||||
# from django.db import connection, transaction, OperationalError
|
||||
|
||||
@@ -17,7 +17,7 @@ from collections import deque
|
||||
from pathlib import Path
|
||||
|
||||
from rich import box
|
||||
from rich.console import Group
|
||||
from rich.console import Group, RenderableType
|
||||
from rich.layout import Layout
|
||||
from rich.columns import Columns
|
||||
from rich.panel import Panel
|
||||
@@ -48,7 +48,7 @@ class CrawlQueuePanel:
|
||||
self.max_crawl_workers = 8
|
||||
self.crawl_id: Optional[str] = None
|
||||
|
||||
def __rich__(self) -> Panel:
|
||||
def __rich__(self) -> RenderableType:
|
||||
grid = Table.grid(expand=True)
|
||||
grid.add_column(justify="left", ratio=1)
|
||||
grid.add_column(justify="center", ratio=1)
|
||||
@@ -104,7 +104,7 @@ class ProcessLogPanel:
|
||||
self.compact = compact
|
||||
self.bg_terminating = bg_terminating
|
||||
|
||||
def __rich__(self) -> Panel:
|
||||
def __rich__(self) -> RenderableType:
|
||||
completed_line = self._completed_output_line()
|
||||
if completed_line:
|
||||
style = "green" if self._completed_ok() else "yellow"
|
||||
|
||||
@@ -111,7 +111,7 @@ def _render_markdown_fallback(text: str) -> str:
|
||||
return _markdown.markdown(
|
||||
text,
|
||||
extensions=["extra", "toc", "sane_lists"],
|
||||
output_format="html5",
|
||||
output_format="html",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -9,13 +9,14 @@ import sys
|
||||
from json import dump
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Tuple
|
||||
from subprocess import _mswindows, PIPE, Popen, CalledProcessError, CompletedProcess, TimeoutExpired
|
||||
from subprocess import PIPE, Popen, CalledProcessError, CompletedProcess, TimeoutExpired
|
||||
|
||||
from atomicwrites import atomic_write as lib_atomic_write
|
||||
|
||||
from archivebox.config.common import STORAGE_CONFIG
|
||||
from archivebox.misc.util import enforce_types, ExtendedEncoder
|
||||
|
||||
IS_WINDOWS = os.name == 'nt'
|
||||
|
||||
def run(cmd, *args, input=None, capture_output=True, timeout=None, check=False, text=False, start_new_session=True, **kwargs):
|
||||
"""Patched of subprocess.run to kill forked child subprocesses and fix blocking io making timeout=innefective
|
||||
@@ -47,13 +48,15 @@ def run(cmd, *args, input=None, capture_output=True, timeout=None, check=False,
|
||||
stdout, stderr = process.communicate(input, timeout=timeout)
|
||||
except TimeoutExpired as exc:
|
||||
process.kill()
|
||||
if _mswindows:
|
||||
if IS_WINDOWS:
|
||||
# Windows accumulates the output in a single blocking
|
||||
# read() call run on child threads, with the timeout
|
||||
# being done in a join() on those threads. communicate()
|
||||
# _after_ kill() is required to collect that and add it
|
||||
# to the exception.
|
||||
exc.stdout, exc.stderr = process.communicate()
|
||||
timed_out_stdout, timed_out_stderr = process.communicate()
|
||||
exc.stdout = timed_out_stdout.encode() if isinstance(timed_out_stdout, str) else timed_out_stdout
|
||||
exc.stderr = timed_out_stderr.encode() if isinstance(timed_out_stderr, str) else timed_out_stderr
|
||||
else:
|
||||
# POSIX _communicate already populated the output so
|
||||
# far into the TimeoutExpired exception.
|
||||
@@ -71,11 +74,12 @@ def run(cmd, *args, input=None, capture_output=True, timeout=None, check=False,
|
||||
finally:
|
||||
# force kill any straggler subprocesses that were forked from the main proc
|
||||
try:
|
||||
os.killpg(pgid, signal.SIGINT)
|
||||
if pgid is not None:
|
||||
os.killpg(pgid, signal.SIGINT)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return CompletedProcess(process.args, retcode, stdout, stderr)
|
||||
return CompletedProcess(process.args, retcode or 0, stdout, stderr)
|
||||
|
||||
|
||||
@enforce_types
|
||||
|
||||
@@ -42,7 +42,7 @@ def convert(ini_str: str) -> str:
|
||||
"""Convert a string of INI config into its TOML equivalent (warning: strips comments)"""
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.optionxform = str # capitalize key names
|
||||
setattr(config, 'optionxform', str) # capitalize key names
|
||||
config.read_string(ini_str)
|
||||
|
||||
# Initialize an empty dictionary to store the TOML representation
|
||||
@@ -77,12 +77,12 @@ class JSONSchemaWithLambdas(GenerateJsonSchema):
|
||||
Usage:
|
||||
>>> json.dumps(value, encoder=JSONSchemaWithLambdas())
|
||||
"""
|
||||
def encode_default(self, default: Any) -> Any:
|
||||
def encode_default(self, dft: Any) -> Any:
|
||||
config = self._config
|
||||
if isinstance(default, Callable):
|
||||
return '{{lambda ' + inspect.getsource(default).split('=lambda ')[-1].strip()[:-1] + '}}'
|
||||
if isinstance(dft, Callable):
|
||||
return '{{lambda ' + inspect.getsource(dft).split('=lambda ')[-1].strip()[:-1] + '}}'
|
||||
return to_jsonable_python(
|
||||
default,
|
||||
dft,
|
||||
timedelta_mode=config.ser_json_timedelta,
|
||||
bytes_mode=config.ser_json_bytes,
|
||||
serialize_unknown=True
|
||||
|
||||
@@ -56,9 +56,19 @@ urldecode = lambda s: s and unquote(s)
|
||||
htmlencode = lambda s: s and escape(s, quote=True)
|
||||
htmldecode = lambda s: s and unescape(s)
|
||||
|
||||
short_ts = lambda ts: str(parse_date(ts).timestamp()).split('.')[0]
|
||||
ts_to_date_str = lambda ts: ts and parse_date(ts).strftime('%Y-%m-%d %H:%M')
|
||||
ts_to_iso = lambda ts: ts and parse_date(ts).isoformat()
|
||||
def short_ts(ts: Any) -> str | None:
|
||||
parsed = parse_date(ts)
|
||||
return None if parsed is None else str(parsed.timestamp()).split('.')[0]
|
||||
|
||||
|
||||
def ts_to_date_str(ts: Any) -> str | None:
|
||||
parsed = parse_date(ts)
|
||||
return None if parsed is None else parsed.strftime('%Y-%m-%d %H:%M')
|
||||
|
||||
|
||||
def ts_to_iso(ts: Any) -> str | None:
|
||||
parsed = parse_date(ts)
|
||||
return None if parsed is None else parsed.isoformat()
|
||||
|
||||
COLOR_REGEX = re.compile(r'\[(?P<arg_1>\d+)(;(?P<arg_2>\d+)(;(?P<arg_3>\d+))?)?m')
|
||||
|
||||
@@ -175,7 +185,7 @@ def docstring(text: Optional[str]):
|
||||
|
||||
|
||||
@enforce_types
|
||||
def str_between(string: str, start: str, end: str=None) -> str:
|
||||
def str_between(string: str, start: str, end: str | None = None) -> str:
|
||||
"""(<abc>12345</def>, <abc>, </def>) -> 12345"""
|
||||
|
||||
content = string.split(start, 1)[-1]
|
||||
@@ -186,7 +196,7 @@ def str_between(string: str, start: str, end: str=None) -> str:
|
||||
|
||||
|
||||
@enforce_types
|
||||
def parse_date(date: Any) -> datetime:
|
||||
def parse_date(date: Any) -> datetime | None:
|
||||
"""Parse unix timestamps, iso format, and human-readable strings"""
|
||||
|
||||
if date is None:
|
||||
@@ -196,20 +206,24 @@ def parse_date(date: Any) -> datetime:
|
||||
if date.tzinfo is None:
|
||||
return date.replace(tzinfo=timezone.utc)
|
||||
|
||||
assert date.tzinfo.utcoffset(datetime.now()).seconds == 0, 'Refusing to load a non-UTC date!'
|
||||
offset = date.utcoffset()
|
||||
assert offset == datetime.now(timezone.utc).utcoffset(), 'Refusing to load a non-UTC date!'
|
||||
return date
|
||||
|
||||
if isinstance(date, (float, int)):
|
||||
date = str(date)
|
||||
|
||||
if isinstance(date, str):
|
||||
return dateparser(date, settings={'TIMEZONE': 'UTC'}).astimezone(timezone.utc)
|
||||
parsed_date = dateparser(date, settings={'TIMEZONE': 'UTC'})
|
||||
if parsed_date is None:
|
||||
raise ValueError(f'Tried to parse invalid date string! {date}')
|
||||
return parsed_date.astimezone(timezone.utc)
|
||||
|
||||
raise ValueError('Tried to parse invalid date! {}'.format(date))
|
||||
|
||||
|
||||
@enforce_types
|
||||
def download_url(url: str, timeout: int=None) -> str:
|
||||
def download_url(url: str, timeout: int | None = None) -> str:
|
||||
"""Download the contents of a remote url and return the text"""
|
||||
|
||||
from archivebox.config.common import ARCHIVING_CONFIG
|
||||
@@ -221,7 +235,8 @@ def download_url(url: str, timeout: int=None) -> str:
|
||||
cookie_jar = http.cookiejar.MozillaCookieJar(ARCHIVING_CONFIG.COOKIES_FILE)
|
||||
cookie_jar.load(ignore_discard=True, ignore_expires=True)
|
||||
for cookie in cookie_jar:
|
||||
session.cookies.set(cookie.name, cookie.value, domain=cookie.domain, path=cookie.path)
|
||||
if cookie.value is not None:
|
||||
session.cookies.set(cookie.name, cookie.value, domain=cookie.domain, path=cookie.path)
|
||||
|
||||
response = session.get(
|
||||
url,
|
||||
@@ -331,47 +346,47 @@ class ExtendedEncoder(pyjson.JSONEncoder):
|
||||
fields and objects
|
||||
"""
|
||||
|
||||
def default(self, obj):
|
||||
cls_name = obj.__class__.__name__
|
||||
def default(self, o):
|
||||
cls_name = o.__class__.__name__
|
||||
|
||||
if hasattr(obj, '_asdict'):
|
||||
return obj._asdict()
|
||||
if hasattr(o, '_asdict'):
|
||||
return o._asdict()
|
||||
|
||||
elif isinstance(obj, bytes):
|
||||
return obj.decode()
|
||||
elif isinstance(o, bytes):
|
||||
return o.decode()
|
||||
|
||||
elif isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
elif isinstance(o, datetime):
|
||||
return o.isoformat()
|
||||
|
||||
elif isinstance(obj, Exception):
|
||||
return '{}: {}'.format(obj.__class__.__name__, obj)
|
||||
elif isinstance(o, Exception):
|
||||
return '{}: {}'.format(o.__class__.__name__, o)
|
||||
|
||||
elif isinstance(obj, Path):
|
||||
return str(obj)
|
||||
elif isinstance(o, Path):
|
||||
return str(o)
|
||||
|
||||
elif cls_name in ('dict_items', 'dict_keys', 'dict_values'):
|
||||
return list(obj)
|
||||
return list(o)
|
||||
|
||||
elif isinstance(obj, Callable):
|
||||
return str(obj)
|
||||
elif isinstance(o, Callable):
|
||||
return str(o)
|
||||
|
||||
# Try dict/list conversion as fallback
|
||||
try:
|
||||
return dict(obj)
|
||||
return dict(o)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
return list(obj)
|
||||
return list(o)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
return str(obj)
|
||||
return str(o)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return pyjson.JSONEncoder.default(self, obj)
|
||||
return pyjson.JSONEncoder.default(self, o)
|
||||
|
||||
|
||||
@enforce_types
|
||||
|
||||
Reference in New Issue
Block a user