mirror of
https://github.com/ArchiveBox/ArchiveBox.git
synced 2026-01-03 09:25:42 +10:00
735 lines
27 KiB
Python
735 lines
27 KiB
Python
"""
|
|
Worker classes for processing queue items.
|
|
|
|
Workers poll the database for items to process, claim them atomically,
|
|
and run the state machine tick() to process each item.
|
|
|
|
Architecture:
|
|
Orchestrator (spawns workers)
|
|
└── Worker (claims items from queue, processes them directly)
|
|
"""
|
|
|
|
__package__ = 'archivebox.workers'
|
|
|
|
import os
|
|
import time
|
|
import traceback
|
|
from typing import ClassVar, Any
|
|
from datetime import timedelta
|
|
from pathlib import Path
|
|
from multiprocessing import Process as MPProcess, cpu_count
|
|
|
|
from django.db.models import QuerySet
|
|
from django.utils import timezone
|
|
from django.conf import settings
|
|
|
|
from rich import print
|
|
|
|
from archivebox.misc.logging_util import log_worker_event
|
|
|
|
|
|
CPU_COUNT = cpu_count()
|
|
|
|
# Registry of worker types by name (defined at bottom, referenced here for _run_worker)
|
|
WORKER_TYPES: dict[str, type['Worker']] = {}
|
|
|
|
|
|
def _run_worker(worker_class_name: str, worker_id: int, **kwargs):
|
|
"""
|
|
Module-level function to run a worker. Must be at module level for pickling.
|
|
"""
|
|
from archivebox.config.django import setup_django
|
|
setup_django()
|
|
|
|
# Get worker class by name to avoid pickling class objects
|
|
worker_cls = WORKER_TYPES[worker_class_name]
|
|
worker = worker_cls(worker_id=worker_id, **kwargs)
|
|
worker.runloop()
|
|
|
|
|
|
def _run_snapshot_worker(snapshot_id: str, worker_id: int, **kwargs):
|
|
"""
|
|
Module-level function to run a SnapshotWorker for a specific snapshot.
|
|
Must be at module level for pickling compatibility.
|
|
"""
|
|
from archivebox.config.django import setup_django
|
|
setup_django()
|
|
|
|
worker = SnapshotWorker(snapshot_id=snapshot_id, worker_id=worker_id, **kwargs)
|
|
worker.runloop()
|
|
|
|
|
|
class Worker:
|
|
"""
|
|
Base worker class for CrawlWorker and SnapshotWorker.
|
|
|
|
Workers are spawned as subprocesses to process crawls and snapshots.
|
|
Each worker type has its own custom runloop implementation.
|
|
"""
|
|
|
|
name: ClassVar[str] = 'worker'
|
|
|
|
# Configuration (can be overridden by subclasses)
|
|
MAX_TICK_TIME: ClassVar[int] = 60
|
|
MAX_CONCURRENT_TASKS: ClassVar[int] = 1
|
|
|
|
def __init__(self, worker_id: int = 0, **kwargs: Any):
|
|
self.worker_id = worker_id
|
|
self.pid: int = os.getpid()
|
|
|
|
def __repr__(self) -> str:
|
|
return f'[underline]{self.__class__.__name__}[/underline]\\[id={self.worker_id}, pid={self.pid}]'
|
|
|
|
def get_model(self):
|
|
"""Get the Django model class. Subclasses must override this."""
|
|
raise NotImplementedError("Subclasses must implement get_model()")
|
|
|
|
def on_startup(self) -> None:
|
|
"""Called when worker starts."""
|
|
from archivebox.machine.models import Process
|
|
|
|
self.pid = os.getpid()
|
|
# Register this worker process in the database
|
|
self.db_process = Process.current()
|
|
# Explicitly set process_type to WORKER and store worker type name
|
|
update_fields = []
|
|
if self.db_process.process_type != Process.TypeChoices.WORKER:
|
|
self.db_process.process_type = Process.TypeChoices.WORKER
|
|
update_fields.append('process_type')
|
|
# Store worker type name (crawl/snapshot) in worker_type field
|
|
if not self.db_process.worker_type:
|
|
self.db_process.worker_type = self.name
|
|
update_fields.append('worker_type')
|
|
if update_fields:
|
|
self.db_process.save(update_fields=update_fields)
|
|
|
|
# Determine worker type for logging
|
|
worker_type_name = self.__class__.__name__
|
|
indent_level = 1 # Default for CrawlWorker
|
|
|
|
# SnapshotWorker gets indent level 2
|
|
if 'Snapshot' in worker_type_name:
|
|
indent_level = 2
|
|
|
|
log_worker_event(
|
|
worker_type=worker_type_name,
|
|
event='Starting...',
|
|
indent_level=indent_level,
|
|
pid=self.pid,
|
|
worker_id=str(self.worker_id),
|
|
)
|
|
|
|
def on_shutdown(self, error: BaseException | None = None) -> None:
|
|
"""Called when worker shuts down."""
|
|
# Update Process record status
|
|
if hasattr(self, 'db_process') and self.db_process:
|
|
self.db_process.exit_code = 1 if error else 0
|
|
self.db_process.status = self.db_process.StatusChoices.EXITED
|
|
self.db_process.ended_at = timezone.now()
|
|
self.db_process.save()
|
|
|
|
# Determine worker type for logging
|
|
worker_type_name = self.__class__.__name__
|
|
indent_level = 1 # CrawlWorker
|
|
|
|
if 'Snapshot' in worker_type_name:
|
|
indent_level = 2
|
|
|
|
log_worker_event(
|
|
worker_type=worker_type_name,
|
|
event='Shutting down',
|
|
indent_level=indent_level,
|
|
pid=self.pid,
|
|
worker_id=str(self.worker_id),
|
|
error=error if error and not isinstance(error, KeyboardInterrupt) else None,
|
|
)
|
|
|
|
def _terminate_background_hooks(
|
|
self,
|
|
background_processes: dict[str, 'Process'],
|
|
worker_type: str,
|
|
indent_level: int,
|
|
) -> None:
|
|
"""
|
|
Terminate background hooks in 3 phases (shared logic for Crawl/Snapshot workers).
|
|
|
|
Phase 1: Send SIGTERM to all bg hooks + children in parallel (polite request to wrap up)
|
|
Phase 2: Wait for each hook's remaining timeout before SIGKILL
|
|
Phase 3: SIGKILL any stragglers that exceeded their timeout
|
|
|
|
Args:
|
|
background_processes: Dict mapping hook name -> Process instance
|
|
worker_type: Worker type name for logging (e.g., 'CrawlWorker', 'SnapshotWorker')
|
|
indent_level: Logging indent level (1 for Crawl, 2 for Snapshot)
|
|
"""
|
|
import signal
|
|
import time
|
|
|
|
if not background_processes:
|
|
return
|
|
|
|
now = time.time()
|
|
|
|
# Phase 1: Send SIGTERM to ALL background processes + children in parallel
|
|
log_worker_event(
|
|
worker_type=worker_type,
|
|
event=f'Sending SIGTERM to {len(background_processes)} background hooks (+ children)',
|
|
indent_level=indent_level,
|
|
pid=self.pid,
|
|
)
|
|
|
|
# Build deadline map first (before killing, to get accurate remaining time)
|
|
deadlines = {}
|
|
for hook_name, process in background_processes.items():
|
|
elapsed = now - process.started_at.timestamp()
|
|
remaining = max(0, process.timeout - elapsed)
|
|
deadline = now + remaining
|
|
deadlines[hook_name] = (process, deadline)
|
|
|
|
# Send SIGTERM to all process trees in parallel (non-blocking)
|
|
for hook_name, process in background_processes.items():
|
|
try:
|
|
# Get chrome children (renderer processes etc) before sending signal
|
|
children_pids = process.get_children_pids()
|
|
if children_pids:
|
|
# Chrome hook with children - kill tree
|
|
os.kill(process.pid, signal.SIGTERM)
|
|
for child_pid in children_pids:
|
|
try:
|
|
os.kill(child_pid, signal.SIGTERM)
|
|
except ProcessLookupError:
|
|
pass
|
|
log_worker_event(
|
|
worker_type=worker_type,
|
|
event=f'Sent SIGTERM to {hook_name} + {len(children_pids)} children',
|
|
indent_level=indent_level,
|
|
pid=self.pid,
|
|
)
|
|
else:
|
|
# No children - normal kill
|
|
os.kill(process.pid, signal.SIGTERM)
|
|
except ProcessLookupError:
|
|
pass # Already dead
|
|
except Exception as e:
|
|
log_worker_event(
|
|
worker_type=worker_type,
|
|
event=f'Failed to SIGTERM {hook_name}: {e}',
|
|
indent_level=indent_level,
|
|
pid=self.pid,
|
|
)
|
|
|
|
# Phase 2: Wait for all processes in parallel, respecting individual timeouts
|
|
for hook_name, (process, deadline) in deadlines.items():
|
|
remaining = deadline - now
|
|
log_worker_event(
|
|
worker_type=worker_type,
|
|
event=f'Waiting up to {remaining:.1f}s for {hook_name}',
|
|
indent_level=indent_level,
|
|
pid=self.pid,
|
|
)
|
|
|
|
# Poll all processes in parallel using Process.poll()
|
|
still_running = set(deadlines.keys())
|
|
|
|
while still_running:
|
|
time.sleep(0.1)
|
|
now = time.time()
|
|
|
|
for hook_name in list(still_running):
|
|
process, deadline = deadlines[hook_name]
|
|
|
|
# Check if process exited using Process.poll()
|
|
exit_code = process.poll()
|
|
if exit_code is not None:
|
|
# Process exited
|
|
still_running.remove(hook_name)
|
|
log_worker_event(
|
|
worker_type=worker_type,
|
|
event=f'✓ {hook_name} exited with code {exit_code}',
|
|
indent_level=indent_level,
|
|
pid=self.pid,
|
|
)
|
|
continue
|
|
|
|
# Check if deadline exceeded
|
|
if now >= deadline:
|
|
# Timeout exceeded - SIGKILL process tree
|
|
try:
|
|
# Get children before killing (chrome may have spawned more)
|
|
children_pids = process.get_children_pids()
|
|
if children_pids:
|
|
# Kill children first
|
|
for child_pid in children_pids:
|
|
try:
|
|
os.kill(child_pid, signal.SIGKILL)
|
|
except ProcessLookupError:
|
|
pass
|
|
# Then kill parent
|
|
process.kill(signal_num=signal.SIGKILL)
|
|
log_worker_event(
|
|
worker_type=worker_type,
|
|
event=f'⚠ Sent SIGKILL to {hook_name} + {len(children_pids) if children_pids else 0} children (exceeded timeout)',
|
|
indent_level=indent_level,
|
|
pid=self.pid,
|
|
)
|
|
except Exception as e:
|
|
log_worker_event(
|
|
worker_type=worker_type,
|
|
event=f'Failed to SIGKILL {hook_name}: {e}',
|
|
indent_level=indent_level,
|
|
pid=self.pid,
|
|
)
|
|
still_running.remove(hook_name)
|
|
|
|
@classmethod
|
|
def start(cls, **kwargs: Any) -> int:
|
|
"""
|
|
Fork a new worker as a subprocess.
|
|
Returns the PID of the new process.
|
|
"""
|
|
from archivebox.machine.models import Process
|
|
|
|
worker_id = Process.get_next_worker_id(process_type=Process.TypeChoices.WORKER)
|
|
|
|
# Use module-level function for pickling compatibility
|
|
proc = MPProcess(
|
|
target=_run_worker,
|
|
args=(cls.name, worker_id),
|
|
kwargs=kwargs,
|
|
name=f'{cls.name}_worker_{worker_id}',
|
|
)
|
|
proc.start()
|
|
|
|
assert proc.pid is not None
|
|
return proc.pid
|
|
|
|
@classmethod
|
|
def get_running_workers(cls) -> list:
|
|
"""Get info about all running workers of this type."""
|
|
from archivebox.machine.models import Process
|
|
|
|
Process.cleanup_stale_running()
|
|
# Convert Process objects to dicts to match the expected API contract
|
|
# Filter by worker_type to get only workers of this specific type (crawl/snapshot/archiveresult)
|
|
processes = Process.objects.filter(
|
|
process_type=Process.TypeChoices.WORKER,
|
|
worker_type=cls.name, # Filter by specific worker type
|
|
status__in=['running', 'started']
|
|
)
|
|
# Note: worker_id is not stored on Process model, it's dynamically generated
|
|
# We return process_id (UUID) and pid (OS process ID) instead
|
|
return [
|
|
{
|
|
'pid': p.pid,
|
|
'process_id': str(p.id), # UUID of Process record
|
|
'started_at': p.started_at.isoformat() if p.started_at else None,
|
|
'status': p.status,
|
|
}
|
|
for p in processes
|
|
]
|
|
|
|
@classmethod
|
|
def get_worker_count(cls) -> int:
|
|
"""Get count of running workers of this type."""
|
|
from archivebox.machine.models import Process
|
|
|
|
return Process.objects.filter(
|
|
process_type=Process.TypeChoices.WORKER,
|
|
worker_type=cls.name, # Filter by specific worker type
|
|
status__in=['running', 'started']
|
|
).count()
|
|
|
|
|
|
class CrawlWorker(Worker):
|
|
"""
|
|
Worker for processing Crawl objects.
|
|
|
|
Responsibilities:
|
|
1. Run on_Crawl__* hooks (e.g., chrome launcher)
|
|
2. Create Snapshots from URLs
|
|
3. Spawn SnapshotWorkers (up to MAX_SNAPSHOT_WORKERS)
|
|
4. Monitor snapshots and seal crawl when all done
|
|
"""
|
|
|
|
name: ClassVar[str] = 'crawl'
|
|
MAX_TICK_TIME: ClassVar[int] = 60
|
|
MAX_SNAPSHOT_WORKERS: ClassVar[int] = 8 # Per crawl limit
|
|
|
|
def __init__(self, crawl_id: str, **kwargs: Any):
|
|
super().__init__(**kwargs)
|
|
self.crawl_id = crawl_id
|
|
self.crawl = None
|
|
|
|
def get_model(self):
|
|
from archivebox.crawls.models import Crawl
|
|
return Crawl
|
|
|
|
def on_startup(self) -> None:
|
|
"""Load crawl."""
|
|
super().on_startup()
|
|
|
|
from archivebox.crawls.models import Crawl
|
|
self.crawl = Crawl.objects.get(id=self.crawl_id)
|
|
|
|
def runloop(self) -> None:
|
|
"""Run crawl state machine, spawn SnapshotWorkers."""
|
|
import sys
|
|
self.on_startup()
|
|
|
|
try:
|
|
print(f'[cyan]🔄 CrawlWorker.runloop: Starting tick() for crawl {self.crawl_id}[/cyan]', file=sys.stderr)
|
|
# Advance state machine: QUEUED → STARTED (triggers run() via @started.enter)
|
|
self.crawl.sm.tick()
|
|
self.crawl.refresh_from_db()
|
|
print(f'[cyan]🔄 tick() complete, crawl status={self.crawl.status}[/cyan]', file=sys.stderr)
|
|
|
|
# Now spawn SnapshotWorkers and monitor progress
|
|
while True:
|
|
# Check if crawl is done
|
|
if self._is_crawl_finished():
|
|
print(f'[cyan]🔄 Crawl finished, sealing...[/cyan]', file=sys.stderr)
|
|
self.crawl.sm.seal()
|
|
break
|
|
|
|
# Spawn workers for queued snapshots
|
|
self._spawn_snapshot_workers()
|
|
|
|
time.sleep(2) # Check every 2s
|
|
|
|
finally:
|
|
self.on_shutdown()
|
|
|
|
def _spawn_snapshot_workers(self) -> None:
|
|
"""Spawn SnapshotWorkers for queued snapshots (up to limit)."""
|
|
from archivebox.core.models import Snapshot
|
|
from archivebox.machine.models import Process
|
|
|
|
# Count running SnapshotWorkers for this crawl
|
|
running_count = Process.objects.filter(
|
|
process_type=Process.TypeChoices.WORKER,
|
|
worker_type='snapshot',
|
|
parent_id=self.db_process.id, # Children of this CrawlWorker
|
|
status__in=['running', 'started'],
|
|
).count()
|
|
|
|
if running_count >= self.MAX_SNAPSHOT_WORKERS:
|
|
return # At limit
|
|
|
|
# Get queued snapshots for this crawl (SnapshotWorker will mark as STARTED in on_startup)
|
|
queued_snapshots = Snapshot.objects.filter(
|
|
crawl_id=self.crawl_id,
|
|
status=Snapshot.StatusChoices.QUEUED,
|
|
).order_by('created_at')[:self.MAX_SNAPSHOT_WORKERS - running_count]
|
|
|
|
import sys
|
|
print(f'[yellow]🔧 _spawn_snapshot_workers: running={running_count}/{self.MAX_SNAPSHOT_WORKERS}, queued={queued_snapshots.count()}[/yellow]', file=sys.stderr)
|
|
|
|
# Spawn workers
|
|
for snapshot in queued_snapshots:
|
|
print(f'[yellow]🔧 Spawning worker for {snapshot.url} (status={snapshot.status})[/yellow]', file=sys.stderr)
|
|
SnapshotWorker.start(snapshot_id=str(snapshot.id))
|
|
log_worker_event(
|
|
worker_type='CrawlWorker',
|
|
event=f'Spawned SnapshotWorker for {snapshot.url}',
|
|
indent_level=1,
|
|
pid=self.pid,
|
|
)
|
|
|
|
def _is_crawl_finished(self) -> bool:
|
|
"""Check if all snapshots are sealed."""
|
|
from archivebox.core.models import Snapshot
|
|
|
|
pending = Snapshot.objects.filter(
|
|
crawl_id=self.crawl_id,
|
|
status__in=[Snapshot.StatusChoices.QUEUED, Snapshot.StatusChoices.STARTED],
|
|
).count()
|
|
|
|
return pending == 0
|
|
|
|
def on_shutdown(self, error: BaseException | None = None) -> None:
|
|
"""
|
|
Terminate all background Crawl hooks when crawl finishes.
|
|
|
|
Background hooks (e.g., chrome launcher) should only be killed when:
|
|
- All snapshots are done (crawl is sealed)
|
|
- Worker is shutting down
|
|
"""
|
|
from archivebox.machine.models import Process
|
|
|
|
# Query for all running hook processes that are children of this CrawlWorker
|
|
background_hooks = Process.objects.filter(
|
|
parent_id=self.db_process.id,
|
|
process_type=Process.TypeChoices.HOOK,
|
|
status=Process.StatusChoices.RUNNING,
|
|
).select_related('machine')
|
|
|
|
# Build dict for shared termination logic
|
|
background_processes = {
|
|
hook.cmd[0] if hook.cmd else f'hook-{hook.pid}': hook
|
|
for hook in background_hooks
|
|
}
|
|
|
|
# Use shared termination logic from Worker base class
|
|
self._terminate_background_hooks(
|
|
background_processes=background_processes,
|
|
worker_type='CrawlWorker',
|
|
indent_level=1,
|
|
)
|
|
|
|
super().on_shutdown(error)
|
|
|
|
|
|
class SnapshotWorker(Worker):
|
|
"""
|
|
Worker that owns sequential hook execution for ONE snapshot.
|
|
|
|
Unlike other workers, SnapshotWorker doesn't poll a queue - it's given
|
|
a specific snapshot_id and runs all hooks for that snapshot sequentially.
|
|
|
|
Execution flow:
|
|
1. Mark snapshot as STARTED
|
|
2. Discover hooks for snapshot
|
|
3. For each hook (sorted by name):
|
|
a. Fork hook Process
|
|
b. If foreground: wait for completion
|
|
c. If background: track but continue to next hook
|
|
d. Update ArchiveResult status
|
|
e. Advance current_step when all step's hooks complete
|
|
4. When all hooks done: seal snapshot
|
|
5. On shutdown: SIGTERM all background hooks
|
|
"""
|
|
|
|
name: ClassVar[str] = 'snapshot'
|
|
|
|
def __init__(self, snapshot_id: str, **kwargs: Any):
|
|
super().__init__(**kwargs)
|
|
self.snapshot_id = snapshot_id
|
|
self.snapshot = None
|
|
self.background_processes: dict[str, Any] = {} # hook_name -> Process
|
|
|
|
def get_model(self):
|
|
"""Not used - SnapshotWorker doesn't poll queues."""
|
|
from archivebox.core.models import Snapshot
|
|
return Snapshot
|
|
|
|
def on_startup(self) -> None:
|
|
"""Load snapshot and mark as STARTED using state machine."""
|
|
super().on_startup()
|
|
|
|
from archivebox.core.models import Snapshot
|
|
self.snapshot = Snapshot.objects.get(id=self.snapshot_id)
|
|
|
|
# Use state machine to transition queued -> started (triggers enter_started())
|
|
self.snapshot.sm.tick()
|
|
self.snapshot.refresh_from_db()
|
|
|
|
def runloop(self) -> None:
|
|
"""Execute all hooks sequentially."""
|
|
from archivebox.hooks import discover_hooks, is_background_hook, extract_step
|
|
from archivebox.core.models import ArchiveResult
|
|
|
|
self.on_startup()
|
|
|
|
try:
|
|
# Discover all hooks for this snapshot
|
|
hooks = discover_hooks('Snapshot', config=self.snapshot.config)
|
|
hooks = sorted(hooks, key=lambda h: h.name) # Sort by name (includes step prefix)
|
|
|
|
# Execute each hook sequentially
|
|
for hook_path in hooks:
|
|
hook_name = hook_path.name
|
|
plugin = self._extract_plugin_name(hook_name)
|
|
hook_step = extract_step(hook_name)
|
|
is_background = is_background_hook(hook_name)
|
|
|
|
# Create ArchiveResult for THIS HOOK (not per plugin)
|
|
# One plugin can have multiple hooks (e.g., chrome/on_Snapshot__20_launch_chrome.js, chrome/on_Snapshot__21_navigate_chrome.js)
|
|
# Unique key = (snapshot, plugin, hook_name) for idempotency
|
|
ar, created = ArchiveResult.objects.get_or_create(
|
|
snapshot=self.snapshot,
|
|
plugin=plugin,
|
|
hook_name=hook_name,
|
|
defaults={
|
|
'status': ArchiveResult.StatusChoices.STARTED,
|
|
'start_ts': timezone.now(),
|
|
}
|
|
)
|
|
|
|
if not created:
|
|
# Update existing AR to STARTED
|
|
ar.status = ArchiveResult.StatusChoices.STARTED
|
|
ar.start_ts = timezone.now()
|
|
ar.save(update_fields=['status', 'start_ts', 'modified_at'])
|
|
|
|
# Fork and run the hook
|
|
process = self._run_hook(hook_path, ar)
|
|
|
|
if is_background:
|
|
# Track but don't wait
|
|
self.background_processes[hook_name] = process
|
|
log_worker_event(
|
|
worker_type='SnapshotWorker',
|
|
event=f'Started background hook: {hook_name} (timeout={process.timeout}s)',
|
|
indent_level=2,
|
|
pid=self.pid,
|
|
)
|
|
else:
|
|
# Wait for foreground hook to complete
|
|
self._wait_for_hook(process, ar)
|
|
log_worker_event(
|
|
worker_type='SnapshotWorker',
|
|
event=f'Completed hook: {hook_name}',
|
|
indent_level=2,
|
|
pid=self.pid,
|
|
)
|
|
|
|
# Check if we can advance to next step
|
|
self._try_advance_step()
|
|
|
|
# All hooks launched (or completed) - seal using state machine
|
|
# This triggers enter_sealed() which calls cleanup() and checks parent crawl sealing
|
|
self.snapshot.sm.seal()
|
|
self.snapshot.refresh_from_db()
|
|
|
|
except Exception as e:
|
|
# Mark snapshot as sealed even on error (still triggers cleanup)
|
|
self.snapshot.sm.seal()
|
|
self.snapshot.refresh_from_db()
|
|
raise
|
|
finally:
|
|
self.on_shutdown()
|
|
|
|
def _run_hook(self, hook_path: Path, ar: Any) -> Any:
|
|
"""Fork and run a hook using Process model, return Process."""
|
|
from archivebox.hooks import run_hook
|
|
|
|
# Create output directory
|
|
output_dir = ar.create_output_dir()
|
|
|
|
# Run hook using Process.launch() - returns Process model directly
|
|
# Pass self.db_process as parent to track SnapshotWorker -> Hook hierarchy
|
|
process = run_hook(
|
|
script=hook_path,
|
|
output_dir=output_dir,
|
|
config=self.snapshot.config,
|
|
timeout=120,
|
|
parent=self.db_process,
|
|
url=str(self.snapshot.url),
|
|
snapshot_id=str(self.snapshot.id),
|
|
)
|
|
|
|
# Link ArchiveResult to Process for tracking
|
|
ar.process = process
|
|
ar.save(update_fields=['process_id', 'modified_at'])
|
|
|
|
return process
|
|
|
|
def _wait_for_hook(self, process: Any, ar: Any) -> None:
|
|
"""Wait for hook using Process.wait(), update AR status."""
|
|
# Use Process.wait() helper instead of manual polling
|
|
try:
|
|
exit_code = process.wait(timeout=process.timeout)
|
|
except TimeoutError:
|
|
# Hook exceeded timeout - kill it
|
|
process.kill(signal_num=9)
|
|
exit_code = -1
|
|
|
|
# Update ArchiveResult from hook output
|
|
ar.update_from_output()
|
|
ar.end_ts = timezone.now()
|
|
|
|
# Determine final status from hook exit code
|
|
if exit_code == 0:
|
|
ar.status = ar.StatusChoices.SUCCEEDED
|
|
else:
|
|
ar.status = ar.StatusChoices.FAILED
|
|
|
|
ar.save(update_fields=['status', 'end_ts', 'modified_at'])
|
|
|
|
def _try_advance_step(self) -> None:
|
|
"""Advance current_step if all foreground hooks in current step are done."""
|
|
from django.db.models import Q
|
|
from archivebox.core.models import ArchiveResult
|
|
|
|
current_step = self.snapshot.current_step
|
|
|
|
# Single query: foreground hooks in current step that aren't finished
|
|
# Foreground hooks: hook_name doesn't contain '.bg.'
|
|
pending_foreground = self.snapshot.archiveresult_set.filter(
|
|
Q(hook_name__contains=f'__{current_step}_') & # Current step
|
|
~Q(hook_name__contains='.bg.') & # Not background
|
|
~Q(status__in=ArchiveResult.FINAL_STATES) # Not finished
|
|
).exists()
|
|
|
|
if pending_foreground:
|
|
return # Still waiting for hooks
|
|
|
|
# All foreground hooks done - advance!
|
|
self.snapshot.current_step += 1
|
|
self.snapshot.save(update_fields=['current_step', 'modified_at'])
|
|
|
|
log_worker_event(
|
|
worker_type='SnapshotWorker',
|
|
event=f'Advanced to step {self.snapshot.current_step}',
|
|
indent_level=2,
|
|
pid=self.pid,
|
|
)
|
|
|
|
def on_shutdown(self, error: BaseException | None = None) -> None:
|
|
"""
|
|
Terminate all background Snapshot hooks when snapshot finishes.
|
|
|
|
Background hooks should only be killed when:
|
|
- All foreground hooks are done (snapshot is sealed)
|
|
- Worker is shutting down
|
|
"""
|
|
# Use shared termination logic from Worker base class
|
|
self._terminate_background_hooks(
|
|
background_processes=self.background_processes,
|
|
worker_type='SnapshotWorker',
|
|
indent_level=2,
|
|
)
|
|
|
|
super().on_shutdown(error)
|
|
|
|
@staticmethod
|
|
def _extract_plugin_name(hook_name: str) -> str:
|
|
"""Extract plugin name from hook filename."""
|
|
# on_Snapshot__50_wget.py -> wget
|
|
name = hook_name.split('__')[-1] # Get part after last __
|
|
name = name.replace('.py', '').replace('.js', '').replace('.sh', '')
|
|
name = name.replace('.bg', '') # Remove .bg suffix
|
|
return name
|
|
|
|
@classmethod
|
|
def start(cls, snapshot_id: str, **kwargs: Any) -> int:
|
|
"""Fork a SnapshotWorker for a specific snapshot."""
|
|
from archivebox.machine.models import Process
|
|
|
|
worker_id = Process.get_next_worker_id(process_type=Process.TypeChoices.WORKER)
|
|
|
|
proc = MPProcess(
|
|
target=_run_snapshot_worker, # New module-level function
|
|
args=(snapshot_id, worker_id),
|
|
kwargs=kwargs,
|
|
name=f'snapshot_worker_{snapshot_id[:8]}',
|
|
)
|
|
proc.start()
|
|
|
|
assert proc.pid is not None
|
|
return proc.pid
|
|
|
|
|
|
# Populate the registry
|
|
WORKER_TYPES.update({
|
|
'crawl': CrawlWorker,
|
|
'snapshot': SnapshotWorker,
|
|
})
|
|
|
|
|
|
def get_worker_class(name: str) -> type[Worker]:
|
|
"""Get worker class by name."""
|
|
if name not in WORKER_TYPES:
|
|
raise ValueError(f'Unknown worker type: {name}. Valid types: {list(WORKER_TYPES.keys())}')
|
|
return WORKER_TYPES[name]
|