Files
ArchiveBox/archivebox/workers/worker.py
Nick Sweeting 54f91c1339 Improve concurrency control between plugin hooks (#1721)
<!-- IMPORTANT: Do not submit PRs with only formatting / PEP8 / line
length changes. -->

# Summary

<!--e.g. This PR fixes ABC or adds the ability to do XYZ...-->

# Related issues

<!-- e.g. #123 or Roadmap goal #
https://github.com/pirate/ArchiveBox/wiki/Roadmap -->

# Changes these areas

- [ ] Bugfixes
- [ ] Feature behavior
- [ ] Command line interface
- [ ] Configuration options
- [ ] Internal architecture
- [ ] Snapshot data layout on disk
2025-12-28 12:48:53 -08:00

434 lines
14 KiB
Python

"""
Worker classes for processing queue items.
Workers poll the database for items to process, claim them atomically,
and run the state machine tick() to process each item.
Architecture:
Orchestrator (spawns workers)
└── Worker (claims items from queue, processes them directly)
"""
__package__ = 'archivebox.workers'
import os
import time
import traceback
from typing import ClassVar, Any
from datetime import timedelta
from pathlib import Path
from multiprocessing import Process, cpu_count
from django.db.models import QuerySet
from django.utils import timezone
from django.conf import settings
from rich import print
from archivebox.misc.logging_util import log_worker_event
from .pid_utils import (
write_pid_file,
remove_pid_file,
get_all_worker_pids,
get_next_worker_id,
cleanup_stale_pid_files,
)
CPU_COUNT = cpu_count()
# Registry of worker types by name (defined at bottom, referenced here for _run_worker)
WORKER_TYPES: dict[str, type['Worker']] = {}
def _run_worker(worker_class_name: str, worker_id: int, daemon: bool, **kwargs):
"""
Module-level function to run a worker. Must be at module level for pickling.
"""
from archivebox.config.django import setup_django
setup_django()
# Get worker class by name to avoid pickling class objects
worker_cls = WORKER_TYPES[worker_class_name]
worker = worker_cls(worker_id=worker_id, daemon=daemon, **kwargs)
worker.runloop()
class Worker:
"""
Base worker class that polls a queue and processes items directly.
Each item is processed by calling its state machine tick() method.
Workers exit when idle for too long (unless daemon mode).
"""
name: ClassVar[str] = 'worker'
# Configuration (can be overridden by subclasses)
MAX_TICK_TIME: ClassVar[int] = 60
MAX_CONCURRENT_TASKS: ClassVar[int] = 1
POLL_INTERVAL: ClassVar[float] = 0.2 # How often to check for new work (seconds)
IDLE_TIMEOUT: ClassVar[int] = 50 # Exit after N idle iterations (10 sec at 0.2 poll interval)
def __init__(self, worker_id: int = 0, daemon: bool = False, **kwargs: Any):
self.worker_id = worker_id
self.daemon = daemon
self.pid: int = os.getpid()
self.pid_file: Path | None = None
self.idle_count: int = 0
def __repr__(self) -> str:
return f'[underline]{self.__class__.__name__}[/underline]\\[id={self.worker_id}, pid={self.pid}]'
def get_model(self):
"""Get the Django model class. Subclasses must override this."""
raise NotImplementedError("Subclasses must implement get_model()")
def get_queue(self) -> QuerySet:
"""Get the queue of objects ready for processing."""
Model = self.get_model()
return Model.objects.filter(
retry_at__lte=timezone.now()
).exclude(
status__in=Model.FINAL_STATES
).order_by('retry_at')
def claim_next(self):
"""
Atomically claim the next object from the queue.
Returns the claimed object or None if queue is empty or claim failed.
"""
Model = self.get_model()
obj = self.get_queue().first()
if obj is None:
return None
# Atomic claim using optimistic locking on retry_at
claimed = Model.objects.filter(
pk=obj.pk,
retry_at=obj.retry_at,
).update(
retry_at=timezone.now() + timedelta(seconds=self.MAX_TICK_TIME)
)
if claimed == 1:
obj.refresh_from_db()
return obj
return None # Someone else claimed it
def process_item(self, obj) -> bool:
"""
Process a single item by calling its state machine tick().
Returns True on success, False on failure.
Subclasses can override for custom processing.
"""
try:
obj.sm.tick()
return True
except Exception as e:
# Error will be logged in runloop's completion event
traceback.print_exc()
return False
def on_startup(self) -> None:
"""Called when worker starts."""
self.pid = os.getpid()
self.pid_file = write_pid_file(self.name, self.worker_id)
# Determine worker type for logging
worker_type_name = self.__class__.__name__
indent_level = 1 # Default for most workers
# Adjust indent level based on worker type
if 'Snapshot' in worker_type_name:
indent_level = 2
elif 'ArchiveResult' in worker_type_name:
indent_level = 3
log_worker_event(
worker_type=worker_type_name,
event='Starting...',
indent_level=indent_level,
pid=self.pid,
worker_id=str(self.worker_id),
metadata={
'max_concurrent': self.MAX_CONCURRENT_TASKS,
'poll_interval': self.POLL_INTERVAL,
},
)
def on_shutdown(self, error: BaseException | None = None) -> None:
"""Called when worker shuts down."""
# Remove PID file
if self.pid_file:
remove_pid_file(self.pid_file)
# Determine worker type for logging
worker_type_name = self.__class__.__name__
indent_level = 1
if 'Snapshot' in worker_type_name:
indent_level = 2
elif 'ArchiveResult' in worker_type_name:
indent_level = 3
log_worker_event(
worker_type=worker_type_name,
event='Shutting down',
indent_level=indent_level,
pid=self.pid,
worker_id=str(self.worker_id),
error=error if error and not isinstance(error, KeyboardInterrupt) else None,
)
def should_exit(self) -> bool:
"""Check if worker should exit due to idle timeout."""
if self.daemon:
return False
if self.IDLE_TIMEOUT == 0:
return False
return self.idle_count >= self.IDLE_TIMEOUT
def runloop(self) -> None:
"""Main worker loop - polls queue, processes items."""
self.on_startup()
# Determine worker type for logging
worker_type_name = self.__class__.__name__
indent_level = 1
if 'Snapshot' in worker_type_name:
indent_level = 2
elif 'ArchiveResult' in worker_type_name:
indent_level = 3
try:
while True:
# Try to claim and process an item
obj = self.claim_next()
if obj is not None:
self.idle_count = 0
# Build metadata for task start
start_metadata = {}
url = None
if hasattr(obj, 'url'):
# SnapshotWorker
url = str(obj.url) if obj.url else None
elif hasattr(obj, 'snapshot') and hasattr(obj.snapshot, 'url'):
# ArchiveResultWorker
url = str(obj.snapshot.url) if obj.snapshot.url else None
elif hasattr(obj, 'get_urls_list'):
# CrawlWorker
urls = obj.get_urls_list()
url = urls[0] if urls else None
plugin = None
if hasattr(obj, 'plugin'):
# ArchiveResultWorker, Crawl
plugin = obj.plugin
log_worker_event(
worker_type=worker_type_name,
event='Starting...',
indent_level=indent_level,
pid=self.pid,
worker_id=str(self.worker_id),
url=url,
plugin=plugin,
metadata=start_metadata if start_metadata else None,
)
start_time = time.time()
success = self.process_item(obj)
elapsed = time.time() - start_time
# Build metadata for task completion
complete_metadata = {
'duration': elapsed,
'status': 'success' if success else 'failed',
}
log_worker_event(
worker_type=worker_type_name,
event='Completed' if success else 'Failed',
indent_level=indent_level,
pid=self.pid,
worker_id=str(self.worker_id),
url=url,
plugin=plugin,
metadata=complete_metadata,
)
else:
# No work available - idle logging suppressed
self.idle_count += 1
# Check if we should exit
if self.should_exit():
# Exit logging suppressed - shutdown will be logged by on_shutdown()
break
time.sleep(self.POLL_INTERVAL)
except KeyboardInterrupt:
pass
except BaseException as e:
self.on_shutdown(error=e)
raise
else:
self.on_shutdown()
@classmethod
def start(cls, worker_id: int | None = None, daemon: bool = False, **kwargs: Any) -> int:
"""
Fork a new worker as a subprocess.
Returns the PID of the new process.
"""
if worker_id is None:
worker_id = get_next_worker_id(cls.name)
# Use module-level function for pickling compatibility
proc = Process(
target=_run_worker,
args=(cls.name, worker_id, daemon),
kwargs=kwargs,
name=f'{cls.name}_worker_{worker_id}',
)
proc.start()
assert proc.pid is not None
return proc.pid
@classmethod
def get_running_workers(cls) -> list[dict]:
"""Get info about all running workers of this type."""
cleanup_stale_pid_files()
return get_all_worker_pids(cls.name)
@classmethod
def get_worker_count(cls) -> int:
"""Get count of running workers of this type."""
return len(cls.get_running_workers())
class CrawlWorker(Worker):
"""Worker for processing Crawl objects."""
name: ClassVar[str] = 'crawl'
MAX_TICK_TIME: ClassVar[int] = 60
def get_model(self):
from crawls.models import Crawl
return Crawl
class SnapshotWorker(Worker):
"""Worker for processing Snapshot objects."""
name: ClassVar[str] = 'snapshot'
MAX_TICK_TIME: ClassVar[int] = 60
def get_model(self):
from core.models import Snapshot
return Snapshot
class ArchiveResultWorker(Worker):
"""Worker for processing ArchiveResult objects."""
name: ClassVar[str] = 'archiveresult'
MAX_TICK_TIME: ClassVar[int] = 120
def __init__(self, plugin: str | None = None, **kwargs: Any):
super().__init__(**kwargs)
self.plugin = plugin
def get_model(self):
from core.models import ArchiveResult
return ArchiveResult
def get_queue(self) -> QuerySet:
"""
Get queue of ArchiveResults ready for processing.
Uses step-based filtering: only claims ARs where hook step <= snapshot.current_step.
This ensures hooks execute in order (step 0 → 1 → 2 ... → 9).
"""
from core.models import ArchiveResult
from archivebox.hooks import extract_step
qs = super().get_queue()
if self.plugin:
qs = qs.filter(plugin=self.plugin)
# Step-based filtering: only process ARs whose step <= snapshot.current_step
# Since step is derived from hook_name, we filter in Python after initial query
# This is efficient because the base query already filters by retry_at and status
# Get candidate ARs
candidates = list(qs[:50]) # Limit to avoid loading too many
ready_pks = []
for ar in candidates:
if not ar.hook_name:
# Legacy ARs without hook_name - process them
ready_pks.append(ar.pk)
continue
ar_step = extract_step(ar.hook_name)
snapshot_step = ar.snapshot.current_step
if ar_step <= snapshot_step:
ready_pks.append(ar.pk)
# Return filtered queryset ordered by hook_name (so earlier hooks run first within a step)
return ArchiveResult.objects.filter(pk__in=ready_pks).order_by('hook_name', 'retry_at')
def process_item(self, obj) -> bool:
"""Process an ArchiveResult by running its plugin."""
try:
obj.sm.tick()
return True
except Exception as e:
# Error will be logged in runloop's completion event
traceback.print_exc()
return False
@classmethod
def start(cls, worker_id: int | None = None, daemon: bool = False, plugin: str | None = None, **kwargs: Any) -> int:
"""Fork a new worker as subprocess with optional plugin filter."""
if worker_id is None:
worker_id = get_next_worker_id(cls.name)
# Use module-level function for pickling compatibility
proc = Process(
target=_run_worker,
args=(cls.name, worker_id, daemon),
kwargs={'plugin': plugin, **kwargs},
name=f'{cls.name}_worker_{worker_id}',
)
proc.start()
assert proc.pid is not None
return proc.pid
# Populate the registry
WORKER_TYPES.update({
'crawl': CrawlWorker,
'snapshot': SnapshotWorker,
'archiveresult': ArchiveResultWorker,
})
def get_worker_class(name: str) -> type[Worker]:
"""Get worker class by name."""
if name not in WORKER_TYPES:
raise ValueError(f'Unknown worker type: {name}. Valid types: {list(WORKER_TYPES.keys())}')
return WORKER_TYPES[name]