mirror of
https://github.com/ArchiveBox/ArchiveBox.git
synced 2026-04-06 07:47:53 +10:00
add stricter locking around stage machine models
This commit is contained in:
5
.gitignore
vendored
5
.gitignore
vendored
@@ -38,6 +38,7 @@ lib/
|
||||
tmp/
|
||||
data/
|
||||
data*/
|
||||
archive/
|
||||
output/
|
||||
logs/
|
||||
index.sqlite3
|
||||
@@ -46,6 +47,10 @@ queue.sqlite3
|
||||
data.*
|
||||
.archivebox_id
|
||||
ArchiveBox.conf
|
||||
*.stdout
|
||||
*.stderr
|
||||
*.log
|
||||
.tmp/
|
||||
|
||||
# vim
|
||||
*.sw?
|
||||
|
||||
@@ -14,16 +14,21 @@ __package__ = 'archivebox'
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Protocol, cast
|
||||
|
||||
# Import uuid_compat early to monkey-patch uuid.uuid7 before Django loads migrations
|
||||
# This fixes migrations generated on Python 3.14+ that reference uuid.uuid7 directly
|
||||
from archivebox import uuid_compat # noqa: F401
|
||||
from abx_plugins import get_plugins_dir
|
||||
|
||||
|
||||
class _ReconfigurableStream(Protocol):
|
||||
def reconfigure(self, *, line_buffering: bool) -> object: ...
|
||||
|
||||
# Force unbuffered output for real-time logs
|
||||
if hasattr(sys.stdout, 'reconfigure'):
|
||||
sys.stdout.reconfigure(line_buffering=True)
|
||||
sys.stderr.reconfigure(line_buffering=True)
|
||||
cast(_ReconfigurableStream, sys.stdout).reconfigure(line_buffering=True)
|
||||
cast(_ReconfigurableStream, sys.stderr).reconfigure(line_buffering=True)
|
||||
os.environ['PYTHONUNBUFFERED'] = '1'
|
||||
|
||||
ASCII_LOGO = """
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
__package__ = 'archivebox.api'
|
||||
|
||||
from typing import Optional, cast
|
||||
from typing import Optional
|
||||
from datetime import timedelta
|
||||
|
||||
from django.http import HttpRequest
|
||||
from django.utils import timezone
|
||||
from django.http import HttpRequest
|
||||
from django.contrib.auth import authenticate
|
||||
from django.contrib.auth.models import AbstractBaseUser
|
||||
from django.contrib.auth.models import User
|
||||
|
||||
from ninja.security import HttpBearer, APIKeyQuery, APIKeyHeader, HttpBasicAuth
|
||||
from ninja.errors import HttpError
|
||||
|
||||
|
||||
def get_or_create_api_token(user):
|
||||
def get_or_create_api_token(user: User | None):
|
||||
from archivebox.api.models import APIToken
|
||||
|
||||
if user and user.is_superuser:
|
||||
@@ -23,122 +23,106 @@ def get_or_create_api_token(user):
|
||||
else:
|
||||
# does not exist, create a new one
|
||||
api_token = APIToken.objects.create(created_by_id=user.pk, expires=timezone.now() + timedelta(days=30))
|
||||
|
||||
|
||||
if api_token is None:
|
||||
return None
|
||||
assert api_token.is_valid(), f"API token is not valid {api_token}"
|
||||
|
||||
return api_token
|
||||
return None
|
||||
|
||||
|
||||
def auth_using_token(token, request: Optional[HttpRequest]=None) -> Optional[AbstractBaseUser]:
|
||||
def auth_using_token(token: str | None, request: HttpRequest | None = None) -> User | None:
|
||||
"""Given an API token string, check if a corresponding non-expired APIToken exists, and return its user"""
|
||||
from archivebox.api.models import APIToken # lazy import model to avoid loading it at urls.py import time
|
||||
|
||||
user = None
|
||||
user: User | None = None
|
||||
|
||||
submitted_empty_form = str(token).strip() in ('string', '', 'None', 'null')
|
||||
if not submitted_empty_form:
|
||||
try:
|
||||
token = APIToken.objects.get(token=token)
|
||||
if token.is_valid():
|
||||
user = token.created_by
|
||||
request._api_token = token
|
||||
api_token = APIToken.objects.get(token=token)
|
||||
if api_token.is_valid() and isinstance(api_token.created_by, User):
|
||||
user = api_token.created_by
|
||||
if request is not None:
|
||||
setattr(request, '_api_token', api_token)
|
||||
except APIToken.DoesNotExist:
|
||||
pass
|
||||
|
||||
if not user:
|
||||
# print('[❌] Failed to authenticate API user using API Key:', request)
|
||||
return None
|
||||
|
||||
return cast(AbstractBaseUser, user)
|
||||
return user
|
||||
|
||||
def auth_using_password(username, password, request: Optional[HttpRequest]=None) -> Optional[AbstractBaseUser]:
|
||||
|
||||
def auth_using_password(username: str | None, password: str | None, request: HttpRequest | None = None) -> User | None:
|
||||
"""Given a username and password, check if they are valid and return the corresponding user"""
|
||||
user = None
|
||||
user: User | None = None
|
||||
|
||||
submitted_empty_form = (username, password) in (('string', 'string'), ('', ''), (None, None))
|
||||
if not submitted_empty_form:
|
||||
user = authenticate(
|
||||
authenticated_user = authenticate(
|
||||
username=username,
|
||||
password=password,
|
||||
)
|
||||
|
||||
if not user:
|
||||
# print('[❌] Failed to authenticate API user using API Key:', request)
|
||||
user = None
|
||||
|
||||
return cast(AbstractBaseUser | None, user)
|
||||
if isinstance(authenticated_user, User):
|
||||
user = authenticated_user
|
||||
return user
|
||||
|
||||
|
||||
### Base Auth Types
|
||||
|
||||
|
||||
class APITokenAuthCheck:
|
||||
"""The base class for authentication methods that use an api.models.APIToken"""
|
||||
def authenticate(self, request: HttpRequest, key: Optional[str]=None) -> Optional[AbstractBaseUser]:
|
||||
request.user = auth_using_token(
|
||||
token=key,
|
||||
request=request,
|
||||
)
|
||||
if request.user and request.user.pk:
|
||||
# Don't set cookie/persist login ouside this erquest, user may be accessing the API from another domain (CSRF/CORS):
|
||||
# login(request, request.user, backend='django.contrib.auth.backends.ModelBackend')
|
||||
request._api_auth_method = self.__class__.__name__
|
||||
|
||||
if not request.user.is_superuser:
|
||||
raise HttpError(403, 'Valid API token but User does not have permission (make sure user.is_superuser=True)')
|
||||
return request.user
|
||||
|
||||
|
||||
class UserPassAuthCheck:
|
||||
"""The base class for authentication methods that use a username & password"""
|
||||
def authenticate(self, request: HttpRequest, username: Optional[str]=None, password: Optional[str]=None) -> Optional[AbstractBaseUser]:
|
||||
request.user = auth_using_password(
|
||||
username=username,
|
||||
password=password,
|
||||
request=request,
|
||||
)
|
||||
if request.user and request.user.pk:
|
||||
# Don't set cookie/persist login ouside this erquest, user may be accessing the API from another domain (CSRF/CORS):
|
||||
# login(request, request.user, backend='django.contrib.auth.backends.ModelBackend')
|
||||
request._api_auth_method = self.__class__.__name__
|
||||
|
||||
if not request.user.is_superuser:
|
||||
raise HttpError(403, 'Valid API token but User does not have permission (make sure user.is_superuser=True)')
|
||||
|
||||
return request.user
|
||||
def _require_superuser(user: User | None, request: HttpRequest, auth_method: str) -> User | None:
|
||||
if user and user.pk:
|
||||
request.user = user
|
||||
setattr(request, '_api_auth_method', auth_method)
|
||||
if not user.is_superuser:
|
||||
raise HttpError(403, 'Valid credentials but User does not have permission (make sure user.is_superuser=True)')
|
||||
return user
|
||||
|
||||
|
||||
### Django-Ninja-Provided Auth Methods
|
||||
|
||||
class HeaderTokenAuth(APITokenAuthCheck, APIKeyHeader):
|
||||
class HeaderTokenAuth(APIKeyHeader):
|
||||
"""Allow authenticating by passing X-API-Key=xyz as a request header"""
|
||||
param_name = "X-ArchiveBox-API-Key"
|
||||
|
||||
class BearerTokenAuth(APITokenAuthCheck, HttpBearer):
|
||||
"""Allow authenticating by passing Bearer=xyz as a request header"""
|
||||
pass
|
||||
def authenticate(self, request: HttpRequest, key: Optional[str]) -> User | None:
|
||||
return _require_superuser(auth_using_token(token=key, request=request), request, self.__class__.__name__)
|
||||
|
||||
class QueryParamTokenAuth(APITokenAuthCheck, APIKeyQuery):
|
||||
class BearerTokenAuth(HttpBearer):
|
||||
"""Allow authenticating by passing Bearer=xyz as a request header"""
|
||||
|
||||
def authenticate(self, request: HttpRequest, token: str) -> User | None:
|
||||
return _require_superuser(auth_using_token(token=token, request=request), request, self.__class__.__name__)
|
||||
|
||||
class QueryParamTokenAuth(APIKeyQuery):
|
||||
"""Allow authenticating by passing api_key=xyz as a GET/POST query parameter"""
|
||||
param_name = "api_key"
|
||||
|
||||
class UsernameAndPasswordAuth(UserPassAuthCheck, HttpBasicAuth):
|
||||
def authenticate(self, request: HttpRequest, key: Optional[str]) -> User | None:
|
||||
return _require_superuser(auth_using_token(token=key, request=request), request, self.__class__.__name__)
|
||||
|
||||
class UsernameAndPasswordAuth(HttpBasicAuth):
|
||||
"""Allow authenticating by passing username & password via HTTP Basic Authentication (not recommended)"""
|
||||
pass
|
||||
|
||||
def authenticate(self, request: HttpRequest, username: str, password: str) -> User | None:
|
||||
return _require_superuser(
|
||||
auth_using_password(username=username, password=password, request=request),
|
||||
request,
|
||||
self.__class__.__name__,
|
||||
)
|
||||
|
||||
class DjangoSessionAuth:
|
||||
"""Allow authenticating with existing Django session cookies (same-origin only)."""
|
||||
def __call__(self, request: HttpRequest) -> Optional[AbstractBaseUser]:
|
||||
def __call__(self, request: HttpRequest) -> User | None:
|
||||
return self.authenticate(request)
|
||||
|
||||
def authenticate(self, request: HttpRequest, **kwargs) -> Optional[AbstractBaseUser]:
|
||||
def authenticate(self, request: HttpRequest, **kwargs) -> User | None:
|
||||
user = getattr(request, 'user', None)
|
||||
if user and user.is_authenticated:
|
||||
request._api_auth_method = self.__class__.__name__
|
||||
if isinstance(user, User) and user.is_authenticated:
|
||||
setattr(request, '_api_auth_method', self.__class__.__name__)
|
||||
if not user.is_superuser:
|
||||
raise HttpError(403, 'Valid session but User does not have permission (make sure user.is_superuser=True)')
|
||||
return cast(AbstractBaseUser, user)
|
||||
return user
|
||||
return None
|
||||
|
||||
### Enabled Auth Methods
|
||||
|
||||
@@ -7,6 +7,7 @@ from contextlib import redirect_stdout, redirect_stderr
|
||||
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
from django.core.exceptions import ObjectDoesNotExist, EmptyResultSet, PermissionDenied
|
||||
from django.contrib.auth.models import User
|
||||
|
||||
from ninja import NinjaAPI, Swagger
|
||||
|
||||
@@ -16,6 +17,7 @@ from archivebox.config import VERSION
|
||||
from archivebox.config.version import get_COMMIT_HASH
|
||||
|
||||
from archivebox.api.auth import API_AUTH_METHODS
|
||||
from archivebox.api.models import APIToken
|
||||
|
||||
|
||||
COMMIT_HASH = get_COMMIT_HASH() or 'unknown'
|
||||
@@ -51,8 +53,8 @@ class NinjaAPIWithIOCapture(NinjaAPI):
|
||||
|
||||
with redirect_stderr(stderr):
|
||||
with redirect_stdout(stdout):
|
||||
request.stdout = stdout
|
||||
request.stderr = stderr
|
||||
setattr(request, 'stdout', stdout)
|
||||
setattr(request, 'stderr', stderr)
|
||||
|
||||
response = super().create_temporal_response(request)
|
||||
|
||||
@@ -60,19 +62,20 @@ class NinjaAPIWithIOCapture(NinjaAPI):
|
||||
response['Cache-Control'] = 'no-store'
|
||||
|
||||
# Add debug stdout and stderr headers to response
|
||||
response['X-ArchiveBox-Stdout'] = str(request.stdout)[200:]
|
||||
response['X-ArchiveBox-Stderr'] = str(request.stderr)[200:]
|
||||
response['X-ArchiveBox-Stdout'] = stdout.getvalue().replace('\n', '\\n')[:200]
|
||||
response['X-ArchiveBox-Stderr'] = stderr.getvalue().replace('\n', '\\n')[:200]
|
||||
# response['X-ArchiveBox-View'] = self.get_openapi_operation_id(request) or 'Unknown'
|
||||
|
||||
# Add Auth Headers to response
|
||||
api_token = getattr(request, '_api_token', None)
|
||||
api_token_attr = getattr(request, '_api_token', None)
|
||||
api_token = api_token_attr if isinstance(api_token_attr, APIToken) else None
|
||||
token_expiry = api_token.expires.isoformat() if api_token and api_token.expires else 'Never'
|
||||
|
||||
response['X-ArchiveBox-Auth-Method'] = getattr(request, '_api_auth_method', None) or 'None'
|
||||
response['X-ArchiveBox-Auth-Method'] = str(getattr(request, '_api_auth_method', 'None'))
|
||||
response['X-ArchiveBox-Auth-Expires'] = token_expiry
|
||||
response['X-ArchiveBox-Auth-Token-Id'] = str(api_token.id) if api_token else 'None'
|
||||
response['X-ArchiveBox-Auth-User-Id'] = request.user.pk if request.user.pk else 'None'
|
||||
response['X-ArchiveBox-Auth-User-Username'] = request.user.username if request.user.pk else 'None'
|
||||
response['X-ArchiveBox-Auth-User-Id'] = str(request.user.pk) if getattr(request.user, 'pk', None) else 'None'
|
||||
response['X-ArchiveBox-Auth-User-Username'] = request.user.username if isinstance(request.user, User) else 'None'
|
||||
|
||||
# import ipdb; ipdb.set_trace()
|
||||
# print('RESPONDING NOW', response)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
__package__ = 'archivebox.api'
|
||||
|
||||
from typing import Optional
|
||||
from django.http import HttpRequest
|
||||
|
||||
from ninja import Router, Schema
|
||||
|
||||
@@ -17,7 +18,7 @@ class PasswordAuthSchema(Schema):
|
||||
|
||||
|
||||
@router.post("/get_api_token", auth=None, summary='Generate an API token for a given username & password (or currently logged-in user)') # auth=None because they are not authed yet
|
||||
def get_api_token(request, auth_data: PasswordAuthSchema):
|
||||
def get_api_token(request: HttpRequest, auth_data: PasswordAuthSchema):
|
||||
user = auth_using_password(
|
||||
username=auth_data.username,
|
||||
password=auth_data.password,
|
||||
@@ -45,7 +46,7 @@ class TokenAuthSchema(Schema):
|
||||
|
||||
|
||||
@router.post("/check_api_token", auth=None, summary='Validate an API token to make sure its valid and non-expired') # auth=None because they are not authed yet
|
||||
def check_api_token(request, token_data: TokenAuthSchema):
|
||||
def check_api_token(request: HttpRequest, token_data: TokenAuthSchema):
|
||||
user = auth_using_token(
|
||||
token=token_data.token,
|
||||
request=request,
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
__package__ = 'archivebox.api'
|
||||
|
||||
import json
|
||||
from io import StringIO
|
||||
from typing import List, Dict, Any, Optional
|
||||
from enum import Enum
|
||||
|
||||
from django.http import HttpRequest
|
||||
|
||||
from ninja import Router, Schema
|
||||
|
||||
from archivebox.misc.util import ansi_to_html
|
||||
@@ -60,16 +63,13 @@ class AddCommandSchema(Schema):
|
||||
index_only: bool = False
|
||||
|
||||
class UpdateCommandSchema(Schema):
|
||||
resume: Optional[float] = 0
|
||||
only_new: bool = ARCHIVING_CONFIG.ONLY_NEW
|
||||
index_only: bool = False
|
||||
overwrite: bool = False
|
||||
resume: Optional[str] = None
|
||||
after: Optional[float] = 0
|
||||
before: Optional[float] = 999999999999999
|
||||
status: Optional[StatusChoices] = StatusChoices.unarchived
|
||||
filter_type: Optional[str] = FilterTypeChoices.substring
|
||||
filter_patterns: Optional[List[str]] = ['https://example.com']
|
||||
plugins: Optional[str] = ""
|
||||
batch_size: int = 100
|
||||
continuous: bool = False
|
||||
|
||||
class ScheduleCommandSchema(Schema):
|
||||
import_path: Optional[str] = None
|
||||
@@ -109,7 +109,7 @@ class RemoveCommandSchema(Schema):
|
||||
|
||||
|
||||
@router.post("/add", response=CLICommandResponseSchema, summary='archivebox add [args] [urls]')
|
||||
def cli_add(request, args: AddCommandSchema):
|
||||
def cli_add(request: HttpRequest, args: AddCommandSchema):
|
||||
from archivebox.cli.archivebox_add import add
|
||||
|
||||
result = add(
|
||||
@@ -132,44 +132,45 @@ def cli_add(request, args: AddCommandSchema):
|
||||
"snapshot_ids": snapshot_ids,
|
||||
"queued_urls": args.urls,
|
||||
}
|
||||
stdout = getattr(request, 'stdout', None)
|
||||
stderr = getattr(request, 'stderr', None)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"errors": [],
|
||||
"result": result_payload,
|
||||
"result_format": "json",
|
||||
"stdout": ansi_to_html(request.stdout.getvalue().strip()),
|
||||
"stderr": ansi_to_html(request.stderr.getvalue().strip()),
|
||||
"stdout": ansi_to_html(stdout.getvalue().strip()) if isinstance(stdout, StringIO) else '',
|
||||
"stderr": ansi_to_html(stderr.getvalue().strip()) if isinstance(stderr, StringIO) else '',
|
||||
}
|
||||
|
||||
|
||||
@router.post("/update", response=CLICommandResponseSchema, summary='archivebox update [args] [filter_patterns]')
|
||||
def cli_update(request, args: UpdateCommandSchema):
|
||||
def cli_update(request: HttpRequest, args: UpdateCommandSchema):
|
||||
from archivebox.cli.archivebox_update import update
|
||||
|
||||
result = update(
|
||||
resume=args.resume,
|
||||
only_new=args.only_new,
|
||||
index_only=args.index_only,
|
||||
overwrite=args.overwrite,
|
||||
before=args.before,
|
||||
filter_patterns=args.filter_patterns or [],
|
||||
filter_type=args.filter_type or FilterTypeChoices.substring,
|
||||
after=args.after,
|
||||
status=args.status,
|
||||
filter_type=args.filter_type,
|
||||
filter_patterns=args.filter_patterns,
|
||||
plugins=args.plugins,
|
||||
before=args.before,
|
||||
resume=args.resume,
|
||||
batch_size=args.batch_size,
|
||||
continuous=args.continuous,
|
||||
)
|
||||
stdout = getattr(request, 'stdout', None)
|
||||
stderr = getattr(request, 'stderr', None)
|
||||
return {
|
||||
"success": True,
|
||||
"errors": [],
|
||||
"result": result,
|
||||
"stdout": ansi_to_html(request.stdout.getvalue().strip()),
|
||||
"stderr": ansi_to_html(request.stderr.getvalue().strip()),
|
||||
"stdout": ansi_to_html(stdout.getvalue().strip()) if isinstance(stdout, StringIO) else '',
|
||||
"stderr": ansi_to_html(stderr.getvalue().strip()) if isinstance(stderr, StringIO) else '',
|
||||
}
|
||||
|
||||
|
||||
@router.post("/schedule", response=CLICommandResponseSchema, summary='archivebox schedule [args] [import_path]')
|
||||
def cli_schedule(request, args: ScheduleCommandSchema):
|
||||
def cli_schedule(request: HttpRequest, args: ScheduleCommandSchema):
|
||||
from archivebox.cli.archivebox_schedule import schedule
|
||||
|
||||
result = schedule(
|
||||
@@ -187,19 +188,21 @@ def cli_schedule(request, args: ScheduleCommandSchema):
|
||||
update=args.update,
|
||||
)
|
||||
|
||||
stdout = getattr(request, 'stdout', None)
|
||||
stderr = getattr(request, 'stderr', None)
|
||||
return {
|
||||
"success": True,
|
||||
"errors": [],
|
||||
"result": result,
|
||||
"result_format": "json",
|
||||
"stdout": ansi_to_html(request.stdout.getvalue().strip()),
|
||||
"stderr": ansi_to_html(request.stderr.getvalue().strip()),
|
||||
"stdout": ansi_to_html(stdout.getvalue().strip()) if isinstance(stdout, StringIO) else '',
|
||||
"stderr": ansi_to_html(stderr.getvalue().strip()) if isinstance(stderr, StringIO) else '',
|
||||
}
|
||||
|
||||
|
||||
|
||||
@router.post("/search", response=CLICommandResponseSchema, summary='archivebox search [args] [filter_patterns]')
|
||||
def cli_search(request, args: ListCommandSchema):
|
||||
def cli_search(request: HttpRequest, args: ListCommandSchema):
|
||||
from archivebox.cli.archivebox_search import search
|
||||
|
||||
result = search(
|
||||
@@ -224,25 +227,28 @@ def cli_search(request, args: ListCommandSchema):
|
||||
elif args.as_csv:
|
||||
result_format = "csv"
|
||||
|
||||
stdout = getattr(request, 'stdout', None)
|
||||
stderr = getattr(request, 'stderr', None)
|
||||
return {
|
||||
"success": True,
|
||||
"errors": [],
|
||||
"result": result,
|
||||
"result_format": result_format,
|
||||
"stdout": ansi_to_html(request.stdout.getvalue().strip()),
|
||||
"stderr": ansi_to_html(request.stderr.getvalue().strip()),
|
||||
"stdout": ansi_to_html(stdout.getvalue().strip()) if isinstance(stdout, StringIO) else '',
|
||||
"stderr": ansi_to_html(stderr.getvalue().strip()) if isinstance(stderr, StringIO) else '',
|
||||
}
|
||||
|
||||
|
||||
|
||||
@router.post("/remove", response=CLICommandResponseSchema, summary='archivebox remove [args] [filter_patterns]')
|
||||
def cli_remove(request, args: RemoveCommandSchema):
|
||||
def cli_remove(request: HttpRequest, args: RemoveCommandSchema):
|
||||
from archivebox.cli.archivebox_remove import remove
|
||||
from archivebox.cli.archivebox_search import get_snapshots
|
||||
from archivebox.core.models import Snapshot
|
||||
|
||||
filter_patterns = args.filter_patterns or []
|
||||
snapshots_to_remove = get_snapshots(
|
||||
filter_patterns=args.filter_patterns,
|
||||
filter_patterns=filter_patterns,
|
||||
filter_type=args.filter_type,
|
||||
after=args.after,
|
||||
before=args.before,
|
||||
@@ -256,7 +262,7 @@ def cli_remove(request, args: RemoveCommandSchema):
|
||||
before=args.before,
|
||||
after=args.after,
|
||||
filter_type=args.filter_type,
|
||||
filter_patterns=args.filter_patterns,
|
||||
filter_patterns=filter_patterns,
|
||||
)
|
||||
|
||||
result = {
|
||||
@@ -264,12 +270,14 @@ def cli_remove(request, args: RemoveCommandSchema):
|
||||
"removed_snapshot_ids": removed_snapshot_ids,
|
||||
"remaining_snapshots": Snapshot.objects.count(),
|
||||
}
|
||||
stdout = getattr(request, 'stdout', None)
|
||||
stderr = getattr(request, 'stderr', None)
|
||||
return {
|
||||
"success": True,
|
||||
"errors": [],
|
||||
"result": result,
|
||||
"result_format": "json",
|
||||
"stdout": ansi_to_html(request.stdout.getvalue().strip()),
|
||||
"stderr": ansi_to_html(request.stderr.getvalue().strip()),
|
||||
"stdout": ansi_to_html(stdout.getvalue().strip()) if isinstance(stdout, StringIO) else '',
|
||||
"stderr": ansi_to_html(stderr.getvalue().strip()) if isinstance(stderr, StringIO) else '',
|
||||
}
|
||||
|
||||
|
||||
@@ -2,16 +2,18 @@ __package__ = 'archivebox.api'
|
||||
|
||||
import math
|
||||
from uuid import UUID
|
||||
from typing import List, Optional, Union, Any
|
||||
from typing import List, Optional, Union, Any, Annotated
|
||||
from datetime import datetime
|
||||
|
||||
from django.db.models import Q
|
||||
from django.db.models import Model, Q
|
||||
from django.http import HttpRequest
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.contrib.auth.models import User
|
||||
from django.shortcuts import redirect
|
||||
from django.utils import timezone
|
||||
|
||||
from ninja import Router, Schema, FilterSchema, Field, Query
|
||||
from ninja import Router, Schema, FilterLookup, FilterSchema, Query
|
||||
from ninja.pagination import paginate, PaginationBase
|
||||
from ninja.errors import HttpError
|
||||
|
||||
@@ -24,12 +26,12 @@ router = Router(tags=['Core Models'])
|
||||
|
||||
|
||||
class CustomPagination(PaginationBase):
|
||||
class Input(Schema):
|
||||
class Input(PaginationBase.Input):
|
||||
limit: int = 200
|
||||
offset: int = 0
|
||||
page: int = 0
|
||||
|
||||
class Output(Schema):
|
||||
class Output(PaginationBase.Output):
|
||||
total_items: int
|
||||
total_pages: int
|
||||
page: int
|
||||
@@ -38,7 +40,7 @@ class CustomPagination(PaginationBase):
|
||||
num_items: int
|
||||
items: List[Any]
|
||||
|
||||
def paginate_queryset(self, queryset, pagination: Input, **params):
|
||||
def paginate_queryset(self, queryset, pagination: Input, request: HttpRequest, **params):
|
||||
limit = min(pagination.limit, 500)
|
||||
offset = pagination.offset or (pagination.page * limit)
|
||||
total = queryset.count()
|
||||
@@ -115,33 +117,33 @@ class ArchiveResultSchema(MinimalArchiveResultSchema):
|
||||
|
||||
|
||||
class ArchiveResultFilterSchema(FilterSchema):
|
||||
id: Optional[str] = Field(None, q=['id__startswith', 'snapshot__id__startswith', 'snapshot__timestamp__startswith'])
|
||||
search: Optional[str] = Field(None, q=['snapshot__url__icontains', 'snapshot__title__icontains', 'snapshot__tags__name__icontains', 'plugin', 'output_str__icontains', 'id__startswith', 'snapshot__id__startswith', 'snapshot__timestamp__startswith'])
|
||||
snapshot_id: Optional[str] = Field(None, q=['snapshot__id__startswith', 'snapshot__timestamp__startswith'])
|
||||
snapshot_url: Optional[str] = Field(None, q='snapshot__url__icontains')
|
||||
snapshot_tag: Optional[str] = Field(None, q='snapshot__tags__name__icontains')
|
||||
status: Optional[str] = Field(None, q='status')
|
||||
output_str: Optional[str] = Field(None, q='output_str__icontains')
|
||||
plugin: Optional[str] = Field(None, q='plugin__icontains')
|
||||
hook_name: Optional[str] = Field(None, q='hook_name__icontains')
|
||||
process_id: Optional[str] = Field(None, q='process__id__startswith')
|
||||
cmd: Optional[str] = Field(None, q='cmd__0__icontains')
|
||||
pwd: Optional[str] = Field(None, q='pwd__icontains')
|
||||
cmd_version: Optional[str] = Field(None, q='cmd_version')
|
||||
created_at: Optional[datetime] = Field(None, q='created_at')
|
||||
created_at__gte: Optional[datetime] = Field(None, q='created_at__gte')
|
||||
created_at__lt: Optional[datetime] = Field(None, q='created_at__lt')
|
||||
id: Annotated[Optional[str], FilterLookup(['id__startswith', 'snapshot__id__startswith', 'snapshot__timestamp__startswith'])] = None
|
||||
search: Annotated[Optional[str], FilterLookup(['snapshot__url__icontains', 'snapshot__title__icontains', 'snapshot__tags__name__icontains', 'plugin', 'output_str__icontains', 'id__startswith', 'snapshot__id__startswith', 'snapshot__timestamp__startswith'])] = None
|
||||
snapshot_id: Annotated[Optional[str], FilterLookup(['snapshot__id__startswith', 'snapshot__timestamp__startswith'])] = None
|
||||
snapshot_url: Annotated[Optional[str], FilterLookup('snapshot__url__icontains')] = None
|
||||
snapshot_tag: Annotated[Optional[str], FilterLookup('snapshot__tags__name__icontains')] = None
|
||||
status: Annotated[Optional[str], FilterLookup('status')] = None
|
||||
output_str: Annotated[Optional[str], FilterLookup('output_str__icontains')] = None
|
||||
plugin: Annotated[Optional[str], FilterLookup('plugin__icontains')] = None
|
||||
hook_name: Annotated[Optional[str], FilterLookup('hook_name__icontains')] = None
|
||||
process_id: Annotated[Optional[str], FilterLookup('process__id__startswith')] = None
|
||||
cmd: Annotated[Optional[str], FilterLookup('cmd__0__icontains')] = None
|
||||
pwd: Annotated[Optional[str], FilterLookup('pwd__icontains')] = None
|
||||
cmd_version: Annotated[Optional[str], FilterLookup('cmd_version')] = None
|
||||
created_at: Annotated[Optional[datetime], FilterLookup('created_at')] = None
|
||||
created_at__gte: Annotated[Optional[datetime], FilterLookup('created_at__gte')] = None
|
||||
created_at__lt: Annotated[Optional[datetime], FilterLookup('created_at__lt')] = None
|
||||
|
||||
|
||||
@router.get("/archiveresults", response=List[ArchiveResultSchema], url_name="get_archiveresult")
|
||||
@paginate(CustomPagination)
|
||||
def get_archiveresults(request, filters: ArchiveResultFilterSchema = Query(...)):
|
||||
def get_archiveresults(request: HttpRequest, filters: Query[ArchiveResultFilterSchema]):
|
||||
"""List all ArchiveResult entries matching these filters."""
|
||||
return filters.filter(ArchiveResult.objects.all()).distinct()
|
||||
|
||||
|
||||
@router.get("/archiveresult/{archiveresult_id}", response=ArchiveResultSchema, url_name="get_archiveresult")
|
||||
def get_archiveresult(request, archiveresult_id: str):
|
||||
def get_archiveresult(request: HttpRequest, archiveresult_id: str):
|
||||
"""Get a specific ArchiveResult by id."""
|
||||
return ArchiveResult.objects.get(Q(id__icontains=archiveresult_id))
|
||||
|
||||
@@ -185,7 +187,7 @@ class SnapshotSchema(Schema):
|
||||
|
||||
@staticmethod
|
||||
def resolve_archiveresults(obj, context):
|
||||
if context['request'].with_archiveresults:
|
||||
if bool(getattr(context['request'], 'with_archiveresults', False)):
|
||||
return obj.archiveresult_set.all().distinct()
|
||||
return ArchiveResult.objects.none()
|
||||
|
||||
@@ -217,36 +219,36 @@ def normalize_tag_list(tags: Optional[List[str]] = None) -> List[str]:
|
||||
|
||||
|
||||
class SnapshotFilterSchema(FilterSchema):
|
||||
id: Optional[str] = Field(None, q=['id__icontains', 'timestamp__startswith'])
|
||||
created_by_id: str = Field(None, q='crawl__created_by_id')
|
||||
created_by_username: str = Field(None, q='crawl__created_by__username__icontains')
|
||||
created_at__gte: datetime = Field(None, q='created_at__gte')
|
||||
created_at__lt: datetime = Field(None, q='created_at__lt')
|
||||
created_at: datetime = Field(None, q='created_at')
|
||||
modified_at: datetime = Field(None, q='modified_at')
|
||||
modified_at__gte: datetime = Field(None, q='modified_at__gte')
|
||||
modified_at__lt: datetime = Field(None, q='modified_at__lt')
|
||||
search: Optional[str] = Field(None, q=['url__icontains', 'title__icontains', 'tags__name__icontains', 'id__icontains', 'timestamp__startswith'])
|
||||
url: Optional[str] = Field(None, q='url')
|
||||
tag: Optional[str] = Field(None, q='tags__name')
|
||||
title: Optional[str] = Field(None, q='title__icontains')
|
||||
timestamp: Optional[str] = Field(None, q='timestamp__startswith')
|
||||
bookmarked_at__gte: Optional[datetime] = Field(None, q='bookmarked_at__gte')
|
||||
bookmarked_at__lt: Optional[datetime] = Field(None, q='bookmarked_at__lt')
|
||||
id: Annotated[Optional[str], FilterLookup(['id__icontains', 'timestamp__startswith'])] = None
|
||||
created_by_id: Annotated[Optional[str], FilterLookup('crawl__created_by_id')] = None
|
||||
created_by_username: Annotated[Optional[str], FilterLookup('crawl__created_by__username__icontains')] = None
|
||||
created_at__gte: Annotated[Optional[datetime], FilterLookup('created_at__gte')] = None
|
||||
created_at__lt: Annotated[Optional[datetime], FilterLookup('created_at__lt')] = None
|
||||
created_at: Annotated[Optional[datetime], FilterLookup('created_at')] = None
|
||||
modified_at: Annotated[Optional[datetime], FilterLookup('modified_at')] = None
|
||||
modified_at__gte: Annotated[Optional[datetime], FilterLookup('modified_at__gte')] = None
|
||||
modified_at__lt: Annotated[Optional[datetime], FilterLookup('modified_at__lt')] = None
|
||||
search: Annotated[Optional[str], FilterLookup(['url__icontains', 'title__icontains', 'tags__name__icontains', 'id__icontains', 'timestamp__startswith'])] = None
|
||||
url: Annotated[Optional[str], FilterLookup('url')] = None
|
||||
tag: Annotated[Optional[str], FilterLookup('tags__name')] = None
|
||||
title: Annotated[Optional[str], FilterLookup('title__icontains')] = None
|
||||
timestamp: Annotated[Optional[str], FilterLookup('timestamp__startswith')] = None
|
||||
bookmarked_at__gte: Annotated[Optional[datetime], FilterLookup('bookmarked_at__gte')] = None
|
||||
bookmarked_at__lt: Annotated[Optional[datetime], FilterLookup('bookmarked_at__lt')] = None
|
||||
|
||||
|
||||
@router.get("/snapshots", response=List[SnapshotSchema], url_name="get_snapshots")
|
||||
@paginate(CustomPagination)
|
||||
def get_snapshots(request, filters: SnapshotFilterSchema = Query(...), with_archiveresults: bool = False):
|
||||
def get_snapshots(request: HttpRequest, filters: Query[SnapshotFilterSchema], with_archiveresults: bool = False):
|
||||
"""List all Snapshot entries matching these filters."""
|
||||
request.with_archiveresults = with_archiveresults
|
||||
setattr(request, 'with_archiveresults', with_archiveresults)
|
||||
return filters.filter(Snapshot.objects.all()).distinct()
|
||||
|
||||
|
||||
@router.get("/snapshot/{snapshot_id}", response=SnapshotSchema, url_name="get_snapshot")
|
||||
def get_snapshot(request, snapshot_id: str, with_archiveresults: bool = True):
|
||||
def get_snapshot(request: HttpRequest, snapshot_id: str, with_archiveresults: bool = True):
|
||||
"""Get a specific Snapshot by id."""
|
||||
request.with_archiveresults = with_archiveresults
|
||||
setattr(request, 'with_archiveresults', with_archiveresults)
|
||||
try:
|
||||
return Snapshot.objects.get(Q(id__startswith=snapshot_id) | Q(timestamp__startswith=snapshot_id))
|
||||
except Snapshot.DoesNotExist:
|
||||
@@ -254,7 +256,7 @@ def get_snapshot(request, snapshot_id: str, with_archiveresults: bool = True):
|
||||
|
||||
|
||||
@router.post("/snapshots", response=SnapshotSchema, url_name="create_snapshot")
|
||||
def create_snapshot(request, data: SnapshotCreateSchema):
|
||||
def create_snapshot(request: HttpRequest, data: SnapshotCreateSchema):
|
||||
tags = normalize_tag_list(data.tags)
|
||||
if data.status is not None and data.status not in Snapshot.StatusChoices.values:
|
||||
raise HttpError(400, f'Invalid status: {data.status}')
|
||||
@@ -274,7 +276,7 @@ def create_snapshot(request, data: SnapshotCreateSchema):
|
||||
tags_str=','.join(tags),
|
||||
status=Crawl.StatusChoices.QUEUED,
|
||||
retry_at=timezone.now(),
|
||||
created_by=request.user,
|
||||
created_by=request.user if isinstance(request.user, User) else None,
|
||||
)
|
||||
|
||||
snapshot_defaults = {
|
||||
@@ -311,12 +313,12 @@ def create_snapshot(request, data: SnapshotCreateSchema):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
request.with_archiveresults = False
|
||||
setattr(request, 'with_archiveresults', False)
|
||||
return snapshot
|
||||
|
||||
|
||||
@router.patch("/snapshot/{snapshot_id}", response=SnapshotSchema, url_name="patch_snapshot")
|
||||
def patch_snapshot(request, snapshot_id: str, data: SnapshotUpdateSchema):
|
||||
def patch_snapshot(request: HttpRequest, snapshot_id: str, data: SnapshotUpdateSchema):
|
||||
"""Update a snapshot (e.g., set status=sealed to cancel queued work)."""
|
||||
try:
|
||||
snapshot = Snapshot.objects.get(Q(id__startswith=snapshot_id) | Q(timestamp__startswith=snapshot_id))
|
||||
@@ -343,15 +345,15 @@ def patch_snapshot(request, snapshot_id: str, data: SnapshotUpdateSchema):
|
||||
snapshot.save_tags(normalize_tag_list(tags))
|
||||
|
||||
snapshot.save(update_fields=update_fields)
|
||||
request.with_archiveresults = False
|
||||
setattr(request, 'with_archiveresults', False)
|
||||
return snapshot
|
||||
|
||||
|
||||
@router.delete("/snapshot/{snapshot_id}", response=SnapshotDeleteResponseSchema, url_name="delete_snapshot")
|
||||
def delete_snapshot(request, snapshot_id: str):
|
||||
def delete_snapshot(request: HttpRequest, snapshot_id: str):
|
||||
snapshot = get_snapshot(request, snapshot_id, with_archiveresults=False)
|
||||
snapshot_id_str = str(snapshot.id)
|
||||
crawl_id_str = str(snapshot.crawl_id)
|
||||
crawl_id_str = str(snapshot.crawl.pk)
|
||||
deleted_count, _ = snapshot.delete()
|
||||
return {
|
||||
'success': True,
|
||||
@@ -381,8 +383,10 @@ class TagSchema(Schema):
|
||||
|
||||
@staticmethod
|
||||
def resolve_created_by_username(obj):
|
||||
User = get_user_model()
|
||||
return User.objects.get(id=obj.created_by_id).username
|
||||
user_model = get_user_model()
|
||||
user = user_model.objects.get(id=obj.created_by_id)
|
||||
username = getattr(user, 'username', None)
|
||||
return username if isinstance(username, str) else str(user)
|
||||
|
||||
@staticmethod
|
||||
def resolve_num_snapshots(obj, context):
|
||||
@@ -390,23 +394,23 @@ class TagSchema(Schema):
|
||||
|
||||
@staticmethod
|
||||
def resolve_snapshots(obj, context):
|
||||
if context['request'].with_snapshots:
|
||||
if bool(getattr(context['request'], 'with_snapshots', False)):
|
||||
return obj.snapshot_set.all().distinct()
|
||||
return Snapshot.objects.none()
|
||||
|
||||
|
||||
@router.get("/tags", response=List[TagSchema], url_name="get_tags")
|
||||
@paginate(CustomPagination)
|
||||
def get_tags(request):
|
||||
request.with_snapshots = False
|
||||
request.with_archiveresults = False
|
||||
def get_tags(request: HttpRequest):
|
||||
setattr(request, 'with_snapshots', False)
|
||||
setattr(request, 'with_archiveresults', False)
|
||||
return Tag.objects.all().distinct()
|
||||
|
||||
|
||||
@router.get("/tag/{tag_id}", response=TagSchema, url_name="get_tag")
|
||||
def get_tag(request, tag_id: str, with_snapshots: bool = True):
|
||||
request.with_snapshots = with_snapshots
|
||||
request.with_archiveresults = False
|
||||
def get_tag(request: HttpRequest, tag_id: str, with_snapshots: bool = True):
|
||||
setattr(request, 'with_snapshots', with_snapshots)
|
||||
setattr(request, 'with_archiveresults', False)
|
||||
try:
|
||||
return Tag.objects.get(id__icontains=tag_id)
|
||||
except (Tag.DoesNotExist, ValidationError):
|
||||
@@ -414,15 +418,15 @@ def get_tag(request, tag_id: str, with_snapshots: bool = True):
|
||||
|
||||
|
||||
@router.get("/any/{id}", response=Union[SnapshotSchema, ArchiveResultSchema, TagSchema, CrawlSchema], url_name="get_any", summary="Get any object by its ID")
|
||||
def get_any(request, id: str):
|
||||
def get_any(request: HttpRequest, id: str):
|
||||
"""Get any object by its ID (e.g. snapshot, archiveresult, tag, crawl, etc.)."""
|
||||
request.with_snapshots = False
|
||||
request.with_archiveresults = False
|
||||
setattr(request, 'with_snapshots', False)
|
||||
setattr(request, 'with_archiveresults', False)
|
||||
|
||||
for getter in [get_snapshot, get_archiveresult, get_tag]:
|
||||
try:
|
||||
response = getter(request, id)
|
||||
if response:
|
||||
if isinstance(response, Model):
|
||||
return redirect(f"/api/v1/{response._meta.app_label}/{response._meta.model_name}/{response.id}?{request.META['QUERY_STRING']}")
|
||||
except Exception:
|
||||
pass
|
||||
@@ -430,7 +434,7 @@ def get_any(request, id: str):
|
||||
try:
|
||||
from archivebox.api.v1_crawls import get_crawl
|
||||
response = get_crawl(request, id)
|
||||
if response:
|
||||
if isinstance(response, Model):
|
||||
return redirect(f"/api/v1/{response._meta.app_label}/{response._meta.model_name}/{response.id}?{request.META['QUERY_STRING']}")
|
||||
except Exception:
|
||||
pass
|
||||
@@ -468,7 +472,7 @@ class TagSnapshotResponseSchema(Schema):
|
||||
|
||||
|
||||
@router.get("/tags/autocomplete/", response=TagAutocompleteSchema, url_name="tags_autocomplete")
|
||||
def tags_autocomplete(request, q: str = ""):
|
||||
def tags_autocomplete(request: HttpRequest, q: str = ""):
|
||||
"""Return tags matching the query for autocomplete."""
|
||||
if not q:
|
||||
# Return all tags if no query (limited to 50)
|
||||
@@ -482,7 +486,7 @@ def tags_autocomplete(request, q: str = ""):
|
||||
|
||||
|
||||
@router.post("/tags/create/", response=TagCreateResponseSchema, url_name="tags_create")
|
||||
def tags_create(request, data: TagCreateSchema):
|
||||
def tags_create(request: HttpRequest, data: TagCreateSchema):
|
||||
"""Create a new tag or return existing one."""
|
||||
name = data.name.strip()
|
||||
if not name:
|
||||
@@ -498,7 +502,10 @@ def tags_create(request, data: TagCreateSchema):
|
||||
|
||||
# If found by case-insensitive match, use that tag
|
||||
if not created:
|
||||
tag = Tag.objects.filter(name__iexact=name).first()
|
||||
existing_tag = Tag.objects.filter(name__iexact=name).first()
|
||||
if existing_tag is None:
|
||||
raise HttpError(500, 'Failed to load existing tag after get_or_create')
|
||||
tag = existing_tag
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
@@ -509,7 +516,7 @@ def tags_create(request, data: TagCreateSchema):
|
||||
|
||||
|
||||
@router.post("/tags/add-to-snapshot/", response=TagSnapshotResponseSchema, url_name="tags_add_to_snapshot")
|
||||
def tags_add_to_snapshot(request, data: TagSnapshotRequestSchema):
|
||||
def tags_add_to_snapshot(request: HttpRequest, data: TagSnapshotRequestSchema):
|
||||
"""Add a tag to a snapshot. Creates the tag if it doesn't exist."""
|
||||
# Get the snapshot
|
||||
try:
|
||||
@@ -522,6 +529,8 @@ def tags_add_to_snapshot(request, data: TagSnapshotRequestSchema):
|
||||
snapshot = Snapshot.objects.filter(
|
||||
Q(id__startswith=data.snapshot_id) | Q(timestamp__startswith=data.snapshot_id)
|
||||
).first()
|
||||
if snapshot is None:
|
||||
raise HttpError(404, 'Snapshot not found')
|
||||
|
||||
# Get or create the tag
|
||||
if data.tag_name:
|
||||
@@ -537,7 +546,9 @@ def tags_add_to_snapshot(request, data: TagSnapshotRequestSchema):
|
||||
}
|
||||
)
|
||||
# If found by case-insensitive match, use that tag
|
||||
tag = Tag.objects.filter(name__iexact=name).first() or tag
|
||||
existing_tag = Tag.objects.filter(name__iexact=name).first()
|
||||
if existing_tag is not None:
|
||||
tag = existing_tag
|
||||
elif data.tag_id:
|
||||
try:
|
||||
tag = Tag.objects.get(pk=data.tag_id)
|
||||
@@ -557,7 +568,7 @@ def tags_add_to_snapshot(request, data: TagSnapshotRequestSchema):
|
||||
|
||||
|
||||
@router.post("/tags/remove-from-snapshot/", response=TagSnapshotResponseSchema, url_name="tags_remove_from_snapshot")
|
||||
def tags_remove_from_snapshot(request, data: TagSnapshotRequestSchema):
|
||||
def tags_remove_from_snapshot(request: HttpRequest, data: TagSnapshotRequestSchema):
|
||||
"""Remove a tag from a snapshot."""
|
||||
# Get the snapshot
|
||||
try:
|
||||
@@ -570,6 +581,8 @@ def tags_remove_from_snapshot(request, data: TagSnapshotRequestSchema):
|
||||
snapshot = Snapshot.objects.filter(
|
||||
Q(id__startswith=data.snapshot_id) | Q(timestamp__startswith=data.snapshot_id)
|
||||
).first()
|
||||
if snapshot is None:
|
||||
raise HttpError(404, 'Snapshot not found')
|
||||
|
||||
# Get the tag
|
||||
if data.tag_id:
|
||||
|
||||
@@ -3,9 +3,11 @@ __package__ = 'archivebox.api'
|
||||
from uuid import UUID
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from django.http import HttpRequest
|
||||
from django.utils import timezone
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.contrib.auth.models import User
|
||||
|
||||
from ninja import Router, Schema
|
||||
from ninja.errors import HttpError
|
||||
@@ -44,12 +46,14 @@ class CrawlSchema(Schema):
|
||||
|
||||
@staticmethod
|
||||
def resolve_created_by_username(obj):
|
||||
User = get_user_model()
|
||||
return User.objects.get(id=obj.created_by_id).username
|
||||
user_model = get_user_model()
|
||||
user = user_model.objects.get(id=obj.created_by_id)
|
||||
username = getattr(user, 'username', None)
|
||||
return username if isinstance(username, str) else str(user)
|
||||
|
||||
@staticmethod
|
||||
def resolve_snapshots(obj, context):
|
||||
if context['request'].with_snapshots:
|
||||
if bool(getattr(context['request'], 'with_snapshots', False)):
|
||||
return obj.snapshot_set.all().distinct()
|
||||
return Snapshot.objects.none()
|
||||
|
||||
@@ -85,12 +89,12 @@ def normalize_tag_list(tags: Optional[List[str]] = None, tags_str: str = '') ->
|
||||
|
||||
|
||||
@router.get("/crawls", response=List[CrawlSchema], url_name="get_crawls")
|
||||
def get_crawls(request):
|
||||
def get_crawls(request: HttpRequest):
|
||||
return Crawl.objects.all().distinct()
|
||||
|
||||
|
||||
@router.post("/crawls", response=CrawlSchema, url_name="create_crawl")
|
||||
def create_crawl(request, data: CrawlCreateSchema):
|
||||
def create_crawl(request: HttpRequest, data: CrawlCreateSchema):
|
||||
urls = [url.strip() for url in data.urls if url and url.strip()]
|
||||
if not urls:
|
||||
raise HttpError(400, 'At least one URL is required')
|
||||
@@ -107,16 +111,16 @@ def create_crawl(request, data: CrawlCreateSchema):
|
||||
config=data.config,
|
||||
status=Crawl.StatusChoices.QUEUED,
|
||||
retry_at=timezone.now(),
|
||||
created_by=request.user,
|
||||
created_by=request.user if isinstance(request.user, User) else None,
|
||||
)
|
||||
crawl.create_snapshots_from_urls()
|
||||
return crawl
|
||||
|
||||
@router.get("/crawl/{crawl_id}", response=CrawlSchema | str, url_name="get_crawl")
|
||||
def get_crawl(request, crawl_id: str, as_rss: bool=False, with_snapshots: bool=False, with_archiveresults: bool=False):
|
||||
def get_crawl(request: HttpRequest, crawl_id: str, as_rss: bool=False, with_snapshots: bool=False, with_archiveresults: bool=False):
|
||||
"""Get a specific Crawl by id."""
|
||||
request.with_snapshots = with_snapshots
|
||||
request.with_archiveresults = with_archiveresults
|
||||
setattr(request, 'with_snapshots', with_snapshots)
|
||||
setattr(request, 'with_archiveresults', with_archiveresults)
|
||||
crawl = Crawl.objects.get(id__icontains=crawl_id)
|
||||
|
||||
if crawl and as_rss:
|
||||
@@ -135,7 +139,7 @@ def get_crawl(request, crawl_id: str, as_rss: bool=False, with_snapshots: bool=F
|
||||
|
||||
|
||||
@router.patch("/crawl/{crawl_id}", response=CrawlSchema, url_name="patch_crawl")
|
||||
def patch_crawl(request, crawl_id: str, data: CrawlUpdateSchema):
|
||||
def patch_crawl(request: HttpRequest, crawl_id: str, data: CrawlUpdateSchema):
|
||||
"""Update a crawl (e.g., set status=sealed to cancel queued work)."""
|
||||
crawl = Crawl.objects.get(id__icontains=crawl_id)
|
||||
payload = data.dict(exclude_unset=True)
|
||||
@@ -174,7 +178,7 @@ def patch_crawl(request, crawl_id: str, data: CrawlUpdateSchema):
|
||||
|
||||
|
||||
@router.delete("/crawl/{crawl_id}", response=CrawlDeleteResponseSchema, url_name="delete_crawl")
|
||||
def delete_crawl(request, crawl_id: str):
|
||||
def delete_crawl(request: HttpRequest, crawl_id: str):
|
||||
crawl = Crawl.objects.get(id__icontains=crawl_id)
|
||||
crawl_id_str = str(crawl.id)
|
||||
snapshot_count = crawl.snapshot_set.count()
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
__package__ = 'archivebox.api'
|
||||
|
||||
from uuid import UUID
|
||||
from typing import List, Optional
|
||||
from typing import Annotated, List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from ninja import Router, Schema, FilterSchema, Field, Query
|
||||
from django.http import HttpRequest
|
||||
|
||||
from ninja import FilterLookup, FilterSchema, Query, Router, Schema
|
||||
from ninja.pagination import paginate
|
||||
|
||||
from archivebox.api.v1_core import CustomPagination
|
||||
@@ -41,16 +43,13 @@ class MachineSchema(Schema):
|
||||
|
||||
|
||||
class MachineFilterSchema(FilterSchema):
|
||||
id: Optional[str] = Field(None, q='id__startswith')
|
||||
hostname: Optional[str] = Field(None, q='hostname__icontains')
|
||||
os_platform: Optional[str] = Field(None, q='os_platform__icontains')
|
||||
os_arch: Optional[str] = Field(None, q='os_arch')
|
||||
hw_in_docker: Optional[bool] = Field(None, q='hw_in_docker')
|
||||
hw_in_vm: Optional[bool] = Field(None, q='hw_in_vm')
|
||||
|
||||
|
||||
# ============================================================================
|
||||
bin_providers: Optional[str] = Field(None, q='bin_providers__icontains')
|
||||
id: Annotated[Optional[str], FilterLookup('id__startswith')] = None
|
||||
hostname: Annotated[Optional[str], FilterLookup('hostname__icontains')] = None
|
||||
os_platform: Annotated[Optional[str], FilterLookup('os_platform__icontains')] = None
|
||||
os_arch: Annotated[Optional[str], FilterLookup('os_arch')] = None
|
||||
hw_in_docker: Annotated[Optional[bool], FilterLookup('hw_in_docker')] = None
|
||||
hw_in_vm: Annotated[Optional[bool], FilterLookup('hw_in_vm')] = None
|
||||
bin_providers: Annotated[Optional[str], FilterLookup('bin_providers__icontains')] = None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
@@ -86,12 +85,12 @@ class BinarySchema(Schema):
|
||||
|
||||
|
||||
class BinaryFilterSchema(FilterSchema):
|
||||
id: Optional[str] = Field(None, q='id__startswith')
|
||||
name: Optional[str] = Field(None, q='name__icontains')
|
||||
binprovider: Optional[str] = Field(None, q='binprovider')
|
||||
status: Optional[str] = Field(None, q='status')
|
||||
machine_id: Optional[str] = Field(None, q='machine_id__startswith')
|
||||
version: Optional[str] = Field(None, q='version__icontains')
|
||||
id: Annotated[Optional[str], FilterLookup('id__startswith')] = None
|
||||
name: Annotated[Optional[str], FilterLookup('name__icontains')] = None
|
||||
binprovider: Annotated[Optional[str], FilterLookup('binprovider')] = None
|
||||
status: Annotated[Optional[str], FilterLookup('status')] = None
|
||||
machine_id: Annotated[Optional[str], FilterLookup('machine_id__startswith')] = None
|
||||
version: Annotated[Optional[str], FilterLookup('version__icontains')] = None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
@@ -100,21 +99,21 @@ class BinaryFilterSchema(FilterSchema):
|
||||
|
||||
@router.get("/machines", response=List[MachineSchema], url_name="get_machines")
|
||||
@paginate(CustomPagination)
|
||||
def get_machines(request, filters: MachineFilterSchema = Query(...)):
|
||||
def get_machines(request: HttpRequest, filters: Query[MachineFilterSchema]):
|
||||
"""List all machines."""
|
||||
from archivebox.machine.models import Machine
|
||||
return filters.filter(Machine.objects.all()).distinct()
|
||||
|
||||
|
||||
@router.get("/machine/current", response=MachineSchema, url_name="get_current_machine")
|
||||
def get_current_machine(request):
|
||||
def get_current_machine(request: HttpRequest):
|
||||
"""Get the current machine."""
|
||||
from archivebox.machine.models import Machine
|
||||
return Machine.current()
|
||||
|
||||
|
||||
@router.get("/machine/{machine_id}", response=MachineSchema, url_name="get_machine")
|
||||
def get_machine(request, machine_id: str):
|
||||
def get_machine(request: HttpRequest, machine_id: str):
|
||||
"""Get a specific machine by ID."""
|
||||
from archivebox.machine.models import Machine
|
||||
from django.db.models import Q
|
||||
@@ -130,21 +129,21 @@ def get_machine(request, machine_id: str):
|
||||
|
||||
@router.get("/binaries", response=List[BinarySchema], url_name="get_binaries")
|
||||
@paginate(CustomPagination)
|
||||
def get_binaries(request, filters: BinaryFilterSchema = Query(...)):
|
||||
def get_binaries(request: HttpRequest, filters: Query[BinaryFilterSchema]):
|
||||
"""List all binaries."""
|
||||
from archivebox.machine.models import Binary
|
||||
return filters.filter(Binary.objects.all().select_related('machine')).distinct()
|
||||
|
||||
|
||||
@router.get("/binary/{binary_id}", response=BinarySchema, url_name="get_binary")
|
||||
def get_binary(request, binary_id: str):
|
||||
def get_binary(request: HttpRequest, binary_id: str):
|
||||
"""Get a specific binary by ID."""
|
||||
from archivebox.machine.models import Binary
|
||||
return Binary.objects.select_related('machine').get(id__startswith=binary_id)
|
||||
|
||||
|
||||
@router.get("/binary/by-name/{name}", response=List[BinarySchema], url_name="get_binaries_by_name")
|
||||
def get_binaries_by_name(request, name: str):
|
||||
def get_binaries_by_name(request: HttpRequest, name: str):
|
||||
"""Get all binaries with the given name."""
|
||||
from archivebox.machine.models import Binary
|
||||
return list(Binary.objects.filter(name__iexact=name).select_related('machine'))
|
||||
|
||||
@@ -39,7 +39,10 @@ def process_archiveresult_by_id(archiveresult_id: str) -> int:
|
||||
"""
|
||||
Run extraction for a single ArchiveResult by ID (used by workers).
|
||||
|
||||
Triggers the ArchiveResult's state machine tick() to run the extractor plugin.
|
||||
Triggers the ArchiveResult's state machine tick() to run the extractor
|
||||
plugin, but only after claiming ownership via retry_at. This keeps direct
|
||||
CLI execution aligned with the worker lifecycle and prevents duplicate hook
|
||||
runs if another process already owns the same ArchiveResult.
|
||||
"""
|
||||
from rich import print as rprint
|
||||
from archivebox.core.models import ArchiveResult
|
||||
@@ -53,9 +56,12 @@ def process_archiveresult_by_id(archiveresult_id: str) -> int:
|
||||
rprint(f'[blue]Extracting {archiveresult.plugin} for {archiveresult.snapshot.url}[/blue]', file=sys.stderr)
|
||||
|
||||
try:
|
||||
# Trigger state machine tick - this runs the actual extraction
|
||||
archiveresult.sm.tick()
|
||||
archiveresult.refresh_from_db()
|
||||
# Claim-before-tick is the required calling pattern for direct
|
||||
# state-machine drivers. If another worker already owns this row,
|
||||
# report that and exit without running duplicate extractor side effects.
|
||||
if not archiveresult.tick_claimed(lock_seconds=120):
|
||||
print(f'[yellow]Extraction already claimed by another process: {archiveresult.plugin}[/yellow]')
|
||||
return 0
|
||||
|
||||
if archiveresult.status == ArchiveResult.StatusChoices.SUCCEEDED:
|
||||
print(f'[green]Extraction succeeded: {archiveresult.output_str}[/green]')
|
||||
|
||||
@@ -382,6 +382,88 @@ class Crawl(ModelWithOutputDir, ModelWithConfig, ModelWithHealthStats, ModelWith
|
||||
|
||||
return created_snapshots
|
||||
|
||||
def install_declared_binaries(self, binary_names: set[str], machine=None) -> None:
|
||||
"""
|
||||
Install crawl-declared Binary rows without violating the retry_at lock lifecycle.
|
||||
|
||||
Correct calling pattern:
|
||||
1. Crawl hooks declare Binary records and queue them with retry_at <= now
|
||||
2. Exactly one actor claims each Binary by moving retry_at into the future
|
||||
3. Only that owner executes `.sm.tick()` and performs install side effects
|
||||
4. Everyone else waits for the claimed owner to finish instead of launching
|
||||
a second install against shared state such as the pip or npm trees
|
||||
|
||||
This helper follows that contract by claiming each Binary before ticking
|
||||
it, and by waiting when another worker already owns the row. That keeps
|
||||
synchronous crawl execution compatible with the global BinaryWorker and
|
||||
avoids duplicate installs of the same dependency.
|
||||
"""
|
||||
import time
|
||||
from archivebox.machine.models import Binary, Machine
|
||||
|
||||
if not binary_names:
|
||||
return
|
||||
|
||||
machine = machine or Machine.current()
|
||||
lock_seconds = 600
|
||||
deadline = time.monotonic() + max(lock_seconds, len(binary_names) * lock_seconds)
|
||||
|
||||
while time.monotonic() < deadline:
|
||||
unresolved_binaries = list(
|
||||
Binary.objects.filter(
|
||||
machine=machine,
|
||||
name__in=binary_names,
|
||||
).exclude(
|
||||
status=Binary.StatusChoices.INSTALLED,
|
||||
).order_by('name')
|
||||
)
|
||||
if not unresolved_binaries:
|
||||
return
|
||||
|
||||
claimed_any = False
|
||||
waiting_on_existing_owner = False
|
||||
now = timezone.now()
|
||||
|
||||
for binary in unresolved_binaries:
|
||||
try:
|
||||
if binary.tick_claimed(lock_seconds=lock_seconds):
|
||||
claimed_any = True
|
||||
continue
|
||||
except Exception:
|
||||
claimed_any = True
|
||||
continue
|
||||
|
||||
binary.refresh_from_db()
|
||||
if binary.status == Binary.StatusChoices.INSTALLED:
|
||||
claimed_any = True
|
||||
continue
|
||||
if binary.retry_at and binary.retry_at > now:
|
||||
waiting_on_existing_owner = True
|
||||
|
||||
if claimed_any:
|
||||
continue
|
||||
if waiting_on_existing_owner:
|
||||
time.sleep(0.5)
|
||||
continue
|
||||
break
|
||||
|
||||
unresolved_binaries = list(
|
||||
Binary.objects.filter(
|
||||
machine=machine,
|
||||
name__in=binary_names,
|
||||
).exclude(
|
||||
status=Binary.StatusChoices.INSTALLED,
|
||||
).order_by('name')
|
||||
)
|
||||
if unresolved_binaries:
|
||||
binary_details = ', '.join(
|
||||
f'{binary.name} (status={binary.status}, retry_at={binary.retry_at})'
|
||||
for binary in unresolved_binaries
|
||||
)
|
||||
raise RuntimeError(
|
||||
f'Crawl dependencies failed to install before continuing: {binary_details}'
|
||||
)
|
||||
|
||||
def run(self) -> 'Snapshot | None':
|
||||
"""
|
||||
Execute this Crawl: run hooks, process JSONL, create snapshots.
|
||||
@@ -428,47 +510,6 @@ class Crawl(ModelWithOutputDir, ModelWithConfig, ModelWithHealthStats, ModelWith
|
||||
chrome_binary=chrome_binary,
|
||||
)
|
||||
|
||||
def install_declared_binaries(binary_names: set[str]) -> None:
|
||||
if not binary_names:
|
||||
return
|
||||
|
||||
max_attempts = max(2, len(binary_names))
|
||||
|
||||
for _ in range(max_attempts):
|
||||
pending_binaries = list(
|
||||
Binary.objects.filter(
|
||||
machine=machine,
|
||||
name__in=binary_names,
|
||||
).exclude(
|
||||
status=Binary.StatusChoices.INSTALLED,
|
||||
).order_by('retry_at', 'name')
|
||||
)
|
||||
if not pending_binaries:
|
||||
return
|
||||
|
||||
for binary in pending_binaries:
|
||||
try:
|
||||
binary.sm.tick()
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
unresolved_binaries = list(
|
||||
Binary.objects.filter(
|
||||
machine=machine,
|
||||
name__in=binary_names,
|
||||
).exclude(
|
||||
status=Binary.StatusChoices.INSTALLED,
|
||||
).order_by('name')
|
||||
)
|
||||
if unresolved_binaries:
|
||||
binary_details = ', '.join(
|
||||
f'{binary.name} (status={binary.status})'
|
||||
for binary in unresolved_binaries
|
||||
)
|
||||
raise RuntimeError(
|
||||
f'Crawl dependencies failed to install before continuing: {binary_details}'
|
||||
)
|
||||
|
||||
executed_crawl_hooks: set[str] = set()
|
||||
|
||||
def run_crawl_hook(hook: Path) -> set[str]:
|
||||
@@ -598,11 +639,11 @@ class Crawl(ModelWithOutputDir, ModelWithConfig, ModelWithHealthStats, ModelWith
|
||||
for hook in hooks:
|
||||
hook_binary_names = run_crawl_hook(hook)
|
||||
if hook_binary_names:
|
||||
install_declared_binaries(resolve_provider_binaries(hook_binary_names))
|
||||
self.install_declared_binaries(resolve_provider_binaries(hook_binary_names), machine=machine)
|
||||
|
||||
# Safety check: don't create snapshots if any crawl-declared dependency
|
||||
# is still unresolved after all crawl hooks have run.
|
||||
install_declared_binaries(declared_binary_names)
|
||||
self.install_declared_binaries(declared_binary_names, machine=machine)
|
||||
|
||||
# Create snapshots from all URLs in self.urls
|
||||
if system_task:
|
||||
|
||||
143
archivebox/tests/test_state_machine_claims.py
Normal file
143
archivebox/tests/test_state_machine_claims.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import threading
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from django.db import close_old_connections
|
||||
from django.utils import timezone
|
||||
|
||||
from archivebox.base_models.models import get_or_create_system_user_pk
|
||||
from archivebox.crawls.models import Crawl
|
||||
from archivebox.machine.models import Binary, Machine
|
||||
from archivebox.workers.worker import BinaryWorker
|
||||
|
||||
|
||||
def get_fresh_machine() -> Machine:
|
||||
import archivebox.machine.models as machine_models
|
||||
|
||||
machine_models._CURRENT_MACHINE = None
|
||||
machine_models._CURRENT_BINARIES.clear()
|
||||
return Machine.current()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_claim_processing_lock_does_not_steal_future_retry_at():
|
||||
"""
|
||||
retry_at is both the schedule and the ownership lock.
|
||||
|
||||
Once one process claims a due row and moves retry_at into the future, a
|
||||
fresh reader must not be able to "re-claim" that future timestamp and run
|
||||
the same side effects a second time.
|
||||
"""
|
||||
machine = get_fresh_machine()
|
||||
binary = Binary.objects.create(
|
||||
machine=machine,
|
||||
name='claim-test',
|
||||
binproviders='env',
|
||||
status=Binary.StatusChoices.QUEUED,
|
||||
retry_at=timezone.now(),
|
||||
)
|
||||
|
||||
owner = Binary.objects.get(pk=binary.pk)
|
||||
contender = Binary.objects.get(pk=binary.pk)
|
||||
|
||||
assert owner.claim_processing_lock(lock_seconds=30) is True
|
||||
|
||||
contender.refresh_from_db()
|
||||
assert contender.retry_at > timezone.now()
|
||||
assert contender.claim_processing_lock(lock_seconds=30) is False
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_binary_worker_skips_binary_claimed_by_other_owner(monkeypatch):
|
||||
"""
|
||||
BinaryWorker must never run install side effects for a Binary whose retry_at
|
||||
lock has already been claimed by another process.
|
||||
"""
|
||||
machine = get_fresh_machine()
|
||||
binary = Binary.objects.create(
|
||||
machine=machine,
|
||||
name='claimed-binary',
|
||||
binproviders='env',
|
||||
status=Binary.StatusChoices.QUEUED,
|
||||
retry_at=timezone.now(),
|
||||
)
|
||||
|
||||
owner = Binary.objects.get(pk=binary.pk)
|
||||
assert owner.claim_processing_lock(lock_seconds=30) is True
|
||||
|
||||
calls: list[str] = []
|
||||
|
||||
def fake_run(self):
|
||||
calls.append(self.name)
|
||||
self.status = self.StatusChoices.INSTALLED
|
||||
self.abspath = '/tmp/fake-binary'
|
||||
self.version = '1.0'
|
||||
self.save(update_fields=['status', 'abspath', 'version', 'modified_at'])
|
||||
|
||||
monkeypatch.setattr(Binary, 'run', fake_run)
|
||||
|
||||
worker = BinaryWorker(binary_id=str(binary.id))
|
||||
worker._process_single_binary()
|
||||
|
||||
assert calls == []
|
||||
|
||||
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_crawl_install_declared_binaries_waits_for_existing_owner(monkeypatch):
|
||||
"""
|
||||
Crawl.install_declared_binaries should wait for the current owner of a Binary
|
||||
to finish instead of launching a duplicate install against shared provider
|
||||
state such as the npm tree.
|
||||
"""
|
||||
machine = get_fresh_machine()
|
||||
crawl = Crawl.objects.create(
|
||||
urls='https://example.com',
|
||||
created_by_id=get_or_create_system_user_pk(),
|
||||
status=Crawl.StatusChoices.QUEUED,
|
||||
retry_at=timezone.now(),
|
||||
)
|
||||
binary = Binary.objects.create(
|
||||
machine=machine,
|
||||
name='puppeteer',
|
||||
binproviders='npm',
|
||||
status=Binary.StatusChoices.QUEUED,
|
||||
retry_at=timezone.now(),
|
||||
)
|
||||
|
||||
owner = Binary.objects.get(pk=binary.pk)
|
||||
assert owner.claim_processing_lock(lock_seconds=30) is True
|
||||
|
||||
calls: list[str] = []
|
||||
|
||||
def fake_run(self):
|
||||
calls.append(self.name)
|
||||
self.status = self.StatusChoices.INSTALLED
|
||||
self.abspath = '/tmp/should-not-run'
|
||||
self.version = '1.0'
|
||||
self.save(update_fields=['status', 'abspath', 'version', 'modified_at'])
|
||||
|
||||
monkeypatch.setattr(Binary, 'run', fake_run)
|
||||
|
||||
def finish_existing_install():
|
||||
close_old_connections()
|
||||
try:
|
||||
time.sleep(0.3)
|
||||
Binary.objects.filter(pk=binary.pk).update(
|
||||
status=Binary.StatusChoices.INSTALLED,
|
||||
retry_at=None,
|
||||
abspath='/tmp/finished-by-owner',
|
||||
version='1.0',
|
||||
modified_at=timezone.now(),
|
||||
)
|
||||
finally:
|
||||
close_old_connections()
|
||||
|
||||
thread = threading.Thread(target=finish_existing_install, daemon=True)
|
||||
thread.start()
|
||||
crawl.install_declared_binaries({'puppeteer'}, machine=machine)
|
||||
thread.join(timeout=5)
|
||||
|
||||
binary.refresh_from_db()
|
||||
assert binary.status == Binary.StatusChoices.INSTALLED
|
||||
assert binary.abspath == '/tmp/finished-by-owner'
|
||||
assert calls == []
|
||||
@@ -210,17 +210,71 @@ class BaseModelWithStateMachine(models.Model, MachineMixin):
|
||||
@classmethod
|
||||
def claim_for_worker(cls, obj: 'BaseModelWithStateMachine', lock_seconds: int = 60) -> bool:
|
||||
"""
|
||||
Atomically claim an object for processing using optimistic locking.
|
||||
Returns True if successfully claimed, False if another worker got it first.
|
||||
Atomically claim a due object for processing using retry_at as the lock.
|
||||
|
||||
Correct lifecycle for any state-machine-driven work item:
|
||||
1. Queue the item by setting retry_at <= now
|
||||
2. Exactly one owner claims it by moving retry_at into the future
|
||||
3. Only that owner may call .sm.tick() and perform side effects
|
||||
4. State-machine callbacks update retry_at again when the work completes,
|
||||
backs off, or is re-queued
|
||||
|
||||
The critical rule is that future retry_at values are already owned.
|
||||
Callers must never "steal" those future timestamps and start another
|
||||
copy of the same work. That is what prevents duplicate installs, hook
|
||||
runs, and other concurrent side effects.
|
||||
|
||||
Returns True if successfully claimed, False if another worker got it
|
||||
first or the object is not currently due.
|
||||
"""
|
||||
updated = cls.objects.filter(
|
||||
pk=obj.pk,
|
||||
retry_at=obj.retry_at,
|
||||
retry_at__lte=timezone.now(),
|
||||
).update(
|
||||
retry_at=timezone.now() + timedelta(seconds=lock_seconds)
|
||||
)
|
||||
return updated == 1
|
||||
|
||||
def claim_processing_lock(self, lock_seconds: int = 60) -> bool:
|
||||
"""
|
||||
Claim this model instance immediately before executing one state-machine tick.
|
||||
|
||||
This helper is the safe entrypoint for any direct state-machine driver
|
||||
(workers, synchronous crawl dependency installers, one-off CLI helpers).
|
||||
Calling `.sm.tick()` without claiming first turns retry_at into "just a
|
||||
schedule" instead of the ownership lock it is meant to be.
|
||||
|
||||
Returns True only for the caller that successfully moved retry_at into
|
||||
the future. False means another process already owns the work item or it
|
||||
is not currently due.
|
||||
"""
|
||||
if self.STATE in self.FINAL_STATES:
|
||||
return False
|
||||
if self.RETRY_AT is None:
|
||||
return False
|
||||
|
||||
claimed = type(self).claim_for_worker(self, lock_seconds=lock_seconds)
|
||||
if claimed:
|
||||
self.refresh_from_db()
|
||||
return claimed
|
||||
|
||||
def tick_claimed(self, lock_seconds: int = 60) -> bool:
|
||||
"""
|
||||
Claim ownership via retry_at and then execute exactly one `.sm.tick()`.
|
||||
|
||||
Future maintainers should prefer this helper over calling `.sm.tick()`
|
||||
directly whenever there is any chance another process could see the same
|
||||
queued row. If this method returns False, someone else already owns the
|
||||
work and the caller must not run side effects for it.
|
||||
"""
|
||||
if not self.claim_processing_lock(lock_seconds=lock_seconds):
|
||||
return False
|
||||
|
||||
self.sm.tick()
|
||||
self.refresh_from_db()
|
||||
return True
|
||||
|
||||
@classproperty
|
||||
def ACTIVE_STATE(cls) -> str:
|
||||
return cls._state_to_str(cls.active_state)
|
||||
|
||||
@@ -35,6 +35,7 @@ from datetime import timedelta
|
||||
from multiprocessing import Process as MPProcess
|
||||
from pathlib import Path
|
||||
|
||||
from django.db import connections
|
||||
from django.utils import timezone
|
||||
|
||||
from rich import print
|
||||
@@ -403,6 +404,17 @@ class Orchestrator:
|
||||
|
||||
return queue_sizes
|
||||
|
||||
def _refresh_db_connections(self) -> None:
|
||||
"""
|
||||
Drop long-lived DB connections before each poll tick.
|
||||
|
||||
The daemon orchestrator must observe rows created by sibling processes
|
||||
(server requests, CLI helpers, docker-compose run invocations). With
|
||||
SQLite, reusing the same connection indefinitely can miss externally
|
||||
committed rows until the process reconnects.
|
||||
"""
|
||||
connections.close_all()
|
||||
|
||||
def _should_process_schedules(self) -> bool:
|
||||
return (not self.exit_on_idle) and (self.crawl_id is None)
|
||||
|
||||
@@ -576,17 +588,10 @@ class Orchestrator:
|
||||
)
|
||||
|
||||
def _claim_crawl(self, crawl) -> bool:
|
||||
"""Atomically claim a crawl using optimistic locking."""
|
||||
"""Atomically claim a due crawl using the shared retry_at lock lifecycle."""
|
||||
from archivebox.crawls.models import Crawl
|
||||
|
||||
updated = Crawl.objects.filter(
|
||||
pk=crawl.pk,
|
||||
retry_at=crawl.retry_at,
|
||||
).update(
|
||||
retry_at=timezone.now() + timedelta(hours=24), # Long lock (crawls take time)
|
||||
)
|
||||
|
||||
return updated == 1
|
||||
return Crawl.claim_for_worker(crawl, lock_seconds=24 * 60 * 60)
|
||||
|
||||
def has_pending_work(self, queue_sizes: dict[str, int]) -> bool:
|
||||
"""Check if any queue has pending work."""
|
||||
@@ -726,6 +731,10 @@ class Orchestrator:
|
||||
while True:
|
||||
tick_count += 1
|
||||
|
||||
# Refresh DB state before polling so this long-lived daemon sees
|
||||
# work created by other processes using the same collection.
|
||||
self._refresh_db_connections()
|
||||
|
||||
# Check queues and spawn workers
|
||||
queue_sizes = self.check_queues_and_spawn_workers()
|
||||
|
||||
|
||||
@@ -569,6 +569,7 @@ class CrawlWorker(Worker):
|
||||
def _spawn_snapshot_workers(self) -> None:
|
||||
"""Spawn SnapshotWorkers for queued snapshots (up to limit)."""
|
||||
from pathlib import Path
|
||||
from archivebox.config.constants import CONSTANTS
|
||||
from archivebox.core.models import Snapshot
|
||||
from archivebox.machine.models import Process
|
||||
import sys
|
||||
@@ -636,6 +637,18 @@ class CrawlWorker(Worker):
|
||||
f.write(f' Spawning worker for {snapshot.url} (status={snapshot.status})\n')
|
||||
f.flush()
|
||||
|
||||
# Claim the snapshot before spawning the worker so retry_at remains
|
||||
# the single source of truth for ownership even if process tracking
|
||||
# lags or multiple schedulers look at the same queue.
|
||||
if not Snapshot.claim_for_worker(snapshot, lock_seconds=CONSTANTS.MAX_SNAPSHOT_RUNTIME_SECONDS):
|
||||
log_worker_event(
|
||||
worker_type='CrawlWorker',
|
||||
event=f'Skipped already-claimed Snapshot: {snapshot.url}',
|
||||
indent_level=1,
|
||||
pid=self.pid,
|
||||
)
|
||||
continue
|
||||
|
||||
pid = SnapshotWorker.start(parent=self.db_process, snapshot_id=str(snapshot.id))
|
||||
|
||||
log_worker_event(
|
||||
@@ -1195,9 +1208,15 @@ class BinaryWorker(Worker):
|
||||
return
|
||||
|
||||
print(f'[cyan]🔧 BinaryWorker installing: {binary.name}[/cyan]', file=sys.stderr)
|
||||
binary.sm.tick()
|
||||
if not binary.tick_claimed(lock_seconds=self.MAX_TICK_TIME):
|
||||
log_worker_event(
|
||||
worker_type='BinaryWorker',
|
||||
event=f'Skipped already-claimed binary: {binary.name}',
|
||||
indent_level=1,
|
||||
pid=self.pid,
|
||||
)
|
||||
return
|
||||
|
||||
binary.refresh_from_db()
|
||||
if binary.status == binary.__class__.StatusChoices.INSTALLED:
|
||||
log_worker_event(
|
||||
worker_type='BinaryWorker',
|
||||
@@ -1254,9 +1273,15 @@ class BinaryWorker(Worker):
|
||||
for binary in pending_binaries:
|
||||
try:
|
||||
print(f'[cyan]🔧 BinaryWorker processing: {binary.name}[/cyan]', file=sys.stderr)
|
||||
binary.sm.tick()
|
||||
if not binary.tick_claimed(lock_seconds=self.MAX_TICK_TIME):
|
||||
log_worker_event(
|
||||
worker_type='BinaryWorker',
|
||||
event=f'Skipped already-claimed binary: {binary.name}',
|
||||
indent_level=1,
|
||||
pid=self.pid,
|
||||
)
|
||||
continue
|
||||
|
||||
binary.refresh_from_db()
|
||||
if binary.status == binary.__class__.StatusChoices.INSTALLED:
|
||||
log_worker_event(
|
||||
worker_type='BinaryWorker',
|
||||
|
||||
Reference in New Issue
Block a user