Files
ArchiveBox/archivebox/workers/worker.py

735 lines
27 KiB
Python

"""
Worker classes for processing queue items.
Workers poll the database for items to process, claim them atomically,
and run the state machine tick() to process each item.
Architecture:
Orchestrator (spawns workers)
└── Worker (claims items from queue, processes them directly)
"""
__package__ = 'archivebox.workers'
import os
import time
import traceback
from typing import ClassVar, Any
from datetime import timedelta
from pathlib import Path
from multiprocessing import Process as MPProcess, cpu_count
from django.db.models import QuerySet
from django.utils import timezone
from django.conf import settings
from rich import print
from archivebox.misc.logging_util import log_worker_event
CPU_COUNT = cpu_count()
# Registry of worker types by name (defined at bottom, referenced here for _run_worker)
WORKER_TYPES: dict[str, type['Worker']] = {}
def _run_worker(worker_class_name: str, worker_id: int, **kwargs):
"""
Module-level function to run a worker. Must be at module level for pickling.
"""
from archivebox.config.django import setup_django
setup_django()
# Get worker class by name to avoid pickling class objects
worker_cls = WORKER_TYPES[worker_class_name]
worker = worker_cls(worker_id=worker_id, **kwargs)
worker.runloop()
def _run_snapshot_worker(snapshot_id: str, worker_id: int, **kwargs):
"""
Module-level function to run a SnapshotWorker for a specific snapshot.
Must be at module level for pickling compatibility.
"""
from archivebox.config.django import setup_django
setup_django()
worker = SnapshotWorker(snapshot_id=snapshot_id, worker_id=worker_id, **kwargs)
worker.runloop()
class Worker:
"""
Base worker class for CrawlWorker and SnapshotWorker.
Workers are spawned as subprocesses to process crawls and snapshots.
Each worker type has its own custom runloop implementation.
"""
name: ClassVar[str] = 'worker'
# Configuration (can be overridden by subclasses)
MAX_TICK_TIME: ClassVar[int] = 60
MAX_CONCURRENT_TASKS: ClassVar[int] = 1
def __init__(self, worker_id: int = 0, **kwargs: Any):
self.worker_id = worker_id
self.pid: int = os.getpid()
def __repr__(self) -> str:
return f'[underline]{self.__class__.__name__}[/underline]\\[id={self.worker_id}, pid={self.pid}]'
def get_model(self):
"""Get the Django model class. Subclasses must override this."""
raise NotImplementedError("Subclasses must implement get_model()")
def on_startup(self) -> None:
"""Called when worker starts."""
from archivebox.machine.models import Process
self.pid = os.getpid()
# Register this worker process in the database
self.db_process = Process.current()
# Explicitly set process_type to WORKER and store worker type name
update_fields = []
if self.db_process.process_type != Process.TypeChoices.WORKER:
self.db_process.process_type = Process.TypeChoices.WORKER
update_fields.append('process_type')
# Store worker type name (crawl/snapshot) in worker_type field
if not self.db_process.worker_type:
self.db_process.worker_type = self.name
update_fields.append('worker_type')
if update_fields:
self.db_process.save(update_fields=update_fields)
# Determine worker type for logging
worker_type_name = self.__class__.__name__
indent_level = 1 # Default for CrawlWorker
# SnapshotWorker gets indent level 2
if 'Snapshot' in worker_type_name:
indent_level = 2
log_worker_event(
worker_type=worker_type_name,
event='Starting...',
indent_level=indent_level,
pid=self.pid,
worker_id=str(self.worker_id),
)
def on_shutdown(self, error: BaseException | None = None) -> None:
"""Called when worker shuts down."""
# Update Process record status
if hasattr(self, 'db_process') and self.db_process:
self.db_process.exit_code = 1 if error else 0
self.db_process.status = self.db_process.StatusChoices.EXITED
self.db_process.ended_at = timezone.now()
self.db_process.save()
# Determine worker type for logging
worker_type_name = self.__class__.__name__
indent_level = 1 # CrawlWorker
if 'Snapshot' in worker_type_name:
indent_level = 2
log_worker_event(
worker_type=worker_type_name,
event='Shutting down',
indent_level=indent_level,
pid=self.pid,
worker_id=str(self.worker_id),
error=error if error and not isinstance(error, KeyboardInterrupt) else None,
)
def _terminate_background_hooks(
self,
background_processes: dict[str, 'Process'],
worker_type: str,
indent_level: int,
) -> None:
"""
Terminate background hooks in 3 phases (shared logic for Crawl/Snapshot workers).
Phase 1: Send SIGTERM to all bg hooks + children in parallel (polite request to wrap up)
Phase 2: Wait for each hook's remaining timeout before SIGKILL
Phase 3: SIGKILL any stragglers that exceeded their timeout
Args:
background_processes: Dict mapping hook name -> Process instance
worker_type: Worker type name for logging (e.g., 'CrawlWorker', 'SnapshotWorker')
indent_level: Logging indent level (1 for Crawl, 2 for Snapshot)
"""
import signal
import time
if not background_processes:
return
now = time.time()
# Phase 1: Send SIGTERM to ALL background processes + children in parallel
log_worker_event(
worker_type=worker_type,
event=f'Sending SIGTERM to {len(background_processes)} background hooks (+ children)',
indent_level=indent_level,
pid=self.pid,
)
# Build deadline map first (before killing, to get accurate remaining time)
deadlines = {}
for hook_name, process in background_processes.items():
elapsed = now - process.started_at.timestamp()
remaining = max(0, process.timeout - elapsed)
deadline = now + remaining
deadlines[hook_name] = (process, deadline)
# Send SIGTERM to all process trees in parallel (non-blocking)
for hook_name, process in background_processes.items():
try:
# Get chrome children (renderer processes etc) before sending signal
children_pids = process.get_children_pids()
if children_pids:
# Chrome hook with children - kill tree
os.kill(process.pid, signal.SIGTERM)
for child_pid in children_pids:
try:
os.kill(child_pid, signal.SIGTERM)
except ProcessLookupError:
pass
log_worker_event(
worker_type=worker_type,
event=f'Sent SIGTERM to {hook_name} + {len(children_pids)} children',
indent_level=indent_level,
pid=self.pid,
)
else:
# No children - normal kill
os.kill(process.pid, signal.SIGTERM)
except ProcessLookupError:
pass # Already dead
except Exception as e:
log_worker_event(
worker_type=worker_type,
event=f'Failed to SIGTERM {hook_name}: {e}',
indent_level=indent_level,
pid=self.pid,
)
# Phase 2: Wait for all processes in parallel, respecting individual timeouts
for hook_name, (process, deadline) in deadlines.items():
remaining = deadline - now
log_worker_event(
worker_type=worker_type,
event=f'Waiting up to {remaining:.1f}s for {hook_name}',
indent_level=indent_level,
pid=self.pid,
)
# Poll all processes in parallel using Process.poll()
still_running = set(deadlines.keys())
while still_running:
time.sleep(0.1)
now = time.time()
for hook_name in list(still_running):
process, deadline = deadlines[hook_name]
# Check if process exited using Process.poll()
exit_code = process.poll()
if exit_code is not None:
# Process exited
still_running.remove(hook_name)
log_worker_event(
worker_type=worker_type,
event=f'{hook_name} exited with code {exit_code}',
indent_level=indent_level,
pid=self.pid,
)
continue
# Check if deadline exceeded
if now >= deadline:
# Timeout exceeded - SIGKILL process tree
try:
# Get children before killing (chrome may have spawned more)
children_pids = process.get_children_pids()
if children_pids:
# Kill children first
for child_pid in children_pids:
try:
os.kill(child_pid, signal.SIGKILL)
except ProcessLookupError:
pass
# Then kill parent
process.kill(signal_num=signal.SIGKILL)
log_worker_event(
worker_type=worker_type,
event=f'⚠ Sent SIGKILL to {hook_name} + {len(children_pids) if children_pids else 0} children (exceeded timeout)',
indent_level=indent_level,
pid=self.pid,
)
except Exception as e:
log_worker_event(
worker_type=worker_type,
event=f'Failed to SIGKILL {hook_name}: {e}',
indent_level=indent_level,
pid=self.pid,
)
still_running.remove(hook_name)
@classmethod
def start(cls, **kwargs: Any) -> int:
"""
Fork a new worker as a subprocess.
Returns the PID of the new process.
"""
from archivebox.machine.models import Process
worker_id = Process.get_next_worker_id(process_type=Process.TypeChoices.WORKER)
# Use module-level function for pickling compatibility
proc = MPProcess(
target=_run_worker,
args=(cls.name, worker_id),
kwargs=kwargs,
name=f'{cls.name}_worker_{worker_id}',
)
proc.start()
assert proc.pid is not None
return proc.pid
@classmethod
def get_running_workers(cls) -> list:
"""Get info about all running workers of this type."""
from archivebox.machine.models import Process
Process.cleanup_stale_running()
# Convert Process objects to dicts to match the expected API contract
# Filter by worker_type to get only workers of this specific type (crawl/snapshot/archiveresult)
processes = Process.objects.filter(
process_type=Process.TypeChoices.WORKER,
worker_type=cls.name, # Filter by specific worker type
status__in=['running', 'started']
)
# Note: worker_id is not stored on Process model, it's dynamically generated
# We return process_id (UUID) and pid (OS process ID) instead
return [
{
'pid': p.pid,
'process_id': str(p.id), # UUID of Process record
'started_at': p.started_at.isoformat() if p.started_at else None,
'status': p.status,
}
for p in processes
]
@classmethod
def get_worker_count(cls) -> int:
"""Get count of running workers of this type."""
from archivebox.machine.models import Process
return Process.objects.filter(
process_type=Process.TypeChoices.WORKER,
worker_type=cls.name, # Filter by specific worker type
status__in=['running', 'started']
).count()
class CrawlWorker(Worker):
"""
Worker for processing Crawl objects.
Responsibilities:
1. Run on_Crawl__* hooks (e.g., chrome launcher)
2. Create Snapshots from URLs
3. Spawn SnapshotWorkers (up to MAX_SNAPSHOT_WORKERS)
4. Monitor snapshots and seal crawl when all done
"""
name: ClassVar[str] = 'crawl'
MAX_TICK_TIME: ClassVar[int] = 60
MAX_SNAPSHOT_WORKERS: ClassVar[int] = 8 # Per crawl limit
def __init__(self, crawl_id: str, **kwargs: Any):
super().__init__(**kwargs)
self.crawl_id = crawl_id
self.crawl = None
def get_model(self):
from archivebox.crawls.models import Crawl
return Crawl
def on_startup(self) -> None:
"""Load crawl."""
super().on_startup()
from archivebox.crawls.models import Crawl
self.crawl = Crawl.objects.get(id=self.crawl_id)
def runloop(self) -> None:
"""Run crawl state machine, spawn SnapshotWorkers."""
import sys
self.on_startup()
try:
print(f'[cyan]🔄 CrawlWorker.runloop: Starting tick() for crawl {self.crawl_id}[/cyan]', file=sys.stderr)
# Advance state machine: QUEUED → STARTED (triggers run() via @started.enter)
self.crawl.sm.tick()
self.crawl.refresh_from_db()
print(f'[cyan]🔄 tick() complete, crawl status={self.crawl.status}[/cyan]', file=sys.stderr)
# Now spawn SnapshotWorkers and monitor progress
while True:
# Check if crawl is done
if self._is_crawl_finished():
print(f'[cyan]🔄 Crawl finished, sealing...[/cyan]', file=sys.stderr)
self.crawl.sm.seal()
break
# Spawn workers for queued snapshots
self._spawn_snapshot_workers()
time.sleep(2) # Check every 2s
finally:
self.on_shutdown()
def _spawn_snapshot_workers(self) -> None:
"""Spawn SnapshotWorkers for queued snapshots (up to limit)."""
from archivebox.core.models import Snapshot
from archivebox.machine.models import Process
# Count running SnapshotWorkers for this crawl
running_count = Process.objects.filter(
process_type=Process.TypeChoices.WORKER,
worker_type='snapshot',
parent_id=self.db_process.id, # Children of this CrawlWorker
status__in=['running', 'started'],
).count()
if running_count >= self.MAX_SNAPSHOT_WORKERS:
return # At limit
# Get queued snapshots for this crawl (SnapshotWorker will mark as STARTED in on_startup)
queued_snapshots = Snapshot.objects.filter(
crawl_id=self.crawl_id,
status=Snapshot.StatusChoices.QUEUED,
).order_by('created_at')[:self.MAX_SNAPSHOT_WORKERS - running_count]
import sys
print(f'[yellow]🔧 _spawn_snapshot_workers: running={running_count}/{self.MAX_SNAPSHOT_WORKERS}, queued={queued_snapshots.count()}[/yellow]', file=sys.stderr)
# Spawn workers
for snapshot in queued_snapshots:
print(f'[yellow]🔧 Spawning worker for {snapshot.url} (status={snapshot.status})[/yellow]', file=sys.stderr)
SnapshotWorker.start(snapshot_id=str(snapshot.id))
log_worker_event(
worker_type='CrawlWorker',
event=f'Spawned SnapshotWorker for {snapshot.url}',
indent_level=1,
pid=self.pid,
)
def _is_crawl_finished(self) -> bool:
"""Check if all snapshots are sealed."""
from archivebox.core.models import Snapshot
pending = Snapshot.objects.filter(
crawl_id=self.crawl_id,
status__in=[Snapshot.StatusChoices.QUEUED, Snapshot.StatusChoices.STARTED],
).count()
return pending == 0
def on_shutdown(self, error: BaseException | None = None) -> None:
"""
Terminate all background Crawl hooks when crawl finishes.
Background hooks (e.g., chrome launcher) should only be killed when:
- All snapshots are done (crawl is sealed)
- Worker is shutting down
"""
from archivebox.machine.models import Process
# Query for all running hook processes that are children of this CrawlWorker
background_hooks = Process.objects.filter(
parent_id=self.db_process.id,
process_type=Process.TypeChoices.HOOK,
status=Process.StatusChoices.RUNNING,
).select_related('machine')
# Build dict for shared termination logic
background_processes = {
hook.cmd[0] if hook.cmd else f'hook-{hook.pid}': hook
for hook in background_hooks
}
# Use shared termination logic from Worker base class
self._terminate_background_hooks(
background_processes=background_processes,
worker_type='CrawlWorker',
indent_level=1,
)
super().on_shutdown(error)
class SnapshotWorker(Worker):
"""
Worker that owns sequential hook execution for ONE snapshot.
Unlike other workers, SnapshotWorker doesn't poll a queue - it's given
a specific snapshot_id and runs all hooks for that snapshot sequentially.
Execution flow:
1. Mark snapshot as STARTED
2. Discover hooks for snapshot
3. For each hook (sorted by name):
a. Fork hook Process
b. If foreground: wait for completion
c. If background: track but continue to next hook
d. Update ArchiveResult status
e. Advance current_step when all step's hooks complete
4. When all hooks done: seal snapshot
5. On shutdown: SIGTERM all background hooks
"""
name: ClassVar[str] = 'snapshot'
def __init__(self, snapshot_id: str, **kwargs: Any):
super().__init__(**kwargs)
self.snapshot_id = snapshot_id
self.snapshot = None
self.background_processes: dict[str, Any] = {} # hook_name -> Process
def get_model(self):
"""Not used - SnapshotWorker doesn't poll queues."""
from archivebox.core.models import Snapshot
return Snapshot
def on_startup(self) -> None:
"""Load snapshot and mark as STARTED using state machine."""
super().on_startup()
from archivebox.core.models import Snapshot
self.snapshot = Snapshot.objects.get(id=self.snapshot_id)
# Use state machine to transition queued -> started (triggers enter_started())
self.snapshot.sm.tick()
self.snapshot.refresh_from_db()
def runloop(self) -> None:
"""Execute all hooks sequentially."""
from archivebox.hooks import discover_hooks, is_background_hook, extract_step
from archivebox.core.models import ArchiveResult
self.on_startup()
try:
# Discover all hooks for this snapshot
hooks = discover_hooks('Snapshot', config=self.snapshot.config)
hooks = sorted(hooks, key=lambda h: h.name) # Sort by name (includes step prefix)
# Execute each hook sequentially
for hook_path in hooks:
hook_name = hook_path.name
plugin = self._extract_plugin_name(hook_name)
hook_step = extract_step(hook_name)
is_background = is_background_hook(hook_name)
# Create ArchiveResult for THIS HOOK (not per plugin)
# One plugin can have multiple hooks (e.g., chrome/on_Snapshot__20_launch_chrome.js, chrome/on_Snapshot__21_navigate_chrome.js)
# Unique key = (snapshot, plugin, hook_name) for idempotency
ar, created = ArchiveResult.objects.get_or_create(
snapshot=self.snapshot,
plugin=plugin,
hook_name=hook_name,
defaults={
'status': ArchiveResult.StatusChoices.STARTED,
'start_ts': timezone.now(),
}
)
if not created:
# Update existing AR to STARTED
ar.status = ArchiveResult.StatusChoices.STARTED
ar.start_ts = timezone.now()
ar.save(update_fields=['status', 'start_ts', 'modified_at'])
# Fork and run the hook
process = self._run_hook(hook_path, ar)
if is_background:
# Track but don't wait
self.background_processes[hook_name] = process
log_worker_event(
worker_type='SnapshotWorker',
event=f'Started background hook: {hook_name} (timeout={process.timeout}s)',
indent_level=2,
pid=self.pid,
)
else:
# Wait for foreground hook to complete
self._wait_for_hook(process, ar)
log_worker_event(
worker_type='SnapshotWorker',
event=f'Completed hook: {hook_name}',
indent_level=2,
pid=self.pid,
)
# Check if we can advance to next step
self._try_advance_step()
# All hooks launched (or completed) - seal using state machine
# This triggers enter_sealed() which calls cleanup() and checks parent crawl sealing
self.snapshot.sm.seal()
self.snapshot.refresh_from_db()
except Exception as e:
# Mark snapshot as sealed even on error (still triggers cleanup)
self.snapshot.sm.seal()
self.snapshot.refresh_from_db()
raise
finally:
self.on_shutdown()
def _run_hook(self, hook_path: Path, ar: Any) -> Any:
"""Fork and run a hook using Process model, return Process."""
from archivebox.hooks import run_hook
# Create output directory
output_dir = ar.create_output_dir()
# Run hook using Process.launch() - returns Process model directly
# Pass self.db_process as parent to track SnapshotWorker -> Hook hierarchy
process = run_hook(
script=hook_path,
output_dir=output_dir,
config=self.snapshot.config,
timeout=120,
parent=self.db_process,
url=str(self.snapshot.url),
snapshot_id=str(self.snapshot.id),
)
# Link ArchiveResult to Process for tracking
ar.process = process
ar.save(update_fields=['process_id', 'modified_at'])
return process
def _wait_for_hook(self, process: Any, ar: Any) -> None:
"""Wait for hook using Process.wait(), update AR status."""
# Use Process.wait() helper instead of manual polling
try:
exit_code = process.wait(timeout=process.timeout)
except TimeoutError:
# Hook exceeded timeout - kill it
process.kill(signal_num=9)
exit_code = -1
# Update ArchiveResult from hook output
ar.update_from_output()
ar.end_ts = timezone.now()
# Determine final status from hook exit code
if exit_code == 0:
ar.status = ar.StatusChoices.SUCCEEDED
else:
ar.status = ar.StatusChoices.FAILED
ar.save(update_fields=['status', 'end_ts', 'modified_at'])
def _try_advance_step(self) -> None:
"""Advance current_step if all foreground hooks in current step are done."""
from django.db.models import Q
from archivebox.core.models import ArchiveResult
current_step = self.snapshot.current_step
# Single query: foreground hooks in current step that aren't finished
# Foreground hooks: hook_name doesn't contain '.bg.'
pending_foreground = self.snapshot.archiveresult_set.filter(
Q(hook_name__contains=f'__{current_step}_') & # Current step
~Q(hook_name__contains='.bg.') & # Not background
~Q(status__in=ArchiveResult.FINAL_STATES) # Not finished
).exists()
if pending_foreground:
return # Still waiting for hooks
# All foreground hooks done - advance!
self.snapshot.current_step += 1
self.snapshot.save(update_fields=['current_step', 'modified_at'])
log_worker_event(
worker_type='SnapshotWorker',
event=f'Advanced to step {self.snapshot.current_step}',
indent_level=2,
pid=self.pid,
)
def on_shutdown(self, error: BaseException | None = None) -> None:
"""
Terminate all background Snapshot hooks when snapshot finishes.
Background hooks should only be killed when:
- All foreground hooks are done (snapshot is sealed)
- Worker is shutting down
"""
# Use shared termination logic from Worker base class
self._terminate_background_hooks(
background_processes=self.background_processes,
worker_type='SnapshotWorker',
indent_level=2,
)
super().on_shutdown(error)
@staticmethod
def _extract_plugin_name(hook_name: str) -> str:
"""Extract plugin name from hook filename."""
# on_Snapshot__50_wget.py -> wget
name = hook_name.split('__')[-1] # Get part after last __
name = name.replace('.py', '').replace('.js', '').replace('.sh', '')
name = name.replace('.bg', '') # Remove .bg suffix
return name
@classmethod
def start(cls, snapshot_id: str, **kwargs: Any) -> int:
"""Fork a SnapshotWorker for a specific snapshot."""
from archivebox.machine.models import Process
worker_id = Process.get_next_worker_id(process_type=Process.TypeChoices.WORKER)
proc = MPProcess(
target=_run_snapshot_worker, # New module-level function
args=(snapshot_id, worker_id),
kwargs=kwargs,
name=f'snapshot_worker_{snapshot_id[:8]}',
)
proc.start()
assert proc.pid is not None
return proc.pid
# Populate the registry
WORKER_TYPES.update({
'crawl': CrawlWorker,
'snapshot': SnapshotWorker,
})
def get_worker_class(name: str) -> type[Worker]:
"""Get worker class by name."""
if name not in WORKER_TYPES:
raise ValueError(f'Unknown worker type: {name}. Valid types: {list(WORKER_TYPES.keys())}')
return WORKER_TYPES[name]