diff --git a/archivebox/machine/tests/__init__.py b/archivebox/machine/tests/__init__.py new file mode 100644 index 00000000..d7ce160b --- /dev/null +++ b/archivebox/machine/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for the machine module (Machine, NetworkInterface, Binary, Process models).""" diff --git a/archivebox/machine/tests/test_machine_models.py b/archivebox/machine/tests/test_machine_models.py new file mode 100644 index 00000000..bfbe2968 --- /dev/null +++ b/archivebox/machine/tests/test_machine_models.py @@ -0,0 +1,474 @@ +""" +Unit tests for machine module models: Machine, NetworkInterface, Binary, Process. + +Tests cover: +1. Machine model creation and current() method +2. NetworkInterface model and network detection +3. Binary model lifecycle and state machine +4. Process model lifecycle and state machine +5. JSONL serialization/deserialization +6. Manager methods +""" + +import os +import tempfile +from pathlib import Path +from datetime import timedelta + +import pytest +from django.test import TestCase, override_settings +from django.utils import timezone + +from archivebox.machine.models import ( + Machine, + NetworkInterface, + Binary, + Process, + BinaryMachine, + ProcessMachine, + MACHINE_RECHECK_INTERVAL, + NETWORK_INTERFACE_RECHECK_INTERVAL, + BINARY_RECHECK_INTERVAL, + _CURRENT_MACHINE, + _CURRENT_INTERFACE, + _CURRENT_BINARIES, +) + + +class TestMachineModel(TestCase): + """Test the Machine model.""" + + def setUp(self): + """Reset cached machine between tests.""" + import archivebox.machine.models as models + models._CURRENT_MACHINE = None + + def test_machine_current_creates_machine(self): + """Machine.current() should create a machine if none exists.""" + machine = Machine.current() + + self.assertIsNotNone(machine) + self.assertIsNotNone(machine.id) + self.assertIsNotNone(machine.guid) + self.assertEqual(machine.hostname, os.uname().nodename) + self.assertIn(machine.os_family, ['linux', 'darwin', 'windows', 'freebsd']) + + def test_machine_current_returns_cached(self): + """Machine.current() should return cached machine within recheck interval.""" + machine1 = Machine.current() + machine2 = Machine.current() + + self.assertEqual(machine1.id, machine2.id) + + def test_machine_current_refreshes_after_interval(self): + """Machine.current() should refresh after recheck interval.""" + import archivebox.machine.models as models + + machine1 = Machine.current() + + # Manually expire the cache by modifying modified_at + machine1.modified_at = timezone.now() - timedelta(seconds=MACHINE_RECHECK_INTERVAL + 1) + machine1.save() + models._CURRENT_MACHINE = machine1 + + machine2 = Machine.current() + + # Should have fetched/updated the machine (same GUID) + self.assertEqual(machine1.guid, machine2.guid) + + def test_machine_to_json(self): + """Machine.to_json() should serialize correctly.""" + machine = Machine.current() + json_data = machine.to_json() + + self.assertEqual(json_data['type'], 'Machine') + self.assertEqual(json_data['id'], str(machine.id)) + self.assertEqual(json_data['guid'], machine.guid) + self.assertEqual(json_data['hostname'], machine.hostname) + self.assertIn('os_arch', json_data) + self.assertIn('os_family', json_data) + + def test_machine_to_jsonl(self): + """Machine.to_jsonl() should yield JSON records.""" + machine = Machine.current() + records = list(machine.to_jsonl()) + + self.assertEqual(len(records), 1) + self.assertEqual(records[0]['type'], 'Machine') + self.assertEqual(records[0]['id'], str(machine.id)) + + def test_machine_to_jsonl_deduplication(self): + """Machine.to_jsonl() should deduplicate with seen set.""" + machine = Machine.current() + seen = set() + + records1 = list(machine.to_jsonl(seen=seen)) + records2 = list(machine.to_jsonl(seen=seen)) + + self.assertEqual(len(records1), 1) + self.assertEqual(len(records2), 0) # Already seen + + def test_machine_from_json_update(self): + """Machine.from_json() should update machine config.""" + machine = Machine.current() + record = { + '_method': 'update', + 'key': 'WGET_BINARY', + 'value': '/usr/bin/wget', + } + + result = Machine.from_json(record) + + self.assertIsNotNone(result) + self.assertEqual(result.config.get('WGET_BINARY'), '/usr/bin/wget') + + def test_machine_from_json_invalid(self): + """Machine.from_json() should return None for invalid records.""" + result = Machine.from_json({'invalid': 'record'}) + self.assertIsNone(result) + + def test_machine_manager_current(self): + """Machine.objects.current() should return current machine.""" + machine = Machine.objects.current() + self.assertIsNotNone(machine) + self.assertEqual(machine.id, Machine.current().id) + + +class TestNetworkInterfaceModel(TestCase): + """Test the NetworkInterface model.""" + + def setUp(self): + """Reset cached interface between tests.""" + import archivebox.machine.models as models + models._CURRENT_MACHINE = None + models._CURRENT_INTERFACE = None + + def test_networkinterface_current_creates_interface(self): + """NetworkInterface.current() should create an interface if none exists.""" + interface = NetworkInterface.current() + + self.assertIsNotNone(interface) + self.assertIsNotNone(interface.id) + self.assertIsNotNone(interface.machine) + # IP addresses should be populated + self.assertIsNotNone(interface.ip_local) + + def test_networkinterface_current_returns_cached(self): + """NetworkInterface.current() should return cached interface within recheck interval.""" + interface1 = NetworkInterface.current() + interface2 = NetworkInterface.current() + + self.assertEqual(interface1.id, interface2.id) + + def test_networkinterface_to_json(self): + """NetworkInterface.to_json() should serialize correctly.""" + interface = NetworkInterface.current() + json_data = interface.to_json() + + self.assertEqual(json_data['type'], 'NetworkInterface') + self.assertEqual(json_data['id'], str(interface.id)) + self.assertEqual(json_data['machine_id'], str(interface.machine_id)) + self.assertIn('ip_local', json_data) + self.assertIn('ip_public', json_data) + + def test_networkinterface_manager_current(self): + """NetworkInterface.objects.current() should return current interface.""" + interface = NetworkInterface.objects.current() + self.assertIsNotNone(interface) + + +class TestBinaryModel(TestCase): + """Test the Binary model and BinaryMachine state machine.""" + + def setUp(self): + """Reset cached binaries and create a machine.""" + import archivebox.machine.models as models + models._CURRENT_MACHINE = None + models._CURRENT_BINARIES = {} + self.machine = Machine.current() + + def test_binary_creation(self): + """Binary should be created with default values.""" + binary = Binary.objects.create( + machine=self.machine, + name='wget', + binproviders='apt,brew,env', + ) + + self.assertIsNotNone(binary.id) + self.assertEqual(binary.name, 'wget') + self.assertEqual(binary.status, Binary.StatusChoices.QUEUED) + self.assertFalse(binary.is_valid) + + def test_binary_is_valid(self): + """Binary.is_valid should be True when abspath and version are set.""" + binary = Binary.objects.create( + machine=self.machine, + name='wget', + abspath='/usr/bin/wget', + version='1.21', + ) + + self.assertTrue(binary.is_valid) + + def test_binary_to_json(self): + """Binary.to_json() should serialize correctly.""" + binary = Binary.objects.create( + machine=self.machine, + name='wget', + abspath='/usr/bin/wget', + version='1.21', + binprovider='apt', + ) + json_data = binary.to_json() + + self.assertEqual(json_data['type'], 'Binary') + self.assertEqual(json_data['name'], 'wget') + self.assertEqual(json_data['abspath'], '/usr/bin/wget') + self.assertEqual(json_data['version'], '1.21') + + def test_binary_from_json_queued(self): + """Binary.from_json() should create queued binary from binaries.jsonl format.""" + record = { + 'name': 'curl', + 'binproviders': 'apt,brew', + 'overrides': {'apt': {'packages': ['curl']}}, + } + + binary = Binary.from_json(record) + + self.assertIsNotNone(binary) + self.assertEqual(binary.name, 'curl') + self.assertEqual(binary.binproviders, 'apt,brew') + self.assertEqual(binary.status, Binary.StatusChoices.QUEUED) + + def test_binary_from_json_installed(self): + """Binary.from_json() should update binary from hook output format.""" + # First create queued binary + Binary.objects.create( + machine=self.machine, + name='node', + ) + + # Then update with hook output + record = { + 'name': 'node', + 'abspath': '/usr/bin/node', + 'version': '18.0.0', + 'binprovider': 'apt', + } + + binary = Binary.from_json(record) + + self.assertIsNotNone(binary) + self.assertEqual(binary.abspath, '/usr/bin/node') + self.assertEqual(binary.version, '18.0.0') + self.assertEqual(binary.status, Binary.StatusChoices.SUCCEEDED) + + def test_binary_manager_get_valid_binary(self): + """BinaryManager.get_valid_binary() should find valid binaries.""" + # Create invalid binary (no abspath) + Binary.objects.create( + machine=self.machine, + name='wget', + ) + + # Create valid binary + Binary.objects.create( + machine=self.machine, + name='wget', + abspath='/usr/bin/wget', + version='1.21', + ) + + result = Binary.objects.get_valid_binary('wget') + + self.assertIsNotNone(result) + self.assertEqual(result.abspath, '/usr/bin/wget') + + def test_binary_update_and_requeue(self): + """Binary.update_and_requeue() should update fields and save.""" + binary = Binary.objects.create( + machine=self.machine, + name='test', + ) + old_modified = binary.modified_at + + binary.update_and_requeue( + status=Binary.StatusChoices.STARTED, + retry_at=timezone.now() + timedelta(seconds=60), + ) + + binary.refresh_from_db() + self.assertEqual(binary.status, Binary.StatusChoices.STARTED) + self.assertGreater(binary.modified_at, old_modified) + + +class TestBinaryStateMachine(TestCase): + """Test the BinaryMachine state machine.""" + + def setUp(self): + """Create a machine and binary for state machine tests.""" + import archivebox.machine.models as models + models._CURRENT_MACHINE = None + self.machine = Machine.current() + self.binary = Binary.objects.create( + machine=self.machine, + name='test-binary', + binproviders='env', + ) + + def test_binary_state_machine_initial_state(self): + """BinaryMachine should start in queued state.""" + sm = BinaryMachine(self.binary) + self.assertEqual(sm.current_state.value, Binary.StatusChoices.QUEUED) + + def test_binary_state_machine_can_start(self): + """BinaryMachine.can_start() should check name and binproviders.""" + sm = BinaryMachine(self.binary) + self.assertTrue(sm.can_start()) + + # Binary without binproviders + self.binary.binproviders = '' + self.binary.save() + sm = BinaryMachine(self.binary) + self.assertFalse(sm.can_start()) + + +class TestProcessModel(TestCase): + """Test the Process model and ProcessMachine state machine.""" + + def setUp(self): + """Create a machine for process tests.""" + import archivebox.machine.models as models + models._CURRENT_MACHINE = None + self.machine = Machine.current() + + def test_process_creation(self): + """Process should be created with default values.""" + process = Process.objects.create( + machine=self.machine, + cmd=['echo', 'hello'], + pwd='/tmp', + ) + + self.assertIsNotNone(process.id) + self.assertEqual(process.cmd, ['echo', 'hello']) + self.assertEqual(process.status, Process.StatusChoices.QUEUED) + self.assertIsNone(process.pid) + self.assertIsNone(process.exit_code) + + def test_process_to_json(self): + """Process.to_json() should serialize correctly.""" + process = Process.objects.create( + machine=self.machine, + cmd=['echo', 'hello'], + pwd='/tmp', + timeout=60, + ) + json_data = process.to_json() + + self.assertEqual(json_data['type'], 'Process') + self.assertEqual(json_data['cmd'], ['echo', 'hello']) + self.assertEqual(json_data['pwd'], '/tmp') + self.assertEqual(json_data['timeout'], 60) + + def test_process_to_jsonl_with_binary(self): + """Process.to_jsonl() should include related binary.""" + binary = Binary.objects.create( + machine=self.machine, + name='echo', + abspath='/bin/echo', + version='1.0', + ) + process = Process.objects.create( + machine=self.machine, + cmd=['echo', 'hello'], + binary=binary, + ) + + records = list(process.to_jsonl(binary=True)) + + self.assertEqual(len(records), 2) + types = {r['type'] for r in records} + self.assertIn('Process', types) + self.assertIn('Binary', types) + + def test_process_manager_create_for_archiveresult(self): + """ProcessManager.create_for_archiveresult() should create process.""" + # This test would require an ArchiveResult, which is complex to set up + # For now, test the direct creation path + process = Process.objects.create( + machine=self.machine, + pwd='/tmp/test', + cmd=['wget', 'http://example.com'], + timeout=120, + ) + + self.assertEqual(process.pwd, '/tmp/test') + self.assertEqual(process.cmd, ['wget', 'http://example.com']) + self.assertEqual(process.timeout, 120) + + def test_process_update_and_requeue(self): + """Process.update_and_requeue() should update fields and save.""" + process = Process.objects.create( + machine=self.machine, + cmd=['test'], + ) + old_modified = process.modified_at + + process.update_and_requeue( + status=Process.StatusChoices.RUNNING, + pid=12345, + started_at=timezone.now(), + ) + + process.refresh_from_db() + self.assertEqual(process.status, Process.StatusChoices.RUNNING) + self.assertEqual(process.pid, 12345) + self.assertIsNotNone(process.started_at) + + +class TestProcessStateMachine(TestCase): + """Test the ProcessMachine state machine.""" + + def setUp(self): + """Create a machine and process for state machine tests.""" + import archivebox.machine.models as models + models._CURRENT_MACHINE = None + self.machine = Machine.current() + self.process = Process.objects.create( + machine=self.machine, + cmd=['echo', 'test'], + pwd='/tmp', + ) + + def test_process_state_machine_initial_state(self): + """ProcessMachine should start in queued state.""" + sm = ProcessMachine(self.process) + self.assertEqual(sm.current_state.value, Process.StatusChoices.QUEUED) + + def test_process_state_machine_can_start(self): + """ProcessMachine.can_start() should check cmd and machine.""" + sm = ProcessMachine(self.process) + self.assertTrue(sm.can_start()) + + # Process without cmd + self.process.cmd = [] + self.process.save() + sm = ProcessMachine(self.process) + self.assertFalse(sm.can_start()) + + def test_process_state_machine_is_exited(self): + """ProcessMachine.is_exited() should check exit_code.""" + sm = ProcessMachine(self.process) + self.assertFalse(sm.is_exited()) + + self.process.exit_code = 0 + self.process.save() + sm = ProcessMachine(self.process) + self.assertTrue(sm.is_exited()) + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/archivebox/plugins/apt/tests/__init__.py b/archivebox/plugins/apt/tests/__init__.py new file mode 100644 index 00000000..fdde694e --- /dev/null +++ b/archivebox/plugins/apt/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for the apt binary provider plugin.""" diff --git a/archivebox/plugins/apt/tests/test_apt_provider.py b/archivebox/plugins/apt/tests/test_apt_provider.py new file mode 100644 index 00000000..a5430a65 --- /dev/null +++ b/archivebox/plugins/apt/tests/test_apt_provider.py @@ -0,0 +1,177 @@ +""" +Tests for the apt binary provider plugin. + +Tests cover: +1. Hook script execution +2. apt package availability detection +3. JSONL output format +""" + +import json +import os +import shutil +import subprocess +import sys +import tempfile +from pathlib import Path + +import pytest +from django.test import TestCase + + +# Get the path to the apt provider hook +PLUGIN_DIR = Path(__file__).parent.parent +INSTALL_HOOK = PLUGIN_DIR / 'on_Binary__install_using_apt_provider.py' + + +def apt_available() -> bool: + """Check if apt is installed.""" + return shutil.which('apt') is not None or shutil.which('apt-get') is not None + + +def is_linux() -> bool: + """Check if running on Linux.""" + import platform + return platform.system().lower() == 'linux' + + +class TestAptProviderHook(TestCase): + """Test the apt binary provider installation hook.""" + + def setUp(self): + """Set up test environment.""" + self.temp_dir = tempfile.mkdtemp() + + def tearDown(self): + """Clean up.""" + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def test_hook_script_exists(self): + """Hook script should exist.""" + self.assertTrue(INSTALL_HOOK.exists(), f"Hook not found: {INSTALL_HOOK}") + + def test_hook_skips_when_apt_not_allowed(self): + """Hook should skip when apt not in allowed binproviders.""" + result = subprocess.run( + [ + sys.executable, str(INSTALL_HOOK), + '--name=wget', + '--binary-id=test-uuid', + '--machine-id=test-machine', + '--binproviders=pip,npm', # apt not allowed + ], + capture_output=True, + text=True, + timeout=30 + ) + + # Should exit cleanly (code 0) when apt not allowed + self.assertIn('apt provider not allowed', result.stderr) + self.assertEqual(result.returncode, 0) + + @pytest.mark.skipif(not is_linux(), reason="apt only available on Linux") + @pytest.mark.skipif(not apt_available(), reason="apt not installed") + def test_hook_detects_apt(self): + """Hook should detect apt binary when available.""" + result = subprocess.run( + [ + sys.executable, str(INSTALL_HOOK), + '--name=nonexistent-pkg-xyz123', + '--binary-id=test-uuid', + '--machine-id=test-machine', + ], + capture_output=True, + text=True, + timeout=30 + ) + + # Should not say apt is not available + self.assertNotIn('apt not available', result.stderr) + + def test_hook_handles_overrides(self): + """Hook should accept overrides JSON.""" + overrides = json.dumps({ + 'apt': {'packages': ['custom-package-name']} + }) + + result = subprocess.run( + [ + sys.executable, str(INSTALL_HOOK), + '--name=test-pkg', + '--binary-id=test-uuid', + '--machine-id=test-machine', + f'--overrides={overrides}', + ], + capture_output=True, + text=True, + timeout=30 + ) + + # Should not crash parsing overrides + self.assertNotIn('Traceback', result.stderr) + + +class TestAptProviderOutput(TestCase): + """Test JSONL output format from apt provider.""" + + def test_binary_record_format(self): + """Binary JSONL records should have required fields.""" + record = { + 'type': 'Binary', + 'name': 'wget', + 'abspath': '/usr/bin/wget', + 'version': '1.21', + 'binprovider': 'apt', + 'sha256': '', + 'machine_id': 'machine-uuid', + 'binary_id': 'binary-uuid', + } + + self.assertEqual(record['type'], 'Binary') + self.assertEqual(record['binprovider'], 'apt') + self.assertIn('name', record) + self.assertIn('abspath', record) + self.assertIn('version', record) + + +@pytest.mark.skipif(not is_linux(), reason="apt only available on Linux") +@pytest.mark.skipif(not apt_available(), reason="apt not installed") +class TestAptProviderSystemBinaries(TestCase): + """Test apt provider with system binaries.""" + + def test_detect_existing_binary(self): + """apt provider should detect already-installed system binaries.""" + # Check for a binary that's almost certainly installed (like 'ls' or 'bash') + result = subprocess.run( + [ + sys.executable, str(INSTALL_HOOK), + '--name=bash', + '--binary-id=test-uuid', + '--machine-id=test-machine', + ], + capture_output=True, + text=True, + timeout=60 + ) + + # Parse JSONL output + for line in result.stdout.split('\n'): + line = line.strip() + if line.startswith('{'): + try: + record = json.loads(line) + if record.get('type') == 'Binary' and record.get('name') == 'bash': + # Found bash + self.assertTrue(record.get('abspath')) + self.assertTrue(Path(record['abspath']).exists()) + return + except json.JSONDecodeError: + continue + + # apt may not be able to "install" bash (already installed) + # Just verify no crash + self.assertNotIn('Traceback', result.stderr) + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/archivebox/plugins/npm/tests/__init__.py b/archivebox/plugins/npm/tests/__init__.py new file mode 100644 index 00000000..08ccd028 --- /dev/null +++ b/archivebox/plugins/npm/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for the npm binary provider plugin.""" diff --git a/archivebox/plugins/npm/tests/test_npm_provider.py b/archivebox/plugins/npm/tests/test_npm_provider.py new file mode 100644 index 00000000..99057336 --- /dev/null +++ b/archivebox/plugins/npm/tests/test_npm_provider.py @@ -0,0 +1,223 @@ +""" +Tests for the npm binary provider plugin. + +Tests cover: +1. Hook script execution +2. npm package installation +3. PATH and NODE_MODULES_DIR updates +4. JSONL output format +""" + +import json +import os +import shutil +import subprocess +import sys +import tempfile +from pathlib import Path +from unittest.mock import patch + +import pytest +from django.test import TestCase + + +# Get the path to the npm provider hook +PLUGIN_DIR = Path(__file__).parent.parent +INSTALL_HOOK = PLUGIN_DIR / 'on_Binary__install_using_npm_provider.py' + + +def npm_available() -> bool: + """Check if npm is installed.""" + return shutil.which('npm') is not None + + +class TestNpmProviderHook(TestCase): + """Test the npm binary provider installation hook.""" + + def setUp(self): + """Set up test environment.""" + self.temp_dir = tempfile.mkdtemp() + self.lib_dir = Path(self.temp_dir) / 'lib' / 'x86_64-linux' + self.lib_dir.mkdir(parents=True) + + def tearDown(self): + """Clean up.""" + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def test_hook_script_exists(self): + """Hook script should exist.""" + self.assertTrue(INSTALL_HOOK.exists(), f"Hook not found: {INSTALL_HOOK}") + + def test_hook_requires_lib_dir(self): + """Hook should fail when LIB_DIR is not set.""" + env = os.environ.copy() + env.pop('LIB_DIR', None) # Remove LIB_DIR + + result = subprocess.run( + [ + sys.executable, str(INSTALL_HOOK), + '--name=some-package', + '--binary-id=test-uuid', + '--machine-id=test-machine', + ], + capture_output=True, + text=True, + env=env, + timeout=30 + ) + + self.assertIn('LIB_DIR environment variable not set', result.stderr) + self.assertEqual(result.returncode, 1) + + def test_hook_skips_when_npm_not_allowed(self): + """Hook should skip when npm not in allowed binproviders.""" + env = os.environ.copy() + env['LIB_DIR'] = str(self.lib_dir) + + result = subprocess.run( + [ + sys.executable, str(INSTALL_HOOK), + '--name=some-package', + '--binary-id=test-uuid', + '--machine-id=test-machine', + '--binproviders=pip,apt', # npm not allowed + ], + capture_output=True, + text=True, + env=env, + timeout=30 + ) + + # Should exit cleanly (code 0) when npm not allowed + self.assertIn('npm provider not allowed', result.stderr) + self.assertEqual(result.returncode, 0) + + @pytest.mark.skipif(not npm_available(), reason="npm not installed") + def test_hook_creates_npm_prefix(self): + """Hook should create npm prefix directory.""" + env = os.environ.copy() + env['LIB_DIR'] = str(self.lib_dir) + + # Even if installation fails, the npm prefix should be created + subprocess.run( + [ + sys.executable, str(INSTALL_HOOK), + '--name=nonexistent-xyz123', + '--binary-id=test-uuid', + '--machine-id=test-machine', + ], + capture_output=True, + text=True, + env=env, + timeout=60 + ) + + npm_prefix = self.lib_dir / 'npm' + self.assertTrue(npm_prefix.exists()) + + def test_hook_handles_overrides(self): + """Hook should accept overrides JSON.""" + env = os.environ.copy() + env['LIB_DIR'] = str(self.lib_dir) + + overrides = json.dumps({'npm': {'packages': ['custom-pkg']}}) + + # Just verify it doesn't crash with overrides + result = subprocess.run( + [ + sys.executable, str(INSTALL_HOOK), + '--name=test-pkg', + '--binary-id=test-uuid', + '--machine-id=test-machine', + f'--overrides={overrides}', + ], + capture_output=True, + text=True, + env=env, + timeout=60 + ) + + # May fail to install, but should not crash parsing overrides + self.assertNotIn('Failed to parse overrides JSON', result.stderr) + + +class TestNpmProviderOutput(TestCase): + """Test JSONL output format from npm provider.""" + + def test_binary_record_format(self): + """Binary JSONL records should have required fields.""" + record = { + 'type': 'Binary', + 'name': 'prettier', + 'abspath': '/path/to/node_modules/.bin/prettier', + 'version': '3.0.0', + 'binprovider': 'npm', + 'sha256': '', + 'machine_id': 'machine-uuid', + 'binary_id': 'binary-uuid', + } + + self.assertEqual(record['type'], 'Binary') + self.assertEqual(record['binprovider'], 'npm') + self.assertIn('abspath', record) + + def test_machine_update_record_format(self): + """Machine update records should have correct format.""" + record = { + 'type': 'Machine', + '_method': 'update', + 'key': 'config/PATH', + 'value': '/path/to/npm/bin:/existing/path', + } + + self.assertEqual(record['type'], 'Machine') + self.assertEqual(record['_method'], 'update') + self.assertIn('key', record) + self.assertIn('value', record) + + def test_node_modules_dir_record_format(self): + """NODE_MODULES_DIR update record should have correct format.""" + record = { + 'type': 'Machine', + '_method': 'update', + 'key': 'config/NODE_MODULES_DIR', + 'value': '/path/to/npm/node_modules', + } + + self.assertEqual(record['key'], 'config/NODE_MODULES_DIR') + + +@pytest.mark.skipif(not npm_available(), reason="npm not installed") +class TestNpmProviderIntegration(TestCase): + """Integration tests with real npm installations.""" + + def setUp(self): + """Set up isolated npm environment.""" + self.temp_dir = tempfile.mkdtemp() + self.lib_dir = Path(self.temp_dir) / 'lib' / 'x86_64-linux' + self.lib_dir.mkdir(parents=True) + + def tearDown(self): + """Clean up.""" + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def test_npm_prefix_structure(self): + """Verify npm creates expected directory structure.""" + npm_prefix = self.lib_dir / 'npm' + npm_prefix.mkdir(parents=True) + + # Expected structure after npm install: + # npm/ + # bin/ (symlinks to binaries) + # node_modules/ (packages) + + expected_dirs = ['bin', 'node_modules'] + for dir_name in expected_dirs: + (npm_prefix / dir_name).mkdir(exist_ok=True) + + for dir_name in expected_dirs: + self.assertTrue((npm_prefix / dir_name).exists()) + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/archivebox/plugins/pip/tests/__init__.py b/archivebox/plugins/pip/tests/__init__.py new file mode 100644 index 00000000..28ac0d82 --- /dev/null +++ b/archivebox/plugins/pip/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for the pip binary provider plugin.""" diff --git a/archivebox/plugins/pip/tests/test_pip_provider.py b/archivebox/plugins/pip/tests/test_pip_provider.py new file mode 100644 index 00000000..3a63f84b --- /dev/null +++ b/archivebox/plugins/pip/tests/test_pip_provider.py @@ -0,0 +1,198 @@ +""" +Tests for the pip binary provider plugin. + +Tests cover: +1. Hook script execution +2. pip package detection +3. Virtual environment handling +4. JSONL output format +""" + +import json +import os +import subprocess +import sys +import tempfile +from pathlib import Path +from unittest.mock import patch, MagicMock + +import pytest +from django.test import TestCase + + +# Get the path to the pip provider hook +PLUGIN_DIR = Path(__file__).parent.parent +INSTALL_HOOK = PLUGIN_DIR / 'on_Binary__install_using_pip_provider.py' + + +class TestPipProviderHook(TestCase): + """Test the pip binary provider installation hook.""" + + def setUp(self): + """Set up test environment.""" + self.temp_dir = tempfile.mkdtemp() + self.output_dir = Path(self.temp_dir) / 'output' + self.output_dir.mkdir() + + def tearDown(self): + """Clean up.""" + import shutil + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def test_hook_script_exists(self): + """Hook script should exist.""" + self.assertTrue(INSTALL_HOOK.exists(), f"Hook not found: {INSTALL_HOOK}") + + def test_hook_help(self): + """Hook should accept --help without error.""" + result = subprocess.run( + [sys.executable, str(INSTALL_HOOK), '--help'], + capture_output=True, + text=True, + timeout=30 + ) + # May succeed or fail depending on implementation + # At minimum should not crash with Python error + self.assertNotIn('Traceback', result.stderr) + + def test_hook_finds_python(self): + """Hook should find Python binary.""" + env = os.environ.copy() + env['DATA_DIR'] = self.temp_dir + + result = subprocess.run( + [ + sys.executable, str(INSTALL_HOOK), + '--name=python3', + '--binproviders=pip,env', + ], + capture_output=True, + text=True, + cwd=str(self.output_dir), + env=env, + timeout=60 + ) + + # Check for JSONL output + jsonl_found = False + for line in result.stdout.split('\n'): + line = line.strip() + if line.startswith('{'): + try: + record = json.loads(line) + if record.get('type') == 'Binary' and record.get('name') == 'python3': + jsonl_found = True + # Verify structure + self.assertIn('abspath', record) + self.assertIn('version', record) + break + except json.JSONDecodeError: + continue + + # May or may not find python3 via pip, but should not crash + self.assertNotIn('Traceback', result.stderr) + + def test_hook_unknown_package(self): + """Hook should handle unknown packages gracefully.""" + env = os.environ.copy() + env['DATA_DIR'] = self.temp_dir + + result = subprocess.run( + [ + sys.executable, str(INSTALL_HOOK), + '--name=nonexistent_package_xyz123', + '--binproviders=pip', + ], + capture_output=True, + text=True, + cwd=str(self.output_dir), + env=env, + timeout=60 + ) + + # Should not crash + self.assertNotIn('Traceback', result.stderr) + # May have non-zero exit code for missing package + + +class TestPipProviderIntegration(TestCase): + """Integration tests for pip provider with real packages.""" + + def setUp(self): + """Set up test environment.""" + self.temp_dir = tempfile.mkdtemp() + self.output_dir = Path(self.temp_dir) / 'output' + self.output_dir.mkdir() + + def tearDown(self): + """Clean up.""" + import shutil + shutil.rmtree(self.temp_dir, ignore_errors=True) + + @pytest.mark.skipif( + subprocess.run([sys.executable, '-m', 'pip', '--version'], + capture_output=True).returncode != 0, + reason="pip not available" + ) + def test_hook_finds_pip_installed_binary(self): + """Hook should find binaries installed via pip.""" + env = os.environ.copy() + env['DATA_DIR'] = self.temp_dir + + # Try to find 'pip' itself which should be available + result = subprocess.run( + [ + sys.executable, str(INSTALL_HOOK), + '--name=pip', + '--binproviders=pip,env', + ], + capture_output=True, + text=True, + cwd=str(self.output_dir), + env=env, + timeout=60 + ) + + # Look for success in output + for line in result.stdout.split('\n'): + line = line.strip() + if line.startswith('{'): + try: + record = json.loads(line) + if record.get('type') == 'Binary' and 'pip' in record.get('name', ''): + # Found pip binary + self.assertTrue(record.get('abspath')) + return + except json.JSONDecodeError: + continue + + # If we get here without finding pip, that's acceptable + # as long as the hook didn't crash + self.assertNotIn('Traceback', result.stderr) + + +class TestPipProviderOutput(TestCase): + """Test JSONL output format from pip provider.""" + + def test_binary_record_format(self): + """Binary JSONL records should have required fields.""" + # Example of expected format + record = { + 'type': 'Binary', + 'name': 'wget', + 'abspath': '/usr/bin/wget', + 'version': '1.21', + 'binprovider': 'pip', + 'sha256': 'abc123...', + } + + # Validate structure + self.assertEqual(record['type'], 'Binary') + self.assertIn('name', record) + self.assertIn('abspath', record) + self.assertIn('version', record) + self.assertIn('binprovider', record) + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/archivebox/plugins/search_backend_ripgrep/tests/test_ripgrep_search.py b/archivebox/plugins/search_backend_ripgrep/tests/test_ripgrep_search.py new file mode 100644 index 00000000..75513d34 --- /dev/null +++ b/archivebox/plugins/search_backend_ripgrep/tests/test_ripgrep_search.py @@ -0,0 +1,308 @@ +""" +Tests for the ripgrep search backend. + +Tests cover: +1. Search with ripgrep binary +2. Snapshot ID extraction from file paths +3. Timeout handling +4. Error handling +5. Environment variable configuration +""" + +import os +import shutil +import subprocess +import tempfile +from pathlib import Path +from unittest.mock import patch, MagicMock + +import pytest +from django.test import TestCase + +from archivebox.plugins.search_backend_ripgrep.search import ( + search, + flush, + get_env, + get_env_int, + get_env_array, +) + + +class TestEnvHelpers(TestCase): + """Test environment variable helper functions.""" + + def test_get_env_default(self): + """get_env should return default for unset vars.""" + result = get_env('NONEXISTENT_VAR_12345', 'default') + self.assertEqual(result, 'default') + + def test_get_env_set(self): + """get_env should return value for set vars.""" + with patch.dict(os.environ, {'TEST_VAR': 'value'}): + result = get_env('TEST_VAR', 'default') + self.assertEqual(result, 'value') + + def test_get_env_strips_whitespace(self): + """get_env should strip whitespace.""" + with patch.dict(os.environ, {'TEST_VAR': ' value '}): + result = get_env('TEST_VAR', '') + self.assertEqual(result, 'value') + + def test_get_env_int_default(self): + """get_env_int should return default for unset vars.""" + result = get_env_int('NONEXISTENT_VAR_12345', 42) + self.assertEqual(result, 42) + + def test_get_env_int_valid(self): + """get_env_int should parse integer values.""" + with patch.dict(os.environ, {'TEST_INT': '100'}): + result = get_env_int('TEST_INT', 0) + self.assertEqual(result, 100) + + def test_get_env_int_invalid(self): + """get_env_int should return default for invalid integers.""" + with patch.dict(os.environ, {'TEST_INT': 'not a number'}): + result = get_env_int('TEST_INT', 42) + self.assertEqual(result, 42) + + def test_get_env_array_default(self): + """get_env_array should return default for unset vars.""" + result = get_env_array('NONEXISTENT_VAR_12345', ['default']) + self.assertEqual(result, ['default']) + + def test_get_env_array_valid(self): + """get_env_array should parse JSON arrays.""" + with patch.dict(os.environ, {'TEST_ARRAY': '["a", "b", "c"]'}): + result = get_env_array('TEST_ARRAY', []) + self.assertEqual(result, ['a', 'b', 'c']) + + def test_get_env_array_invalid_json(self): + """get_env_array should return default for invalid JSON.""" + with patch.dict(os.environ, {'TEST_ARRAY': 'not json'}): + result = get_env_array('TEST_ARRAY', ['default']) + self.assertEqual(result, ['default']) + + def test_get_env_array_not_array(self): + """get_env_array should return default for non-array JSON.""" + with patch.dict(os.environ, {'TEST_ARRAY': '{"key": "value"}'}): + result = get_env_array('TEST_ARRAY', ['default']) + self.assertEqual(result, ['default']) + + +class TestRipgrepFlush(TestCase): + """Test the flush function.""" + + def test_flush_is_noop(self): + """flush should be a no-op for ripgrep backend.""" + # Should not raise + flush(['snap-001', 'snap-002']) + + +class TestRipgrepSearch(TestCase): + """Test the ripgrep search function.""" + + def setUp(self): + """Create temporary archive directory with test files.""" + self.temp_dir = tempfile.mkdtemp() + self.archive_dir = Path(self.temp_dir) / 'archive' + self.archive_dir.mkdir() + + # Create snapshot directories with searchable content + self._create_snapshot('snap-001', { + 'singlefile/index.html': 'Python programming tutorial', + 'title/title.txt': 'Learn Python Programming', + }) + self._create_snapshot('snap-002', { + 'singlefile/index.html': 'JavaScript guide', + 'title/title.txt': 'JavaScript Basics', + }) + self._create_snapshot('snap-003', { + 'wget/index.html': 'Web archiving best practices', + 'title/title.txt': 'Web Archiving Guide', + }) + + # Patch settings + self.settings_patch = patch( + 'archivebox.plugins.search_backend_ripgrep.search.settings' + ) + self.mock_settings = self.settings_patch.start() + self.mock_settings.ARCHIVE_DIR = str(self.archive_dir) + + def tearDown(self): + """Clean up temporary directory.""" + self.settings_patch.stop() + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def _create_snapshot(self, snapshot_id: str, files: dict): + """Create a snapshot directory with files.""" + snap_dir = self.archive_dir / snapshot_id + for path, content in files.items(): + file_path = snap_dir / path + file_path.parent.mkdir(parents=True, exist_ok=True) + file_path.write_text(content) + + def _has_ripgrep(self) -> bool: + """Check if ripgrep is available.""" + return shutil.which('rg') is not None + + def test_search_no_archive_dir(self): + """search should return empty list when archive dir doesn't exist.""" + self.mock_settings.ARCHIVE_DIR = '/nonexistent/path' + results = search('test') + self.assertEqual(results, []) + + @pytest.mark.skipif(not shutil.which('rg'), reason="ripgrep not installed") + def test_search_single_match(self): + """search should find matching snapshot.""" + results = search('Python programming') + + self.assertIn('snap-001', results) + self.assertNotIn('snap-002', results) + self.assertNotIn('snap-003', results) + + @pytest.mark.skipif(not shutil.which('rg'), reason="ripgrep not installed") + def test_search_multiple_matches(self): + """search should find all matching snapshots.""" + # 'guide' appears in snap-002 (JavaScript guide) and snap-003 (Archiving Guide) + results = search('guide') + + self.assertIn('snap-002', results) + self.assertIn('snap-003', results) + self.assertNotIn('snap-001', results) + + @pytest.mark.skipif(not shutil.which('rg'), reason="ripgrep not installed") + def test_search_case_insensitive_by_default(self): + """search should be case-sensitive (ripgrep default).""" + # By default rg is case-sensitive + results_upper = search('PYTHON') + results_lower = search('python') + + # Depending on ripgrep config, results may differ + self.assertIsInstance(results_upper, list) + self.assertIsInstance(results_lower, list) + + @pytest.mark.skipif(not shutil.which('rg'), reason="ripgrep not installed") + def test_search_no_results(self): + """search should return empty list for no matches.""" + results = search('xyznonexistent123') + self.assertEqual(results, []) + + @pytest.mark.skipif(not shutil.which('rg'), reason="ripgrep not installed") + def test_search_regex(self): + """search should support regex patterns.""" + results = search('(Python|JavaScript)') + + self.assertIn('snap-001', results) + self.assertIn('snap-002', results) + + @pytest.mark.skipif(not shutil.which('rg'), reason="ripgrep not installed") + def test_search_distinct_snapshots(self): + """search should return distinct snapshot IDs.""" + # Query matches both files in snap-001 + results = search('Python') + + # Should only appear once + self.assertEqual(results.count('snap-001'), 1) + + def test_search_missing_binary(self): + """search should raise when ripgrep binary not found.""" + with patch.dict(os.environ, {'RIPGREP_BINARY': '/nonexistent/rg'}): + with patch('shutil.which', return_value=None): + with self.assertRaises(RuntimeError) as context: + search('test') + self.assertIn('ripgrep binary not found', str(context.exception)) + + @pytest.mark.skipif(not shutil.which('rg'), reason="ripgrep not installed") + def test_search_with_custom_args(self): + """search should use custom RIPGREP_ARGS.""" + with patch.dict(os.environ, {'RIPGREP_ARGS': '["-i"]'}): # Case insensitive + results = search('PYTHON') + # With -i flag, should find regardless of case + self.assertIn('snap-001', results) + + @pytest.mark.skipif(not shutil.which('rg'), reason="ripgrep not installed") + def test_search_timeout(self): + """search should handle timeout gracefully.""" + with patch.dict(os.environ, {'RIPGREP_TIMEOUT': '1'}): + # Short timeout, should still complete for small archive + results = search('Python') + self.assertIsInstance(results, list) + + +class TestRipgrepSearchIntegration(TestCase): + """Integration tests with realistic archive structure.""" + + def setUp(self): + """Create archive with realistic structure.""" + self.temp_dir = tempfile.mkdtemp() + self.archive_dir = Path(self.temp_dir) / 'archive' + self.archive_dir.mkdir() + + # Realistic snapshot structure + self._create_snapshot('1704067200.123456', { # 2024-01-01 + 'singlefile.html': ''' + +ArchiveBox Documentation + +

Getting Started with ArchiveBox

+

ArchiveBox is a powerful, self-hosted web archiving tool.

+

Install with: pip install archivebox

+ +''', + 'title/title.txt': 'ArchiveBox Documentation', + 'screenshot/screenshot.png': b'PNG IMAGE DATA', # Binary file + }) + self._create_snapshot('1704153600.654321', { # 2024-01-02 + 'wget/index.html': ''' +Python News + +

Python 3.12 Released

+

New features include improved error messages and performance.

+ +''', + 'readability/content.html': '

Python 3.12 has been released with exciting new features.

', + }) + + self.settings_patch = patch( + 'archivebox.plugins.search_backend_ripgrep.search.settings' + ) + self.mock_settings = self.settings_patch.start() + self.mock_settings.ARCHIVE_DIR = str(self.archive_dir) + + def tearDown(self): + """Clean up.""" + self.settings_patch.stop() + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def _create_snapshot(self, timestamp: str, files: dict): + """Create snapshot with timestamp-based ID.""" + snap_dir = self.archive_dir / timestamp + for path, content in files.items(): + file_path = snap_dir / path + file_path.parent.mkdir(parents=True, exist_ok=True) + if isinstance(content, bytes): + file_path.write_bytes(content) + else: + file_path.write_text(content) + + @pytest.mark.skipif(not shutil.which('rg'), reason="ripgrep not installed") + def test_search_archivebox(self): + """Search for archivebox should find documentation snapshot.""" + results = search('archivebox') + self.assertIn('1704067200.123456', results) + + @pytest.mark.skipif(not shutil.which('rg'), reason="ripgrep not installed") + def test_search_python(self): + """Search for python should find Python news snapshot.""" + results = search('Python') + self.assertIn('1704153600.654321', results) + + @pytest.mark.skipif(not shutil.which('rg'), reason="ripgrep not installed") + def test_search_pip_install(self): + """Search for installation command.""" + results = search('pip install') + self.assertIn('1704067200.123456', results) + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/archivebox/plugins/search_backend_sqlite/tests/__init__.py b/archivebox/plugins/search_backend_sqlite/tests/__init__.py new file mode 100644 index 00000000..6bef82e4 --- /dev/null +++ b/archivebox/plugins/search_backend_sqlite/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for the SQLite FTS5 search backend.""" diff --git a/archivebox/plugins/search_backend_sqlite/tests/test_sqlite_search.py b/archivebox/plugins/search_backend_sqlite/tests/test_sqlite_search.py new file mode 100644 index 00000000..ea12b85f --- /dev/null +++ b/archivebox/plugins/search_backend_sqlite/tests/test_sqlite_search.py @@ -0,0 +1,351 @@ +""" +Tests for the SQLite FTS5 search backend. + +Tests cover: +1. Search index creation +2. Indexing snapshots +3. Search queries with real test data +4. Flush operations +5. Edge cases (empty index, special characters) +""" + +import os +import sqlite3 +import tempfile +from pathlib import Path +from unittest.mock import patch + +import pytest +from django.test import TestCase, override_settings + +from archivebox.plugins.search_backend_sqlite.search import ( + get_db_path, + search, + flush, + SQLITEFTS_DB, + FTS_TOKENIZERS, +) + + +class TestSqliteSearchBackend(TestCase): + """Test SQLite FTS5 search backend.""" + + def setUp(self): + """Create a temporary data directory with search index.""" + self.temp_dir = tempfile.mkdtemp() + self.db_path = Path(self.temp_dir) / 'search.sqlite3' + + # Patch DATA_DIR + self.settings_patch = patch( + 'archivebox.plugins.search_backend_sqlite.search.settings' + ) + self.mock_settings = self.settings_patch.start() + self.mock_settings.DATA_DIR = self.temp_dir + + # Create FTS5 table + self._create_index() + + def tearDown(self): + """Clean up temporary directory.""" + self.settings_patch.stop() + import shutil + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def _create_index(self): + """Create the FTS5 search index table.""" + conn = sqlite3.connect(str(self.db_path)) + try: + conn.execute(f''' + CREATE VIRTUAL TABLE IF NOT EXISTS search_index + USING fts5( + snapshot_id, + url, + title, + content, + tokenize = '{FTS_TOKENIZERS}' + ) + ''') + conn.commit() + finally: + conn.close() + + def _index_snapshot(self, snapshot_id: str, url: str, title: str, content: str): + """Add a snapshot to the index.""" + conn = sqlite3.connect(str(self.db_path)) + try: + conn.execute( + 'INSERT INTO search_index (snapshot_id, url, title, content) VALUES (?, ?, ?, ?)', + (snapshot_id, url, title, content) + ) + conn.commit() + finally: + conn.close() + + def test_get_db_path(self): + """get_db_path should return correct path.""" + path = get_db_path() + self.assertEqual(path, Path(self.temp_dir) / SQLITEFTS_DB) + + def test_search_empty_index(self): + """search should return empty list for empty index.""" + results = search('nonexistent') + self.assertEqual(results, []) + + def test_search_no_index_file(self): + """search should return empty list when index file doesn't exist.""" + os.remove(self.db_path) + results = search('test') + self.assertEqual(results, []) + + def test_search_single_result(self): + """search should find matching snapshot.""" + self._index_snapshot( + 'snap-001', + 'https://example.com/page1', + 'Example Page', + 'This is example content about testing.' + ) + + results = search('example') + self.assertEqual(len(results), 1) + self.assertEqual(results[0], 'snap-001') + + def test_search_multiple_results(self): + """search should find all matching snapshots.""" + self._index_snapshot('snap-001', 'https://example.com/1', 'Python Tutorial', 'Learn Python programming') + self._index_snapshot('snap-002', 'https://example.com/2', 'Python Guide', 'Advanced Python concepts') + self._index_snapshot('snap-003', 'https://example.com/3', 'JavaScript Basics', 'Learn JavaScript') + + results = search('Python') + self.assertEqual(len(results), 2) + self.assertIn('snap-001', results) + self.assertIn('snap-002', results) + self.assertNotIn('snap-003', results) + + def test_search_title_match(self): + """search should match against title.""" + self._index_snapshot('snap-001', 'https://example.com', 'Django Web Framework', 'Content here') + + results = search('Django') + self.assertEqual(len(results), 1) + self.assertEqual(results[0], 'snap-001') + + def test_search_url_match(self): + """search should match against URL.""" + self._index_snapshot('snap-001', 'https://archivebox.io/docs', 'Title', 'Content') + + results = search('archivebox') + self.assertEqual(len(results), 1) + + def test_search_content_match(self): + """search should match against content.""" + self._index_snapshot( + 'snap-001', + 'https://example.com', + 'Generic Title', + 'This document contains information about cryptography and security.' + ) + + results = search('cryptography') + self.assertEqual(len(results), 1) + + def test_search_case_insensitive(self): + """search should be case insensitive.""" + self._index_snapshot('snap-001', 'https://example.com', 'Title', 'PYTHON programming') + + results = search('python') + self.assertEqual(len(results), 1) + + def test_search_stemming(self): + """search should use porter stemmer for word stems.""" + self._index_snapshot('snap-001', 'https://example.com', 'Title', 'Programming concepts') + + # 'program' should match 'programming' with porter stemmer + results = search('program') + self.assertEqual(len(results), 1) + + def test_search_multiple_words(self): + """search should match documents with all words.""" + self._index_snapshot('snap-001', 'https://example.com', 'Web Development', 'Learn web development skills') + self._index_snapshot('snap-002', 'https://example.com', 'Web Design', 'Design beautiful websites') + + results = search('web development') + # FTS5 defaults to OR, so both might match + # With porter stemmer, both should match 'web' + self.assertIn('snap-001', results) + + def test_search_phrase(self): + """search should support phrase queries.""" + self._index_snapshot('snap-001', 'https://example.com', 'Title', 'machine learning algorithms') + self._index_snapshot('snap-002', 'https://example.com', 'Title', 'machine algorithms learning') + + # Phrase search with quotes + results = search('"machine learning"') + self.assertEqual(len(results), 1) + self.assertEqual(results[0], 'snap-001') + + def test_search_distinct_results(self): + """search should return distinct snapshot IDs.""" + # Index same snapshot twice (could happen with multiple fields matching) + self._index_snapshot('snap-001', 'https://python.org', 'Python', 'Python programming language') + + results = search('Python') + self.assertEqual(len(results), 1) + + def test_flush_single(self): + """flush should remove snapshot from index.""" + self._index_snapshot('snap-001', 'https://example.com', 'Title', 'Content') + self._index_snapshot('snap-002', 'https://example.com', 'Title', 'Content') + + flush(['snap-001']) + + results = search('Content') + self.assertEqual(len(results), 1) + self.assertEqual(results[0], 'snap-002') + + def test_flush_multiple(self): + """flush should remove multiple snapshots.""" + self._index_snapshot('snap-001', 'https://example.com', 'Title', 'Test') + self._index_snapshot('snap-002', 'https://example.com', 'Title', 'Test') + self._index_snapshot('snap-003', 'https://example.com', 'Title', 'Test') + + flush(['snap-001', 'snap-003']) + + results = search('Test') + self.assertEqual(len(results), 1) + self.assertEqual(results[0], 'snap-002') + + def test_flush_nonexistent(self): + """flush should not raise for nonexistent snapshots.""" + # Should not raise + flush(['nonexistent-snap']) + + def test_flush_no_index(self): + """flush should not raise when index doesn't exist.""" + os.remove(self.db_path) + # Should not raise + flush(['snap-001']) + + def test_search_special_characters(self): + """search should handle special characters in queries.""" + self._index_snapshot('snap-001', 'https://example.com', 'C++ Programming', 'Learn C++ basics') + + # FTS5 handles special chars + results = search('C++') + # May or may not match depending on tokenizer config + # At minimum, should not raise + self.assertIsInstance(results, list) + + def test_search_unicode(self): + """search should handle unicode content.""" + self._index_snapshot('snap-001', 'https://example.com', 'Titre Francais', 'cafe resume') + self._index_snapshot('snap-002', 'https://example.com', 'Japanese', 'Hello world') + + # With remove_diacritics, 'cafe' should match + results = search('cafe') + self.assertEqual(len(results), 1) + + +class TestSqliteSearchWithRealData(TestCase): + """Integration tests with realistic archived content.""" + + def setUp(self): + """Create index with realistic test data.""" + self.temp_dir = tempfile.mkdtemp() + self.db_path = Path(self.temp_dir) / 'search.sqlite3' + + self.settings_patch = patch( + 'archivebox.plugins.search_backend_sqlite.search.settings' + ) + self.mock_settings = self.settings_patch.start() + self.mock_settings.DATA_DIR = self.temp_dir + + # Create index + conn = sqlite3.connect(str(self.db_path)) + try: + conn.execute(f''' + CREATE VIRTUAL TABLE IF NOT EXISTS search_index + USING fts5( + snapshot_id, + url, + title, + content, + tokenize = '{FTS_TOKENIZERS}' + ) + ''') + # Index realistic data + test_data = [ + ('snap-001', 'https://github.com/ArchiveBox/ArchiveBox', + 'ArchiveBox - Self-hosted web archiving', + 'Open source self-hosted web archiving. Collects, saves, and displays various types of content.'), + ('snap-002', 'https://docs.python.org/3/tutorial/', + 'Python 3 Tutorial', + 'An informal introduction to Python. Python is an easy to learn, powerful programming language.'), + ('snap-003', 'https://developer.mozilla.org/docs/Web/JavaScript', + 'JavaScript - MDN Web Docs', + 'JavaScript (JS) is a lightweight, interpreted programming language with first-class functions.'), + ('snap-004', 'https://news.ycombinator.com', + 'Hacker News', + 'Social news website focusing on computer science and entrepreneurship.'), + ('snap-005', 'https://en.wikipedia.org/wiki/Web_archiving', + 'Web archiving - Wikipedia', + 'Web archiving is the process of collecting portions of the World Wide Web to ensure the information is preserved.'), + ] + conn.executemany( + 'INSERT INTO search_index (snapshot_id, url, title, content) VALUES (?, ?, ?, ?)', + test_data + ) + conn.commit() + finally: + conn.close() + + def tearDown(self): + """Clean up.""" + self.settings_patch.stop() + import shutil + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def test_search_archivebox(self): + """Search for 'archivebox' should find relevant results.""" + results = search('archivebox') + self.assertIn('snap-001', results) + + def test_search_programming(self): + """Search for 'programming' should find Python and JS docs.""" + results = search('programming') + self.assertIn('snap-002', results) + self.assertIn('snap-003', results) + + def test_search_web_archiving(self): + """Search for 'web archiving' should find relevant results.""" + results = search('web archiving') + # Both ArchiveBox and Wikipedia should match + self.assertIn('snap-001', results) + self.assertIn('snap-005', results) + + def test_search_github(self): + """Search for 'github' should find URL match.""" + results = search('github') + self.assertIn('snap-001', results) + + def test_search_tutorial(self): + """Search for 'tutorial' should find Python tutorial.""" + results = search('tutorial') + self.assertIn('snap-002', results) + + def test_flush_and_search(self): + """Flushing a snapshot should remove it from search results.""" + # Verify it's there first + results = search('archivebox') + self.assertIn('snap-001', results) + + # Flush it + flush(['snap-001']) + + # Should no longer be found + results = search('archivebox') + self.assertNotIn('snap-001', results) + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/archivebox/workers/tests/__init__.py b/archivebox/workers/tests/__init__.py new file mode 100644 index 00000000..f798b10f --- /dev/null +++ b/archivebox/workers/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for the workers module (Orchestrator, Worker, pid_utils).""" diff --git a/archivebox/workers/tests/test_orchestrator.py b/archivebox/workers/tests/test_orchestrator.py new file mode 100644 index 00000000..033ac087 --- /dev/null +++ b/archivebox/workers/tests/test_orchestrator.py @@ -0,0 +1,364 @@ +""" +Unit tests for the Orchestrator and Worker classes. + +Tests cover: +1. Orchestrator lifecycle (startup, shutdown) +2. Queue polling and worker spawning +3. Idle detection and exit logic +4. Worker registration and management +5. PID file utilities +""" + +import os +import tempfile +import time +import signal +from pathlib import Path +from unittest.mock import patch, MagicMock + +import pytest +from django.test import TestCase, override_settings + +from archivebox.workers.pid_utils import ( + get_pid_dir, + write_pid_file, + read_pid_file, + remove_pid_file, + is_process_alive, + get_all_pid_files, + get_all_worker_pids, + cleanup_stale_pid_files, + get_running_worker_count, + get_next_worker_id, + stop_worker, +) +from archivebox.workers.orchestrator import Orchestrator + + +class TestPidUtils(TestCase): + """Test PID file utility functions.""" + + def setUp(self): + """Create a temporary directory for PID files.""" + self.temp_dir = tempfile.mkdtemp() + self.pid_dir_patch = patch( + 'archivebox.workers.pid_utils.get_pid_dir', + return_value=Path(self.temp_dir) + ) + self.pid_dir_patch.start() + + def tearDown(self): + """Clean up temporary directory.""" + self.pid_dir_patch.stop() + import shutil + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def test_write_pid_file_orchestrator(self): + """write_pid_file should create orchestrator.pid for orchestrator.""" + pid_file = write_pid_file('orchestrator') + + self.assertTrue(pid_file.exists()) + self.assertEqual(pid_file.name, 'orchestrator.pid') + + content = pid_file.read_text().strip().split('\n') + self.assertEqual(int(content[0]), os.getpid()) + self.assertEqual(content[1], 'orchestrator') + + def test_write_pid_file_worker(self): + """write_pid_file should create numbered pid file for workers.""" + pid_file = write_pid_file('snapshot', worker_id=3) + + self.assertTrue(pid_file.exists()) + self.assertEqual(pid_file.name, 'snapshot_worker_3.pid') + + def test_write_pid_file_with_extractor(self): + """write_pid_file should include extractor in content.""" + pid_file = write_pid_file('archiveresult', worker_id=0, extractor='singlefile') + + content = pid_file.read_text().strip().split('\n') + self.assertEqual(content[2], 'singlefile') + + def test_read_pid_file_valid(self): + """read_pid_file should parse valid PID files.""" + pid_file = write_pid_file('snapshot', worker_id=1) + info = read_pid_file(pid_file) + + self.assertIsNotNone(info) + self.assertEqual(info['pid'], os.getpid()) + self.assertEqual(info['worker_type'], 'snapshot') + self.assertEqual(info['pid_file'], pid_file) + self.assertIsNotNone(info['started_at']) + + def test_read_pid_file_invalid(self): + """read_pid_file should return None for invalid files.""" + invalid_file = Path(self.temp_dir) / 'invalid.pid' + invalid_file.write_text('not valid') + + info = read_pid_file(invalid_file) + self.assertIsNone(info) + + def test_read_pid_file_nonexistent(self): + """read_pid_file should return None for nonexistent files.""" + info = read_pid_file(Path(self.temp_dir) / 'nonexistent.pid') + self.assertIsNone(info) + + def test_remove_pid_file(self): + """remove_pid_file should delete the file.""" + pid_file = write_pid_file('test', worker_id=0) + self.assertTrue(pid_file.exists()) + + remove_pid_file(pid_file) + self.assertFalse(pid_file.exists()) + + def test_remove_pid_file_nonexistent(self): + """remove_pid_file should not raise for nonexistent files.""" + # Should not raise + remove_pid_file(Path(self.temp_dir) / 'nonexistent.pid') + + def test_is_process_alive_current(self): + """is_process_alive should return True for current process.""" + self.assertTrue(is_process_alive(os.getpid())) + + def test_is_process_alive_dead(self): + """is_process_alive should return False for dead processes.""" + # PID 999999 is unlikely to exist + self.assertFalse(is_process_alive(999999)) + + def test_get_all_pid_files(self): + """get_all_pid_files should return all .pid files.""" + write_pid_file('orchestrator') + write_pid_file('snapshot', worker_id=0) + write_pid_file('crawl', worker_id=1) + + files = get_all_pid_files() + self.assertEqual(len(files), 3) + + def test_get_all_worker_pids(self): + """get_all_worker_pids should return info for live workers.""" + write_pid_file('snapshot', worker_id=0) + write_pid_file('crawl', worker_id=1) + + workers = get_all_worker_pids() + # All should be alive since they're current process PID + self.assertEqual(len(workers), 2) + + def test_get_all_worker_pids_filtered(self): + """get_all_worker_pids should filter by worker type.""" + write_pid_file('snapshot', worker_id=0) + write_pid_file('snapshot', worker_id=1) + write_pid_file('crawl', worker_id=0) + + snapshot_workers = get_all_worker_pids('snapshot') + self.assertEqual(len(snapshot_workers), 2) + + crawl_workers = get_all_worker_pids('crawl') + self.assertEqual(len(crawl_workers), 1) + + def test_cleanup_stale_pid_files(self): + """cleanup_stale_pid_files should remove files for dead processes.""" + # Create a PID file with a dead PID + stale_file = Path(self.temp_dir) / 'stale_worker_0.pid' + stale_file.write_text('999999\nstale\n\n2024-01-01T00:00:00+00:00\n') + + # Create a valid PID file (current process) + write_pid_file('valid', worker_id=0) + + removed = cleanup_stale_pid_files() + + self.assertEqual(removed, 1) + self.assertFalse(stale_file.exists()) + + def test_get_running_worker_count(self): + """get_running_worker_count should count workers of a type.""" + write_pid_file('snapshot', worker_id=0) + write_pid_file('snapshot', worker_id=1) + write_pid_file('crawl', worker_id=0) + + self.assertEqual(get_running_worker_count('snapshot'), 2) + self.assertEqual(get_running_worker_count('crawl'), 1) + self.assertEqual(get_running_worker_count('archiveresult'), 0) + + def test_get_next_worker_id(self): + """get_next_worker_id should find lowest unused ID.""" + write_pid_file('snapshot', worker_id=0) + write_pid_file('snapshot', worker_id=1) + write_pid_file('snapshot', worker_id=3) # Skip 2 + + next_id = get_next_worker_id('snapshot') + self.assertEqual(next_id, 2) + + def test_get_next_worker_id_empty(self): + """get_next_worker_id should return 0 if no workers exist.""" + next_id = get_next_worker_id('snapshot') + self.assertEqual(next_id, 0) + + +class TestOrchestratorUnit(TestCase): + """Unit tests for Orchestrator class (mocked dependencies).""" + + def test_orchestrator_creation(self): + """Orchestrator should initialize with correct defaults.""" + orchestrator = Orchestrator(exit_on_idle=True) + + self.assertTrue(orchestrator.exit_on_idle) + self.assertEqual(orchestrator.idle_count, 0) + self.assertIsNone(orchestrator.pid_file) + + def test_orchestrator_repr(self): + """Orchestrator __repr__ should include PID.""" + orchestrator = Orchestrator() + repr_str = repr(orchestrator) + + self.assertIn('Orchestrator', repr_str) + self.assertIn(str(os.getpid()), repr_str) + + def test_has_pending_work(self): + """has_pending_work should check if any queue has items.""" + orchestrator = Orchestrator() + + self.assertFalse(orchestrator.has_pending_work({'crawl': 0, 'snapshot': 0})) + self.assertTrue(orchestrator.has_pending_work({'crawl': 0, 'snapshot': 5})) + self.assertTrue(orchestrator.has_pending_work({'crawl': 10, 'snapshot': 0})) + + def test_should_exit_not_exit_on_idle(self): + """should_exit should return False when exit_on_idle is False.""" + orchestrator = Orchestrator(exit_on_idle=False) + orchestrator.idle_count = 100 + + self.assertFalse(orchestrator.should_exit({'crawl': 0})) + + def test_should_exit_pending_work(self): + """should_exit should return False when there's pending work.""" + orchestrator = Orchestrator(exit_on_idle=True) + orchestrator.idle_count = 100 + + self.assertFalse(orchestrator.should_exit({'crawl': 5})) + + @patch.object(Orchestrator, 'has_running_workers') + def test_should_exit_running_workers(self, mock_has_workers): + """should_exit should return False when workers are running.""" + mock_has_workers.return_value = True + orchestrator = Orchestrator(exit_on_idle=True) + orchestrator.idle_count = 100 + + self.assertFalse(orchestrator.should_exit({'crawl': 0})) + + @patch.object(Orchestrator, 'has_running_workers') + @patch.object(Orchestrator, 'has_future_work') + def test_should_exit_idle_timeout(self, mock_future, mock_workers): + """should_exit should return True after idle timeout with no work.""" + mock_workers.return_value = False + mock_future.return_value = False + + orchestrator = Orchestrator(exit_on_idle=True) + orchestrator.idle_count = orchestrator.IDLE_TIMEOUT + + self.assertTrue(orchestrator.should_exit({'crawl': 0, 'snapshot': 0})) + + @patch.object(Orchestrator, 'has_running_workers') + @patch.object(Orchestrator, 'has_future_work') + def test_should_exit_below_idle_timeout(self, mock_future, mock_workers): + """should_exit should return False below idle timeout.""" + mock_workers.return_value = False + mock_future.return_value = False + + orchestrator = Orchestrator(exit_on_idle=True) + orchestrator.idle_count = orchestrator.IDLE_TIMEOUT - 1 + + self.assertFalse(orchestrator.should_exit({'crawl': 0})) + + def test_should_spawn_worker_no_queue(self): + """should_spawn_worker should return False when queue is empty.""" + orchestrator = Orchestrator() + + # Create a mock worker class + mock_worker = MagicMock() + mock_worker.get_running_workers.return_value = [] + + self.assertFalse(orchestrator.should_spawn_worker(mock_worker, 0)) + + def test_should_spawn_worker_at_limit(self): + """should_spawn_worker should return False when at per-type limit.""" + orchestrator = Orchestrator() + + mock_worker = MagicMock() + mock_worker.get_running_workers.return_value = [{}] * orchestrator.MAX_WORKERS_PER_TYPE + + self.assertFalse(orchestrator.should_spawn_worker(mock_worker, 10)) + + @patch.object(Orchestrator, 'get_total_worker_count') + def test_should_spawn_worker_at_total_limit(self, mock_total): + """should_spawn_worker should return False when at total limit.""" + orchestrator = Orchestrator() + mock_total.return_value = orchestrator.MAX_TOTAL_WORKERS + + mock_worker = MagicMock() + mock_worker.get_running_workers.return_value = [] + + self.assertFalse(orchestrator.should_spawn_worker(mock_worker, 10)) + + @patch.object(Orchestrator, 'get_total_worker_count') + def test_should_spawn_worker_success(self, mock_total): + """should_spawn_worker should return True when conditions are met.""" + orchestrator = Orchestrator() + mock_total.return_value = 0 + + mock_worker = MagicMock() + mock_worker.get_running_workers.return_value = [] + mock_worker.MAX_CONCURRENT_TASKS = 5 + + self.assertTrue(orchestrator.should_spawn_worker(mock_worker, 10)) + + @patch.object(Orchestrator, 'get_total_worker_count') + def test_should_spawn_worker_enough_workers(self, mock_total): + """should_spawn_worker should return False when enough workers for queue.""" + orchestrator = Orchestrator() + mock_total.return_value = 2 + + mock_worker = MagicMock() + mock_worker.get_running_workers.return_value = [{}] # 1 worker running + mock_worker.MAX_CONCURRENT_TASKS = 5 # Can handle 5 items + + # Queue size (3) <= running_workers (1) * MAX_CONCURRENT_TASKS (5) + self.assertFalse(orchestrator.should_spawn_worker(mock_worker, 3)) + + +class TestOrchestratorIsRunning(TestCase): + """Test Orchestrator.is_running() class method.""" + + def setUp(self): + """Create a temporary directory for PID files.""" + self.temp_dir = tempfile.mkdtemp() + self.pid_dir_patch = patch( + 'archivebox.workers.pid_utils.get_pid_dir', + return_value=Path(self.temp_dir) + ) + self.pid_dir_patch.start() + + def tearDown(self): + """Clean up.""" + self.pid_dir_patch.stop() + import shutil + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def test_is_running_no_pid_file(self): + """is_running should return False when no orchestrator PID file.""" + self.assertFalse(Orchestrator.is_running()) + + def test_is_running_with_live_orchestrator(self): + """is_running should return True when orchestrator PID file exists.""" + write_pid_file('orchestrator') + self.assertTrue(Orchestrator.is_running()) + + def test_is_running_with_dead_orchestrator(self): + """is_running should return False when orchestrator process is dead.""" + # Create a PID file with a dead PID + pid_file = Path(self.temp_dir) / 'orchestrator.pid' + pid_file.write_text('999999\norchestrator\n\n2024-01-01T00:00:00+00:00\n') + + # The get_all_worker_pids filters out dead processes + self.assertFalse(Orchestrator.is_running()) + + +if __name__ == '__main__': + pytest.main([__file__, '-v'])