mirror of
https://github.com/ArchiveBox/ArchiveBox.git
synced 2026-04-05 23:37:58 +10:00
wip major changes
This commit is contained in:
@@ -1,9 +1,7 @@
|
||||
__package__ = 'archivebox.workers'
|
||||
__order__ = 100
|
||||
|
||||
import abx
|
||||
|
||||
@abx.hookimpl
|
||||
def register_admin(admin_site):
|
||||
from workers.admin import register_admin
|
||||
register_admin(admin_site)
|
||||
|
||||
@@ -1,166 +0,0 @@
|
||||
# __package__ = 'archivebox.workers'
|
||||
|
||||
# import time
|
||||
|
||||
|
||||
# from typing import ClassVar, Type, Iterable, TypedDict
|
||||
# from django.db.models import QuerySet
|
||||
# from django.db import transaction
|
||||
# from django.utils import timezone
|
||||
# from django.utils.functional import classproperty # type: ignore
|
||||
|
||||
# from .models import Event, Process, EventDict
|
||||
|
||||
|
||||
# class ActorType:
|
||||
# # static class attributes
|
||||
# name: ClassVar[str]
|
||||
# event_prefix: ClassVar[str]
|
||||
# poll_interval: ClassVar[int] = 1
|
||||
|
||||
# @classproperty
|
||||
# def event_queue(cls) -> QuerySet[Event]:
|
||||
# return Event.objects.filter(type__startswith=cls.event_prefix)
|
||||
|
||||
# @classmethod
|
||||
# def fork(cls, wait_for_first_event=False, exit_on_idle=True) -> Process:
|
||||
# cmd = ['archivebox', 'actor', cls.name]
|
||||
# if exit_on_idle:
|
||||
# cmd.append('--exit-on-idle')
|
||||
# if wait_for_first_event:
|
||||
# cmd.append('--wait-for-first-event')
|
||||
# return Process.create_and_fork(cmd=cmd, actor_type=cls.name)
|
||||
|
||||
# @classproperty
|
||||
# def processes(cls) -> QuerySet[Process]:
|
||||
# return Process.objects.filter(actor_type=cls.name)
|
||||
|
||||
# @classmethod
|
||||
# def run(cls, wait_for_first_event=False, exit_on_idle=True):
|
||||
|
||||
# if wait_for_first_event:
|
||||
# event = cls.event_queue.get_next_unclaimed()
|
||||
# while not event:
|
||||
# time.sleep(cls.poll_interval)
|
||||
# event = cls.event_queue.get_next_unclaimed()
|
||||
|
||||
# while True:
|
||||
# output_events = list(cls.process_next_event()) or list(cls.process_idle_tick()) # process next event, or tick if idle
|
||||
# yield from output_events
|
||||
# if not output_events:
|
||||
# if exit_on_idle:
|
||||
# break
|
||||
# else:
|
||||
# time.sleep(cls.poll_interval)
|
||||
|
||||
# @classmethod
|
||||
# def process_next_event(cls) -> Iterable[EventDict]:
|
||||
# event = cls.event_queue.get_next_unclaimed()
|
||||
# output_events = []
|
||||
|
||||
# if not event:
|
||||
# return []
|
||||
|
||||
# cls.mark_event_claimed(event, duration=60)
|
||||
# try:
|
||||
# for output_event in cls.receive(event):
|
||||
# output_events.append(output_event)
|
||||
# yield output_event
|
||||
# cls.mark_event_succeeded(event, output_events=output_events)
|
||||
# except BaseException as e:
|
||||
# cls.mark_event_failed(event, output_events=output_events, error=e)
|
||||
|
||||
# @classmethod
|
||||
# def process_idle_tick(cls) -> Iterable[EventDict]:
|
||||
# # reset the idle event to be claimed by the current process
|
||||
# event, _created = Event.objects.update_or_create(
|
||||
# name=f'{cls.event_prefix}IDLE',
|
||||
# emitted_by=Process.current(),
|
||||
# defaults={
|
||||
# 'deliver_at': timezone.now(),
|
||||
# 'claimed_proc': None,
|
||||
# 'claimed_at': None,
|
||||
# 'finished_at': None,
|
||||
# 'error': None,
|
||||
# 'parent': None,
|
||||
# },
|
||||
# )
|
||||
|
||||
# # then process it like any other event
|
||||
# yield from cls.process_next_event()
|
||||
|
||||
# @classmethod
|
||||
# def receive(cls, event: Event) -> Iterable[EventDict]:
|
||||
# handler_method = getattr(cls, f'on_{event.name}', None)
|
||||
# if handler_method:
|
||||
# yield from handler_method(event)
|
||||
# else:
|
||||
# raise Exception(f'No handler method for event: {event.name}')
|
||||
|
||||
# @staticmethod
|
||||
# def on_IDLE() -> Iterable[EventDict]:
|
||||
# return []
|
||||
|
||||
# @staticmethod
|
||||
# def mark_event_claimed(event: Event, duration: int=60):
|
||||
# proc = Process.current()
|
||||
|
||||
# with transaction.atomic():
|
||||
# claimed = Event.objects.filter(id=event.id, claimed_proc=None, claimed_at=None).update(claimed_proc=proc, claimed_at=timezone.now())
|
||||
# if not claimed:
|
||||
# event.refresh_from_db()
|
||||
# raise Exception(f'Event already claimed by another process: {event.claimed_proc}')
|
||||
|
||||
# process_updated = Process.objects.filter(id=proc.id, active_event=None).update(active_event=event)
|
||||
# if not process_updated:
|
||||
# raise Exception(f'Unable to update process.active_event: {proc}.active_event = {event}')
|
||||
|
||||
# @staticmethod
|
||||
# def mark_event_succeeded(event: Event, output_events: Iterable[EventDict]):
|
||||
# assert event.claimed_proc and (event.claimed_proc == Process.current())
|
||||
# with transaction.atomic():
|
||||
# updated = Event.objects.filter(id=event.id, claimed_proc=event.claimed_proc, claimed_at=event.claimed_at, finished_at=None).update(finished_at=timezone.now())
|
||||
# if not updated:
|
||||
# event.refresh_from_db()
|
||||
# raise Exception(f'Event {event} failed to mark as succeeded, it was modified by another process: {event.claimed_proc}')
|
||||
|
||||
# process_updated = Process.objects.filter(id=event.claimed_proc.id, active_event=event).update(active_event=None)
|
||||
# if not process_updated:
|
||||
# raise Exception(f'Unable to unset process.active_event: {event.claimed_proc}.active_event = {event}')
|
||||
|
||||
# # dispatch any output events
|
||||
# for output_event in output_events:
|
||||
# Event.dispatch(event=output_event, parent=event)
|
||||
|
||||
# # trigger any callback events
|
||||
# if event.on_success:
|
||||
# Event.dispatch(event=event.on_success, parent=event)
|
||||
|
||||
# @staticmethod
|
||||
# def mark_event_failed(event: Event, output_events: Iterable[EventDict]=(), error: BaseException | None = None):
|
||||
# assert event.claimed_proc and (event.claimed_proc == Process.current())
|
||||
# with transaction.atomic():
|
||||
# updated = event.objects.filter(id=event.id, claimed_proc=event.claimed_proc, claimed_at=event.claimed_at, finished_at=None).update(finished_at=timezone.now(), error=str(error))
|
||||
# if not updated:
|
||||
# event.refresh_from_db()
|
||||
# raise Exception(f'Event {event} failed to mark as failed, it was modified by another process: {event.claimed_proc}')
|
||||
|
||||
# process_updated = Process.objects.filter(id=event.claimed_proc.id, active_event=event).update(active_event=None)
|
||||
# if not process_updated:
|
||||
# raise Exception(f'Unable to unset process.active_event: {event.claimed_proc}.active_event = {event}')
|
||||
|
||||
|
||||
# # add dedicated error event to the output events
|
||||
# output_events = [
|
||||
# *output_events,
|
||||
# {'name': f'{event.name}_ERROR', 'error': f'{type(error).__name__}: {error}'},
|
||||
# ]
|
||||
|
||||
# # dispatch any output events
|
||||
# for output_event in output_events:
|
||||
# Event.dispatch(event=output_event, parent=event)
|
||||
|
||||
# # trigger any callback events
|
||||
# if event.on_failure:
|
||||
# Event.dispatch(event=event.on_failure, parent=event)
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
__package__ = 'archivebox.workers'
|
||||
|
||||
import abx
|
||||
|
||||
from django.contrib.auth import get_permission_codename
|
||||
|
||||
from huey_monitor.apps import HueyMonitorConfig
|
||||
@@ -20,7 +18,6 @@ class CustomTaskModelAdmin(TaskModelAdmin):
|
||||
|
||||
|
||||
|
||||
@abx.hookimpl
|
||||
def register_admin(admin_site):
|
||||
admin_site.register(TaskModel, CustomTaskModelAdmin)
|
||||
admin_site.register(SignalInfoModel, SignalInfoModelAdmin)
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
|
||||
|
||||
from django.core.management.base import BaseCommand
|
||||
|
||||
from workers.orchestrator import ArchivingOrchestrator
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = 'Run the archivebox orchestrator'
|
||||
|
||||
# def add_arguments(self, parser):
|
||||
# parser.add_argument('subcommand', type=str, help='The subcommand you want to run')
|
||||
# parser.add_argument('command_args', nargs='*', help='Arguments to pass to the subcommand')
|
||||
|
||||
|
||||
def handle(self, *args, **kwargs):
|
||||
orchestrator = ArchivingOrchestrator()
|
||||
orchestrator.start()
|
||||
@@ -1,20 +1,14 @@
|
||||
__package__ = 'archivebox.workers'
|
||||
|
||||
import uuid
|
||||
import json
|
||||
|
||||
from typing import ClassVar, Type, Iterable, TypedDict
|
||||
from typing import ClassVar, Type, Iterable
|
||||
from datetime import datetime, timedelta
|
||||
from statemachine.mixins import MachineMixin
|
||||
|
||||
from django.db import models
|
||||
from django.db.models import QuerySet
|
||||
from django.core import checks
|
||||
from django.utils import timezone
|
||||
from django.utils.functional import classproperty
|
||||
|
||||
from machine.models import Process
|
||||
|
||||
from statemachine import registry, StateMachine, State
|
||||
|
||||
|
||||
@@ -33,31 +27,31 @@ ObjectStateList = Iterable[ObjectState]
|
||||
|
||||
class BaseModelWithStateMachine(models.Model, MachineMixin):
|
||||
id: models.UUIDField
|
||||
|
||||
|
||||
StatusChoices: ClassVar[Type[models.TextChoices]]
|
||||
|
||||
|
||||
# status: models.CharField
|
||||
# retry_at: models.DateTimeField
|
||||
|
||||
|
||||
state_machine_name: ClassVar[str]
|
||||
state_field_name: ClassVar[str]
|
||||
state_machine_attr: ClassVar[str] = 'sm'
|
||||
bind_events_as_methods: ClassVar[bool] = True
|
||||
|
||||
|
||||
active_state: ClassVar[ObjectState]
|
||||
retry_at_field_name: ClassVar[str]
|
||||
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
||||
|
||||
@classmethod
|
||||
def check(cls, sender=None, **kwargs):
|
||||
errors = super().check(**kwargs)
|
||||
|
||||
|
||||
found_id_field = False
|
||||
found_status_field = False
|
||||
found_retry_at_field = False
|
||||
|
||||
|
||||
for field in cls._meta.get_fields():
|
||||
if getattr(field, '_is_state_field', False):
|
||||
if cls.state_field_name == field.name:
|
||||
@@ -74,7 +68,7 @@ class BaseModelWithStateMachine(models.Model, MachineMixin):
|
||||
found_retry_at_field = True
|
||||
if field.name == 'id' and getattr(field, 'primary_key', False):
|
||||
found_id_field = True
|
||||
|
||||
|
||||
if not found_status_field:
|
||||
errors.append(checks.Error(
|
||||
f'{cls.__name__}.state_field_name must be defined and point to a StatusField()',
|
||||
@@ -89,7 +83,7 @@ class BaseModelWithStateMachine(models.Model, MachineMixin):
|
||||
obj=cls,
|
||||
id='workers.E013',
|
||||
))
|
||||
|
||||
|
||||
if not found_id_field:
|
||||
errors.append(checks.Error(
|
||||
f'{cls.__name__} must have an id field that is a primary key',
|
||||
@@ -97,7 +91,7 @@ class BaseModelWithStateMachine(models.Model, MachineMixin):
|
||||
obj=cls,
|
||||
id='workers.E014',
|
||||
))
|
||||
|
||||
|
||||
if not isinstance(cls.state_machine_name, str):
|
||||
errors.append(checks.Error(
|
||||
f'{cls.__name__}.state_machine_name must be a dotted-import path to a StateMachine class',
|
||||
@@ -105,7 +99,7 @@ class BaseModelWithStateMachine(models.Model, MachineMixin):
|
||||
obj=cls,
|
||||
id='workers.E015',
|
||||
))
|
||||
|
||||
|
||||
try:
|
||||
cls.StateMachineClass
|
||||
except Exception as err:
|
||||
@@ -115,7 +109,7 @@ class BaseModelWithStateMachine(models.Model, MachineMixin):
|
||||
obj=cls,
|
||||
id='workers.E016',
|
||||
))
|
||||
|
||||
|
||||
if cls.INITIAL_STATE not in cls.StatusChoices.values:
|
||||
errors.append(checks.Error(
|
||||
f'{cls.__name__}.StateMachineClass.initial_state must be present within {cls.__name__}.StatusChoices',
|
||||
@@ -123,7 +117,7 @@ class BaseModelWithStateMachine(models.Model, MachineMixin):
|
||||
obj=cls,
|
||||
id='workers.E017',
|
||||
))
|
||||
|
||||
|
||||
if cls.ACTIVE_STATE not in cls.StatusChoices.values:
|
||||
errors.append(checks.Error(
|
||||
f'{cls.__name__}.active_state must be set to a valid State present within {cls.__name__}.StatusChoices',
|
||||
@@ -131,8 +125,8 @@ class BaseModelWithStateMachine(models.Model, MachineMixin):
|
||||
obj=cls,
|
||||
id='workers.E018',
|
||||
))
|
||||
|
||||
|
||||
|
||||
|
||||
for state in cls.FINAL_STATES:
|
||||
if state not in cls.StatusChoices.values:
|
||||
errors.append(checks.Error(
|
||||
@@ -143,55 +137,106 @@ class BaseModelWithStateMachine(models.Model, MachineMixin):
|
||||
))
|
||||
break
|
||||
return errors
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _state_to_str(state: ObjectState) -> str:
|
||||
"""Convert a statemachine.State, models.TextChoices.choices value, or Enum value to a str"""
|
||||
return str(state.value) if isinstance(state, State) else str(state)
|
||||
|
||||
|
||||
|
||||
|
||||
@property
|
||||
def RETRY_AT(self) -> datetime:
|
||||
return getattr(self, self.retry_at_field_name)
|
||||
|
||||
|
||||
@RETRY_AT.setter
|
||||
def RETRY_AT(self, value: datetime):
|
||||
setattr(self, self.retry_at_field_name, value)
|
||||
|
||||
|
||||
@property
|
||||
def STATE(self) -> str:
|
||||
return getattr(self, self.state_field_name)
|
||||
|
||||
|
||||
@STATE.setter
|
||||
def STATE(self, value: str):
|
||||
setattr(self, self.state_field_name, value)
|
||||
|
||||
|
||||
def bump_retry_at(self, seconds: int = 10):
|
||||
self.RETRY_AT = timezone.now() + timedelta(seconds=seconds)
|
||||
|
||||
|
||||
def update_for_workers(self, **kwargs) -> bool:
|
||||
"""
|
||||
Atomically update the object's fields for worker processing.
|
||||
Returns True if the update was successful, False if the object was modified by another worker.
|
||||
"""
|
||||
# Get the current retry_at to use as optimistic lock
|
||||
current_retry_at = self.RETRY_AT
|
||||
|
||||
# Apply the updates
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
# Try to save with optimistic locking
|
||||
updated = type(self).objects.filter(
|
||||
pk=self.pk,
|
||||
retry_at=current_retry_at,
|
||||
).update(**{k: getattr(self, k) for k in kwargs})
|
||||
|
||||
if updated == 1:
|
||||
self.refresh_from_db()
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_queue(cls):
|
||||
"""
|
||||
Get the sorted and filtered QuerySet of objects that are ready for processing.
|
||||
Objects are ready if:
|
||||
- status is not in FINAL_STATES
|
||||
- retry_at is in the past (or now)
|
||||
"""
|
||||
return cls.objects.filter(
|
||||
retry_at__lte=timezone.now()
|
||||
).exclude(
|
||||
status__in=cls.FINAL_STATES
|
||||
).order_by('retry_at')
|
||||
|
||||
@classmethod
|
||||
def claim_for_worker(cls, obj: 'BaseModelWithStateMachine', lock_seconds: int = 60) -> bool:
|
||||
"""
|
||||
Atomically claim an object for processing using optimistic locking.
|
||||
Returns True if successfully claimed, False if another worker got it first.
|
||||
"""
|
||||
updated = cls.objects.filter(
|
||||
pk=obj.pk,
|
||||
retry_at=obj.retry_at,
|
||||
).update(
|
||||
retry_at=timezone.now() + timedelta(seconds=lock_seconds)
|
||||
)
|
||||
return updated == 1
|
||||
|
||||
@classproperty
|
||||
def ACTIVE_STATE(cls) -> str:
|
||||
return cls._state_to_str(cls.active_state)
|
||||
|
||||
|
||||
@classproperty
|
||||
def INITIAL_STATE(cls) -> str:
|
||||
return cls._state_to_str(cls.StateMachineClass.initial_state)
|
||||
|
||||
|
||||
@classproperty
|
||||
def FINAL_STATES(cls) -> list[str]:
|
||||
return [cls._state_to_str(state) for state in cls.StateMachineClass.final_states]
|
||||
|
||||
|
||||
@classproperty
|
||||
def FINAL_OR_ACTIVE_STATES(cls) -> list[str]:
|
||||
return [*cls.FINAL_STATES, cls.ACTIVE_STATE]
|
||||
|
||||
|
||||
@classmethod
|
||||
def extend_choices(cls, base_choices: Type[models.TextChoices]):
|
||||
"""
|
||||
Decorator to extend the base choices with extra choices, e.g.:
|
||||
|
||||
|
||||
class MyModel(ModelWithStateMachine):
|
||||
|
||||
|
||||
@ModelWithStateMachine.extend_choices(ModelWithStateMachine.StatusChoices)
|
||||
class StatusChoices(models.TextChoices):
|
||||
SUCCEEDED = 'succeeded'
|
||||
@@ -207,12 +252,12 @@ class BaseModelWithStateMachine(models.Model, MachineMixin):
|
||||
joined[item[0]] = item[1]
|
||||
return models.TextChoices('StatusChoices', joined)
|
||||
return wrapper
|
||||
|
||||
|
||||
@classmethod
|
||||
def StatusField(cls, **kwargs) -> models.CharField:
|
||||
"""
|
||||
Used on subclasses to extend/modify the status field with updated kwargs. e.g.:
|
||||
|
||||
|
||||
class MyModel(ModelWithStateMachine):
|
||||
class StatusChoices(ModelWithStateMachine.StatusChoices):
|
||||
QUEUED = 'queued', 'Queued'
|
||||
@@ -221,7 +266,7 @@ class BaseModelWithStateMachine(models.Model, MachineMixin):
|
||||
BACKOFF = 'backoff', 'Backoff'
|
||||
FAILED = 'failed', 'Failed'
|
||||
SKIPPED = 'skipped', 'Skipped'
|
||||
|
||||
|
||||
status = ModelWithStateMachine.StatusField(choices=StatusChoices.choices, default=StatusChoices.QUEUED)
|
||||
"""
|
||||
default_kwargs = default_status_field.deconstruct()[3]
|
||||
@@ -234,7 +279,7 @@ class BaseModelWithStateMachine(models.Model, MachineMixin):
|
||||
def RetryAtField(cls, **kwargs) -> models.DateTimeField:
|
||||
"""
|
||||
Used on subclasses to extend/modify the retry_at field with updated kwargs. e.g.:
|
||||
|
||||
|
||||
class MyModel(ModelWithStateMachine):
|
||||
retry_at = ModelWithStateMachine.RetryAtField(editable=False)
|
||||
"""
|
||||
@@ -243,7 +288,7 @@ class BaseModelWithStateMachine(models.Model, MachineMixin):
|
||||
field = models.DateTimeField(**updated_kwargs)
|
||||
field._is_retry_at_field = True # type: ignore
|
||||
return field
|
||||
|
||||
|
||||
@classproperty
|
||||
def StateMachineClass(cls) -> Type[StateMachine]:
|
||||
"""Get the StateMachine class for the given django Model that inherits from MachineMixin"""
|
||||
@@ -254,271 +299,21 @@ class BaseModelWithStateMachine(models.Model, MachineMixin):
|
||||
assert issubclass(StateMachineCls, StateMachine)
|
||||
return StateMachineCls
|
||||
raise NotImplementedError(f'ActorType[{cls.__name__}] must define .state_machine_name: str that points to a valid StateMachine')
|
||||
|
||||
# @classproperty
|
||||
# def final_q(cls) -> Q:
|
||||
# """Get the filter for objects that are in a final state"""
|
||||
# return Q(**{f'{cls.state_field_name}__in': cls.final_states})
|
||||
|
||||
# @classproperty
|
||||
# def active_q(cls) -> Q:
|
||||
# """Get the filter for objects that are actively processing right now"""
|
||||
# return Q(**{cls.state_field_name: cls._state_to_str(cls.active_state)}) # e.g. Q(status='started')
|
||||
|
||||
# @classproperty
|
||||
# def stalled_q(cls) -> Q:
|
||||
# """Get the filter for objects that are marked active but have timed out"""
|
||||
# return cls.active_q & Q(retry_at__lte=timezone.now()) # e.g. Q(status='started') AND Q(<retry_at is in the past>)
|
||||
|
||||
# @classproperty
|
||||
# def future_q(cls) -> Q:
|
||||
# """Get the filter for objects that have a retry_at in the future"""
|
||||
# return Q(retry_at__gt=timezone.now())
|
||||
|
||||
# @classproperty
|
||||
# def pending_q(cls) -> Q:
|
||||
# """Get the filter for objects that are ready for processing."""
|
||||
# return ~(cls.active_q) & ~(cls.final_q) & ~(cls.future_q)
|
||||
|
||||
# @classmethod
|
||||
# def get_queue(cls) -> QuerySet:
|
||||
# """
|
||||
# Get the sorted and filtered QuerySet of objects that are ready for processing.
|
||||
# e.g. qs.exclude(status__in=('sealed', 'started'), retry_at__gt=timezone.now()).order_by('retry_at')
|
||||
# """
|
||||
# return cls.objects.filter(cls.pending_q)
|
||||
|
||||
|
||||
class ModelWithStateMachine(BaseModelWithStateMachine):
|
||||
StatusChoices: ClassVar[Type[DefaultStatusChoices]] = DefaultStatusChoices
|
||||
|
||||
|
||||
status: models.CharField = BaseModelWithStateMachine.StatusField()
|
||||
retry_at: models.DateTimeField = BaseModelWithStateMachine.RetryAtField()
|
||||
|
||||
|
||||
state_machine_name: ClassVar[str] # e.g. 'core.statemachines.ArchiveResultMachine'
|
||||
state_field_name: ClassVar[str] = 'status'
|
||||
state_machine_attr: ClassVar[str] = 'sm'
|
||||
bind_events_as_methods: ClassVar[bool] = True
|
||||
|
||||
|
||||
active_state: ClassVar[str] = StatusChoices.STARTED
|
||||
retry_at_field_name: ClassVar[str] = 'retry_at'
|
||||
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class EventDict(TypedDict, total=False):
|
||||
name: str
|
||||
|
||||
id: str | uuid.UUID
|
||||
path: str
|
||||
content: str
|
||||
status: str
|
||||
retry_at: datetime | None
|
||||
url: str
|
||||
seed_id: str | uuid.UUID
|
||||
crawl_id: str | uuid.UUID
|
||||
snapshot_id: str | uuid.UUID
|
||||
process_id: str | uuid.UUID
|
||||
extractor: str
|
||||
error: str
|
||||
on_success: dict | None
|
||||
on_failure: dict | None
|
||||
|
||||
class EventManager(models.Manager):
|
||||
pass
|
||||
|
||||
class EventQuerySet(models.QuerySet):
|
||||
def get_next_unclaimed(self) -> 'Event | None':
|
||||
return self.filter(claimed_at=None).order_by('deliver_at').first()
|
||||
|
||||
def expired(self, older_than: int=60 * 10) -> QuerySet['Event']:
|
||||
return self.filter(claimed_at__lt=timezone.now() - timedelta(seconds=older_than))
|
||||
|
||||
|
||||
class Event(models.Model):
|
||||
id = models.UUIDField(primary_key=True, default=uuid.uuid4, null=False, editable=False, unique=True)
|
||||
|
||||
# immutable fields
|
||||
deliver_at = models.DateTimeField(default=timezone.now, null=False, editable=False, unique=True, db_index=True)
|
||||
name = models.CharField(max_length=255, null=False, blank=False, db_index=True)
|
||||
kwargs = models.JSONField(default=dict)
|
||||
timeout = models.IntegerField(null=False, default=60)
|
||||
parent = models.ForeignKey('Event', null=True, on_delete=models.SET_NULL, related_name='child_events')
|
||||
emitted_by = models.ForeignKey(Process, null=False, on_delete=models.PROTECT, related_name='emitted_events')
|
||||
on_success = models.JSONField(null=True)
|
||||
on_failure = models.JSONField(null=True)
|
||||
|
||||
# mutable fields
|
||||
modified_at = models.DateTimeField(auto_now=True)
|
||||
claimed_proc = models.ForeignKey(Process, null=True, on_delete=models.CASCADE, related_name='claimed_events')
|
||||
claimed_at = models.DateTimeField(null=True)
|
||||
finished_at = models.DateTimeField(null=True)
|
||||
error = models.TextField(null=True)
|
||||
|
||||
objects: EventManager = EventManager.from_queryset(EventQuerySet)()
|
||||
|
||||
child_events: models.RelatedManager['Event']
|
||||
|
||||
@classmethod
|
||||
def get_next_timestamp(cls):
|
||||
"""Get the next monotonically increasing timestamp for the next event.dispatch_at"""
|
||||
latest_event = cls.objects.order_by('-deliver_at').first()
|
||||
ts = timezone.now()
|
||||
if latest_event:
|
||||
assert ts > latest_event.deliver_at, f'Event.deliver_at is not monotonically increasing: {latest_event.deliver_at} > {ts}'
|
||||
return ts
|
||||
|
||||
@classmethod
|
||||
def dispatch(cls, name: str | EventDict | None = None, event: EventDict | None = None, **event_init_kwargs) -> 'Event':
|
||||
"""
|
||||
Create a new Event and save it to the database.
|
||||
|
||||
Can be called as either:
|
||||
>>> Event.dispatch(name, {**kwargs}, **event_init_kwargs)
|
||||
# OR
|
||||
>>> Event.dispatch({name, **kwargs}, **event_init_kwargs)
|
||||
"""
|
||||
event_kwargs: EventDict = event or {}
|
||||
if isinstance(name, dict):
|
||||
event_kwargs.update(name)
|
||||
assert isinstance(event_kwargs, dict), 'must be called as Event.dispatch(name, {**kwargs}) or Event.dispatch({name, **kwargs})'
|
||||
|
||||
event_name: str = name if (isinstance(name, str) and name) else event_kwargs.pop('name')
|
||||
|
||||
new_event = cls(
|
||||
name=event_name,
|
||||
kwargs=event_kwargs,
|
||||
emitted_by=Process.current(),
|
||||
**event_init_kwargs,
|
||||
)
|
||||
new_event.save()
|
||||
return new_event
|
||||
|
||||
def clean(self, *args, **kwargs) -> None:
|
||||
"""Fill and validate all the event fields"""
|
||||
|
||||
# check uuid and deliver_at are set
|
||||
assert self.id, 'Event.id must be set to a valid v4 UUID'
|
||||
if not self.deliver_at:
|
||||
self.deliver_at = self.get_next_timestamp()
|
||||
assert self.deliver_at and (datetime(2024, 12, 8, 12, 0, 0, tzinfo=timezone.utc) < self.deliver_at < datetime(2100, 12, 31, 23, 59, 0, tzinfo=timezone.utc)), (
|
||||
f'Event.deliver_at must be set to a valid UTC datetime (got Event.deliver_at = {self.deliver_at})')
|
||||
|
||||
# if name is not set but it's found in the kwargs, move it out of the kwargs to the name field
|
||||
if 'type' in self.kwargs and ((self.name == self.kwargs['type']) or not self.name):
|
||||
self.name = self.kwargs.pop('type')
|
||||
if 'name' in self.kwargs and ((self.name == self.kwargs['name']) or not self.name):
|
||||
self.name = self.kwargs.pop('name')
|
||||
|
||||
# check name is set and is a valid identifier
|
||||
assert isinstance(self.name, str) and len(self.name) > 3, 'Event.name must be set to a non-empty string'
|
||||
assert self.name.isidentifier(), f'Event.name must be a valid identifier (got Event.name = {self.name})'
|
||||
assert self.name.isupper(), f'Event.name must be in uppercase (got Event.name = {self.name})'
|
||||
|
||||
# check that kwargs keys and values are valid
|
||||
for key, value in self.kwargs.items():
|
||||
assert isinstance(key, str), f'Event kwargs keys can only be strings (got Event.kwargs[{key}: {type(key).__name__}])'
|
||||
assert key not in self._meta.get_fields(), f'Event.kwargs cannot contain "{key}" key (Event.kwargs[{key}] conflicts with with reserved attr Event.{key} = {getattr(self, key)})'
|
||||
assert json.dumps(value, sort_keys=True), f'Event can only contain JSON serializable values (got Event.kwargs[{key}]: {type(value).__name__} = {value})'
|
||||
|
||||
# validate on_success and on_failure are valid event dicts if set
|
||||
if self.on_success:
|
||||
assert isinstance(self.on_success, dict) and self.on_success.get('name', '!invalid').isidentifier(), f'Event.on_success must be a valid event dict (got {self.on_success})'
|
||||
if self.on_failure:
|
||||
assert isinstance(self.on_failure, dict) and self.on_failure.get('name', '!invalid').isidentifier(), f'Event.on_failure must be a valid event dict (got {self.on_failure})'
|
||||
|
||||
# validate mutable fields like claimed_at, claimed_proc, finished_at are set correctly
|
||||
if self.claimed_at:
|
||||
assert self.claimed_proc, f'Event.claimed_at and Event.claimed_proc must be set together (only found Event.claimed_at = {self.claimed_at})'
|
||||
if self.claimed_proc:
|
||||
assert self.claimed_at, f'Event.claimed_at and Event.claimed_proc must be set together (only found Event.claimed_proc = {self.claimed_proc})'
|
||||
if self.finished_at:
|
||||
assert self.claimed_at, f'If Event.finished_at is set, Event.claimed_at and Event.claimed_proc must also be set (Event.claimed_proc = {self.claimed_proc} and Event.claimed_at = {self.claimed_at})'
|
||||
|
||||
# validate error is a non-empty string or None
|
||||
if isinstance(self.error, BaseException):
|
||||
self.error = f'{type(self.error).__name__}: {self.error}'
|
||||
if self.error:
|
||||
assert isinstance(self.error, str) and str(self.error).strip(), f'Event.error must be a non-empty string (got Event.error: {type(self.error).__name__} = {self.error})'
|
||||
else:
|
||||
assert self.error is None, f'Event.error must be None or a non-empty string (got Event.error: {type(self.error).__name__} = {self.error})'
|
||||
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
self.clean()
|
||||
return super().save(*args, **kwargs)
|
||||
|
||||
def reset(self):
|
||||
"""Force-update an event to a pending/unclaimed state (without running any of its handlers or callbacks)"""
|
||||
self.claimed_proc = None
|
||||
self.claimed_at = None
|
||||
self.finished_at = None
|
||||
self.error = None
|
||||
self.save()
|
||||
|
||||
def abort(self):
|
||||
"""Force-update an event to a completed/failed state (without running any of its handlers or callbacks)"""
|
||||
self.claimed_proc = Process.current()
|
||||
self.claimed_at = timezone.now()
|
||||
self.finished_at = timezone.now()
|
||||
self.error = 'Aborted'
|
||||
self.save()
|
||||
|
||||
|
||||
def __repr__(self) -> str:
|
||||
label = f'[{self.name} {self.kwargs}]'
|
||||
if self.is_finished:
|
||||
label += f' ✅'
|
||||
elif self.claimed_proc:
|
||||
label += f' 🏃'
|
||||
return label
|
||||
|
||||
def __str__(self) -> str:
|
||||
return repr(self)
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def is_queued(self):
|
||||
return not self.is_claimed and not self.is_finished
|
||||
|
||||
@property
|
||||
def is_claimed(self):
|
||||
return self.claimed_at is not None
|
||||
|
||||
@property
|
||||
def is_expired(self):
|
||||
if not self.claimed_at:
|
||||
return False
|
||||
|
||||
elapsed_time = timezone.now() - self.claimed_at
|
||||
return elapsed_time > timedelta(seconds=self.timeout)
|
||||
|
||||
@property
|
||||
def is_processing(self):
|
||||
return self.is_claimed and not self.is_finished
|
||||
|
||||
@property
|
||||
def is_finished(self):
|
||||
return self.finished_at is not None
|
||||
|
||||
@property
|
||||
def is_failed(self):
|
||||
return self.is_finished and bool(self.error)
|
||||
|
||||
@property
|
||||
def is_succeeded(self):
|
||||
return self.is_finished and not bool(self.error)
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
"""
|
||||
Allow access to the event kwargs as attributes e.g.
|
||||
Event(name='CRAWL_CREATE', kwargs={'some_key': 'some_val'}).some_key -> 'some_val'
|
||||
"""
|
||||
return self.kwargs.get(key)
|
||||
|
||||
@@ -1,206 +1,287 @@
|
||||
"""
|
||||
Orchestrator for managing worker processes.
|
||||
|
||||
The Orchestrator polls queues for each model type (Crawl, Snapshot, ArchiveResult)
|
||||
and lazily spawns worker processes when there is work to be done.
|
||||
|
||||
Architecture:
|
||||
Orchestrator (main loop, polls queues)
|
||||
├── CrawlWorker subprocess(es)
|
||||
├── SnapshotWorker subprocess(es)
|
||||
└── ArchiveResultWorker subprocess(es)
|
||||
└── Each worker spawns task subprocesses via CLI
|
||||
|
||||
Usage:
|
||||
# Embedded in other commands (exits when done)
|
||||
orchestrator = Orchestrator(exit_on_idle=True)
|
||||
orchestrator.runloop()
|
||||
|
||||
# Daemon mode (runs forever)
|
||||
orchestrator = Orchestrator(exit_on_idle=False)
|
||||
orchestrator.start() # fork and return
|
||||
|
||||
# Or run via CLI
|
||||
archivebox orchestrator [--daemon]
|
||||
"""
|
||||
|
||||
__package__ = 'archivebox.workers'
|
||||
|
||||
import os
|
||||
import time
|
||||
import sys
|
||||
import itertools
|
||||
from typing import Dict, Type, Literal, TYPE_CHECKING
|
||||
from django.utils.functional import classproperty
|
||||
from typing import Type
|
||||
from multiprocessing import Process
|
||||
|
||||
from django.utils import timezone
|
||||
|
||||
import multiprocessing
|
||||
|
||||
|
||||
|
||||
from rich import print
|
||||
|
||||
# from django.db.models import QuerySet
|
||||
|
||||
from django.apps import apps
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .actor import ActorType
|
||||
|
||||
|
||||
multiprocessing.set_start_method('fork', force=True)
|
||||
from .worker import Worker, CrawlWorker, SnapshotWorker, ArchiveResultWorker
|
||||
from .pid_utils import (
|
||||
write_pid_file,
|
||||
remove_pid_file,
|
||||
get_all_worker_pids,
|
||||
cleanup_stale_pid_files,
|
||||
)
|
||||
|
||||
|
||||
class Orchestrator:
|
||||
pid: int
|
||||
idle_count: int = 0
|
||||
actor_types: Dict[str, Type['ActorType']] = {}
|
||||
mode: Literal['thread', 'process'] = 'process'
|
||||
exit_on_idle: bool = True
|
||||
max_concurrent_actors: int = 20
|
||||
"""
|
||||
Manages worker processes by polling queues and spawning workers as needed.
|
||||
|
||||
def __init__(self, actor_types: Dict[str, Type['ActorType']] | None = None, mode: Literal['thread', 'process'] | None=None, exit_on_idle: bool=True, max_concurrent_actors: int=max_concurrent_actors):
|
||||
self.actor_types = actor_types or self.actor_types or self.autodiscover_actor_types()
|
||||
self.mode = mode or self.mode
|
||||
The orchestrator:
|
||||
1. Polls each model queue (Crawl, Snapshot, ArchiveResult)
|
||||
2. If items exist and fewer than MAX_CONCURRENT workers are running, spawns workers
|
||||
3. Monitors worker health and cleans up stale PIDs
|
||||
4. Exits when all queues are empty (unless daemon mode)
|
||||
"""
|
||||
|
||||
WORKER_TYPES: list[Type[Worker]] = [CrawlWorker, SnapshotWorker, ArchiveResultWorker]
|
||||
|
||||
# Configuration
|
||||
POLL_INTERVAL: float = 1.0
|
||||
IDLE_TIMEOUT: int = 3 # Exit after N idle ticks (0 = never exit)
|
||||
MAX_WORKERS_PER_TYPE: int = 4 # Max workers per model type
|
||||
MAX_TOTAL_WORKERS: int = 12 # Max workers across all types
|
||||
|
||||
def __init__(self, exit_on_idle: bool = True):
|
||||
self.exit_on_idle = exit_on_idle
|
||||
self.max_concurrent_actors = max_concurrent_actors
|
||||
|
||||
self.pid: int = os.getpid()
|
||||
self.pid_file = None
|
||||
self.idle_count: int = 0
|
||||
|
||||
def __repr__(self) -> str:
|
||||
label = 'tid' if self.mode == 'thread' else 'pid'
|
||||
return f'[underline]{self.name}[/underline]\\[{label}={self.pid}]'
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.__repr__()
|
||||
|
||||
@classproperty
|
||||
def name(cls) -> str:
|
||||
return cls.__name__ # type: ignore
|
||||
|
||||
# def _fork_as_thread(self):
|
||||
# self.thread = Thread(target=self.runloop)
|
||||
# self.thread.start()
|
||||
# assert self.thread.native_id is not None
|
||||
# return self.thread.native_id
|
||||
|
||||
def _fork_as_process(self):
|
||||
self.process = multiprocessing.Process(target=self.runloop)
|
||||
self.process.start()
|
||||
assert self.process.pid is not None
|
||||
return self.process.pid
|
||||
|
||||
def start(self) -> int:
|
||||
if self.mode == 'thread':
|
||||
# return self._fork_as_thread()
|
||||
raise NotImplementedError('Thread-based orchestrators are disabled for now to reduce codebase complexity')
|
||||
elif self.mode == 'process':
|
||||
return self._fork_as_process()
|
||||
raise ValueError(f'Invalid orchestrator mode: {self.mode}')
|
||||
return f'[underline]Orchestrator[/underline]\\[pid={self.pid}]'
|
||||
|
||||
@classmethod
|
||||
def autodiscover_actor_types(cls) -> Dict[str, Type['ActorType']]:
|
||||
from archivebox.config.django import setup_django
|
||||
setup_django()
|
||||
def is_running(cls) -> bool:
|
||||
"""Check if an orchestrator is already running."""
|
||||
workers = get_all_worker_pids('orchestrator')
|
||||
return len(workers) > 0
|
||||
|
||||
def on_startup(self) -> None:
|
||||
"""Called when orchestrator starts."""
|
||||
self.pid = os.getpid()
|
||||
self.pid_file = write_pid_file('orchestrator', worker_id=0)
|
||||
print(f'[green]👨✈️ {self} STARTED[/green]')
|
||||
|
||||
# returns a Dict of all discovered {actor_type_id: ActorType} across the codebase
|
||||
# override this method in a subclass to customize the actor types that are used
|
||||
# return {'Snapshot': SnapshotWorker, 'ArchiveResult_chrome': ChromeActorType, ...}
|
||||
from crawls.statemachines import CrawlWorker
|
||||
from core.statemachines import SnapshotWorker, ArchiveResultWorker
|
||||
return {
|
||||
'CrawlWorker': CrawlWorker,
|
||||
'SnapshotWorker': SnapshotWorker,
|
||||
'ArchiveResultWorker': ArchiveResultWorker,
|
||||
# look through all models and find all classes that inherit from ActorType
|
||||
# actor_type.__name__: actor_type
|
||||
# for actor_type in abx.pm.hook.get_all_ACTORS_TYPES().values()
|
||||
}
|
||||
# Clean up any stale PID files from previous runs
|
||||
stale_count = cleanup_stale_pid_files()
|
||||
if stale_count:
|
||||
print(f'[yellow]👨✈️ {self} cleaned up {stale_count} stale PID files[/yellow]')
|
||||
|
||||
@classmethod
|
||||
def get_orphaned_objects(cls, all_queues) -> list:
|
||||
# returns a list of objects that are in the queues of all actor types but not in the queues of any other actor types
|
||||
all_queued_ids = itertools.chain(*[queue.values('id', flat=True) for queue in all_queues.values()])
|
||||
orphaned_objects = []
|
||||
for model in apps.get_models():
|
||||
if hasattr(model, 'retry_at'):
|
||||
orphaned_objects.extend(model.objects.filter(retry_at__lt=timezone.now()).exclude(id__in=all_queued_ids))
|
||||
return orphaned_objects
|
||||
|
||||
@classmethod
|
||||
def has_future_objects(cls, all_queues) -> bool:
|
||||
# returns a list of objects that are in the queues of all actor types but not in the queues of any other actor types
|
||||
|
||||
return any(
|
||||
queue.filter(retry_at__gte=timezone.now()).exists()
|
||||
for queue in all_queues.values()
|
||||
)
|
||||
|
||||
def on_startup(self):
|
||||
if self.mode == 'thread':
|
||||
# self.pid = get_native_id()
|
||||
print(f'[green]👨✈️ {self}.on_startup() STARTUP (THREAD)[/green]')
|
||||
raise NotImplementedError('Thread-based orchestrators are disabled for now to reduce codebase complexity')
|
||||
elif self.mode == 'process':
|
||||
self.pid = os.getpid()
|
||||
print(f'[green]👨✈️ {self}.on_startup() STARTUP (PROCESS)[/green]')
|
||||
# abx.pm.hook.on_orchestrator_startup(self)
|
||||
|
||||
def on_shutdown(self, err: BaseException | None = None):
|
||||
print(f'[grey53]👨✈️ {self}.on_shutdown() SHUTTING DOWN[/grey53]', err or '[green](gracefully)[/green]')
|
||||
# abx.pm.hook.on_orchestrator_shutdown(self)
|
||||
def on_shutdown(self, error: BaseException | None = None) -> None:
|
||||
"""Called when orchestrator shuts down."""
|
||||
if self.pid_file:
|
||||
remove_pid_file(self.pid_file)
|
||||
|
||||
def on_tick_started(self, all_queues):
|
||||
# total_pending = sum(queue.count() for queue in all_queues.values())
|
||||
# if total_pending:
|
||||
# print(f'👨✈️ {self}.on_tick_started()', f'total_pending={total_pending}')
|
||||
# abx.pm.hook.on_orchestrator_tick_started(self, actor_types, all_queues)
|
||||
pass
|
||||
if error and not isinstance(error, KeyboardInterrupt):
|
||||
print(f'[red]👨✈️ {self} SHUTDOWN with error:[/red] {type(error).__name__}: {error}')
|
||||
else:
|
||||
print(f'[grey53]👨✈️ {self} SHUTDOWN[/grey53]')
|
||||
|
||||
def on_tick_finished(self, all_queues, all_existing_actors, all_spawned_actors):
|
||||
# if all_spawned_actors:
|
||||
# total_queue_length = sum(queue.count() for queue in all_queues.values())
|
||||
# print(f'[grey53]👨✈️ {self}.on_tick_finished() queue={total_queue_length} existing_actors={len(all_existing_actors)} spawned_actors={len(all_spawned_actors)}[/grey53]')
|
||||
# abx.pm.hook.on_orchestrator_tick_finished(self, actor_types, all_queues)
|
||||
pass
|
||||
|
||||
def on_idle(self, all_queues):
|
||||
# print(f'👨✈️ {self}.on_idle()', f'idle_count={self.idle_count}')
|
||||
print('.', end='', flush=True, file=sys.stderr)
|
||||
# abx.pm.hook.on_orchestrator_idle(self)
|
||||
# check for orphaned objects left behind
|
||||
if self.idle_count == 60:
|
||||
orphaned_objects = self.get_orphaned_objects(all_queues)
|
||||
if orphaned_objects:
|
||||
print('[red]👨✈️ WARNING: some objects may not be processed, no actor has claimed them after 30s:[/red]', orphaned_objects)
|
||||
if self.idle_count > 3 and self.exit_on_idle and not self.has_future_objects(all_queues):
|
||||
raise KeyboardInterrupt('✅ All tasks completed, exiting')
|
||||
|
||||
def runloop(self):
|
||||
from archivebox.config.django import setup_django
|
||||
setup_django()
|
||||
def get_total_worker_count(self) -> int:
|
||||
"""Get total count of running workers across all types."""
|
||||
cleanup_stale_pid_files()
|
||||
return sum(len(W.get_running_workers()) for W in self.WORKER_TYPES)
|
||||
|
||||
def should_spawn_worker(self, WorkerClass: Type[Worker], queue_count: int) -> bool:
|
||||
"""Determine if we should spawn a new worker of the given type."""
|
||||
if queue_count == 0:
|
||||
return False
|
||||
|
||||
# Check per-type limit
|
||||
running_workers = WorkerClass.get_running_workers()
|
||||
if len(running_workers) >= self.MAX_WORKERS_PER_TYPE:
|
||||
return False
|
||||
|
||||
# Check total limit
|
||||
if self.get_total_worker_count() >= self.MAX_TOTAL_WORKERS:
|
||||
return False
|
||||
|
||||
# Check if we already have enough workers for the queue size
|
||||
# Spawn more gradually - don't flood with workers
|
||||
if len(running_workers) > 0 and queue_count <= len(running_workers) * WorkerClass.MAX_CONCURRENT_TASKS:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def spawn_worker(self, WorkerClass: Type[Worker]) -> int | None:
|
||||
"""Spawn a new worker process. Returns PID or None if spawn failed."""
|
||||
try:
|
||||
pid = WorkerClass.start(daemon=False)
|
||||
print(f'[blue]👨✈️ {self} spawned {WorkerClass.name} worker[/blue] pid={pid}')
|
||||
return pid
|
||||
except Exception as e:
|
||||
print(f'[red]👨✈️ {self} failed to spawn {WorkerClass.name} worker:[/red] {e}')
|
||||
return None
|
||||
|
||||
def check_queues_and_spawn_workers(self) -> dict[str, int]:
|
||||
"""
|
||||
Check all queues and spawn workers as needed.
|
||||
Returns dict of queue sizes by worker type.
|
||||
"""
|
||||
queue_sizes = {}
|
||||
|
||||
for WorkerClass in self.WORKER_TYPES:
|
||||
# Get queue for this worker type
|
||||
# Need to instantiate worker to get queue (for model access)
|
||||
worker = WorkerClass(worker_id=-1) # temp instance just for queue access
|
||||
queue = worker.get_queue()
|
||||
queue_count = queue.count()
|
||||
queue_sizes[WorkerClass.name] = queue_count
|
||||
|
||||
# Spawn worker if needed
|
||||
if self.should_spawn_worker(WorkerClass, queue_count):
|
||||
self.spawn_worker(WorkerClass)
|
||||
|
||||
return queue_sizes
|
||||
|
||||
def has_pending_work(self, queue_sizes: dict[str, int]) -> bool:
|
||||
"""Check if any queue has pending work."""
|
||||
return any(count > 0 for count in queue_sizes.values())
|
||||
|
||||
def has_running_workers(self) -> bool:
|
||||
"""Check if any workers are still running."""
|
||||
return self.get_total_worker_count() > 0
|
||||
|
||||
def has_future_work(self) -> bool:
|
||||
"""Check if there's work scheduled for the future (retry_at > now)."""
|
||||
for WorkerClass in self.WORKER_TYPES:
|
||||
worker = WorkerClass(worker_id=-1)
|
||||
Model = worker.get_model()
|
||||
# Check for items not in final state with future retry_at
|
||||
future_count = Model.objects.filter(
|
||||
retry_at__gt=timezone.now()
|
||||
).exclude(
|
||||
status__in=Model.FINAL_STATES
|
||||
).count()
|
||||
if future_count > 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
def on_tick(self, queue_sizes: dict[str, int]) -> None:
|
||||
"""Called each orchestrator tick. Override for custom behavior."""
|
||||
total_queued = sum(queue_sizes.values())
|
||||
total_workers = self.get_total_worker_count()
|
||||
|
||||
if total_queued > 0 or total_workers > 0:
|
||||
# Build status line
|
||||
status_parts = []
|
||||
for WorkerClass in self.WORKER_TYPES:
|
||||
name = WorkerClass.name
|
||||
queued = queue_sizes.get(name, 0)
|
||||
workers = len(WorkerClass.get_running_workers())
|
||||
if queued > 0 or workers > 0:
|
||||
status_parts.append(f'{name}={queued}q/{workers}w')
|
||||
|
||||
if status_parts:
|
||||
print(f'[grey53]👨✈️ {self} tick:[/grey53] {" ".join(status_parts)}')
|
||||
|
||||
def on_idle(self) -> None:
|
||||
"""Called when orchestrator is idle (no work, no workers)."""
|
||||
if self.idle_count == 1:
|
||||
print(f'[grey53]👨✈️ {self} idle, waiting for work...[/grey53]')
|
||||
|
||||
def should_exit(self, queue_sizes: dict[str, int]) -> bool:
|
||||
"""Determine if orchestrator should exit."""
|
||||
if not self.exit_on_idle:
|
||||
return False
|
||||
|
||||
if self.IDLE_TIMEOUT == 0:
|
||||
return False
|
||||
|
||||
# Don't exit if there's pending or future work
|
||||
if self.has_pending_work(queue_sizes):
|
||||
return False
|
||||
|
||||
if self.has_running_workers():
|
||||
return False
|
||||
|
||||
if self.has_future_work():
|
||||
return False
|
||||
|
||||
# Exit after idle timeout
|
||||
return self.idle_count >= self.IDLE_TIMEOUT
|
||||
|
||||
def runloop(self) -> None:
|
||||
"""Main orchestrator loop."""
|
||||
self.on_startup()
|
||||
|
||||
try:
|
||||
while True:
|
||||
all_queues = {
|
||||
actor_type: actor_type.get_queue()
|
||||
for actor_type in self.actor_types.values()
|
||||
}
|
||||
if not all_queues:
|
||||
raise Exception('Failed to find any actor_types to process')
|
||||
|
||||
self.on_tick_started(all_queues)
|
||||
|
||||
all_existing_actors = []
|
||||
all_spawned_actors = []
|
||||
|
||||
for actor_type, queue in all_queues.items():
|
||||
if not queue.exists():
|
||||
continue
|
||||
|
||||
next_obj = queue.first()
|
||||
print()
|
||||
print(f'🏃♂️ {self}.runloop() {actor_type.__name__.ljust(20)} queue={str(queue.count()).ljust(3)} next={next_obj.id if next_obj else "None"} {next_obj.status if next_obj else "None"} {(timezone.now() - next_obj.retry_at).total_seconds() if next_obj and next_obj.retry_at else "None"}')
|
||||
# Check queues and spawn workers
|
||||
queue_sizes = self.check_queues_and_spawn_workers()
|
||||
|
||||
# Track idle state
|
||||
if self.has_pending_work(queue_sizes) or self.has_running_workers():
|
||||
self.idle_count = 0
|
||||
try:
|
||||
existing_actors = actor_type.get_running_actors()
|
||||
all_existing_actors.extend(existing_actors)
|
||||
actors_to_spawn = actor_type.get_actors_to_spawn(queue, existing_actors)
|
||||
can_spawn_num_remaining = self.max_concurrent_actors - len(all_existing_actors) # set max_concurrent_actors=1 to disable multitasking
|
||||
for launch_kwargs in actors_to_spawn[:can_spawn_num_remaining]:
|
||||
new_actor_pid = actor_type.start(mode='process', **launch_kwargs)
|
||||
all_spawned_actors.append(new_actor_pid)
|
||||
except Exception as err:
|
||||
print(f'🏃♂️ ERROR: {self} Failed to get {actor_type} queue & running actors', err)
|
||||
except BaseException:
|
||||
raise
|
||||
|
||||
if not any(queue.exists() for queue in all_queues.values()):
|
||||
self.on_idle(all_queues)
|
||||
self.idle_count += 1
|
||||
time.sleep(0.5)
|
||||
self.on_tick(queue_sizes)
|
||||
else:
|
||||
self.idle_count = 0
|
||||
|
||||
self.on_tick_finished(all_queues, all_existing_actors, all_spawned_actors)
|
||||
time.sleep(1)
|
||||
|
||||
except BaseException as err:
|
||||
if isinstance(err, KeyboardInterrupt):
|
||||
print()
|
||||
else:
|
||||
print(f'\n[red]🏃♂️ {self}.runloop() FATAL:[/red]', err.__class__.__name__, err)
|
||||
self.on_shutdown(err=err)
|
||||
self.idle_count += 1
|
||||
self.on_idle()
|
||||
|
||||
# Check if we should exit
|
||||
if self.should_exit(queue_sizes):
|
||||
print(f'[green]👨✈️ {self} all work complete, exiting[/green]')
|
||||
break
|
||||
|
||||
time.sleep(self.POLL_INTERVAL)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print() # Newline after ^C
|
||||
except BaseException as e:
|
||||
self.on_shutdown(error=e)
|
||||
raise
|
||||
else:
|
||||
self.on_shutdown()
|
||||
|
||||
def start(self) -> int:
|
||||
"""
|
||||
Fork orchestrator as a background process.
|
||||
Returns the PID of the new process.
|
||||
"""
|
||||
def run_orchestrator():
|
||||
from archivebox.config.django import setup_django
|
||||
setup_django()
|
||||
self.runloop()
|
||||
|
||||
proc = Process(target=run_orchestrator, name='orchestrator')
|
||||
proc.start()
|
||||
|
||||
assert proc.pid is not None
|
||||
print(f'[green]👨✈️ Orchestrator started in background[/green] pid={proc.pid}')
|
||||
return proc.pid
|
||||
|
||||
@classmethod
|
||||
def get_or_start(cls, exit_on_idle: bool = True) -> 'Orchestrator':
|
||||
"""
|
||||
Get running orchestrator or start a new one.
|
||||
Used by commands like 'add' to ensure orchestrator is running.
|
||||
"""
|
||||
if cls.is_running():
|
||||
print('[grey53]👨✈️ Orchestrator already running[/grey53]')
|
||||
# Return a placeholder - actual orchestrator is in another process
|
||||
return cls(exit_on_idle=exit_on_idle)
|
||||
|
||||
orchestrator = cls(exit_on_idle=exit_on_idle)
|
||||
return orchestrator
|
||||
|
||||
191
archivebox/workers/pid_utils.py
Normal file
191
archivebox/workers/pid_utils.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""
|
||||
PID file utilities for tracking worker and orchestrator processes.
|
||||
|
||||
PID files are stored in data/tmp/workers/ and contain:
|
||||
- Line 1: PID
|
||||
- Line 2: Worker type (orchestrator, crawl, snapshot, archiveresult)
|
||||
- Line 3: Extractor filter (optional, for archiveresult workers)
|
||||
- Line 4: Started at ISO timestamp
|
||||
"""
|
||||
|
||||
__package__ = 'archivebox.workers'
|
||||
|
||||
import os
|
||||
import signal
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
|
||||
def get_pid_dir() -> Path:
|
||||
"""Get the directory for PID files, creating it if needed."""
|
||||
pid_dir = Path(settings.DATA_DIR) / 'tmp' / 'workers'
|
||||
pid_dir.mkdir(parents=True, exist_ok=True)
|
||||
return pid_dir
|
||||
|
||||
|
||||
def write_pid_file(worker_type: str, worker_id: int = 0, extractor: str | None = None) -> Path:
|
||||
"""
|
||||
Write a PID file for the current process.
|
||||
Returns the path to the PID file.
|
||||
"""
|
||||
pid_dir = get_pid_dir()
|
||||
|
||||
if worker_type == 'orchestrator':
|
||||
pid_file = pid_dir / 'orchestrator.pid'
|
||||
else:
|
||||
pid_file = pid_dir / f'{worker_type}_worker_{worker_id}.pid'
|
||||
|
||||
content = f"{os.getpid()}\n{worker_type}\n{extractor or ''}\n{datetime.now(timezone.utc).isoformat()}\n"
|
||||
pid_file.write_text(content)
|
||||
|
||||
return pid_file
|
||||
|
||||
|
||||
def read_pid_file(path: Path) -> dict | None:
|
||||
"""
|
||||
Read and parse a PID file.
|
||||
Returns dict with pid, worker_type, extractor, started_at or None if invalid.
|
||||
"""
|
||||
try:
|
||||
if not path.exists():
|
||||
return None
|
||||
|
||||
lines = path.read_text().strip().split('\n')
|
||||
if len(lines) < 4:
|
||||
return None
|
||||
|
||||
return {
|
||||
'pid': int(lines[0]),
|
||||
'worker_type': lines[1],
|
||||
'extractor': lines[2] or None,
|
||||
'started_at': datetime.fromisoformat(lines[3]),
|
||||
'pid_file': path,
|
||||
}
|
||||
except (ValueError, IndexError, OSError):
|
||||
return None
|
||||
|
||||
|
||||
def remove_pid_file(path: Path) -> None:
|
||||
"""Remove a PID file if it exists."""
|
||||
try:
|
||||
path.unlink(missing_ok=True)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def is_process_alive(pid: int) -> bool:
|
||||
"""Check if a process with the given PID is still running."""
|
||||
try:
|
||||
os.kill(pid, 0) # Signal 0 doesn't kill, just checks
|
||||
return True
|
||||
except (OSError, ProcessLookupError):
|
||||
return False
|
||||
|
||||
|
||||
def get_all_pid_files() -> list[Path]:
|
||||
"""Get all PID files in the workers directory."""
|
||||
pid_dir = get_pid_dir()
|
||||
return list(pid_dir.glob('*.pid'))
|
||||
|
||||
|
||||
def get_all_worker_pids(worker_type: str | None = None) -> list[dict]:
|
||||
"""
|
||||
Get info about all running workers.
|
||||
Optionally filter by worker_type.
|
||||
"""
|
||||
workers = []
|
||||
|
||||
for pid_file in get_all_pid_files():
|
||||
info = read_pid_file(pid_file)
|
||||
if info is None:
|
||||
continue
|
||||
|
||||
# Skip if process is dead
|
||||
if not is_process_alive(info['pid']):
|
||||
continue
|
||||
|
||||
# Filter by type if specified
|
||||
if worker_type and info['worker_type'] != worker_type:
|
||||
continue
|
||||
|
||||
workers.append(info)
|
||||
|
||||
return workers
|
||||
|
||||
|
||||
def cleanup_stale_pid_files() -> int:
|
||||
"""
|
||||
Remove PID files for processes that are no longer running.
|
||||
Returns the number of stale files removed.
|
||||
"""
|
||||
removed = 0
|
||||
|
||||
for pid_file in get_all_pid_files():
|
||||
info = read_pid_file(pid_file)
|
||||
if info is None:
|
||||
# Invalid PID file, remove it
|
||||
remove_pid_file(pid_file)
|
||||
removed += 1
|
||||
continue
|
||||
|
||||
if not is_process_alive(info['pid']):
|
||||
remove_pid_file(pid_file)
|
||||
removed += 1
|
||||
|
||||
return removed
|
||||
|
||||
|
||||
def get_running_worker_count(worker_type: str) -> int:
|
||||
"""Get the count of running workers of a specific type."""
|
||||
return len(get_all_worker_pids(worker_type))
|
||||
|
||||
|
||||
def get_next_worker_id(worker_type: str) -> int:
|
||||
"""Get the next available worker ID for a given type."""
|
||||
existing_ids = set()
|
||||
|
||||
for pid_file in get_all_pid_files():
|
||||
# Parse worker ID from filename like "snapshot_worker_3.pid"
|
||||
name = pid_file.stem
|
||||
if name.startswith(f'{worker_type}_worker_'):
|
||||
try:
|
||||
worker_id = int(name.split('_')[-1])
|
||||
existing_ids.add(worker_id)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# Find the lowest unused ID
|
||||
next_id = 0
|
||||
while next_id in existing_ids:
|
||||
next_id += 1
|
||||
|
||||
return next_id
|
||||
|
||||
|
||||
def stop_worker(pid: int, graceful: bool = True) -> bool:
|
||||
"""
|
||||
Stop a worker process.
|
||||
If graceful=True, sends SIGTERM first, then SIGKILL after timeout.
|
||||
Returns True if process was stopped.
|
||||
"""
|
||||
if not is_process_alive(pid):
|
||||
return True
|
||||
|
||||
try:
|
||||
if graceful:
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
# Give it a moment to shut down
|
||||
import time
|
||||
for _ in range(10): # Wait up to 1 second
|
||||
time.sleep(0.1)
|
||||
if not is_process_alive(pid):
|
||||
return True
|
||||
# Force kill if still running
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
else:
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
return True
|
||||
except (OSError, ProcessLookupError):
|
||||
return True # Process already dead
|
||||
@@ -1,103 +0,0 @@
|
||||
# import uuid
|
||||
# from functools import wraps
|
||||
# from django.db import connection, transaction
|
||||
# from django.utils import timezone
|
||||
# from huey.exceptions import TaskLockedException
|
||||
|
||||
# from archivebox.config import CONSTANTS
|
||||
|
||||
# class SqliteSemaphore:
|
||||
# def __init__(self, db_path, table_name, name, value=1, timeout=None):
|
||||
# self.db_path = db_path
|
||||
# self.table_name = table_name
|
||||
# self.name = name
|
||||
# self.value = value
|
||||
# self.timeout = timeout or 86400 # Set a max age for lock holders
|
||||
|
||||
# # Ensure the table exists
|
||||
# with connection.cursor() as cursor:
|
||||
# cursor.execute(f"""
|
||||
# CREATE TABLE IF NOT EXISTS {self.table_name} (
|
||||
# id TEXT PRIMARY KEY,
|
||||
# name TEXT,
|
||||
# timestamp DATETIME
|
||||
# )
|
||||
# """)
|
||||
|
||||
# def acquire(self, name=None):
|
||||
# name = name or str(uuid.uuid4())
|
||||
# now = timezone.now()
|
||||
# expiration = now - timezone.timedelta(seconds=self.timeout)
|
||||
|
||||
# with transaction.atomic():
|
||||
# # Remove expired locks
|
||||
# with connection.cursor() as cursor:
|
||||
# cursor.execute(f"""
|
||||
# DELETE FROM {self.table_name}
|
||||
# WHERE name = %s AND timestamp < %s
|
||||
# """, [self.name, expiration])
|
||||
|
||||
# # Try to acquire the lock
|
||||
# with connection.cursor() as cursor:
|
||||
# cursor.execute(f"""
|
||||
# INSERT INTO {self.table_name} (id, name, timestamp)
|
||||
# SELECT %s, %s, %s
|
||||
# WHERE (
|
||||
# SELECT COUNT(*) FROM {self.table_name}
|
||||
# WHERE name = %s
|
||||
# ) < %s
|
||||
# """, [name, self.name, now, self.name, self.value])
|
||||
|
||||
# if cursor.rowcount > 0:
|
||||
# return name
|
||||
|
||||
# # If we couldn't acquire the lock, remove our attempted entry
|
||||
# with connection.cursor() as cursor:
|
||||
# cursor.execute(f"""
|
||||
# DELETE FROM {self.table_name}
|
||||
# WHERE id = %s AND name = %s
|
||||
# """, [name, self.name])
|
||||
|
||||
# return None
|
||||
|
||||
# def release(self, name):
|
||||
# with connection.cursor() as cursor:
|
||||
# cursor.execute(f"""
|
||||
# DELETE FROM {self.table_name}
|
||||
# WHERE id = %s AND name = %s
|
||||
# """, [name, self.name])
|
||||
# return cursor.rowcount > 0
|
||||
|
||||
|
||||
# LOCKS_DB_PATH = CONSTANTS.DATABASE_FILE.parent / 'locks.sqlite3'
|
||||
|
||||
|
||||
# def lock_task_semaphore(db_path, table_name, lock_name, value=1, timeout=None):
|
||||
# """
|
||||
# Lock which can be acquired multiple times (default = 1).
|
||||
|
||||
# NOTE: no provisions are made for blocking, waiting, or notifying. This is
|
||||
# just a lock which can be acquired a configurable number of times.
|
||||
|
||||
# Example:
|
||||
|
||||
# # Allow up to 3 workers to run this task concurrently. If the task is
|
||||
# # locked, retry up to 2 times with a delay of 60s.
|
||||
# @huey.task(retries=2, retry_delay=60)
|
||||
# @lock_task_semaphore('path/to/db.sqlite3', 'semaphore_locks', 'my-lock', 3)
|
||||
# def my_task():
|
||||
# ...
|
||||
# """
|
||||
# sem = SqliteSemaphore(db_path, table_name, lock_name, value, timeout)
|
||||
# def decorator(fn):
|
||||
# @wraps(fn)
|
||||
# def inner(*args, **kwargs):
|
||||
# tid = sem.acquire()
|
||||
# if tid is None:
|
||||
# raise TaskLockedException(f'unable to acquire lock {lock_name}')
|
||||
# try:
|
||||
# return fn(*args, **kwargs)
|
||||
# finally:
|
||||
# sem.release(tid)
|
||||
# return inner
|
||||
# return decorator
|
||||
@@ -63,61 +63,68 @@ def bg_add(add_kwargs, task=None, parent_task_id=None):
|
||||
|
||||
|
||||
@task(queue="commands", context=True)
|
||||
def bg_archive_links(args, kwargs=None, task=None, parent_task_id=None):
|
||||
def bg_archive_snapshots(snapshots, kwargs=None, task=None, parent_task_id=None):
|
||||
"""
|
||||
Queue multiple snapshots for archiving via the state machine system.
|
||||
|
||||
This sets snapshots to 'queued' status so the orchestrator workers pick them up.
|
||||
The actual archiving happens through ArchiveResult.run().
|
||||
"""
|
||||
get_or_create_supervisord_process(daemonize=False)
|
||||
|
||||
from ..extractors import archive_links
|
||||
|
||||
|
||||
from django.utils import timezone
|
||||
from core.models import Snapshot
|
||||
|
||||
if task and parent_task_id:
|
||||
TaskModel.objects.set_parent_task(main_task_id=parent_task_id, sub_task_id=task.id)
|
||||
|
||||
assert args and args[0]
|
||||
assert snapshots
|
||||
kwargs = kwargs or {}
|
||||
|
||||
rough_count = len(args[0])
|
||||
|
||||
process_info = ProcessInfo(task, desc="archive_links", parent_task_id=parent_task_id, total=rough_count)
|
||||
|
||||
result = archive_links(*args, **kwargs)
|
||||
process_info.update(n=rough_count)
|
||||
return result
|
||||
|
||||
rough_count = len(snapshots) if hasattr(snapshots, '__len__') else snapshots.count()
|
||||
process_info = ProcessInfo(task, desc="archive_snapshots", parent_task_id=parent_task_id, total=rough_count)
|
||||
|
||||
@task(queue="commands", context=True)
|
||||
def bg_archive_link(args, kwargs=None,task=None, parent_task_id=None):
|
||||
get_or_create_supervisord_process(daemonize=False)
|
||||
|
||||
from ..extractors import archive_link
|
||||
|
||||
if task and parent_task_id:
|
||||
TaskModel.objects.set_parent_task(main_task_id=parent_task_id, sub_task_id=task.id)
|
||||
# Queue snapshots by setting status to queued with immediate retry_at
|
||||
queued_count = 0
|
||||
for snapshot in snapshots:
|
||||
if hasattr(snapshot, 'id'):
|
||||
# Update snapshot to queued state so workers pick it up
|
||||
Snapshot.objects.filter(id=snapshot.id).update(
|
||||
status=Snapshot.StatusChoices.QUEUED,
|
||||
retry_at=timezone.now(),
|
||||
)
|
||||
queued_count += 1
|
||||
|
||||
assert args and args[0]
|
||||
kwargs = kwargs or {}
|
||||
|
||||
rough_count = len(args[0])
|
||||
|
||||
process_info = ProcessInfo(task, desc="archive_link", parent_task_id=parent_task_id, total=rough_count)
|
||||
|
||||
result = archive_link(*args, **kwargs)
|
||||
process_info.update(n=rough_count)
|
||||
return result
|
||||
process_info.update(n=queued_count)
|
||||
return queued_count
|
||||
|
||||
|
||||
@task(queue="commands", context=True)
|
||||
def bg_archive_snapshot(snapshot, overwrite=False, methods=None, task=None, parent_task_id=None):
|
||||
# get_or_create_supervisord_process(daemonize=False)
|
||||
"""
|
||||
Queue a single snapshot for archiving via the state machine system.
|
||||
|
||||
This sets the snapshot to 'queued' status so the orchestrator workers pick it up.
|
||||
The actual archiving happens through ArchiveResult.run().
|
||||
"""
|
||||
get_or_create_supervisord_process(daemonize=False)
|
||||
|
||||
from django.utils import timezone
|
||||
from core.models import Snapshot
|
||||
|
||||
from ..extractors import archive_link
|
||||
|
||||
if task and parent_task_id:
|
||||
TaskModel.objects.set_parent_task(main_task_id=parent_task_id, sub_task_id=task.id)
|
||||
|
||||
process_info = ProcessInfo(task, desc="archive_link", parent_task_id=parent_task_id, total=1)
|
||||
|
||||
link = snapshot.as_link_with_details()
|
||||
|
||||
result = archive_link(link, overwrite=overwrite, methods=methods)
|
||||
process_info.update(n=1)
|
||||
return result
|
||||
process_info = ProcessInfo(task, desc="archive_snapshot", parent_task_id=parent_task_id, total=1)
|
||||
|
||||
# Queue the snapshot by setting status to queued
|
||||
if hasattr(snapshot, 'id'):
|
||||
Snapshot.objects.filter(id=snapshot.id).update(
|
||||
status=Snapshot.StatusChoices.QUEUED,
|
||||
retry_at=timezone.now(),
|
||||
)
|
||||
process_info.update(n=1)
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
@@ -115,7 +115,7 @@
|
||||
const jobElement = document.createElement('div');
|
||||
jobElement.className = 'job-item';
|
||||
jobElement.innerHTML = `
|
||||
<p><a href="/api/v1/core/any/${job.abid}?api_key={{api_token|default:'NONE PROVIDED BY VIEW'}}"><code>${job.abid}</code></a></p>
|
||||
<p><a href="/api/v1/core/any/${job.id}?api_key={{api_token|default:'NONE PROVIDED BY VIEW'}}"><code>${job.id}</code></a></p>
|
||||
<p>
|
||||
<span class="badge badge-${job.status}">${job.status}</span>
|
||||
<span class="date">♻️ ${formatDate(job.retry_at)}</span>
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
from django.test import TestCase
|
||||
|
||||
# Create your tests here.
|
||||
|
||||
|
||||
class CrawlActorTest(TestCase):
|
||||
|
||||
def test_crawl_creation(self):
|
||||
seed = Seed.objects.create(uri='https://example.com')
|
||||
Event.dispatch('CRAWL_CREATE', {'seed_id': seed.id})
|
||||
|
||||
crawl_actor = CrawlActor()
|
||||
|
||||
output_events = list(crawl_actor.process_next_event())
|
||||
|
||||
assert len(output_events) == 1
|
||||
assert output_events[0].get('name', 'unset') == 'FS_WRITE'
|
||||
assert output_events[0].get('path') == '/tmp/test_crawl/index.json'
|
||||
|
||||
output_events = list(crawl_actor.process_next_event())
|
||||
assert len(output_events) == 1
|
||||
assert output_events[0].get('name', 'unset') == 'CRAWL_CREATED'
|
||||
|
||||
assert Crawl.objects.filter(seed_id=seed.id).exists(), 'Crawl was not created'
|
||||
|
||||
@@ -1,457 +1,330 @@
|
||||
"""
|
||||
Worker classes for processing queue items.
|
||||
|
||||
Workers poll the database for items to process, claim them atomically,
|
||||
and run the state machine tick() to process each item.
|
||||
|
||||
Architecture:
|
||||
Orchestrator (spawns workers)
|
||||
└── Worker (claims items from queue, processes them directly)
|
||||
"""
|
||||
|
||||
__package__ = 'archivebox.workers'
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
import json
|
||||
|
||||
from typing import ClassVar, Iterable, Type
|
||||
import traceback
|
||||
from typing import ClassVar, Any
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from multiprocessing import Process, cpu_count
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from django.utils import timezone
|
||||
from django.conf import settings
|
||||
|
||||
from rich import print
|
||||
|
||||
from django.db import transaction
|
||||
from django.db.models import QuerySet
|
||||
from django.utils import timezone
|
||||
from django.utils.functional import classproperty # type: ignore
|
||||
|
||||
from crawls.models import Crawl
|
||||
from core.models import Snapshot, ArchiveResult
|
||||
|
||||
from workers.models import Event, Process, EventDict
|
||||
from .pid_utils import (
|
||||
write_pid_file,
|
||||
remove_pid_file,
|
||||
get_all_worker_pids,
|
||||
get_next_worker_id,
|
||||
cleanup_stale_pid_files,
|
||||
)
|
||||
|
||||
|
||||
class WorkerType:
|
||||
# static class attributes
|
||||
name: ClassVar[str] # e.g. 'log' or 'filesystem' or 'crawl' or 'snapshot' or 'archiveresult' etc.
|
||||
|
||||
listens_to: ClassVar[str] # e.g. 'LOG_' or 'FS_' or 'CRAWL_' or 'SNAPSHOT_' or 'ARCHIVERESULT_' etc.
|
||||
outputs: ClassVar[list[str]] # e.g. ['LOG_', 'FS_', 'CRAWL_', 'SNAPSHOT_', 'ARCHIVERESULT_'] etc.
|
||||
|
||||
poll_interval: ClassVar[int] = 1 # how long to wait before polling for new events
|
||||
|
||||
@classproperty
|
||||
def event_queue(cls) -> QuerySet[Event]:
|
||||
return Event.objects.filter(name__startswith=cls.listens_to)
|
||||
CPU_COUNT = cpu_count()
|
||||
|
||||
@classmethod
|
||||
def fork(cls, wait_for_first_event=False, exit_on_idle=True) -> Process:
|
||||
cmd = ['archivebox', 'worker', cls.name]
|
||||
if exit_on_idle:
|
||||
cmd.append('--exit-on-idle')
|
||||
if wait_for_first_event:
|
||||
cmd.append('--wait-for-first-event')
|
||||
return Process.create_and_fork(cmd=cmd, actor_type=cls.name)
|
||||
# Registry of worker types by name (defined at bottom, referenced here for _run_worker)
|
||||
WORKER_TYPES: dict[str, type['Worker']] = {}
|
||||
|
||||
@classproperty
|
||||
def processes(cls) -> QuerySet[Process]:
|
||||
return Process.objects.filter(actor_type=cls.name)
|
||||
|
||||
@classmethod
|
||||
def run(cls, wait_for_first_event=False, exit_on_idle=True):
|
||||
def _run_worker(worker_class_name: str, worker_id: int, daemon: bool, **kwargs):
|
||||
"""
|
||||
Module-level function to run a worker. Must be at module level for pickling.
|
||||
"""
|
||||
from archivebox.config.django import setup_django
|
||||
setup_django()
|
||||
|
||||
if wait_for_first_event:
|
||||
event = cls.event_queue.get_next_unclaimed()
|
||||
while not event:
|
||||
time.sleep(cls.poll_interval)
|
||||
event = cls.event_queue.get_next_unclaimed()
|
||||
# Get worker class by name to avoid pickling class objects
|
||||
worker_cls = WORKER_TYPES[worker_class_name]
|
||||
worker = worker_cls(worker_id=worker_id, daemon=daemon, **kwargs)
|
||||
worker.runloop()
|
||||
|
||||
while True:
|
||||
output_events = list(cls.process_next_event()) or list(cls.process_idle_tick()) # process next event, or tick if idle
|
||||
yield from output_events
|
||||
if not output_events:
|
||||
if exit_on_idle:
|
||||
break
|
||||
else:
|
||||
time.sleep(cls.poll_interval)
|
||||
|
||||
@classmethod
|
||||
def process_next_event(cls) -> Iterable[EventDict]:
|
||||
event = cls.event_queue.get_next_unclaimed()
|
||||
output_events = []
|
||||
|
||||
if not event:
|
||||
return []
|
||||
|
||||
cls.mark_event_claimed(event)
|
||||
print(f'{cls.__name__}[{Process.current().pid}] {event}', file=sys.stderr)
|
||||
try:
|
||||
for output_event in cls.receive(event):
|
||||
output_events.append(output_event)
|
||||
yield output_event
|
||||
cls.mark_event_succeeded(event, output_events=output_events)
|
||||
except BaseException as e:
|
||||
cls.mark_event_failed(event, output_events=output_events, error=e)
|
||||
class Worker:
|
||||
"""
|
||||
Base worker class that polls a queue and processes items directly.
|
||||
|
||||
@classmethod
|
||||
def process_idle_tick(cls) -> Iterable[EventDict]:
|
||||
# reset the idle event to be claimed by the current process
|
||||
event, _created = Event.objects.update_or_create(
|
||||
name=f'{cls.listens_to}IDLE',
|
||||
emitted_by=Process.current(),
|
||||
defaults={
|
||||
'deliver_at': timezone.now(),
|
||||
'claimed_proc': None,
|
||||
'claimed_at': None,
|
||||
'finished_at': None,
|
||||
'error': None,
|
||||
'parent': None,
|
||||
},
|
||||
Each item is processed by calling its state machine tick() method.
|
||||
Workers exit when idle for too long (unless daemon mode).
|
||||
"""
|
||||
|
||||
name: ClassVar[str] = 'worker'
|
||||
|
||||
# Configuration (can be overridden by subclasses)
|
||||
MAX_TICK_TIME: ClassVar[int] = 60
|
||||
POLL_INTERVAL: ClassVar[float] = 0.5
|
||||
IDLE_TIMEOUT: ClassVar[int] = 3 # Exit after N idle iterations (set to 0 to never exit)
|
||||
|
||||
def __init__(self, worker_id: int = 0, daemon: bool = False, **kwargs: Any):
|
||||
self.worker_id = worker_id
|
||||
self.daemon = daemon
|
||||
self.pid: int = os.getpid()
|
||||
self.pid_file: Path | None = None
|
||||
self.idle_count: int = 0
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'[underline]{self.__class__.__name__}[/underline]\\[id={self.worker_id}, pid={self.pid}]'
|
||||
|
||||
def get_model(self):
|
||||
"""Get the Django model class. Subclasses must override this."""
|
||||
raise NotImplementedError("Subclasses must implement get_model()")
|
||||
|
||||
def get_queue(self) -> QuerySet:
|
||||
"""Get the queue of objects ready for processing."""
|
||||
Model = self.get_model()
|
||||
return Model.objects.filter(
|
||||
retry_at__lte=timezone.now()
|
||||
).exclude(
|
||||
status__in=Model.FINAL_STATES
|
||||
).order_by('retry_at')
|
||||
|
||||
def claim_next(self):
|
||||
"""
|
||||
Atomically claim the next object from the queue.
|
||||
Returns the claimed object or None if queue is empty or claim failed.
|
||||
"""
|
||||
Model = self.get_model()
|
||||
obj = self.get_queue().first()
|
||||
if obj is None:
|
||||
return None
|
||||
|
||||
# Atomic claim using optimistic locking on retry_at
|
||||
claimed = Model.objects.filter(
|
||||
pk=obj.pk,
|
||||
retry_at=obj.retry_at,
|
||||
).update(
|
||||
retry_at=timezone.now() + timedelta(seconds=self.MAX_TICK_TIME)
|
||||
)
|
||||
|
||||
# then process it like any other event
|
||||
yield from cls.process_next_event()
|
||||
|
||||
if claimed == 1:
|
||||
obj.refresh_from_db()
|
||||
return obj
|
||||
|
||||
return None # Someone else claimed it
|
||||
|
||||
def process_item(self, obj) -> bool:
|
||||
"""
|
||||
Process a single item by calling its state machine tick().
|
||||
Returns True on success, False on failure.
|
||||
Subclasses can override for custom processing.
|
||||
"""
|
||||
try:
|
||||
obj.sm.tick()
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f'[red]{self} error processing {obj.pk}:[/red] {e}')
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def on_startup(self) -> None:
|
||||
"""Called when worker starts."""
|
||||
self.pid = os.getpid()
|
||||
self.pid_file = write_pid_file(self.name, self.worker_id)
|
||||
print(f'[green]{self} STARTED[/green] pid_file={self.pid_file}')
|
||||
|
||||
def on_shutdown(self, error: BaseException | None = None) -> None:
|
||||
"""Called when worker shuts down."""
|
||||
# Remove PID file
|
||||
if self.pid_file:
|
||||
remove_pid_file(self.pid_file)
|
||||
|
||||
if error and not isinstance(error, KeyboardInterrupt):
|
||||
print(f'[red]{self} SHUTDOWN with error:[/red] {type(error).__name__}: {error}')
|
||||
else:
|
||||
print(f'[grey53]{self} SHUTDOWN[/grey53]')
|
||||
|
||||
def should_exit(self) -> bool:
|
||||
"""Check if worker should exit due to idle timeout."""
|
||||
if self.daemon:
|
||||
return False
|
||||
|
||||
if self.IDLE_TIMEOUT == 0:
|
||||
return False
|
||||
|
||||
return self.idle_count >= self.IDLE_TIMEOUT
|
||||
|
||||
def runloop(self) -> None:
|
||||
"""Main worker loop - polls queue, processes items."""
|
||||
self.on_startup()
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Try to claim and process an item
|
||||
obj = self.claim_next()
|
||||
|
||||
if obj is not None:
|
||||
self.idle_count = 0
|
||||
print(f'[blue]{self} processing:[/blue] {obj.pk}')
|
||||
|
||||
start_time = time.time()
|
||||
success = self.process_item(obj)
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
if success:
|
||||
print(f'[green]{self} completed ({elapsed:.1f}s):[/green] {obj.pk}')
|
||||
else:
|
||||
print(f'[red]{self} failed ({elapsed:.1f}s):[/red] {obj.pk}')
|
||||
else:
|
||||
# No work available
|
||||
self.idle_count += 1
|
||||
if self.idle_count == 1:
|
||||
print(f'[grey53]{self} idle, waiting for work...[/grey53]')
|
||||
|
||||
# Check if we should exit
|
||||
if self.should_exit():
|
||||
print(f'[grey53]{self} idle timeout reached, exiting[/grey53]')
|
||||
break
|
||||
|
||||
time.sleep(self.POLL_INTERVAL)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
except BaseException as e:
|
||||
self.on_shutdown(error=e)
|
||||
raise
|
||||
else:
|
||||
self.on_shutdown()
|
||||
|
||||
@classmethod
|
||||
def receive(cls, event: Event) -> Iterable[EventDict]:
|
||||
handler_method = getattr(cls, f'on_{event.name}', None)
|
||||
if handler_method:
|
||||
yield from handler_method(event)
|
||||
else:
|
||||
raise Exception(f'No handler method for event: {event.name}')
|
||||
def start(cls, worker_id: int | None = None, daemon: bool = False, **kwargs: Any) -> int:
|
||||
"""
|
||||
Fork a new worker as a subprocess.
|
||||
Returns the PID of the new process.
|
||||
"""
|
||||
if worker_id is None:
|
||||
worker_id = get_next_worker_id(cls.name)
|
||||
|
||||
@staticmethod
|
||||
def on_IDLE() -> Iterable[EventDict]:
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def mark_event_claimed(event: Event):
|
||||
proc = Process.current()
|
||||
|
||||
with transaction.atomic():
|
||||
claimed = Event.objects.filter(id=event.id, claimed_proc=None, claimed_at=None).update(claimed_proc=proc, claimed_at=timezone.now())
|
||||
event.refresh_from_db()
|
||||
if not claimed:
|
||||
raise Exception(f'Event already claimed by another process: {event.claimed_proc}')
|
||||
|
||||
print(f'{self}.mark_event_claimed(): Claimed {event} ⛏️')
|
||||
|
||||
# process_updated = Process.objects.filter(id=proc.id, active_event=None).update(active_event=event)
|
||||
# if not process_updated:
|
||||
# raise Exception(f'Unable to update process.active_event: {proc}.active_event = {event}')
|
||||
# Use module-level function for pickling compatibility
|
||||
proc = Process(
|
||||
target=_run_worker,
|
||||
args=(cls.name, worker_id, daemon),
|
||||
kwargs=kwargs,
|
||||
name=f'{cls.name}_worker_{worker_id}',
|
||||
)
|
||||
proc.start()
|
||||
|
||||
@staticmethod
|
||||
def mark_event_succeeded(event: Event, output_events: Iterable[EventDict]):
|
||||
event.refresh_from_db()
|
||||
assert event.claimed_proc, f'Cannot mark event as succeeded if it is not claimed by a process: {event}'
|
||||
assert (event.claimed_proc == Process.current()), f'Cannot mark event as succeeded if it claimed by a different process: {event}.claimed_proc = {event.claimed_proc}, current_process = {Process.current()}'
|
||||
|
||||
with transaction.atomic():
|
||||
updated = Event.objects.filter(id=event.id, claimed_proc=event.claimed_proc, claimed_at=event.claimed_at, finished_at=None).update(finished_at=timezone.now())
|
||||
event.refresh_from_db()
|
||||
if not updated:
|
||||
raise Exception(f'Event {event} failed to mark as succeeded, it was modified by another process: {event.claimed_proc}')
|
||||
assert proc.pid is not None
|
||||
return proc.pid
|
||||
|
||||
# process_updated = Process.objects.filter(id=event.claimed_proc.id, active_event=event).update(active_event=None)
|
||||
# if not process_updated:
|
||||
# raise Exception(f'Unable to unset process.active_event: {event.claimed_proc}.active_event = {event}')
|
||||
@classmethod
|
||||
def get_running_workers(cls) -> list[dict]:
|
||||
"""Get info about all running workers of this type."""
|
||||
cleanup_stale_pid_files()
|
||||
return get_all_worker_pids(cls.name)
|
||||
|
||||
# dispatch any output events
|
||||
for output_event in output_events:
|
||||
Event.dispatch(event=output_event, parent=event)
|
||||
|
||||
# trigger any callback events
|
||||
if event.on_success:
|
||||
Event.dispatch(event=event.on_success, parent=event)
|
||||
|
||||
@staticmethod
|
||||
def mark_event_failed(event: Event, output_events: Iterable[EventDict]=(), error: BaseException | None = None):
|
||||
event.refresh_from_db()
|
||||
assert event.claimed_proc, f'Cannot mark event as failed if it is not claimed by a process: {event}'
|
||||
assert (event.claimed_proc == Process.current()), f'Cannot mark event as failed if it claimed by a different process: {event}.claimed_proc = {event.claimed_proc}, current_process = {Process.current()}'
|
||||
|
||||
with transaction.atomic():
|
||||
updated = Event.objects.filter(id=event.id, claimed_proc=event.claimed_proc, claimed_at=event.claimed_at, finished_at=None).update(finished_at=timezone.now(), error=str(error))
|
||||
event.refresh_from_db()
|
||||
if not updated:
|
||||
raise Exception(f'Event {event} failed to mark as failed, it was modified by another process: {event.claimed_proc}')
|
||||
|
||||
# process_updated = Process.objects.filter(id=event.claimed_proc.id, active_event=event).update(active_event=None)
|
||||
# if not process_updated:
|
||||
# raise Exception(f'Unable to unset process.active_event: {event.claimed_proc}.active_event = {event}')
|
||||
|
||||
|
||||
# add dedicated error event to the output events
|
||||
if not event.name.endswith('_ERROR'):
|
||||
output_events = [
|
||||
*output_events,
|
||||
{'name': f'{event.name}_ERROR', 'msg': f'{type(error).__name__}: {error}'},
|
||||
]
|
||||
|
||||
# dispatch any output events
|
||||
for output_event in output_events:
|
||||
Event.dispatch(event=output_event, parent=event)
|
||||
|
||||
# trigger any callback events
|
||||
if event.on_failure:
|
||||
Event.dispatch(event=event.on_failure, parent=event)
|
||||
@classmethod
|
||||
def get_worker_count(cls) -> int:
|
||||
"""Get count of running workers of this type."""
|
||||
return len(cls.get_running_workers())
|
||||
|
||||
|
||||
class CrawlWorker(Worker):
|
||||
"""Worker for processing Crawl objects."""
|
||||
|
||||
name: ClassVar[str] = 'crawl'
|
||||
MAX_TICK_TIME: ClassVar[int] = 60
|
||||
|
||||
def get_model(self):
|
||||
from crawls.models import Crawl
|
||||
return Crawl
|
||||
|
||||
|
||||
class OrchestratorWorker(WorkerType):
|
||||
name = 'orchestrator'
|
||||
listens_to = 'PROC_'
|
||||
outputs = ['PROC_']
|
||||
|
||||
@staticmethod
|
||||
def on_PROC_IDLE() -> Iterable[EventDict]:
|
||||
# look through all Processes that are not yet launched and launch them
|
||||
to_launch = Process.objects.filter(launched_at=None).order_by('created_at').first()
|
||||
if not to_launch:
|
||||
return []
|
||||
|
||||
yield {'name': 'PROC_LAUNCH', 'id': to_launch.id}
|
||||
|
||||
@staticmethod
|
||||
def on_PROC_LAUNCH(event: Event) -> Iterable[EventDict]:
|
||||
process = Process.create_and_fork(**event.kwargs)
|
||||
yield {'name': 'PROC_LAUNCHED', 'process_id': process.id}
|
||||
|
||||
@staticmethod
|
||||
def on_PROC_EXIT(event: Event) -> Iterable[EventDict]:
|
||||
process = Process.objects.get(id=event.process_id)
|
||||
process.kill()
|
||||
yield {'name': 'PROC_KILLED', 'process_id': process.id}
|
||||
|
||||
@staticmethod
|
||||
def on_PROC_KILL(event: Event) -> Iterable[EventDict]:
|
||||
process = Process.objects.get(id=event.process_id)
|
||||
process.kill()
|
||||
yield {'name': 'PROC_KILLED', 'process_id': process.id}
|
||||
class SnapshotWorker(Worker):
|
||||
"""Worker for processing Snapshot objects."""
|
||||
|
||||
name: ClassVar[str] = 'snapshot'
|
||||
MAX_TICK_TIME: ClassVar[int] = 60
|
||||
|
||||
def get_model(self):
|
||||
from core.models import Snapshot
|
||||
return Snapshot
|
||||
|
||||
|
||||
class FileSystemWorker(WorkerType):
|
||||
name = 'filesystem'
|
||||
listens_to = 'FS_'
|
||||
outputs = ['FS_']
|
||||
class ArchiveResultWorker(Worker):
|
||||
"""Worker for processing ArchiveResult objects."""
|
||||
|
||||
@staticmethod
|
||||
def on_FS_IDLE(event: Event) -> Iterable[EventDict]:
|
||||
# check for tmp files that can be deleted
|
||||
for tmp_file in Path('/tmp').glob('archivebox/*'):
|
||||
yield {'name': 'FS_DELETE', 'path': str(tmp_file)}
|
||||
|
||||
@staticmethod
|
||||
def on_FS_WRITE(event: Event) -> Iterable[EventDict]:
|
||||
with open(event.path, 'w') as f:
|
||||
f.write(event.content)
|
||||
yield {'name': 'FS_CHANGED', 'path': event.path}
|
||||
name: ClassVar[str] = 'archiveresult'
|
||||
MAX_TICK_TIME: ClassVar[int] = 120
|
||||
|
||||
@staticmethod
|
||||
def on_FS_APPEND(event: Event) -> Iterable[EventDict]:
|
||||
with open(event.path, 'a') as f:
|
||||
f.write(event.content)
|
||||
yield {'name': 'FS_CHANGED', 'path': event.path}
|
||||
|
||||
@staticmethod
|
||||
def on_FS_DELETE(event: Event) -> Iterable[EventDict]:
|
||||
os.remove(event.path)
|
||||
yield {'name': 'FS_CHANGED', 'path': event.path}
|
||||
|
||||
@staticmethod
|
||||
def on_FS_RSYNC(event: Event) -> Iterable[EventDict]:
|
||||
os.system(f'rsync -av {event.src} {event.dst}')
|
||||
yield {'name': 'FS_CHANGED', 'path': event.dst}
|
||||
def __init__(self, extractor: str | None = None, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
self.extractor = extractor
|
||||
|
||||
def get_model(self):
|
||||
from core.models import ArchiveResult
|
||||
return ArchiveResult
|
||||
|
||||
def get_queue(self) -> QuerySet:
|
||||
"""Get queue of ArchiveResults ready for processing."""
|
||||
from django.db.models import Exists, OuterRef
|
||||
from core.models import ArchiveResult
|
||||
|
||||
qs = super().get_queue()
|
||||
|
||||
if self.extractor:
|
||||
qs = qs.filter(extractor=self.extractor)
|
||||
|
||||
# Exclude ArchiveResults whose Snapshot already has one in progress
|
||||
in_progress = ArchiveResult.objects.filter(
|
||||
snapshot_id=OuterRef('snapshot_id'),
|
||||
status=ArchiveResult.StatusChoices.STARTED,
|
||||
)
|
||||
qs = qs.exclude(Exists(in_progress))
|
||||
|
||||
return qs
|
||||
|
||||
def process_item(self, obj) -> bool:
|
||||
"""Process an ArchiveResult by running its extractor."""
|
||||
try:
|
||||
obj.sm.tick()
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f'[red]{self} error processing {obj.pk}:[/red] {e}')
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def start(cls, worker_id: int | None = None, daemon: bool = False, extractor: str | None = None, **kwargs: Any) -> int:
|
||||
"""Fork a new worker as subprocess with optional extractor filter."""
|
||||
if worker_id is None:
|
||||
worker_id = get_next_worker_id(cls.name)
|
||||
|
||||
# Use module-level function for pickling compatibility
|
||||
proc = Process(
|
||||
target=_run_worker,
|
||||
args=(cls.name, worker_id, daemon),
|
||||
kwargs={'extractor': extractor, **kwargs},
|
||||
name=f'{cls.name}_worker_{worker_id}',
|
||||
)
|
||||
proc.start()
|
||||
|
||||
assert proc.pid is not None
|
||||
return proc.pid
|
||||
|
||||
|
||||
class CrawlWorker(WorkerType):
|
||||
name = 'crawl'
|
||||
listens_to = 'CRAWL_'
|
||||
outputs = ['CRAWL_', 'FS_', 'SNAPSHOT_']
|
||||
|
||||
@staticmethod
|
||||
def on_CRAWL_IDLE(event: Event) -> Iterable[EventDict]:
|
||||
# check for any stale crawls that can be started or sealed
|
||||
stale_crawl = Crawl.objects.filter(retry_at__lt=timezone.now()).first()
|
||||
if not stale_crawl:
|
||||
return []
|
||||
|
||||
if stale_crawl.can_start():
|
||||
yield {'name': 'CRAWL_START', 'id': stale_crawl.id}
|
||||
|
||||
elif stale_crawl.can_seal():
|
||||
yield {'name': 'CRAWL_SEAL', 'id': stale_crawl.id}
|
||||
|
||||
@staticmethod
|
||||
def on_CRAWL_CREATE(event: Event) -> Iterable[EventDict]:
|
||||
crawl, created = Crawl.objects.get_or_create(id=event.id, defaults=event)
|
||||
if created:
|
||||
yield {'name': 'CRAWL_UPDATED', 'crawl_id': crawl.id}
|
||||
|
||||
@staticmethod
|
||||
def on_CRAWL_UPDATE(event: Event) -> Iterable[EventDict]:
|
||||
crawl = Crawl.objects.get(id=event.pop('crawl_id'))
|
||||
diff = {
|
||||
key: val
|
||||
for key, val in event.items()
|
||||
if getattr(crawl, key) != val
|
||||
}
|
||||
if diff:
|
||||
crawl.update(**diff)
|
||||
yield {'name': 'CRAWL_UPDATED', 'crawl_id': crawl.id}
|
||||
|
||||
@staticmethod
|
||||
def on_CRAWL_UPDATED(event: Event) -> Iterable[EventDict]:
|
||||
crawl = Crawl.objects.get(id=event.crawl_id)
|
||||
yield {'name': 'FS_WRITE_SYMLINKS', 'path': crawl.OUTPUT_DIR, 'symlinks': crawl.output_dir_symlinks}
|
||||
|
||||
|
||||
@staticmethod
|
||||
def on_CRAWL_SEAL(event: Event) -> Iterable[EventDict]:
|
||||
crawl = Crawl.objects.filter(id=event.id, status=Crawl.StatusChoices.STARTED).first()
|
||||
if not crawl:
|
||||
return
|
||||
crawl.status = Crawl.StatusChoices.SEALED
|
||||
crawl.save()
|
||||
yield {'name': 'FS_WRITE', 'path': crawl.OUTPUT_DIR / 'index.json', 'content': json.dumps(crawl.as_json(), default=str, indent=4, sort_keys=True)}
|
||||
yield {'name': 'CRAWL_UPDATED', 'crawl_id': crawl.id}
|
||||
|
||||
@staticmethod
|
||||
def on_CRAWL_START(event: Event) -> Iterable[EventDict]:
|
||||
# create root snapshot
|
||||
crawl = Crawl.objects.get(id=event.crawl_id)
|
||||
new_snapshot_id = uuid.uuid4()
|
||||
yield {'name': 'SNAPSHOT_CREATE', 'snapshot_id': new_snapshot_id, 'crawl_id': crawl.id, 'url': crawl.seed.uri}
|
||||
yield {'name': 'SNAPSHOT_START', 'snapshot_id': new_snapshot_id}
|
||||
yield {'name': 'CRAWL_UPDATE', 'crawl_id': crawl.id, 'status': 'started', 'retry_at': None}
|
||||
# Populate the registry
|
||||
WORKER_TYPES.update({
|
||||
'crawl': CrawlWorker,
|
||||
'snapshot': SnapshotWorker,
|
||||
'archiveresult': ArchiveResultWorker,
|
||||
})
|
||||
|
||||
|
||||
class SnapshotWorker(WorkerType):
|
||||
name = 'snapshot'
|
||||
listens_to = 'SNAPSHOT_'
|
||||
outputs = ['SNAPSHOT_', 'FS_']
|
||||
|
||||
@staticmethod
|
||||
def on_SNAPSHOT_IDLE(event: Event) -> Iterable[EventDict]:
|
||||
# check for any snapshots that can be started or sealed
|
||||
snapshot = Snapshot.objects.exclude(status=Snapshot.StatusChoices.SEALED).first()
|
||||
if not snapshot:
|
||||
return []
|
||||
|
||||
if snapshot.can_start():
|
||||
yield {'name': 'SNAPSHOT_START', 'id': snapshot.id}
|
||||
elif snapshot.can_seal():
|
||||
yield {'name': 'SNAPSHOT_SEAL', 'id': snapshot.id}
|
||||
|
||||
@staticmethod
|
||||
def on_SNAPSHOT_CREATE(event: Event) -> Iterable[EventDict]:
|
||||
snapshot = Snapshot.objects.create(id=event.snapshot_id, **event.kwargs)
|
||||
yield {'name': 'FS_WRITE', 'path': snapshot.OUTPUT_DIR / 'index.json', 'content': json.dumps(snapshot.as_json(), default=str, indent=4, sort_keys=True)}
|
||||
yield {'name': 'SNAPSHOT_UPDATED', 'id': snapshot.id}
|
||||
|
||||
@staticmethod
|
||||
def on_SNAPSHOT_SEAL(event: Event) -> Iterable[EventDict]:
|
||||
snapshot = Snapshot.objects.get(id=event.snapshot_id, status=Snapshot.StatusChoices.STARTED)
|
||||
assert snapshot.can_seal()
|
||||
snapshot.status = Snapshot.StatusChoices.SEALED
|
||||
snapshot.save()
|
||||
yield {'name': 'FS_WRITE', 'path': snapshot.OUTPUT_DIR / 'index.json', 'content': json.dumps(snapshot.as_json(), default=str, indent=4, sort_keys=True)}
|
||||
yield {'name': 'SNAPSHOT_UPDATED', 'id': snapshot.id}
|
||||
|
||||
@staticmethod
|
||||
def on_SNAPSHOT_START(event: Event) -> Iterable[EventDict]:
|
||||
snapshot = Snapshot.objects.get(id=event.snapshot_id, status=Snapshot.StatusChoices.QUEUED)
|
||||
assert snapshot.can_start()
|
||||
|
||||
# create pending archiveresults for each extractor
|
||||
for extractor in snapshot.get_extractors():
|
||||
new_archiveresult_id = uuid.uuid4()
|
||||
yield {'name': 'ARCHIVERESULT_CREATE', 'id': new_archiveresult_id, 'snapshot_id': snapshot.id, 'extractor': extractor.name}
|
||||
yield {'name': 'ARCHIVERESULT_START', 'id': new_archiveresult_id}
|
||||
|
||||
snapshot.status = Snapshot.StatusChoices.STARTED
|
||||
snapshot.save()
|
||||
yield {'name': 'FS_WRITE', 'path': snapshot.OUTPUT_DIR / 'index.json', 'content': json.dumps(snapshot.as_json(), default=str, indent=4, sort_keys=True)}
|
||||
yield {'name': 'SNAPSHOT_UPDATED', 'id': snapshot.id}
|
||||
|
||||
|
||||
|
||||
class ArchiveResultWorker(WorkerType):
|
||||
name = 'archiveresult'
|
||||
listens_to = 'ARCHIVERESULT_'
|
||||
outputs = ['ARCHIVERESULT_', 'FS_']
|
||||
|
||||
@staticmethod
|
||||
def on_ARCHIVERESULT_UPDATE(event: Event) -> Iterable[EventDict]:
|
||||
archiveresult = ArchiveResult.objects.get(id=event.id)
|
||||
diff = {
|
||||
key: val
|
||||
for key, val in event.items()
|
||||
if getattr(archiveresult, key) != val
|
||||
}
|
||||
if diff:
|
||||
archiveresult.update(**diff)
|
||||
yield {'name': 'ARCHIVERESULT_UPDATED', 'id': archiveresult.id}
|
||||
|
||||
@staticmethod
|
||||
def on_ARCHIVERESULT_UPDATED(event: Event) -> Iterable[EventDict]:
|
||||
archiveresult = ArchiveResult.objects.get(id=event.id)
|
||||
yield {'name': 'FS_WRITE_SYMLINKS', 'path': archiveresult.OUTPUT_DIR, 'symlinks': archiveresult.output_dir_symlinks}
|
||||
|
||||
@staticmethod
|
||||
def on_ARCHIVERESULT_CREATE(event: Event) -> Iterable[EventDict]:
|
||||
archiveresult, created = ArchiveResult.objects.get_or_create(id=event.pop('archiveresult_id'), defaults=event)
|
||||
if created:
|
||||
yield {'name': 'ARCHIVERESULT_UPDATE', 'id': archiveresult.id}
|
||||
else:
|
||||
diff = {
|
||||
key: val
|
||||
for key, val in event.items()
|
||||
if getattr(archiveresult, key) != val
|
||||
}
|
||||
assert not diff, f'ArchiveResult {archiveresult.id} already exists and has different values, cannot create on top of it: {diff}'
|
||||
|
||||
@staticmethod
|
||||
def on_ARCHIVERESULT_SEAL(event: Event) -> Iterable[EventDict]:
|
||||
archiveresult = ArchiveResult.objects.get(id=event.id, status=ArchiveResult.StatusChoices.STARTED)
|
||||
assert archiveresult.can_seal()
|
||||
yield {'name': 'ARCHIVERESULT_UPDATE', 'id': archiveresult.id, 'status': 'sealed'}
|
||||
|
||||
@staticmethod
|
||||
def on_ARCHIVERESULT_START(event: Event) -> Iterable[EventDict]:
|
||||
archiveresult = ArchiveResult.objects.get(id=event.id, status=ArchiveResult.StatusChoices.QUEUED)
|
||||
|
||||
yield {
|
||||
'name': 'SHELL_EXEC',
|
||||
'cmd': archiveresult.EXTRACTOR.get_cmd(),
|
||||
'cwd': archiveresult.OUTPUT_DIR,
|
||||
'on_exit': {
|
||||
'name': 'ARCHIVERESULT_SEAL',
|
||||
'id': archiveresult.id,
|
||||
},
|
||||
}
|
||||
|
||||
archiveresult.status = ArchiveResult.StatusChoices.STARTED
|
||||
archiveresult.save()
|
||||
yield {'name': 'FS_WRITE', 'path': archiveresult.OUTPUT_DIR / 'index.json', 'content': json.dumps(archiveresult.as_json(), default=str, indent=4, sort_keys=True)}
|
||||
yield {'name': 'ARCHIVERESULT_UPDATED', 'id': archiveresult.id}
|
||||
|
||||
@staticmethod
|
||||
def on_ARCHIVERESULT_IDLE(event: Event) -> Iterable[EventDict]:
|
||||
stale_archiveresult = ArchiveResult.objects.exclude(status__in=[ArchiveResult.StatusChoices.SUCCEEDED, ArchiveResult.StatusChoices.FAILED]).first()
|
||||
if not stale_archiveresult:
|
||||
return []
|
||||
if stale_archiveresult.can_start():
|
||||
yield {'name': 'ARCHIVERESULT_START', 'id': stale_archiveresult.id}
|
||||
if stale_archiveresult.can_seal():
|
||||
yield {'name': 'ARCHIVERESULT_SEAL', 'id': stale_archiveresult.id}
|
||||
|
||||
|
||||
WORKER_TYPES = [
|
||||
OrchestratorWorker,
|
||||
FileSystemWorker,
|
||||
CrawlWorker,
|
||||
SnapshotWorker,
|
||||
ArchiveResultWorker,
|
||||
]
|
||||
|
||||
def get_worker_type(name: str) -> Type[WorkerType]:
|
||||
for worker_type in WORKER_TYPES:
|
||||
matches_verbose_name = (worker_type.name == name)
|
||||
matches_class_name = (worker_type.__name__.lower() == name.lower())
|
||||
matches_listens_to = (worker_type.listens_to.strip('_').lower() == name.strip('_').lower())
|
||||
if matches_verbose_name or matches_class_name or matches_listens_to:
|
||||
return worker_type
|
||||
raise Exception(f'Worker type not found: {name}')
|
||||
def get_worker_class(name: str) -> type[Worker]:
|
||||
"""Get worker class by name."""
|
||||
if name not in WORKER_TYPES:
|
||||
raise ValueError(f'Unknown worker type: {name}. Valid types: {list(WORKER_TYPES.keys())}')
|
||||
return WORKER_TYPES[name]
|
||||
|
||||
Reference in New Issue
Block a user