mirror of
https://github.com/ArchiveBox/ArchiveBox.git
synced 2026-01-03 09:25:42 +10:00
465 lines
16 KiB
Python
465 lines
16 KiB
Python
"""
|
|
Worker classes for processing queue items.
|
|
|
|
Workers poll the database for items to process, claim them atomically,
|
|
and run the state machine tick() to process each item.
|
|
|
|
Architecture:
|
|
Orchestrator (spawns workers)
|
|
└── Worker (claims items from queue, processes them directly)
|
|
"""
|
|
|
|
__package__ = 'archivebox.workers'
|
|
|
|
import os
|
|
import time
|
|
import traceback
|
|
from typing import ClassVar, Any
|
|
from datetime import timedelta
|
|
from pathlib import Path
|
|
from multiprocessing import Process as MPProcess, cpu_count
|
|
|
|
from django.db.models import QuerySet
|
|
from django.utils import timezone
|
|
from django.conf import settings
|
|
|
|
from rich import print
|
|
|
|
from archivebox.misc.logging_util import log_worker_event
|
|
|
|
|
|
CPU_COUNT = cpu_count()
|
|
|
|
# Registry of worker types by name (defined at bottom, referenced here for _run_worker)
|
|
WORKER_TYPES: dict[str, type['Worker']] = {}
|
|
|
|
|
|
def _run_worker(worker_class_name: str, worker_id: int, daemon: bool, **kwargs):
|
|
"""
|
|
Module-level function to run a worker. Must be at module level for pickling.
|
|
"""
|
|
from archivebox.config.django import setup_django
|
|
setup_django()
|
|
|
|
# Get worker class by name to avoid pickling class objects
|
|
worker_cls = WORKER_TYPES[worker_class_name]
|
|
worker = worker_cls(worker_id=worker_id, daemon=daemon, **kwargs)
|
|
worker.runloop()
|
|
|
|
|
|
class Worker:
|
|
"""
|
|
Base worker class that polls a queue and processes items directly.
|
|
|
|
Each item is processed by calling its state machine tick() method.
|
|
Workers exit when idle for too long (unless daemon mode).
|
|
"""
|
|
|
|
name: ClassVar[str] = 'worker'
|
|
|
|
# Configuration (can be overridden by subclasses)
|
|
MAX_TICK_TIME: ClassVar[int] = 60
|
|
MAX_CONCURRENT_TASKS: ClassVar[int] = 1
|
|
POLL_INTERVAL: ClassVar[float] = 0.2 # How often to check for new work (seconds)
|
|
IDLE_TIMEOUT: ClassVar[int] = 50 # Exit after N idle iterations (10 sec at 0.2 poll interval)
|
|
|
|
def __init__(self, worker_id: int = 0, daemon: bool = False, crawl_id: str | None = None, **kwargs: Any):
|
|
self.worker_id = worker_id
|
|
self.daemon = daemon
|
|
self.crawl_id = crawl_id # If set, only process work for this crawl
|
|
self.pid: int = os.getpid()
|
|
self.pid_file: Path | None = None
|
|
self.idle_count: int = 0
|
|
|
|
def __repr__(self) -> str:
|
|
return f'[underline]{self.__class__.__name__}[/underline]\\[id={self.worker_id}, pid={self.pid}]'
|
|
|
|
def get_model(self):
|
|
"""Get the Django model class. Subclasses must override this."""
|
|
raise NotImplementedError("Subclasses must implement get_model()")
|
|
|
|
def get_queue(self) -> QuerySet:
|
|
"""Get the queue of objects ready for processing."""
|
|
Model = self.get_model()
|
|
return Model.objects.filter(
|
|
retry_at__lte=timezone.now()
|
|
).exclude(
|
|
status__in=Model.FINAL_STATES
|
|
).order_by('retry_at')
|
|
|
|
def claim_next(self):
|
|
"""
|
|
Atomically claim the next object from the queue.
|
|
Returns the claimed object or None if queue is empty or claim failed.
|
|
"""
|
|
Model = self.get_model()
|
|
obj = self.get_queue().first()
|
|
if obj is None:
|
|
return None
|
|
|
|
# Atomic claim using optimistic locking on retry_at
|
|
claimed = Model.objects.filter(
|
|
pk=obj.pk,
|
|
retry_at=obj.retry_at,
|
|
).update(
|
|
retry_at=timezone.now() + timedelta(seconds=self.MAX_TICK_TIME)
|
|
)
|
|
|
|
if claimed == 1:
|
|
obj.refresh_from_db()
|
|
return obj
|
|
|
|
return None # Someone else claimed it
|
|
|
|
def process_item(self, obj) -> bool:
|
|
"""
|
|
Process a single item by calling its state machine tick().
|
|
Returns True on success, False on failure.
|
|
Subclasses can override for custom processing.
|
|
"""
|
|
try:
|
|
obj.sm.tick()
|
|
return True
|
|
except Exception as e:
|
|
# Error will be logged in runloop's completion event
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
def on_startup(self) -> None:
|
|
"""Called when worker starts."""
|
|
from archivebox.machine.models import Process
|
|
|
|
self.pid = os.getpid()
|
|
# Register this worker process in the database
|
|
self.db_process = Process.current()
|
|
# Explicitly set process_type to WORKER to prevent mis-detection
|
|
if self.db_process.process_type != Process.TypeChoices.WORKER:
|
|
self.db_process.process_type = Process.TypeChoices.WORKER
|
|
self.db_process.save(update_fields=['process_type'])
|
|
|
|
# Determine worker type for logging
|
|
worker_type_name = self.__class__.__name__
|
|
indent_level = 1 # Default for most workers
|
|
|
|
# Adjust indent level based on worker type
|
|
if 'Snapshot' in worker_type_name:
|
|
indent_level = 2
|
|
elif 'ArchiveResult' in worker_type_name:
|
|
indent_level = 3
|
|
|
|
log_worker_event(
|
|
worker_type=worker_type_name,
|
|
event='Starting...',
|
|
indent_level=indent_level,
|
|
pid=self.pid,
|
|
worker_id=str(self.worker_id),
|
|
metadata={
|
|
'max_concurrent': self.MAX_CONCURRENT_TASKS,
|
|
'poll_interval': self.POLL_INTERVAL,
|
|
},
|
|
)
|
|
|
|
def on_shutdown(self, error: BaseException | None = None) -> None:
|
|
"""Called when worker shuts down."""
|
|
# Update Process record status
|
|
if hasattr(self, 'db_process') and self.db_process:
|
|
self.db_process.exit_code = 1 if error else 0
|
|
self.db_process.status = self.db_process.StatusChoices.EXITED
|
|
self.db_process.ended_at = timezone.now()
|
|
self.db_process.save()
|
|
|
|
# Determine worker type for logging
|
|
worker_type_name = self.__class__.__name__
|
|
indent_level = 1
|
|
|
|
if 'Snapshot' in worker_type_name:
|
|
indent_level = 2
|
|
elif 'ArchiveResult' in worker_type_name:
|
|
indent_level = 3
|
|
|
|
log_worker_event(
|
|
worker_type=worker_type_name,
|
|
event='Shutting down',
|
|
indent_level=indent_level,
|
|
pid=self.pid,
|
|
worker_id=str(self.worker_id),
|
|
error=error if error and not isinstance(error, KeyboardInterrupt) else None,
|
|
)
|
|
|
|
def should_exit(self) -> bool:
|
|
"""Check if worker should exit due to idle timeout."""
|
|
if self.daemon:
|
|
return False
|
|
|
|
if self.IDLE_TIMEOUT == 0:
|
|
return False
|
|
|
|
return self.idle_count >= self.IDLE_TIMEOUT
|
|
|
|
def runloop(self) -> None:
|
|
"""Main worker loop - polls queue, processes items."""
|
|
self.on_startup()
|
|
|
|
# Determine worker type for logging
|
|
worker_type_name = self.__class__.__name__
|
|
indent_level = 1
|
|
|
|
if 'Snapshot' in worker_type_name:
|
|
indent_level = 2
|
|
elif 'ArchiveResult' in worker_type_name:
|
|
indent_level = 3
|
|
|
|
try:
|
|
while True:
|
|
# Try to claim and process an item
|
|
obj = self.claim_next()
|
|
|
|
if obj is not None:
|
|
self.idle_count = 0
|
|
|
|
# Build metadata for task start
|
|
start_metadata = {}
|
|
url = None
|
|
if hasattr(obj, 'url'):
|
|
# SnapshotWorker
|
|
url = str(obj.url) if obj.url else None
|
|
elif hasattr(obj, 'snapshot') and hasattr(obj.snapshot, 'url'):
|
|
# ArchiveResultWorker
|
|
url = str(obj.snapshot.url) if obj.snapshot.url else None
|
|
elif hasattr(obj, 'get_urls_list'):
|
|
# CrawlWorker
|
|
urls = obj.get_urls_list()
|
|
url = urls[0] if urls else None
|
|
|
|
plugin = None
|
|
if hasattr(obj, 'plugin'):
|
|
# ArchiveResultWorker, Crawl
|
|
plugin = obj.plugin
|
|
|
|
log_worker_event(
|
|
worker_type=worker_type_name,
|
|
event='Starting...',
|
|
indent_level=indent_level,
|
|
pid=self.pid,
|
|
worker_id=str(self.worker_id),
|
|
url=url,
|
|
plugin=plugin,
|
|
metadata=start_metadata if start_metadata else None,
|
|
)
|
|
|
|
start_time = time.time()
|
|
success = self.process_item(obj)
|
|
elapsed = time.time() - start_time
|
|
|
|
# Build metadata for task completion
|
|
complete_metadata = {
|
|
'duration': elapsed,
|
|
'status': 'success' if success else 'failed',
|
|
}
|
|
|
|
log_worker_event(
|
|
worker_type=worker_type_name,
|
|
event='Completed' if success else 'Failed',
|
|
indent_level=indent_level,
|
|
pid=self.pid,
|
|
worker_id=str(self.worker_id),
|
|
url=url,
|
|
plugin=plugin,
|
|
metadata=complete_metadata,
|
|
)
|
|
else:
|
|
# No work available - idle logging suppressed
|
|
self.idle_count += 1
|
|
|
|
# Check if we should exit
|
|
if self.should_exit():
|
|
# Exit logging suppressed - shutdown will be logged by on_shutdown()
|
|
break
|
|
|
|
time.sleep(self.POLL_INTERVAL)
|
|
|
|
except KeyboardInterrupt:
|
|
pass
|
|
except BaseException as e:
|
|
self.on_shutdown(error=e)
|
|
raise
|
|
else:
|
|
self.on_shutdown()
|
|
|
|
@classmethod
|
|
def start(cls, worker_id: int | None = None, daemon: bool = False, **kwargs: Any) -> int:
|
|
"""
|
|
Fork a new worker as a subprocess.
|
|
Returns the PID of the new process.
|
|
"""
|
|
from archivebox.machine.models import Process
|
|
|
|
if worker_id is None:
|
|
worker_id = Process.get_next_worker_id(process_type=Process.TypeChoices.WORKER)
|
|
|
|
# Use module-level function for pickling compatibility
|
|
proc = MPProcess(
|
|
target=_run_worker,
|
|
args=(cls.name, worker_id, daemon),
|
|
kwargs=kwargs,
|
|
name=f'{cls.name}_worker_{worker_id}',
|
|
)
|
|
proc.start()
|
|
|
|
assert proc.pid is not None
|
|
return proc.pid
|
|
|
|
@classmethod
|
|
def get_running_workers(cls) -> list:
|
|
"""Get info about all running workers of this type."""
|
|
from archivebox.machine.models import Process
|
|
|
|
Process.cleanup_stale_running()
|
|
# Convert Process objects to dicts to match the expected API contract
|
|
processes = Process.get_running(process_type=Process.TypeChoices.WORKER)
|
|
# Note: worker_id is not stored on Process model, it's dynamically generated
|
|
# We return process_id (UUID) and pid (OS process ID) instead
|
|
return [
|
|
{
|
|
'pid': p.pid,
|
|
'process_id': str(p.id), # UUID of Process record
|
|
'started_at': p.started_at.isoformat() if p.started_at else None,
|
|
'status': p.status,
|
|
}
|
|
for p in processes
|
|
]
|
|
|
|
@classmethod
|
|
def get_worker_count(cls) -> int:
|
|
"""Get count of running workers of this type."""
|
|
from archivebox.machine.models import Process
|
|
|
|
return Process.get_running_count(process_type=Process.TypeChoices.WORKER)
|
|
|
|
|
|
class CrawlWorker(Worker):
|
|
"""Worker for processing Crawl objects."""
|
|
|
|
name: ClassVar[str] = 'crawl'
|
|
MAX_TICK_TIME: ClassVar[int] = 60
|
|
|
|
def get_model(self):
|
|
from archivebox.crawls.models import Crawl
|
|
return Crawl
|
|
|
|
def get_queue(self) -> QuerySet:
|
|
"""Get queue of Crawls ready for processing, optionally filtered by crawl_id."""
|
|
qs = super().get_queue()
|
|
if self.crawl_id:
|
|
qs = qs.filter(id=self.crawl_id)
|
|
return qs
|
|
|
|
|
|
class SnapshotWorker(Worker):
|
|
"""Worker for processing Snapshot objects."""
|
|
|
|
name: ClassVar[str] = 'snapshot'
|
|
MAX_TICK_TIME: ClassVar[int] = 60
|
|
|
|
def get_model(self):
|
|
from archivebox.core.models import Snapshot
|
|
return Snapshot
|
|
|
|
|
|
class ArchiveResultWorker(Worker):
|
|
"""Worker for processing ArchiveResult objects."""
|
|
|
|
name: ClassVar[str] = 'archiveresult'
|
|
MAX_TICK_TIME: ClassVar[int] = 120
|
|
|
|
def __init__(self, plugin: str | None = None, **kwargs: Any):
|
|
super().__init__(**kwargs)
|
|
self.plugin = plugin
|
|
|
|
def get_model(self):
|
|
from archivebox.core.models import ArchiveResult
|
|
return ArchiveResult
|
|
|
|
def get_queue(self) -> QuerySet:
|
|
"""
|
|
Get queue of ArchiveResults ready for processing.
|
|
|
|
Uses step-based filtering: only claims ARs where hook step <= snapshot.current_step.
|
|
This ensures hooks execute in order (step 0 → 1 → 2 ... → 9).
|
|
"""
|
|
from archivebox.core.models import ArchiveResult
|
|
from archivebox.hooks import extract_step
|
|
|
|
qs = super().get_queue()
|
|
|
|
if self.plugin:
|
|
qs = qs.filter(plugin=self.plugin)
|
|
|
|
# Step-based filtering: only process ARs whose step <= snapshot.current_step
|
|
# Since step is derived from hook_name, we filter in Python after initial query
|
|
# This is efficient because the base query already filters by retry_at and status
|
|
|
|
# Get candidate ARs
|
|
candidates = list(qs[:50]) # Limit to avoid loading too many
|
|
ready_pks = []
|
|
|
|
for ar in candidates:
|
|
if not ar.hook_name:
|
|
# Legacy ARs without hook_name - process them
|
|
ready_pks.append(ar.pk)
|
|
continue
|
|
|
|
ar_step = extract_step(ar.hook_name)
|
|
snapshot_step = ar.snapshot.current_step
|
|
|
|
if ar_step <= snapshot_step:
|
|
ready_pks.append(ar.pk)
|
|
|
|
# Return filtered queryset ordered by hook_name (so earlier hooks run first within a step)
|
|
return ArchiveResult.objects.filter(pk__in=ready_pks).order_by('hook_name', 'retry_at')
|
|
|
|
def process_item(self, obj) -> bool:
|
|
"""Process an ArchiveResult by running its plugin."""
|
|
try:
|
|
obj.sm.tick()
|
|
return True
|
|
except Exception as e:
|
|
# Error will be logged in runloop's completion event
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
@classmethod
|
|
def start(cls, worker_id: int | None = None, daemon: bool = False, plugin: str | None = None, **kwargs: Any) -> int:
|
|
"""Fork a new worker as subprocess with optional plugin filter."""
|
|
from archivebox.machine.models import Process
|
|
|
|
if worker_id is None:
|
|
worker_id = Process.get_next_worker_id(process_type=Process.TypeChoices.WORKER)
|
|
|
|
# Use module-level function for pickling compatibility
|
|
proc = MPProcess(
|
|
target=_run_worker,
|
|
args=(cls.name, worker_id, daemon),
|
|
kwargs={'plugin': plugin, **kwargs},
|
|
name=f'{cls.name}_worker_{worker_id}',
|
|
)
|
|
proc.start()
|
|
|
|
assert proc.pid is not None
|
|
return proc.pid
|
|
|
|
|
|
# Populate the registry
|
|
WORKER_TYPES.update({
|
|
'crawl': CrawlWorker,
|
|
'snapshot': SnapshotWorker,
|
|
'archiveresult': ArchiveResultWorker,
|
|
})
|
|
|
|
|
|
def get_worker_class(name: str) -> type[Worker]:
|
|
"""Get worker class by name."""
|
|
if name not in WORKER_TYPES:
|
|
raise ValueError(f'Unknown worker type: {name}. Valid types: {list(WORKER_TYPES.keys())}')
|
|
return WORKER_TYPES[name]
|