mirror of
https://github.com/ArchiveBox/ArchiveBox.git
synced 2026-04-05 23:37:58 +10:00
wip major changes
This commit is contained in:
@@ -1,457 +1,330 @@
|
||||
"""
|
||||
Worker classes for processing queue items.
|
||||
|
||||
Workers poll the database for items to process, claim them atomically,
|
||||
and run the state machine tick() to process each item.
|
||||
|
||||
Architecture:
|
||||
Orchestrator (spawns workers)
|
||||
└── Worker (claims items from queue, processes them directly)
|
||||
"""
|
||||
|
||||
__package__ = 'archivebox.workers'
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
import json
|
||||
|
||||
from typing import ClassVar, Iterable, Type
|
||||
import traceback
|
||||
from typing import ClassVar, Any
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from multiprocessing import Process, cpu_count
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from django.utils import timezone
|
||||
from django.conf import settings
|
||||
|
||||
from rich import print
|
||||
|
||||
from django.db import transaction
|
||||
from django.db.models import QuerySet
|
||||
from django.utils import timezone
|
||||
from django.utils.functional import classproperty # type: ignore
|
||||
|
||||
from crawls.models import Crawl
|
||||
from core.models import Snapshot, ArchiveResult
|
||||
|
||||
from workers.models import Event, Process, EventDict
|
||||
from .pid_utils import (
|
||||
write_pid_file,
|
||||
remove_pid_file,
|
||||
get_all_worker_pids,
|
||||
get_next_worker_id,
|
||||
cleanup_stale_pid_files,
|
||||
)
|
||||
|
||||
|
||||
class WorkerType:
|
||||
# static class attributes
|
||||
name: ClassVar[str] # e.g. 'log' or 'filesystem' or 'crawl' or 'snapshot' or 'archiveresult' etc.
|
||||
|
||||
listens_to: ClassVar[str] # e.g. 'LOG_' or 'FS_' or 'CRAWL_' or 'SNAPSHOT_' or 'ARCHIVERESULT_' etc.
|
||||
outputs: ClassVar[list[str]] # e.g. ['LOG_', 'FS_', 'CRAWL_', 'SNAPSHOT_', 'ARCHIVERESULT_'] etc.
|
||||
|
||||
poll_interval: ClassVar[int] = 1 # how long to wait before polling for new events
|
||||
|
||||
@classproperty
|
||||
def event_queue(cls) -> QuerySet[Event]:
|
||||
return Event.objects.filter(name__startswith=cls.listens_to)
|
||||
CPU_COUNT = cpu_count()
|
||||
|
||||
@classmethod
|
||||
def fork(cls, wait_for_first_event=False, exit_on_idle=True) -> Process:
|
||||
cmd = ['archivebox', 'worker', cls.name]
|
||||
if exit_on_idle:
|
||||
cmd.append('--exit-on-idle')
|
||||
if wait_for_first_event:
|
||||
cmd.append('--wait-for-first-event')
|
||||
return Process.create_and_fork(cmd=cmd, actor_type=cls.name)
|
||||
# Registry of worker types by name (defined at bottom, referenced here for _run_worker)
|
||||
WORKER_TYPES: dict[str, type['Worker']] = {}
|
||||
|
||||
@classproperty
|
||||
def processes(cls) -> QuerySet[Process]:
|
||||
return Process.objects.filter(actor_type=cls.name)
|
||||
|
||||
@classmethod
|
||||
def run(cls, wait_for_first_event=False, exit_on_idle=True):
|
||||
def _run_worker(worker_class_name: str, worker_id: int, daemon: bool, **kwargs):
|
||||
"""
|
||||
Module-level function to run a worker. Must be at module level for pickling.
|
||||
"""
|
||||
from archivebox.config.django import setup_django
|
||||
setup_django()
|
||||
|
||||
if wait_for_first_event:
|
||||
event = cls.event_queue.get_next_unclaimed()
|
||||
while not event:
|
||||
time.sleep(cls.poll_interval)
|
||||
event = cls.event_queue.get_next_unclaimed()
|
||||
# Get worker class by name to avoid pickling class objects
|
||||
worker_cls = WORKER_TYPES[worker_class_name]
|
||||
worker = worker_cls(worker_id=worker_id, daemon=daemon, **kwargs)
|
||||
worker.runloop()
|
||||
|
||||
while True:
|
||||
output_events = list(cls.process_next_event()) or list(cls.process_idle_tick()) # process next event, or tick if idle
|
||||
yield from output_events
|
||||
if not output_events:
|
||||
if exit_on_idle:
|
||||
break
|
||||
else:
|
||||
time.sleep(cls.poll_interval)
|
||||
|
||||
@classmethod
|
||||
def process_next_event(cls) -> Iterable[EventDict]:
|
||||
event = cls.event_queue.get_next_unclaimed()
|
||||
output_events = []
|
||||
|
||||
if not event:
|
||||
return []
|
||||
|
||||
cls.mark_event_claimed(event)
|
||||
print(f'{cls.__name__}[{Process.current().pid}] {event}', file=sys.stderr)
|
||||
try:
|
||||
for output_event in cls.receive(event):
|
||||
output_events.append(output_event)
|
||||
yield output_event
|
||||
cls.mark_event_succeeded(event, output_events=output_events)
|
||||
except BaseException as e:
|
||||
cls.mark_event_failed(event, output_events=output_events, error=e)
|
||||
class Worker:
|
||||
"""
|
||||
Base worker class that polls a queue and processes items directly.
|
||||
|
||||
@classmethod
|
||||
def process_idle_tick(cls) -> Iterable[EventDict]:
|
||||
# reset the idle event to be claimed by the current process
|
||||
event, _created = Event.objects.update_or_create(
|
||||
name=f'{cls.listens_to}IDLE',
|
||||
emitted_by=Process.current(),
|
||||
defaults={
|
||||
'deliver_at': timezone.now(),
|
||||
'claimed_proc': None,
|
||||
'claimed_at': None,
|
||||
'finished_at': None,
|
||||
'error': None,
|
||||
'parent': None,
|
||||
},
|
||||
Each item is processed by calling its state machine tick() method.
|
||||
Workers exit when idle for too long (unless daemon mode).
|
||||
"""
|
||||
|
||||
name: ClassVar[str] = 'worker'
|
||||
|
||||
# Configuration (can be overridden by subclasses)
|
||||
MAX_TICK_TIME: ClassVar[int] = 60
|
||||
POLL_INTERVAL: ClassVar[float] = 0.5
|
||||
IDLE_TIMEOUT: ClassVar[int] = 3 # Exit after N idle iterations (set to 0 to never exit)
|
||||
|
||||
def __init__(self, worker_id: int = 0, daemon: bool = False, **kwargs: Any):
|
||||
self.worker_id = worker_id
|
||||
self.daemon = daemon
|
||||
self.pid: int = os.getpid()
|
||||
self.pid_file: Path | None = None
|
||||
self.idle_count: int = 0
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'[underline]{self.__class__.__name__}[/underline]\\[id={self.worker_id}, pid={self.pid}]'
|
||||
|
||||
def get_model(self):
|
||||
"""Get the Django model class. Subclasses must override this."""
|
||||
raise NotImplementedError("Subclasses must implement get_model()")
|
||||
|
||||
def get_queue(self) -> QuerySet:
|
||||
"""Get the queue of objects ready for processing."""
|
||||
Model = self.get_model()
|
||||
return Model.objects.filter(
|
||||
retry_at__lte=timezone.now()
|
||||
).exclude(
|
||||
status__in=Model.FINAL_STATES
|
||||
).order_by('retry_at')
|
||||
|
||||
def claim_next(self):
|
||||
"""
|
||||
Atomically claim the next object from the queue.
|
||||
Returns the claimed object or None if queue is empty or claim failed.
|
||||
"""
|
||||
Model = self.get_model()
|
||||
obj = self.get_queue().first()
|
||||
if obj is None:
|
||||
return None
|
||||
|
||||
# Atomic claim using optimistic locking on retry_at
|
||||
claimed = Model.objects.filter(
|
||||
pk=obj.pk,
|
||||
retry_at=obj.retry_at,
|
||||
).update(
|
||||
retry_at=timezone.now() + timedelta(seconds=self.MAX_TICK_TIME)
|
||||
)
|
||||
|
||||
# then process it like any other event
|
||||
yield from cls.process_next_event()
|
||||
|
||||
if claimed == 1:
|
||||
obj.refresh_from_db()
|
||||
return obj
|
||||
|
||||
return None # Someone else claimed it
|
||||
|
||||
def process_item(self, obj) -> bool:
|
||||
"""
|
||||
Process a single item by calling its state machine tick().
|
||||
Returns True on success, False on failure.
|
||||
Subclasses can override for custom processing.
|
||||
"""
|
||||
try:
|
||||
obj.sm.tick()
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f'[red]{self} error processing {obj.pk}:[/red] {e}')
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def on_startup(self) -> None:
|
||||
"""Called when worker starts."""
|
||||
self.pid = os.getpid()
|
||||
self.pid_file = write_pid_file(self.name, self.worker_id)
|
||||
print(f'[green]{self} STARTED[/green] pid_file={self.pid_file}')
|
||||
|
||||
def on_shutdown(self, error: BaseException | None = None) -> None:
|
||||
"""Called when worker shuts down."""
|
||||
# Remove PID file
|
||||
if self.pid_file:
|
||||
remove_pid_file(self.pid_file)
|
||||
|
||||
if error and not isinstance(error, KeyboardInterrupt):
|
||||
print(f'[red]{self} SHUTDOWN with error:[/red] {type(error).__name__}: {error}')
|
||||
else:
|
||||
print(f'[grey53]{self} SHUTDOWN[/grey53]')
|
||||
|
||||
def should_exit(self) -> bool:
|
||||
"""Check if worker should exit due to idle timeout."""
|
||||
if self.daemon:
|
||||
return False
|
||||
|
||||
if self.IDLE_TIMEOUT == 0:
|
||||
return False
|
||||
|
||||
return self.idle_count >= self.IDLE_TIMEOUT
|
||||
|
||||
def runloop(self) -> None:
|
||||
"""Main worker loop - polls queue, processes items."""
|
||||
self.on_startup()
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Try to claim and process an item
|
||||
obj = self.claim_next()
|
||||
|
||||
if obj is not None:
|
||||
self.idle_count = 0
|
||||
print(f'[blue]{self} processing:[/blue] {obj.pk}')
|
||||
|
||||
start_time = time.time()
|
||||
success = self.process_item(obj)
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
if success:
|
||||
print(f'[green]{self} completed ({elapsed:.1f}s):[/green] {obj.pk}')
|
||||
else:
|
||||
print(f'[red]{self} failed ({elapsed:.1f}s):[/red] {obj.pk}')
|
||||
else:
|
||||
# No work available
|
||||
self.idle_count += 1
|
||||
if self.idle_count == 1:
|
||||
print(f'[grey53]{self} idle, waiting for work...[/grey53]')
|
||||
|
||||
# Check if we should exit
|
||||
if self.should_exit():
|
||||
print(f'[grey53]{self} idle timeout reached, exiting[/grey53]')
|
||||
break
|
||||
|
||||
time.sleep(self.POLL_INTERVAL)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
except BaseException as e:
|
||||
self.on_shutdown(error=e)
|
||||
raise
|
||||
else:
|
||||
self.on_shutdown()
|
||||
|
||||
@classmethod
|
||||
def receive(cls, event: Event) -> Iterable[EventDict]:
|
||||
handler_method = getattr(cls, f'on_{event.name}', None)
|
||||
if handler_method:
|
||||
yield from handler_method(event)
|
||||
else:
|
||||
raise Exception(f'No handler method for event: {event.name}')
|
||||
def start(cls, worker_id: int | None = None, daemon: bool = False, **kwargs: Any) -> int:
|
||||
"""
|
||||
Fork a new worker as a subprocess.
|
||||
Returns the PID of the new process.
|
||||
"""
|
||||
if worker_id is None:
|
||||
worker_id = get_next_worker_id(cls.name)
|
||||
|
||||
@staticmethod
|
||||
def on_IDLE() -> Iterable[EventDict]:
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def mark_event_claimed(event: Event):
|
||||
proc = Process.current()
|
||||
|
||||
with transaction.atomic():
|
||||
claimed = Event.objects.filter(id=event.id, claimed_proc=None, claimed_at=None).update(claimed_proc=proc, claimed_at=timezone.now())
|
||||
event.refresh_from_db()
|
||||
if not claimed:
|
||||
raise Exception(f'Event already claimed by another process: {event.claimed_proc}')
|
||||
|
||||
print(f'{self}.mark_event_claimed(): Claimed {event} ⛏️')
|
||||
|
||||
# process_updated = Process.objects.filter(id=proc.id, active_event=None).update(active_event=event)
|
||||
# if not process_updated:
|
||||
# raise Exception(f'Unable to update process.active_event: {proc}.active_event = {event}')
|
||||
# Use module-level function for pickling compatibility
|
||||
proc = Process(
|
||||
target=_run_worker,
|
||||
args=(cls.name, worker_id, daemon),
|
||||
kwargs=kwargs,
|
||||
name=f'{cls.name}_worker_{worker_id}',
|
||||
)
|
||||
proc.start()
|
||||
|
||||
@staticmethod
|
||||
def mark_event_succeeded(event: Event, output_events: Iterable[EventDict]):
|
||||
event.refresh_from_db()
|
||||
assert event.claimed_proc, f'Cannot mark event as succeeded if it is not claimed by a process: {event}'
|
||||
assert (event.claimed_proc == Process.current()), f'Cannot mark event as succeeded if it claimed by a different process: {event}.claimed_proc = {event.claimed_proc}, current_process = {Process.current()}'
|
||||
|
||||
with transaction.atomic():
|
||||
updated = Event.objects.filter(id=event.id, claimed_proc=event.claimed_proc, claimed_at=event.claimed_at, finished_at=None).update(finished_at=timezone.now())
|
||||
event.refresh_from_db()
|
||||
if not updated:
|
||||
raise Exception(f'Event {event} failed to mark as succeeded, it was modified by another process: {event.claimed_proc}')
|
||||
assert proc.pid is not None
|
||||
return proc.pid
|
||||
|
||||
# process_updated = Process.objects.filter(id=event.claimed_proc.id, active_event=event).update(active_event=None)
|
||||
# if not process_updated:
|
||||
# raise Exception(f'Unable to unset process.active_event: {event.claimed_proc}.active_event = {event}')
|
||||
@classmethod
|
||||
def get_running_workers(cls) -> list[dict]:
|
||||
"""Get info about all running workers of this type."""
|
||||
cleanup_stale_pid_files()
|
||||
return get_all_worker_pids(cls.name)
|
||||
|
||||
# dispatch any output events
|
||||
for output_event in output_events:
|
||||
Event.dispatch(event=output_event, parent=event)
|
||||
|
||||
# trigger any callback events
|
||||
if event.on_success:
|
||||
Event.dispatch(event=event.on_success, parent=event)
|
||||
|
||||
@staticmethod
|
||||
def mark_event_failed(event: Event, output_events: Iterable[EventDict]=(), error: BaseException | None = None):
|
||||
event.refresh_from_db()
|
||||
assert event.claimed_proc, f'Cannot mark event as failed if it is not claimed by a process: {event}'
|
||||
assert (event.claimed_proc == Process.current()), f'Cannot mark event as failed if it claimed by a different process: {event}.claimed_proc = {event.claimed_proc}, current_process = {Process.current()}'
|
||||
|
||||
with transaction.atomic():
|
||||
updated = Event.objects.filter(id=event.id, claimed_proc=event.claimed_proc, claimed_at=event.claimed_at, finished_at=None).update(finished_at=timezone.now(), error=str(error))
|
||||
event.refresh_from_db()
|
||||
if not updated:
|
||||
raise Exception(f'Event {event} failed to mark as failed, it was modified by another process: {event.claimed_proc}')
|
||||
|
||||
# process_updated = Process.objects.filter(id=event.claimed_proc.id, active_event=event).update(active_event=None)
|
||||
# if not process_updated:
|
||||
# raise Exception(f'Unable to unset process.active_event: {event.claimed_proc}.active_event = {event}')
|
||||
|
||||
|
||||
# add dedicated error event to the output events
|
||||
if not event.name.endswith('_ERROR'):
|
||||
output_events = [
|
||||
*output_events,
|
||||
{'name': f'{event.name}_ERROR', 'msg': f'{type(error).__name__}: {error}'},
|
||||
]
|
||||
|
||||
# dispatch any output events
|
||||
for output_event in output_events:
|
||||
Event.dispatch(event=output_event, parent=event)
|
||||
|
||||
# trigger any callback events
|
||||
if event.on_failure:
|
||||
Event.dispatch(event=event.on_failure, parent=event)
|
||||
@classmethod
|
||||
def get_worker_count(cls) -> int:
|
||||
"""Get count of running workers of this type."""
|
||||
return len(cls.get_running_workers())
|
||||
|
||||
|
||||
class CrawlWorker(Worker):
|
||||
"""Worker for processing Crawl objects."""
|
||||
|
||||
name: ClassVar[str] = 'crawl'
|
||||
MAX_TICK_TIME: ClassVar[int] = 60
|
||||
|
||||
def get_model(self):
|
||||
from crawls.models import Crawl
|
||||
return Crawl
|
||||
|
||||
|
||||
class OrchestratorWorker(WorkerType):
|
||||
name = 'orchestrator'
|
||||
listens_to = 'PROC_'
|
||||
outputs = ['PROC_']
|
||||
|
||||
@staticmethod
|
||||
def on_PROC_IDLE() -> Iterable[EventDict]:
|
||||
# look through all Processes that are not yet launched and launch them
|
||||
to_launch = Process.objects.filter(launched_at=None).order_by('created_at').first()
|
||||
if not to_launch:
|
||||
return []
|
||||
|
||||
yield {'name': 'PROC_LAUNCH', 'id': to_launch.id}
|
||||
|
||||
@staticmethod
|
||||
def on_PROC_LAUNCH(event: Event) -> Iterable[EventDict]:
|
||||
process = Process.create_and_fork(**event.kwargs)
|
||||
yield {'name': 'PROC_LAUNCHED', 'process_id': process.id}
|
||||
|
||||
@staticmethod
|
||||
def on_PROC_EXIT(event: Event) -> Iterable[EventDict]:
|
||||
process = Process.objects.get(id=event.process_id)
|
||||
process.kill()
|
||||
yield {'name': 'PROC_KILLED', 'process_id': process.id}
|
||||
|
||||
@staticmethod
|
||||
def on_PROC_KILL(event: Event) -> Iterable[EventDict]:
|
||||
process = Process.objects.get(id=event.process_id)
|
||||
process.kill()
|
||||
yield {'name': 'PROC_KILLED', 'process_id': process.id}
|
||||
class SnapshotWorker(Worker):
|
||||
"""Worker for processing Snapshot objects."""
|
||||
|
||||
name: ClassVar[str] = 'snapshot'
|
||||
MAX_TICK_TIME: ClassVar[int] = 60
|
||||
|
||||
def get_model(self):
|
||||
from core.models import Snapshot
|
||||
return Snapshot
|
||||
|
||||
|
||||
class FileSystemWorker(WorkerType):
|
||||
name = 'filesystem'
|
||||
listens_to = 'FS_'
|
||||
outputs = ['FS_']
|
||||
class ArchiveResultWorker(Worker):
|
||||
"""Worker for processing ArchiveResult objects."""
|
||||
|
||||
@staticmethod
|
||||
def on_FS_IDLE(event: Event) -> Iterable[EventDict]:
|
||||
# check for tmp files that can be deleted
|
||||
for tmp_file in Path('/tmp').glob('archivebox/*'):
|
||||
yield {'name': 'FS_DELETE', 'path': str(tmp_file)}
|
||||
|
||||
@staticmethod
|
||||
def on_FS_WRITE(event: Event) -> Iterable[EventDict]:
|
||||
with open(event.path, 'w') as f:
|
||||
f.write(event.content)
|
||||
yield {'name': 'FS_CHANGED', 'path': event.path}
|
||||
name: ClassVar[str] = 'archiveresult'
|
||||
MAX_TICK_TIME: ClassVar[int] = 120
|
||||
|
||||
@staticmethod
|
||||
def on_FS_APPEND(event: Event) -> Iterable[EventDict]:
|
||||
with open(event.path, 'a') as f:
|
||||
f.write(event.content)
|
||||
yield {'name': 'FS_CHANGED', 'path': event.path}
|
||||
|
||||
@staticmethod
|
||||
def on_FS_DELETE(event: Event) -> Iterable[EventDict]:
|
||||
os.remove(event.path)
|
||||
yield {'name': 'FS_CHANGED', 'path': event.path}
|
||||
|
||||
@staticmethod
|
||||
def on_FS_RSYNC(event: Event) -> Iterable[EventDict]:
|
||||
os.system(f'rsync -av {event.src} {event.dst}')
|
||||
yield {'name': 'FS_CHANGED', 'path': event.dst}
|
||||
def __init__(self, extractor: str | None = None, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
self.extractor = extractor
|
||||
|
||||
def get_model(self):
|
||||
from core.models import ArchiveResult
|
||||
return ArchiveResult
|
||||
|
||||
def get_queue(self) -> QuerySet:
|
||||
"""Get queue of ArchiveResults ready for processing."""
|
||||
from django.db.models import Exists, OuterRef
|
||||
from core.models import ArchiveResult
|
||||
|
||||
qs = super().get_queue()
|
||||
|
||||
if self.extractor:
|
||||
qs = qs.filter(extractor=self.extractor)
|
||||
|
||||
# Exclude ArchiveResults whose Snapshot already has one in progress
|
||||
in_progress = ArchiveResult.objects.filter(
|
||||
snapshot_id=OuterRef('snapshot_id'),
|
||||
status=ArchiveResult.StatusChoices.STARTED,
|
||||
)
|
||||
qs = qs.exclude(Exists(in_progress))
|
||||
|
||||
return qs
|
||||
|
||||
def process_item(self, obj) -> bool:
|
||||
"""Process an ArchiveResult by running its extractor."""
|
||||
try:
|
||||
obj.sm.tick()
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f'[red]{self} error processing {obj.pk}:[/red] {e}')
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def start(cls, worker_id: int | None = None, daemon: bool = False, extractor: str | None = None, **kwargs: Any) -> int:
|
||||
"""Fork a new worker as subprocess with optional extractor filter."""
|
||||
if worker_id is None:
|
||||
worker_id = get_next_worker_id(cls.name)
|
||||
|
||||
# Use module-level function for pickling compatibility
|
||||
proc = Process(
|
||||
target=_run_worker,
|
||||
args=(cls.name, worker_id, daemon),
|
||||
kwargs={'extractor': extractor, **kwargs},
|
||||
name=f'{cls.name}_worker_{worker_id}',
|
||||
)
|
||||
proc.start()
|
||||
|
||||
assert proc.pid is not None
|
||||
return proc.pid
|
||||
|
||||
|
||||
class CrawlWorker(WorkerType):
|
||||
name = 'crawl'
|
||||
listens_to = 'CRAWL_'
|
||||
outputs = ['CRAWL_', 'FS_', 'SNAPSHOT_']
|
||||
|
||||
@staticmethod
|
||||
def on_CRAWL_IDLE(event: Event) -> Iterable[EventDict]:
|
||||
# check for any stale crawls that can be started or sealed
|
||||
stale_crawl = Crawl.objects.filter(retry_at__lt=timezone.now()).first()
|
||||
if not stale_crawl:
|
||||
return []
|
||||
|
||||
if stale_crawl.can_start():
|
||||
yield {'name': 'CRAWL_START', 'id': stale_crawl.id}
|
||||
|
||||
elif stale_crawl.can_seal():
|
||||
yield {'name': 'CRAWL_SEAL', 'id': stale_crawl.id}
|
||||
|
||||
@staticmethod
|
||||
def on_CRAWL_CREATE(event: Event) -> Iterable[EventDict]:
|
||||
crawl, created = Crawl.objects.get_or_create(id=event.id, defaults=event)
|
||||
if created:
|
||||
yield {'name': 'CRAWL_UPDATED', 'crawl_id': crawl.id}
|
||||
|
||||
@staticmethod
|
||||
def on_CRAWL_UPDATE(event: Event) -> Iterable[EventDict]:
|
||||
crawl = Crawl.objects.get(id=event.pop('crawl_id'))
|
||||
diff = {
|
||||
key: val
|
||||
for key, val in event.items()
|
||||
if getattr(crawl, key) != val
|
||||
}
|
||||
if diff:
|
||||
crawl.update(**diff)
|
||||
yield {'name': 'CRAWL_UPDATED', 'crawl_id': crawl.id}
|
||||
|
||||
@staticmethod
|
||||
def on_CRAWL_UPDATED(event: Event) -> Iterable[EventDict]:
|
||||
crawl = Crawl.objects.get(id=event.crawl_id)
|
||||
yield {'name': 'FS_WRITE_SYMLINKS', 'path': crawl.OUTPUT_DIR, 'symlinks': crawl.output_dir_symlinks}
|
||||
|
||||
|
||||
@staticmethod
|
||||
def on_CRAWL_SEAL(event: Event) -> Iterable[EventDict]:
|
||||
crawl = Crawl.objects.filter(id=event.id, status=Crawl.StatusChoices.STARTED).first()
|
||||
if not crawl:
|
||||
return
|
||||
crawl.status = Crawl.StatusChoices.SEALED
|
||||
crawl.save()
|
||||
yield {'name': 'FS_WRITE', 'path': crawl.OUTPUT_DIR / 'index.json', 'content': json.dumps(crawl.as_json(), default=str, indent=4, sort_keys=True)}
|
||||
yield {'name': 'CRAWL_UPDATED', 'crawl_id': crawl.id}
|
||||
|
||||
@staticmethod
|
||||
def on_CRAWL_START(event: Event) -> Iterable[EventDict]:
|
||||
# create root snapshot
|
||||
crawl = Crawl.objects.get(id=event.crawl_id)
|
||||
new_snapshot_id = uuid.uuid4()
|
||||
yield {'name': 'SNAPSHOT_CREATE', 'snapshot_id': new_snapshot_id, 'crawl_id': crawl.id, 'url': crawl.seed.uri}
|
||||
yield {'name': 'SNAPSHOT_START', 'snapshot_id': new_snapshot_id}
|
||||
yield {'name': 'CRAWL_UPDATE', 'crawl_id': crawl.id, 'status': 'started', 'retry_at': None}
|
||||
# Populate the registry
|
||||
WORKER_TYPES.update({
|
||||
'crawl': CrawlWorker,
|
||||
'snapshot': SnapshotWorker,
|
||||
'archiveresult': ArchiveResultWorker,
|
||||
})
|
||||
|
||||
|
||||
class SnapshotWorker(WorkerType):
|
||||
name = 'snapshot'
|
||||
listens_to = 'SNAPSHOT_'
|
||||
outputs = ['SNAPSHOT_', 'FS_']
|
||||
|
||||
@staticmethod
|
||||
def on_SNAPSHOT_IDLE(event: Event) -> Iterable[EventDict]:
|
||||
# check for any snapshots that can be started or sealed
|
||||
snapshot = Snapshot.objects.exclude(status=Snapshot.StatusChoices.SEALED).first()
|
||||
if not snapshot:
|
||||
return []
|
||||
|
||||
if snapshot.can_start():
|
||||
yield {'name': 'SNAPSHOT_START', 'id': snapshot.id}
|
||||
elif snapshot.can_seal():
|
||||
yield {'name': 'SNAPSHOT_SEAL', 'id': snapshot.id}
|
||||
|
||||
@staticmethod
|
||||
def on_SNAPSHOT_CREATE(event: Event) -> Iterable[EventDict]:
|
||||
snapshot = Snapshot.objects.create(id=event.snapshot_id, **event.kwargs)
|
||||
yield {'name': 'FS_WRITE', 'path': snapshot.OUTPUT_DIR / 'index.json', 'content': json.dumps(snapshot.as_json(), default=str, indent=4, sort_keys=True)}
|
||||
yield {'name': 'SNAPSHOT_UPDATED', 'id': snapshot.id}
|
||||
|
||||
@staticmethod
|
||||
def on_SNAPSHOT_SEAL(event: Event) -> Iterable[EventDict]:
|
||||
snapshot = Snapshot.objects.get(id=event.snapshot_id, status=Snapshot.StatusChoices.STARTED)
|
||||
assert snapshot.can_seal()
|
||||
snapshot.status = Snapshot.StatusChoices.SEALED
|
||||
snapshot.save()
|
||||
yield {'name': 'FS_WRITE', 'path': snapshot.OUTPUT_DIR / 'index.json', 'content': json.dumps(snapshot.as_json(), default=str, indent=4, sort_keys=True)}
|
||||
yield {'name': 'SNAPSHOT_UPDATED', 'id': snapshot.id}
|
||||
|
||||
@staticmethod
|
||||
def on_SNAPSHOT_START(event: Event) -> Iterable[EventDict]:
|
||||
snapshot = Snapshot.objects.get(id=event.snapshot_id, status=Snapshot.StatusChoices.QUEUED)
|
||||
assert snapshot.can_start()
|
||||
|
||||
# create pending archiveresults for each extractor
|
||||
for extractor in snapshot.get_extractors():
|
||||
new_archiveresult_id = uuid.uuid4()
|
||||
yield {'name': 'ARCHIVERESULT_CREATE', 'id': new_archiveresult_id, 'snapshot_id': snapshot.id, 'extractor': extractor.name}
|
||||
yield {'name': 'ARCHIVERESULT_START', 'id': new_archiveresult_id}
|
||||
|
||||
snapshot.status = Snapshot.StatusChoices.STARTED
|
||||
snapshot.save()
|
||||
yield {'name': 'FS_WRITE', 'path': snapshot.OUTPUT_DIR / 'index.json', 'content': json.dumps(snapshot.as_json(), default=str, indent=4, sort_keys=True)}
|
||||
yield {'name': 'SNAPSHOT_UPDATED', 'id': snapshot.id}
|
||||
|
||||
|
||||
|
||||
class ArchiveResultWorker(WorkerType):
|
||||
name = 'archiveresult'
|
||||
listens_to = 'ARCHIVERESULT_'
|
||||
outputs = ['ARCHIVERESULT_', 'FS_']
|
||||
|
||||
@staticmethod
|
||||
def on_ARCHIVERESULT_UPDATE(event: Event) -> Iterable[EventDict]:
|
||||
archiveresult = ArchiveResult.objects.get(id=event.id)
|
||||
diff = {
|
||||
key: val
|
||||
for key, val in event.items()
|
||||
if getattr(archiveresult, key) != val
|
||||
}
|
||||
if diff:
|
||||
archiveresult.update(**diff)
|
||||
yield {'name': 'ARCHIVERESULT_UPDATED', 'id': archiveresult.id}
|
||||
|
||||
@staticmethod
|
||||
def on_ARCHIVERESULT_UPDATED(event: Event) -> Iterable[EventDict]:
|
||||
archiveresult = ArchiveResult.objects.get(id=event.id)
|
||||
yield {'name': 'FS_WRITE_SYMLINKS', 'path': archiveresult.OUTPUT_DIR, 'symlinks': archiveresult.output_dir_symlinks}
|
||||
|
||||
@staticmethod
|
||||
def on_ARCHIVERESULT_CREATE(event: Event) -> Iterable[EventDict]:
|
||||
archiveresult, created = ArchiveResult.objects.get_or_create(id=event.pop('archiveresult_id'), defaults=event)
|
||||
if created:
|
||||
yield {'name': 'ARCHIVERESULT_UPDATE', 'id': archiveresult.id}
|
||||
else:
|
||||
diff = {
|
||||
key: val
|
||||
for key, val in event.items()
|
||||
if getattr(archiveresult, key) != val
|
||||
}
|
||||
assert not diff, f'ArchiveResult {archiveresult.id} already exists and has different values, cannot create on top of it: {diff}'
|
||||
|
||||
@staticmethod
|
||||
def on_ARCHIVERESULT_SEAL(event: Event) -> Iterable[EventDict]:
|
||||
archiveresult = ArchiveResult.objects.get(id=event.id, status=ArchiveResult.StatusChoices.STARTED)
|
||||
assert archiveresult.can_seal()
|
||||
yield {'name': 'ARCHIVERESULT_UPDATE', 'id': archiveresult.id, 'status': 'sealed'}
|
||||
|
||||
@staticmethod
|
||||
def on_ARCHIVERESULT_START(event: Event) -> Iterable[EventDict]:
|
||||
archiveresult = ArchiveResult.objects.get(id=event.id, status=ArchiveResult.StatusChoices.QUEUED)
|
||||
|
||||
yield {
|
||||
'name': 'SHELL_EXEC',
|
||||
'cmd': archiveresult.EXTRACTOR.get_cmd(),
|
||||
'cwd': archiveresult.OUTPUT_DIR,
|
||||
'on_exit': {
|
||||
'name': 'ARCHIVERESULT_SEAL',
|
||||
'id': archiveresult.id,
|
||||
},
|
||||
}
|
||||
|
||||
archiveresult.status = ArchiveResult.StatusChoices.STARTED
|
||||
archiveresult.save()
|
||||
yield {'name': 'FS_WRITE', 'path': archiveresult.OUTPUT_DIR / 'index.json', 'content': json.dumps(archiveresult.as_json(), default=str, indent=4, sort_keys=True)}
|
||||
yield {'name': 'ARCHIVERESULT_UPDATED', 'id': archiveresult.id}
|
||||
|
||||
@staticmethod
|
||||
def on_ARCHIVERESULT_IDLE(event: Event) -> Iterable[EventDict]:
|
||||
stale_archiveresult = ArchiveResult.objects.exclude(status__in=[ArchiveResult.StatusChoices.SUCCEEDED, ArchiveResult.StatusChoices.FAILED]).first()
|
||||
if not stale_archiveresult:
|
||||
return []
|
||||
if stale_archiveresult.can_start():
|
||||
yield {'name': 'ARCHIVERESULT_START', 'id': stale_archiveresult.id}
|
||||
if stale_archiveresult.can_seal():
|
||||
yield {'name': 'ARCHIVERESULT_SEAL', 'id': stale_archiveresult.id}
|
||||
|
||||
|
||||
WORKER_TYPES = [
|
||||
OrchestratorWorker,
|
||||
FileSystemWorker,
|
||||
CrawlWorker,
|
||||
SnapshotWorker,
|
||||
ArchiveResultWorker,
|
||||
]
|
||||
|
||||
def get_worker_type(name: str) -> Type[WorkerType]:
|
||||
for worker_type in WORKER_TYPES:
|
||||
matches_verbose_name = (worker_type.name == name)
|
||||
matches_class_name = (worker_type.__name__.lower() == name.lower())
|
||||
matches_listens_to = (worker_type.listens_to.strip('_').lower() == name.strip('_').lower())
|
||||
if matches_verbose_name or matches_class_name or matches_listens_to:
|
||||
return worker_type
|
||||
raise Exception(f'Worker type not found: {name}')
|
||||
def get_worker_class(name: str) -> type[Worker]:
|
||||
"""Get worker class by name."""
|
||||
if name not in WORKER_TYPES:
|
||||
raise ValueError(f'Unknown worker type: {name}. Valid types: {list(WORKER_TYPES.keys())}')
|
||||
return WORKER_TYPES[name]
|
||||
|
||||
Reference in New Issue
Block a user