mirror of
https://github.com/ArchiveBox/ArchiveBox.git
synced 2026-04-06 07:47:53 +10:00
cleanup archivebox tests
This commit is contained in:
@@ -3,8 +3,10 @@
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import tempfile
|
||||
import textwrap
|
||||
import time
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
|
||||
@@ -14,6 +16,9 @@ from archivebox.uuid_compat import uuid7
|
||||
|
||||
pytest_plugins = ["archivebox.tests.fixtures"]
|
||||
|
||||
SESSION_DATA_DIR = Path(tempfile.mkdtemp(prefix="archivebox-pytest-session-")).resolve()
|
||||
os.environ.setdefault("DATA_DIR", str(SESSION_DATA_DIR))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# CLI Helpers (defined before fixtures that use them)
|
||||
@@ -82,6 +87,36 @@ def run_archivebox_cmd(
|
||||
# Fixtures
|
||||
# =============================================================================
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def isolate_test_runtime(tmp_path):
|
||||
"""
|
||||
Run each pytest test from an isolated temp cwd and restore env mutations.
|
||||
|
||||
The maintained pytest suite lives under ``archivebox/tests``. Many of those
|
||||
CLI tests shell out without passing ``cwd=`` explicitly, so the safest
|
||||
contract is that every test starts in its own temp directory and any
|
||||
in-process ``os.environ`` edits are rolled back afterwards.
|
||||
|
||||
We intentionally clear ``DATA_DIR`` for the body of each test so subprocess
|
||||
tests that rely on cwd keep working. During collection/import time we still
|
||||
seed a separate session-scoped temp ``DATA_DIR`` above so any ArchiveBox
|
||||
config imported before this fixture runs never points at the repo root.
|
||||
"""
|
||||
original_cwd = Path.cwd()
|
||||
original_env = os.environ.copy()
|
||||
os.chdir(tmp_path)
|
||||
os.environ.pop("DATA_DIR", None)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
os.chdir(original_cwd)
|
||||
os.environ.clear()
|
||||
os.environ.update(original_env)
|
||||
|
||||
|
||||
def pytest_sessionfinish(session, exitstatus):
|
||||
shutil.rmtree(SESSION_DATA_DIR, ignore_errors=True)
|
||||
|
||||
@pytest.fixture
|
||||
def isolated_data_dir(tmp_path):
|
||||
"""
|
||||
|
||||
@@ -7,8 +7,11 @@ import pytest
|
||||
|
||||
@pytest.fixture
|
||||
def process(tmp_path):
|
||||
os.chdir(tmp_path)
|
||||
process = subprocess.run(['archivebox', 'init'], capture_output=True)
|
||||
process = subprocess.run(
|
||||
['archivebox', 'init'],
|
||||
capture_output=True,
|
||||
cwd=tmp_path,
|
||||
)
|
||||
return process
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
36
archivebox/tests/test_api_cli_schedule.py
Normal file
36
archivebox/tests/test_api_cli_schedule.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from io import StringIO
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.test import RequestFactory, TestCase
|
||||
|
||||
from archivebox.api.v1_cli import ScheduleCommandSchema, cli_schedule
|
||||
from archivebox.crawls.models import CrawlSchedule
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
|
||||
class CLIScheduleAPITests(TestCase):
|
||||
def setUp(self):
|
||||
self.user = User.objects.create_user(
|
||||
username='api-user',
|
||||
password='testpass123',
|
||||
email='api@example.com',
|
||||
)
|
||||
|
||||
def test_schedule_api_creates_schedule(self):
|
||||
request = RequestFactory().post('/api/v1/cli/schedule')
|
||||
request.user = self.user
|
||||
setattr(request, 'stdout', StringIO())
|
||||
setattr(request, 'stderr', StringIO())
|
||||
args = ScheduleCommandSchema(
|
||||
every='daily',
|
||||
import_path='https://example.com/feed.xml',
|
||||
quiet=True,
|
||||
)
|
||||
|
||||
response = cli_schedule(request, args)
|
||||
|
||||
self.assertTrue(response['success'])
|
||||
self.assertEqual(response['result_format'], 'json')
|
||||
self.assertEqual(CrawlSchedule.objects.count(), 1)
|
||||
self.assertEqual(len(response['result']['created_schedule_ids']), 1)
|
||||
@@ -1,13 +1,10 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Integration tests for archivebox extract command."""
|
||||
"""Tests for archivebox extract input handling and pipelines."""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sqlite3
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
|
||||
def test_extract_runs_on_snapshot_id(tmp_path, process, disable_extractors_dict):
|
||||
@@ -271,7 +268,3 @@ class TestExtractCLI:
|
||||
|
||||
# Should show warning about no snapshots or exit normally (empty input)
|
||||
assert result.returncode == 0 or 'No' in result.stderr
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
377
archivebox/tests/test_cli_piping.py
Normal file
377
archivebox/tests/test_cli_piping.py
Normal file
@@ -0,0 +1,377 @@
|
||||
"""
|
||||
Tests for JSONL piping contracts and `archivebox run` / `archivebox orchestrator`.
|
||||
|
||||
This file covers both:
|
||||
- low-level JSONL/stdin parsing behavior that makes CLI piping work
|
||||
- subprocess integration for the supported records `archivebox run` consumes
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
import sys
|
||||
import uuid
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
|
||||
from archivebox.tests.conftest import (
|
||||
create_test_url,
|
||||
parse_jsonl_output,
|
||||
run_archivebox_cmd,
|
||||
)
|
||||
|
||||
|
||||
PIPE_TEST_ENV = {
|
||||
"PLUGINS": "favicon",
|
||||
"SAVE_FAVICON": "True",
|
||||
"USE_COLOR": "False",
|
||||
"SHOW_PROGRESS": "False",
|
||||
}
|
||||
|
||||
|
||||
class MockTTYStringIO(StringIO):
|
||||
def __init__(self, initial_value: str = "", *, is_tty: bool):
|
||||
super().__init__(initial_value)
|
||||
self._is_tty = is_tty
|
||||
|
||||
def isatty(self) -> bool:
|
||||
return self._is_tty
|
||||
|
||||
|
||||
def _stdout_lines(stdout: str) -> list[str]:
|
||||
return [line for line in stdout.splitlines() if line.strip()]
|
||||
|
||||
|
||||
def _assert_stdout_is_jsonl_only(stdout: str) -> None:
|
||||
lines = _stdout_lines(stdout)
|
||||
assert lines, "Expected stdout to contain JSONL records"
|
||||
assert all(line.lstrip().startswith("{") for line in lines), stdout
|
||||
|
||||
|
||||
def _sqlite_param(value: object) -> object:
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
try:
|
||||
return uuid.UUID(value).hex
|
||||
except ValueError:
|
||||
return value
|
||||
|
||||
|
||||
def _db_value(data_dir: Path, sql: str, params: tuple[object, ...] = ()) -> object | None:
|
||||
conn = sqlite3.connect(data_dir / "index.sqlite3")
|
||||
try:
|
||||
row = conn.execute(sql, tuple(_sqlite_param(param) for param in params)).fetchone()
|
||||
finally:
|
||||
conn.close()
|
||||
return row[0] if row else None
|
||||
|
||||
|
||||
def test_parse_line_accepts_supported_piping_inputs():
|
||||
"""The JSONL parser should normalize the input forms CLI pipes accept."""
|
||||
from archivebox.misc.jsonl import TYPE_CRAWL, TYPE_SNAPSHOT, parse_line
|
||||
|
||||
assert parse_line("") is None
|
||||
assert parse_line(" ") is None
|
||||
assert parse_line("# comment") is None
|
||||
assert parse_line("not-a-url") is None
|
||||
assert parse_line("ftp://example.com") is None
|
||||
|
||||
plain_url = parse_line("https://example.com")
|
||||
assert plain_url == {"type": TYPE_SNAPSHOT, "url": "https://example.com"}
|
||||
|
||||
file_url = parse_line("file:///tmp/example.txt")
|
||||
assert file_url == {"type": TYPE_SNAPSHOT, "url": "file:///tmp/example.txt"}
|
||||
|
||||
snapshot_json = parse_line('{"type":"Snapshot","url":"https://example.com","tags":"tag1,tag2"}')
|
||||
assert snapshot_json is not None
|
||||
assert snapshot_json["type"] == TYPE_SNAPSHOT
|
||||
assert snapshot_json["tags"] == "tag1,tag2"
|
||||
|
||||
crawl_json = parse_line('{"type":"Crawl","id":"abc123","urls":"https://example.com","max_depth":1}')
|
||||
assert crawl_json is not None
|
||||
assert crawl_json["type"] == TYPE_CRAWL
|
||||
assert crawl_json["id"] == "abc123"
|
||||
assert crawl_json["max_depth"] == 1
|
||||
|
||||
snapshot_id = "01234567-89ab-cdef-0123-456789abcdef"
|
||||
parsed_id = parse_line(snapshot_id)
|
||||
assert parsed_id == {"type": TYPE_SNAPSHOT, "id": snapshot_id}
|
||||
|
||||
compact_snapshot_id = "0123456789abcdef0123456789abcdef"
|
||||
compact_parsed_id = parse_line(compact_snapshot_id)
|
||||
assert compact_parsed_id == {"type": TYPE_SNAPSHOT, "id": compact_snapshot_id}
|
||||
|
||||
|
||||
def test_read_args_or_stdin_handles_args_stdin_and_mixed_jsonl():
|
||||
"""Piping helpers should consume args, structured JSONL, and pass-through records."""
|
||||
from archivebox.misc.jsonl import TYPE_CRAWL, read_args_or_stdin
|
||||
|
||||
records = list(read_args_or_stdin(("https://example1.com", "https://example2.com")))
|
||||
assert [record["url"] for record in records] == ["https://example1.com", "https://example2.com"]
|
||||
|
||||
stdin_records = list(
|
||||
read_args_or_stdin(
|
||||
(),
|
||||
stream=MockTTYStringIO(
|
||||
'https://plain-url.com\n'
|
||||
'{"type":"Snapshot","url":"https://jsonl-url.com","tags":"test"}\n'
|
||||
'{"type":"Tag","id":"tag-1","name":"example"}\n'
|
||||
'01234567-89ab-cdef-0123-456789abcdef\n'
|
||||
'not valid json\n',
|
||||
is_tty=False,
|
||||
),
|
||||
)
|
||||
)
|
||||
assert len(stdin_records) == 4
|
||||
assert stdin_records[0]["url"] == "https://plain-url.com"
|
||||
assert stdin_records[1]["url"] == "https://jsonl-url.com"
|
||||
assert stdin_records[1]["tags"] == "test"
|
||||
assert stdin_records[2]["type"] == "Tag"
|
||||
assert stdin_records[2]["name"] == "example"
|
||||
assert stdin_records[3]["id"] == "01234567-89ab-cdef-0123-456789abcdef"
|
||||
|
||||
crawl_records = list(
|
||||
read_args_or_stdin(
|
||||
(),
|
||||
stream=MockTTYStringIO(
|
||||
'{"type":"Crawl","id":"crawl-1","urls":"https://example.com\\nhttps://foo.com"}\n',
|
||||
is_tty=False,
|
||||
),
|
||||
)
|
||||
)
|
||||
assert len(crawl_records) == 1
|
||||
assert crawl_records[0]["type"] == TYPE_CRAWL
|
||||
assert crawl_records[0]["id"] == "crawl-1"
|
||||
|
||||
tty_records = list(read_args_or_stdin((), stream=MockTTYStringIO("https://example.com", is_tty=True)))
|
||||
assert tty_records == []
|
||||
|
||||
|
||||
def test_collect_urls_from_plugins_reads_only_parser_outputs(tmp_path):
|
||||
"""Parser extractor `urls.jsonl` outputs should be discoverable for recursive piping."""
|
||||
from archivebox.hooks import collect_urls_from_plugins
|
||||
|
||||
(tmp_path / "wget").mkdir()
|
||||
(tmp_path / "wget" / "urls.jsonl").write_text(
|
||||
'{"url":"https://wget-link-1.com"}\n'
|
||||
'{"url":"https://wget-link-2.com"}\n',
|
||||
encoding="utf-8",
|
||||
)
|
||||
(tmp_path / "parse_html_urls").mkdir()
|
||||
(tmp_path / "parse_html_urls" / "urls.jsonl").write_text(
|
||||
'{"url":"https://html-link-1.com"}\n'
|
||||
'{"url":"https://html-link-2.com","title":"HTML Link 2"}\n',
|
||||
encoding="utf-8",
|
||||
)
|
||||
(tmp_path / "screenshot").mkdir()
|
||||
|
||||
urls = collect_urls_from_plugins(tmp_path)
|
||||
assert len(urls) == 4
|
||||
assert {url["plugin"] for url in urls} == {"wget", "parse_html_urls"}
|
||||
titled = [url for url in urls if url.get("title") == "HTML Link 2"]
|
||||
assert len(titled) == 1
|
||||
assert titled[0]["url"] == "https://html-link-2.com"
|
||||
|
||||
assert collect_urls_from_plugins(tmp_path / "nonexistent") == []
|
||||
|
||||
|
||||
def test_crawl_create_stdout_pipes_into_run(initialized_archive):
|
||||
"""`archivebox crawl create | archivebox run` should queue and materialize snapshots."""
|
||||
url = create_test_url()
|
||||
|
||||
create_stdout, create_stderr, create_code = run_archivebox_cmd(
|
||||
["crawl", "create", url],
|
||||
data_dir=initialized_archive,
|
||||
)
|
||||
assert create_code == 0, create_stderr
|
||||
_assert_stdout_is_jsonl_only(create_stdout)
|
||||
|
||||
crawl = next(record for record in parse_jsonl_output(create_stdout) if record.get("type") == "Crawl")
|
||||
|
||||
run_stdout, run_stderr, run_code = run_archivebox_cmd(
|
||||
["run"],
|
||||
stdin=create_stdout,
|
||||
data_dir=initialized_archive,
|
||||
timeout=120,
|
||||
env=PIPE_TEST_ENV,
|
||||
)
|
||||
assert run_code == 0, run_stderr
|
||||
_assert_stdout_is_jsonl_only(run_stdout)
|
||||
|
||||
run_records = parse_jsonl_output(run_stdout)
|
||||
assert any(record.get("type") == "Crawl" and record.get("id") == crawl["id"] for record in run_records)
|
||||
|
||||
snapshot_count = _db_value(
|
||||
initialized_archive,
|
||||
"SELECT COUNT(*) FROM core_snapshot WHERE crawl_id = ?",
|
||||
(crawl["id"],),
|
||||
)
|
||||
assert isinstance(snapshot_count, int)
|
||||
assert snapshot_count >= 1
|
||||
|
||||
|
||||
def test_snapshot_list_stdout_pipes_into_run(initialized_archive):
|
||||
"""`archivebox snapshot list | archivebox run` should requeue listed snapshots."""
|
||||
url = create_test_url()
|
||||
|
||||
create_stdout, create_stderr, create_code = run_archivebox_cmd(
|
||||
["snapshot", "create", url],
|
||||
data_dir=initialized_archive,
|
||||
)
|
||||
assert create_code == 0, create_stderr
|
||||
snapshot = next(record for record in parse_jsonl_output(create_stdout) if record.get("type") == "Snapshot")
|
||||
|
||||
list_stdout, list_stderr, list_code = run_archivebox_cmd(
|
||||
["snapshot", "list", "--status=queued", f"--url__icontains={snapshot['id']}"],
|
||||
data_dir=initialized_archive,
|
||||
)
|
||||
if list_code != 0 or not parse_jsonl_output(list_stdout):
|
||||
list_stdout, list_stderr, list_code = run_archivebox_cmd(
|
||||
["snapshot", "list", f"--url__icontains={url}"],
|
||||
data_dir=initialized_archive,
|
||||
)
|
||||
assert list_code == 0, list_stderr
|
||||
_assert_stdout_is_jsonl_only(list_stdout)
|
||||
|
||||
run_stdout, run_stderr, run_code = run_archivebox_cmd(
|
||||
["run"],
|
||||
stdin=list_stdout,
|
||||
data_dir=initialized_archive,
|
||||
timeout=120,
|
||||
env=PIPE_TEST_ENV,
|
||||
)
|
||||
assert run_code == 0, run_stderr
|
||||
_assert_stdout_is_jsonl_only(run_stdout)
|
||||
|
||||
run_records = parse_jsonl_output(run_stdout)
|
||||
assert any(record.get("type") == "Snapshot" and record.get("id") == snapshot["id"] for record in run_records)
|
||||
|
||||
snapshot_status = _db_value(
|
||||
initialized_archive,
|
||||
"SELECT status FROM core_snapshot WHERE id = ?",
|
||||
(snapshot["id"],),
|
||||
)
|
||||
assert snapshot_status == "sealed"
|
||||
|
||||
|
||||
def test_archiveresult_list_stdout_pipes_into_orchestrator_alias(initialized_archive):
|
||||
"""`archivebox archiveresult list | archivebox orchestrator` should preserve clean JSONL stdout."""
|
||||
url = create_test_url()
|
||||
|
||||
snapshot_stdout, snapshot_stderr, snapshot_code = run_archivebox_cmd(
|
||||
["snapshot", "create", url],
|
||||
data_dir=initialized_archive,
|
||||
)
|
||||
assert snapshot_code == 0, snapshot_stderr
|
||||
|
||||
ar_create_stdout, ar_create_stderr, ar_create_code = run_archivebox_cmd(
|
||||
["archiveresult", "create", "--plugin=favicon"],
|
||||
stdin=snapshot_stdout,
|
||||
data_dir=initialized_archive,
|
||||
)
|
||||
assert ar_create_code == 0, ar_create_stderr
|
||||
|
||||
created_records = parse_jsonl_output(ar_create_stdout)
|
||||
archiveresult = next(record for record in created_records if record.get("type") == "ArchiveResult")
|
||||
|
||||
list_stdout, list_stderr, list_code = run_archivebox_cmd(
|
||||
["archiveresult", "list", "--plugin=favicon"],
|
||||
data_dir=initialized_archive,
|
||||
)
|
||||
assert list_code == 0, list_stderr
|
||||
_assert_stdout_is_jsonl_only(list_stdout)
|
||||
|
||||
orchestrator_stdout, orchestrator_stderr, orchestrator_code = run_archivebox_cmd(
|
||||
["orchestrator"],
|
||||
stdin=list_stdout,
|
||||
data_dir=initialized_archive,
|
||||
timeout=120,
|
||||
env=PIPE_TEST_ENV,
|
||||
)
|
||||
assert orchestrator_code == 0, orchestrator_stderr
|
||||
_assert_stdout_is_jsonl_only(orchestrator_stdout)
|
||||
assert "renamed to `archivebox run`" in orchestrator_stderr
|
||||
|
||||
run_records = parse_jsonl_output(orchestrator_stdout)
|
||||
assert any(
|
||||
record.get("type") == "ArchiveResult" and record.get("id") == archiveresult["id"]
|
||||
for record in run_records
|
||||
)
|
||||
|
||||
|
||||
def test_binary_create_stdout_pipes_into_run(initialized_archive):
|
||||
"""`archivebox binary create | archivebox run` should queue the binary record for processing."""
|
||||
create_stdout, create_stderr, create_code = run_archivebox_cmd(
|
||||
["binary", "create", "--name=python3", f"--abspath={sys.executable}", "--version=test"],
|
||||
data_dir=initialized_archive,
|
||||
)
|
||||
assert create_code == 0, create_stderr
|
||||
_assert_stdout_is_jsonl_only(create_stdout)
|
||||
|
||||
binary = next(record for record in parse_jsonl_output(create_stdout) if record.get("type") == "Binary")
|
||||
|
||||
run_stdout, run_stderr, run_code = run_archivebox_cmd(
|
||||
["run"],
|
||||
stdin=create_stdout,
|
||||
data_dir=initialized_archive,
|
||||
timeout=120,
|
||||
)
|
||||
assert run_code == 0, run_stderr
|
||||
_assert_stdout_is_jsonl_only(run_stdout)
|
||||
|
||||
run_records = parse_jsonl_output(run_stdout)
|
||||
assert any(record.get("type") == "Binary" and record.get("id") == binary["id"] for record in run_records)
|
||||
|
||||
status = _db_value(
|
||||
initialized_archive,
|
||||
"SELECT status FROM machine_binary WHERE id = ?",
|
||||
(binary["id"],),
|
||||
)
|
||||
assert status in {"queued", "installed"}
|
||||
|
||||
|
||||
def test_multi_stage_pipeline_into_run(initialized_archive):
|
||||
"""`crawl create | snapshot create | archiveresult create | run` should preserve JSONL and finish work."""
|
||||
url = create_test_url()
|
||||
|
||||
crawl_stdout, crawl_stderr, crawl_code = run_archivebox_cmd(
|
||||
["crawl", "create", url],
|
||||
data_dir=initialized_archive,
|
||||
)
|
||||
assert crawl_code == 0, crawl_stderr
|
||||
_assert_stdout_is_jsonl_only(crawl_stdout)
|
||||
|
||||
snapshot_stdout, snapshot_stderr, snapshot_code = run_archivebox_cmd(
|
||||
["snapshot", "create"],
|
||||
stdin=crawl_stdout,
|
||||
data_dir=initialized_archive,
|
||||
)
|
||||
assert snapshot_code == 0, snapshot_stderr
|
||||
_assert_stdout_is_jsonl_only(snapshot_stdout)
|
||||
|
||||
archiveresult_stdout, archiveresult_stderr, archiveresult_code = run_archivebox_cmd(
|
||||
["archiveresult", "create", "--plugin=favicon"],
|
||||
stdin=snapshot_stdout,
|
||||
data_dir=initialized_archive,
|
||||
)
|
||||
assert archiveresult_code == 0, archiveresult_stderr
|
||||
_assert_stdout_is_jsonl_only(archiveresult_stdout)
|
||||
|
||||
run_stdout, run_stderr, run_code = run_archivebox_cmd(
|
||||
["run"],
|
||||
stdin=archiveresult_stdout,
|
||||
data_dir=initialized_archive,
|
||||
timeout=120,
|
||||
env=PIPE_TEST_ENV,
|
||||
)
|
||||
assert run_code == 0, run_stderr
|
||||
_assert_stdout_is_jsonl_only(run_stdout)
|
||||
|
||||
run_records = parse_jsonl_output(run_stdout)
|
||||
snapshot = next(record for record in run_records if record.get("type") == "Snapshot")
|
||||
assert any(record.get("type") == "ArchiveResult" for record in run_records)
|
||||
|
||||
snapshot_status = _db_value(
|
||||
initialized_archive,
|
||||
"SELECT status FROM core_snapshot WHERE id = ?",
|
||||
(snapshot["id"],),
|
||||
)
|
||||
assert snapshot_status == "sealed"
|
||||
@@ -1,156 +0,0 @@
|
||||
import json as pyjson
|
||||
import sqlite3
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
from .fixtures import disable_extractors_dict, process
|
||||
|
||||
FIXTURES = (disable_extractors_dict, process)
|
||||
|
||||
|
||||
def _find_snapshot_dir(data_dir: Path, snapshot_id: str) -> Path | None:
|
||||
candidates = {snapshot_id}
|
||||
if len(snapshot_id) == 32:
|
||||
candidates.add(f"{snapshot_id[:8]}-{snapshot_id[8:12]}-{snapshot_id[12:16]}-{snapshot_id[16:20]}-{snapshot_id[20:]}")
|
||||
elif len(snapshot_id) == 36 and "-" in snapshot_id:
|
||||
candidates.add(snapshot_id.replace("-", ""))
|
||||
|
||||
for needle in candidates:
|
||||
for path in data_dir.rglob(needle):
|
||||
if path.is_dir():
|
||||
return path
|
||||
return None
|
||||
|
||||
|
||||
def _latest_snapshot_dir(data_dir: Path) -> Path:
|
||||
conn = sqlite3.connect(data_dir / "index.sqlite3")
|
||||
try:
|
||||
snapshot_id = conn.execute(
|
||||
"SELECT id FROM core_snapshot ORDER BY created_at DESC LIMIT 1"
|
||||
).fetchone()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
assert snapshot_id is not None, "Expected a snapshot to be created"
|
||||
snapshot_dir = _find_snapshot_dir(data_dir, str(snapshot_id[0]))
|
||||
assert snapshot_dir is not None, f"Snapshot output directory not found for {snapshot_id[0]}"
|
||||
return snapshot_dir
|
||||
|
||||
|
||||
def _latest_plugin_result(data_dir: Path, plugin: str) -> tuple[str, str, dict]:
|
||||
conn = sqlite3.connect(data_dir / "index.sqlite3")
|
||||
try:
|
||||
row = conn.execute(
|
||||
"SELECT snapshot_id, status, output_files FROM core_archiveresult "
|
||||
"WHERE plugin = ? ORDER BY created_at DESC LIMIT 1",
|
||||
(plugin,),
|
||||
).fetchone()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
assert row is not None, f"Expected an ArchiveResult row for plugin={plugin}"
|
||||
output_files = row[2]
|
||||
if isinstance(output_files, str):
|
||||
output_files = pyjson.loads(output_files or "{}")
|
||||
output_files = output_files or {}
|
||||
return str(row[0]), str(row[1]), output_files
|
||||
|
||||
|
||||
def _plugin_output_paths(data_dir: Path, plugin: str) -> list[Path]:
|
||||
snapshot_id, status, output_files = _latest_plugin_result(data_dir, plugin)
|
||||
assert status == "succeeded", f"Expected {plugin} ArchiveResult to succeed, got {status}"
|
||||
assert output_files, f"Expected {plugin} ArchiveResult to record output_files"
|
||||
|
||||
snapshot_dir = _find_snapshot_dir(data_dir, snapshot_id)
|
||||
assert snapshot_dir is not None, f"Snapshot output directory not found for {snapshot_id}"
|
||||
|
||||
plugin_dir = snapshot_dir / plugin
|
||||
output_paths = [plugin_dir / rel_path for rel_path in output_files.keys()]
|
||||
missing_paths = [path for path in output_paths if not path.exists()]
|
||||
assert not missing_paths, f"Expected plugin outputs to exist on disk, missing: {missing_paths}"
|
||||
return output_paths
|
||||
|
||||
|
||||
def _archivebox_env(base_env: dict, data_dir: Path) -> dict:
|
||||
env = base_env.copy()
|
||||
tmp_dir = Path("/tmp") / f"abx-{data_dir.name}"
|
||||
tmp_dir.mkdir(parents=True, exist_ok=True)
|
||||
env["TMP_DIR"] = str(tmp_dir)
|
||||
env["ARCHIVEBOX_ALLOW_NO_UNIX_SOCKETS"] = "true"
|
||||
return env
|
||||
|
||||
|
||||
def test_singlefile_works(tmp_path, process, disable_extractors_dict):
|
||||
data_dir = Path.cwd()
|
||||
env = _archivebox_env(disable_extractors_dict, data_dir)
|
||||
env.update({"SAVE_SINGLEFILE": "true"})
|
||||
add_process = subprocess.run(
|
||||
['archivebox', 'add', '--plugins=singlefile', 'https://example.com'],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
env=env,
|
||||
timeout=900,
|
||||
)
|
||||
assert add_process.returncode == 0, add_process.stderr
|
||||
output_files = _plugin_output_paths(data_dir, "singlefile")
|
||||
assert any(path.suffix in (".html", ".htm") for path in output_files)
|
||||
|
||||
def test_readability_works(tmp_path, process, disable_extractors_dict):
|
||||
data_dir = Path.cwd()
|
||||
env = _archivebox_env(disable_extractors_dict, data_dir)
|
||||
env.update({"SAVE_SINGLEFILE": "true", "SAVE_READABILITY": "true"})
|
||||
add_process = subprocess.run(
|
||||
['archivebox', 'add', '--plugins=singlefile,readability', 'https://example.com'],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
env=env,
|
||||
timeout=900,
|
||||
)
|
||||
assert add_process.returncode == 0, add_process.stderr
|
||||
output_files = _plugin_output_paths(data_dir, "readability")
|
||||
assert any(path.suffix in (".html", ".htm") for path in output_files)
|
||||
|
||||
def test_htmltotext_works(tmp_path, process, disable_extractors_dict):
|
||||
data_dir = Path.cwd()
|
||||
env = _archivebox_env(disable_extractors_dict, data_dir)
|
||||
env.update({"SAVE_WGET": "true", "SAVE_HTMLTOTEXT": "true"})
|
||||
add_process = subprocess.run(
|
||||
['archivebox', 'add', '--plugins=wget,htmltotext', 'https://example.com'],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
env=env,
|
||||
timeout=900,
|
||||
)
|
||||
assert add_process.returncode == 0, add_process.stderr
|
||||
output_files = _plugin_output_paths(data_dir, "htmltotext")
|
||||
assert any(path.suffix == ".txt" for path in output_files)
|
||||
|
||||
def test_use_node_false_disables_readability_and_singlefile(tmp_path, process, disable_extractors_dict):
|
||||
env = _archivebox_env(disable_extractors_dict, Path.cwd())
|
||||
env.update({"SAVE_READABILITY": "true", "SAVE_DOM": "true", "SAVE_SINGLEFILE": "true", "USE_NODE": "false"})
|
||||
add_process = subprocess.run(['archivebox', 'add', '--plugins=readability,dom,singlefile', 'https://example.com'],
|
||||
capture_output=True, env=env)
|
||||
output_str = add_process.stdout.decode("utf-8")
|
||||
assert "> singlefile" not in output_str
|
||||
assert "> readability" not in output_str
|
||||
|
||||
def test_headers_retrieved(tmp_path, process, disable_extractors_dict):
|
||||
data_dir = Path.cwd()
|
||||
env = _archivebox_env(disable_extractors_dict, data_dir)
|
||||
env.update({"SAVE_HEADERS": "true"})
|
||||
add_process = subprocess.run(
|
||||
['archivebox', 'add', '--plugins=headers', 'https://example.com'],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
env=env,
|
||||
timeout=900,
|
||||
)
|
||||
assert add_process.returncode == 0, add_process.stderr
|
||||
output_files = _plugin_output_paths(data_dir, "headers")
|
||||
output_file = next((path for path in output_files if path.suffix == ".json"), None)
|
||||
assert output_file is not None, f"Expected headers output_files to include a JSON file, got: {output_files}"
|
||||
with open(output_file, 'r', encoding='utf-8') as f:
|
||||
headers = pyjson.load(f)
|
||||
response_headers = headers.get("response_headers") or headers.get("headers") or {}
|
||||
assert isinstance(response_headers, dict), f"Expected response_headers dict, got: {response_headers!r}"
|
||||
assert 'Content-Type' in response_headers or 'content-type' in response_headers
|
||||
641
archivebox/tests/test_machine_models.py
Normal file
641
archivebox/tests/test_machine_models.py
Normal file
@@ -0,0 +1,641 @@
|
||||
"""
|
||||
Unit tests for machine module models: Machine, NetworkInterface, Binary, Process.
|
||||
|
||||
Tests cover:
|
||||
1. Machine model creation and current() method
|
||||
2. NetworkInterface model and network detection
|
||||
3. Binary model lifecycle and state machine
|
||||
4. Process model lifecycle, hierarchy, and state machine
|
||||
5. JSONL serialization/deserialization
|
||||
6. Manager methods
|
||||
7. Process tracking methods (replacing pid_utils)
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from typing import cast
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from django.test import TestCase
|
||||
from django.utils import timezone
|
||||
|
||||
from archivebox.machine.models import (
|
||||
BinaryManager,
|
||||
Machine,
|
||||
NetworkInterface,
|
||||
Binary,
|
||||
Process,
|
||||
BinaryMachine,
|
||||
ProcessMachine,
|
||||
MACHINE_RECHECK_INTERVAL,
|
||||
PID_REUSE_WINDOW,
|
||||
)
|
||||
|
||||
|
||||
class TestMachineModel(TestCase):
|
||||
"""Test the Machine model."""
|
||||
|
||||
def setUp(self):
|
||||
"""Reset cached machine between tests."""
|
||||
import archivebox.machine.models as models
|
||||
models._CURRENT_MACHINE = None
|
||||
|
||||
def test_machine_current_creates_machine(self):
|
||||
"""Machine.current() should create a machine if none exists."""
|
||||
machine = Machine.current()
|
||||
|
||||
self.assertIsNotNone(machine)
|
||||
self.assertIsNotNone(machine.id)
|
||||
self.assertIsNotNone(machine.guid)
|
||||
self.assertEqual(machine.hostname, os.uname().nodename)
|
||||
self.assertIn(machine.os_family, ['linux', 'darwin', 'windows', 'freebsd'])
|
||||
|
||||
def test_machine_current_returns_cached(self):
|
||||
"""Machine.current() should return cached machine within recheck interval."""
|
||||
machine1 = Machine.current()
|
||||
machine2 = Machine.current()
|
||||
|
||||
self.assertEqual(machine1.id, machine2.id)
|
||||
|
||||
def test_machine_current_refreshes_after_interval(self):
|
||||
"""Machine.current() should refresh after recheck interval."""
|
||||
import archivebox.machine.models as models
|
||||
|
||||
machine1 = Machine.current()
|
||||
|
||||
# Manually expire the cache by modifying modified_at
|
||||
machine1.modified_at = timezone.now() - timedelta(seconds=MACHINE_RECHECK_INTERVAL + 1)
|
||||
machine1.save()
|
||||
models._CURRENT_MACHINE = machine1
|
||||
|
||||
machine2 = Machine.current()
|
||||
|
||||
# Should have fetched/updated the machine (same GUID)
|
||||
self.assertEqual(machine1.guid, machine2.guid)
|
||||
|
||||
def test_machine_from_jsonl_update(self):
|
||||
"""Machine.from_json() should update machine config."""
|
||||
Machine.current() # Ensure machine exists
|
||||
record = {
|
||||
'config': {
|
||||
'WGET_BINARY': '/usr/bin/wget',
|
||||
},
|
||||
}
|
||||
|
||||
result = Machine.from_json(record)
|
||||
|
||||
self.assertIsNotNone(result)
|
||||
assert result is not None
|
||||
self.assertEqual(result.config.get('WGET_BINARY'), '/usr/bin/wget')
|
||||
|
||||
def test_machine_from_jsonl_invalid(self):
|
||||
"""Machine.from_json() should return None for invalid records."""
|
||||
result = Machine.from_json({'invalid': 'record'})
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_machine_manager_current(self):
|
||||
"""Machine.objects.current() should return current machine."""
|
||||
machine = Machine.current()
|
||||
self.assertIsNotNone(machine)
|
||||
self.assertEqual(machine.id, Machine.current().id)
|
||||
|
||||
|
||||
class TestNetworkInterfaceModel(TestCase):
|
||||
"""Test the NetworkInterface model."""
|
||||
|
||||
def setUp(self):
|
||||
"""Reset cached interface between tests."""
|
||||
import archivebox.machine.models as models
|
||||
models._CURRENT_MACHINE = None
|
||||
models._CURRENT_INTERFACE = None
|
||||
|
||||
def test_networkinterface_current_creates_interface(self):
|
||||
"""NetworkInterface.current() should create an interface if none exists."""
|
||||
interface = NetworkInterface.current()
|
||||
|
||||
self.assertIsNotNone(interface)
|
||||
self.assertIsNotNone(interface.id)
|
||||
self.assertIsNotNone(interface.machine)
|
||||
self.assertIsNotNone(interface.ip_local)
|
||||
|
||||
def test_networkinterface_current_returns_cached(self):
|
||||
"""NetworkInterface.current() should return cached interface within recheck interval."""
|
||||
interface1 = NetworkInterface.current()
|
||||
interface2 = NetworkInterface.current()
|
||||
|
||||
self.assertEqual(interface1.id, interface2.id)
|
||||
|
||||
def test_networkinterface_manager_current(self):
|
||||
"""NetworkInterface.objects.current() should return current interface."""
|
||||
interface = NetworkInterface.current()
|
||||
self.assertIsNotNone(interface)
|
||||
|
||||
|
||||
class TestBinaryModel(TestCase):
|
||||
"""Test the Binary model."""
|
||||
|
||||
def setUp(self):
|
||||
"""Reset cached binaries and create a machine."""
|
||||
import archivebox.machine.models as models
|
||||
models._CURRENT_MACHINE = None
|
||||
models._CURRENT_BINARIES = {}
|
||||
self.machine = Machine.current()
|
||||
|
||||
def test_binary_creation(self):
|
||||
"""Binary should be created with default values."""
|
||||
binary = Binary.objects.create(
|
||||
machine=self.machine,
|
||||
name='wget',
|
||||
binproviders='apt,brew,env',
|
||||
)
|
||||
|
||||
self.assertIsNotNone(binary.id)
|
||||
self.assertEqual(binary.name, 'wget')
|
||||
self.assertEqual(binary.status, Binary.StatusChoices.QUEUED)
|
||||
self.assertFalse(binary.is_valid)
|
||||
|
||||
def test_binary_is_valid(self):
|
||||
"""Binary.is_valid should be True when abspath and version are set."""
|
||||
binary = Binary.objects.create(
|
||||
machine=self.machine,
|
||||
name='wget',
|
||||
abspath='/usr/bin/wget',
|
||||
version='1.21',
|
||||
)
|
||||
|
||||
self.assertTrue(binary.is_valid)
|
||||
|
||||
def test_binary_manager_get_valid_binary(self):
|
||||
"""BinaryManager.get_valid_binary() should find valid binaries."""
|
||||
# Create invalid binary (no abspath)
|
||||
Binary.objects.create(machine=self.machine, name='wget')
|
||||
|
||||
# Create valid binary
|
||||
Binary.objects.create(
|
||||
machine=self.machine,
|
||||
name='wget',
|
||||
abspath='/usr/bin/wget',
|
||||
version='1.21',
|
||||
)
|
||||
|
||||
result = cast(BinaryManager, Binary.objects).get_valid_binary('wget')
|
||||
|
||||
self.assertIsNotNone(result)
|
||||
assert result is not None
|
||||
self.assertEqual(result.abspath, '/usr/bin/wget')
|
||||
|
||||
def test_binary_update_and_requeue(self):
|
||||
"""Binary.update_and_requeue() should update fields and save."""
|
||||
binary = Binary.objects.create(machine=self.machine, name='test')
|
||||
old_modified = binary.modified_at
|
||||
|
||||
binary.update_and_requeue(
|
||||
status=Binary.StatusChoices.QUEUED,
|
||||
retry_at=timezone.now() + timedelta(seconds=60),
|
||||
)
|
||||
|
||||
binary.refresh_from_db()
|
||||
self.assertEqual(binary.status, Binary.StatusChoices.QUEUED)
|
||||
self.assertGreater(binary.modified_at, old_modified)
|
||||
|
||||
def test_binary_from_json_preserves_install_args_overrides(self):
|
||||
"""Binary.from_json() should persist canonical install_args overrides unchanged."""
|
||||
overrides = {
|
||||
'apt': {'install_args': ['chromium']},
|
||||
'npm': {'install_args': 'puppeteer'},
|
||||
'custom': {'install_args': ['bash', '-lc', 'echo ok']},
|
||||
}
|
||||
|
||||
binary = Binary.from_json({
|
||||
'name': 'chrome',
|
||||
'binproviders': 'apt,npm,custom',
|
||||
'overrides': overrides,
|
||||
})
|
||||
|
||||
self.assertIsNotNone(binary)
|
||||
assert binary is not None
|
||||
self.assertEqual(binary.overrides, overrides)
|
||||
|
||||
def test_binary_from_json_does_not_coerce_legacy_override_shapes(self):
|
||||
"""Binary.from_json() should no longer translate legacy non-dict provider overrides."""
|
||||
overrides = {
|
||||
'apt': ['chromium'],
|
||||
'npm': 'puppeteer',
|
||||
}
|
||||
|
||||
binary = Binary.from_json({
|
||||
'name': 'chrome',
|
||||
'binproviders': 'apt,npm',
|
||||
'overrides': overrides,
|
||||
})
|
||||
|
||||
self.assertIsNotNone(binary)
|
||||
assert binary is not None
|
||||
self.assertEqual(binary.overrides, overrides)
|
||||
|
||||
def test_binary_from_json_prefers_published_readability_package(self):
|
||||
"""Binary.from_json() should rewrite readability's npm git URL to the published package."""
|
||||
binary = Binary.from_json({
|
||||
'name': 'readability-extractor',
|
||||
'binproviders': 'env,npm',
|
||||
'overrides': {
|
||||
'npm': {
|
||||
'install_args': ['https://github.com/ArchiveBox/readability-extractor'],
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
self.assertIsNotNone(binary)
|
||||
assert binary is not None
|
||||
self.assertEqual(
|
||||
binary.overrides,
|
||||
{
|
||||
'npm': {
|
||||
'install_args': ['readability-extractor'],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class TestBinaryStateMachine(TestCase):
|
||||
"""Test the BinaryMachine state machine."""
|
||||
|
||||
def setUp(self):
|
||||
"""Create a machine and binary for state machine tests."""
|
||||
import archivebox.machine.models as models
|
||||
models._CURRENT_MACHINE = None
|
||||
self.machine = Machine.current()
|
||||
self.binary = Binary.objects.create(
|
||||
machine=self.machine,
|
||||
name='test-binary',
|
||||
binproviders='env',
|
||||
)
|
||||
|
||||
def test_binary_state_machine_initial_state(self):
|
||||
"""BinaryMachine should start in queued state."""
|
||||
sm = BinaryMachine(self.binary)
|
||||
self.assertEqual(sm.current_state_value, Binary.StatusChoices.QUEUED)
|
||||
|
||||
def test_binary_state_machine_can_start(self):
|
||||
"""BinaryMachine.can_start() should check name and binproviders."""
|
||||
sm = BinaryMachine(self.binary)
|
||||
self.assertTrue(sm.can_install())
|
||||
|
||||
self.binary.binproviders = ''
|
||||
self.binary.save()
|
||||
sm = BinaryMachine(self.binary)
|
||||
self.assertFalse(sm.can_install())
|
||||
|
||||
|
||||
class TestProcessModel(TestCase):
|
||||
"""Test the Process model."""
|
||||
|
||||
def setUp(self):
|
||||
"""Create a machine for process tests."""
|
||||
import archivebox.machine.models as models
|
||||
models._CURRENT_MACHINE = None
|
||||
models._CURRENT_PROCESS = None
|
||||
self.machine = Machine.current()
|
||||
|
||||
def test_process_creation(self):
|
||||
"""Process should be created with default values."""
|
||||
process = Process.objects.create(
|
||||
machine=self.machine,
|
||||
cmd=['echo', 'hello'],
|
||||
pwd='/tmp',
|
||||
)
|
||||
|
||||
self.assertIsNotNone(process.id)
|
||||
self.assertEqual(process.cmd, ['echo', 'hello'])
|
||||
self.assertEqual(process.status, Process.StatusChoices.QUEUED)
|
||||
self.assertIsNone(process.pid)
|
||||
self.assertIsNone(process.exit_code)
|
||||
|
||||
def test_process_to_jsonl(self):
|
||||
"""Process.to_json() should serialize correctly."""
|
||||
process = Process.objects.create(
|
||||
machine=self.machine,
|
||||
cmd=['echo', 'hello'],
|
||||
pwd='/tmp',
|
||||
timeout=60,
|
||||
)
|
||||
json_data = process.to_json()
|
||||
|
||||
self.assertEqual(json_data['type'], 'Process')
|
||||
self.assertEqual(json_data['cmd'], ['echo', 'hello'])
|
||||
self.assertEqual(json_data['pwd'], '/tmp')
|
||||
self.assertEqual(json_data['timeout'], 60)
|
||||
|
||||
def test_process_update_and_requeue(self):
|
||||
"""Process.update_and_requeue() should update fields and save."""
|
||||
process = Process.objects.create(machine=self.machine, cmd=['test'])
|
||||
|
||||
process.update_and_requeue(
|
||||
status=Process.StatusChoices.RUNNING,
|
||||
pid=12345,
|
||||
started_at=timezone.now(),
|
||||
)
|
||||
|
||||
process.refresh_from_db()
|
||||
self.assertEqual(process.status, Process.StatusChoices.RUNNING)
|
||||
self.assertEqual(process.pid, 12345)
|
||||
self.assertIsNotNone(process.started_at)
|
||||
|
||||
|
||||
class TestProcessCurrent(TestCase):
|
||||
"""Test Process.current() method."""
|
||||
|
||||
def setUp(self):
|
||||
"""Reset caches."""
|
||||
import archivebox.machine.models as models
|
||||
models._CURRENT_MACHINE = None
|
||||
models._CURRENT_PROCESS = None
|
||||
|
||||
def test_process_current_creates_record(self):
|
||||
"""Process.current() should create a Process for current PID."""
|
||||
proc = Process.current()
|
||||
|
||||
self.assertIsNotNone(proc)
|
||||
self.assertEqual(proc.pid, os.getpid())
|
||||
self.assertEqual(proc.status, Process.StatusChoices.RUNNING)
|
||||
self.assertIsNotNone(proc.machine)
|
||||
self.assertIsNotNone(proc.started_at)
|
||||
|
||||
def test_process_current_caches(self):
|
||||
"""Process.current() should cache the result."""
|
||||
proc1 = Process.current()
|
||||
proc2 = Process.current()
|
||||
|
||||
self.assertEqual(proc1.id, proc2.id)
|
||||
|
||||
def test_process_detect_type_orchestrator(self):
|
||||
"""_detect_process_type should detect orchestrator."""
|
||||
with patch('sys.argv', ['archivebox', 'manage', 'orchestrator']):
|
||||
result = Process._detect_process_type()
|
||||
self.assertEqual(result, Process.TypeChoices.ORCHESTRATOR)
|
||||
|
||||
def test_process_detect_type_cli(self):
|
||||
"""_detect_process_type should detect CLI commands."""
|
||||
with patch('sys.argv', ['archivebox', 'add', 'http://example.com']):
|
||||
result = Process._detect_process_type()
|
||||
self.assertEqual(result, Process.TypeChoices.CLI)
|
||||
|
||||
def test_process_detect_type_worker(self):
|
||||
"""_detect_process_type should detect workers."""
|
||||
with patch('sys.argv', ['python', '-m', 'crawl_worker']):
|
||||
result = Process._detect_process_type()
|
||||
self.assertEqual(result, Process.TypeChoices.WORKER)
|
||||
|
||||
|
||||
class TestProcessHierarchy(TestCase):
|
||||
"""Test Process parent/child relationships."""
|
||||
|
||||
def setUp(self):
|
||||
"""Create machine."""
|
||||
import archivebox.machine.models as models
|
||||
models._CURRENT_MACHINE = None
|
||||
self.machine = Machine.current()
|
||||
|
||||
def test_process_parent_child(self):
|
||||
"""Process should track parent/child relationships."""
|
||||
parent = Process.objects.create(
|
||||
machine=self.machine,
|
||||
process_type=Process.TypeChoices.CLI,
|
||||
status=Process.StatusChoices.RUNNING,
|
||||
pid=1,
|
||||
started_at=timezone.now(),
|
||||
)
|
||||
|
||||
child = Process.objects.create(
|
||||
machine=self.machine,
|
||||
parent=parent,
|
||||
process_type=Process.TypeChoices.WORKER,
|
||||
status=Process.StatusChoices.RUNNING,
|
||||
pid=2,
|
||||
started_at=timezone.now(),
|
||||
)
|
||||
|
||||
self.assertEqual(child.parent, parent)
|
||||
self.assertIn(child, parent.children.all())
|
||||
|
||||
def test_process_root(self):
|
||||
"""Process.root should return the root of the hierarchy."""
|
||||
root = Process.objects.create(
|
||||
machine=self.machine,
|
||||
process_type=Process.TypeChoices.CLI,
|
||||
status=Process.StatusChoices.RUNNING,
|
||||
started_at=timezone.now(),
|
||||
)
|
||||
child = Process.objects.create(
|
||||
machine=self.machine,
|
||||
parent=root,
|
||||
status=Process.StatusChoices.RUNNING,
|
||||
started_at=timezone.now(),
|
||||
)
|
||||
grandchild = Process.objects.create(
|
||||
machine=self.machine,
|
||||
parent=child,
|
||||
status=Process.StatusChoices.RUNNING,
|
||||
started_at=timezone.now(),
|
||||
)
|
||||
|
||||
self.assertEqual(grandchild.root, root)
|
||||
self.assertEqual(child.root, root)
|
||||
self.assertEqual(root.root, root)
|
||||
|
||||
def test_process_depth(self):
|
||||
"""Process.depth should return depth in tree."""
|
||||
root = Process.objects.create(
|
||||
machine=self.machine,
|
||||
status=Process.StatusChoices.RUNNING,
|
||||
started_at=timezone.now(),
|
||||
)
|
||||
child = Process.objects.create(
|
||||
machine=self.machine,
|
||||
parent=root,
|
||||
status=Process.StatusChoices.RUNNING,
|
||||
started_at=timezone.now(),
|
||||
)
|
||||
|
||||
self.assertEqual(root.depth, 0)
|
||||
self.assertEqual(child.depth, 1)
|
||||
|
||||
|
||||
class TestProcessLifecycle(TestCase):
|
||||
"""Test Process lifecycle methods."""
|
||||
|
||||
def setUp(self):
|
||||
"""Create machine."""
|
||||
import archivebox.machine.models as models
|
||||
models._CURRENT_MACHINE = None
|
||||
self.machine = Machine.current()
|
||||
|
||||
def test_process_is_running_current_pid(self):
|
||||
"""is_running should be True for current PID."""
|
||||
import psutil
|
||||
from datetime import datetime
|
||||
|
||||
proc_start = datetime.fromtimestamp(psutil.Process(os.getpid()).create_time(), tz=timezone.get_current_timezone())
|
||||
proc = Process.objects.create(
|
||||
machine=self.machine,
|
||||
status=Process.StatusChoices.RUNNING,
|
||||
pid=os.getpid(),
|
||||
started_at=proc_start,
|
||||
)
|
||||
|
||||
self.assertTrue(proc.is_running)
|
||||
|
||||
def test_process_is_running_fake_pid(self):
|
||||
"""is_running should be False for non-existent PID."""
|
||||
proc = Process.objects.create(
|
||||
machine=self.machine,
|
||||
status=Process.StatusChoices.RUNNING,
|
||||
pid=999999,
|
||||
started_at=timezone.now(),
|
||||
)
|
||||
|
||||
self.assertFalse(proc.is_running)
|
||||
|
||||
def test_process_poll_detects_exit(self):
|
||||
"""poll() should detect exited process."""
|
||||
proc = Process.objects.create(
|
||||
machine=self.machine,
|
||||
status=Process.StatusChoices.RUNNING,
|
||||
pid=999999,
|
||||
started_at=timezone.now(),
|
||||
)
|
||||
|
||||
exit_code = proc.poll()
|
||||
|
||||
self.assertIsNotNone(exit_code)
|
||||
proc.refresh_from_db()
|
||||
self.assertEqual(proc.status, Process.StatusChoices.EXITED)
|
||||
|
||||
def test_process_poll_normalizes_negative_exit_code(self):
|
||||
"""poll() should normalize -1 exit codes to 137."""
|
||||
proc = Process.objects.create(
|
||||
machine=self.machine,
|
||||
status=Process.StatusChoices.EXITED,
|
||||
pid=999999,
|
||||
exit_code=-1,
|
||||
started_at=timezone.now(),
|
||||
)
|
||||
|
||||
exit_code = proc.poll()
|
||||
|
||||
self.assertEqual(exit_code, 137)
|
||||
proc.refresh_from_db()
|
||||
self.assertEqual(proc.exit_code, 137)
|
||||
|
||||
def test_process_terminate_dead_process(self):
|
||||
"""terminate() should handle already-dead process."""
|
||||
proc = Process.objects.create(
|
||||
machine=self.machine,
|
||||
status=Process.StatusChoices.RUNNING,
|
||||
pid=999999,
|
||||
started_at=timezone.now(),
|
||||
)
|
||||
|
||||
result = proc.terminate()
|
||||
|
||||
self.assertFalse(result)
|
||||
proc.refresh_from_db()
|
||||
self.assertEqual(proc.status, Process.StatusChoices.EXITED)
|
||||
|
||||
|
||||
class TestProcessClassMethods(TestCase):
|
||||
"""Test Process class methods for querying."""
|
||||
|
||||
def setUp(self):
|
||||
"""Create machine."""
|
||||
import archivebox.machine.models as models
|
||||
models._CURRENT_MACHINE = None
|
||||
self.machine = Machine.current()
|
||||
|
||||
def test_get_running(self):
|
||||
"""get_running should return running processes."""
|
||||
proc = Process.objects.create(
|
||||
machine=self.machine,
|
||||
process_type=Process.TypeChoices.HOOK,
|
||||
status=Process.StatusChoices.RUNNING,
|
||||
pid=99999,
|
||||
started_at=timezone.now(),
|
||||
)
|
||||
|
||||
running = Process.get_running(process_type=Process.TypeChoices.HOOK)
|
||||
|
||||
self.assertIn(proc, running)
|
||||
|
||||
def test_get_running_count(self):
|
||||
"""get_running_count should count running processes."""
|
||||
for i in range(3):
|
||||
Process.objects.create(
|
||||
machine=self.machine,
|
||||
process_type=Process.TypeChoices.HOOK,
|
||||
status=Process.StatusChoices.RUNNING,
|
||||
pid=99900 + i,
|
||||
started_at=timezone.now(),
|
||||
)
|
||||
|
||||
count = Process.get_running_count(process_type=Process.TypeChoices.HOOK)
|
||||
self.assertGreaterEqual(count, 3)
|
||||
|
||||
def test_cleanup_stale_running(self):
|
||||
"""cleanup_stale_running should mark stale processes as exited."""
|
||||
stale = Process.objects.create(
|
||||
machine=self.machine,
|
||||
status=Process.StatusChoices.RUNNING,
|
||||
pid=999999,
|
||||
started_at=timezone.now() - PID_REUSE_WINDOW - timedelta(hours=1),
|
||||
)
|
||||
|
||||
cleaned = Process.cleanup_stale_running()
|
||||
|
||||
self.assertGreaterEqual(cleaned, 1)
|
||||
stale.refresh_from_db()
|
||||
self.assertEqual(stale.status, Process.StatusChoices.EXITED)
|
||||
|
||||
|
||||
class TestProcessStateMachine(TestCase):
|
||||
"""Test the ProcessMachine state machine."""
|
||||
|
||||
def setUp(self):
|
||||
"""Create a machine and process for state machine tests."""
|
||||
import archivebox.machine.models as models
|
||||
models._CURRENT_MACHINE = None
|
||||
self.machine = Machine.current()
|
||||
self.process = Process.objects.create(
|
||||
machine=self.machine,
|
||||
cmd=['echo', 'test'],
|
||||
pwd='/tmp',
|
||||
)
|
||||
|
||||
def test_process_state_machine_initial_state(self):
|
||||
"""ProcessMachine should start in queued state."""
|
||||
sm = ProcessMachine(self.process)
|
||||
self.assertEqual(sm.current_state_value, Process.StatusChoices.QUEUED)
|
||||
|
||||
def test_process_state_machine_can_start(self):
|
||||
"""ProcessMachine.can_start() should check cmd and machine."""
|
||||
sm = ProcessMachine(self.process)
|
||||
self.assertTrue(sm.can_start())
|
||||
|
||||
self.process.cmd = []
|
||||
self.process.save()
|
||||
sm = ProcessMachine(self.process)
|
||||
self.assertFalse(sm.can_start())
|
||||
|
||||
def test_process_state_machine_is_exited(self):
|
||||
"""ProcessMachine.is_exited() should check exit_code."""
|
||||
sm = ProcessMachine(self.process)
|
||||
self.assertFalse(sm.is_exited())
|
||||
|
||||
self.process.exit_code = 0
|
||||
self.process.save()
|
||||
sm = ProcessMachine(self.process)
|
||||
self.assertTrue(sm.is_exited())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
484
archivebox/tests/test_orchestrator.py
Normal file
484
archivebox/tests/test_orchestrator.py
Normal file
@@ -0,0 +1,484 @@
|
||||
"""
|
||||
Unit tests for the Orchestrator and Worker classes.
|
||||
|
||||
Tests cover:
|
||||
1. Orchestrator lifecycle (startup, shutdown)
|
||||
2. Queue polling and worker spawning
|
||||
3. Idle detection and exit logic
|
||||
4. Worker registration and management
|
||||
5. Process model methods (replacing old pid_utils)
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
from typing import ClassVar
|
||||
|
||||
import pytest
|
||||
from django.test import TestCase
|
||||
from django.utils import timezone
|
||||
|
||||
from archivebox.workers.orchestrator import Orchestrator
|
||||
from archivebox.workers.worker import Worker
|
||||
|
||||
|
||||
class FakeWorker(Worker):
|
||||
name: ClassVar[str] = 'crawl'
|
||||
MAX_CONCURRENT_TASKS: ClassVar[int] = 5
|
||||
running_workers: ClassVar[list[dict[str, object]]] = []
|
||||
|
||||
@classmethod
|
||||
def get_running_workers(cls) -> list[dict[str, object]]:
|
||||
return cls.running_workers
|
||||
|
||||
|
||||
class TestOrchestratorUnit(TestCase):
|
||||
"""Unit tests for Orchestrator class (mocked dependencies)."""
|
||||
|
||||
def test_orchestrator_creation(self):
|
||||
"""Orchestrator should initialize with correct defaults."""
|
||||
orchestrator = Orchestrator(exit_on_idle=True)
|
||||
|
||||
self.assertTrue(orchestrator.exit_on_idle)
|
||||
self.assertEqual(orchestrator.idle_count, 0)
|
||||
self.assertIsNone(orchestrator.pid_file)
|
||||
|
||||
def test_orchestrator_repr(self):
|
||||
"""Orchestrator __repr__ should include PID."""
|
||||
orchestrator = Orchestrator()
|
||||
repr_str = repr(orchestrator)
|
||||
|
||||
self.assertIn('Orchestrator', repr_str)
|
||||
self.assertIn(str(os.getpid()), repr_str)
|
||||
|
||||
def test_has_pending_work(self):
|
||||
"""has_pending_work should check if any queue has items."""
|
||||
orchestrator = Orchestrator()
|
||||
|
||||
self.assertFalse(orchestrator.has_pending_work({'crawl': 0, 'snapshot': 0}))
|
||||
self.assertTrue(orchestrator.has_pending_work({'crawl': 0, 'snapshot': 5}))
|
||||
self.assertTrue(orchestrator.has_pending_work({'crawl': 10, 'snapshot': 0}))
|
||||
|
||||
def test_should_exit_not_exit_on_idle(self):
|
||||
"""should_exit should return False when exit_on_idle is False."""
|
||||
orchestrator = Orchestrator(exit_on_idle=False)
|
||||
orchestrator.idle_count = 100
|
||||
|
||||
self.assertFalse(orchestrator.should_exit({'crawl': 0}))
|
||||
|
||||
def test_should_exit_pending_work(self):
|
||||
"""should_exit should return False when there's pending work."""
|
||||
orchestrator = Orchestrator(exit_on_idle=True)
|
||||
orchestrator.idle_count = 100
|
||||
|
||||
self.assertFalse(orchestrator.should_exit({'crawl': 5}))
|
||||
|
||||
@patch.object(Orchestrator, 'has_running_workers')
|
||||
def test_should_exit_running_workers(self, mock_has_workers):
|
||||
"""should_exit should return False when workers are running."""
|
||||
mock_has_workers.return_value = True
|
||||
orchestrator = Orchestrator(exit_on_idle=True)
|
||||
orchestrator.idle_count = 100
|
||||
|
||||
self.assertFalse(orchestrator.should_exit({'crawl': 0}))
|
||||
|
||||
@patch.object(Orchestrator, 'has_running_workers')
|
||||
@patch.object(Orchestrator, 'has_future_work')
|
||||
def test_should_exit_idle_timeout(self, mock_future, mock_workers):
|
||||
"""should_exit should return True after idle timeout with no work."""
|
||||
mock_workers.return_value = False
|
||||
mock_future.return_value = False
|
||||
|
||||
orchestrator = Orchestrator(exit_on_idle=True)
|
||||
orchestrator.idle_count = orchestrator.IDLE_TIMEOUT
|
||||
|
||||
self.assertTrue(orchestrator.should_exit({'crawl': 0, 'snapshot': 0}))
|
||||
|
||||
@patch.object(Orchestrator, 'has_running_workers')
|
||||
@patch.object(Orchestrator, 'has_future_work')
|
||||
def test_should_exit_below_idle_timeout(self, mock_future, mock_workers):
|
||||
"""should_exit should return False below idle timeout."""
|
||||
mock_workers.return_value = False
|
||||
mock_future.return_value = False
|
||||
|
||||
orchestrator = Orchestrator(exit_on_idle=True)
|
||||
orchestrator.idle_count = orchestrator.IDLE_TIMEOUT - 1
|
||||
|
||||
self.assertFalse(orchestrator.should_exit({'crawl': 0}))
|
||||
|
||||
def test_should_spawn_worker_no_queue(self):
|
||||
"""should_spawn_worker should return False when queue is empty."""
|
||||
orchestrator = Orchestrator()
|
||||
|
||||
FakeWorker.running_workers = []
|
||||
self.assertFalse(orchestrator.should_spawn_worker(FakeWorker, 0))
|
||||
|
||||
def test_should_spawn_worker_at_limit(self):
|
||||
"""should_spawn_worker should return False when at per-type limit."""
|
||||
orchestrator = Orchestrator()
|
||||
|
||||
running_workers: list[dict[str, object]] = [{'worker_id': worker_id} for worker_id in range(orchestrator.MAX_CRAWL_WORKERS)]
|
||||
FakeWorker.running_workers = running_workers
|
||||
self.assertFalse(orchestrator.should_spawn_worker(FakeWorker, 10))
|
||||
|
||||
@patch.object(Orchestrator, 'get_total_worker_count')
|
||||
def test_should_spawn_worker_at_total_limit(self, mock_total):
|
||||
"""should_spawn_worker should return False when at total limit."""
|
||||
orchestrator = Orchestrator()
|
||||
mock_total.return_value = 0
|
||||
running_workers: list[dict[str, object]] = [{'worker_id': worker_id} for worker_id in range(orchestrator.MAX_CRAWL_WORKERS)]
|
||||
FakeWorker.running_workers = running_workers
|
||||
self.assertFalse(orchestrator.should_spawn_worker(FakeWorker, 10))
|
||||
|
||||
@patch.object(Orchestrator, 'get_total_worker_count')
|
||||
def test_should_spawn_worker_success(self, mock_total):
|
||||
"""should_spawn_worker should return True when conditions are met."""
|
||||
orchestrator = Orchestrator()
|
||||
mock_total.return_value = 0
|
||||
|
||||
FakeWorker.running_workers = []
|
||||
self.assertTrue(orchestrator.should_spawn_worker(FakeWorker, 10))
|
||||
|
||||
@patch.object(Orchestrator, 'get_total_worker_count')
|
||||
def test_should_spawn_worker_enough_workers(self, mock_total):
|
||||
"""should_spawn_worker should return False when enough workers for queue."""
|
||||
orchestrator = Orchestrator()
|
||||
mock_total.return_value = 2
|
||||
|
||||
FakeWorker.running_workers = [{}] # 1 worker running
|
||||
self.assertFalse(orchestrator.should_spawn_worker(FakeWorker, 3))
|
||||
|
||||
|
||||
class TestOrchestratorWithProcess(TestCase):
|
||||
"""Test Orchestrator using Process model for tracking."""
|
||||
|
||||
def setUp(self):
|
||||
"""Reset process cache."""
|
||||
import archivebox.machine.models as models
|
||||
models._CURRENT_MACHINE = None
|
||||
models._CURRENT_PROCESS = None
|
||||
|
||||
def test_is_running_no_orchestrator(self):
|
||||
"""is_running should return False when no orchestrator process exists."""
|
||||
from archivebox.machine.models import Process
|
||||
|
||||
# Clean up any stale processes first
|
||||
Process.cleanup_stale_running()
|
||||
|
||||
# Mark any running orchestrators as exited for clean test state
|
||||
Process.objects.filter(
|
||||
process_type=Process.TypeChoices.ORCHESTRATOR,
|
||||
status=Process.StatusChoices.RUNNING
|
||||
).update(status=Process.StatusChoices.EXITED)
|
||||
|
||||
self.assertFalse(Orchestrator.is_running())
|
||||
|
||||
def test_is_running_with_orchestrator_process(self):
|
||||
"""is_running should return True when orchestrator Process exists."""
|
||||
from archivebox.machine.models import Process, Machine
|
||||
import psutil
|
||||
|
||||
machine = Machine.current()
|
||||
current_proc = psutil.Process(os.getpid())
|
||||
|
||||
# Create an orchestrator Process record
|
||||
proc = Process.objects.create(
|
||||
machine=machine,
|
||||
process_type=Process.TypeChoices.ORCHESTRATOR,
|
||||
status=Process.StatusChoices.RUNNING,
|
||||
pid=os.getpid(), # Use current PID so it appears alive
|
||||
started_at=datetime.fromtimestamp(current_proc.create_time(), tz=timezone.get_current_timezone()),
|
||||
cmd=current_proc.cmdline(),
|
||||
)
|
||||
|
||||
try:
|
||||
# Should detect running orchestrator
|
||||
self.assertTrue(Orchestrator.is_running())
|
||||
finally:
|
||||
# Clean up
|
||||
proc.status = Process.StatusChoices.EXITED
|
||||
proc.save()
|
||||
|
||||
def test_orchestrator_uses_process_for_is_running(self):
|
||||
"""Orchestrator.is_running should use Process.get_running_count."""
|
||||
from archivebox.machine.models import Process
|
||||
|
||||
# Verify is_running uses Process model, not pid files
|
||||
with patch.object(Process, 'get_running_count') as mock_count:
|
||||
mock_count.return_value = 1
|
||||
|
||||
result = Orchestrator.is_running()
|
||||
|
||||
# Should have called Process.get_running_count with orchestrator type
|
||||
mock_count.assert_called()
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_orchestrator_scoped_worker_count(self):
|
||||
"""Orchestrator with crawl_id should count only descendant workers."""
|
||||
from archivebox.machine.models import Process, Machine
|
||||
|
||||
machine = Machine.current()
|
||||
orchestrator = Orchestrator(exit_on_idle=True, crawl_id='test-crawl')
|
||||
|
||||
orchestrator.db_process = Process.objects.create(
|
||||
machine=machine,
|
||||
process_type=Process.TypeChoices.ORCHESTRATOR,
|
||||
status=Process.StatusChoices.RUNNING,
|
||||
pid=12345,
|
||||
started_at=timezone.now(),
|
||||
)
|
||||
|
||||
# Prevent cleanup from marking fake PIDs as exited
|
||||
orchestrator._last_cleanup_time = time.time()
|
||||
|
||||
Process.objects.create(
|
||||
machine=machine,
|
||||
process_type=Process.TypeChoices.WORKER,
|
||||
worker_type='crawl',
|
||||
status=Process.StatusChoices.RUNNING,
|
||||
pid=12346,
|
||||
parent=orchestrator.db_process,
|
||||
started_at=timezone.now(),
|
||||
)
|
||||
|
||||
Process.objects.create(
|
||||
machine=machine,
|
||||
process_type=Process.TypeChoices.WORKER,
|
||||
worker_type='crawl',
|
||||
status=Process.StatusChoices.RUNNING,
|
||||
pid=12347,
|
||||
started_at=timezone.now(),
|
||||
)
|
||||
|
||||
self.assertEqual(orchestrator.get_total_worker_count(), 1)
|
||||
|
||||
|
||||
class TestProcessBasedWorkerTracking(TestCase):
|
||||
"""Test Process model methods that replace pid_utils functionality."""
|
||||
|
||||
def setUp(self):
|
||||
"""Reset caches."""
|
||||
import archivebox.machine.models as models
|
||||
models._CURRENT_MACHINE = None
|
||||
models._CURRENT_PROCESS = None
|
||||
|
||||
def test_process_current_creates_record(self):
|
||||
"""Process.current() should create a Process record for current PID."""
|
||||
from archivebox.machine.models import Process
|
||||
|
||||
proc = Process.current()
|
||||
|
||||
self.assertIsNotNone(proc)
|
||||
self.assertEqual(proc.pid, os.getpid())
|
||||
self.assertEqual(proc.status, Process.StatusChoices.RUNNING)
|
||||
self.assertIsNotNone(proc.machine)
|
||||
self.assertIsNotNone(proc.started_at)
|
||||
|
||||
def test_process_current_caches_result(self):
|
||||
"""Process.current() should return cached Process within interval."""
|
||||
from archivebox.machine.models import Process
|
||||
|
||||
proc1 = Process.current()
|
||||
proc2 = Process.current()
|
||||
|
||||
self.assertEqual(proc1.id, proc2.id)
|
||||
|
||||
def test_process_get_running_count(self):
|
||||
"""Process.get_running_count should count running processes by type."""
|
||||
from archivebox.machine.models import Process, Machine
|
||||
|
||||
machine = Machine.current()
|
||||
|
||||
# Create some worker processes
|
||||
for i in range(3):
|
||||
Process.objects.create(
|
||||
machine=machine,
|
||||
process_type=Process.TypeChoices.WORKER,
|
||||
status=Process.StatusChoices.RUNNING,
|
||||
pid=99990 + i, # Fake PIDs
|
||||
started_at=timezone.now(),
|
||||
)
|
||||
|
||||
count = Process.get_running_count(process_type=Process.TypeChoices.WORKER)
|
||||
self.assertGreaterEqual(count, 3)
|
||||
|
||||
def test_process_get_next_worker_id(self):
|
||||
"""Process.get_next_worker_id should return count of running workers."""
|
||||
from archivebox.machine.models import Process, Machine
|
||||
|
||||
machine = Machine.current()
|
||||
|
||||
# Create 2 worker processes
|
||||
for i in range(2):
|
||||
Process.objects.create(
|
||||
machine=machine,
|
||||
process_type=Process.TypeChoices.WORKER,
|
||||
status=Process.StatusChoices.RUNNING,
|
||||
pid=99980 + i,
|
||||
started_at=timezone.now(),
|
||||
)
|
||||
|
||||
next_id = Process.get_next_worker_id(process_type=Process.TypeChoices.WORKER)
|
||||
self.assertGreaterEqual(next_id, 2)
|
||||
|
||||
def test_process_cleanup_stale_running(self):
|
||||
"""Process.cleanup_stale_running should mark stale processes as exited."""
|
||||
from archivebox.machine.models import Process, Machine, PID_REUSE_WINDOW
|
||||
|
||||
machine = Machine.current()
|
||||
|
||||
# Create a stale process (old started_at, fake PID)
|
||||
stale_proc = Process.objects.create(
|
||||
machine=machine,
|
||||
process_type=Process.TypeChoices.WORKER,
|
||||
status=Process.StatusChoices.RUNNING,
|
||||
pid=999999, # Fake PID that doesn't exist
|
||||
started_at=timezone.now() - PID_REUSE_WINDOW - timedelta(hours=1),
|
||||
)
|
||||
|
||||
cleaned = Process.cleanup_stale_running()
|
||||
|
||||
self.assertGreaterEqual(cleaned, 1)
|
||||
|
||||
stale_proc.refresh_from_db()
|
||||
self.assertEqual(stale_proc.status, Process.StatusChoices.EXITED)
|
||||
|
||||
def test_process_get_running(self):
|
||||
"""Process.get_running should return queryset of running processes."""
|
||||
from archivebox.machine.models import Process, Machine
|
||||
|
||||
machine = Machine.current()
|
||||
|
||||
# Create a running process
|
||||
proc = Process.objects.create(
|
||||
machine=machine,
|
||||
process_type=Process.TypeChoices.HOOK,
|
||||
status=Process.StatusChoices.RUNNING,
|
||||
pid=99970,
|
||||
started_at=timezone.now(),
|
||||
)
|
||||
|
||||
running = Process.get_running(process_type=Process.TypeChoices.HOOK)
|
||||
|
||||
self.assertIn(proc, running)
|
||||
|
||||
def test_process_type_detection(self):
|
||||
"""Process._detect_process_type should detect process type from argv."""
|
||||
from archivebox.machine.models import Process
|
||||
|
||||
# Test detection logic
|
||||
with patch('sys.argv', ['archivebox', 'manage', 'orchestrator']):
|
||||
result = Process._detect_process_type()
|
||||
self.assertEqual(result, Process.TypeChoices.ORCHESTRATOR)
|
||||
|
||||
with patch('sys.argv', ['archivebox', 'add', 'http://example.com']):
|
||||
result = Process._detect_process_type()
|
||||
self.assertEqual(result, Process.TypeChoices.CLI)
|
||||
|
||||
with patch('sys.argv', ['supervisord', '-c', 'config.ini']):
|
||||
result = Process._detect_process_type()
|
||||
self.assertEqual(result, Process.TypeChoices.SUPERVISORD)
|
||||
|
||||
|
||||
class TestProcessLifecycle(TestCase):
|
||||
"""Test Process model lifecycle methods."""
|
||||
|
||||
def setUp(self):
|
||||
"""Reset caches and create a machine."""
|
||||
import archivebox.machine.models as models
|
||||
models._CURRENT_MACHINE = None
|
||||
models._CURRENT_PROCESS = None
|
||||
self.machine = models.Machine.current()
|
||||
|
||||
def test_process_is_running_property(self):
|
||||
"""Process.is_running should check actual OS process."""
|
||||
from archivebox.machine.models import Process
|
||||
proc = Process.current()
|
||||
|
||||
# Should be running (current process exists)
|
||||
self.assertTrue(proc.is_running)
|
||||
|
||||
# Create a process with fake PID
|
||||
fake_proc = Process.objects.create(
|
||||
machine=self.machine,
|
||||
status=Process.StatusChoices.RUNNING,
|
||||
pid=999999,
|
||||
started_at=timezone.now(),
|
||||
)
|
||||
|
||||
# Should not be running (PID doesn't exist)
|
||||
self.assertFalse(fake_proc.is_running)
|
||||
|
||||
def test_process_poll(self):
|
||||
"""Process.poll should check and update exit status."""
|
||||
from archivebox.machine.models import Process
|
||||
|
||||
# Create a process with fake PID (already exited)
|
||||
proc = Process.objects.create(
|
||||
machine=self.machine,
|
||||
status=Process.StatusChoices.RUNNING,
|
||||
pid=999999,
|
||||
started_at=timezone.now(),
|
||||
)
|
||||
|
||||
exit_code = proc.poll()
|
||||
|
||||
# Should have detected exit and updated status
|
||||
self.assertIsNotNone(exit_code)
|
||||
proc.refresh_from_db()
|
||||
self.assertEqual(proc.status, Process.StatusChoices.EXITED)
|
||||
|
||||
def test_process_terminate_already_dead(self):
|
||||
"""Process.terminate should handle already-dead processes."""
|
||||
from archivebox.machine.models import Process
|
||||
|
||||
# Create a process with fake PID
|
||||
proc = Process.objects.create(
|
||||
machine=self.machine,
|
||||
status=Process.StatusChoices.RUNNING,
|
||||
pid=999999,
|
||||
started_at=timezone.now(),
|
||||
)
|
||||
|
||||
result = proc.terminate()
|
||||
|
||||
# Should return False (was already dead)
|
||||
self.assertFalse(result)
|
||||
|
||||
proc.refresh_from_db()
|
||||
self.assertEqual(proc.status, Process.StatusChoices.EXITED)
|
||||
|
||||
def test_process_tree_traversal(self):
|
||||
"""Process parent/children relationships should work."""
|
||||
from archivebox.machine.models import Process
|
||||
|
||||
# Create parent process
|
||||
parent = Process.objects.create(
|
||||
machine=self.machine,
|
||||
process_type=Process.TypeChoices.CLI,
|
||||
status=Process.StatusChoices.RUNNING,
|
||||
pid=1,
|
||||
started_at=timezone.now(),
|
||||
)
|
||||
|
||||
# Create child process
|
||||
child = Process.objects.create(
|
||||
machine=self.machine,
|
||||
parent=parent,
|
||||
process_type=Process.TypeChoices.WORKER,
|
||||
status=Process.StatusChoices.RUNNING,
|
||||
pid=2,
|
||||
started_at=timezone.now(),
|
||||
)
|
||||
|
||||
# Test relationships
|
||||
self.assertEqual(child.parent, parent)
|
||||
self.assertIn(child, parent.children.all())
|
||||
self.assertEqual(child.root, parent)
|
||||
self.assertEqual(child.depth, 1)
|
||||
self.assertEqual(parent.depth, 0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
@@ -13,7 +13,6 @@ ADMIN_HOST = 'admin.archivebox.localhost:8000'
|
||||
|
||||
|
||||
def _run_savepagenow_script(initialized_archive: Path, request_url: str, expected_url: str, *, login: bool, public_add_view: bool, host: str):
|
||||
project_root = Path(__file__).resolve().parents[2]
|
||||
script = textwrap.dedent(
|
||||
f"""
|
||||
import os
|
||||
@@ -81,7 +80,7 @@ def _run_savepagenow_script(initialized_archive: Path, request_url: str, expecte
|
||||
|
||||
return subprocess.run(
|
||||
[sys.executable, '-c', script],
|
||||
cwd=project_root,
|
||||
cwd=initialized_archive,
|
||||
env=env,
|
||||
text=True,
|
||||
capture_output=True,
|
||||
@@ -90,7 +89,6 @@ def _run_savepagenow_script(initialized_archive: Path, request_url: str, expecte
|
||||
|
||||
|
||||
def _run_savepagenow_not_found_script(initialized_archive: Path, request_url: str):
|
||||
project_root = Path(__file__).resolve().parents[2]
|
||||
script = textwrap.dedent(
|
||||
f"""
|
||||
import os
|
||||
@@ -137,7 +135,7 @@ def _run_savepagenow_not_found_script(initialized_archive: Path, request_url: st
|
||||
|
||||
return subprocess.run(
|
||||
[sys.executable, '-c', script],
|
||||
cwd=project_root,
|
||||
cwd=initialized_archive,
|
||||
env=env,
|
||||
text=True,
|
||||
capture_output=True,
|
||||
@@ -146,7 +144,6 @@ def _run_savepagenow_not_found_script(initialized_archive: Path, request_url: st
|
||||
|
||||
|
||||
def _run_savepagenow_existing_snapshot_script(initialized_archive: Path, request_url: str, stored_url: str):
|
||||
project_root = Path(__file__).resolve().parents[2]
|
||||
script = textwrap.dedent(
|
||||
f"""
|
||||
import os
|
||||
@@ -199,7 +196,7 @@ def _run_savepagenow_existing_snapshot_script(initialized_archive: Path, request
|
||||
|
||||
return subprocess.run(
|
||||
[sys.executable, '-c', script],
|
||||
cwd=project_root,
|
||||
cwd=initialized_archive,
|
||||
env=env,
|
||||
text=True,
|
||||
capture_output=True,
|
||||
|
||||
84
archivebox/tests/test_scheduled_crawls.py
Normal file
84
archivebox/tests/test_scheduled_crawls.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from datetime import timedelta
|
||||
from typing import cast
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.contrib.auth.models import UserManager
|
||||
from django.test import TestCase
|
||||
from django.utils import timezone
|
||||
|
||||
from archivebox.crawls.models import Crawl, CrawlSchedule
|
||||
from archivebox.workers.orchestrator import Orchestrator
|
||||
from archivebox.workers.worker import CrawlWorker
|
||||
|
||||
|
||||
class TestScheduledCrawlMaterialization(TestCase):
|
||||
def setUp(self):
|
||||
user_manager = cast(UserManager, get_user_model().objects)
|
||||
self.user = user_manager.create_user(
|
||||
username='schedule-user',
|
||||
password='password',
|
||||
)
|
||||
|
||||
def _create_due_schedule(self) -> CrawlSchedule:
|
||||
template = Crawl.objects.create(
|
||||
urls='https://example.com/feed.xml',
|
||||
max_depth=1,
|
||||
tags_str='scheduled',
|
||||
label='Scheduled Feed',
|
||||
notes='template',
|
||||
created_by=self.user,
|
||||
status=Crawl.StatusChoices.SEALED,
|
||||
retry_at=None,
|
||||
)
|
||||
schedule = CrawlSchedule.objects.create(
|
||||
template=template,
|
||||
schedule='daily',
|
||||
is_enabled=True,
|
||||
label='Scheduled Feed',
|
||||
notes='template',
|
||||
created_by=self.user,
|
||||
)
|
||||
past = timezone.now() - timedelta(days=2)
|
||||
Crawl.objects.filter(pk=template.pk).update(created_at=past, modified_at=past)
|
||||
template.refresh_from_db()
|
||||
schedule.refresh_from_db()
|
||||
return schedule
|
||||
|
||||
def test_global_orchestrator_materializes_due_schedule(self):
|
||||
schedule = self._create_due_schedule()
|
||||
|
||||
orchestrator = Orchestrator(exit_on_idle=False)
|
||||
orchestrator._materialize_due_schedules()
|
||||
|
||||
scheduled_crawls = Crawl.objects.filter(schedule=schedule).order_by('created_at')
|
||||
self.assertEqual(scheduled_crawls.count(), 2)
|
||||
|
||||
queued_crawl = scheduled_crawls.last()
|
||||
self.assertIsNotNone(queued_crawl)
|
||||
assert queued_crawl is not None
|
||||
self.assertEqual(queued_crawl.status, Crawl.StatusChoices.QUEUED)
|
||||
self.assertEqual(queued_crawl.urls, 'https://example.com/feed.xml')
|
||||
self.assertEqual(queued_crawl.max_depth, 1)
|
||||
self.assertEqual(queued_crawl.tags_str, 'scheduled')
|
||||
|
||||
def test_one_shot_orchestrator_does_not_materialize_due_schedule(self):
|
||||
schedule = self._create_due_schedule()
|
||||
|
||||
Orchestrator(exit_on_idle=True)._materialize_due_schedules()
|
||||
self.assertEqual(Crawl.objects.filter(schedule=schedule).count(), 1)
|
||||
|
||||
Orchestrator(exit_on_idle=False, crawl_id=str(schedule.template.id))._materialize_due_schedules()
|
||||
self.assertEqual(Crawl.objects.filter(schedule=schedule).count(), 1)
|
||||
|
||||
@patch.object(CrawlWorker, 'start')
|
||||
def test_global_orchestrator_waits_one_tick_before_spawning_materialized_schedule(self, mock_start):
|
||||
schedule = self._create_due_schedule()
|
||||
|
||||
orchestrator = Orchestrator(exit_on_idle=False)
|
||||
with patch.object(orchestrator, '_claim_crawl', return_value=True):
|
||||
queue_sizes = orchestrator.check_queues_and_spawn_workers()
|
||||
|
||||
self.assertEqual(queue_sizes['crawl'], 1)
|
||||
self.assertEqual(Crawl.objects.filter(schedule=schedule).count(), 2)
|
||||
mock_start.assert_not_called()
|
||||
76
archivebox/tests/test_snapshot_worker.py
Normal file
76
archivebox/tests/test_snapshot_worker.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.test import SimpleTestCase
|
||||
|
||||
from archivebox.workers.worker import SnapshotWorker
|
||||
|
||||
|
||||
class TestSnapshotWorkerRetryForegroundHooks(SimpleTestCase):
|
||||
def _make_worker(self):
|
||||
worker = SnapshotWorker.__new__(SnapshotWorker)
|
||||
worker.pid = 12345
|
||||
cast(Any, worker).snapshot = SimpleNamespace(
|
||||
status='started',
|
||||
refresh_from_db=lambda: None,
|
||||
)
|
||||
worker._snapshot_exceeded_hard_timeout = lambda: False
|
||||
worker._seal_snapshot_due_to_timeout = lambda: None
|
||||
worker._run_hook = lambda *args, **kwargs: SimpleNamespace()
|
||||
worker._wait_for_hook = lambda process, ar: None
|
||||
return worker
|
||||
|
||||
@patch('archivebox.workers.worker.log_worker_event')
|
||||
def test_retry_skips_successful_hook_with_only_inline_output(self, mock_log):
|
||||
worker = self._make_worker()
|
||||
archive_result = SimpleNamespace(
|
||||
status='succeeded',
|
||||
output_files={},
|
||||
output_str='scrolled 600px',
|
||||
output_json=None,
|
||||
refresh_from_db=lambda: None,
|
||||
)
|
||||
|
||||
worker._retry_failed_empty_foreground_hooks(
|
||||
[(Path('/tmp/on_Snapshot__45_infiniscroll.js'), archive_result)],
|
||||
config={},
|
||||
)
|
||||
|
||||
mock_log.assert_not_called()
|
||||
|
||||
@patch('archivebox.workers.worker.log_worker_event')
|
||||
def test_retry_replays_failed_hook_with_no_outputs(self, mock_log):
|
||||
worker = self._make_worker()
|
||||
run_calls = []
|
||||
wait_calls = []
|
||||
|
||||
def run_hook(*args, **kwargs):
|
||||
run_calls.append((args, kwargs))
|
||||
return SimpleNamespace()
|
||||
|
||||
def wait_for_hook(process, ar):
|
||||
wait_calls.append((process, ar))
|
||||
ar.status = 'succeeded'
|
||||
ar.output_files = {'singlefile.html': {}}
|
||||
|
||||
archive_result = SimpleNamespace(
|
||||
status='failed',
|
||||
output_files={},
|
||||
output_str='',
|
||||
output_json=None,
|
||||
refresh_from_db=lambda: None,
|
||||
)
|
||||
|
||||
worker._run_hook = run_hook
|
||||
worker._wait_for_hook = wait_for_hook
|
||||
|
||||
worker._retry_failed_empty_foreground_hooks(
|
||||
[(Path('/tmp/on_Snapshot__50_singlefile.py'), archive_result)],
|
||||
config={},
|
||||
)
|
||||
|
||||
assert len(run_calls) == 1
|
||||
assert len(wait_calls) == 1
|
||||
mock_log.assert_called_once()
|
||||
Reference in New Issue
Block a user