add stricter locking around stage machine models

This commit is contained in:
Nick Sweeting
2026-03-15 19:21:41 -07:00
parent 311e4340ec
commit f932054915
15 changed files with 584 additions and 284 deletions

5
.gitignore vendored
View File

@@ -38,6 +38,7 @@ lib/
tmp/
data/
data*/
archive/
output/
logs/
index.sqlite3
@@ -46,6 +47,10 @@ queue.sqlite3
data.*
.archivebox_id
ArchiveBox.conf
*.stdout
*.stderr
*.log
.tmp/
# vim
*.sw?

View File

@@ -14,16 +14,21 @@ __package__ = 'archivebox'
import os
import sys
from pathlib import Path
from typing import Protocol, cast
# Import uuid_compat early to monkey-patch uuid.uuid7 before Django loads migrations
# This fixes migrations generated on Python 3.14+ that reference uuid.uuid7 directly
from archivebox import uuid_compat # noqa: F401
from abx_plugins import get_plugins_dir
class _ReconfigurableStream(Protocol):
def reconfigure(self, *, line_buffering: bool) -> object: ...
# Force unbuffered output for real-time logs
if hasattr(sys.stdout, 'reconfigure'):
sys.stdout.reconfigure(line_buffering=True)
sys.stderr.reconfigure(line_buffering=True)
cast(_ReconfigurableStream, sys.stdout).reconfigure(line_buffering=True)
cast(_ReconfigurableStream, sys.stderr).reconfigure(line_buffering=True)
os.environ['PYTHONUNBUFFERED'] = '1'
ASCII_LOGO = """

View File

@@ -1,18 +1,18 @@
__package__ = 'archivebox.api'
from typing import Optional, cast
from typing import Optional
from datetime import timedelta
from django.http import HttpRequest
from django.utils import timezone
from django.http import HttpRequest
from django.contrib.auth import authenticate
from django.contrib.auth.models import AbstractBaseUser
from django.contrib.auth.models import User
from ninja.security import HttpBearer, APIKeyQuery, APIKeyHeader, HttpBasicAuth
from ninja.errors import HttpError
def get_or_create_api_token(user):
def get_or_create_api_token(user: User | None):
from archivebox.api.models import APIToken
if user and user.is_superuser:
@@ -23,122 +23,106 @@ def get_or_create_api_token(user):
else:
# does not exist, create a new one
api_token = APIToken.objects.create(created_by_id=user.pk, expires=timezone.now() + timedelta(days=30))
if api_token is None:
return None
assert api_token.is_valid(), f"API token is not valid {api_token}"
return api_token
return None
def auth_using_token(token, request: Optional[HttpRequest]=None) -> Optional[AbstractBaseUser]:
def auth_using_token(token: str | None, request: HttpRequest | None = None) -> User | None:
"""Given an API token string, check if a corresponding non-expired APIToken exists, and return its user"""
from archivebox.api.models import APIToken # lazy import model to avoid loading it at urls.py import time
user = None
user: User | None = None
submitted_empty_form = str(token).strip() in ('string', '', 'None', 'null')
if not submitted_empty_form:
try:
token = APIToken.objects.get(token=token)
if token.is_valid():
user = token.created_by
request._api_token = token
api_token = APIToken.objects.get(token=token)
if api_token.is_valid() and isinstance(api_token.created_by, User):
user = api_token.created_by
if request is not None:
setattr(request, '_api_token', api_token)
except APIToken.DoesNotExist:
pass
if not user:
# print('[❌] Failed to authenticate API user using API Key:', request)
return None
return cast(AbstractBaseUser, user)
return user
def auth_using_password(username, password, request: Optional[HttpRequest]=None) -> Optional[AbstractBaseUser]:
def auth_using_password(username: str | None, password: str | None, request: HttpRequest | None = None) -> User | None:
"""Given a username and password, check if they are valid and return the corresponding user"""
user = None
user: User | None = None
submitted_empty_form = (username, password) in (('string', 'string'), ('', ''), (None, None))
if not submitted_empty_form:
user = authenticate(
authenticated_user = authenticate(
username=username,
password=password,
)
if not user:
# print('[❌] Failed to authenticate API user using API Key:', request)
user = None
return cast(AbstractBaseUser | None, user)
if isinstance(authenticated_user, User):
user = authenticated_user
return user
### Base Auth Types
class APITokenAuthCheck:
"""The base class for authentication methods that use an api.models.APIToken"""
def authenticate(self, request: HttpRequest, key: Optional[str]=None) -> Optional[AbstractBaseUser]:
request.user = auth_using_token(
token=key,
request=request,
)
if request.user and request.user.pk:
# Don't set cookie/persist login ouside this erquest, user may be accessing the API from another domain (CSRF/CORS):
# login(request, request.user, backend='django.contrib.auth.backends.ModelBackend')
request._api_auth_method = self.__class__.__name__
if not request.user.is_superuser:
raise HttpError(403, 'Valid API token but User does not have permission (make sure user.is_superuser=True)')
return request.user
class UserPassAuthCheck:
"""The base class for authentication methods that use a username & password"""
def authenticate(self, request: HttpRequest, username: Optional[str]=None, password: Optional[str]=None) -> Optional[AbstractBaseUser]:
request.user = auth_using_password(
username=username,
password=password,
request=request,
)
if request.user and request.user.pk:
# Don't set cookie/persist login ouside this erquest, user may be accessing the API from another domain (CSRF/CORS):
# login(request, request.user, backend='django.contrib.auth.backends.ModelBackend')
request._api_auth_method = self.__class__.__name__
if not request.user.is_superuser:
raise HttpError(403, 'Valid API token but User does not have permission (make sure user.is_superuser=True)')
return request.user
def _require_superuser(user: User | None, request: HttpRequest, auth_method: str) -> User | None:
if user and user.pk:
request.user = user
setattr(request, '_api_auth_method', auth_method)
if not user.is_superuser:
raise HttpError(403, 'Valid credentials but User does not have permission (make sure user.is_superuser=True)')
return user
### Django-Ninja-Provided Auth Methods
class HeaderTokenAuth(APITokenAuthCheck, APIKeyHeader):
class HeaderTokenAuth(APIKeyHeader):
"""Allow authenticating by passing X-API-Key=xyz as a request header"""
param_name = "X-ArchiveBox-API-Key"
class BearerTokenAuth(APITokenAuthCheck, HttpBearer):
"""Allow authenticating by passing Bearer=xyz as a request header"""
pass
def authenticate(self, request: HttpRequest, key: Optional[str]) -> User | None:
return _require_superuser(auth_using_token(token=key, request=request), request, self.__class__.__name__)
class QueryParamTokenAuth(APITokenAuthCheck, APIKeyQuery):
class BearerTokenAuth(HttpBearer):
"""Allow authenticating by passing Bearer=xyz as a request header"""
def authenticate(self, request: HttpRequest, token: str) -> User | None:
return _require_superuser(auth_using_token(token=token, request=request), request, self.__class__.__name__)
class QueryParamTokenAuth(APIKeyQuery):
"""Allow authenticating by passing api_key=xyz as a GET/POST query parameter"""
param_name = "api_key"
class UsernameAndPasswordAuth(UserPassAuthCheck, HttpBasicAuth):
def authenticate(self, request: HttpRequest, key: Optional[str]) -> User | None:
return _require_superuser(auth_using_token(token=key, request=request), request, self.__class__.__name__)
class UsernameAndPasswordAuth(HttpBasicAuth):
"""Allow authenticating by passing username & password via HTTP Basic Authentication (not recommended)"""
pass
def authenticate(self, request: HttpRequest, username: str, password: str) -> User | None:
return _require_superuser(
auth_using_password(username=username, password=password, request=request),
request,
self.__class__.__name__,
)
class DjangoSessionAuth:
"""Allow authenticating with existing Django session cookies (same-origin only)."""
def __call__(self, request: HttpRequest) -> Optional[AbstractBaseUser]:
def __call__(self, request: HttpRequest) -> User | None:
return self.authenticate(request)
def authenticate(self, request: HttpRequest, **kwargs) -> Optional[AbstractBaseUser]:
def authenticate(self, request: HttpRequest, **kwargs) -> User | None:
user = getattr(request, 'user', None)
if user and user.is_authenticated:
request._api_auth_method = self.__class__.__name__
if isinstance(user, User) and user.is_authenticated:
setattr(request, '_api_auth_method', self.__class__.__name__)
if not user.is_superuser:
raise HttpError(403, 'Valid session but User does not have permission (make sure user.is_superuser=True)')
return cast(AbstractBaseUser, user)
return user
return None
### Enabled Auth Methods

View File

@@ -7,6 +7,7 @@ from contextlib import redirect_stdout, redirect_stderr
from django.http import HttpRequest, HttpResponse
from django.core.exceptions import ObjectDoesNotExist, EmptyResultSet, PermissionDenied
from django.contrib.auth.models import User
from ninja import NinjaAPI, Swagger
@@ -16,6 +17,7 @@ from archivebox.config import VERSION
from archivebox.config.version import get_COMMIT_HASH
from archivebox.api.auth import API_AUTH_METHODS
from archivebox.api.models import APIToken
COMMIT_HASH = get_COMMIT_HASH() or 'unknown'
@@ -51,8 +53,8 @@ class NinjaAPIWithIOCapture(NinjaAPI):
with redirect_stderr(stderr):
with redirect_stdout(stdout):
request.stdout = stdout
request.stderr = stderr
setattr(request, 'stdout', stdout)
setattr(request, 'stderr', stderr)
response = super().create_temporal_response(request)
@@ -60,19 +62,20 @@ class NinjaAPIWithIOCapture(NinjaAPI):
response['Cache-Control'] = 'no-store'
# Add debug stdout and stderr headers to response
response['X-ArchiveBox-Stdout'] = str(request.stdout)[200:]
response['X-ArchiveBox-Stderr'] = str(request.stderr)[200:]
response['X-ArchiveBox-Stdout'] = stdout.getvalue().replace('\n', '\\n')[:200]
response['X-ArchiveBox-Stderr'] = stderr.getvalue().replace('\n', '\\n')[:200]
# response['X-ArchiveBox-View'] = self.get_openapi_operation_id(request) or 'Unknown'
# Add Auth Headers to response
api_token = getattr(request, '_api_token', None)
api_token_attr = getattr(request, '_api_token', None)
api_token = api_token_attr if isinstance(api_token_attr, APIToken) else None
token_expiry = api_token.expires.isoformat() if api_token and api_token.expires else 'Never'
response['X-ArchiveBox-Auth-Method'] = getattr(request, '_api_auth_method', None) or 'None'
response['X-ArchiveBox-Auth-Method'] = str(getattr(request, '_api_auth_method', 'None'))
response['X-ArchiveBox-Auth-Expires'] = token_expiry
response['X-ArchiveBox-Auth-Token-Id'] = str(api_token.id) if api_token else 'None'
response['X-ArchiveBox-Auth-User-Id'] = request.user.pk if request.user.pk else 'None'
response['X-ArchiveBox-Auth-User-Username'] = request.user.username if request.user.pk else 'None'
response['X-ArchiveBox-Auth-User-Id'] = str(request.user.pk) if getattr(request.user, 'pk', None) else 'None'
response['X-ArchiveBox-Auth-User-Username'] = request.user.username if isinstance(request.user, User) else 'None'
# import ipdb; ipdb.set_trace()
# print('RESPONDING NOW', response)

View File

@@ -1,6 +1,7 @@
__package__ = 'archivebox.api'
from typing import Optional
from django.http import HttpRequest
from ninja import Router, Schema
@@ -17,7 +18,7 @@ class PasswordAuthSchema(Schema):
@router.post("/get_api_token", auth=None, summary='Generate an API token for a given username & password (or currently logged-in user)') # auth=None because they are not authed yet
def get_api_token(request, auth_data: PasswordAuthSchema):
def get_api_token(request: HttpRequest, auth_data: PasswordAuthSchema):
user = auth_using_password(
username=auth_data.username,
password=auth_data.password,
@@ -45,7 +46,7 @@ class TokenAuthSchema(Schema):
@router.post("/check_api_token", auth=None, summary='Validate an API token to make sure its valid and non-expired') # auth=None because they are not authed yet
def check_api_token(request, token_data: TokenAuthSchema):
def check_api_token(request: HttpRequest, token_data: TokenAuthSchema):
user = auth_using_token(
token=token_data.token,
request=request,

View File

@@ -1,9 +1,12 @@
__package__ = 'archivebox.api'
import json
from io import StringIO
from typing import List, Dict, Any, Optional
from enum import Enum
from django.http import HttpRequest
from ninja import Router, Schema
from archivebox.misc.util import ansi_to_html
@@ -60,16 +63,13 @@ class AddCommandSchema(Schema):
index_only: bool = False
class UpdateCommandSchema(Schema):
resume: Optional[float] = 0
only_new: bool = ARCHIVING_CONFIG.ONLY_NEW
index_only: bool = False
overwrite: bool = False
resume: Optional[str] = None
after: Optional[float] = 0
before: Optional[float] = 999999999999999
status: Optional[StatusChoices] = StatusChoices.unarchived
filter_type: Optional[str] = FilterTypeChoices.substring
filter_patterns: Optional[List[str]] = ['https://example.com']
plugins: Optional[str] = ""
batch_size: int = 100
continuous: bool = False
class ScheduleCommandSchema(Schema):
import_path: Optional[str] = None
@@ -109,7 +109,7 @@ class RemoveCommandSchema(Schema):
@router.post("/add", response=CLICommandResponseSchema, summary='archivebox add [args] [urls]')
def cli_add(request, args: AddCommandSchema):
def cli_add(request: HttpRequest, args: AddCommandSchema):
from archivebox.cli.archivebox_add import add
result = add(
@@ -132,44 +132,45 @@ def cli_add(request, args: AddCommandSchema):
"snapshot_ids": snapshot_ids,
"queued_urls": args.urls,
}
stdout = getattr(request, 'stdout', None)
stderr = getattr(request, 'stderr', None)
return {
"success": True,
"errors": [],
"result": result_payload,
"result_format": "json",
"stdout": ansi_to_html(request.stdout.getvalue().strip()),
"stderr": ansi_to_html(request.stderr.getvalue().strip()),
"stdout": ansi_to_html(stdout.getvalue().strip()) if isinstance(stdout, StringIO) else '',
"stderr": ansi_to_html(stderr.getvalue().strip()) if isinstance(stderr, StringIO) else '',
}
@router.post("/update", response=CLICommandResponseSchema, summary='archivebox update [args] [filter_patterns]')
def cli_update(request, args: UpdateCommandSchema):
def cli_update(request: HttpRequest, args: UpdateCommandSchema):
from archivebox.cli.archivebox_update import update
result = update(
resume=args.resume,
only_new=args.only_new,
index_only=args.index_only,
overwrite=args.overwrite,
before=args.before,
filter_patterns=args.filter_patterns or [],
filter_type=args.filter_type or FilterTypeChoices.substring,
after=args.after,
status=args.status,
filter_type=args.filter_type,
filter_patterns=args.filter_patterns,
plugins=args.plugins,
before=args.before,
resume=args.resume,
batch_size=args.batch_size,
continuous=args.continuous,
)
stdout = getattr(request, 'stdout', None)
stderr = getattr(request, 'stderr', None)
return {
"success": True,
"errors": [],
"result": result,
"stdout": ansi_to_html(request.stdout.getvalue().strip()),
"stderr": ansi_to_html(request.stderr.getvalue().strip()),
"stdout": ansi_to_html(stdout.getvalue().strip()) if isinstance(stdout, StringIO) else '',
"stderr": ansi_to_html(stderr.getvalue().strip()) if isinstance(stderr, StringIO) else '',
}
@router.post("/schedule", response=CLICommandResponseSchema, summary='archivebox schedule [args] [import_path]')
def cli_schedule(request, args: ScheduleCommandSchema):
def cli_schedule(request: HttpRequest, args: ScheduleCommandSchema):
from archivebox.cli.archivebox_schedule import schedule
result = schedule(
@@ -187,19 +188,21 @@ def cli_schedule(request, args: ScheduleCommandSchema):
update=args.update,
)
stdout = getattr(request, 'stdout', None)
stderr = getattr(request, 'stderr', None)
return {
"success": True,
"errors": [],
"result": result,
"result_format": "json",
"stdout": ansi_to_html(request.stdout.getvalue().strip()),
"stderr": ansi_to_html(request.stderr.getvalue().strip()),
"stdout": ansi_to_html(stdout.getvalue().strip()) if isinstance(stdout, StringIO) else '',
"stderr": ansi_to_html(stderr.getvalue().strip()) if isinstance(stderr, StringIO) else '',
}
@router.post("/search", response=CLICommandResponseSchema, summary='archivebox search [args] [filter_patterns]')
def cli_search(request, args: ListCommandSchema):
def cli_search(request: HttpRequest, args: ListCommandSchema):
from archivebox.cli.archivebox_search import search
result = search(
@@ -224,25 +227,28 @@ def cli_search(request, args: ListCommandSchema):
elif args.as_csv:
result_format = "csv"
stdout = getattr(request, 'stdout', None)
stderr = getattr(request, 'stderr', None)
return {
"success": True,
"errors": [],
"result": result,
"result_format": result_format,
"stdout": ansi_to_html(request.stdout.getvalue().strip()),
"stderr": ansi_to_html(request.stderr.getvalue().strip()),
"stdout": ansi_to_html(stdout.getvalue().strip()) if isinstance(stdout, StringIO) else '',
"stderr": ansi_to_html(stderr.getvalue().strip()) if isinstance(stderr, StringIO) else '',
}
@router.post("/remove", response=CLICommandResponseSchema, summary='archivebox remove [args] [filter_patterns]')
def cli_remove(request, args: RemoveCommandSchema):
def cli_remove(request: HttpRequest, args: RemoveCommandSchema):
from archivebox.cli.archivebox_remove import remove
from archivebox.cli.archivebox_search import get_snapshots
from archivebox.core.models import Snapshot
filter_patterns = args.filter_patterns or []
snapshots_to_remove = get_snapshots(
filter_patterns=args.filter_patterns,
filter_patterns=filter_patterns,
filter_type=args.filter_type,
after=args.after,
before=args.before,
@@ -256,7 +262,7 @@ def cli_remove(request, args: RemoveCommandSchema):
before=args.before,
after=args.after,
filter_type=args.filter_type,
filter_patterns=args.filter_patterns,
filter_patterns=filter_patterns,
)
result = {
@@ -264,12 +270,14 @@ def cli_remove(request, args: RemoveCommandSchema):
"removed_snapshot_ids": removed_snapshot_ids,
"remaining_snapshots": Snapshot.objects.count(),
}
stdout = getattr(request, 'stdout', None)
stderr = getattr(request, 'stderr', None)
return {
"success": True,
"errors": [],
"result": result,
"result_format": "json",
"stdout": ansi_to_html(request.stdout.getvalue().strip()),
"stderr": ansi_to_html(request.stderr.getvalue().strip()),
"stdout": ansi_to_html(stdout.getvalue().strip()) if isinstance(stdout, StringIO) else '',
"stderr": ansi_to_html(stderr.getvalue().strip()) if isinstance(stderr, StringIO) else '',
}

View File

@@ -2,16 +2,18 @@ __package__ = 'archivebox.api'
import math
from uuid import UUID
from typing import List, Optional, Union, Any
from typing import List, Optional, Union, Any, Annotated
from datetime import datetime
from django.db.models import Q
from django.db.models import Model, Q
from django.http import HttpRequest
from django.core.exceptions import ValidationError
from django.contrib.auth import get_user_model
from django.contrib.auth.models import User
from django.shortcuts import redirect
from django.utils import timezone
from ninja import Router, Schema, FilterSchema, Field, Query
from ninja import Router, Schema, FilterLookup, FilterSchema, Query
from ninja.pagination import paginate, PaginationBase
from ninja.errors import HttpError
@@ -24,12 +26,12 @@ router = Router(tags=['Core Models'])
class CustomPagination(PaginationBase):
class Input(Schema):
class Input(PaginationBase.Input):
limit: int = 200
offset: int = 0
page: int = 0
class Output(Schema):
class Output(PaginationBase.Output):
total_items: int
total_pages: int
page: int
@@ -38,7 +40,7 @@ class CustomPagination(PaginationBase):
num_items: int
items: List[Any]
def paginate_queryset(self, queryset, pagination: Input, **params):
def paginate_queryset(self, queryset, pagination: Input, request: HttpRequest, **params):
limit = min(pagination.limit, 500)
offset = pagination.offset or (pagination.page * limit)
total = queryset.count()
@@ -115,33 +117,33 @@ class ArchiveResultSchema(MinimalArchiveResultSchema):
class ArchiveResultFilterSchema(FilterSchema):
id: Optional[str] = Field(None, q=['id__startswith', 'snapshot__id__startswith', 'snapshot__timestamp__startswith'])
search: Optional[str] = Field(None, q=['snapshot__url__icontains', 'snapshot__title__icontains', 'snapshot__tags__name__icontains', 'plugin', 'output_str__icontains', 'id__startswith', 'snapshot__id__startswith', 'snapshot__timestamp__startswith'])
snapshot_id: Optional[str] = Field(None, q=['snapshot__id__startswith', 'snapshot__timestamp__startswith'])
snapshot_url: Optional[str] = Field(None, q='snapshot__url__icontains')
snapshot_tag: Optional[str] = Field(None, q='snapshot__tags__name__icontains')
status: Optional[str] = Field(None, q='status')
output_str: Optional[str] = Field(None, q='output_str__icontains')
plugin: Optional[str] = Field(None, q='plugin__icontains')
hook_name: Optional[str] = Field(None, q='hook_name__icontains')
process_id: Optional[str] = Field(None, q='process__id__startswith')
cmd: Optional[str] = Field(None, q='cmd__0__icontains')
pwd: Optional[str] = Field(None, q='pwd__icontains')
cmd_version: Optional[str] = Field(None, q='cmd_version')
created_at: Optional[datetime] = Field(None, q='created_at')
created_at__gte: Optional[datetime] = Field(None, q='created_at__gte')
created_at__lt: Optional[datetime] = Field(None, q='created_at__lt')
id: Annotated[Optional[str], FilterLookup(['id__startswith', 'snapshot__id__startswith', 'snapshot__timestamp__startswith'])] = None
search: Annotated[Optional[str], FilterLookup(['snapshot__url__icontains', 'snapshot__title__icontains', 'snapshot__tags__name__icontains', 'plugin', 'output_str__icontains', 'id__startswith', 'snapshot__id__startswith', 'snapshot__timestamp__startswith'])] = None
snapshot_id: Annotated[Optional[str], FilterLookup(['snapshot__id__startswith', 'snapshot__timestamp__startswith'])] = None
snapshot_url: Annotated[Optional[str], FilterLookup('snapshot__url__icontains')] = None
snapshot_tag: Annotated[Optional[str], FilterLookup('snapshot__tags__name__icontains')] = None
status: Annotated[Optional[str], FilterLookup('status')] = None
output_str: Annotated[Optional[str], FilterLookup('output_str__icontains')] = None
plugin: Annotated[Optional[str], FilterLookup('plugin__icontains')] = None
hook_name: Annotated[Optional[str], FilterLookup('hook_name__icontains')] = None
process_id: Annotated[Optional[str], FilterLookup('process__id__startswith')] = None
cmd: Annotated[Optional[str], FilterLookup('cmd__0__icontains')] = None
pwd: Annotated[Optional[str], FilterLookup('pwd__icontains')] = None
cmd_version: Annotated[Optional[str], FilterLookup('cmd_version')] = None
created_at: Annotated[Optional[datetime], FilterLookup('created_at')] = None
created_at__gte: Annotated[Optional[datetime], FilterLookup('created_at__gte')] = None
created_at__lt: Annotated[Optional[datetime], FilterLookup('created_at__lt')] = None
@router.get("/archiveresults", response=List[ArchiveResultSchema], url_name="get_archiveresult")
@paginate(CustomPagination)
def get_archiveresults(request, filters: ArchiveResultFilterSchema = Query(...)):
def get_archiveresults(request: HttpRequest, filters: Query[ArchiveResultFilterSchema]):
"""List all ArchiveResult entries matching these filters."""
return filters.filter(ArchiveResult.objects.all()).distinct()
@router.get("/archiveresult/{archiveresult_id}", response=ArchiveResultSchema, url_name="get_archiveresult")
def get_archiveresult(request, archiveresult_id: str):
def get_archiveresult(request: HttpRequest, archiveresult_id: str):
"""Get a specific ArchiveResult by id."""
return ArchiveResult.objects.get(Q(id__icontains=archiveresult_id))
@@ -185,7 +187,7 @@ class SnapshotSchema(Schema):
@staticmethod
def resolve_archiveresults(obj, context):
if context['request'].with_archiveresults:
if bool(getattr(context['request'], 'with_archiveresults', False)):
return obj.archiveresult_set.all().distinct()
return ArchiveResult.objects.none()
@@ -217,36 +219,36 @@ def normalize_tag_list(tags: Optional[List[str]] = None) -> List[str]:
class SnapshotFilterSchema(FilterSchema):
id: Optional[str] = Field(None, q=['id__icontains', 'timestamp__startswith'])
created_by_id: str = Field(None, q='crawl__created_by_id')
created_by_username: str = Field(None, q='crawl__created_by__username__icontains')
created_at__gte: datetime = Field(None, q='created_at__gte')
created_at__lt: datetime = Field(None, q='created_at__lt')
created_at: datetime = Field(None, q='created_at')
modified_at: datetime = Field(None, q='modified_at')
modified_at__gte: datetime = Field(None, q='modified_at__gte')
modified_at__lt: datetime = Field(None, q='modified_at__lt')
search: Optional[str] = Field(None, q=['url__icontains', 'title__icontains', 'tags__name__icontains', 'id__icontains', 'timestamp__startswith'])
url: Optional[str] = Field(None, q='url')
tag: Optional[str] = Field(None, q='tags__name')
title: Optional[str] = Field(None, q='title__icontains')
timestamp: Optional[str] = Field(None, q='timestamp__startswith')
bookmarked_at__gte: Optional[datetime] = Field(None, q='bookmarked_at__gte')
bookmarked_at__lt: Optional[datetime] = Field(None, q='bookmarked_at__lt')
id: Annotated[Optional[str], FilterLookup(['id__icontains', 'timestamp__startswith'])] = None
created_by_id: Annotated[Optional[str], FilterLookup('crawl__created_by_id')] = None
created_by_username: Annotated[Optional[str], FilterLookup('crawl__created_by__username__icontains')] = None
created_at__gte: Annotated[Optional[datetime], FilterLookup('created_at__gte')] = None
created_at__lt: Annotated[Optional[datetime], FilterLookup('created_at__lt')] = None
created_at: Annotated[Optional[datetime], FilterLookup('created_at')] = None
modified_at: Annotated[Optional[datetime], FilterLookup('modified_at')] = None
modified_at__gte: Annotated[Optional[datetime], FilterLookup('modified_at__gte')] = None
modified_at__lt: Annotated[Optional[datetime], FilterLookup('modified_at__lt')] = None
search: Annotated[Optional[str], FilterLookup(['url__icontains', 'title__icontains', 'tags__name__icontains', 'id__icontains', 'timestamp__startswith'])] = None
url: Annotated[Optional[str], FilterLookup('url')] = None
tag: Annotated[Optional[str], FilterLookup('tags__name')] = None
title: Annotated[Optional[str], FilterLookup('title__icontains')] = None
timestamp: Annotated[Optional[str], FilterLookup('timestamp__startswith')] = None
bookmarked_at__gte: Annotated[Optional[datetime], FilterLookup('bookmarked_at__gte')] = None
bookmarked_at__lt: Annotated[Optional[datetime], FilterLookup('bookmarked_at__lt')] = None
@router.get("/snapshots", response=List[SnapshotSchema], url_name="get_snapshots")
@paginate(CustomPagination)
def get_snapshots(request, filters: SnapshotFilterSchema = Query(...), with_archiveresults: bool = False):
def get_snapshots(request: HttpRequest, filters: Query[SnapshotFilterSchema], with_archiveresults: bool = False):
"""List all Snapshot entries matching these filters."""
request.with_archiveresults = with_archiveresults
setattr(request, 'with_archiveresults', with_archiveresults)
return filters.filter(Snapshot.objects.all()).distinct()
@router.get("/snapshot/{snapshot_id}", response=SnapshotSchema, url_name="get_snapshot")
def get_snapshot(request, snapshot_id: str, with_archiveresults: bool = True):
def get_snapshot(request: HttpRequest, snapshot_id: str, with_archiveresults: bool = True):
"""Get a specific Snapshot by id."""
request.with_archiveresults = with_archiveresults
setattr(request, 'with_archiveresults', with_archiveresults)
try:
return Snapshot.objects.get(Q(id__startswith=snapshot_id) | Q(timestamp__startswith=snapshot_id))
except Snapshot.DoesNotExist:
@@ -254,7 +256,7 @@ def get_snapshot(request, snapshot_id: str, with_archiveresults: bool = True):
@router.post("/snapshots", response=SnapshotSchema, url_name="create_snapshot")
def create_snapshot(request, data: SnapshotCreateSchema):
def create_snapshot(request: HttpRequest, data: SnapshotCreateSchema):
tags = normalize_tag_list(data.tags)
if data.status is not None and data.status not in Snapshot.StatusChoices.values:
raise HttpError(400, f'Invalid status: {data.status}')
@@ -274,7 +276,7 @@ def create_snapshot(request, data: SnapshotCreateSchema):
tags_str=','.join(tags),
status=Crawl.StatusChoices.QUEUED,
retry_at=timezone.now(),
created_by=request.user,
created_by=request.user if isinstance(request.user, User) else None,
)
snapshot_defaults = {
@@ -311,12 +313,12 @@ def create_snapshot(request, data: SnapshotCreateSchema):
except Exception:
pass
request.with_archiveresults = False
setattr(request, 'with_archiveresults', False)
return snapshot
@router.patch("/snapshot/{snapshot_id}", response=SnapshotSchema, url_name="patch_snapshot")
def patch_snapshot(request, snapshot_id: str, data: SnapshotUpdateSchema):
def patch_snapshot(request: HttpRequest, snapshot_id: str, data: SnapshotUpdateSchema):
"""Update a snapshot (e.g., set status=sealed to cancel queued work)."""
try:
snapshot = Snapshot.objects.get(Q(id__startswith=snapshot_id) | Q(timestamp__startswith=snapshot_id))
@@ -343,15 +345,15 @@ def patch_snapshot(request, snapshot_id: str, data: SnapshotUpdateSchema):
snapshot.save_tags(normalize_tag_list(tags))
snapshot.save(update_fields=update_fields)
request.with_archiveresults = False
setattr(request, 'with_archiveresults', False)
return snapshot
@router.delete("/snapshot/{snapshot_id}", response=SnapshotDeleteResponseSchema, url_name="delete_snapshot")
def delete_snapshot(request, snapshot_id: str):
def delete_snapshot(request: HttpRequest, snapshot_id: str):
snapshot = get_snapshot(request, snapshot_id, with_archiveresults=False)
snapshot_id_str = str(snapshot.id)
crawl_id_str = str(snapshot.crawl_id)
crawl_id_str = str(snapshot.crawl.pk)
deleted_count, _ = snapshot.delete()
return {
'success': True,
@@ -381,8 +383,10 @@ class TagSchema(Schema):
@staticmethod
def resolve_created_by_username(obj):
User = get_user_model()
return User.objects.get(id=obj.created_by_id).username
user_model = get_user_model()
user = user_model.objects.get(id=obj.created_by_id)
username = getattr(user, 'username', None)
return username if isinstance(username, str) else str(user)
@staticmethod
def resolve_num_snapshots(obj, context):
@@ -390,23 +394,23 @@ class TagSchema(Schema):
@staticmethod
def resolve_snapshots(obj, context):
if context['request'].with_snapshots:
if bool(getattr(context['request'], 'with_snapshots', False)):
return obj.snapshot_set.all().distinct()
return Snapshot.objects.none()
@router.get("/tags", response=List[TagSchema], url_name="get_tags")
@paginate(CustomPagination)
def get_tags(request):
request.with_snapshots = False
request.with_archiveresults = False
def get_tags(request: HttpRequest):
setattr(request, 'with_snapshots', False)
setattr(request, 'with_archiveresults', False)
return Tag.objects.all().distinct()
@router.get("/tag/{tag_id}", response=TagSchema, url_name="get_tag")
def get_tag(request, tag_id: str, with_snapshots: bool = True):
request.with_snapshots = with_snapshots
request.with_archiveresults = False
def get_tag(request: HttpRequest, tag_id: str, with_snapshots: bool = True):
setattr(request, 'with_snapshots', with_snapshots)
setattr(request, 'with_archiveresults', False)
try:
return Tag.objects.get(id__icontains=tag_id)
except (Tag.DoesNotExist, ValidationError):
@@ -414,15 +418,15 @@ def get_tag(request, tag_id: str, with_snapshots: bool = True):
@router.get("/any/{id}", response=Union[SnapshotSchema, ArchiveResultSchema, TagSchema, CrawlSchema], url_name="get_any", summary="Get any object by its ID")
def get_any(request, id: str):
def get_any(request: HttpRequest, id: str):
"""Get any object by its ID (e.g. snapshot, archiveresult, tag, crawl, etc.)."""
request.with_snapshots = False
request.with_archiveresults = False
setattr(request, 'with_snapshots', False)
setattr(request, 'with_archiveresults', False)
for getter in [get_snapshot, get_archiveresult, get_tag]:
try:
response = getter(request, id)
if response:
if isinstance(response, Model):
return redirect(f"/api/v1/{response._meta.app_label}/{response._meta.model_name}/{response.id}?{request.META['QUERY_STRING']}")
except Exception:
pass
@@ -430,7 +434,7 @@ def get_any(request, id: str):
try:
from archivebox.api.v1_crawls import get_crawl
response = get_crawl(request, id)
if response:
if isinstance(response, Model):
return redirect(f"/api/v1/{response._meta.app_label}/{response._meta.model_name}/{response.id}?{request.META['QUERY_STRING']}")
except Exception:
pass
@@ -468,7 +472,7 @@ class TagSnapshotResponseSchema(Schema):
@router.get("/tags/autocomplete/", response=TagAutocompleteSchema, url_name="tags_autocomplete")
def tags_autocomplete(request, q: str = ""):
def tags_autocomplete(request: HttpRequest, q: str = ""):
"""Return tags matching the query for autocomplete."""
if not q:
# Return all tags if no query (limited to 50)
@@ -482,7 +486,7 @@ def tags_autocomplete(request, q: str = ""):
@router.post("/tags/create/", response=TagCreateResponseSchema, url_name="tags_create")
def tags_create(request, data: TagCreateSchema):
def tags_create(request: HttpRequest, data: TagCreateSchema):
"""Create a new tag or return existing one."""
name = data.name.strip()
if not name:
@@ -498,7 +502,10 @@ def tags_create(request, data: TagCreateSchema):
# If found by case-insensitive match, use that tag
if not created:
tag = Tag.objects.filter(name__iexact=name).first()
existing_tag = Tag.objects.filter(name__iexact=name).first()
if existing_tag is None:
raise HttpError(500, 'Failed to load existing tag after get_or_create')
tag = existing_tag
return {
'success': True,
@@ -509,7 +516,7 @@ def tags_create(request, data: TagCreateSchema):
@router.post("/tags/add-to-snapshot/", response=TagSnapshotResponseSchema, url_name="tags_add_to_snapshot")
def tags_add_to_snapshot(request, data: TagSnapshotRequestSchema):
def tags_add_to_snapshot(request: HttpRequest, data: TagSnapshotRequestSchema):
"""Add a tag to a snapshot. Creates the tag if it doesn't exist."""
# Get the snapshot
try:
@@ -522,6 +529,8 @@ def tags_add_to_snapshot(request, data: TagSnapshotRequestSchema):
snapshot = Snapshot.objects.filter(
Q(id__startswith=data.snapshot_id) | Q(timestamp__startswith=data.snapshot_id)
).first()
if snapshot is None:
raise HttpError(404, 'Snapshot not found')
# Get or create the tag
if data.tag_name:
@@ -537,7 +546,9 @@ def tags_add_to_snapshot(request, data: TagSnapshotRequestSchema):
}
)
# If found by case-insensitive match, use that tag
tag = Tag.objects.filter(name__iexact=name).first() or tag
existing_tag = Tag.objects.filter(name__iexact=name).first()
if existing_tag is not None:
tag = existing_tag
elif data.tag_id:
try:
tag = Tag.objects.get(pk=data.tag_id)
@@ -557,7 +568,7 @@ def tags_add_to_snapshot(request, data: TagSnapshotRequestSchema):
@router.post("/tags/remove-from-snapshot/", response=TagSnapshotResponseSchema, url_name="tags_remove_from_snapshot")
def tags_remove_from_snapshot(request, data: TagSnapshotRequestSchema):
def tags_remove_from_snapshot(request: HttpRequest, data: TagSnapshotRequestSchema):
"""Remove a tag from a snapshot."""
# Get the snapshot
try:
@@ -570,6 +581,8 @@ def tags_remove_from_snapshot(request, data: TagSnapshotRequestSchema):
snapshot = Snapshot.objects.filter(
Q(id__startswith=data.snapshot_id) | Q(timestamp__startswith=data.snapshot_id)
).first()
if snapshot is None:
raise HttpError(404, 'Snapshot not found')
# Get the tag
if data.tag_id:

View File

@@ -3,9 +3,11 @@ __package__ = 'archivebox.api'
from uuid import UUID
from typing import List, Optional
from datetime import datetime
from django.http import HttpRequest
from django.utils import timezone
from django.contrib.auth import get_user_model
from django.contrib.auth.models import User
from ninja import Router, Schema
from ninja.errors import HttpError
@@ -44,12 +46,14 @@ class CrawlSchema(Schema):
@staticmethod
def resolve_created_by_username(obj):
User = get_user_model()
return User.objects.get(id=obj.created_by_id).username
user_model = get_user_model()
user = user_model.objects.get(id=obj.created_by_id)
username = getattr(user, 'username', None)
return username if isinstance(username, str) else str(user)
@staticmethod
def resolve_snapshots(obj, context):
if context['request'].with_snapshots:
if bool(getattr(context['request'], 'with_snapshots', False)):
return obj.snapshot_set.all().distinct()
return Snapshot.objects.none()
@@ -85,12 +89,12 @@ def normalize_tag_list(tags: Optional[List[str]] = None, tags_str: str = '') ->
@router.get("/crawls", response=List[CrawlSchema], url_name="get_crawls")
def get_crawls(request):
def get_crawls(request: HttpRequest):
return Crawl.objects.all().distinct()
@router.post("/crawls", response=CrawlSchema, url_name="create_crawl")
def create_crawl(request, data: CrawlCreateSchema):
def create_crawl(request: HttpRequest, data: CrawlCreateSchema):
urls = [url.strip() for url in data.urls if url and url.strip()]
if not urls:
raise HttpError(400, 'At least one URL is required')
@@ -107,16 +111,16 @@ def create_crawl(request, data: CrawlCreateSchema):
config=data.config,
status=Crawl.StatusChoices.QUEUED,
retry_at=timezone.now(),
created_by=request.user,
created_by=request.user if isinstance(request.user, User) else None,
)
crawl.create_snapshots_from_urls()
return crawl
@router.get("/crawl/{crawl_id}", response=CrawlSchema | str, url_name="get_crawl")
def get_crawl(request, crawl_id: str, as_rss: bool=False, with_snapshots: bool=False, with_archiveresults: bool=False):
def get_crawl(request: HttpRequest, crawl_id: str, as_rss: bool=False, with_snapshots: bool=False, with_archiveresults: bool=False):
"""Get a specific Crawl by id."""
request.with_snapshots = with_snapshots
request.with_archiveresults = with_archiveresults
setattr(request, 'with_snapshots', with_snapshots)
setattr(request, 'with_archiveresults', with_archiveresults)
crawl = Crawl.objects.get(id__icontains=crawl_id)
if crawl and as_rss:
@@ -135,7 +139,7 @@ def get_crawl(request, crawl_id: str, as_rss: bool=False, with_snapshots: bool=F
@router.patch("/crawl/{crawl_id}", response=CrawlSchema, url_name="patch_crawl")
def patch_crawl(request, crawl_id: str, data: CrawlUpdateSchema):
def patch_crawl(request: HttpRequest, crawl_id: str, data: CrawlUpdateSchema):
"""Update a crawl (e.g., set status=sealed to cancel queued work)."""
crawl = Crawl.objects.get(id__icontains=crawl_id)
payload = data.dict(exclude_unset=True)
@@ -174,7 +178,7 @@ def patch_crawl(request, crawl_id: str, data: CrawlUpdateSchema):
@router.delete("/crawl/{crawl_id}", response=CrawlDeleteResponseSchema, url_name="delete_crawl")
def delete_crawl(request, crawl_id: str):
def delete_crawl(request: HttpRequest, crawl_id: str):
crawl = Crawl.objects.get(id__icontains=crawl_id)
crawl_id_str = str(crawl.id)
snapshot_count = crawl.snapshot_set.count()

View File

@@ -1,10 +1,12 @@
__package__ = 'archivebox.api'
from uuid import UUID
from typing import List, Optional
from typing import Annotated, List, Optional
from datetime import datetime
from ninja import Router, Schema, FilterSchema, Field, Query
from django.http import HttpRequest
from ninja import FilterLookup, FilterSchema, Query, Router, Schema
from ninja.pagination import paginate
from archivebox.api.v1_core import CustomPagination
@@ -41,16 +43,13 @@ class MachineSchema(Schema):
class MachineFilterSchema(FilterSchema):
id: Optional[str] = Field(None, q='id__startswith')
hostname: Optional[str] = Field(None, q='hostname__icontains')
os_platform: Optional[str] = Field(None, q='os_platform__icontains')
os_arch: Optional[str] = Field(None, q='os_arch')
hw_in_docker: Optional[bool] = Field(None, q='hw_in_docker')
hw_in_vm: Optional[bool] = Field(None, q='hw_in_vm')
# ============================================================================
bin_providers: Optional[str] = Field(None, q='bin_providers__icontains')
id: Annotated[Optional[str], FilterLookup('id__startswith')] = None
hostname: Annotated[Optional[str], FilterLookup('hostname__icontains')] = None
os_platform: Annotated[Optional[str], FilterLookup('os_platform__icontains')] = None
os_arch: Annotated[Optional[str], FilterLookup('os_arch')] = None
hw_in_docker: Annotated[Optional[bool], FilterLookup('hw_in_docker')] = None
hw_in_vm: Annotated[Optional[bool], FilterLookup('hw_in_vm')] = None
bin_providers: Annotated[Optional[str], FilterLookup('bin_providers__icontains')] = None
# ============================================================================
@@ -86,12 +85,12 @@ class BinarySchema(Schema):
class BinaryFilterSchema(FilterSchema):
id: Optional[str] = Field(None, q='id__startswith')
name: Optional[str] = Field(None, q='name__icontains')
binprovider: Optional[str] = Field(None, q='binprovider')
status: Optional[str] = Field(None, q='status')
machine_id: Optional[str] = Field(None, q='machine_id__startswith')
version: Optional[str] = Field(None, q='version__icontains')
id: Annotated[Optional[str], FilterLookup('id__startswith')] = None
name: Annotated[Optional[str], FilterLookup('name__icontains')] = None
binprovider: Annotated[Optional[str], FilterLookup('binprovider')] = None
status: Annotated[Optional[str], FilterLookup('status')] = None
machine_id: Annotated[Optional[str], FilterLookup('machine_id__startswith')] = None
version: Annotated[Optional[str], FilterLookup('version__icontains')] = None
# ============================================================================
@@ -100,21 +99,21 @@ class BinaryFilterSchema(FilterSchema):
@router.get("/machines", response=List[MachineSchema], url_name="get_machines")
@paginate(CustomPagination)
def get_machines(request, filters: MachineFilterSchema = Query(...)):
def get_machines(request: HttpRequest, filters: Query[MachineFilterSchema]):
"""List all machines."""
from archivebox.machine.models import Machine
return filters.filter(Machine.objects.all()).distinct()
@router.get("/machine/current", response=MachineSchema, url_name="get_current_machine")
def get_current_machine(request):
def get_current_machine(request: HttpRequest):
"""Get the current machine."""
from archivebox.machine.models import Machine
return Machine.current()
@router.get("/machine/{machine_id}", response=MachineSchema, url_name="get_machine")
def get_machine(request, machine_id: str):
def get_machine(request: HttpRequest, machine_id: str):
"""Get a specific machine by ID."""
from archivebox.machine.models import Machine
from django.db.models import Q
@@ -130,21 +129,21 @@ def get_machine(request, machine_id: str):
@router.get("/binaries", response=List[BinarySchema], url_name="get_binaries")
@paginate(CustomPagination)
def get_binaries(request, filters: BinaryFilterSchema = Query(...)):
def get_binaries(request: HttpRequest, filters: Query[BinaryFilterSchema]):
"""List all binaries."""
from archivebox.machine.models import Binary
return filters.filter(Binary.objects.all().select_related('machine')).distinct()
@router.get("/binary/{binary_id}", response=BinarySchema, url_name="get_binary")
def get_binary(request, binary_id: str):
def get_binary(request: HttpRequest, binary_id: str):
"""Get a specific binary by ID."""
from archivebox.machine.models import Binary
return Binary.objects.select_related('machine').get(id__startswith=binary_id)
@router.get("/binary/by-name/{name}", response=List[BinarySchema], url_name="get_binaries_by_name")
def get_binaries_by_name(request, name: str):
def get_binaries_by_name(request: HttpRequest, name: str):
"""Get all binaries with the given name."""
from archivebox.machine.models import Binary
return list(Binary.objects.filter(name__iexact=name).select_related('machine'))

View File

@@ -39,7 +39,10 @@ def process_archiveresult_by_id(archiveresult_id: str) -> int:
"""
Run extraction for a single ArchiveResult by ID (used by workers).
Triggers the ArchiveResult's state machine tick() to run the extractor plugin.
Triggers the ArchiveResult's state machine tick() to run the extractor
plugin, but only after claiming ownership via retry_at. This keeps direct
CLI execution aligned with the worker lifecycle and prevents duplicate hook
runs if another process already owns the same ArchiveResult.
"""
from rich import print as rprint
from archivebox.core.models import ArchiveResult
@@ -53,9 +56,12 @@ def process_archiveresult_by_id(archiveresult_id: str) -> int:
rprint(f'[blue]Extracting {archiveresult.plugin} for {archiveresult.snapshot.url}[/blue]', file=sys.stderr)
try:
# Trigger state machine tick - this runs the actual extraction
archiveresult.sm.tick()
archiveresult.refresh_from_db()
# Claim-before-tick is the required calling pattern for direct
# state-machine drivers. If another worker already owns this row,
# report that and exit without running duplicate extractor side effects.
if not archiveresult.tick_claimed(lock_seconds=120):
print(f'[yellow]Extraction already claimed by another process: {archiveresult.plugin}[/yellow]')
return 0
if archiveresult.status == ArchiveResult.StatusChoices.SUCCEEDED:
print(f'[green]Extraction succeeded: {archiveresult.output_str}[/green]')

View File

@@ -382,6 +382,88 @@ class Crawl(ModelWithOutputDir, ModelWithConfig, ModelWithHealthStats, ModelWith
return created_snapshots
def install_declared_binaries(self, binary_names: set[str], machine=None) -> None:
"""
Install crawl-declared Binary rows without violating the retry_at lock lifecycle.
Correct calling pattern:
1. Crawl hooks declare Binary records and queue them with retry_at <= now
2. Exactly one actor claims each Binary by moving retry_at into the future
3. Only that owner executes `.sm.tick()` and performs install side effects
4. Everyone else waits for the claimed owner to finish instead of launching
a second install against shared state such as the pip or npm trees
This helper follows that contract by claiming each Binary before ticking
it, and by waiting when another worker already owns the row. That keeps
synchronous crawl execution compatible with the global BinaryWorker and
avoids duplicate installs of the same dependency.
"""
import time
from archivebox.machine.models import Binary, Machine
if not binary_names:
return
machine = machine or Machine.current()
lock_seconds = 600
deadline = time.monotonic() + max(lock_seconds, len(binary_names) * lock_seconds)
while time.monotonic() < deadline:
unresolved_binaries = list(
Binary.objects.filter(
machine=machine,
name__in=binary_names,
).exclude(
status=Binary.StatusChoices.INSTALLED,
).order_by('name')
)
if not unresolved_binaries:
return
claimed_any = False
waiting_on_existing_owner = False
now = timezone.now()
for binary in unresolved_binaries:
try:
if binary.tick_claimed(lock_seconds=lock_seconds):
claimed_any = True
continue
except Exception:
claimed_any = True
continue
binary.refresh_from_db()
if binary.status == Binary.StatusChoices.INSTALLED:
claimed_any = True
continue
if binary.retry_at and binary.retry_at > now:
waiting_on_existing_owner = True
if claimed_any:
continue
if waiting_on_existing_owner:
time.sleep(0.5)
continue
break
unresolved_binaries = list(
Binary.objects.filter(
machine=machine,
name__in=binary_names,
).exclude(
status=Binary.StatusChoices.INSTALLED,
).order_by('name')
)
if unresolved_binaries:
binary_details = ', '.join(
f'{binary.name} (status={binary.status}, retry_at={binary.retry_at})'
for binary in unresolved_binaries
)
raise RuntimeError(
f'Crawl dependencies failed to install before continuing: {binary_details}'
)
def run(self) -> 'Snapshot | None':
"""
Execute this Crawl: run hooks, process JSONL, create snapshots.
@@ -428,47 +510,6 @@ class Crawl(ModelWithOutputDir, ModelWithConfig, ModelWithHealthStats, ModelWith
chrome_binary=chrome_binary,
)
def install_declared_binaries(binary_names: set[str]) -> None:
if not binary_names:
return
max_attempts = max(2, len(binary_names))
for _ in range(max_attempts):
pending_binaries = list(
Binary.objects.filter(
machine=machine,
name__in=binary_names,
).exclude(
status=Binary.StatusChoices.INSTALLED,
).order_by('retry_at', 'name')
)
if not pending_binaries:
return
for binary in pending_binaries:
try:
binary.sm.tick()
except Exception:
continue
unresolved_binaries = list(
Binary.objects.filter(
machine=machine,
name__in=binary_names,
).exclude(
status=Binary.StatusChoices.INSTALLED,
).order_by('name')
)
if unresolved_binaries:
binary_details = ', '.join(
f'{binary.name} (status={binary.status})'
for binary in unresolved_binaries
)
raise RuntimeError(
f'Crawl dependencies failed to install before continuing: {binary_details}'
)
executed_crawl_hooks: set[str] = set()
def run_crawl_hook(hook: Path) -> set[str]:
@@ -598,11 +639,11 @@ class Crawl(ModelWithOutputDir, ModelWithConfig, ModelWithHealthStats, ModelWith
for hook in hooks:
hook_binary_names = run_crawl_hook(hook)
if hook_binary_names:
install_declared_binaries(resolve_provider_binaries(hook_binary_names))
self.install_declared_binaries(resolve_provider_binaries(hook_binary_names), machine=machine)
# Safety check: don't create snapshots if any crawl-declared dependency
# is still unresolved after all crawl hooks have run.
install_declared_binaries(declared_binary_names)
self.install_declared_binaries(declared_binary_names, machine=machine)
# Create snapshots from all URLs in self.urls
if system_task:

View File

@@ -0,0 +1,143 @@
import threading
import time
import pytest
from django.db import close_old_connections
from django.utils import timezone
from archivebox.base_models.models import get_or_create_system_user_pk
from archivebox.crawls.models import Crawl
from archivebox.machine.models import Binary, Machine
from archivebox.workers.worker import BinaryWorker
def get_fresh_machine() -> Machine:
import archivebox.machine.models as machine_models
machine_models._CURRENT_MACHINE = None
machine_models._CURRENT_BINARIES.clear()
return Machine.current()
@pytest.mark.django_db
def test_claim_processing_lock_does_not_steal_future_retry_at():
"""
retry_at is both the schedule and the ownership lock.
Once one process claims a due row and moves retry_at into the future, a
fresh reader must not be able to "re-claim" that future timestamp and run
the same side effects a second time.
"""
machine = get_fresh_machine()
binary = Binary.objects.create(
machine=machine,
name='claim-test',
binproviders='env',
status=Binary.StatusChoices.QUEUED,
retry_at=timezone.now(),
)
owner = Binary.objects.get(pk=binary.pk)
contender = Binary.objects.get(pk=binary.pk)
assert owner.claim_processing_lock(lock_seconds=30) is True
contender.refresh_from_db()
assert contender.retry_at > timezone.now()
assert contender.claim_processing_lock(lock_seconds=30) is False
@pytest.mark.django_db
def test_binary_worker_skips_binary_claimed_by_other_owner(monkeypatch):
"""
BinaryWorker must never run install side effects for a Binary whose retry_at
lock has already been claimed by another process.
"""
machine = get_fresh_machine()
binary = Binary.objects.create(
machine=machine,
name='claimed-binary',
binproviders='env',
status=Binary.StatusChoices.QUEUED,
retry_at=timezone.now(),
)
owner = Binary.objects.get(pk=binary.pk)
assert owner.claim_processing_lock(lock_seconds=30) is True
calls: list[str] = []
def fake_run(self):
calls.append(self.name)
self.status = self.StatusChoices.INSTALLED
self.abspath = '/tmp/fake-binary'
self.version = '1.0'
self.save(update_fields=['status', 'abspath', 'version', 'modified_at'])
monkeypatch.setattr(Binary, 'run', fake_run)
worker = BinaryWorker(binary_id=str(binary.id))
worker._process_single_binary()
assert calls == []
@pytest.mark.django_db(transaction=True)
def test_crawl_install_declared_binaries_waits_for_existing_owner(monkeypatch):
"""
Crawl.install_declared_binaries should wait for the current owner of a Binary
to finish instead of launching a duplicate install against shared provider
state such as the npm tree.
"""
machine = get_fresh_machine()
crawl = Crawl.objects.create(
urls='https://example.com',
created_by_id=get_or_create_system_user_pk(),
status=Crawl.StatusChoices.QUEUED,
retry_at=timezone.now(),
)
binary = Binary.objects.create(
machine=machine,
name='puppeteer',
binproviders='npm',
status=Binary.StatusChoices.QUEUED,
retry_at=timezone.now(),
)
owner = Binary.objects.get(pk=binary.pk)
assert owner.claim_processing_lock(lock_seconds=30) is True
calls: list[str] = []
def fake_run(self):
calls.append(self.name)
self.status = self.StatusChoices.INSTALLED
self.abspath = '/tmp/should-not-run'
self.version = '1.0'
self.save(update_fields=['status', 'abspath', 'version', 'modified_at'])
monkeypatch.setattr(Binary, 'run', fake_run)
def finish_existing_install():
close_old_connections()
try:
time.sleep(0.3)
Binary.objects.filter(pk=binary.pk).update(
status=Binary.StatusChoices.INSTALLED,
retry_at=None,
abspath='/tmp/finished-by-owner',
version='1.0',
modified_at=timezone.now(),
)
finally:
close_old_connections()
thread = threading.Thread(target=finish_existing_install, daemon=True)
thread.start()
crawl.install_declared_binaries({'puppeteer'}, machine=machine)
thread.join(timeout=5)
binary.refresh_from_db()
assert binary.status == Binary.StatusChoices.INSTALLED
assert binary.abspath == '/tmp/finished-by-owner'
assert calls == []

View File

@@ -210,17 +210,71 @@ class BaseModelWithStateMachine(models.Model, MachineMixin):
@classmethod
def claim_for_worker(cls, obj: 'BaseModelWithStateMachine', lock_seconds: int = 60) -> bool:
"""
Atomically claim an object for processing using optimistic locking.
Returns True if successfully claimed, False if another worker got it first.
Atomically claim a due object for processing using retry_at as the lock.
Correct lifecycle for any state-machine-driven work item:
1. Queue the item by setting retry_at <= now
2. Exactly one owner claims it by moving retry_at into the future
3. Only that owner may call .sm.tick() and perform side effects
4. State-machine callbacks update retry_at again when the work completes,
backs off, or is re-queued
The critical rule is that future retry_at values are already owned.
Callers must never "steal" those future timestamps and start another
copy of the same work. That is what prevents duplicate installs, hook
runs, and other concurrent side effects.
Returns True if successfully claimed, False if another worker got it
first or the object is not currently due.
"""
updated = cls.objects.filter(
pk=obj.pk,
retry_at=obj.retry_at,
retry_at__lte=timezone.now(),
).update(
retry_at=timezone.now() + timedelta(seconds=lock_seconds)
)
return updated == 1
def claim_processing_lock(self, lock_seconds: int = 60) -> bool:
"""
Claim this model instance immediately before executing one state-machine tick.
This helper is the safe entrypoint for any direct state-machine driver
(workers, synchronous crawl dependency installers, one-off CLI helpers).
Calling `.sm.tick()` without claiming first turns retry_at into "just a
schedule" instead of the ownership lock it is meant to be.
Returns True only for the caller that successfully moved retry_at into
the future. False means another process already owns the work item or it
is not currently due.
"""
if self.STATE in self.FINAL_STATES:
return False
if self.RETRY_AT is None:
return False
claimed = type(self).claim_for_worker(self, lock_seconds=lock_seconds)
if claimed:
self.refresh_from_db()
return claimed
def tick_claimed(self, lock_seconds: int = 60) -> bool:
"""
Claim ownership via retry_at and then execute exactly one `.sm.tick()`.
Future maintainers should prefer this helper over calling `.sm.tick()`
directly whenever there is any chance another process could see the same
queued row. If this method returns False, someone else already owns the
work and the caller must not run side effects for it.
"""
if not self.claim_processing_lock(lock_seconds=lock_seconds):
return False
self.sm.tick()
self.refresh_from_db()
return True
@classproperty
def ACTIVE_STATE(cls) -> str:
return cls._state_to_str(cls.active_state)

View File

@@ -35,6 +35,7 @@ from datetime import timedelta
from multiprocessing import Process as MPProcess
from pathlib import Path
from django.db import connections
from django.utils import timezone
from rich import print
@@ -403,6 +404,17 @@ class Orchestrator:
return queue_sizes
def _refresh_db_connections(self) -> None:
"""
Drop long-lived DB connections before each poll tick.
The daemon orchestrator must observe rows created by sibling processes
(server requests, CLI helpers, docker-compose run invocations). With
SQLite, reusing the same connection indefinitely can miss externally
committed rows until the process reconnects.
"""
connections.close_all()
def _should_process_schedules(self) -> bool:
return (not self.exit_on_idle) and (self.crawl_id is None)
@@ -576,17 +588,10 @@ class Orchestrator:
)
def _claim_crawl(self, crawl) -> bool:
"""Atomically claim a crawl using optimistic locking."""
"""Atomically claim a due crawl using the shared retry_at lock lifecycle."""
from archivebox.crawls.models import Crawl
updated = Crawl.objects.filter(
pk=crawl.pk,
retry_at=crawl.retry_at,
).update(
retry_at=timezone.now() + timedelta(hours=24), # Long lock (crawls take time)
)
return updated == 1
return Crawl.claim_for_worker(crawl, lock_seconds=24 * 60 * 60)
def has_pending_work(self, queue_sizes: dict[str, int]) -> bool:
"""Check if any queue has pending work."""
@@ -726,6 +731,10 @@ class Orchestrator:
while True:
tick_count += 1
# Refresh DB state before polling so this long-lived daemon sees
# work created by other processes using the same collection.
self._refresh_db_connections()
# Check queues and spawn workers
queue_sizes = self.check_queues_and_spawn_workers()

View File

@@ -569,6 +569,7 @@ class CrawlWorker(Worker):
def _spawn_snapshot_workers(self) -> None:
"""Spawn SnapshotWorkers for queued snapshots (up to limit)."""
from pathlib import Path
from archivebox.config.constants import CONSTANTS
from archivebox.core.models import Snapshot
from archivebox.machine.models import Process
import sys
@@ -636,6 +637,18 @@ class CrawlWorker(Worker):
f.write(f' Spawning worker for {snapshot.url} (status={snapshot.status})\n')
f.flush()
# Claim the snapshot before spawning the worker so retry_at remains
# the single source of truth for ownership even if process tracking
# lags or multiple schedulers look at the same queue.
if not Snapshot.claim_for_worker(snapshot, lock_seconds=CONSTANTS.MAX_SNAPSHOT_RUNTIME_SECONDS):
log_worker_event(
worker_type='CrawlWorker',
event=f'Skipped already-claimed Snapshot: {snapshot.url}',
indent_level=1,
pid=self.pid,
)
continue
pid = SnapshotWorker.start(parent=self.db_process, snapshot_id=str(snapshot.id))
log_worker_event(
@@ -1195,9 +1208,15 @@ class BinaryWorker(Worker):
return
print(f'[cyan]🔧 BinaryWorker installing: {binary.name}[/cyan]', file=sys.stderr)
binary.sm.tick()
if not binary.tick_claimed(lock_seconds=self.MAX_TICK_TIME):
log_worker_event(
worker_type='BinaryWorker',
event=f'Skipped already-claimed binary: {binary.name}',
indent_level=1,
pid=self.pid,
)
return
binary.refresh_from_db()
if binary.status == binary.__class__.StatusChoices.INSTALLED:
log_worker_event(
worker_type='BinaryWorker',
@@ -1254,9 +1273,15 @@ class BinaryWorker(Worker):
for binary in pending_binaries:
try:
print(f'[cyan]🔧 BinaryWorker processing: {binary.name}[/cyan]', file=sys.stderr)
binary.sm.tick()
if not binary.tick_claimed(lock_seconds=self.MAX_TICK_TIME):
log_worker_event(
worker_type='BinaryWorker',
event=f'Skipped already-claimed binary: {binary.name}',
indent_level=1,
pid=self.pid,
)
continue
binary.refresh_from_db()
if binary.status == binary.__class__.StatusChoices.INSTALLED:
log_worker_event(
worker_type='BinaryWorker',