wip major changes

This commit is contained in:
Nick Sweeting
2025-12-24 20:09:51 -08:00
parent c1335fed37
commit 1915333b81
450 changed files with 35814 additions and 19015 deletions

View File

@@ -1,9 +1,7 @@
__package__ = 'archivebox.workers'
__order__ = 100
import abx
@abx.hookimpl
def register_admin(admin_site):
from workers.admin import register_admin
register_admin(admin_site)

View File

@@ -1,166 +0,0 @@
# __package__ = 'archivebox.workers'
# import time
# from typing import ClassVar, Type, Iterable, TypedDict
# from django.db.models import QuerySet
# from django.db import transaction
# from django.utils import timezone
# from django.utils.functional import classproperty # type: ignore
# from .models import Event, Process, EventDict
# class ActorType:
# # static class attributes
# name: ClassVar[str]
# event_prefix: ClassVar[str]
# poll_interval: ClassVar[int] = 1
# @classproperty
# def event_queue(cls) -> QuerySet[Event]:
# return Event.objects.filter(type__startswith=cls.event_prefix)
# @classmethod
# def fork(cls, wait_for_first_event=False, exit_on_idle=True) -> Process:
# cmd = ['archivebox', 'actor', cls.name]
# if exit_on_idle:
# cmd.append('--exit-on-idle')
# if wait_for_first_event:
# cmd.append('--wait-for-first-event')
# return Process.create_and_fork(cmd=cmd, actor_type=cls.name)
# @classproperty
# def processes(cls) -> QuerySet[Process]:
# return Process.objects.filter(actor_type=cls.name)
# @classmethod
# def run(cls, wait_for_first_event=False, exit_on_idle=True):
# if wait_for_first_event:
# event = cls.event_queue.get_next_unclaimed()
# while not event:
# time.sleep(cls.poll_interval)
# event = cls.event_queue.get_next_unclaimed()
# while True:
# output_events = list(cls.process_next_event()) or list(cls.process_idle_tick()) # process next event, or tick if idle
# yield from output_events
# if not output_events:
# if exit_on_idle:
# break
# else:
# time.sleep(cls.poll_interval)
# @classmethod
# def process_next_event(cls) -> Iterable[EventDict]:
# event = cls.event_queue.get_next_unclaimed()
# output_events = []
# if not event:
# return []
# cls.mark_event_claimed(event, duration=60)
# try:
# for output_event in cls.receive(event):
# output_events.append(output_event)
# yield output_event
# cls.mark_event_succeeded(event, output_events=output_events)
# except BaseException as e:
# cls.mark_event_failed(event, output_events=output_events, error=e)
# @classmethod
# def process_idle_tick(cls) -> Iterable[EventDict]:
# # reset the idle event to be claimed by the current process
# event, _created = Event.objects.update_or_create(
# name=f'{cls.event_prefix}IDLE',
# emitted_by=Process.current(),
# defaults={
# 'deliver_at': timezone.now(),
# 'claimed_proc': None,
# 'claimed_at': None,
# 'finished_at': None,
# 'error': None,
# 'parent': None,
# },
# )
# # then process it like any other event
# yield from cls.process_next_event()
# @classmethod
# def receive(cls, event: Event) -> Iterable[EventDict]:
# handler_method = getattr(cls, f'on_{event.name}', None)
# if handler_method:
# yield from handler_method(event)
# else:
# raise Exception(f'No handler method for event: {event.name}')
# @staticmethod
# def on_IDLE() -> Iterable[EventDict]:
# return []
# @staticmethod
# def mark_event_claimed(event: Event, duration: int=60):
# proc = Process.current()
# with transaction.atomic():
# claimed = Event.objects.filter(id=event.id, claimed_proc=None, claimed_at=None).update(claimed_proc=proc, claimed_at=timezone.now())
# if not claimed:
# event.refresh_from_db()
# raise Exception(f'Event already claimed by another process: {event.claimed_proc}')
# process_updated = Process.objects.filter(id=proc.id, active_event=None).update(active_event=event)
# if not process_updated:
# raise Exception(f'Unable to update process.active_event: {proc}.active_event = {event}')
# @staticmethod
# def mark_event_succeeded(event: Event, output_events: Iterable[EventDict]):
# assert event.claimed_proc and (event.claimed_proc == Process.current())
# with transaction.atomic():
# updated = Event.objects.filter(id=event.id, claimed_proc=event.claimed_proc, claimed_at=event.claimed_at, finished_at=None).update(finished_at=timezone.now())
# if not updated:
# event.refresh_from_db()
# raise Exception(f'Event {event} failed to mark as succeeded, it was modified by another process: {event.claimed_proc}')
# process_updated = Process.objects.filter(id=event.claimed_proc.id, active_event=event).update(active_event=None)
# if not process_updated:
# raise Exception(f'Unable to unset process.active_event: {event.claimed_proc}.active_event = {event}')
# # dispatch any output events
# for output_event in output_events:
# Event.dispatch(event=output_event, parent=event)
# # trigger any callback events
# if event.on_success:
# Event.dispatch(event=event.on_success, parent=event)
# @staticmethod
# def mark_event_failed(event: Event, output_events: Iterable[EventDict]=(), error: BaseException | None = None):
# assert event.claimed_proc and (event.claimed_proc == Process.current())
# with transaction.atomic():
# updated = event.objects.filter(id=event.id, claimed_proc=event.claimed_proc, claimed_at=event.claimed_at, finished_at=None).update(finished_at=timezone.now(), error=str(error))
# if not updated:
# event.refresh_from_db()
# raise Exception(f'Event {event} failed to mark as failed, it was modified by another process: {event.claimed_proc}')
# process_updated = Process.objects.filter(id=event.claimed_proc.id, active_event=event).update(active_event=None)
# if not process_updated:
# raise Exception(f'Unable to unset process.active_event: {event.claimed_proc}.active_event = {event}')
# # add dedicated error event to the output events
# output_events = [
# *output_events,
# {'name': f'{event.name}_ERROR', 'error': f'{type(error).__name__}: {error}'},
# ]
# # dispatch any output events
# for output_event in output_events:
# Event.dispatch(event=output_event, parent=event)
# # trigger any callback events
# if event.on_failure:
# Event.dispatch(event=event.on_failure, parent=event)

View File

@@ -1,7 +1,5 @@
__package__ = 'archivebox.workers'
import abx
from django.contrib.auth import get_permission_codename
from huey_monitor.apps import HueyMonitorConfig
@@ -20,7 +18,6 @@ class CustomTaskModelAdmin(TaskModelAdmin):
@abx.hookimpl
def register_admin(admin_site):
admin_site.register(TaskModel, CustomTaskModelAdmin)
admin_site.register(SignalInfoModel, SignalInfoModelAdmin)

View File

@@ -1,18 +0,0 @@
from django.core.management.base import BaseCommand
from workers.orchestrator import ArchivingOrchestrator
class Command(BaseCommand):
help = 'Run the archivebox orchestrator'
# def add_arguments(self, parser):
# parser.add_argument('subcommand', type=str, help='The subcommand you want to run')
# parser.add_argument('command_args', nargs='*', help='Arguments to pass to the subcommand')
def handle(self, *args, **kwargs):
orchestrator = ArchivingOrchestrator()
orchestrator.start()

View File

@@ -1,20 +1,14 @@
__package__ = 'archivebox.workers'
import uuid
import json
from typing import ClassVar, Type, Iterable, TypedDict
from typing import ClassVar, Type, Iterable
from datetime import datetime, timedelta
from statemachine.mixins import MachineMixin
from django.db import models
from django.db.models import QuerySet
from django.core import checks
from django.utils import timezone
from django.utils.functional import classproperty
from machine.models import Process
from statemachine import registry, StateMachine, State
@@ -33,31 +27,31 @@ ObjectStateList = Iterable[ObjectState]
class BaseModelWithStateMachine(models.Model, MachineMixin):
id: models.UUIDField
StatusChoices: ClassVar[Type[models.TextChoices]]
# status: models.CharField
# retry_at: models.DateTimeField
state_machine_name: ClassVar[str]
state_field_name: ClassVar[str]
state_machine_attr: ClassVar[str] = 'sm'
bind_events_as_methods: ClassVar[bool] = True
active_state: ClassVar[ObjectState]
retry_at_field_name: ClassVar[str]
class Meta:
abstract = True
@classmethod
def check(cls, sender=None, **kwargs):
errors = super().check(**kwargs)
found_id_field = False
found_status_field = False
found_retry_at_field = False
for field in cls._meta.get_fields():
if getattr(field, '_is_state_field', False):
if cls.state_field_name == field.name:
@@ -74,7 +68,7 @@ class BaseModelWithStateMachine(models.Model, MachineMixin):
found_retry_at_field = True
if field.name == 'id' and getattr(field, 'primary_key', False):
found_id_field = True
if not found_status_field:
errors.append(checks.Error(
f'{cls.__name__}.state_field_name must be defined and point to a StatusField()',
@@ -89,7 +83,7 @@ class BaseModelWithStateMachine(models.Model, MachineMixin):
obj=cls,
id='workers.E013',
))
if not found_id_field:
errors.append(checks.Error(
f'{cls.__name__} must have an id field that is a primary key',
@@ -97,7 +91,7 @@ class BaseModelWithStateMachine(models.Model, MachineMixin):
obj=cls,
id='workers.E014',
))
if not isinstance(cls.state_machine_name, str):
errors.append(checks.Error(
f'{cls.__name__}.state_machine_name must be a dotted-import path to a StateMachine class',
@@ -105,7 +99,7 @@ class BaseModelWithStateMachine(models.Model, MachineMixin):
obj=cls,
id='workers.E015',
))
try:
cls.StateMachineClass
except Exception as err:
@@ -115,7 +109,7 @@ class BaseModelWithStateMachine(models.Model, MachineMixin):
obj=cls,
id='workers.E016',
))
if cls.INITIAL_STATE not in cls.StatusChoices.values:
errors.append(checks.Error(
f'{cls.__name__}.StateMachineClass.initial_state must be present within {cls.__name__}.StatusChoices',
@@ -123,7 +117,7 @@ class BaseModelWithStateMachine(models.Model, MachineMixin):
obj=cls,
id='workers.E017',
))
if cls.ACTIVE_STATE not in cls.StatusChoices.values:
errors.append(checks.Error(
f'{cls.__name__}.active_state must be set to a valid State present within {cls.__name__}.StatusChoices',
@@ -131,8 +125,8 @@ class BaseModelWithStateMachine(models.Model, MachineMixin):
obj=cls,
id='workers.E018',
))
for state in cls.FINAL_STATES:
if state not in cls.StatusChoices.values:
errors.append(checks.Error(
@@ -143,55 +137,106 @@ class BaseModelWithStateMachine(models.Model, MachineMixin):
))
break
return errors
@staticmethod
def _state_to_str(state: ObjectState) -> str:
"""Convert a statemachine.State, models.TextChoices.choices value, or Enum value to a str"""
return str(state.value) if isinstance(state, State) else str(state)
@property
def RETRY_AT(self) -> datetime:
return getattr(self, self.retry_at_field_name)
@RETRY_AT.setter
def RETRY_AT(self, value: datetime):
setattr(self, self.retry_at_field_name, value)
@property
def STATE(self) -> str:
return getattr(self, self.state_field_name)
@STATE.setter
def STATE(self, value: str):
setattr(self, self.state_field_name, value)
def bump_retry_at(self, seconds: int = 10):
self.RETRY_AT = timezone.now() + timedelta(seconds=seconds)
def update_for_workers(self, **kwargs) -> bool:
"""
Atomically update the object's fields for worker processing.
Returns True if the update was successful, False if the object was modified by another worker.
"""
# Get the current retry_at to use as optimistic lock
current_retry_at = self.RETRY_AT
# Apply the updates
for key, value in kwargs.items():
setattr(self, key, value)
# Try to save with optimistic locking
updated = type(self).objects.filter(
pk=self.pk,
retry_at=current_retry_at,
).update(**{k: getattr(self, k) for k in kwargs})
if updated == 1:
self.refresh_from_db()
return True
return False
@classmethod
def get_queue(cls):
"""
Get the sorted and filtered QuerySet of objects that are ready for processing.
Objects are ready if:
- status is not in FINAL_STATES
- retry_at is in the past (or now)
"""
return cls.objects.filter(
retry_at__lte=timezone.now()
).exclude(
status__in=cls.FINAL_STATES
).order_by('retry_at')
@classmethod
def claim_for_worker(cls, obj: 'BaseModelWithStateMachine', lock_seconds: int = 60) -> bool:
"""
Atomically claim an object for processing using optimistic locking.
Returns True if successfully claimed, False if another worker got it first.
"""
updated = cls.objects.filter(
pk=obj.pk,
retry_at=obj.retry_at,
).update(
retry_at=timezone.now() + timedelta(seconds=lock_seconds)
)
return updated == 1
@classproperty
def ACTIVE_STATE(cls) -> str:
return cls._state_to_str(cls.active_state)
@classproperty
def INITIAL_STATE(cls) -> str:
return cls._state_to_str(cls.StateMachineClass.initial_state)
@classproperty
def FINAL_STATES(cls) -> list[str]:
return [cls._state_to_str(state) for state in cls.StateMachineClass.final_states]
@classproperty
def FINAL_OR_ACTIVE_STATES(cls) -> list[str]:
return [*cls.FINAL_STATES, cls.ACTIVE_STATE]
@classmethod
def extend_choices(cls, base_choices: Type[models.TextChoices]):
"""
Decorator to extend the base choices with extra choices, e.g.:
class MyModel(ModelWithStateMachine):
@ModelWithStateMachine.extend_choices(ModelWithStateMachine.StatusChoices)
class StatusChoices(models.TextChoices):
SUCCEEDED = 'succeeded'
@@ -207,12 +252,12 @@ class BaseModelWithStateMachine(models.Model, MachineMixin):
joined[item[0]] = item[1]
return models.TextChoices('StatusChoices', joined)
return wrapper
@classmethod
def StatusField(cls, **kwargs) -> models.CharField:
"""
Used on subclasses to extend/modify the status field with updated kwargs. e.g.:
class MyModel(ModelWithStateMachine):
class StatusChoices(ModelWithStateMachine.StatusChoices):
QUEUED = 'queued', 'Queued'
@@ -221,7 +266,7 @@ class BaseModelWithStateMachine(models.Model, MachineMixin):
BACKOFF = 'backoff', 'Backoff'
FAILED = 'failed', 'Failed'
SKIPPED = 'skipped', 'Skipped'
status = ModelWithStateMachine.StatusField(choices=StatusChoices.choices, default=StatusChoices.QUEUED)
"""
default_kwargs = default_status_field.deconstruct()[3]
@@ -234,7 +279,7 @@ class BaseModelWithStateMachine(models.Model, MachineMixin):
def RetryAtField(cls, **kwargs) -> models.DateTimeField:
"""
Used on subclasses to extend/modify the retry_at field with updated kwargs. e.g.:
class MyModel(ModelWithStateMachine):
retry_at = ModelWithStateMachine.RetryAtField(editable=False)
"""
@@ -243,7 +288,7 @@ class BaseModelWithStateMachine(models.Model, MachineMixin):
field = models.DateTimeField(**updated_kwargs)
field._is_retry_at_field = True # type: ignore
return field
@classproperty
def StateMachineClass(cls) -> Type[StateMachine]:
"""Get the StateMachine class for the given django Model that inherits from MachineMixin"""
@@ -254,271 +299,21 @@ class BaseModelWithStateMachine(models.Model, MachineMixin):
assert issubclass(StateMachineCls, StateMachine)
return StateMachineCls
raise NotImplementedError(f'ActorType[{cls.__name__}] must define .state_machine_name: str that points to a valid StateMachine')
# @classproperty
# def final_q(cls) -> Q:
# """Get the filter for objects that are in a final state"""
# return Q(**{f'{cls.state_field_name}__in': cls.final_states})
# @classproperty
# def active_q(cls) -> Q:
# """Get the filter for objects that are actively processing right now"""
# return Q(**{cls.state_field_name: cls._state_to_str(cls.active_state)}) # e.g. Q(status='started')
# @classproperty
# def stalled_q(cls) -> Q:
# """Get the filter for objects that are marked active but have timed out"""
# return cls.active_q & Q(retry_at__lte=timezone.now()) # e.g. Q(status='started') AND Q(<retry_at is in the past>)
# @classproperty
# def future_q(cls) -> Q:
# """Get the filter for objects that have a retry_at in the future"""
# return Q(retry_at__gt=timezone.now())
# @classproperty
# def pending_q(cls) -> Q:
# """Get the filter for objects that are ready for processing."""
# return ~(cls.active_q) & ~(cls.final_q) & ~(cls.future_q)
# @classmethod
# def get_queue(cls) -> QuerySet:
# """
# Get the sorted and filtered QuerySet of objects that are ready for processing.
# e.g. qs.exclude(status__in=('sealed', 'started'), retry_at__gt=timezone.now()).order_by('retry_at')
# """
# return cls.objects.filter(cls.pending_q)
class ModelWithStateMachine(BaseModelWithStateMachine):
StatusChoices: ClassVar[Type[DefaultStatusChoices]] = DefaultStatusChoices
status: models.CharField = BaseModelWithStateMachine.StatusField()
retry_at: models.DateTimeField = BaseModelWithStateMachine.RetryAtField()
state_machine_name: ClassVar[str] # e.g. 'core.statemachines.ArchiveResultMachine'
state_field_name: ClassVar[str] = 'status'
state_machine_attr: ClassVar[str] = 'sm'
bind_events_as_methods: ClassVar[bool] = True
active_state: ClassVar[str] = StatusChoices.STARTED
retry_at_field_name: ClassVar[str] = 'retry_at'
class Meta:
abstract = True
class EventDict(TypedDict, total=False):
name: str
id: str | uuid.UUID
path: str
content: str
status: str
retry_at: datetime | None
url: str
seed_id: str | uuid.UUID
crawl_id: str | uuid.UUID
snapshot_id: str | uuid.UUID
process_id: str | uuid.UUID
extractor: str
error: str
on_success: dict | None
on_failure: dict | None
class EventManager(models.Manager):
pass
class EventQuerySet(models.QuerySet):
def get_next_unclaimed(self) -> 'Event | None':
return self.filter(claimed_at=None).order_by('deliver_at').first()
def expired(self, older_than: int=60 * 10) -> QuerySet['Event']:
return self.filter(claimed_at__lt=timezone.now() - timedelta(seconds=older_than))
class Event(models.Model):
id = models.UUIDField(primary_key=True, default=uuid.uuid4, null=False, editable=False, unique=True)
# immutable fields
deliver_at = models.DateTimeField(default=timezone.now, null=False, editable=False, unique=True, db_index=True)
name = models.CharField(max_length=255, null=False, blank=False, db_index=True)
kwargs = models.JSONField(default=dict)
timeout = models.IntegerField(null=False, default=60)
parent = models.ForeignKey('Event', null=True, on_delete=models.SET_NULL, related_name='child_events')
emitted_by = models.ForeignKey(Process, null=False, on_delete=models.PROTECT, related_name='emitted_events')
on_success = models.JSONField(null=True)
on_failure = models.JSONField(null=True)
# mutable fields
modified_at = models.DateTimeField(auto_now=True)
claimed_proc = models.ForeignKey(Process, null=True, on_delete=models.CASCADE, related_name='claimed_events')
claimed_at = models.DateTimeField(null=True)
finished_at = models.DateTimeField(null=True)
error = models.TextField(null=True)
objects: EventManager = EventManager.from_queryset(EventQuerySet)()
child_events: models.RelatedManager['Event']
@classmethod
def get_next_timestamp(cls):
"""Get the next monotonically increasing timestamp for the next event.dispatch_at"""
latest_event = cls.objects.order_by('-deliver_at').first()
ts = timezone.now()
if latest_event:
assert ts > latest_event.deliver_at, f'Event.deliver_at is not monotonically increasing: {latest_event.deliver_at} > {ts}'
return ts
@classmethod
def dispatch(cls, name: str | EventDict | None = None, event: EventDict | None = None, **event_init_kwargs) -> 'Event':
"""
Create a new Event and save it to the database.
Can be called as either:
>>> Event.dispatch(name, {**kwargs}, **event_init_kwargs)
# OR
>>> Event.dispatch({name, **kwargs}, **event_init_kwargs)
"""
event_kwargs: EventDict = event or {}
if isinstance(name, dict):
event_kwargs.update(name)
assert isinstance(event_kwargs, dict), 'must be called as Event.dispatch(name, {**kwargs}) or Event.dispatch({name, **kwargs})'
event_name: str = name if (isinstance(name, str) and name) else event_kwargs.pop('name')
new_event = cls(
name=event_name,
kwargs=event_kwargs,
emitted_by=Process.current(),
**event_init_kwargs,
)
new_event.save()
return new_event
def clean(self, *args, **kwargs) -> None:
"""Fill and validate all the event fields"""
# check uuid and deliver_at are set
assert self.id, 'Event.id must be set to a valid v4 UUID'
if not self.deliver_at:
self.deliver_at = self.get_next_timestamp()
assert self.deliver_at and (datetime(2024, 12, 8, 12, 0, 0, tzinfo=timezone.utc) < self.deliver_at < datetime(2100, 12, 31, 23, 59, 0, tzinfo=timezone.utc)), (
f'Event.deliver_at must be set to a valid UTC datetime (got Event.deliver_at = {self.deliver_at})')
# if name is not set but it's found in the kwargs, move it out of the kwargs to the name field
if 'type' in self.kwargs and ((self.name == self.kwargs['type']) or not self.name):
self.name = self.kwargs.pop('type')
if 'name' in self.kwargs and ((self.name == self.kwargs['name']) or not self.name):
self.name = self.kwargs.pop('name')
# check name is set and is a valid identifier
assert isinstance(self.name, str) and len(self.name) > 3, 'Event.name must be set to a non-empty string'
assert self.name.isidentifier(), f'Event.name must be a valid identifier (got Event.name = {self.name})'
assert self.name.isupper(), f'Event.name must be in uppercase (got Event.name = {self.name})'
# check that kwargs keys and values are valid
for key, value in self.kwargs.items():
assert isinstance(key, str), f'Event kwargs keys can only be strings (got Event.kwargs[{key}: {type(key).__name__}])'
assert key not in self._meta.get_fields(), f'Event.kwargs cannot contain "{key}" key (Event.kwargs[{key}] conflicts with with reserved attr Event.{key} = {getattr(self, key)})'
assert json.dumps(value, sort_keys=True), f'Event can only contain JSON serializable values (got Event.kwargs[{key}]: {type(value).__name__} = {value})'
# validate on_success and on_failure are valid event dicts if set
if self.on_success:
assert isinstance(self.on_success, dict) and self.on_success.get('name', '!invalid').isidentifier(), f'Event.on_success must be a valid event dict (got {self.on_success})'
if self.on_failure:
assert isinstance(self.on_failure, dict) and self.on_failure.get('name', '!invalid').isidentifier(), f'Event.on_failure must be a valid event dict (got {self.on_failure})'
# validate mutable fields like claimed_at, claimed_proc, finished_at are set correctly
if self.claimed_at:
assert self.claimed_proc, f'Event.claimed_at and Event.claimed_proc must be set together (only found Event.claimed_at = {self.claimed_at})'
if self.claimed_proc:
assert self.claimed_at, f'Event.claimed_at and Event.claimed_proc must be set together (only found Event.claimed_proc = {self.claimed_proc})'
if self.finished_at:
assert self.claimed_at, f'If Event.finished_at is set, Event.claimed_at and Event.claimed_proc must also be set (Event.claimed_proc = {self.claimed_proc} and Event.claimed_at = {self.claimed_at})'
# validate error is a non-empty string or None
if isinstance(self.error, BaseException):
self.error = f'{type(self.error).__name__}: {self.error}'
if self.error:
assert isinstance(self.error, str) and str(self.error).strip(), f'Event.error must be a non-empty string (got Event.error: {type(self.error).__name__} = {self.error})'
else:
assert self.error is None, f'Event.error must be None or a non-empty string (got Event.error: {type(self.error).__name__} = {self.error})'
def save(self, *args, **kwargs):
self.clean()
return super().save(*args, **kwargs)
def reset(self):
"""Force-update an event to a pending/unclaimed state (without running any of its handlers or callbacks)"""
self.claimed_proc = None
self.claimed_at = None
self.finished_at = None
self.error = None
self.save()
def abort(self):
"""Force-update an event to a completed/failed state (without running any of its handlers or callbacks)"""
self.claimed_proc = Process.current()
self.claimed_at = timezone.now()
self.finished_at = timezone.now()
self.error = 'Aborted'
self.save()
def __repr__(self) -> str:
label = f'[{self.name} {self.kwargs}]'
if self.is_finished:
label += f''
elif self.claimed_proc:
label += f' 🏃'
return label
def __str__(self) -> str:
return repr(self)
@property
def type(self) -> str:
return self.name
@property
def is_queued(self):
return not self.is_claimed and not self.is_finished
@property
def is_claimed(self):
return self.claimed_at is not None
@property
def is_expired(self):
if not self.claimed_at:
return False
elapsed_time = timezone.now() - self.claimed_at
return elapsed_time > timedelta(seconds=self.timeout)
@property
def is_processing(self):
return self.is_claimed and not self.is_finished
@property
def is_finished(self):
return self.finished_at is not None
@property
def is_failed(self):
return self.is_finished and bool(self.error)
@property
def is_succeeded(self):
return self.is_finished and not bool(self.error)
def __getattr__(self, key: str):
"""
Allow access to the event kwargs as attributes e.g.
Event(name='CRAWL_CREATE', kwargs={'some_key': 'some_val'}).some_key -> 'some_val'
"""
return self.kwargs.get(key)

View File

@@ -1,206 +1,287 @@
"""
Orchestrator for managing worker processes.
The Orchestrator polls queues for each model type (Crawl, Snapshot, ArchiveResult)
and lazily spawns worker processes when there is work to be done.
Architecture:
Orchestrator (main loop, polls queues)
├── CrawlWorker subprocess(es)
├── SnapshotWorker subprocess(es)
└── ArchiveResultWorker subprocess(es)
└── Each worker spawns task subprocesses via CLI
Usage:
# Embedded in other commands (exits when done)
orchestrator = Orchestrator(exit_on_idle=True)
orchestrator.runloop()
# Daemon mode (runs forever)
orchestrator = Orchestrator(exit_on_idle=False)
orchestrator.start() # fork and return
# Or run via CLI
archivebox orchestrator [--daemon]
"""
__package__ = 'archivebox.workers'
import os
import time
import sys
import itertools
from typing import Dict, Type, Literal, TYPE_CHECKING
from django.utils.functional import classproperty
from typing import Type
from multiprocessing import Process
from django.utils import timezone
import multiprocessing
from rich import print
# from django.db.models import QuerySet
from django.apps import apps
if TYPE_CHECKING:
from .actor import ActorType
multiprocessing.set_start_method('fork', force=True)
from .worker import Worker, CrawlWorker, SnapshotWorker, ArchiveResultWorker
from .pid_utils import (
write_pid_file,
remove_pid_file,
get_all_worker_pids,
cleanup_stale_pid_files,
)
class Orchestrator:
pid: int
idle_count: int = 0
actor_types: Dict[str, Type['ActorType']] = {}
mode: Literal['thread', 'process'] = 'process'
exit_on_idle: bool = True
max_concurrent_actors: int = 20
"""
Manages worker processes by polling queues and spawning workers as needed.
def __init__(self, actor_types: Dict[str, Type['ActorType']] | None = None, mode: Literal['thread', 'process'] | None=None, exit_on_idle: bool=True, max_concurrent_actors: int=max_concurrent_actors):
self.actor_types = actor_types or self.actor_types or self.autodiscover_actor_types()
self.mode = mode or self.mode
The orchestrator:
1. Polls each model queue (Crawl, Snapshot, ArchiveResult)
2. If items exist and fewer than MAX_CONCURRENT workers are running, spawns workers
3. Monitors worker health and cleans up stale PIDs
4. Exits when all queues are empty (unless daemon mode)
"""
WORKER_TYPES: list[Type[Worker]] = [CrawlWorker, SnapshotWorker, ArchiveResultWorker]
# Configuration
POLL_INTERVAL: float = 1.0
IDLE_TIMEOUT: int = 3 # Exit after N idle ticks (0 = never exit)
MAX_WORKERS_PER_TYPE: int = 4 # Max workers per model type
MAX_TOTAL_WORKERS: int = 12 # Max workers across all types
def __init__(self, exit_on_idle: bool = True):
self.exit_on_idle = exit_on_idle
self.max_concurrent_actors = max_concurrent_actors
self.pid: int = os.getpid()
self.pid_file = None
self.idle_count: int = 0
def __repr__(self) -> str:
label = 'tid' if self.mode == 'thread' else 'pid'
return f'[underline]{self.name}[/underline]\\[{label}={self.pid}]'
def __str__(self) -> str:
return self.__repr__()
@classproperty
def name(cls) -> str:
return cls.__name__ # type: ignore
# def _fork_as_thread(self):
# self.thread = Thread(target=self.runloop)
# self.thread.start()
# assert self.thread.native_id is not None
# return self.thread.native_id
def _fork_as_process(self):
self.process = multiprocessing.Process(target=self.runloop)
self.process.start()
assert self.process.pid is not None
return self.process.pid
def start(self) -> int:
if self.mode == 'thread':
# return self._fork_as_thread()
raise NotImplementedError('Thread-based orchestrators are disabled for now to reduce codebase complexity')
elif self.mode == 'process':
return self._fork_as_process()
raise ValueError(f'Invalid orchestrator mode: {self.mode}')
return f'[underline]Orchestrator[/underline]\\[pid={self.pid}]'
@classmethod
def autodiscover_actor_types(cls) -> Dict[str, Type['ActorType']]:
from archivebox.config.django import setup_django
setup_django()
def is_running(cls) -> bool:
"""Check if an orchestrator is already running."""
workers = get_all_worker_pids('orchestrator')
return len(workers) > 0
def on_startup(self) -> None:
"""Called when orchestrator starts."""
self.pid = os.getpid()
self.pid_file = write_pid_file('orchestrator', worker_id=0)
print(f'[green]👨‍✈️ {self} STARTED[/green]')
# returns a Dict of all discovered {actor_type_id: ActorType} across the codebase
# override this method in a subclass to customize the actor types that are used
# return {'Snapshot': SnapshotWorker, 'ArchiveResult_chrome': ChromeActorType, ...}
from crawls.statemachines import CrawlWorker
from core.statemachines import SnapshotWorker, ArchiveResultWorker
return {
'CrawlWorker': CrawlWorker,
'SnapshotWorker': SnapshotWorker,
'ArchiveResultWorker': ArchiveResultWorker,
# look through all models and find all classes that inherit from ActorType
# actor_type.__name__: actor_type
# for actor_type in abx.pm.hook.get_all_ACTORS_TYPES().values()
}
# Clean up any stale PID files from previous runs
stale_count = cleanup_stale_pid_files()
if stale_count:
print(f'[yellow]👨‍✈️ {self} cleaned up {stale_count} stale PID files[/yellow]')
@classmethod
def get_orphaned_objects(cls, all_queues) -> list:
# returns a list of objects that are in the queues of all actor types but not in the queues of any other actor types
all_queued_ids = itertools.chain(*[queue.values('id', flat=True) for queue in all_queues.values()])
orphaned_objects = []
for model in apps.get_models():
if hasattr(model, 'retry_at'):
orphaned_objects.extend(model.objects.filter(retry_at__lt=timezone.now()).exclude(id__in=all_queued_ids))
return orphaned_objects
@classmethod
def has_future_objects(cls, all_queues) -> bool:
# returns a list of objects that are in the queues of all actor types but not in the queues of any other actor types
return any(
queue.filter(retry_at__gte=timezone.now()).exists()
for queue in all_queues.values()
)
def on_startup(self):
if self.mode == 'thread':
# self.pid = get_native_id()
print(f'[green]👨‍✈️ {self}.on_startup() STARTUP (THREAD)[/green]')
raise NotImplementedError('Thread-based orchestrators are disabled for now to reduce codebase complexity')
elif self.mode == 'process':
self.pid = os.getpid()
print(f'[green]👨‍✈️ {self}.on_startup() STARTUP (PROCESS)[/green]')
# abx.pm.hook.on_orchestrator_startup(self)
def on_shutdown(self, err: BaseException | None = None):
print(f'[grey53]👨‍✈️ {self}.on_shutdown() SHUTTING DOWN[/grey53]', err or '[green](gracefully)[/green]')
# abx.pm.hook.on_orchestrator_shutdown(self)
def on_shutdown(self, error: BaseException | None = None) -> None:
"""Called when orchestrator shuts down."""
if self.pid_file:
remove_pid_file(self.pid_file)
def on_tick_started(self, all_queues):
# total_pending = sum(queue.count() for queue in all_queues.values())
# if total_pending:
# print(f'👨‍✈️ {self}.on_tick_started()', f'total_pending={total_pending}')
# abx.pm.hook.on_orchestrator_tick_started(self, actor_types, all_queues)
pass
if error and not isinstance(error, KeyboardInterrupt):
print(f'[red]👨‍✈️ {self} SHUTDOWN with error:[/red] {type(error).__name__}: {error}')
else:
print(f'[grey53]👨‍✈️ {self} SHUTDOWN[/grey53]')
def on_tick_finished(self, all_queues, all_existing_actors, all_spawned_actors):
# if all_spawned_actors:
# total_queue_length = sum(queue.count() for queue in all_queues.values())
# print(f'[grey53]👨‍✈️ {self}.on_tick_finished() queue={total_queue_length} existing_actors={len(all_existing_actors)} spawned_actors={len(all_spawned_actors)}[/grey53]')
# abx.pm.hook.on_orchestrator_tick_finished(self, actor_types, all_queues)
pass
def on_idle(self, all_queues):
# print(f'👨‍✈️ {self}.on_idle()', f'idle_count={self.idle_count}')
print('.', end='', flush=True, file=sys.stderr)
# abx.pm.hook.on_orchestrator_idle(self)
# check for orphaned objects left behind
if self.idle_count == 60:
orphaned_objects = self.get_orphaned_objects(all_queues)
if orphaned_objects:
print('[red]👨‍✈️ WARNING: some objects may not be processed, no actor has claimed them after 30s:[/red]', orphaned_objects)
if self.idle_count > 3 and self.exit_on_idle and not self.has_future_objects(all_queues):
raise KeyboardInterrupt('✅ All tasks completed, exiting')
def runloop(self):
from archivebox.config.django import setup_django
setup_django()
def get_total_worker_count(self) -> int:
"""Get total count of running workers across all types."""
cleanup_stale_pid_files()
return sum(len(W.get_running_workers()) for W in self.WORKER_TYPES)
def should_spawn_worker(self, WorkerClass: Type[Worker], queue_count: int) -> bool:
"""Determine if we should spawn a new worker of the given type."""
if queue_count == 0:
return False
# Check per-type limit
running_workers = WorkerClass.get_running_workers()
if len(running_workers) >= self.MAX_WORKERS_PER_TYPE:
return False
# Check total limit
if self.get_total_worker_count() >= self.MAX_TOTAL_WORKERS:
return False
# Check if we already have enough workers for the queue size
# Spawn more gradually - don't flood with workers
if len(running_workers) > 0 and queue_count <= len(running_workers) * WorkerClass.MAX_CONCURRENT_TASKS:
return False
return True
def spawn_worker(self, WorkerClass: Type[Worker]) -> int | None:
"""Spawn a new worker process. Returns PID or None if spawn failed."""
try:
pid = WorkerClass.start(daemon=False)
print(f'[blue]👨‍✈️ {self} spawned {WorkerClass.name} worker[/blue] pid={pid}')
return pid
except Exception as e:
print(f'[red]👨‍✈️ {self} failed to spawn {WorkerClass.name} worker:[/red] {e}')
return None
def check_queues_and_spawn_workers(self) -> dict[str, int]:
"""
Check all queues and spawn workers as needed.
Returns dict of queue sizes by worker type.
"""
queue_sizes = {}
for WorkerClass in self.WORKER_TYPES:
# Get queue for this worker type
# Need to instantiate worker to get queue (for model access)
worker = WorkerClass(worker_id=-1) # temp instance just for queue access
queue = worker.get_queue()
queue_count = queue.count()
queue_sizes[WorkerClass.name] = queue_count
# Spawn worker if needed
if self.should_spawn_worker(WorkerClass, queue_count):
self.spawn_worker(WorkerClass)
return queue_sizes
def has_pending_work(self, queue_sizes: dict[str, int]) -> bool:
"""Check if any queue has pending work."""
return any(count > 0 for count in queue_sizes.values())
def has_running_workers(self) -> bool:
"""Check if any workers are still running."""
return self.get_total_worker_count() > 0
def has_future_work(self) -> bool:
"""Check if there's work scheduled for the future (retry_at > now)."""
for WorkerClass in self.WORKER_TYPES:
worker = WorkerClass(worker_id=-1)
Model = worker.get_model()
# Check for items not in final state with future retry_at
future_count = Model.objects.filter(
retry_at__gt=timezone.now()
).exclude(
status__in=Model.FINAL_STATES
).count()
if future_count > 0:
return True
return False
def on_tick(self, queue_sizes: dict[str, int]) -> None:
"""Called each orchestrator tick. Override for custom behavior."""
total_queued = sum(queue_sizes.values())
total_workers = self.get_total_worker_count()
if total_queued > 0 or total_workers > 0:
# Build status line
status_parts = []
for WorkerClass in self.WORKER_TYPES:
name = WorkerClass.name
queued = queue_sizes.get(name, 0)
workers = len(WorkerClass.get_running_workers())
if queued > 0 or workers > 0:
status_parts.append(f'{name}={queued}q/{workers}w')
if status_parts:
print(f'[grey53]👨‍✈️ {self} tick:[/grey53] {" ".join(status_parts)}')
def on_idle(self) -> None:
"""Called when orchestrator is idle (no work, no workers)."""
if self.idle_count == 1:
print(f'[grey53]👨‍✈️ {self} idle, waiting for work...[/grey53]')
def should_exit(self, queue_sizes: dict[str, int]) -> bool:
"""Determine if orchestrator should exit."""
if not self.exit_on_idle:
return False
if self.IDLE_TIMEOUT == 0:
return False
# Don't exit if there's pending or future work
if self.has_pending_work(queue_sizes):
return False
if self.has_running_workers():
return False
if self.has_future_work():
return False
# Exit after idle timeout
return self.idle_count >= self.IDLE_TIMEOUT
def runloop(self) -> None:
"""Main orchestrator loop."""
self.on_startup()
try:
while True:
all_queues = {
actor_type: actor_type.get_queue()
for actor_type in self.actor_types.values()
}
if not all_queues:
raise Exception('Failed to find any actor_types to process')
self.on_tick_started(all_queues)
all_existing_actors = []
all_spawned_actors = []
for actor_type, queue in all_queues.items():
if not queue.exists():
continue
next_obj = queue.first()
print()
print(f'🏃‍♂️ {self}.runloop() {actor_type.__name__.ljust(20)} queue={str(queue.count()).ljust(3)} next={next_obj.id if next_obj else "None"} {next_obj.status if next_obj else "None"} {(timezone.now() - next_obj.retry_at).total_seconds() if next_obj and next_obj.retry_at else "None"}')
# Check queues and spawn workers
queue_sizes = self.check_queues_and_spawn_workers()
# Track idle state
if self.has_pending_work(queue_sizes) or self.has_running_workers():
self.idle_count = 0
try:
existing_actors = actor_type.get_running_actors()
all_existing_actors.extend(existing_actors)
actors_to_spawn = actor_type.get_actors_to_spawn(queue, existing_actors)
can_spawn_num_remaining = self.max_concurrent_actors - len(all_existing_actors) # set max_concurrent_actors=1 to disable multitasking
for launch_kwargs in actors_to_spawn[:can_spawn_num_remaining]:
new_actor_pid = actor_type.start(mode='process', **launch_kwargs)
all_spawned_actors.append(new_actor_pid)
except Exception as err:
print(f'🏃‍♂️ ERROR: {self} Failed to get {actor_type} queue & running actors', err)
except BaseException:
raise
if not any(queue.exists() for queue in all_queues.values()):
self.on_idle(all_queues)
self.idle_count += 1
time.sleep(0.5)
self.on_tick(queue_sizes)
else:
self.idle_count = 0
self.on_tick_finished(all_queues, all_existing_actors, all_spawned_actors)
time.sleep(1)
except BaseException as err:
if isinstance(err, KeyboardInterrupt):
print()
else:
print(f'\n[red]🏃‍♂️ {self}.runloop() FATAL:[/red]', err.__class__.__name__, err)
self.on_shutdown(err=err)
self.idle_count += 1
self.on_idle()
# Check if we should exit
if self.should_exit(queue_sizes):
print(f'[green]👨‍✈️ {self} all work complete, exiting[/green]')
break
time.sleep(self.POLL_INTERVAL)
except KeyboardInterrupt:
print() # Newline after ^C
except BaseException as e:
self.on_shutdown(error=e)
raise
else:
self.on_shutdown()
def start(self) -> int:
"""
Fork orchestrator as a background process.
Returns the PID of the new process.
"""
def run_orchestrator():
from archivebox.config.django import setup_django
setup_django()
self.runloop()
proc = Process(target=run_orchestrator, name='orchestrator')
proc.start()
assert proc.pid is not None
print(f'[green]👨‍✈️ Orchestrator started in background[/green] pid={proc.pid}')
return proc.pid
@classmethod
def get_or_start(cls, exit_on_idle: bool = True) -> 'Orchestrator':
"""
Get running orchestrator or start a new one.
Used by commands like 'add' to ensure orchestrator is running.
"""
if cls.is_running():
print('[grey53]👨‍✈️ Orchestrator already running[/grey53]')
# Return a placeholder - actual orchestrator is in another process
return cls(exit_on_idle=exit_on_idle)
orchestrator = cls(exit_on_idle=exit_on_idle)
return orchestrator

View File

@@ -0,0 +1,191 @@
"""
PID file utilities for tracking worker and orchestrator processes.
PID files are stored in data/tmp/workers/ and contain:
- Line 1: PID
- Line 2: Worker type (orchestrator, crawl, snapshot, archiveresult)
- Line 3: Extractor filter (optional, for archiveresult workers)
- Line 4: Started at ISO timestamp
"""
__package__ = 'archivebox.workers'
import os
import signal
from pathlib import Path
from datetime import datetime, timezone
from django.conf import settings
def get_pid_dir() -> Path:
"""Get the directory for PID files, creating it if needed."""
pid_dir = Path(settings.DATA_DIR) / 'tmp' / 'workers'
pid_dir.mkdir(parents=True, exist_ok=True)
return pid_dir
def write_pid_file(worker_type: str, worker_id: int = 0, extractor: str | None = None) -> Path:
"""
Write a PID file for the current process.
Returns the path to the PID file.
"""
pid_dir = get_pid_dir()
if worker_type == 'orchestrator':
pid_file = pid_dir / 'orchestrator.pid'
else:
pid_file = pid_dir / f'{worker_type}_worker_{worker_id}.pid'
content = f"{os.getpid()}\n{worker_type}\n{extractor or ''}\n{datetime.now(timezone.utc).isoformat()}\n"
pid_file.write_text(content)
return pid_file
def read_pid_file(path: Path) -> dict | None:
"""
Read and parse a PID file.
Returns dict with pid, worker_type, extractor, started_at or None if invalid.
"""
try:
if not path.exists():
return None
lines = path.read_text().strip().split('\n')
if len(lines) < 4:
return None
return {
'pid': int(lines[0]),
'worker_type': lines[1],
'extractor': lines[2] or None,
'started_at': datetime.fromisoformat(lines[3]),
'pid_file': path,
}
except (ValueError, IndexError, OSError):
return None
def remove_pid_file(path: Path) -> None:
"""Remove a PID file if it exists."""
try:
path.unlink(missing_ok=True)
except OSError:
pass
def is_process_alive(pid: int) -> bool:
"""Check if a process with the given PID is still running."""
try:
os.kill(pid, 0) # Signal 0 doesn't kill, just checks
return True
except (OSError, ProcessLookupError):
return False
def get_all_pid_files() -> list[Path]:
"""Get all PID files in the workers directory."""
pid_dir = get_pid_dir()
return list(pid_dir.glob('*.pid'))
def get_all_worker_pids(worker_type: str | None = None) -> list[dict]:
"""
Get info about all running workers.
Optionally filter by worker_type.
"""
workers = []
for pid_file in get_all_pid_files():
info = read_pid_file(pid_file)
if info is None:
continue
# Skip if process is dead
if not is_process_alive(info['pid']):
continue
# Filter by type if specified
if worker_type and info['worker_type'] != worker_type:
continue
workers.append(info)
return workers
def cleanup_stale_pid_files() -> int:
"""
Remove PID files for processes that are no longer running.
Returns the number of stale files removed.
"""
removed = 0
for pid_file in get_all_pid_files():
info = read_pid_file(pid_file)
if info is None:
# Invalid PID file, remove it
remove_pid_file(pid_file)
removed += 1
continue
if not is_process_alive(info['pid']):
remove_pid_file(pid_file)
removed += 1
return removed
def get_running_worker_count(worker_type: str) -> int:
"""Get the count of running workers of a specific type."""
return len(get_all_worker_pids(worker_type))
def get_next_worker_id(worker_type: str) -> int:
"""Get the next available worker ID for a given type."""
existing_ids = set()
for pid_file in get_all_pid_files():
# Parse worker ID from filename like "snapshot_worker_3.pid"
name = pid_file.stem
if name.startswith(f'{worker_type}_worker_'):
try:
worker_id = int(name.split('_')[-1])
existing_ids.add(worker_id)
except ValueError:
continue
# Find the lowest unused ID
next_id = 0
while next_id in existing_ids:
next_id += 1
return next_id
def stop_worker(pid: int, graceful: bool = True) -> bool:
"""
Stop a worker process.
If graceful=True, sends SIGTERM first, then SIGKILL after timeout.
Returns True if process was stopped.
"""
if not is_process_alive(pid):
return True
try:
if graceful:
os.kill(pid, signal.SIGTERM)
# Give it a moment to shut down
import time
for _ in range(10): # Wait up to 1 second
time.sleep(0.1)
if not is_process_alive(pid):
return True
# Force kill if still running
os.kill(pid, signal.SIGKILL)
else:
os.kill(pid, signal.SIGKILL)
return True
except (OSError, ProcessLookupError):
return True # Process already dead

View File

@@ -1,103 +0,0 @@
# import uuid
# from functools import wraps
# from django.db import connection, transaction
# from django.utils import timezone
# from huey.exceptions import TaskLockedException
# from archivebox.config import CONSTANTS
# class SqliteSemaphore:
# def __init__(self, db_path, table_name, name, value=1, timeout=None):
# self.db_path = db_path
# self.table_name = table_name
# self.name = name
# self.value = value
# self.timeout = timeout or 86400 # Set a max age for lock holders
# # Ensure the table exists
# with connection.cursor() as cursor:
# cursor.execute(f"""
# CREATE TABLE IF NOT EXISTS {self.table_name} (
# id TEXT PRIMARY KEY,
# name TEXT,
# timestamp DATETIME
# )
# """)
# def acquire(self, name=None):
# name = name or str(uuid.uuid4())
# now = timezone.now()
# expiration = now - timezone.timedelta(seconds=self.timeout)
# with transaction.atomic():
# # Remove expired locks
# with connection.cursor() as cursor:
# cursor.execute(f"""
# DELETE FROM {self.table_name}
# WHERE name = %s AND timestamp < %s
# """, [self.name, expiration])
# # Try to acquire the lock
# with connection.cursor() as cursor:
# cursor.execute(f"""
# INSERT INTO {self.table_name} (id, name, timestamp)
# SELECT %s, %s, %s
# WHERE (
# SELECT COUNT(*) FROM {self.table_name}
# WHERE name = %s
# ) < %s
# """, [name, self.name, now, self.name, self.value])
# if cursor.rowcount > 0:
# return name
# # If we couldn't acquire the lock, remove our attempted entry
# with connection.cursor() as cursor:
# cursor.execute(f"""
# DELETE FROM {self.table_name}
# WHERE id = %s AND name = %s
# """, [name, self.name])
# return None
# def release(self, name):
# with connection.cursor() as cursor:
# cursor.execute(f"""
# DELETE FROM {self.table_name}
# WHERE id = %s AND name = %s
# """, [name, self.name])
# return cursor.rowcount > 0
# LOCKS_DB_PATH = CONSTANTS.DATABASE_FILE.parent / 'locks.sqlite3'
# def lock_task_semaphore(db_path, table_name, lock_name, value=1, timeout=None):
# """
# Lock which can be acquired multiple times (default = 1).
# NOTE: no provisions are made for blocking, waiting, or notifying. This is
# just a lock which can be acquired a configurable number of times.
# Example:
# # Allow up to 3 workers to run this task concurrently. If the task is
# # locked, retry up to 2 times with a delay of 60s.
# @huey.task(retries=2, retry_delay=60)
# @lock_task_semaphore('path/to/db.sqlite3', 'semaphore_locks', 'my-lock', 3)
# def my_task():
# ...
# """
# sem = SqliteSemaphore(db_path, table_name, lock_name, value, timeout)
# def decorator(fn):
# @wraps(fn)
# def inner(*args, **kwargs):
# tid = sem.acquire()
# if tid is None:
# raise TaskLockedException(f'unable to acquire lock {lock_name}')
# try:
# return fn(*args, **kwargs)
# finally:
# sem.release(tid)
# return inner
# return decorator

View File

@@ -63,61 +63,68 @@ def bg_add(add_kwargs, task=None, parent_task_id=None):
@task(queue="commands", context=True)
def bg_archive_links(args, kwargs=None, task=None, parent_task_id=None):
def bg_archive_snapshots(snapshots, kwargs=None, task=None, parent_task_id=None):
"""
Queue multiple snapshots for archiving via the state machine system.
This sets snapshots to 'queued' status so the orchestrator workers pick them up.
The actual archiving happens through ArchiveResult.run().
"""
get_or_create_supervisord_process(daemonize=False)
from ..extractors import archive_links
from django.utils import timezone
from core.models import Snapshot
if task and parent_task_id:
TaskModel.objects.set_parent_task(main_task_id=parent_task_id, sub_task_id=task.id)
assert args and args[0]
assert snapshots
kwargs = kwargs or {}
rough_count = len(args[0])
process_info = ProcessInfo(task, desc="archive_links", parent_task_id=parent_task_id, total=rough_count)
result = archive_links(*args, **kwargs)
process_info.update(n=rough_count)
return result
rough_count = len(snapshots) if hasattr(snapshots, '__len__') else snapshots.count()
process_info = ProcessInfo(task, desc="archive_snapshots", parent_task_id=parent_task_id, total=rough_count)
@task(queue="commands", context=True)
def bg_archive_link(args, kwargs=None,task=None, parent_task_id=None):
get_or_create_supervisord_process(daemonize=False)
from ..extractors import archive_link
if task and parent_task_id:
TaskModel.objects.set_parent_task(main_task_id=parent_task_id, sub_task_id=task.id)
# Queue snapshots by setting status to queued with immediate retry_at
queued_count = 0
for snapshot in snapshots:
if hasattr(snapshot, 'id'):
# Update snapshot to queued state so workers pick it up
Snapshot.objects.filter(id=snapshot.id).update(
status=Snapshot.StatusChoices.QUEUED,
retry_at=timezone.now(),
)
queued_count += 1
assert args and args[0]
kwargs = kwargs or {}
rough_count = len(args[0])
process_info = ProcessInfo(task, desc="archive_link", parent_task_id=parent_task_id, total=rough_count)
result = archive_link(*args, **kwargs)
process_info.update(n=rough_count)
return result
process_info.update(n=queued_count)
return queued_count
@task(queue="commands", context=True)
def bg_archive_snapshot(snapshot, overwrite=False, methods=None, task=None, parent_task_id=None):
# get_or_create_supervisord_process(daemonize=False)
"""
Queue a single snapshot for archiving via the state machine system.
This sets the snapshot to 'queued' status so the orchestrator workers pick it up.
The actual archiving happens through ArchiveResult.run().
"""
get_or_create_supervisord_process(daemonize=False)
from django.utils import timezone
from core.models import Snapshot
from ..extractors import archive_link
if task and parent_task_id:
TaskModel.objects.set_parent_task(main_task_id=parent_task_id, sub_task_id=task.id)
process_info = ProcessInfo(task, desc="archive_link", parent_task_id=parent_task_id, total=1)
link = snapshot.as_link_with_details()
result = archive_link(link, overwrite=overwrite, methods=methods)
process_info.update(n=1)
return result
process_info = ProcessInfo(task, desc="archive_snapshot", parent_task_id=parent_task_id, total=1)
# Queue the snapshot by setting status to queued
if hasattr(snapshot, 'id'):
Snapshot.objects.filter(id=snapshot.id).update(
status=Snapshot.StatusChoices.QUEUED,
retry_at=timezone.now(),
)
process_info.update(n=1)
return 1
return 0

View File

@@ -115,7 +115,7 @@
const jobElement = document.createElement('div');
jobElement.className = 'job-item';
jobElement.innerHTML = `
<p><a href="/api/v1/core/any/${job.abid}?api_key={{api_token|default:'NONE PROVIDED BY VIEW'}}"><code>${job.abid}</code></a></p>
<p><a href="/api/v1/core/any/${job.id}?api_key={{api_token|default:'NONE PROVIDED BY VIEW'}}"><code>${job.id}</code></a></p>
<p>
<span class="badge badge-${job.status}">${job.status}</span>
<span class="date">♻️ ${formatDate(job.retry_at)}</span>

View File

@@ -1,25 +0,0 @@
from django.test import TestCase
# Create your tests here.
class CrawlActorTest(TestCase):
def test_crawl_creation(self):
seed = Seed.objects.create(uri='https://example.com')
Event.dispatch('CRAWL_CREATE', {'seed_id': seed.id})
crawl_actor = CrawlActor()
output_events = list(crawl_actor.process_next_event())
assert len(output_events) == 1
assert output_events[0].get('name', 'unset') == 'FS_WRITE'
assert output_events[0].get('path') == '/tmp/test_crawl/index.json'
output_events = list(crawl_actor.process_next_event())
assert len(output_events) == 1
assert output_events[0].get('name', 'unset') == 'CRAWL_CREATED'
assert Crawl.objects.filter(seed_id=seed.id).exists(), 'Crawl was not created'

View File

@@ -1,457 +1,330 @@
"""
Worker classes for processing queue items.
Workers poll the database for items to process, claim them atomically,
and run the state machine tick() to process each item.
Architecture:
Orchestrator (spawns workers)
└── Worker (claims items from queue, processes them directly)
"""
__package__ = 'archivebox.workers'
import os
import sys
import time
import uuid
import json
from typing import ClassVar, Iterable, Type
import traceback
from typing import ClassVar, Any
from datetime import timedelta
from pathlib import Path
from multiprocessing import Process, cpu_count
from django.db.models import QuerySet
from django.utils import timezone
from django.conf import settings
from rich import print
from django.db import transaction
from django.db.models import QuerySet
from django.utils import timezone
from django.utils.functional import classproperty # type: ignore
from crawls.models import Crawl
from core.models import Snapshot, ArchiveResult
from workers.models import Event, Process, EventDict
from .pid_utils import (
write_pid_file,
remove_pid_file,
get_all_worker_pids,
get_next_worker_id,
cleanup_stale_pid_files,
)
class WorkerType:
# static class attributes
name: ClassVar[str] # e.g. 'log' or 'filesystem' or 'crawl' or 'snapshot' or 'archiveresult' etc.
listens_to: ClassVar[str] # e.g. 'LOG_' or 'FS_' or 'CRAWL_' or 'SNAPSHOT_' or 'ARCHIVERESULT_' etc.
outputs: ClassVar[list[str]] # e.g. ['LOG_', 'FS_', 'CRAWL_', 'SNAPSHOT_', 'ARCHIVERESULT_'] etc.
poll_interval: ClassVar[int] = 1 # how long to wait before polling for new events
@classproperty
def event_queue(cls) -> QuerySet[Event]:
return Event.objects.filter(name__startswith=cls.listens_to)
CPU_COUNT = cpu_count()
@classmethod
def fork(cls, wait_for_first_event=False, exit_on_idle=True) -> Process:
cmd = ['archivebox', 'worker', cls.name]
if exit_on_idle:
cmd.append('--exit-on-idle')
if wait_for_first_event:
cmd.append('--wait-for-first-event')
return Process.create_and_fork(cmd=cmd, actor_type=cls.name)
# Registry of worker types by name (defined at bottom, referenced here for _run_worker)
WORKER_TYPES: dict[str, type['Worker']] = {}
@classproperty
def processes(cls) -> QuerySet[Process]:
return Process.objects.filter(actor_type=cls.name)
@classmethod
def run(cls, wait_for_first_event=False, exit_on_idle=True):
def _run_worker(worker_class_name: str, worker_id: int, daemon: bool, **kwargs):
"""
Module-level function to run a worker. Must be at module level for pickling.
"""
from archivebox.config.django import setup_django
setup_django()
if wait_for_first_event:
event = cls.event_queue.get_next_unclaimed()
while not event:
time.sleep(cls.poll_interval)
event = cls.event_queue.get_next_unclaimed()
# Get worker class by name to avoid pickling class objects
worker_cls = WORKER_TYPES[worker_class_name]
worker = worker_cls(worker_id=worker_id, daemon=daemon, **kwargs)
worker.runloop()
while True:
output_events = list(cls.process_next_event()) or list(cls.process_idle_tick()) # process next event, or tick if idle
yield from output_events
if not output_events:
if exit_on_idle:
break
else:
time.sleep(cls.poll_interval)
@classmethod
def process_next_event(cls) -> Iterable[EventDict]:
event = cls.event_queue.get_next_unclaimed()
output_events = []
if not event:
return []
cls.mark_event_claimed(event)
print(f'{cls.__name__}[{Process.current().pid}] {event}', file=sys.stderr)
try:
for output_event in cls.receive(event):
output_events.append(output_event)
yield output_event
cls.mark_event_succeeded(event, output_events=output_events)
except BaseException as e:
cls.mark_event_failed(event, output_events=output_events, error=e)
class Worker:
"""
Base worker class that polls a queue and processes items directly.
@classmethod
def process_idle_tick(cls) -> Iterable[EventDict]:
# reset the idle event to be claimed by the current process
event, _created = Event.objects.update_or_create(
name=f'{cls.listens_to}IDLE',
emitted_by=Process.current(),
defaults={
'deliver_at': timezone.now(),
'claimed_proc': None,
'claimed_at': None,
'finished_at': None,
'error': None,
'parent': None,
},
Each item is processed by calling its state machine tick() method.
Workers exit when idle for too long (unless daemon mode).
"""
name: ClassVar[str] = 'worker'
# Configuration (can be overridden by subclasses)
MAX_TICK_TIME: ClassVar[int] = 60
POLL_INTERVAL: ClassVar[float] = 0.5
IDLE_TIMEOUT: ClassVar[int] = 3 # Exit after N idle iterations (set to 0 to never exit)
def __init__(self, worker_id: int = 0, daemon: bool = False, **kwargs: Any):
self.worker_id = worker_id
self.daemon = daemon
self.pid: int = os.getpid()
self.pid_file: Path | None = None
self.idle_count: int = 0
def __repr__(self) -> str:
return f'[underline]{self.__class__.__name__}[/underline]\\[id={self.worker_id}, pid={self.pid}]'
def get_model(self):
"""Get the Django model class. Subclasses must override this."""
raise NotImplementedError("Subclasses must implement get_model()")
def get_queue(self) -> QuerySet:
"""Get the queue of objects ready for processing."""
Model = self.get_model()
return Model.objects.filter(
retry_at__lte=timezone.now()
).exclude(
status__in=Model.FINAL_STATES
).order_by('retry_at')
def claim_next(self):
"""
Atomically claim the next object from the queue.
Returns the claimed object or None if queue is empty or claim failed.
"""
Model = self.get_model()
obj = self.get_queue().first()
if obj is None:
return None
# Atomic claim using optimistic locking on retry_at
claimed = Model.objects.filter(
pk=obj.pk,
retry_at=obj.retry_at,
).update(
retry_at=timezone.now() + timedelta(seconds=self.MAX_TICK_TIME)
)
# then process it like any other event
yield from cls.process_next_event()
if claimed == 1:
obj.refresh_from_db()
return obj
return None # Someone else claimed it
def process_item(self, obj) -> bool:
"""
Process a single item by calling its state machine tick().
Returns True on success, False on failure.
Subclasses can override for custom processing.
"""
try:
obj.sm.tick()
return True
except Exception as e:
print(f'[red]{self} error processing {obj.pk}:[/red] {e}')
traceback.print_exc()
return False
def on_startup(self) -> None:
"""Called when worker starts."""
self.pid = os.getpid()
self.pid_file = write_pid_file(self.name, self.worker_id)
print(f'[green]{self} STARTED[/green] pid_file={self.pid_file}')
def on_shutdown(self, error: BaseException | None = None) -> None:
"""Called when worker shuts down."""
# Remove PID file
if self.pid_file:
remove_pid_file(self.pid_file)
if error and not isinstance(error, KeyboardInterrupt):
print(f'[red]{self} SHUTDOWN with error:[/red] {type(error).__name__}: {error}')
else:
print(f'[grey53]{self} SHUTDOWN[/grey53]')
def should_exit(self) -> bool:
"""Check if worker should exit due to idle timeout."""
if self.daemon:
return False
if self.IDLE_TIMEOUT == 0:
return False
return self.idle_count >= self.IDLE_TIMEOUT
def runloop(self) -> None:
"""Main worker loop - polls queue, processes items."""
self.on_startup()
try:
while True:
# Try to claim and process an item
obj = self.claim_next()
if obj is not None:
self.idle_count = 0
print(f'[blue]{self} processing:[/blue] {obj.pk}')
start_time = time.time()
success = self.process_item(obj)
elapsed = time.time() - start_time
if success:
print(f'[green]{self} completed ({elapsed:.1f}s):[/green] {obj.pk}')
else:
print(f'[red]{self} failed ({elapsed:.1f}s):[/red] {obj.pk}')
else:
# No work available
self.idle_count += 1
if self.idle_count == 1:
print(f'[grey53]{self} idle, waiting for work...[/grey53]')
# Check if we should exit
if self.should_exit():
print(f'[grey53]{self} idle timeout reached, exiting[/grey53]')
break
time.sleep(self.POLL_INTERVAL)
except KeyboardInterrupt:
pass
except BaseException as e:
self.on_shutdown(error=e)
raise
else:
self.on_shutdown()
@classmethod
def receive(cls, event: Event) -> Iterable[EventDict]:
handler_method = getattr(cls, f'on_{event.name}', None)
if handler_method:
yield from handler_method(event)
else:
raise Exception(f'No handler method for event: {event.name}')
def start(cls, worker_id: int | None = None, daemon: bool = False, **kwargs: Any) -> int:
"""
Fork a new worker as a subprocess.
Returns the PID of the new process.
"""
if worker_id is None:
worker_id = get_next_worker_id(cls.name)
@staticmethod
def on_IDLE() -> Iterable[EventDict]:
return []
@staticmethod
def mark_event_claimed(event: Event):
proc = Process.current()
with transaction.atomic():
claimed = Event.objects.filter(id=event.id, claimed_proc=None, claimed_at=None).update(claimed_proc=proc, claimed_at=timezone.now())
event.refresh_from_db()
if not claimed:
raise Exception(f'Event already claimed by another process: {event.claimed_proc}')
print(f'{self}.mark_event_claimed(): Claimed {event} ⛏️')
# process_updated = Process.objects.filter(id=proc.id, active_event=None).update(active_event=event)
# if not process_updated:
# raise Exception(f'Unable to update process.active_event: {proc}.active_event = {event}')
# Use module-level function for pickling compatibility
proc = Process(
target=_run_worker,
args=(cls.name, worker_id, daemon),
kwargs=kwargs,
name=f'{cls.name}_worker_{worker_id}',
)
proc.start()
@staticmethod
def mark_event_succeeded(event: Event, output_events: Iterable[EventDict]):
event.refresh_from_db()
assert event.claimed_proc, f'Cannot mark event as succeeded if it is not claimed by a process: {event}'
assert (event.claimed_proc == Process.current()), f'Cannot mark event as succeeded if it claimed by a different process: {event}.claimed_proc = {event.claimed_proc}, current_process = {Process.current()}'
with transaction.atomic():
updated = Event.objects.filter(id=event.id, claimed_proc=event.claimed_proc, claimed_at=event.claimed_at, finished_at=None).update(finished_at=timezone.now())
event.refresh_from_db()
if not updated:
raise Exception(f'Event {event} failed to mark as succeeded, it was modified by another process: {event.claimed_proc}')
assert proc.pid is not None
return proc.pid
# process_updated = Process.objects.filter(id=event.claimed_proc.id, active_event=event).update(active_event=None)
# if not process_updated:
# raise Exception(f'Unable to unset process.active_event: {event.claimed_proc}.active_event = {event}')
@classmethod
def get_running_workers(cls) -> list[dict]:
"""Get info about all running workers of this type."""
cleanup_stale_pid_files()
return get_all_worker_pids(cls.name)
# dispatch any output events
for output_event in output_events:
Event.dispatch(event=output_event, parent=event)
# trigger any callback events
if event.on_success:
Event.dispatch(event=event.on_success, parent=event)
@staticmethod
def mark_event_failed(event: Event, output_events: Iterable[EventDict]=(), error: BaseException | None = None):
event.refresh_from_db()
assert event.claimed_proc, f'Cannot mark event as failed if it is not claimed by a process: {event}'
assert (event.claimed_proc == Process.current()), f'Cannot mark event as failed if it claimed by a different process: {event}.claimed_proc = {event.claimed_proc}, current_process = {Process.current()}'
with transaction.atomic():
updated = Event.objects.filter(id=event.id, claimed_proc=event.claimed_proc, claimed_at=event.claimed_at, finished_at=None).update(finished_at=timezone.now(), error=str(error))
event.refresh_from_db()
if not updated:
raise Exception(f'Event {event} failed to mark as failed, it was modified by another process: {event.claimed_proc}')
# process_updated = Process.objects.filter(id=event.claimed_proc.id, active_event=event).update(active_event=None)
# if not process_updated:
# raise Exception(f'Unable to unset process.active_event: {event.claimed_proc}.active_event = {event}')
# add dedicated error event to the output events
if not event.name.endswith('_ERROR'):
output_events = [
*output_events,
{'name': f'{event.name}_ERROR', 'msg': f'{type(error).__name__}: {error}'},
]
# dispatch any output events
for output_event in output_events:
Event.dispatch(event=output_event, parent=event)
# trigger any callback events
if event.on_failure:
Event.dispatch(event=event.on_failure, parent=event)
@classmethod
def get_worker_count(cls) -> int:
"""Get count of running workers of this type."""
return len(cls.get_running_workers())
class CrawlWorker(Worker):
"""Worker for processing Crawl objects."""
name: ClassVar[str] = 'crawl'
MAX_TICK_TIME: ClassVar[int] = 60
def get_model(self):
from crawls.models import Crawl
return Crawl
class OrchestratorWorker(WorkerType):
name = 'orchestrator'
listens_to = 'PROC_'
outputs = ['PROC_']
@staticmethod
def on_PROC_IDLE() -> Iterable[EventDict]:
# look through all Processes that are not yet launched and launch them
to_launch = Process.objects.filter(launched_at=None).order_by('created_at').first()
if not to_launch:
return []
yield {'name': 'PROC_LAUNCH', 'id': to_launch.id}
@staticmethod
def on_PROC_LAUNCH(event: Event) -> Iterable[EventDict]:
process = Process.create_and_fork(**event.kwargs)
yield {'name': 'PROC_LAUNCHED', 'process_id': process.id}
@staticmethod
def on_PROC_EXIT(event: Event) -> Iterable[EventDict]:
process = Process.objects.get(id=event.process_id)
process.kill()
yield {'name': 'PROC_KILLED', 'process_id': process.id}
@staticmethod
def on_PROC_KILL(event: Event) -> Iterable[EventDict]:
process = Process.objects.get(id=event.process_id)
process.kill()
yield {'name': 'PROC_KILLED', 'process_id': process.id}
class SnapshotWorker(Worker):
"""Worker for processing Snapshot objects."""
name: ClassVar[str] = 'snapshot'
MAX_TICK_TIME: ClassVar[int] = 60
def get_model(self):
from core.models import Snapshot
return Snapshot
class FileSystemWorker(WorkerType):
name = 'filesystem'
listens_to = 'FS_'
outputs = ['FS_']
class ArchiveResultWorker(Worker):
"""Worker for processing ArchiveResult objects."""
@staticmethod
def on_FS_IDLE(event: Event) -> Iterable[EventDict]:
# check for tmp files that can be deleted
for tmp_file in Path('/tmp').glob('archivebox/*'):
yield {'name': 'FS_DELETE', 'path': str(tmp_file)}
@staticmethod
def on_FS_WRITE(event: Event) -> Iterable[EventDict]:
with open(event.path, 'w') as f:
f.write(event.content)
yield {'name': 'FS_CHANGED', 'path': event.path}
name: ClassVar[str] = 'archiveresult'
MAX_TICK_TIME: ClassVar[int] = 120
@staticmethod
def on_FS_APPEND(event: Event) -> Iterable[EventDict]:
with open(event.path, 'a') as f:
f.write(event.content)
yield {'name': 'FS_CHANGED', 'path': event.path}
@staticmethod
def on_FS_DELETE(event: Event) -> Iterable[EventDict]:
os.remove(event.path)
yield {'name': 'FS_CHANGED', 'path': event.path}
@staticmethod
def on_FS_RSYNC(event: Event) -> Iterable[EventDict]:
os.system(f'rsync -av {event.src} {event.dst}')
yield {'name': 'FS_CHANGED', 'path': event.dst}
def __init__(self, extractor: str | None = None, **kwargs: Any):
super().__init__(**kwargs)
self.extractor = extractor
def get_model(self):
from core.models import ArchiveResult
return ArchiveResult
def get_queue(self) -> QuerySet:
"""Get queue of ArchiveResults ready for processing."""
from django.db.models import Exists, OuterRef
from core.models import ArchiveResult
qs = super().get_queue()
if self.extractor:
qs = qs.filter(extractor=self.extractor)
# Exclude ArchiveResults whose Snapshot already has one in progress
in_progress = ArchiveResult.objects.filter(
snapshot_id=OuterRef('snapshot_id'),
status=ArchiveResult.StatusChoices.STARTED,
)
qs = qs.exclude(Exists(in_progress))
return qs
def process_item(self, obj) -> bool:
"""Process an ArchiveResult by running its extractor."""
try:
obj.sm.tick()
return True
except Exception as e:
print(f'[red]{self} error processing {obj.pk}:[/red] {e}')
traceback.print_exc()
return False
@classmethod
def start(cls, worker_id: int | None = None, daemon: bool = False, extractor: str | None = None, **kwargs: Any) -> int:
"""Fork a new worker as subprocess with optional extractor filter."""
if worker_id is None:
worker_id = get_next_worker_id(cls.name)
# Use module-level function for pickling compatibility
proc = Process(
target=_run_worker,
args=(cls.name, worker_id, daemon),
kwargs={'extractor': extractor, **kwargs},
name=f'{cls.name}_worker_{worker_id}',
)
proc.start()
assert proc.pid is not None
return proc.pid
class CrawlWorker(WorkerType):
name = 'crawl'
listens_to = 'CRAWL_'
outputs = ['CRAWL_', 'FS_', 'SNAPSHOT_']
@staticmethod
def on_CRAWL_IDLE(event: Event) -> Iterable[EventDict]:
# check for any stale crawls that can be started or sealed
stale_crawl = Crawl.objects.filter(retry_at__lt=timezone.now()).first()
if not stale_crawl:
return []
if stale_crawl.can_start():
yield {'name': 'CRAWL_START', 'id': stale_crawl.id}
elif stale_crawl.can_seal():
yield {'name': 'CRAWL_SEAL', 'id': stale_crawl.id}
@staticmethod
def on_CRAWL_CREATE(event: Event) -> Iterable[EventDict]:
crawl, created = Crawl.objects.get_or_create(id=event.id, defaults=event)
if created:
yield {'name': 'CRAWL_UPDATED', 'crawl_id': crawl.id}
@staticmethod
def on_CRAWL_UPDATE(event: Event) -> Iterable[EventDict]:
crawl = Crawl.objects.get(id=event.pop('crawl_id'))
diff = {
key: val
for key, val in event.items()
if getattr(crawl, key) != val
}
if diff:
crawl.update(**diff)
yield {'name': 'CRAWL_UPDATED', 'crawl_id': crawl.id}
@staticmethod
def on_CRAWL_UPDATED(event: Event) -> Iterable[EventDict]:
crawl = Crawl.objects.get(id=event.crawl_id)
yield {'name': 'FS_WRITE_SYMLINKS', 'path': crawl.OUTPUT_DIR, 'symlinks': crawl.output_dir_symlinks}
@staticmethod
def on_CRAWL_SEAL(event: Event) -> Iterable[EventDict]:
crawl = Crawl.objects.filter(id=event.id, status=Crawl.StatusChoices.STARTED).first()
if not crawl:
return
crawl.status = Crawl.StatusChoices.SEALED
crawl.save()
yield {'name': 'FS_WRITE', 'path': crawl.OUTPUT_DIR / 'index.json', 'content': json.dumps(crawl.as_json(), default=str, indent=4, sort_keys=True)}
yield {'name': 'CRAWL_UPDATED', 'crawl_id': crawl.id}
@staticmethod
def on_CRAWL_START(event: Event) -> Iterable[EventDict]:
# create root snapshot
crawl = Crawl.objects.get(id=event.crawl_id)
new_snapshot_id = uuid.uuid4()
yield {'name': 'SNAPSHOT_CREATE', 'snapshot_id': new_snapshot_id, 'crawl_id': crawl.id, 'url': crawl.seed.uri}
yield {'name': 'SNAPSHOT_START', 'snapshot_id': new_snapshot_id}
yield {'name': 'CRAWL_UPDATE', 'crawl_id': crawl.id, 'status': 'started', 'retry_at': None}
# Populate the registry
WORKER_TYPES.update({
'crawl': CrawlWorker,
'snapshot': SnapshotWorker,
'archiveresult': ArchiveResultWorker,
})
class SnapshotWorker(WorkerType):
name = 'snapshot'
listens_to = 'SNAPSHOT_'
outputs = ['SNAPSHOT_', 'FS_']
@staticmethod
def on_SNAPSHOT_IDLE(event: Event) -> Iterable[EventDict]:
# check for any snapshots that can be started or sealed
snapshot = Snapshot.objects.exclude(status=Snapshot.StatusChoices.SEALED).first()
if not snapshot:
return []
if snapshot.can_start():
yield {'name': 'SNAPSHOT_START', 'id': snapshot.id}
elif snapshot.can_seal():
yield {'name': 'SNAPSHOT_SEAL', 'id': snapshot.id}
@staticmethod
def on_SNAPSHOT_CREATE(event: Event) -> Iterable[EventDict]:
snapshot = Snapshot.objects.create(id=event.snapshot_id, **event.kwargs)
yield {'name': 'FS_WRITE', 'path': snapshot.OUTPUT_DIR / 'index.json', 'content': json.dumps(snapshot.as_json(), default=str, indent=4, sort_keys=True)}
yield {'name': 'SNAPSHOT_UPDATED', 'id': snapshot.id}
@staticmethod
def on_SNAPSHOT_SEAL(event: Event) -> Iterable[EventDict]:
snapshot = Snapshot.objects.get(id=event.snapshot_id, status=Snapshot.StatusChoices.STARTED)
assert snapshot.can_seal()
snapshot.status = Snapshot.StatusChoices.SEALED
snapshot.save()
yield {'name': 'FS_WRITE', 'path': snapshot.OUTPUT_DIR / 'index.json', 'content': json.dumps(snapshot.as_json(), default=str, indent=4, sort_keys=True)}
yield {'name': 'SNAPSHOT_UPDATED', 'id': snapshot.id}
@staticmethod
def on_SNAPSHOT_START(event: Event) -> Iterable[EventDict]:
snapshot = Snapshot.objects.get(id=event.snapshot_id, status=Snapshot.StatusChoices.QUEUED)
assert snapshot.can_start()
# create pending archiveresults for each extractor
for extractor in snapshot.get_extractors():
new_archiveresult_id = uuid.uuid4()
yield {'name': 'ARCHIVERESULT_CREATE', 'id': new_archiveresult_id, 'snapshot_id': snapshot.id, 'extractor': extractor.name}
yield {'name': 'ARCHIVERESULT_START', 'id': new_archiveresult_id}
snapshot.status = Snapshot.StatusChoices.STARTED
snapshot.save()
yield {'name': 'FS_WRITE', 'path': snapshot.OUTPUT_DIR / 'index.json', 'content': json.dumps(snapshot.as_json(), default=str, indent=4, sort_keys=True)}
yield {'name': 'SNAPSHOT_UPDATED', 'id': snapshot.id}
class ArchiveResultWorker(WorkerType):
name = 'archiveresult'
listens_to = 'ARCHIVERESULT_'
outputs = ['ARCHIVERESULT_', 'FS_']
@staticmethod
def on_ARCHIVERESULT_UPDATE(event: Event) -> Iterable[EventDict]:
archiveresult = ArchiveResult.objects.get(id=event.id)
diff = {
key: val
for key, val in event.items()
if getattr(archiveresult, key) != val
}
if diff:
archiveresult.update(**diff)
yield {'name': 'ARCHIVERESULT_UPDATED', 'id': archiveresult.id}
@staticmethod
def on_ARCHIVERESULT_UPDATED(event: Event) -> Iterable[EventDict]:
archiveresult = ArchiveResult.objects.get(id=event.id)
yield {'name': 'FS_WRITE_SYMLINKS', 'path': archiveresult.OUTPUT_DIR, 'symlinks': archiveresult.output_dir_symlinks}
@staticmethod
def on_ARCHIVERESULT_CREATE(event: Event) -> Iterable[EventDict]:
archiveresult, created = ArchiveResult.objects.get_or_create(id=event.pop('archiveresult_id'), defaults=event)
if created:
yield {'name': 'ARCHIVERESULT_UPDATE', 'id': archiveresult.id}
else:
diff = {
key: val
for key, val in event.items()
if getattr(archiveresult, key) != val
}
assert not diff, f'ArchiveResult {archiveresult.id} already exists and has different values, cannot create on top of it: {diff}'
@staticmethod
def on_ARCHIVERESULT_SEAL(event: Event) -> Iterable[EventDict]:
archiveresult = ArchiveResult.objects.get(id=event.id, status=ArchiveResult.StatusChoices.STARTED)
assert archiveresult.can_seal()
yield {'name': 'ARCHIVERESULT_UPDATE', 'id': archiveresult.id, 'status': 'sealed'}
@staticmethod
def on_ARCHIVERESULT_START(event: Event) -> Iterable[EventDict]:
archiveresult = ArchiveResult.objects.get(id=event.id, status=ArchiveResult.StatusChoices.QUEUED)
yield {
'name': 'SHELL_EXEC',
'cmd': archiveresult.EXTRACTOR.get_cmd(),
'cwd': archiveresult.OUTPUT_DIR,
'on_exit': {
'name': 'ARCHIVERESULT_SEAL',
'id': archiveresult.id,
},
}
archiveresult.status = ArchiveResult.StatusChoices.STARTED
archiveresult.save()
yield {'name': 'FS_WRITE', 'path': archiveresult.OUTPUT_DIR / 'index.json', 'content': json.dumps(archiveresult.as_json(), default=str, indent=4, sort_keys=True)}
yield {'name': 'ARCHIVERESULT_UPDATED', 'id': archiveresult.id}
@staticmethod
def on_ARCHIVERESULT_IDLE(event: Event) -> Iterable[EventDict]:
stale_archiveresult = ArchiveResult.objects.exclude(status__in=[ArchiveResult.StatusChoices.SUCCEEDED, ArchiveResult.StatusChoices.FAILED]).first()
if not stale_archiveresult:
return []
if stale_archiveresult.can_start():
yield {'name': 'ARCHIVERESULT_START', 'id': stale_archiveresult.id}
if stale_archiveresult.can_seal():
yield {'name': 'ARCHIVERESULT_SEAL', 'id': stale_archiveresult.id}
WORKER_TYPES = [
OrchestratorWorker,
FileSystemWorker,
CrawlWorker,
SnapshotWorker,
ArchiveResultWorker,
]
def get_worker_type(name: str) -> Type[WorkerType]:
for worker_type in WORKER_TYPES:
matches_verbose_name = (worker_type.name == name)
matches_class_name = (worker_type.__name__.lower() == name.lower())
matches_listens_to = (worker_type.listens_to.strip('_').lower() == name.strip('_').lower())
if matches_verbose_name or matches_class_name or matches_listens_to:
return worker_type
raise Exception(f'Worker type not found: {name}')
def get_worker_class(name: str) -> type[Worker]:
"""Get worker class by name."""
if name not in WORKER_TYPES:
raise ValueError(f'Unknown worker type: {name}. Valid types: {list(WORKER_TYPES.keys())}')
return WORKER_TYPES[name]