From 438384425a96f3f69b4620e25945fd241e617c4a Mon Sep 17 00:00:00 2001
From: ed
Date: Thu, 16 Jun 2022 01:07:15 +0200
Subject: [PATCH] add types, isort, errorhandling
---
README.md | 9 +-
bin/mtag/image-noexif.py | 1 -
bin/up2k.py | 22 +-
contrib/systemd/copyparty.service | 2 +-
copyparty/__init__.py | 31 +-
copyparty/__main__.py | 100 ++--
copyparty/authsrv.py | 574 +++++++++++---------
copyparty/bos/bos.py | 33 +-
copyparty/bos/path.py | 21 +-
copyparty/broker_mp.py | 62 ++-
copyparty/broker_mpw.py | 61 ++-
copyparty/broker_thr.py | 57 +-
copyparty/broker_util.py | 41 +-
copyparty/ftpd.py | 194 ++++---
copyparty/httpcli.py | 696 +++++++++++++++----------
copyparty/httpconn.py | 94 ++--
copyparty/httpsrv.py | 107 ++--
copyparty/ico.py | 24 +-
copyparty/mtag.py | 142 ++---
copyparty/star.py | 54 +-
copyparty/stolen/surrogateescape.py | 23 +-
copyparty/sutil.py | 23 +-
copyparty/svchub.py | 113 ++--
copyparty/szip.py | 77 ++-
copyparty/tcpsrv.py | 75 +--
copyparty/th_cli.py | 28 +-
copyparty/th_srv.py | 120 +++--
copyparty/u2idx.py | 89 ++--
copyparty/up2k.py | 781 ++++++++++++++++------------
copyparty/util.py | 513 ++++++++++--------
scripts/make-pypi-release.sh | 9 +
scripts/make-sfx.sh | 60 ++-
scripts/run-tests.sh | 12 +-
scripts/sfx.py | 13 +-
scripts/sfx.sh | 2 +-
scripts/strip_hints/a.py | 57 ++
scripts/test/race.py | 8 +-
scripts/test/smoketest.py | 2 +
tests/test_vfs.py | 12 +-
tests/util.py | 5 +-
40 files changed, 2597 insertions(+), 1750 deletions(-)
create mode 100644 scripts/strip_hints/a.py
diff --git a/README.md b/README.md
index 38c60727..6493f0d2 100644
--- a/README.md
+++ b/README.md
@@ -1200,15 +1200,18 @@ journalctl -aS '48 hour ago' -u copyparty | grep -C10 FILENAME | tee bug.log
## dev env setup
-mostly optional; if you need a working env for vscode or similar
+you need python 3.9 or newer due to type hints
+
+the rest is mostly optional; if you need a working env for vscode or similar
```sh
python3 -m venv .venv
. .venv/bin/activate
-pip install jinja2 # mandatory
+pip install jinja2 strip_hints # MANDATORY
pip install mutagen # audio metadata
+pip install pyftpdlib # ftp server
pip install Pillow pyheif-pillow-opener pillow-avif-plugin # thumbnails
-pip install black==21.12b0 bandit pylint flake8 # vscode tooling
+pip install black==21.12b0 click==8.0.2 bandit pylint flake8 isort mypy # vscode tooling
```
diff --git a/bin/mtag/image-noexif.py b/bin/mtag/image-noexif.py
index d5392b00..bc009c17 100644
--- a/bin/mtag/image-noexif.py
+++ b/bin/mtag/image-noexif.py
@@ -43,7 +43,6 @@ PS: this requires e2ts to be functional,
import os
import sys
-import time
import filecmp
import subprocess as sp
diff --git a/bin/up2k.py b/bin/up2k.py
index 8cf904c4..23d35ea3 100755
--- a/bin/up2k.py
+++ b/bin/up2k.py
@@ -77,15 +77,15 @@ class File(object):
self.up_b = 0 # type: int
self.up_c = 0 # type: int
- # m = "size({}) lmod({}) top({}) rel({}) abs({}) name({})\n"
- # eprint(m.format(self.size, self.lmod, self.top, self.rel, self.abs, self.name))
+ # t = "size({}) lmod({}) top({}) rel({}) abs({}) name({})\n"
+ # eprint(t.format(self.size, self.lmod, self.top, self.rel, self.abs, self.name))
class FileSlice(object):
"""file-like object providing a fixed window into a file"""
def __init__(self, file, cid):
- # type: (File, str) -> FileSlice
+ # type: (File, str) -> None
self.car, self.len = file.kchunks[cid]
self.cdr = self.car + self.len
@@ -216,8 +216,8 @@ class CTermsize(object):
eprint("\033[s\033[r\033[u")
else:
self.g = 1 + self.h - margin
- m = "{0}\033[{1}A".format("\n" * margin, margin)
- eprint("{0}\033[s\033[1;{1}r\033[u".format(m, self.g - 1))
+ t = "{0}\033[{1}A".format("\n" * margin, margin)
+ eprint("{0}\033[s\033[1;{1}r\033[u".format(t, self.g - 1))
ss = CTermsize()
@@ -597,8 +597,8 @@ class Ctl(object):
if "/" in name:
name = "\033[36m{0}\033[0m/{1}".format(*name.rsplit("/", 1))
- m = "{0:6.1f}% {1} {2}\033[K"
- txt += m.format(p, self.nfiles - f, name)
+ t = "{0:6.1f}% {1} {2}\033[K"
+ txt += t.format(p, self.nfiles - f, name)
txt += "\033[{0}H ".format(ss.g + 2)
else:
@@ -618,8 +618,8 @@ class Ctl(object):
nleft = self.nfiles - self.up_f
tail = "\033[K\033[u" if VT100 else "\r"
- m = "{0} eta @ {1}/s, {2}, {3}# left".format(eta, spd, sleft, nleft)
- eprint(txt + "\033]0;{0}\033\\\r{0}{1}".format(m, tail))
+ t = "{0} eta @ {1}/s, {2}, {3}# left".format(eta, spd, sleft, nleft)
+ eprint(txt + "\033]0;{0}\033\\\r{0}{1}".format(t, tail))
def cleanup_vt100(self):
ss.scroll_region(None)
@@ -721,8 +721,8 @@ class Ctl(object):
if search:
if hs:
for hit in hs:
- m = "found: {0}\n {1}{2}\n"
- print(m.format(upath, burl, hit["rp"]), end="")
+ t = "found: {0}\n {1}{2}\n"
+ print(t.format(upath, burl, hit["rp"]), end="")
else:
print("NOT found: {0}\n".format(upath), end="")
diff --git a/contrib/systemd/copyparty.service b/contrib/systemd/copyparty.service
index e8e02f90..fd3f9efc 100644
--- a/contrib/systemd/copyparty.service
+++ b/contrib/systemd/copyparty.service
@@ -4,7 +4,7 @@
# installation:
# cp -pv copyparty.service /etc/systemd/system
# restorecon -vr /etc/systemd/system/copyparty.service
-# firewall-cmd --permanent --add-port={80,443,3923}/tcp
+# firewall-cmd --permanent --add-port={80,443,3923}/tcp # --zone=libvirt
# firewall-cmd --reload
# systemctl daemon-reload && systemctl enable --now copyparty
#
diff --git a/copyparty/__init__.py b/copyparty/__init__.py
index 17513708..594363eb 100644
--- a/copyparty/__init__.py
+++ b/copyparty/__init__.py
@@ -1,21 +1,30 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
-import platform
-import time
-import sys
import os
+import platform
+import sys
+import time
+
+try:
+ from collections.abc import Callable
+
+ from typing import TYPE_CHECKING, Any
+except:
+ TYPE_CHECKING = False
PY2 = sys.version_info[0] == 2
if PY2:
sys.dont_write_bytecode = True
- unicode = unicode
+ unicode = unicode # noqa: F821 # pylint: disable=undefined-variable,self-assigning-variable
else:
unicode = str
-WINDOWS = False
-if platform.system() == "Windows":
- WINDOWS = [int(x) for x in platform.version().split(".")]
+WINDOWS: Any = (
+ [int(x) for x in platform.version().split(".")]
+ if platform.system() == "Windows"
+ else False
+)
VT100 = not WINDOWS or WINDOWS >= [10, 0, 14393]
# introduced in anniversary update
@@ -25,8 +34,8 @@ ANYWIN = WINDOWS or sys.platform in ["msys"]
MACOS = platform.system() == "Darwin"
-def get_unixdir():
- paths = [
+def get_unixdir() -> str:
+ paths: list[tuple[Callable[..., str], str]] = [
(os.environ.get, "XDG_CONFIG_HOME"),
(os.path.expanduser, "~/.config"),
(os.environ.get, "TMPDIR"),
@@ -43,7 +52,7 @@ def get_unixdir():
continue
p = os.path.normpath(p)
- chk(p)
+ chk(p) # type: ignore
p = os.path.join(p, "copyparty")
if not os.path.isdir(p):
os.mkdir(p)
@@ -56,7 +65,7 @@ def get_unixdir():
class EnvParams(object):
- def __init__(self):
+ def __init__(self) -> None:
self.t0 = time.time()
self.mod = os.path.dirname(os.path.realpath(__file__))
if self.mod.endswith("__init__"):
diff --git a/copyparty/__main__.py b/copyparty/__main__.py
index 87203f87..5412a6c2 100644
--- a/copyparty/__main__.py
+++ b/copyparty/__main__.py
@@ -8,35 +8,42 @@ __copyright__ = 2019
__license__ = "MIT"
__url__ = "https://github.com/9001/copyparty/"
-import re
-import os
-import sys
-import time
-import shutil
+import argparse
import filecmp
import locale
-import argparse
+import os
+import re
+import shutil
+import sys
import threading
+import time
import traceback
from textwrap import dedent
-from .__init__ import E, WINDOWS, ANYWIN, VT100, PY2, unicode
-from .__version__ import S_VERSION, S_BUILD_DT, CODENAME
-from .svchub import SvcHub
-from .util import py_desc, align_tab, IMPLICATIONS, ansi_re, min_ex
+from .__init__ import ANYWIN, PY2, VT100, WINDOWS, E, unicode
+from .__version__ import CODENAME, S_BUILD_DT, S_VERSION
from .authsrv import re_vol
+from .svchub import SvcHub
+from .util import IMPLICATIONS, align_tab, ansi_re, min_ex, py_desc
-HAVE_SSL = True
try:
+ from types import FrameType
+
+ from typing import Any, Optional
+except:
+ pass
+
+try:
+ HAVE_SSL = True
import ssl
except:
HAVE_SSL = False
-printed = []
+printed: list[str] = []
class RiceFormatter(argparse.HelpFormatter):
- def _get_help_string(self, action):
+ def _get_help_string(self, action: argparse.Action) -> str:
"""
same as ArgumentDefaultsHelpFormatter(HelpFormatter)
except the help += [...] line now has colors
@@ -45,27 +52,27 @@ class RiceFormatter(argparse.HelpFormatter):
if not VT100:
fmt = " (default: %(default)s)"
- ret = action.help
- if "%(default)" not in action.help:
+ ret = str(action.help)
+ if "%(default)" not in ret:
if action.default is not argparse.SUPPRESS:
defaulting_nargs = [argparse.OPTIONAL, argparse.ZERO_OR_MORE]
if action.option_strings or action.nargs in defaulting_nargs:
ret += fmt
return ret
- def _fill_text(self, text, width, indent):
+ def _fill_text(self, text: str, width: int, indent: str) -> str:
"""same as RawDescriptionHelpFormatter(HelpFormatter)"""
return "".join(indent + line + "\n" for line in text.splitlines())
class Dodge11874(RiceFormatter):
- def __init__(self, *args, **kwargs):
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs["width"] = 9003
super(Dodge11874, self).__init__(*args, **kwargs)
-def lprint(*a, **ka):
- txt = " ".join(unicode(x) for x in a) + ka.get("end", "\n")
+def lprint(*a: Any, **ka: Any) -> None:
+ txt: str = " ".join(unicode(x) for x in a) + ka.get("end", "\n")
printed.append(txt)
if not VT100:
txt = ansi_re.sub("", txt)
@@ -73,11 +80,11 @@ def lprint(*a, **ka):
print(txt, **ka)
-def warn(msg):
+def warn(msg: str) -> None:
lprint("\033[1mwarning:\033[0;33m {}\033[0m\n".format(msg))
-def ensure_locale():
+def ensure_locale() -> None:
for x in [
"en_US.UTF-8",
"English_United States.UTF8",
@@ -91,7 +98,7 @@ def ensure_locale():
continue
-def ensure_cert():
+def ensure_cert() -> None:
"""
the default cert (and the entire TLS support) is only here to enable the
crypto.subtle javascript API, which is necessary due to the webkit guys
@@ -117,8 +124,8 @@ def ensure_cert():
# printf 'NO\n.\n.\n.\n.\ncopyparty-insecure\n.\n' | faketime '2000-01-01 00:00:00' openssl req -x509 -sha256 -newkey rsa:2048 -keyout insecure.pem -out insecure.pem -days $((($(printf %d 0x7fffffff)-$(date +%s --date=2000-01-01T00:00:00Z))/(60*60*24))) -nodes && ls -al insecure.pem && openssl x509 -in insecure.pem -text -noout
-def configure_ssl_ver(al):
- def terse_sslver(txt):
+def configure_ssl_ver(al: argparse.Namespace) -> None:
+ def terse_sslver(txt: str) -> str:
txt = txt.lower()
for c in ["_", "v", "."]:
txt = txt.replace(c, "")
@@ -133,8 +140,8 @@ def configure_ssl_ver(al):
flags = [k for k in ssl.__dict__ if ptn.match(k)]
# SSLv2 SSLv3 TLSv1 TLSv1_1 TLSv1_2 TLSv1_3
if "help" in sslver:
- avail = [terse_sslver(x[6:]) for x in flags]
- avail = " ".join(sorted(avail) + ["all"])
+ avail1 = [terse_sslver(x[6:]) for x in flags]
+ avail = " ".join(sorted(avail1) + ["all"])
lprint("\navailable ssl/tls versions:\n " + avail)
sys.exit(0)
@@ -160,7 +167,7 @@ def configure_ssl_ver(al):
# think i need that beer now
-def configure_ssl_ciphers(al):
+def configure_ssl_ciphers(al: argparse.Namespace) -> None:
ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
if al.ssl_ver:
ctx.options &= ~al.ssl_flags_en
@@ -184,8 +191,8 @@ def configure_ssl_ciphers(al):
sys.exit(0)
-def args_from_cfg(cfg_path):
- ret = []
+def args_from_cfg(cfg_path: str) -> list[str]:
+ ret: list[str] = []
skip = False
with open(cfg_path, "rb") as f:
for ln in [x.decode("utf-8").strip() for x in f]:
@@ -210,29 +217,30 @@ def args_from_cfg(cfg_path):
return ret
-def sighandler(sig=None, frame=None):
+def sighandler(sig: Optional[int] = None, frame: Optional[FrameType] = None) -> None:
msg = [""] * 5
for th in threading.enumerate():
+ stk = sys._current_frames()[th.ident] # type: ignore
msg.append(str(th))
- msg.extend(traceback.format_stack(sys._current_frames()[th.ident]))
+ msg.extend(traceback.format_stack(stk))
msg.append("\n")
print("\n".join(msg))
-def disable_quickedit():
- import ctypes
+def disable_quickedit() -> None:
import atexit
+ import ctypes
from ctypes import wintypes
- def ecb(ok, fun, args):
+ def ecb(ok: bool, fun: Any, args: list[Any]) -> list[Any]:
if not ok:
- err = ctypes.get_last_error()
+ err: int = ctypes.get_last_error() # type: ignore
if err:
- raise ctypes.WinError(err)
+ raise ctypes.WinError(err) # type: ignore
return args
- k32 = ctypes.WinDLL("kernel32", use_last_error=True)
+ k32 = ctypes.WinDLL("kernel32", use_last_error=True) # type: ignore
if PY2:
wintypes.LPDWORD = ctypes.POINTER(wintypes.DWORD)
@@ -242,14 +250,14 @@ def disable_quickedit():
k32.GetConsoleMode.argtypes = (wintypes.HANDLE, wintypes.LPDWORD)
k32.SetConsoleMode.argtypes = (wintypes.HANDLE, wintypes.DWORD)
- def cmode(out, mode=None):
+ def cmode(out: bool, mode: Optional[int] = None) -> int:
h = k32.GetStdHandle(-11 if out else -10)
if mode:
- return k32.SetConsoleMode(h, mode)
+ return k32.SetConsoleMode(h, mode) # type: ignore
- mode = wintypes.DWORD()
- k32.GetConsoleMode(h, ctypes.byref(mode))
- return mode.value
+ cmode = wintypes.DWORD()
+ k32.GetConsoleMode(h, ctypes.byref(cmode))
+ return cmode.value
# disable quickedit
mode = orig_in = cmode(False)
@@ -268,7 +276,7 @@ def disable_quickedit():
cmode(True, mode | 4)
-def run_argparse(argv, formatter):
+def run_argparse(argv: list[str], formatter: Any) -> argparse.Namespace:
ap = argparse.ArgumentParser(
formatter_class=formatter,
prog="copyparty",
@@ -596,7 +604,7 @@ def run_argparse(argv, formatter):
return ret
-def main(argv=None):
+def main(argv: Optional[list[str]] = None) -> None:
time.strptime("19970815", "%Y%m%d") # python#7980
if WINDOWS:
os.system("rem") # enables colors
@@ -618,7 +626,7 @@ def main(argv=None):
supp = args_from_cfg(v)
argv.extend(supp)
- deprecated = []
+ deprecated: list[tuple[str, str]] = []
for dk, nk in deprecated:
try:
idx = argv.index(dk)
@@ -650,7 +658,7 @@ def main(argv=None):
if not VT100:
al.wintitle = ""
- nstrs = []
+ nstrs: list[str] = []
anymod = False
for ostr in al.v or []:
m = re_vol.match(ostr)
diff --git a/copyparty/authsrv.py b/copyparty/authsrv.py
index 4ca406ee..92612529 100644
--- a/copyparty/authsrv.py
+++ b/copyparty/authsrv.py
@@ -1,44 +1,68 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
-import re
-import os
-import sys
-import stat
-import time
+import argparse
import base64
import hashlib
+import os
+import re
+import stat
+import sys
import threading
+import time
from datetime import datetime
-from .__init__ import ANYWIN, WINDOWS
+from .__init__ import ANYWIN, TYPE_CHECKING, WINDOWS
+from .bos import bos
from .util import (
IMPLICATIONS,
META_NOBOTS,
+ Pebkac,
+ absreal,
+ fsenc,
+ relchk,
+ statdir,
uncyg,
undot,
- relchk,
unhumanize,
- absreal,
- Pebkac,
- fsenc,
- statdir,
)
-from .bos import bos
+
+try:
+ from collections.abc import Iterable
+
+ import typing
+ from typing import Any, Generator, Optional, Union
+
+ from .util import RootLogger
+except:
+ pass
+
+if TYPE_CHECKING:
+ pass
+ # Vflags: TypeAlias = dict[str, str | bool | float | list[str]]
+ # Vflags: TypeAlias = dict[str, Any]
+ # Mflags: TypeAlias = dict[str, Vflags]
LEELOO_DALLAS = "leeloo_dallas"
class AXS(object):
- def __init__(self, uread=None, uwrite=None, umove=None, udel=None, uget=None):
- self.uread = {} if uread is None else {k: 1 for k in uread}
- self.uwrite = {} if uwrite is None else {k: 1 for k in uwrite}
- self.umove = {} if umove is None else {k: 1 for k in umove}
- self.udel = {} if udel is None else {k: 1 for k in udel}
- self.uget = {} if uget is None else {k: 1 for k in uget}
+ def __init__(
+ self,
+ uread: Optional[Union[list[str], set[str]]] = None,
+ uwrite: Optional[Union[list[str], set[str]]] = None,
+ umove: Optional[Union[list[str], set[str]]] = None,
+ udel: Optional[Union[list[str], set[str]]] = None,
+ uget: Optional[Union[list[str], set[str]]] = None,
+ ) -> None:
+ self.uread: set[str] = set(uread or [])
+ self.uwrite: set[str] = set(uwrite or [])
+ self.umove: set[str] = set(umove or [])
+ self.udel: set[str] = set(udel or [])
+ self.uget: set[str] = set(uget or [])
- def __repr__(self):
+ def __repr__(self) -> str:
return "AXS({})".format(
", ".join(
"{}={!r}".format(k, self.__dict__[k])
@@ -48,33 +72,33 @@ class AXS(object):
class Lim(object):
- def __init__(self):
- self.nups = {} # num tracker
- self.bups = {} # byte tracker list
- self.bupc = {} # byte tracker cache
+ def __init__(self) -> None:
+ self.nups: dict[str, list[float]] = {} # num tracker
+ self.bups: dict[str, list[tuple[float, int]]] = {} # byte tracker list
+ self.bupc: dict[str, int] = {} # byte tracker cache
self.nosub = False # disallow subdirectories
- self.smin = None # filesize min
- self.smax = None # filesize max
+ self.smin = -1 # filesize min
+ self.smax = -1 # filesize max
- self.bwin = None # bytes window
- self.bmax = None # bytes max
- self.nwin = None # num window
- self.nmax = None # num max
+ self.bwin = 0 # bytes window
+ self.bmax = 0 # bytes max
+ self.nwin = 0 # num window
+ self.nmax = 0 # num max
- self.rotn = None # rot num files
- self.rotl = None # rot depth
- self.rotf = None # rot datefmt
- self.rot_re = None # rotf check
+ self.rotn = 0 # rot num files
+ self.rotl = 0 # rot depth
+ self.rotf = "" # rot datefmt
+ self.rot_re = re.compile("") # rotf check
- def set_rotf(self, fmt):
+ def set_rotf(self, fmt: str) -> None:
self.rotf = fmt
r = re.escape(fmt).replace("%Y", "[0-9]{4}").replace("%j", "[0-9]{3}")
r = re.sub("%[mdHMSWU]", "[0-9]{2}", r)
self.rot_re = re.compile("(^|/)" + r + "$")
- def all(self, ip, rem, sz, abspath):
+ def all(self, ip: str, rem: str, sz: float, abspath: str) -> tuple[str, str]:
self.chk_nup(ip)
self.chk_bup(ip)
self.chk_rem(rem)
@@ -87,18 +111,18 @@ class Lim(object):
return ap2, ("{}/{}".format(rem, vp2) if rem else vp2)
- def chk_sz(self, sz):
- if self.smin is not None and sz < self.smin:
+ def chk_sz(self, sz: float) -> None:
+ if self.smin != -1 and sz < self.smin:
raise Pebkac(400, "file too small")
- if self.smax is not None and sz > self.smax:
+ if self.smax != -1 and sz > self.smax:
raise Pebkac(400, "file too big")
- def chk_rem(self, rem):
+ def chk_rem(self, rem: str) -> None:
if self.nosub and rem:
raise Pebkac(500, "no subdirectories allowed")
- def rot(self, path):
+ def rot(self, path: str) -> tuple[str, str]:
if not self.rotf and not self.rotn:
return path, ""
@@ -120,7 +144,7 @@ class Lim(object):
d = ret[len(path) :].strip("/\\").replace("\\", "/")
return ret, d
- def dive(self, path, lvs):
+ def dive(self, path: str, lvs: int) -> Optional[str]:
items = bos.listdir(path)
if not lvs:
@@ -155,14 +179,14 @@ class Lim(object):
return os.path.join(sub, ret)
- def nup(self, ip):
+ def nup(self, ip: str) -> None:
try:
self.nups[ip].append(time.time())
except:
self.nups[ip] = [time.time()]
- def bup(self, ip, nbytes):
- v = [time.time(), nbytes]
+ def bup(self, ip: str, nbytes: int) -> None:
+ v = (time.time(), nbytes)
try:
self.bups[ip].append(v)
self.bupc[ip] += nbytes
@@ -170,7 +194,7 @@ class Lim(object):
self.bups[ip] = [v]
self.bupc[ip] = nbytes
- def chk_nup(self, ip):
+ def chk_nup(self, ip: str) -> None:
if not self.nmax or ip not in self.nups:
return
@@ -182,7 +206,7 @@ class Lim(object):
if len(nups) >= self.nmax:
raise Pebkac(429, "too many uploads")
- def chk_bup(self, ip):
+ def chk_bup(self, ip: str) -> None:
if not self.bmax or ip not in self.bups:
return
@@ -200,35 +224,37 @@ class Lim(object):
class VFS(object):
"""single level in the virtual fs"""
- def __init__(self, log, realpath, vpath, axs, flags):
+ def __init__(
+ self,
+ log: Optional[RootLogger],
+ realpath: str,
+ vpath: str,
+ axs: AXS,
+ flags: dict[str, Any],
+ ) -> None:
self.log = log
self.realpath = realpath # absolute path on host filesystem
self.vpath = vpath # absolute path in the virtual filesystem
- self.axs = axs # type: AXS
+ self.axs = axs
self.flags = flags # config options
- self.nodes = {} # child nodes
- self.histtab = None # all realpath->histpath
- self.dbv = None # closest full/non-jump parent
- self.lim = None # type: Lim # upload limits; only set for dbv
+ self.nodes: dict[str, VFS] = {} # child nodes
+ self.histtab: dict[str, str] = {} # all realpath->histpath
+ self.dbv: Optional[VFS] = None # closest full/non-jump parent
+ self.lim: Optional[Lim] = None # upload limits; only set for dbv
+ self.aread: dict[str, list[str]] = {}
+ self.awrite: dict[str, list[str]] = {}
+ self.amove: dict[str, list[str]] = {}
+ self.adel: dict[str, list[str]] = {}
+ self.aget: dict[str, list[str]] = {}
if realpath:
self.histpath = os.path.join(realpath, ".hist") # db / thumbcache
self.all_vols = {vpath: self} # flattened recursive
- self.aread = {}
- self.awrite = {}
- self.amove = {}
- self.adel = {}
- self.aget = {}
else:
- self.histpath = None
- self.all_vols = None
- self.aread = None
- self.awrite = None
- self.amove = None
- self.adel = None
- self.aget = None
+ self.histpath = ""
+ self.all_vols = {}
- def __repr__(self):
+ def __repr__(self) -> str:
return "VFS({})".format(
", ".join(
"{}={!r}".format(k, self.__dict__[k])
@@ -236,14 +262,14 @@ class VFS(object):
)
)
- def get_all_vols(self, outdict):
+ def get_all_vols(self, outdict: dict[str, "VFS"]) -> None:
if self.realpath:
outdict[self.vpath] = self
for v in self.nodes.values():
v.get_all_vols(outdict)
- def add(self, src, dst):
+ def add(self, src: str, dst: str) -> "VFS":
"""get existing, or add new path to the vfs"""
assert not src.endswith("/") # nosec
assert not dst.endswith("/") # nosec
@@ -257,7 +283,7 @@ class VFS(object):
vn = VFS(
self.log,
- os.path.join(self.realpath, name) if self.realpath else None,
+ os.path.join(self.realpath, name) if self.realpath else "",
"{}/{}".format(self.vpath, name).lstrip("/"),
self.axs,
self._copy_flags(name),
@@ -277,7 +303,7 @@ class VFS(object):
self.nodes[dst] = vn
return vn
- def _copy_flags(self, name):
+ def _copy_flags(self, name: str) -> dict[str, Any]:
flags = {k: v for k, v in self.flags.items()}
hist = flags.get("hist")
if hist and hist != "-":
@@ -285,20 +311,20 @@ class VFS(object):
return flags
- def bubble_flags(self):
+ def bubble_flags(self) -> None:
if self.dbv:
for k, v in self.dbv.flags.items():
if k not in ["hist"]:
self.flags[k] = v
- for v in self.nodes.values():
- v.bubble_flags()
+ for n in self.nodes.values():
+ n.bubble_flags()
- def _find(self, vpath):
+ def _find(self, vpath: str) -> tuple["VFS", str]:
"""return [vfs,remainder]"""
vpath = undot(vpath)
if vpath == "":
- return [self, ""]
+ return self, ""
if "/" in vpath:
name, rem = vpath.split("/", 1)
@@ -309,66 +335,64 @@ class VFS(object):
if name in self.nodes:
return self.nodes[name]._find(rem)
- return [self, vpath]
+ return self, vpath
- def can_access(self, vpath, uname):
- # type: (str, str) -> tuple[bool, bool, bool, bool]
+ def can_access(self, vpath: str, uname: str) -> tuple[bool, bool, bool, bool, bool]:
"""can Read,Write,Move,Delete,Get"""
vn, _ = self._find(vpath)
c = vn.axs
- return [
+ return (
uname in c.uread or "*" in c.uread,
uname in c.uwrite or "*" in c.uwrite,
uname in c.umove or "*" in c.umove,
uname in c.udel or "*" in c.udel,
uname in c.uget or "*" in c.uget,
- ]
+ )
def get(
self,
- vpath,
- uname,
- will_read,
- will_write,
- will_move=False,
- will_del=False,
- will_get=False,
- ):
- # type: (str, str, bool, bool, bool, bool, bool) -> tuple[VFS, str]
+ vpath: str,
+ uname: str,
+ will_read: bool,
+ will_write: bool,
+ will_move: bool = False,
+ will_del: bool = False,
+ will_get: bool = False,
+ ) -> tuple["VFS", str]:
"""returns [vfsnode,fs_remainder] if user has the requested permissions"""
if ANYWIN:
mod = relchk(vpath)
if mod:
- self.log("vfs", "invalid relpath [{}]".format(vpath))
+ if self.log:
+ self.log("vfs", "invalid relpath [{}]".format(vpath))
raise Pebkac(404)
vn, rem = self._find(vpath)
- c = vn.axs
+ c: AXS = vn.axs
for req, d, msg in [
- [will_read, c.uread, "read"],
- [will_write, c.uwrite, "write"],
- [will_move, c.umove, "move"],
- [will_del, c.udel, "delete"],
- [will_get, c.uget, "get"],
+ (will_read, c.uread, "read"),
+ (will_write, c.uwrite, "write"),
+ (will_move, c.umove, "move"),
+ (will_del, c.udel, "delete"),
+ (will_get, c.uget, "get"),
]:
if req and (uname not in d and "*" not in d) and uname != LEELOO_DALLAS:
- m = "you don't have {}-access for this location"
- raise Pebkac(403, m.format(msg))
+ t = "you don't have {}-access for this location"
+ raise Pebkac(403, t.format(msg))
return vn, rem
- def get_dbv(self, vrem):
- # type: (str) -> tuple[VFS, str]
+ def get_dbv(self, vrem: str) -> tuple["VFS", str]:
dbv = self.dbv
if not dbv:
return self, vrem
- vrem = [self.vpath[len(dbv.vpath) :].lstrip("/"), vrem]
- vrem = "/".join([x for x in vrem if x])
+ tv = [self.vpath[len(dbv.vpath) :].lstrip("/"), vrem]
+ vrem = "/".join([x for x in tv if x])
return dbv, vrem
- def canonical(self, rem, resolve=True):
+ def canonical(self, rem: str, resolve: bool = True) -> str:
"""returns the canonical path (fully-resolved absolute fs path)"""
rp = self.realpath
if rem:
@@ -376,8 +400,14 @@ class VFS(object):
return absreal(rp) if resolve else rp
- def ls(self, rem, uname, scandir, permsets, lstat=False):
- # type: (str, str, bool, list[list[bool]], bool) -> tuple[str, str, dict[str, VFS]]
+ def ls(
+ self,
+ rem: str,
+ uname: str,
+ scandir: bool,
+ permsets: list[list[bool]],
+ lstat: bool = False,
+ ) -> tuple[str, list[tuple[str, os.stat_result]], dict[str, "VFS"]]:
"""return user-readable [fsdir,real,virt] items at vpath"""
virt_vis = {} # nodes readable by user
abspath = self.canonical(rem)
@@ -389,8 +419,8 @@ class VFS(object):
for name, vn2 in sorted(self.nodes.items()):
ok = False
- axs = vn2.axs
- axs = [axs.uread, axs.uwrite, axs.umove, axs.udel, axs.uget]
+ zx = vn2.axs
+ axs = [zx.uread, zx.uwrite, zx.umove, zx.udel, zx.uget]
for pset in permsets:
ok = True
for req, lst in zip(pset, axs):
@@ -409,9 +439,32 @@ class VFS(object):
elif "/.hist/th/" in p:
real = [x for x in real if not x[0].endswith("dir.txt")]
- return [abspath, real, virt_vis]
+ return abspath, real, virt_vis
- def walk(self, rel, rem, seen, uname, permsets, dots, scandir, lstat, subvols=True):
+ def walk(
+ self,
+ rel: str,
+ rem: str,
+ seen: list[str],
+ uname: str,
+ permsets: list[list[bool]],
+ dots: bool,
+ scandir: bool,
+ lstat: bool,
+ subvols: bool = True,
+ ) -> Generator[
+ tuple[
+ "VFS",
+ str,
+ str,
+ str,
+ list[tuple[str, os.stat_result]],
+ list[tuple[str, os.stat_result]],
+ dict[str, "VFS"],
+ ],
+ None,
+ None,
+ ]:
"""
recursively yields from ./rem;
rel is a unix-style user-defined vpath (not vfs-related)
@@ -425,8 +478,9 @@ class VFS(object):
and (not fsroot.startswith(seen[-1]) or fsroot == seen[-1])
and fsroot in seen
):
- m = "bailing from symlink loop,\n prev: {}\n curr: {}\n from: {}/{}"
- self.log("vfs.walk", m.format(seen[-1], fsroot, self.vpath, rem), 3)
+ if self.log:
+ t = "bailing from symlink loop,\n prev: {}\n curr: {}\n from: {}/{}"
+ self.log("vfs.walk", t.format(seen[-1], fsroot, self.vpath, rem), 3)
return
seen = seen[:] + [fsroot]
@@ -460,9 +514,9 @@ class VFS(object):
for x in vfs.walk(wrel, "", seen, uname, permsets, dots, scandir, lstat):
yield x
- def zipgen(self, vrem, flt, uname, dots, scandir):
- if flt:
- flt = {k: True for k in flt}
+ def zipgen(
+ self, vrem: str, flt: set[str], uname: str, dots: bool, scandir: bool
+ ) -> Generator[dict[str, Any], None, None]:
# if multiselect: add all items to archive root
# if single folder: the folder itself is the top-level item
@@ -473,33 +527,33 @@ class VFS(object):
if flt:
files = [x for x in files if x[0] in flt]
- rm = [x for x in rd if x[0] not in flt]
- [rd.remove(x) for x in rm]
+ rm1 = [x for x in rd if x[0] not in flt]
+ _ = [rd.remove(x) for x in rm1] # type: ignore
- rm = [x for x in vd.keys() if x not in flt]
- [vd.pop(x) for x in rm]
+ rm2 = [x for x in vd.keys() if x not in flt]
+ _ = [vd.pop(x) for x in rm2]
- flt = None
+ flt = set()
# print(repr([vpath, apath, [x[0] for x in files]]))
fnames = [n[0] for n in files]
vpaths = [vpath + "/" + n for n in fnames] if vpath else fnames
apaths = [os.path.join(apath, n) for n in fnames]
- files = list(zip(vpaths, apaths, files))
+ ret = list(zip(vpaths, apaths, files))
if not dots:
# dotfile filtering based on vpath (intended visibility)
- files = [x for x in files if "/." not in "/" + x[0]]
+ ret = [x for x in ret if "/." not in "/" + x[0]]
- rm = [x for x in rd if x[0].startswith(".")]
- for x in rm:
- rd.remove(x)
+ zel = [ze for ze in rd if ze[0].startswith(".")]
+ for ze in zel:
+ rd.remove(ze)
- rm = [k for k in vd.keys() if k.startswith(".")]
- for x in rm:
- del vd[x]
+ zsl = [zs for zs in vd.keys() if zs.startswith(".")]
+ for zs in zsl:
+ del vd[zs]
- for f in [{"vp": v, "ap": a, "st": n[1]} for v, a, n in files]:
+ for f in [{"vp": v, "ap": a, "st": n[1]} for v, a, n in ret]:
yield f
@@ -512,7 +566,12 @@ else:
class AuthSrv(object):
"""verifies users against given paths"""
- def __init__(self, args, log_func, warn_anonwrite=True):
+ def __init__(
+ self,
+ args: argparse.Namespace,
+ log_func: Optional[RootLogger],
+ warn_anonwrite: bool = True,
+ ) -> None:
self.args = args
self.log_func = log_func
self.warn_anonwrite = warn_anonwrite
@@ -521,11 +580,11 @@ class AuthSrv(object):
self.mutex = threading.Lock()
self.reload()
- def log(self, msg, c=0):
+ def log(self, msg: str, c: Union[int, str] = 0) -> None:
if self.log_func:
self.log_func("auth", msg, c)
- def laggy_iter(self, iterable):
+ def laggy_iter(self, iterable: Iterable[Any]) -> Generator[Any, None, None]:
"""returns [value,isFinalValue]"""
it = iter(iterable)
prev = next(it)
@@ -535,26 +594,39 @@ class AuthSrv(object):
yield prev, True
- def _map_volume(self, src, dst, mount, daxs, mflags):
+ def _map_volume(
+ self,
+ src: str,
+ dst: str,
+ mount: dict[str, str],
+ daxs: dict[str, AXS],
+ mflags: dict[str, dict[str, Any]],
+ ) -> None:
if dst in mount:
- m = "multiple filesystem-paths mounted at [/{}]:\n [{}]\n [{}]"
- self.log(m.format(dst, mount[dst], src), c=1)
+ t = "multiple filesystem-paths mounted at [/{}]:\n [{}]\n [{}]"
+ self.log(t.format(dst, mount[dst], src), c=1)
raise Exception("invalid config")
if src in mount.values():
- m = "warning: filesystem-path [{}] mounted in multiple locations:"
- m = m.format(src)
+ t = "warning: filesystem-path [{}] mounted in multiple locations:"
+ t = t.format(src)
for v in [k for k, v in mount.items() if v == src] + [dst]:
- m += "\n /{}".format(v)
+ t += "\n /{}".format(v)
- self.log(m, c=3)
+ self.log(t, c=3)
mount[dst] = src
daxs[dst] = AXS()
mflags[dst] = {}
- def _parse_config_file(self, fd, acct, daxs, mflags, mount):
- # type: (any, str, dict[str, AXS], any, str) -> None
+ def _parse_config_file(
+ self,
+ fd: typing.BinaryIO,
+ acct: dict[str, str],
+ daxs: dict[str, AXS],
+ mflags: dict[str, dict[str, Any]],
+ mount: dict[str, str],
+ ) -> None:
skip = False
vol_src = None
vol_dst = None
@@ -601,23 +673,25 @@ class AuthSrv(object):
uname = "*"
if lvl == "a":
- m = "WARNING (config-file): permission flag 'a' is deprecated; please use 'rw' instead"
- self.log(m, 1)
+ t = "WARNING (config-file): permission flag 'a' is deprecated; please use 'rw' instead"
+ self.log(t, 1)
self._read_vol_str(lvl, uname, daxs[vol_dst], mflags[vol_dst])
- def _read_vol_str(self, lvl, uname, axs, flags):
- # type: (str, str, AXS, any) -> None
+ def _read_vol_str(
+ self, lvl: str, uname: str, axs: AXS, flags: dict[str, Any]
+ ) -> None:
if lvl.strip("crwmdg"):
raise Exception("invalid volume flag: {},{}".format(lvl, uname))
if lvl == "c":
+ cval: Union[bool, str] = True
try:
# volume flag with arguments, possibly with a preceding list of bools
uname, cval = uname.split("=", 1)
except:
# just one or more bools
- cval = True
+ pass
while "," in uname:
# one or more bools before the final flag; eat them
@@ -631,34 +705,38 @@ class AuthSrv(object):
uname = "*"
for un in uname.replace(",", " ").strip().split():
- if "r" in lvl:
- axs.uread[un] = 1
+ for ch, al in [
+ ("r", axs.uread),
+ ("w", axs.uwrite),
+ ("m", axs.umove),
+ ("d", axs.udel),
+ ("g", axs.uget),
+ ]:
+ if ch in lvl:
+ al.add(un)
- if "w" in lvl:
- axs.uwrite[un] = 1
-
- if "m" in lvl:
- axs.umove[un] = 1
-
- if "d" in lvl:
- axs.udel[un] = 1
-
- if "g" in lvl:
- axs.uget[un] = 1
-
- def _read_volflag(self, flags, name, value, is_list):
+ def _read_volflag(
+ self,
+ flags: dict[str, Any],
+ name: str,
+ value: Union[str, bool, list[str]],
+ is_list: bool,
+ ) -> None:
if name not in ["mtp"]:
flags[name] = value
return
- if not is_list:
- value = [value]
- elif not value:
+ vals = flags.get(name, [])
+ if not value:
return
+ elif is_list:
+ vals += value
+ else:
+ vals += [value]
- flags[name] = flags.get(name, []) + value
+ flags[name] = vals
- def reload(self):
+ def reload(self) -> None:
"""
construct a flat list of mountpoints and usernames
first from the commandline arguments
@@ -666,10 +744,10 @@ class AuthSrv(object):
before finally building the VFS
"""
- acct = {} # username:password
- daxs = {} # type: dict[str, AXS]
- mflags = {} # mountpoint:[flag]
- mount = {} # dst:src (mountpoint:realpath)
+ acct: dict[str, str] = {} # username:password
+ daxs: dict[str, AXS] = {}
+ mflags: dict[str, dict[str, Any]] = {} # moutpoint:flags
+ mount: dict[str, str] = {} # dst:src (mountpoint:realpath)
if self.args.a:
# list of username:password
@@ -678,8 +756,8 @@ class AuthSrv(object):
u, p = x.split(":", 1)
acct[u] = p
except:
- m = '\n invalid value "{}" for argument -a, must be username:password'
- raise Exception(m.format(x))
+ t = '\n invalid value "{}" for argument -a, must be username:password'
+ raise Exception(t.format(x))
if self.args.v:
# list of src:dst:permset:permset:...
@@ -708,8 +786,8 @@ class AuthSrv(object):
try:
self._parse_config_file(f, acct, daxs, mflags, mount)
except:
- m = "\n\033[1;31m\nerror in config file {} on line {}:\n\033[0m"
- self.log(m.format(cfg_fn, self.line_ctr), 1)
+ t = "\n\033[1;31m\nerror in config file {} on line {}:\n\033[0m"
+ self.log(t.format(cfg_fn, self.line_ctr), 1)
raise
# case-insensitive; normalize
@@ -726,7 +804,7 @@ class AuthSrv(object):
vfs = VFS(self.log_func, bos.path.abspath("."), "", axs, {})
elif "" not in mount:
# there's volumes but no root; make root inaccessible
- vfs = VFS(self.log_func, None, "", AXS(), {})
+ vfs = VFS(self.log_func, "", "", AXS(), {})
vfs.flags["d2d"] = True
maxdepth = 0
@@ -740,10 +818,10 @@ class AuthSrv(object):
vfs = VFS(self.log_func, mount[dst], dst, daxs[dst], mflags[dst])
continue
- v = vfs.add(mount[dst], dst)
- v.axs = daxs[dst]
- v.flags = mflags[dst]
- v.dbv = None
+ zv = vfs.add(mount[dst], dst)
+ zv.axs = daxs[dst]
+ zv.flags = mflags[dst]
+ zv.dbv = None
vfs.all_vols = {}
vfs.get_all_vols(vfs.all_vols)
@@ -751,11 +829,11 @@ class AuthSrv(object):
for perm in "read write move del get".split():
axs_key = "u" + perm
unames = ["*"] + list(acct.keys())
- umap = {x: [] for x in unames}
+ umap: dict[str, list[str]] = {x: [] for x in unames}
for usr in unames:
for vp, vol in vfs.all_vols.items():
- axs = getattr(vol.axs, axs_key)
- if usr in axs or "*" in axs:
+ zx = getattr(vol.axs, axs_key)
+ if usr in zx or "*" in zx:
umap[usr].append(vp)
umap[usr].sort()
setattr(vfs, "a" + perm, umap)
@@ -764,7 +842,7 @@ class AuthSrv(object):
missing_users = {}
for axs in daxs.values():
for d in [axs.uread, axs.uwrite, axs.umove, axs.udel, axs.uget]:
- for usr in d.keys():
+ for usr in d:
all_users[usr] = 1
if usr != "*" and usr not in acct:
missing_users[usr] = 1
@@ -783,8 +861,8 @@ class AuthSrv(object):
promote = []
demote = []
for vol in vfs.all_vols.values():
- hid = hashlib.sha512(fsenc(vol.realpath)).digest()
- hid = base64.b32encode(hid).decode("ascii").lower()
+ zb = hashlib.sha512(fsenc(vol.realpath)).digest()
+ hid = base64.b32encode(zb).decode("ascii").lower()
vflag = vol.flags.get("hist")
if vflag == "-":
pass
@@ -822,21 +900,21 @@ class AuthSrv(object):
demote.append(vol)
# discard jump-vols
- for v in demote:
- vfs.all_vols.pop(v.vpath)
+ for zv in demote:
+ vfs.all_vols.pop(zv.vpath)
if promote:
- msg = [
+ ta = [
"\n the following jump-volumes were generated to assist the vfs.\n As they contain a database (probably from v0.11.11 or older),\n they are promoted to full volumes:"
]
for vol in promote:
- msg.append(
+ ta.append(
" /{} ({}) ({})".format(vol.vpath, vol.realpath, vol.histpath)
)
- self.log("\n\n".join(msg) + "\n", c=3)
+ self.log("\n\n".join(ta) + "\n", c=3)
- vfs.histtab = {v.realpath: v.histpath for v in vfs.all_vols.values()}
+ vfs.histtab = {zv.realpath: zv.histpath for zv in vfs.all_vols.values()}
for vol in vfs.all_vols.values():
lim = Lim()
@@ -846,30 +924,30 @@ class AuthSrv(object):
use = True
lim.nosub = True
- v = vol.flags.get("sz")
- if v:
+ zs = vol.flags.get("sz")
+ if zs:
use = True
- lim.smin, lim.smax = [unhumanize(x) for x in v.split("-")]
+ lim.smin, lim.smax = [unhumanize(x) for x in zs.split("-")]
- v = vol.flags.get("rotn")
- if v:
+ zs = vol.flags.get("rotn")
+ if zs:
use = True
- lim.rotn, lim.rotl = [int(x) for x in v.split(",")]
+ lim.rotn, lim.rotl = [int(x) for x in zs.split(",")]
- v = vol.flags.get("rotf")
- if v:
+ zs = vol.flags.get("rotf")
+ if zs:
use = True
- lim.set_rotf(v)
+ lim.set_rotf(zs)
- v = vol.flags.get("maxn")
- if v:
+ zs = vol.flags.get("maxn")
+ if zs:
use = True
- lim.nmax, lim.nwin = [int(x) for x in v.split(",")]
+ lim.nmax, lim.nwin = [int(x) for x in zs.split(",")]
- v = vol.flags.get("maxb")
- if v:
+ zs = vol.flags.get("maxb")
+ if zs:
use = True
- lim.bmax, lim.bwin = [unhumanize(x) for x in v.split(",")]
+ lim.bmax, lim.bwin = [unhumanize(x) for x in zs.split(",")]
if use:
vol.lim = lim
@@ -1005,8 +1083,8 @@ class AuthSrv(object):
for mtp in local_only_mtp:
if mtp not in local_mte:
- m = 'volume "/{}" defines metadata tag "{}", but doesnt use it in "-mte" (or with "cmte" in its volume-flags)'
- self.log(m.format(vol.vpath, mtp), 1)
+ t = 'volume "/{}" defines metadata tag "{}", but doesnt use it in "-mte" (or with "cmte" in its volume-flags)'
+ self.log(t.format(vol.vpath, mtp), 1)
errors = True
tags = self.args.mtp or []
@@ -1014,8 +1092,8 @@ class AuthSrv(object):
tags = [y for x in tags for y in x.split(",")]
for mtp in tags:
if mtp not in all_mte:
- m = 'metadata tag "{}" is defined by "-mtm" or "-mtp", but is not used by "-mte" (or by any "cmte" volume-flag)'
- self.log(m.format(mtp), 1)
+ t = 'metadata tag "{}" is defined by "-mtm" or "-mtp", but is not used by "-mte" (or by any "cmte" volume-flag)'
+ self.log(t.format(mtp), 1)
errors = True
if errors:
@@ -1023,12 +1101,12 @@ class AuthSrv(object):
vfs.bubble_flags()
- m = "volumes and permissions:\n"
- for v in vfs.all_vols.values():
+ t = "volumes and permissions:\n"
+ for zv in vfs.all_vols.values():
if not self.warn_anonwrite:
break
- m += '\n\033[36m"/{}" \033[33m{}\033[0m'.format(v.vpath, v.realpath)
+ t += '\n\033[36m"/{}" \033[33m{}\033[0m'.format(zv.vpath, zv.realpath)
for txt, attr in [
[" read", "uread"],
[" write", "uwrite"],
@@ -1036,21 +1114,21 @@ class AuthSrv(object):
["delete", "udel"],
[" get", "uget"],
]:
- u = list(sorted(getattr(v.axs, attr).keys()))
+ u = list(sorted(getattr(zv.axs, attr)))
u = ", ".join("\033[35meverybody\033[0m" if x == "*" else x for x in u)
u = u if u else "\033[36m--none--\033[0m"
- m += "\n| {}: {}".format(txt, u)
- m += "\n"
+ t += "\n| {}: {}".format(txt, u)
+ t += "\n"
if self.warn_anonwrite and not self.args.no_voldump:
- self.log(m)
+ self.log(t)
try:
- v, _ = vfs.get("/", "*", False, True)
- if self.warn_anonwrite and os.getcwd() == v.realpath:
+ zv, _ = vfs.get("/", "*", False, True)
+ if self.warn_anonwrite and os.getcwd() == zv.realpath:
self.warn_anonwrite = False
- msg = "anyone can read/write the current directory: {}\n"
- self.log(msg.format(v.realpath), c=1)
+ t = "anyone can read/write the current directory: {}\n"
+ self.log(t.format(zv.realpath), c=1)
except Pebkac:
self.warn_anonwrite = True
@@ -1064,19 +1142,19 @@ class AuthSrv(object):
if pwds:
self.re_pwd = re.compile("=(" + "|".join(pwds) + ")([]&; ]|$)")
- def dbg_ls(self):
+ def dbg_ls(self) -> None:
users = self.args.ls
- vols = "*"
- flags = []
+ vol = "*"
+ flags: list[str] = []
try:
- users, vols = users.split(",", 1)
+ users, vol = users.split(",", 1)
except:
pass
try:
- vols, flags = vols.split(",", 1)
- flags = flags.split(",")
+ vol, zf = vol.split(",", 1)
+ flags = zf.split(",")
except:
pass
@@ -1089,23 +1167,23 @@ class AuthSrv(object):
if u not in self.acct and u != "*":
raise Exception("user not found: " + u)
- if vols == "*":
+ if vol == "*":
vols = ["/" + x for x in self.vfs.all_vols]
else:
- vols = [vols]
+ vols = [vol]
- for v in vols:
- if not v.startswith("/"):
+ for zs in vols:
+ if not zs.startswith("/"):
raise Exception("volumes must start with /")
- if v[1:] not in self.vfs.all_vols:
- raise Exception("volume not found: " + v)
+ if zs[1:] not in self.vfs.all_vols:
+ raise Exception("volume not found: " + zs)
- self.log({"users": users, "vols": vols, "flags": flags})
- m = "/{}: read({}) write({}) move({}) del({}) get({})"
- for k, v in self.vfs.all_vols.items():
- vc = v.axs
- self.log(m.format(k, vc.uread, vc.uwrite, vc.umove, vc.udel, vc.uget))
+ self.log(str({"users": users, "vols": vols, "flags": flags}))
+ t = "/{}: read({}) write({}) move({}) del({}) get({})"
+ for k, zv in self.vfs.all_vols.items():
+ vc = zv.axs
+ self.log(t.format(k, vc.uread, vc.uwrite, vc.umove, vc.udel, vc.uget))
flag_v = "v" in flags
flag_ln = "ln" in flags
@@ -1136,12 +1214,14 @@ class AuthSrv(object):
False,
False,
)
- for _, _, vpath, apath, files, dirs, _ in g:
- fnames = [n[0] for n in files]
- vpaths = [vpath + "/" + n for n in fnames] if vpath else fnames
- vpaths = [vtop + x for x in vpaths]
+ for _, _, vpath, apath, files1, dirs, _ in g:
+ fnames = [n[0] for n in files1]
+ zsl = [vpath + "/" + n for n in fnames] if vpath else fnames
+ vpaths = [vtop + x for x in zsl]
apaths = [os.path.join(apath, n) for n in fnames]
- files = [[vpath + "/", apath + os.sep]] + list(zip(vpaths, apaths))
+ files = [(vpath + "/", apath + os.sep)] + list(
+ [(zs1, zs2) for zs1, zs2 in zip(vpaths, apaths)]
+ )
if flag_ln:
files = [x for x in files if not x[1].startswith(safeabs)]
@@ -1152,21 +1232,23 @@ class AuthSrv(object):
if not files:
continue
elif flag_v:
- msg = [""] + [
+ ta = [""] + [
'# user "{}", vpath "{}"\n{}'.format(u, vp, ap)
for vp, ap in files
]
else:
- msg = ["user {}, vol {}: {} =>".format(u, vtop, files[0][0])]
- msg += [x[1] for x in files]
+ ta = ["user {}, vol {}: {} =>".format(u, vtop, files[0][0])]
+ ta += [x[1] for x in files]
- self.log("\n".join(msg))
+ self.log("\n".join(ta))
if bads:
self.log("\n ".join(["found symlinks leaving volume:"] + bads))
if bads and flag_p:
- raise Exception("found symlink leaving volume, and strict is set")
+ raise Exception(
+ "\033[31m\n [--ls] found a safety issue and prevented startup:\n found symlinks leaving volume, and strict is set\n\033[0m"
+ )
if not flag_r:
sys.exit(0)
diff --git a/copyparty/bos/bos.py b/copyparty/bos/bos.py
index d5e003cf..617545af 100644
--- a/copyparty/bos/bos.py
+++ b/copyparty/bos/bos.py
@@ -2,23 +2,30 @@
from __future__ import print_function, unicode_literals
import os
-from ..util import fsenc, fsdec, SYMTIME
+
+from ..util import SYMTIME, fsdec, fsenc
from . import path
+try:
+ from typing import Optional
+except:
+ pass
+
+_ = (path,)
# grep -hRiE '(^|[^a-zA-Z_\.-])os\.' . | gsed -r 's/ /\n/g;s/\(/(\n/g' | grep -hRiE '(^|[^a-zA-Z_\.-])os\.' | sort | uniq -c
# printf 'os\.(%s)' "$(grep ^def bos/__init__.py | gsed -r 's/^def //;s/\(.*//' | tr '\n' '|' | gsed -r 's/.$//')"
-def chmod(p, mode):
+def chmod(p: str, mode: int) -> None:
return os.chmod(fsenc(p), mode)
-def listdir(p="."):
+def listdir(p: str = ".") -> list[str]:
return [fsdec(x) for x in os.listdir(fsenc(p))]
-def makedirs(name, mode=0o755, exist_ok=True):
+def makedirs(name: str, mode: int = 0o755, exist_ok: bool = True) -> None:
bname = fsenc(name)
try:
os.makedirs(bname, mode)
@@ -27,31 +34,33 @@ def makedirs(name, mode=0o755, exist_ok=True):
raise
-def mkdir(p, mode=0o755):
+def mkdir(p: str, mode: int = 0o755) -> None:
return os.mkdir(fsenc(p), mode)
-def rename(src, dst):
+def rename(src: str, dst: str) -> None:
return os.rename(fsenc(src), fsenc(dst))
-def replace(src, dst):
+def replace(src: str, dst: str) -> None:
return os.replace(fsenc(src), fsenc(dst))
-def rmdir(p):
+def rmdir(p: str) -> None:
return os.rmdir(fsenc(p))
-def stat(p):
+def stat(p: str) -> os.stat_result:
return os.stat(fsenc(p))
-def unlink(p):
+def unlink(p: str) -> None:
return os.unlink(fsenc(p))
-def utime(p, times=None, follow_symlinks=True):
+def utime(
+ p: str, times: Optional[tuple[float, float]] = None, follow_symlinks: bool = True
+) -> None:
if SYMTIME:
return os.utime(fsenc(p), times, follow_symlinks=follow_symlinks)
else:
@@ -60,7 +69,7 @@ def utime(p, times=None, follow_symlinks=True):
if hasattr(os, "lstat"):
- def lstat(p):
+ def lstat(p: str) -> os.stat_result:
return os.lstat(fsenc(p))
else:
diff --git a/copyparty/bos/path.py b/copyparty/bos/path.py
index 066453b0..c5769d84 100644
--- a/copyparty/bos/path.py
+++ b/copyparty/bos/path.py
@@ -2,43 +2,44 @@
from __future__ import print_function, unicode_literals
import os
-from ..util import fsenc, fsdec, SYMTIME
+
+from ..util import SYMTIME, fsdec, fsenc
-def abspath(p):
+def abspath(p: str) -> str:
return fsdec(os.path.abspath(fsenc(p)))
-def exists(p):
+def exists(p: str) -> bool:
return os.path.exists(fsenc(p))
-def getmtime(p, follow_symlinks=True):
+def getmtime(p: str, follow_symlinks: bool = True) -> float:
if not follow_symlinks and SYMTIME:
return os.lstat(fsenc(p)).st_mtime
else:
return os.path.getmtime(fsenc(p))
-def getsize(p):
+def getsize(p: str) -> int:
return os.path.getsize(fsenc(p))
-def isfile(p):
+def isfile(p: str) -> bool:
return os.path.isfile(fsenc(p))
-def isdir(p):
+def isdir(p: str) -> bool:
return os.path.isdir(fsenc(p))
-def islink(p):
+def islink(p: str) -> bool:
return os.path.islink(fsenc(p))
-def lexists(p):
+def lexists(p: str) -> bool:
return os.path.lexists(fsenc(p))
-def realpath(p):
+def realpath(p: str) -> str:
return fsdec(os.path.realpath(fsenc(p)))
diff --git a/copyparty/broker_mp.py b/copyparty/broker_mp.py
index df73658d..c7bfed70 100644
--- a/copyparty/broker_mp.py
+++ b/copyparty/broker_mp.py
@@ -1,37 +1,56 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
-import time
import threading
+import time
-from .broker_util import try_exec
+import queue
+
+from .__init__ import TYPE_CHECKING
from .broker_mpw import MpWorker
+from .broker_util import try_exec
from .util import mp
+if TYPE_CHECKING:
+ from .svchub import SvcHub
+
+try:
+ from typing import Any
+except:
+ pass
+
+
+class MProcess(mp.Process):
+ def __init__(
+ self,
+ q_pend: queue.Queue[tuple[int, str, list[Any]]],
+ q_yield: queue.Queue[tuple[int, str, list[Any]]],
+ target: Any,
+ args: Any,
+ ) -> None:
+ super(MProcess, self).__init__(target=target, args=args)
+ self.q_pend = q_pend
+ self.q_yield = q_yield
+
class BrokerMp(object):
"""external api; manages MpWorkers"""
- def __init__(self, hub):
+ def __init__(self, hub: "SvcHub") -> None:
self.hub = hub
self.log = hub.log
self.args = hub.args
self.procs = []
- self.retpend = {}
- self.retpend_mutex = threading.Lock()
self.mutex = threading.Lock()
self.num_workers = self.args.j or mp.cpu_count()
self.log("broker", "booting {} subprocesses".format(self.num_workers))
for n in range(1, self.num_workers + 1):
- q_pend = mp.Queue(1)
- q_yield = mp.Queue(64)
+ q_pend: queue.Queue[tuple[int, str, list[Any]]] = mp.Queue(1)
+ q_yield: queue.Queue[tuple[int, str, list[Any]]] = mp.Queue(64)
- proc = mp.Process(target=MpWorker, args=(q_pend, q_yield, self.args, n))
- proc.q_pend = q_pend
- proc.q_yield = q_yield
- proc.clients = {}
+ proc = MProcess(q_pend, q_yield, MpWorker, (q_pend, q_yield, self.args, n))
thr = threading.Thread(
target=self.collector, args=(proc,), name="mp-sink-{}".format(n)
@@ -42,11 +61,11 @@ class BrokerMp(object):
self.procs.append(proc)
proc.start()
- def shutdown(self):
+ def shutdown(self) -> None:
self.log("broker", "shutting down")
for n, proc in enumerate(self.procs):
thr = threading.Thread(
- target=proc.q_pend.put([0, "shutdown", []]),
+ target=proc.q_pend.put((0, "shutdown", [])),
name="mp-shutdown-{}-{}".format(n, len(self.procs)),
)
thr.start()
@@ -62,12 +81,12 @@ class BrokerMp(object):
procs.pop()
- def reload(self):
+ def reload(self) -> None:
self.log("broker", "reloading")
for _, proc in enumerate(self.procs):
- proc.q_pend.put([0, "reload", []])
+ proc.q_pend.put((0, "reload", []))
- def collector(self, proc):
+ def collector(self, proc: MProcess) -> None:
"""receive message from hub in other process"""
while True:
msg = proc.q_yield.get()
@@ -78,10 +97,7 @@ class BrokerMp(object):
elif dest == "retq":
# response from previous ipc call
- with self.retpend_mutex:
- retq = self.retpend.pop(retq_id)
-
- retq.put(args)
+ raise Exception("invalid broker_mp usage")
else:
# new ipc invoking managed service in hub
@@ -93,9 +109,9 @@ class BrokerMp(object):
rv = try_exec(retq_id, obj, *args)
if retq_id:
- proc.q_pend.put([retq_id, "retq", rv])
+ proc.q_pend.put((retq_id, "retq", rv))
- def put(self, want_retval, dest, *args):
+ def say(self, dest: str, *args: Any) -> None:
"""
send message to non-hub component in other process,
returns a Queue object which eventually contains the response if want_retval
@@ -103,7 +119,7 @@ class BrokerMp(object):
"""
if dest == "listen":
for p in self.procs:
- p.q_pend.put([0, dest, [args[0], len(self.procs)]])
+ p.q_pend.put((0, dest, [args[0], len(self.procs)]))
elif dest == "cb_httpsrv_up":
self.hub.cb_httpsrv_up()
diff --git a/copyparty/broker_mpw.py b/copyparty/broker_mpw.py
index c4a1054c..4dbec5b1 100644
--- a/copyparty/broker_mpw.py
+++ b/copyparty/broker_mpw.py
@@ -1,20 +1,38 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
-import sys
+import argparse
import signal
+import sys
import threading
-from .broker_util import ExceptionalQueue
+import queue
+
+from .authsrv import AuthSrv
+from .broker_util import BrokerCli, ExceptionalQueue
from .httpsrv import HttpSrv
from .util import FAKE_MP
-from .authsrv import AuthSrv
+
+try:
+ from types import FrameType
+
+ from typing import Any, Optional, Union
+except:
+ pass
-class MpWorker(object):
+class MpWorker(BrokerCli):
"""one single mp instance"""
- def __init__(self, q_pend, q_yield, args, n):
+ def __init__(
+ self,
+ q_pend: queue.Queue[tuple[int, str, list[Any]]],
+ q_yield: queue.Queue[tuple[int, str, list[Any]]],
+ args: argparse.Namespace,
+ n: int,
+ ) -> None:
+ super(MpWorker, self).__init__()
+
self.q_pend = q_pend
self.q_yield = q_yield
self.args = args
@@ -22,7 +40,7 @@ class MpWorker(object):
self.log = self._log_disabled if args.q and not args.lo else self._log_enabled
- self.retpend = {}
+ self.retpend: dict[int, Any] = {}
self.retpend_mutex = threading.Lock()
self.mutex = threading.Lock()
@@ -45,20 +63,20 @@ class MpWorker(object):
thr.start()
thr.join()
- def signal_handler(self, sig, frame):
+ def signal_handler(self, sig: Optional[int], frame: Optional[FrameType]) -> None:
# print('k')
pass
- def _log_enabled(self, src, msg, c=0):
- self.q_yield.put([0, "log", [src, msg, c]])
+ def _log_enabled(self, src: str, msg: str, c: Union[int, str] = 0) -> None:
+ self.q_yield.put((0, "log", [src, msg, c]))
- def _log_disabled(self, src, msg, c=0):
+ def _log_disabled(self, src: str, msg: str, c: Union[int, str] = 0) -> None:
pass
- def logw(self, msg, c=0):
+ def logw(self, msg: str, c: Union[int, str] = 0) -> None:
self.log("mp{}".format(self.n), msg, c)
- def main(self):
+ def main(self) -> None:
while True:
retq_id, dest, args = self.q_pend.get()
@@ -87,15 +105,14 @@ class MpWorker(object):
else:
raise Exception("what is " + str(dest))
- def put(self, want_retval, dest, *args):
- if want_retval:
- retq = ExceptionalQueue(1)
- retq_id = id(retq)
- with self.retpend_mutex:
- self.retpend[retq_id] = retq
- else:
- retq = None
- retq_id = 0
+ def ask(self, dest: str, *args: Any) -> ExceptionalQueue:
+ retq = ExceptionalQueue(1)
+ retq_id = id(retq)
+ with self.retpend_mutex:
+ self.retpend[retq_id] = retq
- self.q_yield.put([retq_id, dest, args])
+ self.q_yield.put((retq_id, dest, list(args)))
return retq
+
+ def say(self, dest: str, *args: Any) -> None:
+ self.q_yield.put((0, dest, list(args)))
diff --git a/copyparty/broker_thr.py b/copyparty/broker_thr.py
index 1c7a1abf..51c25d41 100644
--- a/copyparty/broker_thr.py
+++ b/copyparty/broker_thr.py
@@ -3,14 +3,25 @@ from __future__ import print_function, unicode_literals
import threading
+from .__init__ import TYPE_CHECKING
+from .broker_util import BrokerCli, ExceptionalQueue, try_exec
from .httpsrv import HttpSrv
-from .broker_util import ExceptionalQueue, try_exec
+
+if TYPE_CHECKING:
+ from .svchub import SvcHub
+
+try:
+ from typing import Any
+except:
+ pass
-class BrokerThr(object):
+class BrokerThr(BrokerCli):
"""external api; behaves like BrokerMP but using plain threads"""
- def __init__(self, hub):
+ def __init__(self, hub: "SvcHub") -> None:
+ super(BrokerThr, self).__init__()
+
self.hub = hub
self.log = hub.log
self.args = hub.args
@@ -23,29 +34,35 @@ class BrokerThr(object):
self.httpsrv = HttpSrv(self, None)
self.reload = self.noop
- def shutdown(self):
+ def shutdown(self) -> None:
# self.log("broker", "shutting down")
self.httpsrv.shutdown()
- def noop(self):
+ def noop(self) -> None:
pass
- def put(self, want_retval, dest, *args):
+ def ask(self, dest: str, *args: Any) -> ExceptionalQueue:
+
+ # new ipc invoking managed service in hub
+ obj = self.hub
+ for node in dest.split("."):
+ obj = getattr(obj, node)
+
+ rv = try_exec(True, obj, *args)
+
+ # pretend we're broker_mp
+ retq = ExceptionalQueue(1)
+ retq.put(rv)
+ return retq
+
+ def say(self, dest: str, *args: Any) -> None:
if dest == "listen":
self.httpsrv.listen(args[0], 1)
+ return
- else:
- # new ipc invoking managed service in hub
- obj = self.hub
- for node in dest.split("."):
- obj = getattr(obj, node)
+ # new ipc invoking managed service in hub
+ obj = self.hub
+ for node in dest.split("."):
+ obj = getattr(obj, node)
- # TODO will deadlock if dest performs another ipc
- rv = try_exec(want_retval, obj, *args)
- if not want_retval:
- return
-
- # pretend we're broker_mp
- retq = ExceptionalQueue(1)
- retq.put(rv)
- return retq
+ try_exec(False, obj, *args)
diff --git a/copyparty/broker_util.py b/copyparty/broker_util.py
index 896cba2c..b0d44575 100644
--- a/copyparty/broker_util.py
+++ b/copyparty/broker_util.py
@@ -1,14 +1,28 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
-
+import argparse
import traceback
-from .util import Pebkac, Queue
+from queue import Queue
+
+from .__init__ import TYPE_CHECKING
+from .authsrv import AuthSrv
+from .util import Pebkac
+
+try:
+ from typing import Any, Optional, Union
+
+ from .util import RootLogger
+except:
+ pass
+
+if TYPE_CHECKING:
+ from .httpsrv import HttpSrv
class ExceptionalQueue(Queue, object):
- def get(self, block=True, timeout=None):
+ def get(self, block: bool = True, timeout: Optional[float] = None) -> Any:
rv = super(ExceptionalQueue, self).get(block, timeout)
if isinstance(rv, list):
@@ -21,7 +35,26 @@ class ExceptionalQueue(Queue, object):
return rv
-def try_exec(want_retval, func, *args):
+class BrokerCli(object):
+ """
+ helps mypy understand httpsrv.broker but still fails a few levels deeper,
+ for example resolving httpconn.* in httpcli -- see lines tagged #mypy404
+ """
+
+ def __init__(self) -> None:
+ self.log: RootLogger = None
+ self.args: argparse.Namespace = None
+ self.asrv: AuthSrv = None
+ self.httpsrv: "HttpSrv" = None
+
+ def ask(self, dest: str, *args: Any) -> ExceptionalQueue:
+ return ExceptionalQueue(1)
+
+ def say(self, dest: str, *args: Any) -> None:
+ pass
+
+
+def try_exec(want_retval: Union[bool, int], func: Any, *args: list[Any]) -> Any:
try:
return func(*args)
diff --git a/copyparty/ftpd.py b/copyparty/ftpd.py
index 828c0db6..a47586de 100644
--- a/copyparty/ftpd.py
+++ b/copyparty/ftpd.py
@@ -1,23 +1,23 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
-import os
-import sys
-import stat
-import time
-import logging
import argparse
+import logging
+import os
+import stat
+import sys
import threading
+import time
-from pyftpdlib.authorizers import DummyAuthorizer, AuthenticationFailed
+from pyftpdlib.authorizers import AuthenticationFailed, DummyAuthorizer
from pyftpdlib.filesystems import AbstractedFS, FilesystemError
from pyftpdlib.handlers import FTPHandler
-from pyftpdlib.servers import FTPServer
from pyftpdlib.log import config_logging
+from pyftpdlib.servers import FTPServer
-from .__init__ import E, PY2
-from .util import Pebkac, fsenc, exclude_dotfiles
+from .__init__ import PY2, TYPE_CHECKING, E
from .bos import bos
+from .util import Pebkac, exclude_dotfiles, fsenc
try:
from pyftpdlib.ioloop import IOLoop
@@ -28,58 +28,63 @@ except ImportError:
from pyftpdlib.ioloop import IOLoop
-try:
- from typing import TYPE_CHECKING
+if TYPE_CHECKING:
+ from .svchub import SvcHub
- if TYPE_CHECKING:
- from .svchub import SvcHub
-except ImportError:
+try:
+ import typing
+ from typing import Any, Optional
+except:
pass
class FtpAuth(DummyAuthorizer):
- def __init__(self):
+ def __init__(self, hub: "SvcHub") -> None:
super(FtpAuth, self).__init__()
- self.hub = None # type: SvcHub
+ self.hub = hub
- def validate_authentication(self, username, password, handler):
+ def validate_authentication(
+ self, username: str, password: str, handler: Any
+ ) -> None:
asrv = self.hub.asrv
if username == "anonymous":
password = ""
uname = "*"
if password:
- uname = asrv.iacct.get(password, None)
+ uname = asrv.iacct.get(password, "")
handler.username = uname
if password and not uname:
raise AuthenticationFailed("Authentication failed.")
- def get_home_dir(self, username):
+ def get_home_dir(self, username: str) -> str:
return "/"
- def has_user(self, username):
+ def has_user(self, username: str) -> bool:
asrv = self.hub.asrv
return username in asrv.acct
- def has_perm(self, username, perm, path=None):
+ def has_perm(self, username: str, perm: int, path: Optional[str] = None) -> bool:
return True # handled at filesystem layer
- def get_perms(self, username):
+ def get_perms(self, username: str) -> str:
return "elradfmwMT"
- def get_msg_login(self, username):
+ def get_msg_login(self, username: str) -> str:
return "sup {}".format(username)
- def get_msg_quit(self, username):
+ def get_msg_quit(self, username: str) -> str:
return "cya"
class FtpFs(AbstractedFS):
- def __init__(self, root, cmd_channel):
+ def __init__(
+ self, root: str, cmd_channel: Any
+ ) -> None: # pylint: disable=super-init-not-called
self.h = self.cmd_channel = cmd_channel # type: FTPHandler
- self.hub = cmd_channel.hub # type: SvcHub
+ self.hub: "SvcHub" = cmd_channel.hub
self.args = cmd_channel.args
self.uname = self.hub.asrv.iacct.get(cmd_channel.password, "*")
@@ -90,7 +95,14 @@ class FtpFs(AbstractedFS):
self.listdirinfo = self.listdir
self.chdir(".")
- def v2a(self, vpath, r=False, w=False, m=False, d=False):
+ def v2a(
+ self,
+ vpath: str,
+ r: bool = False,
+ w: bool = False,
+ m: bool = False,
+ d: bool = False,
+ ) -> str:
try:
vpath = vpath.replace("\\", "/").lstrip("/")
vfs, rem = self.hub.asrv.vfs.get(vpath, self.uname, r, w, m, d)
@@ -101,25 +113,32 @@ class FtpFs(AbstractedFS):
except Pebkac as ex:
raise FilesystemError(str(ex))
- def rv2a(self, vpath, r=False, w=False, m=False, d=False):
+ def rv2a(
+ self,
+ vpath: str,
+ r: bool = False,
+ w: bool = False,
+ m: bool = False,
+ d: bool = False,
+ ) -> str:
return self.v2a(os.path.join(self.cwd, vpath), r, w, m, d)
- def ftp2fs(self, ftppath):
+ def ftp2fs(self, ftppath: str) -> str:
# return self.v2a(ftppath)
return ftppath # self.cwd must be vpath
- def fs2ftp(self, fspath):
+ def fs2ftp(self, fspath: str) -> str:
# raise NotImplementedError()
return fspath
- def validpath(self, path):
+ def validpath(self, path: str) -> bool:
if "/.hist/" in path:
if "/up2k." in path or path.endswith("/dir.txt"):
raise FilesystemError("access to this file is forbidden")
return True
- def open(self, filename, mode):
+ def open(self, filename: str, mode: str) -> typing.IO[Any]:
r = "r" in mode
w = "w" in mode or "a" in mode or "+" in mode
@@ -130,24 +149,24 @@ class FtpFs(AbstractedFS):
self.validpath(ap)
return open(fsenc(ap), mode)
- def chdir(self, path):
+ def chdir(self, path: str) -> None:
self.cwd = join(self.cwd, path)
x = self.hub.asrv.vfs.can_access(self.cwd.lstrip("/"), self.h.username)
self.can_read, self.can_write, self.can_move, self.can_delete, self.can_get = x
- def mkdir(self, path):
+ def mkdir(self, path: str) -> None:
ap = self.rv2a(path, w=True)
bos.mkdir(ap)
- def listdir(self, path):
+ def listdir(self, path: str) -> list[str]:
vpath = join(self.cwd, path).lstrip("/")
try:
vfs, rem = self.hub.asrv.vfs.get(vpath, self.uname, True, False)
- fsroot, vfs_ls, vfs_virt = vfs.ls(
+ fsroot, vfs_ls1, vfs_virt = vfs.ls(
rem, self.uname, not self.args.no_scandir, [[True], [False, True]]
)
- vfs_ls = [x[0] for x in vfs_ls]
+ vfs_ls = [x[0] for x in vfs_ls1]
vfs_ls.extend(vfs_virt.keys())
if not self.args.ed:
@@ -164,11 +183,11 @@ class FtpFs(AbstractedFS):
r = {x.split("/")[0]: 1 for x in self.hub.asrv.vfs.all_vols.keys()}
return list(sorted(list(r.keys())))
- def rmdir(self, path):
+ def rmdir(self, path: str) -> None:
ap = self.rv2a(path, d=True)
bos.rmdir(ap)
- def remove(self, path):
+ def remove(self, path: str) -> None:
if self.args.no_del:
raise FilesystemError("the delete feature is disabled in server config")
@@ -178,13 +197,13 @@ class FtpFs(AbstractedFS):
except Exception as ex:
raise FilesystemError(str(ex))
- def rename(self, src, dst):
+ def rename(self, src: str, dst: str) -> None:
if not self.can_move:
raise FilesystemError("not allowed for user " + self.h.username)
if self.args.no_mv:
- m = "the rename/move feature is disabled in server config"
- raise FilesystemError(m)
+ t = "the rename/move feature is disabled in server config"
+ raise FilesystemError(t)
svp = join(self.cwd, src).lstrip("/")
dvp = join(self.cwd, dst).lstrip("/")
@@ -193,10 +212,10 @@ class FtpFs(AbstractedFS):
except Exception as ex:
raise FilesystemError(str(ex))
- def chmod(self, path, mode):
+ def chmod(self, path: str, mode: str) -> None:
pass
- def stat(self, path):
+ def stat(self, path: str) -> os.stat_result:
try:
ap = self.rv2a(path, r=True)
return bos.stat(ap)
@@ -208,59 +227,59 @@ class FtpFs(AbstractedFS):
return st
- def utime(self, path, timeval):
+ def utime(self, path: str, timeval: float) -> None:
ap = self.rv2a(path, w=True)
return bos.utime(ap, (timeval, timeval))
- def lstat(self, path):
+ def lstat(self, path: str) -> os.stat_result:
ap = self.rv2a(path)
return bos.lstat(ap)
- def isfile(self, path):
+ def isfile(self, path: str) -> bool:
st = self.stat(path)
return stat.S_ISREG(st.st_mode)
- def islink(self, path):
+ def islink(self, path: str) -> bool:
ap = self.rv2a(path)
return bos.path.islink(ap)
- def isdir(self, path):
+ def isdir(self, path: str) -> bool:
try:
st = self.stat(path)
return stat.S_ISDIR(st.st_mode)
except:
return True
- def getsize(self, path):
+ def getsize(self, path: str) -> int:
ap = self.rv2a(path)
return bos.path.getsize(ap)
- def getmtime(self, path):
+ def getmtime(self, path: str) -> float:
ap = self.rv2a(path)
return bos.path.getmtime(ap)
- def realpath(self, path):
+ def realpath(self, path: str) -> str:
return path
- def lexists(self, path):
+ def lexists(self, path: str) -> bool:
ap = self.rv2a(path)
return bos.path.lexists(ap)
- def get_user_by_uid(self, uid):
+ def get_user_by_uid(self, uid: int) -> str:
return "root"
- def get_group_by_uid(self, gid):
+ def get_group_by_uid(self, gid: int) -> str:
return "root"
class FtpHandler(FTPHandler):
abstracted_fs = FtpFs
- hub = None # type: SvcHub
- args = None # type: argparse.Namespace
+ hub: "SvcHub" = None
+ args: argparse.Namespace = None
- def __init__(self, conn, server, ioloop=None):
- self.hub = FtpHandler.hub # type: SvcHub
- self.args = FtpHandler.args # type: argparse.Namespace
+ def __init__(self, conn: Any, server: Any, ioloop: Any = None) -> None:
+ self.hub: "SvcHub" = FtpHandler.hub
+ self.args: argparse.Namespace = FtpHandler.args
if PY2:
FTPHandler.__init__(self, conn, server, ioloop)
@@ -268,9 +287,10 @@ class FtpHandler(FTPHandler):
super(FtpHandler, self).__init__(conn, server, ioloop)
# abspath->vpath mapping to resolve log_transfer paths
- self.vfs_map = {}
+ self.vfs_map: dict[str, str] = {}
- def ftp_STOR(self, file, mode="w"):
+ def ftp_STOR(self, file: str, mode: str = "w") -> Any:
+ # Optional[str]
vp = join(self.fs.cwd, file).lstrip("/")
ap = self.fs.v2a(vp)
self.vfs_map[ap] = vp
@@ -279,7 +299,16 @@ class FtpHandler(FTPHandler):
# print("ftp_STOR: {} {} OK".format(vp, mode))
return ret
- def log_transfer(self, cmd, filename, receive, completed, elapsed, bytes):
+ def log_transfer(
+ self,
+ cmd: str,
+ filename: bytes,
+ receive: bool,
+ completed: bool,
+ elapsed: float,
+ bytes: int,
+ ) -> Any:
+ # None
ap = filename.decode("utf-8", "replace")
vp = self.vfs_map.pop(ap, None)
# print("xfer_end: {} => {}".format(ap, vp))
@@ -312,7 +341,7 @@ except:
class Ftpd(object):
- def __init__(self, hub):
+ def __init__(self, hub: "SvcHub") -> None:
self.hub = hub
self.args = hub.args
@@ -321,24 +350,23 @@ class Ftpd(object):
hs.append([FtpHandler, self.args.ftp])
if self.args.ftps:
try:
- h = SftpHandler
+ h1 = SftpHandler
except:
- m = "\nftps requires pyopenssl;\nplease run the following:\n\n {} -m pip install --user pyopenssl\n"
- print(m.format(sys.executable))
+ t = "\nftps requires pyopenssl;\nplease run the following:\n\n {} -m pip install --user pyopenssl\n"
+ print(t.format(sys.executable))
sys.exit(1)
- h.certfile = os.path.join(E.cfg, "cert.pem")
- h.tls_control_required = True
- h.tls_data_required = True
+ h1.certfile = os.path.join(E.cfg, "cert.pem")
+ h1.tls_control_required = True
+ h1.tls_data_required = True
- hs.append([h, self.args.ftps])
+ hs.append([h1, self.args.ftps])
- for h in hs:
- h, lp = h
- h.hub = hub
- h.args = hub.args
- h.authorizer = FtpAuth()
- h.authorizer.hub = hub
+ for h_lp in hs:
+ h2, lp = h_lp
+ h2.hub = hub
+ h2.args = hub.args
+ h2.authorizer = FtpAuth(hub)
if self.args.ftp_pr:
p1, p2 = [int(x) for x in self.args.ftp_pr.split("-")]
@@ -350,10 +378,10 @@ class Ftpd(object):
else:
p1 += d + 1
- h.passive_ports = list(range(p1, p2 + 1))
+ h2.passive_ports = list(range(p1, p2 + 1))
if self.args.ftp_nat:
- h.masquerade_address = self.args.ftp_nat
+ h2.masquerade_address = self.args.ftp_nat
if self.args.ftp_dbg:
config_logging(level=logging.DEBUG)
@@ -363,11 +391,11 @@ class Ftpd(object):
for h, lp in hs:
FTPServer((ip, int(lp)), h, ioloop)
- t = threading.Thread(target=ioloop.loop)
- t.daemon = True
- t.start()
+ thr = threading.Thread(target=ioloop.loop)
+ thr.daemon = True
+ thr.start()
-def join(p1, p2):
+def join(p1: str, p2: str) -> str:
w = os.path.join(p1, p2.replace("\\", "/"))
return os.path.normpath(w).replace("\\", "/")
diff --git a/copyparty/httpcli.py b/copyparty/httpcli.py
index 5368bf62..cfc2bbf7 100644
--- a/copyparty/httpcli.py
+++ b/copyparty/httpcli.py
@@ -1,18 +1,23 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
-import os
-import stat
-import gzip
-import time
-import copy
-import json
+import argparse # typechk
import base64
-import string
+import calendar
+import copy
+import gzip
+import json
+import os
+import re
import socket
+import stat
+import string
+import threading # typechk
+import time
from datetime import datetime
from operator import itemgetter
-import calendar
+
+import jinja2 # typechk
try:
import lzma
@@ -24,13 +29,62 @@ try:
except:
pass
-from .__init__ import E, PY2, WINDOWS, ANYWIN, unicode
-from .util import * # noqa # pylint: disable=unused-wildcard-import
+from .__init__ import ANYWIN, PY2, TYPE_CHECKING, WINDOWS, E, unicode
+from .authsrv import VFS # typechk
from .bos import bos
-from .authsrv import AuthSrv
-from .szip import StreamZip
from .star import StreamTar
+from .sutil import StreamArc # typechk
+from .szip import StreamZip
+from .util import (
+ HTTP_TS_FMT,
+ HTTPCODE,
+ META_NOBOTS,
+ MultipartParser,
+ Pebkac,
+ UnrecvEOF,
+ alltrace,
+ exclude_dotfiles,
+ fsenc,
+ gen_filekey,
+ gencookie,
+ get_spd,
+ guess_mime,
+ gzip_orig_sz,
+ hashcopy,
+ html_bescape,
+ html_escape,
+ http_ts,
+ humansize,
+ min_ex,
+ quotep,
+ read_header,
+ read_socket,
+ read_socket_chunked,
+ read_socket_unbounded,
+ relchk,
+ ren_open,
+ s3enc,
+ sanitize_fn,
+ sendfile_kern,
+ sendfile_py,
+ undot,
+ unescape_cookie,
+ unquote,
+ unquotep,
+ vol_san,
+ vsplit,
+ yieldfile,
+)
+try:
+ from typing import Any, Generator, Match, Optional, Pattern, Type, Union
+except:
+ pass
+
+if TYPE_CHECKING:
+ from .httpconn import HttpConn
+
+_ = (argparse, threading)
NO_CACHE = {"Cache-Control": "no-cache"}
@@ -40,27 +94,60 @@ class HttpCli(object):
Spawned by HttpConn to process one http transaction
"""
- def __init__(self, conn):
+ def __init__(self, conn: "HttpConn") -> None:
+ assert conn.sr
+
self.t0 = time.time()
self.conn = conn
- self.mutex = conn.mutex
- self.s = conn.s # type: socket
- self.sr = conn.sr # type: Unrecv
+ self.mutex = conn.mutex # mypy404
+ self.s = conn.s
+ self.sr = conn.sr
self.ip = conn.addr[0]
- self.addr = conn.addr # type: tuple[str, int]
- self.args = conn.args
- self.asrv = conn.asrv # type: AuthSrv
- self.ico = conn.ico
- self.thumbcli = conn.thumbcli
- self.u2fh = conn.u2fh
- self.log_func = conn.log_func
- self.log_src = conn.log_src
- self.tls = hasattr(self.s, "cipher")
+ self.addr: tuple[str, int] = conn.addr
+ self.args = conn.args # mypy404
+ self.asrv = conn.asrv # mypy404
+ self.ico = conn.ico # mypy404
+ self.thumbcli = conn.thumbcli # mypy404
+ self.u2fh = conn.u2fh # mypy404
+ self.log_func = conn.log_func # mypy404
+ self.log_src = conn.log_src # mypy404
+ self.tls: bool = hasattr(self.s, "cipher")
+
+ # placeholders; assigned by run()
+ self.keepalive = False
+ self.is_https = False
+ self.headers: dict[str, str] = {}
+ self.mode = " "
+ self.req = " "
+ self.http_ver = " "
+ self.ua = " "
+ self.is_rclone = False
+ self.is_ancient = False
+ self.dip = " "
+ self.ouparam: dict[str, str] = {}
+ self.uparam: dict[str, str] = {}
+ self.cookies: dict[str, str] = {}
+ self.vpath = " "
+ self.uname = " "
+ self.rvol = [" "]
+ self.wvol = [" "]
+ self.mvol = [" "]
+ self.dvol = [" "]
+ self.gvol = [" "]
+ self.do_log = True
+ self.can_read = False
+ self.can_write = False
+ self.can_move = False
+ self.can_delete = False
+ self.can_get = False
+ # post
+ self.parser: Optional[MultipartParser] = None
+ # end placeholders
self.bufsz = 1024 * 32
- self.hint = None
+ self.hint = ""
self.trailing_slash = True
- self.out_headerlist = []
+ self.out_headerlist: list[tuple[str, str]] = []
self.out_headers = {
"Access-Control-Allow-Origin": "*",
"Cache-Control": "no-store; max-age=0",
@@ -71,44 +158,44 @@ class HttpCli(object):
self.out_headers["X-Robots-Tag"] = "noindex, nofollow"
self.html_head = h
- def log(self, msg, c=0):
+ def log(self, msg: str, c: Union[int, str] = 0) -> None:
ptn = self.asrv.re_pwd
if ptn and ptn.search(msg):
msg = ptn.sub(self.unpwd, msg)
self.log_func(self.log_src, msg, c)
- def unpwd(self, m):
+ def unpwd(self, m: Match[str]) -> str:
a, b = m.groups()
return "=\033[7m {} \033[27m{}".format(self.asrv.iacct[a], b)
- def _check_nonfatal(self, ex, post):
+ def _check_nonfatal(self, ex: Pebkac, post: bool) -> bool:
if post:
return ex.code < 300
return ex.code < 400 or ex.code in [404, 429]
- def _assert_safe_rem(self, rem):
+ def _assert_safe_rem(self, rem: str) -> None:
# sanity check to prevent any disasters
if rem.startswith("/") or rem.startswith("../") or "/../" in rem:
raise Exception("that was close")
- def j2(self, name, **ka):
+ def j2s(self, name: str, **ka: Any) -> str:
tpl = self.conn.hsrv.j2[name]
- if ka:
- ka["ts"] = self.conn.hsrv.cachebuster()
- ka["svcname"] = self.args.doctitle
- ka["html_head"] = self.html_head
- return tpl.render(**ka)
+ ka["ts"] = self.conn.hsrv.cachebuster()
+ ka["svcname"] = self.args.doctitle
+ ka["html_head"] = self.html_head
+ return tpl.render(**ka) # type: ignore
- return tpl
+ def j2j(self, name: str) -> jinja2.Template:
+ return self.conn.hsrv.j2[name]
- def run(self):
+ def run(self) -> bool:
"""returns true if connection can be reused"""
self.keepalive = False
self.is_https = False
self.headers = {}
- self.hint = None
+ self.hint = ""
try:
headerlines = read_header(self.sr)
if not headerlines:
@@ -125,8 +212,8 @@ class HttpCli(object):
# normalize incoming headers to lowercase;
# outgoing headers however are Correct-Case
for header_line in headerlines[1:]:
- k, v = header_line.split(":", 1)
- self.headers[k.lower()] = v.strip()
+ k, zs = header_line.split(":", 1)
+ self.headers[k.lower()] = zs.strip()
except:
msg = " ]\n#[ ".join(headerlines)
raise Pebkac(400, "bad headers:\n#[ " + msg + " ]")
@@ -147,23 +234,26 @@ class HttpCli(object):
self.is_rclone = self.ua.startswith("rclone/")
self.is_ancient = self.ua.startswith("Mozilla/4.")
- v = self.headers.get("connection", "").lower()
- self.keepalive = not v.startswith("close") and self.http_ver != "HTTP/1.0"
- self.is_https = (self.headers.get("x-forwarded-proto", "").lower() == "https" or self.tls)
+ zs = self.headers.get("connection", "").lower()
+ self.keepalive = not zs.startswith("close") and self.http_ver != "HTTP/1.0"
+ self.is_https = (
+ self.headers.get("x-forwarded-proto", "").lower() == "https" or self.tls
+ )
n = self.args.rproxy
if n:
- v = self.headers.get("x-forwarded-for")
- if v and self.conn.addr[0] in ["127.0.0.1", "::1"]:
+ zso = self.headers.get("x-forwarded-for")
+ if zso and self.conn.addr[0] in ["127.0.0.1", "::1"]:
if n > 0:
n -= 1
- vs = v.split(",")
+ zsl = zso.split(",")
try:
- self.ip = vs[n].strip()
+ self.ip = zsl[n].strip()
except:
- self.ip = vs[0].strip()
- self.log("rproxy={} oob x-fwd {}".format(self.args.rproxy, v), c=3)
+ self.ip = zsl[0].strip()
+ t = "rproxy={} oob x-fwd {}"
+ self.log(t.format(self.args.rproxy, zso), c=3)
self.log_src = self.conn.set_rproxy(self.ip)
@@ -175,9 +265,9 @@ class HttpCli(object):
keys = list(sorted(self.headers.keys()))
for k in keys:
- v = self.headers.get(k)
- if v is not None:
- self.log("[H] {}: \033[33m[{}]".format(k, v), 6)
+ zso = self.headers.get(k)
+ if zso is not None:
+ self.log("[H] {}: \033[33m[{}]".format(k, zso), 6)
if "&" in self.req and "?" not in self.req:
self.hint = "did you mean '?' instead of '&'"
@@ -193,20 +283,22 @@ class HttpCli(object):
vpath = undot(vpath)
for k in arglist.split("&"):
if "=" in k:
- k, v = k.split("=", 1)
- uparam[k.lower()] = v.strip()
+ k, zs = k.split("=", 1)
+ uparam[k.lower()] = zs.strip()
else:
- uparam[k.lower()] = False
+ uparam[k.lower()] = ""
- self.ouparam = {k: v for k, v in uparam.items()}
+ self.ouparam = {k: zs for k, zs in uparam.items()}
- cookies = self.headers.get("cookie") or {}
- if cookies:
- cookies = [x.split("=", 1) for x in cookies.split(";") if "=" in x]
- cookies = {k.strip(): unescape_cookie(v) for k, v in cookies}
+ zso = self.headers.get("cookie")
+ if zso:
+ zsll = [x.split("=", 1) for x in zso.split(";") if "=" in x]
+ cookies = {k.strip(): unescape_cookie(zs) for k, zs in zsll}
for kc, ku in [["cppwd", "pw"], ["b", "b"]]:
if kc in cookies and ku not in uparam:
uparam[ku] = cookies[kc]
+ else:
+ cookies = {}
if len(uparam) > 10 or len(cookies) > 50:
raise Pebkac(400, "u wot m8")
@@ -223,22 +315,22 @@ class HttpCli(object):
self.log("invalid relpath [{}]".format(self.vpath))
return self.tx_404() and self.keepalive
- pwd = None
- ba = self.headers.get("authorization")
- if ba:
+ pwd = ""
+ zso = self.headers.get("authorization")
+ if zso:
try:
- ba = ba.split(" ")[1].encode("ascii")
- ba = base64.b64decode(ba).decode("utf-8")
+ zb = zso.split(" ")[1].encode("ascii")
+ zs = base64.b64decode(zb).decode("utf-8")
# try "pwd", "x:pwd", "pwd:x"
- for ba in [ba] + ba.split(":", 1)[::-1]:
- if self.asrv.iacct.get(ba):
- pwd = ba
+ for zs in [zs] + zs.split(":", 1)[::-1]:
+ if self.asrv.iacct.get(zs):
+ pwd = zs
break
except:
pass
pwd = uparam.get("pw") or pwd
- self.uname = self.asrv.iacct.get(pwd, "*")
+ self.uname = self.asrv.iacct.get(pwd) or "*"
self.rvol = self.asrv.vfs.aread[self.uname]
self.wvol = self.asrv.vfs.awrite[self.uname]
self.mvol = self.asrv.vfs.amove[self.uname]
@@ -249,12 +341,13 @@ class HttpCli(object):
self.out_headerlist.append(("Set-Cookie", self.get_pwd_cookie(pwd)[0]))
if self.is_rclone:
- uparam["raw"] = False
- uparam["dots"] = False
- uparam["b"] = False
- cookies["b"] = False
+ uparam["raw"] = ""
+ uparam["dots"] = ""
+ uparam["b"] = ""
+ cookies["b"] = ""
- self.do_log = not self.conn.lf_url or not self.conn.lf_url.search(self.req)
+ ptn: Optional[Pattern[str]] = self.conn.lf_url # mypy404
+ self.do_log = not ptn or not ptn.search(self.req)
x = self.asrv.vfs.can_access(self.vpath, self.uname)
self.can_read, self.can_write, self.can_move, self.can_delete, self.can_get = x
@@ -272,9 +365,10 @@ class HttpCli(object):
raise Pebkac(400, 'invalid HTTP mode "{0}"'.format(self.mode))
except Exception as ex:
- pex = ex
if not hasattr(ex, "code"):
pex = Pebkac(500)
+ else:
+ pex = ex # type: ignore
try:
post = self.mode in ["POST", "PUT"] or "content-length" in self.headers
@@ -294,7 +388,7 @@ class HttpCli(object):
except Pebkac:
return False
- def permit_caching(self):
+ def permit_caching(self) -> None:
cache = self.uparam.get("cache")
if cache is None:
self.out_headers.update(NO_CACHE)
@@ -303,11 +397,17 @@ class HttpCli(object):
n = "604800" if cache == "i" else cache or "69"
self.out_headers["Cache-Control"] = "max-age=" + n
- def k304(self):
+ def k304(self) -> bool:
k304 = self.cookies.get("k304")
return k304 == "y" or ("; Trident/" in self.ua and not k304)
- def send_headers(self, length, status=200, mime=None, headers=None):
+ def send_headers(
+ self,
+ length: Optional[int],
+ status: int = 200,
+ mime: Optional[str] = None,
+ headers: Optional[dict[str, str]] = None,
+ ) -> None:
response = ["{} {} {}".format(self.http_ver, status, HTTPCODE[status])]
if length is not None:
@@ -327,10 +427,11 @@ class HttpCli(object):
if not mime:
mime = self.out_headers.get("Content-Type", "text/html; charset=utf-8")
+ assert mime
self.out_headers["Content-Type"] = mime
- for k, v in list(self.out_headers.items()) + self.out_headerlist:
- response.append("{}: {}".format(k, v))
+ for k, zs in list(self.out_headers.items()) + self.out_headerlist:
+ response.append("{}: {}".format(k, zs))
try:
# best practice to separate headers and body into different packets
@@ -338,11 +439,19 @@ class HttpCli(object):
except:
raise Pebkac(400, "client d/c while replying headers")
- def reply(self, body, status=200, mime=None, headers=None, volsan=False):
+ def reply(
+ self,
+ body: bytes,
+ status: int = 200,
+ mime: Optional[str] = None,
+ headers: Optional[dict[str, str]] = None,
+ volsan: bool = False,
+ ) -> bytes:
# TODO something to reply with user-supplied values safely
if volsan:
- body = vol_san(self.asrv.vfs.all_vols.values(), body)
+ vols = list(self.asrv.vfs.all_vols.values())
+ body = vol_san(vols, body)
self.send_headers(len(body), status, mime, headers)
@@ -354,17 +463,19 @@ class HttpCli(object):
return body
- def loud_reply(self, body, *args, **kwargs):
+ def loud_reply(self, body: str, *args: Any, **kwargs: Any) -> None:
if not kwargs.get("mime"):
kwargs["mime"] = "text/plain; charset=utf-8"
self.log(body.rstrip())
self.reply(body.encode("utf-8") + b"\r\n", *list(args), **kwargs)
- def urlq(self, add, rm):
+ def urlq(self, add: dict[str, str], rm: list[str]) -> str:
"""
generates url query based on uparam (b, pw, all others)
removing anything in rm, adding pairs in add
+
+ also list faster than set until ~20 items
"""
if self.is_rclone:
@@ -372,28 +483,28 @@ class HttpCli(object):
cmap = {"pw": "cppwd"}
kv = {
- k: v
- for k, v in self.uparam.items()
- if k not in rm and self.cookies.get(cmap.get(k, k)) != v
+ k: zs
+ for k, zs in self.uparam.items()
+ if k not in rm and self.cookies.get(cmap.get(k, k)) != zs
}
kv.update(add)
if not kv:
return ""
- r = ["{}={}".format(k, quotep(v)) if v else k for k, v in kv.items()]
+ r = ["{}={}".format(k, quotep(zs)) if zs else k for k, zs in kv.items()]
return "?" + "&".join(r)
def redirect(
self,
- vpath,
- suf="",
- msg="aight",
- flavor="go to",
- click=True,
- status=200,
- use302=False,
- ):
- html = self.j2(
+ vpath: str,
+ suf: str = "",
+ msg: str = "aight",
+ flavor: str = "go to",
+ click: bool = True,
+ status: int = 200,
+ use302: bool = False,
+ ) -> bool:
+ html = self.j2s(
"msg",
h2='{} /{}'.format(
quotep(vpath) + suf, flavor, html_escape(vpath, crlf=True) + suf
@@ -407,7 +518,9 @@ class HttpCli(object):
else:
self.reply(html, status=status)
- def handle_get(self):
+ return True
+
+ def handle_get(self) -> bool:
if self.do_log:
logmsg = "{:4} {}".format(self.mode, self.req)
@@ -434,13 +547,13 @@ class HttpCli(object):
self.log("inaccessible: [{}]".format(self.vpath))
return self.tx_404(True)
- self.uparam["h"] = False
+ self.uparam["h"] = ""
if "tree" in self.uparam:
return self.tx_tree()
if "delete" in self.uparam:
- return self.handle_rm()
+ return self.handle_rm([])
if "move" in self.uparam:
return self.handle_mv()
@@ -486,7 +599,7 @@ class HttpCli(object):
return self.tx_browser()
- def handle_options(self):
+ def handle_options(self) -> bool:
if self.do_log:
self.log("OPTIONS " + self.req)
@@ -501,7 +614,7 @@ class HttpCli(object):
)
return True
- def handle_put(self):
+ def handle_put(self) -> bool:
self.log("PUT " + self.req)
if self.headers.get("expect", "").lower() == "100-continue":
@@ -512,7 +625,7 @@ class HttpCli(object):
return self.handle_stash()
- def handle_post(self):
+ def handle_post(self) -> bool:
self.log("POST " + self.req)
if self.headers.get("expect", "").lower() == "100-continue":
@@ -549,16 +662,16 @@ class HttpCli(object):
reader, _ = self.get_body_reader()
for buf in reader:
orig = buf.decode("utf-8", "replace")
- m = "urlform_raw {} @ {}\n {}\n"
- self.log(m.format(len(orig), self.vpath, orig))
+ t = "urlform_raw {} @ {}\n {}\n"
+ self.log(t.format(len(orig), self.vpath, orig))
try:
- plain = unquote(buf.replace(b"+", b" "))
- plain = plain.decode("utf-8", "replace")
+ zb = unquote(buf.replace(b"+", b" "))
+ plain = zb.decode("utf-8", "replace")
if buf.startswith(b"msg="):
plain = plain[4:]
- m = "urlform_dec {} @ {}\n {}\n"
- self.log(m.format(len(plain), self.vpath, plain))
+ t = "urlform_dec {} @ {}\n {}\n"
+ self.log(t.format(len(plain), self.vpath, plain))
except Exception as ex:
self.log(repr(ex))
@@ -569,7 +682,7 @@ class HttpCli(object):
raise Pebkac(405, "don't know how to handle POST({})".format(ctype))
- def get_body_reader(self):
+ def get_body_reader(self) -> tuple[Generator[bytes, None, None], int]:
if "chunked" in self.headers.get("transfer-encoding", "").lower():
return read_socket_chunked(self.sr), -1
@@ -580,7 +693,8 @@ class HttpCli(object):
else:
return read_socket(self.sr, remains), remains
- def dump_to_file(self):
+ def dump_to_file(self) -> tuple[int, str, str, int, str, str]:
+ # post_sz, sha_hex, sha_b64, remains, path, url
reader, remains = self.get_body_reader()
vfs, rem = self.asrv.vfs.get(self.vpath, self.uname, False, True)
lim = vfs.get_dbv(rem)[0].lim
@@ -595,7 +709,7 @@ class HttpCli(object):
bos.makedirs(fdir)
- open_ka = {"fun": open}
+ open_ka: dict[str, Any] = {"fun": open}
open_a = ["wb", 512 * 1024]
# user-request || config-force
@@ -607,7 +721,7 @@ class HttpCli(object):
):
fb = {"gz": 9, "xz": 0} # default/fallback level
lv = {} # selected level
- alg = None # selected algo (gz=preferred)
+ alg = "" # selected algo (gz=preferred)
# user-prefs first
if "gz" in self.uparam or "pk" in self.uparam: # def.pk
@@ -615,8 +729,8 @@ class HttpCli(object):
if "xz" in self.uparam:
alg = "xz"
if alg:
- v = self.uparam.get(alg)
- lv[alg] = fb[alg] if v is None else int(v)
+ zso = self.uparam.get(alg)
+ lv[alg] = fb[alg] if zso is None else int(zso)
if alg not in vfs.flags:
alg = "gz" if "gz" in vfs.flags else "xz"
@@ -633,7 +747,7 @@ class HttpCli(object):
except:
pass
- lv[alg] = lv.get(alg) or fb.get(alg)
+ lv[alg] = lv.get(alg) or fb.get(alg) or 0
self.log("compressing with {} level {}".format(alg, lv.get(alg)))
if alg == "gz":
@@ -656,8 +770,8 @@ class HttpCli(object):
if not fn:
fn = "put" + suffix
- with ren_open(fn, *open_a, **params) as f:
- f, fn = f["orz"]
+ with ren_open(fn, *open_a, **params) as zfw:
+ f, fn = zfw["orz"]
path = os.path.join(fdir, fn)
post_sz, sha_hex, sha_b64 = hashcopy(reader, f, self.args.s_wr_slp)
@@ -674,8 +788,7 @@ class HttpCli(object):
return post_sz, sha_hex, sha_b64, remains, path, ""
vfs, rem = vfs.get_dbv(rem)
- self.conn.hsrv.broker.put(
- False,
+ self.conn.hsrv.broker.say(
"up2k.hash_file",
vfs.realpath,
vfs.flags,
@@ -705,16 +818,16 @@ class HttpCli(object):
return post_sz, sha_hex, sha_b64, remains, path, url
- def handle_stash(self):
+ def handle_stash(self) -> bool:
post_sz, sha_hex, sha_b64, remains, path, url = self.dump_to_file()
spd = self._spd(post_sz)
- m = "{} wrote {}/{} bytes to {} # {}"
- self.log(m.format(spd, post_sz, remains, path, sha_b64[:28])) # 21
- m = "{}\n{}\n{}\n{}\n".format(post_sz, sha_b64, sha_hex[:56], url)
- self.reply(m.encode("utf-8"))
+ t = "{} wrote {}/{} bytes to {} # {}"
+ self.log(t.format(spd, post_sz, remains, path, sha_b64[:28])) # 21
+ t = "{}\n{}\n{}\n{}\n".format(post_sz, sha_b64, sha_hex[:56], url)
+ self.reply(t.encode("utf-8"))
return True
- def _spd(self, nbytes, add=True):
+ def _spd(self, nbytes: int, add: bool = True) -> str:
if add:
self.conn.nbyte += nbytes
@@ -722,7 +835,7 @@ class HttpCli(object):
spd2 = get_spd(self.conn.nbyte, self.conn.t0)
return "{} {} n{}".format(spd1, spd2, self.conn.nreq)
- def handle_post_multipart(self):
+ def handle_post_multipart(self) -> bool:
self.parser = MultipartParser(self.log, self.sr, self.headers)
self.parser.parse()
@@ -749,7 +862,8 @@ class HttpCli(object):
raise Pebkac(422, 'invalid action "{}"'.format(act))
- def handle_zip_post(self):
+ def handle_zip_post(self) -> bool:
+ assert self.parser
for k in ["zip", "tar"]:
v = self.uparam.get(k)
if v is not None:
@@ -759,17 +873,17 @@ class HttpCli(object):
raise Pebkac(422, "need zip or tar keyword")
vn, rem = self.asrv.vfs.get(self.vpath, self.uname, True, False)
- items = self.parser.require("files", 1024 * 1024)
- if not items:
+ zs = self.parser.require("files", 1024 * 1024)
+ if not zs:
raise Pebkac(422, "need files list")
- items = items.replace("\r", "").split("\n")
+ items = zs.replace("\r", "").split("\n")
items = [unquotep(x) for x in items if items]
self.parser.drop()
return self.tx_zip(k, v, vn, rem, items, self.args.ed)
- def handle_post_json(self):
+ def handle_post_json(self) -> bool:
try:
remains = int(self.headers["content-length"])
except:
@@ -836,14 +950,14 @@ class HttpCli(object):
except:
raise Pebkac(500, min_ex())
- x = self.conn.hsrv.broker.put(True, "up2k.handle_json", body)
+ x = self.conn.hsrv.broker.ask("up2k.handle_json", body)
ret = x.get()
ret = json.dumps(ret)
self.log(ret)
self.reply(ret.encode("utf-8"), mime="application/json")
return True
- def handle_search(self, body):
+ def handle_search(self, body: dict[str, Any]) -> bool:
idx = self.conn.get_u2idx()
if not hasattr(idx, "p_end"):
raise Pebkac(500, "sqlite3 is not available on the server; cannot search")
@@ -857,15 +971,15 @@ class HttpCli(object):
continue
seen[vfs] = True
- vols.append([vfs.vpath, vfs.realpath, vfs.flags])
+ vols.append((vfs.vpath, vfs.realpath, vfs.flags))
t0 = time.time()
if idx.p_end:
penalty = 0.7
t_idle = t0 - idx.p_end
if idx.p_dur > 0.7 and t_idle < penalty:
- m = "rate-limit {:.1f} sec, cost {:.2f}, idle {:.2f}"
- raise Pebkac(429, m.format(penalty, idx.p_dur, t_idle))
+ t = "rate-limit {:.1f} sec, cost {:.2f}, idle {:.2f}"
+ raise Pebkac(429, t.format(penalty, idx.p_dur, t_idle))
if "srch" in body:
# search by up2k hashlist
@@ -873,8 +987,8 @@ class HttpCli(object):
vbody["hash"] = len(vbody["hash"])
self.log("qj: " + repr(vbody))
hits = idx.fsearch(vols, body)
- msg = repr(hits)
- taglist = {}
+ msg: Any = repr(hits)
+ taglist: list[str] = []
else:
# search by query params
q = body["q"]
@@ -900,7 +1014,7 @@ class HttpCli(object):
self.reply(r, mime="application/json")
return True
- def handle_post_binary(self):
+ def handle_post_binary(self) -> bool:
try:
remains = int(self.headers["content-length"])
except:
@@ -915,7 +1029,7 @@ class HttpCli(object):
vfs, _ = self.asrv.vfs.get(self.vpath, self.uname, False, True)
ptop = (vfs.dbv or vfs).realpath
- x = self.conn.hsrv.broker.put(True, "up2k.handle_chunk", ptop, wark, chash)
+ x = self.conn.hsrv.broker.ask("up2k.handle_chunk", ptop, wark, chash)
response = x.get()
chunksize, cstart, path, lastmod = response
@@ -946,8 +1060,8 @@ class HttpCli(object):
post_sz, _, sha_b64 = hashcopy(reader, f, self.args.s_wr_slp)
if sha_b64 != chash:
- m = "your chunk got corrupted somehow (received {} bytes); expected vs received hash:\n{}\n{}"
- raise Pebkac(400, m.format(post_sz, chash, sha_b64))
+ t = "your chunk got corrupted somehow (received {} bytes); expected vs received hash:\n{}\n{}"
+ raise Pebkac(400, t.format(post_sz, chash, sha_b64))
if len(cstart) > 1 and path != os.devnull:
self.log(
@@ -974,15 +1088,15 @@ class HttpCli(object):
with self.mutex:
self.u2fh.put(path, f)
finally:
- x = self.conn.hsrv.broker.put(True, "up2k.release_chunk", ptop, wark, chash)
+ x = self.conn.hsrv.broker.ask("up2k.release_chunk", ptop, wark, chash)
x.get() # block client until released
- x = self.conn.hsrv.broker.put(True, "up2k.confirm_chunk", ptop, wark, chash)
- x = x.get()
+ x = self.conn.hsrv.broker.ask("up2k.confirm_chunk", ptop, wark, chash)
+ ztis = x.get()
try:
- num_left, fin_path = x
+ num_left, fin_path = ztis
except:
- self.loud_reply(x, status=500)
+ self.loud_reply(ztis, status=500)
return False
if not num_left and fpool:
@@ -991,7 +1105,7 @@ class HttpCli(object):
# windows cant rename open files
if ANYWIN and path != fin_path and not self.args.nw:
- self.conn.hsrv.broker.put(True, "up2k.finish_upload", ptop, wark).get()
+ self.conn.hsrv.broker.ask("up2k.finish_upload", ptop, wark).get()
if not ANYWIN and not num_left:
times = (int(time.time()), int(lastmod))
@@ -1006,7 +1120,8 @@ class HttpCli(object):
self.reply(b"thank")
return True
- def handle_login(self):
+ def handle_login(self) -> bool:
+ assert self.parser
pwd = self.parser.require("cppwd", 64)
self.parser.drop()
@@ -1021,11 +1136,11 @@ class HttpCli(object):
dst += quotep(self.vpath)
ck, msg = self.get_pwd_cookie(pwd)
- html = self.j2("msg", h1=msg, h2='ack', redir=dst)
+ html = self.j2s("msg", h1=msg, h2='ack', redir=dst)
self.reply(html.encode("utf-8"), headers={"Set-Cookie": ck})
return True
- def get_pwd_cookie(self, pwd):
+ def get_pwd_cookie(self, pwd: str) -> tuple[str, str]:
if pwd in self.asrv.iacct:
msg = "login ok"
dur = int(60 * 60 * self.args.logout)
@@ -1038,9 +1153,10 @@ class HttpCli(object):
if self.is_ancient:
r = r.rsplit(" ", 1)[0]
- return [r, msg]
+ return r, msg
- def handle_mkdir(self):
+ def handle_mkdir(self) -> bool:
+ assert self.parser
new_dir = self.parser.require("name", 512)
self.parser.drop()
@@ -1075,7 +1191,8 @@ class HttpCli(object):
self.redirect(vpath)
return True
- def handle_new_md(self):
+ def handle_new_md(self) -> bool:
+ assert self.parser
new_file = self.parser.require("name", 512)
self.parser.drop()
@@ -1102,7 +1219,8 @@ class HttpCli(object):
self.redirect(vpath, "?edit")
return True
- def handle_plain_upload(self):
+ def handle_plain_upload(self) -> bool:
+ assert self.parser
nullwrite = self.args.nw
vfs, rem = self.asrv.vfs.get(self.vpath, self.uname, False, True)
self._assert_safe_rem(rem)
@@ -1116,17 +1234,21 @@ class HttpCli(object):
if not nullwrite:
bos.makedirs(fdir_base)
- files = []
+ files: list[tuple[int, str, str, str, str, str]] = []
+ # sz, sha_hex, sha_b64, p_file, fname, abspath
errmsg = ""
t0 = time.time()
try:
+ assert self.parser.gen
for nfile, (p_field, p_file, p_data) in enumerate(self.parser.gen):
if not p_file:
self.log("discarding incoming file without filename")
# fallthrough
fdir = fdir_base
- fname = sanitize_fn(p_file, "", [".prologue.html", ".epilogue.html"])
+ fname = sanitize_fn(
+ p_file or "", "", [".prologue.html", ".epilogue.html"]
+ )
if p_file and not nullwrite:
if not bos.path.isdir(fdir):
raise Pebkac(404, "that folder does not exist")
@@ -1143,8 +1265,8 @@ class HttpCli(object):
lim.chk_nup(self.ip)
try:
- with ren_open(fname, "wb", 512 * 1024, **open_args) as f:
- f, fname = f["orz"]
+ with ren_open(fname, "wb", 512 * 1024, **open_args) as zfw:
+ f, fname = zfw["orz"]
abspath = os.path.join(fdir, fname)
self.log("writing to {}".format(abspath))
sz, sha_hex, sha_b64 = hashcopy(p_data, f, self.args.s_wr_slp)
@@ -1158,12 +1280,14 @@ class HttpCli(object):
lim.chk_sz(sz)
except:
bos.unlink(abspath)
+ fname = os.devnull
raise
- files.append([sz, sha_hex, sha_b64, p_file, fname, abspath])
+ files.append(
+ (sz, sha_hex, sha_b64, p_file or "(discarded)", fname, abspath)
+ )
dbv, vrem = vfs.get_dbv(rem)
- self.conn.hsrv.broker.put(
- False,
+ self.conn.hsrv.broker.say(
"up2k.hash_file",
dbv.realpath,
dbv.flags,
@@ -1192,7 +1316,7 @@ class HttpCli(object):
except Pebkac as ex:
errmsg = vol_san(
- self.asrv.vfs.all_vols.values(), unicode(ex).encode("utf-8")
+ list(self.asrv.vfs.all_vols.values()), unicode(ex).encode("utf-8")
).decode("utf-8")
td = max(0.1, time.time() - t0)
@@ -1205,7 +1329,12 @@ class HttpCli(object):
status = "ERROR"
msg = "{} // {} bytes // {:.3f} MiB/s\n".format(status, sz_total, spd)
- jmsg = {"status": status, "sz": sz_total, "mbps": round(spd, 3), "files": []}
+ jmsg: dict[str, Any] = {
+ "status": status,
+ "sz": sz_total,
+ "mbps": round(spd, 3),
+ "files": [],
+ }
if errmsg:
msg += errmsg + "\n"
@@ -1260,17 +1389,17 @@ class HttpCli(object):
ft = "{}\n{}\n{}\n".format(ft, msg.rstrip(), errmsg)
f.write(ft.encode("utf-8"))
- status = 400 if errmsg else 200
+ sc = 400 if errmsg else 200
if "j" in self.uparam:
jtxt = json.dumps(jmsg, indent=2, sort_keys=True).encode("utf-8", "replace")
- self.reply(jtxt, mime="application/json", status=status)
+ self.reply(jtxt, mime="application/json", status=sc)
else:
self.redirect(
self.vpath,
msg=msg,
flavor="return to",
click=False,
- status=status,
+ status=sc,
)
if errmsg:
@@ -1279,7 +1408,8 @@ class HttpCli(object):
self.parser.drop()
return True
- def handle_text_upload(self):
+ def handle_text_upload(self) -> bool:
+ assert self.parser
try:
cli_lastmod3 = int(self.parser.require("lastmod", 16))
except:
@@ -1314,7 +1444,8 @@ class HttpCli(object):
self.reply(response.encode("utf-8"))
return True
- srv_lastmod = srv_lastmod3 = -1
+ srv_lastmod = -1.0
+ srv_lastmod3 = -1
try:
st = bos.stat(fp)
srv_lastmod = st.st_mtime
@@ -1329,7 +1460,7 @@ class HttpCli(object):
if not same_lastmod:
# some filesystems/transports limit precision to 1sec, hopefully floored
same_lastmod = (
- srv_lastmod == int(srv_lastmod)
+ srv_lastmod == int(cli_lastmod3 / 1000)
and cli_lastmod3 > srv_lastmod3
and cli_lastmod3 - srv_lastmod3 < 1000
)
@@ -1360,6 +1491,7 @@ class HttpCli(object):
pass
bos.rename(fp, os.path.join(mdir, ".hist", mfile2))
+ assert self.parser.gen
p_field, _, p_data = next(self.parser.gen)
if p_field != "body":
raise Pebkac(400, "expected body, got {}".format(p_field))
@@ -1388,7 +1520,7 @@ class HttpCli(object):
self.reply(response.encode("utf-8"))
return True
- def _chk_lastmod(self, file_ts):
+ def _chk_lastmod(self, file_ts: int) -> tuple[str, bool]:
file_lastmod = http_ts(file_ts)
cli_lastmod = self.headers.get("if-modified-since")
if cli_lastmod:
@@ -1408,7 +1540,7 @@ class HttpCli(object):
return file_lastmod, True
- def tx_file(self, req_path):
+ def tx_file(self, req_path: str) -> bool:
status = 200
logmsg = "{:4} {} ".format("", self.req)
logtail = ""
@@ -1417,7 +1549,7 @@ class HttpCli(object):
# if request is for foo.js, check if we have foo.js.{gz,br}
file_ts = 0
- editions = {}
+ editions: dict[str, tuple[str, int]] = {}
for ext in ["", ".gz", ".br"]:
try:
fs_path = req_path + ext
@@ -1425,8 +1557,8 @@ class HttpCli(object):
if stat.S_ISDIR(st.st_mode):
continue
- file_ts = max(file_ts, st.st_mtime)
- editions[ext or "plain"] = [fs_path, st.st_size]
+ file_ts = max(file_ts, int(st.st_mtime))
+ editions[ext or "plain"] = (fs_path, st.st_size)
except:
pass
if not self.vpath.startswith(".cpr/"):
@@ -1526,8 +1658,8 @@ class HttpCli(object):
use_sendfile = False
if decompress:
- open_func = gzip.open
- open_args = [fsenc(fs_path), "rb"]
+ open_func: Any = gzip.open
+ open_args: list[Any] = [fsenc(fs_path), "rb"]
# Content-Length := original file size
upper = gzip_orig_sz(fs_path)
else:
@@ -1551,7 +1683,7 @@ class HttpCli(object):
if "txt" in self.uparam:
mime = "text/plain; charset={}".format(self.uparam["txt"] or "utf-8")
elif "mime" in self.uparam:
- mime = self.uparam.get("mime")
+ mime = str(self.uparam.get("mime"))
else:
mime = guess_mime(req_path)
@@ -1583,19 +1715,18 @@ class HttpCli(object):
return ret
- def tx_zip(self, fmt, uarg, vn, rem, items, dots):
+ def tx_zip(
+ self, fmt: str, uarg: str, vn: VFS, rem: str, items: list[str], dots: bool
+ ) -> bool:
if self.args.no_zip:
raise Pebkac(400, "not enabled")
logmsg = "{:4} {} ".format("", self.req)
self.keepalive = False
- if not uarg:
- uarg = ""
-
if fmt == "tar":
mime = "application/x-tar"
- packer = StreamTar
+ packer: Type[StreamArc] = StreamTar
else:
mime = "application/zip"
packer = StreamZip
@@ -1609,24 +1740,25 @@ class HttpCli(object):
safe = (string.ascii_letters + string.digits).replace("%", "")
afn = "".join([x if x in safe.replace('"', "") else "_" for x in fn])
bascii = unicode(safe).encode("utf-8")
- ufn = fn.encode("utf-8", "xmlcharrefreplace")
- if PY2:
- ufn = [unicode(x) if x in bascii else "%{:02x}".format(ord(x)) for x in ufn]
- else:
- ufn = [
+ zb = fn.encode("utf-8", "xmlcharrefreplace")
+ if not PY2:
+ zbl = [
chr(x).encode("utf-8")
if x in bascii
else "%{:02x}".format(x).encode("ascii")
- for x in ufn
+ for x in zb
]
- ufn = b"".join(ufn).decode("ascii")
+ else:
+ zbl = [unicode(x) if x in bascii else "%{:02x}".format(ord(x)) for x in zb]
+
+ ufn = b"".join(zbl).decode("ascii")
cdis = "attachment; filename=\"{}.{}\"; filename*=UTF-8''{}.{}"
cdis = cdis.format(afn, fmt, ufn, fmt)
self.log(cdis)
self.send_headers(None, mime=mime, headers={"Content-Disposition": cdis})
- fgen = vn.zipgen(rem, items, self.uname, dots, not self.args.no_scandir)
+ fgen = vn.zipgen(rem, set(items), self.uname, dots, not self.args.no_scandir)
# for f in fgen: print(repr({k: f[k] for k in ["vp", "ap"]}))
bgen = packer(self.log, fgen, utf8="utf" in uarg, pre_crc="crc" in uarg)
bsent = 0
@@ -1645,7 +1777,7 @@ class HttpCli(object):
self.log("{}, {}".format(logmsg, spd))
return True
- def tx_ico(self, ext, exact=False):
+ def tx_ico(self, ext: str, exact: bool = False) -> bool:
self.permit_caching()
if ext.endswith("/"):
ext = "folder"
@@ -1674,7 +1806,7 @@ class HttpCli(object):
self.reply(ico, mime=mime, headers={"Last-Modified": lm})
return True
- def tx_md(self, fs_path):
+ def tx_md(self, fs_path: str) -> bool:
logmsg = "{:4} {} ".format("", self.req)
if not self.can_write:
@@ -1683,7 +1815,7 @@ class HttpCli(object):
tpl = "mde" if "edit2" in self.uparam else "md"
html_path = os.path.join(E.mod, "web", "{}.html".format(tpl))
- template = self.j2(tpl)
+ template = self.j2j(tpl)
st = bos.stat(fs_path)
ts_md = st.st_mtime
@@ -1694,7 +1826,7 @@ class HttpCli(object):
sz_md = 0
for buf in yieldfile(fs_path):
sz_md += len(buf)
- for c, v in [[b"&", 4], [b"<", 3], [b">", 3]]:
+ for c, v in [(b"&", 4), (b"<", 3), (b">", 3)]:
sz_md += (len(buf) - len(buf.replace(c, b""))) * v
file_ts = max(ts_md, ts_html, E.t0)
@@ -1720,8 +1852,8 @@ class HttpCli(object):
"md": boundary,
"arg_base": arg_base,
}
- html = template.render(**targs).encode("utf-8", "replace")
- html = html.split(boundary.encode("utf-8"))
+ zs = template.render(**targs).encode("utf-8", "replace")
+ html = zs.split(boundary.encode("utf-8"))
if len(html) != 2:
raise Exception("boundary appears in " + html_path)
@@ -1750,7 +1882,7 @@ class HttpCli(object):
return True
- def tx_mounts(self):
+ def tx_mounts(self) -> bool:
suf = self.urlq({}, ["h"])
avol = [x for x in self.wvol if x in self.rvol]
rvol, wvol, avol = [
@@ -1759,7 +1891,7 @@ class HttpCli(object):
]
if avol and not self.args.no_rescan:
- x = self.conn.hsrv.broker.put(True, "up2k.get_state")
+ x = self.conn.hsrv.broker.ask("up2k.get_state")
vs = json.loads(x.get())
vstate = {("/" + k).rstrip("/") + "/": v for k, v in vs["volstate"].items()}
else:
@@ -1787,11 +1919,11 @@ class HttpCli(object):
for v in wvol:
txt += "\n " + v
- txt = txt.encode("utf-8", "replace") + b"\n"
- self.reply(txt, mime="text/plain; charset=utf-8")
+ zb = txt.encode("utf-8", "replace") + b"\n"
+ self.reply(zb, mime="text/plain; charset=utf-8")
return True
- html = self.j2(
+ html = self.j2s(
"splash",
this=self,
qvpath=quotep(self.vpath),
@@ -1809,38 +1941,41 @@ class HttpCli(object):
self.reply(html.encode("utf-8"))
return True
- def set_k304(self):
+ def set_k304(self) -> bool:
ck = gencookie("k304", self.uparam["k304"], 60 * 60 * 24 * 365)
self.out_headerlist.append(("Set-Cookie", ck))
self.redirect("", "?h#cc")
+ return True
- def set_am_js(self):
+ def set_am_js(self) -> bool:
v = "n" if self.uparam["am_js"] == "n" else "y"
ck = gencookie("js", v, 60 * 60 * 24 * 365)
self.out_headerlist.append(("Set-Cookie", ck))
self.reply(b"promoted\n")
+ return True
- def set_cfg_reset(self):
+ def set_cfg_reset(self) -> bool:
for k in ("k304", "js", "cppwd"):
self.out_headerlist.append(("Set-Cookie", gencookie(k, "x", None)))
self.redirect("", "?h#cc")
+ return True
- def tx_404(self, is_403=False):
+ def tx_404(self, is_403: bool = False) -> bool:
rc = 404
if self.args.vague_403:
- m = '404 not found ┐( ´ -`)┌
or maybe you don\'t have access -- try logging in or go home
' + t = 'or maybe you don\'t have access -- try logging in or go home
' elif is_403: - m = 'you\'ll have to log in or go home
' + t = 'you\'ll have to log in or go home
' rc = 403 else: - m = '{}\n{}".format(time.time(), html_escape(alltrace()))
self.reply(ret.encode("utf-8"))
+ return True
- def tx_tree(self):
+ def tx_tree(self) -> bool:
top = self.uparam["tree"] or ""
dst = self.vpath
if top in [".", ".."]:
@@ -1898,12 +2034,12 @@ class HttpCli(object):
dst = dst[len(top) + 1 :]
ret = self.gen_tree(top, dst)
- ret = json.dumps(ret)
- self.reply(ret.encode("utf-8"), mime="application/json")
+ zs = json.dumps(ret)
+ self.reply(zs.encode("utf-8"), mime="application/json")
return True
- def gen_tree(self, top, target):
- ret = {}
+ def gen_tree(self, top: str, target: str) -> dict[str, Any]:
+ ret: dict[str, Any] = {}
excl = None
if target:
excl, target = (target.split("/", 1) + [""])[:2]
@@ -1921,26 +2057,26 @@ class HttpCli(object):
for v in self.rvol:
d1, d2 = v.rsplit("/", 1) if "/" in v else ["", v]
if d1 == top:
- vfs_virt[d2] = 0
+ vfs_virt[d2] = vn # typechk, value never read
dirs = []
- vfs_ls = [x[0] for x in vfs_ls if stat.S_ISDIR(x[1].st_mode)]
+ dirnames = [x[0] for x in vfs_ls if stat.S_ISDIR(x[1].st_mode)]
if not self.args.ed or "dots" not in self.uparam:
- vfs_ls = exclude_dotfiles(vfs_ls)
+ dirnames = exclude_dotfiles(dirnames)
- for fn in [x for x in vfs_ls if x != excl]:
+ for fn in [x for x in dirnames if x != excl]:
dirs.append(quotep(fn))
- for x in vfs_virt.keys():
+ for x in vfs_virt:
if x != excl:
dirs.append(x)
ret["a"] = dirs
return ret
- def tx_ups(self):
+ def tx_ups(self) -> bool:
if not self.args.unpost:
raise Pebkac(400, "the unpost feature is disabled in server config")
@@ -1952,7 +2088,7 @@ class HttpCli(object):
lm = "ups [{}]".format(filt)
self.log(lm)
- ret = []
+ ret: list[dict[str, Any]] = []
t0 = time.time()
lim = time.time() - self.args.unpost
for vol in self.asrv.vfs.all_vols.values():
@@ -1968,17 +2104,18 @@ class HttpCli(object):
ret.append({"vp": quotep(vp), "sz": sz, "at": at})
if len(ret) > 3000:
- ret.sort(key=lambda x: x["at"], reverse=True)
+ ret.sort(key=lambda x: x["at"], reverse=True) # type: ignore
ret = ret[:2000]
- ret.sort(key=lambda x: x["at"], reverse=True)
+ ret.sort(key=lambda x: x["at"], reverse=True) # type: ignore
ret = ret[:2000]
jtxt = json.dumps(ret, indent=2, sort_keys=True).encode("utf-8", "replace")
self.log("{} #{} {:.2f}sec".format(lm, len(ret), time.time() - t0))
self.reply(jtxt, mime="application/json")
+ return True
- def handle_rm(self, req=None):
+ def handle_rm(self, req: list[str]) -> bool:
if not req and not self.can_delete:
raise Pebkac(403, "not allowed for user " + self.uname)
@@ -1988,10 +2125,11 @@ class HttpCli(object):
if not req:
req = [self.vpath]
- x = self.conn.hsrv.broker.put(True, "up2k.handle_rm", self.uname, self.ip, req)
+ x = self.conn.hsrv.broker.ask("up2k.handle_rm", self.uname, self.ip, req)
self.loud_reply(x.get())
+ return True
- def handle_mv(self):
+ def handle_mv(self) -> bool:
if not self.can_move:
raise Pebkac(403, "not allowed for user " + self.uname)
@@ -2006,12 +2144,11 @@ class HttpCli(object):
# x-www-form-urlencoded (url query part) uses
# either + or %20 for 0x20 so handle both
dst = unquotep(dst.replace("+", " "))
- x = self.conn.hsrv.broker.put(
- True, "up2k.handle_mv", self.uname, self.vpath, dst
- )
+ x = self.conn.hsrv.broker.ask("up2k.handle_mv", self.uname, self.vpath, dst)
self.loud_reply(x.get())
+ return True
- def tx_ls(self, ls):
+ def tx_ls(self, ls: dict[str, Any]) -> bool:
dirs = ls["dirs"]
files = ls["files"]
arg = self.uparam["ls"]
@@ -2055,17 +2192,17 @@ class HttpCli(object):
x["name"] = n
fmt = fmt.format(len(nfmt.format(biggest)))
- ret = [
+ retl = [
"# {}: {}".format(x, ls[x])
for x in ["acct", "perms", "srvinf"]
if x in ls
]
- ret += [
+ retl += [
fmt.format(x["dt"], x["sz"], x["name"])
for y in [dirs, files]
for x in y
]
- ret = "\n".join(ret)
+ ret = "\n".join(retl)
mime = "text/plain; charset=utf-8"
else:
[x.pop(k) for k in ["name", "dt"] for y in [dirs, files] for x in y]
@@ -2076,7 +2213,7 @@ class HttpCli(object):
self.reply(ret.encode("utf-8", "replace") + b"\n", mime=mime)
return True
- def tx_browser(self):
+ def tx_browser(self) -> bool:
vpath = ""
vpnodes = [["", "/"]]
if self.vpath:
@@ -2164,7 +2301,7 @@ class HttpCli(object):
if WINDOWS:
try:
bfree = ctypes.c_ulonglong(0)
- ctypes.windll.kernel32.GetDiskFreeSpaceExW(
+ ctypes.windll.kernel32.GetDiskFreeSpaceExW( # type: ignore
ctypes.c_wchar_p(abspath), None, None, ctypes.pointer(bfree)
)
srv_info.append(humansize(bfree.value) + " free")
@@ -2179,7 +2316,7 @@ class HttpCli(object):
except:
pass
- srv_info = " // ".join(srv_info)
+ srv_infot = " // ".join(srv_info)
perms = []
if self.can_read:
@@ -2223,7 +2360,7 @@ class HttpCli(object):
"dirs": [],
"files": [],
"taglist": [],
- "srvinf": srv_info,
+ "srvinf": srv_infot,
"acct": self.uname,
"idx": ("e2d" in vn.flags),
"perms": perms,
@@ -2251,7 +2388,7 @@ class HttpCli(object):
"logues": logues,
"readme": readme,
"title": html_escape(self.vpath, crlf=True),
- "srv_info": srv_info,
+ "srv_info": srv_infot,
"lang": self.args.lang,
"dtheme": self.args.theme,
"themes": self.args.themes,
@@ -2267,7 +2404,7 @@ class HttpCli(object):
if "zip" in self.uparam or "tar" in self.uparam:
raise Pebkac(403)
- html = self.j2(tpl, **j2a)
+ html = self.j2s(tpl, **j2a)
self.reply(html.encode("utf-8", "replace"))
return True
@@ -2280,11 +2417,12 @@ class HttpCli(object):
rem, self.uname, not self.args.no_scandir, [[True], [False, True]]
)
stats = {k: v for k, v in vfs_ls}
- vfs_ls = [x[0] for x in vfs_ls]
- vfs_ls.extend(vfs_virt.keys())
+ ls_names = [x[0] for x in vfs_ls]
+ ls_names.extend(list(vfs_virt.keys()))
# check for old versions of files,
- hist = {} # [num-backups, most-recent, hist-path]
+ # [num-backups, most-recent, hist-path]
+ hist: dict[str, tuple[int, float, str]] = {}
histdir = os.path.join(fsroot, ".hist")
ptn = re.compile(r"(.*)\.([0-9]+\.[0-9]{3})(\.[^\.]+)$")
try:
@@ -2294,14 +2432,14 @@ class HttpCli(object):
continue
fn = m.group(1) + m.group(3)
- n, ts, _ = hist.get(fn, [0, 0, ""])
- hist[fn] = [n + 1, max(ts, float(m.group(2))), hfn]
+ n, ts, _ = hist.get(fn, (0, 0, ""))
+ hist[fn] = (n + 1, max(ts, float(m.group(2))), hfn)
except:
pass
# show dotfiles if permitted and requested
if not self.args.ed or "dots" not in self.uparam:
- vfs_ls = exclude_dotfiles(vfs_ls)
+ ls_names = exclude_dotfiles(ls_names)
icur = None
if "e2t" in vn.flags:
@@ -2312,7 +2450,7 @@ class HttpCli(object):
dirs = []
files = []
- for fn in vfs_ls:
+ for fn in ls_names:
base = ""
href = fn
if not is_ls and not is_js and not self.trailing_slash and vpath:
@@ -2339,14 +2477,14 @@ class HttpCli(object):
margin = 'zip'.format(quotep(href))
elif fn in hist:
margin = '#{}'.format(
- base, html_escape(hist[fn][2], quote=True, crlf=True), hist[fn][0]
+ base, html_escape(hist[fn][2], quot=True, crlf=True), hist[fn][0]
)
else:
margin = "-"
sz = inf.st_size
- dt = datetime.utcfromtimestamp(inf.st_mtime)
- dt = dt.strftime("%Y-%m-%d %H:%M:%S")
+ zd = datetime.utcfromtimestamp(inf.st_mtime)
+ dt = zd.strftime("%Y-%m-%d %H:%M:%S")
try:
ext = "---" if is_dir else fn.rsplit(".", 1)[1]
@@ -2380,11 +2518,11 @@ class HttpCli(object):
files.append(item)
item["rd"] = rem
- taglist = {}
- for f in files:
- fn = f["name"]
- rd = f["rd"]
- del f["rd"]
+ tagset: set[str] = set()
+ for fe in files:
+ fn = fe["name"]
+ rd = fe["rd"]
+ del fe["rd"]
if not icur:
break
@@ -2403,12 +2541,12 @@ class HttpCli(object):
args = s3enc(idx.mem_cur, rd, fn)
r = icur.execute(q, args).fetchone()
except:
- m = "tag list error, {}/{}\n{}"
- self.log(m.format(rd, fn, min_ex()))
+ t = "tag list error, {}/{}\n{}"
+ self.log(t.format(rd, fn, min_ex()))
break
- tags = {}
- f["tags"] = tags
+ tags: dict[str, Any] = {}
+ fe["tags"] = tags
if not r:
continue
@@ -2417,17 +2555,19 @@ class HttpCli(object):
q = "select k, v from mt where w = ? and +k != 'x'"
try:
for k, v in icur.execute(q, (w,)):
- taglist[k] = True
+ tagset.add(k)
tags[k] = v
except:
- m = "tag read error, {}/{} [{}]:\n{}"
- self.log(m.format(rd, fn, w, min_ex()))
+ t = "tag read error, {}/{} [{}]:\n{}"
+ self.log(t.format(rd, fn, w, min_ex()))
break
if icur:
- taglist = [k for k in vn.flags.get("mte", "").split(",") if k in taglist]
- for f in dirs:
- f["tags"] = {}
+ taglist = [k for k in vn.flags.get("mte", "").split(",") if k in tagset]
+ for fe in dirs:
+ fe["tags"] = {}
+ else:
+ taglist = list(tagset)
if is_ls:
ls_ret["dirs"] = dirs
@@ -2480,6 +2620,6 @@ class HttpCli(object):
if self.args.css_browser:
j2a["css"] = self.args.css_browser
- html = self.j2(tpl, **j2a)
+ html = self.j2s(tpl, **j2a)
self.reply(html.encode("utf-8", "replace"))
return True
diff --git a/copyparty/httpconn.py b/copyparty/httpconn.py
index 85f5f701..067e2d31 100644
--- a/copyparty/httpconn.py
+++ b/copyparty/httpconn.py
@@ -1,25 +1,36 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
-import re
+import argparse # typechk
import os
-import time
+import re
import socket
+import threading # typechk
+import time
-HAVE_SSL = True
try:
+ HAVE_SSL = True
import ssl
except:
HAVE_SSL = False
-from .__init__ import E
-from .util import Unrecv
+from . import util as Util
+from .__init__ import TYPE_CHECKING, E
+from .authsrv import AuthSrv # typechk
from .httpcli import HttpCli
-from .u2idx import U2idx
+from .ico import Ico
+from .mtag import HAVE_FFMPEG
from .th_cli import ThumbCli
from .th_srv import HAVE_PIL, HAVE_VIPS
-from .mtag import HAVE_FFMPEG
-from .ico import Ico
+from .u2idx import U2idx
+
+try:
+ from typing import Optional, Pattern, Union
+except:
+ pass
+
+if TYPE_CHECKING:
+ from .httpsrv import HttpSrv
class HttpConn(object):
@@ -28,32 +39,37 @@ class HttpConn(object):
creates an HttpCli for each request (Connection: Keep-Alive)
"""
- def __init__(self, sck, addr, hsrv):
+ def __init__(
+ self, sck: socket.socket, addr: tuple[str, int], hsrv: "HttpSrv"
+ ) -> None:
self.s = sck
- self.sr = None # Type: Unrecv
+ self.sr: Optional[Util._Unrecv] = None
self.addr = addr
self.hsrv = hsrv
- self.mutex = hsrv.mutex
- self.args = hsrv.args
- self.asrv = hsrv.asrv
+ self.mutex: threading.Lock = hsrv.mutex # mypy404
+ self.args: argparse.Namespace = hsrv.args # mypy404
+ self.asrv: AuthSrv = hsrv.asrv # mypy404
self.cert_path = hsrv.cert_path
- self.u2fh = hsrv.u2fh
+ self.u2fh: Util.FHC = hsrv.u2fh # mypy404
enth = (HAVE_PIL or HAVE_VIPS or HAVE_FFMPEG) and not self.args.no_thumb
- self.thumbcli = ThumbCli(hsrv) if enth else None
- self.ico = Ico(self.args)
+ self.thumbcli: Optional[ThumbCli] = ThumbCli(hsrv) if enth else None # mypy404
+ self.ico: Ico = Ico(self.args) # mypy404
- self.t0 = time.time()
+ self.t0: float = time.time() # mypy404
self.stopping = False
- self.nreq = 0
- self.nbyte = 0
- self.u2idx = None
- self.log_func = hsrv.log
- self.lf_url = re.compile(self.args.lf_url) if self.args.lf_url else None
+ self.nreq: int = 0 # mypy404
+ self.nbyte: int = 0 # mypy404
+ self.u2idx: Optional[U2idx] = None
+ self.log_func: Util.RootLogger = hsrv.log # mypy404
+ self.log_src: str = "httpconn" # mypy404
+ self.lf_url: Optional[Pattern[str]] = (
+ re.compile(self.args.lf_url) if self.args.lf_url else None
+ ) # mypy404
self.set_rproxy()
- def shutdown(self):
+ def shutdown(self) -> None:
self.stopping = True
try:
self.s.shutdown(socket.SHUT_RDWR)
@@ -61,7 +77,7 @@ class HttpConn(object):
except:
pass
- def set_rproxy(self, ip=None):
+ def set_rproxy(self, ip: Optional[str] = None) -> str:
if ip is None:
color = 36
ip = self.addr[0]
@@ -74,35 +90,35 @@ class HttpConn(object):
self.log_src = "{} \033[{}m{}".format(ip, color, self.addr[1]).ljust(26)
return self.log_src
- def respath(self, res_name):
+ def respath(self, res_name: str) -> str:
return os.path.join(E.mod, "web", res_name)
- def log(self, msg, c=0):
+ def log(self, msg: str, c: Union[int, str] = 0) -> None:
self.log_func(self.log_src, msg, c)
- def get_u2idx(self):
+ def get_u2idx(self) -> U2idx:
if not self.u2idx:
self.u2idx = U2idx(self)
return self.u2idx
- def _detect_https(self):
+ def _detect_https(self) -> bool:
method = None
if self.cert_path:
try:
method = self.s.recv(4, socket.MSG_PEEK)
except socket.timeout:
- return
+ return False
except AttributeError:
# jython does not support msg_peek; forget about https
method = self.s.recv(4)
- self.sr = Unrecv(self.s, self.log)
+ self.sr = Util.Unrecv(self.s, self.log)
self.sr.buf = method
# jython used to do this, they stopped since it's broken
# but reimplementing sendall is out of scope for now
if not getattr(self.s, "sendall", None):
- self.s.sendall = self.s.send
+ self.s.sendall = self.s.send # type: ignore
if len(method) != 4:
err = "need at least 4 bytes in the first packet; got {}".format(
@@ -112,17 +128,18 @@ class HttpConn(object):
self.log(err)
self.s.send(b"HTTP/1.1 400 Bad Request\r\n\r\n" + err.encode("utf-8"))
- return
+ return False
return method not in [None, b"GET ", b"HEAD", b"POST", b"PUT ", b"OPTI"]
- def run(self):
+ def run(self) -> None:
self.sr = None
if self.args.https_only:
is_https = True
elif self.args.http_only or not HAVE_SSL:
is_https = False
else:
+ # raise Exception("asdf")
is_https = self._detect_https()
if is_https:
@@ -151,14 +168,15 @@ class HttpConn(object):
self.s = ctx.wrap_socket(self.s, server_side=True)
msg = [
"\033[1;3{:d}m{}".format(c, s)
- for c, s in zip([0, 5, 0], self.s.cipher())
+ for c, s in zip([0, 5, 0], self.s.cipher()) # type: ignore
]
self.log(" ".join(msg) + "\033[0m")
if self.args.ssl_dbg and hasattr(self.s, "shared_ciphers"):
- overlap = [y[::-1] for y in self.s.shared_ciphers()]
- lines = [str(x) for x in (["TLS cipher overlap:"] + overlap)]
- self.log("\n".join(lines))
+ ciphers = self.s.shared_ciphers()
+ assert ciphers
+ overlap = [str(y[::-1]) for y in ciphers]
+ self.log("TLS cipher overlap:" + "\n".join(overlap))
for k, v in [
["compression", self.s.compression()],
["ALPN proto", self.s.selected_alpn_protocol()],
@@ -183,7 +201,7 @@ class HttpConn(object):
return
if not self.sr:
- self.sr = Unrecv(self.s, self.log)
+ self.sr = Util.Unrecv(self.s, self.log)
while not self.stopping:
self.nreq += 1
diff --git a/copyparty/httpsrv.py b/copyparty/httpsrv.py
index 04d2146c..fd449b5f 100644
--- a/copyparty/httpsrv.py
+++ b/copyparty/httpsrv.py
@@ -1,13 +1,15 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
-import os
-import sys
-import time
-import math
import base64
+import math
+import os
import socket
+import sys
import threading
+import time
+
+import queue
try:
import jinja2
@@ -26,15 +28,18 @@ except ImportError:
)
sys.exit(1)
-from .__init__ import E, PY2, MACOS
-from .util import FHC, spack, min_ex, start_stackmon, start_log_thrs
+from .__init__ import MACOS, TYPE_CHECKING, E
from .bos import bos
from .httpconn import HttpConn
+from .util import FHC, min_ex, spack, start_log_thrs, start_stackmon
-if PY2:
- import Queue as queue
-else:
- import queue
+if TYPE_CHECKING:
+ from .broker_util import BrokerCli
+
+try:
+ from typing import Any, Optional
+except:
+ pass
class HttpSrv(object):
@@ -43,7 +48,7 @@ class HttpSrv(object):
relying on MpSrv for performance (HttpSrv is just plain threads)
"""
- def __init__(self, broker, nid):
+ def __init__(self, broker: "BrokerCli", nid: Optional[int]) -> None:
self.broker = broker
self.nid = nid
self.args = broker.args
@@ -58,17 +63,19 @@ class HttpSrv(object):
self.tp_nthr = 0 # actual
self.tp_ncli = 0 # fading
- self.tp_time = None # latest worker collect
- self.tp_q = None if self.args.no_htp else queue.LifoQueue()
- self.t_periodic = None
+ self.tp_time = 0.0 # latest worker collect
+ self.tp_q: Optional[queue.LifoQueue[Any]] = (
+ None if self.args.no_htp else queue.LifoQueue()
+ )
+ self.t_periodic: Optional[threading.Thread] = None
self.u2fh = FHC()
- self.srvs = []
+ self.srvs: list[socket.socket] = []
self.ncli = 0 # exact
- self.clients = {} # laggy
+ self.clients: set[HttpConn] = set() # laggy
self.nclimax = 0
- self.cb_ts = 0
- self.cb_v = 0
+ self.cb_ts = 0.0
+ self.cb_v = ""
env = jinja2.Environment()
env.loader = jinja2.FileSystemLoader(os.path.join(E.mod, "web"))
@@ -82,7 +89,7 @@ class HttpSrv(object):
if bos.path.exists(cert_path):
self.cert_path = cert_path
else:
- self.cert_path = None
+ self.cert_path = ""
if self.tp_q:
self.start_threads(4)
@@ -94,19 +101,19 @@ class HttpSrv(object):
if self.args.log_thrs:
start_log_thrs(self.log, self.args.log_thrs, nid)
- self.th_cfg = {} # type: dict[str, Any]
+ self.th_cfg: dict[str, Any] = {}
t = threading.Thread(target=self.post_init)
t.daemon = True
t.start()
- def post_init(self):
+ def post_init(self) -> None:
try:
- x = self.broker.put(True, "thumbsrv.getcfg")
+ x = self.broker.ask("thumbsrv.getcfg")
self.th_cfg = x.get()
except:
pass
- def start_threads(self, n):
+ def start_threads(self, n: int) -> None:
self.tp_nthr += n
if self.args.log_htp:
self.log(self.name, "workers += {} = {}".format(n, self.tp_nthr), 6)
@@ -119,15 +126,16 @@ class HttpSrv(object):
thr.daemon = True
thr.start()
- def stop_threads(self, n):
+ def stop_threads(self, n: int) -> None:
self.tp_nthr -= n
if self.args.log_htp:
self.log(self.name, "workers -= {} = {}".format(n, self.tp_nthr), 6)
+ assert self.tp_q
for _ in range(n):
self.tp_q.put(None)
- def periodic(self):
+ def periodic(self) -> None:
while True:
time.sleep(2 if self.tp_ncli or self.ncli else 10)
with self.mutex:
@@ -141,7 +149,7 @@ class HttpSrv(object):
self.t_periodic = None
return
- def listen(self, sck, nlisteners):
+ def listen(self, sck: socket.socket, nlisteners: int) -> None:
ip, port = sck.getsockname()
self.srvs.append(sck)
self.nclimax = math.ceil(self.args.nc * 1.0 / nlisteners)
@@ -153,15 +161,15 @@ class HttpSrv(object):
t.daemon = True
t.start()
- def thr_listen(self, srv_sck):
+ def thr_listen(self, srv_sck: socket.socket) -> None:
"""listens on a shared tcp server"""
ip, port = srv_sck.getsockname()
fno = srv_sck.fileno()
msg = "subscribed @ {}:{} f{}".format(ip, port, fno)
self.log(self.name, msg)
- def fun():
- self.broker.put(False, "cb_httpsrv_up")
+ def fun() -> None:
+ self.broker.say("cb_httpsrv_up")
threading.Thread(target=fun).start()
@@ -185,21 +193,21 @@ class HttpSrv(object):
continue
if self.args.log_conn:
- m = "|{}C-acc2 \033[0;36m{} \033[3{}m{}".format(
+ t = "|{}C-acc2 \033[0;36m{} \033[3{}m{}".format(
"-" * 3, ip, port % 8, port
)
- self.log("%s %s" % addr, m, c="1;30")
+ self.log("%s %s" % addr, t, c="1;30")
self.accept(sck, addr)
- def accept(self, sck, addr):
+ def accept(self, sck: socket.socket, addr: tuple[str, int]) -> None:
"""takes an incoming tcp connection and creates a thread to handle it"""
now = time.time()
if now - (self.tp_time or now) > 300:
- m = "httpserver threadpool died: tpt {:.2f}, now {:.2f}, nthr {}, ncli {}"
- self.log(self.name, m.format(self.tp_time, now, self.tp_nthr, self.ncli), 1)
- self.tp_time = None
+ t = "httpserver threadpool died: tpt {:.2f}, now {:.2f}, nthr {}, ncli {}"
+ self.log(self.name, t.format(self.tp_time, now, self.tp_nthr, self.ncli), 1)
+ self.tp_time = 0
self.tp_q = None
with self.mutex:
@@ -209,10 +217,10 @@ class HttpSrv(object):
if self.nid:
name += "-{}".format(self.nid)
- t = threading.Thread(target=self.periodic, name=name)
- self.t_periodic = t
- t.daemon = True
- t.start()
+ thr = threading.Thread(target=self.periodic, name=name)
+ self.t_periodic = thr
+ thr.daemon = True
+ thr.start()
if self.tp_q:
self.tp_time = self.tp_time or now
@@ -224,8 +232,8 @@ class HttpSrv(object):
return
if not self.args.no_htp:
- m = "looks like the httpserver threadpool died; please make an issue on github and tell me the story of how you pulled that off, thanks and dog bless\n"
- self.log(self.name, m, 1)
+ t = "looks like the httpserver threadpool died; please make an issue on github and tell me the story of how you pulled that off, thanks and dog bless\n"
+ self.log(self.name, t, 1)
thr = threading.Thread(
target=self.thr_client,
@@ -235,14 +243,15 @@ class HttpSrv(object):
thr.daemon = True
thr.start()
- def thr_poolw(self):
+ def thr_poolw(self) -> None:
+ assert self.tp_q
while True:
task = self.tp_q.get()
if not task:
break
with self.mutex:
- self.tp_time = None
+ self.tp_time = 0
try:
sck, addr = task
@@ -255,7 +264,7 @@ class HttpSrv(object):
except:
self.log(self.name, "thr_client: " + min_ex(), 3)
- def shutdown(self):
+ def shutdown(self) -> None:
self.stopping = True
for srv in self.srvs:
try:
@@ -263,7 +272,7 @@ class HttpSrv(object):
except:
pass
- clients = list(self.clients.keys())
+ clients = list(self.clients)
for cli in clients:
try:
cli.shutdown()
@@ -279,13 +288,13 @@ class HttpSrv(object):
self.log(self.name, "ok bye")
- def thr_client(self, sck, addr):
+ def thr_client(self, sck: socket.socket, addr: tuple[str, int]) -> None:
"""thread managing one tcp client"""
sck.settimeout(120)
cli = HttpConn(sck, addr, self)
with self.mutex:
- self.clients[cli] = 0
+ self.clients.add(cli)
fno = sck.fileno()
try:
@@ -328,10 +337,10 @@ class HttpSrv(object):
raise
finally:
with self.mutex:
- del self.clients[cli]
+ self.clients.remove(cli)
self.ncli -= 1
- def cachebuster(self):
+ def cachebuster(self) -> str:
if time.time() - self.cb_ts < 1:
return self.cb_v
diff --git a/copyparty/ico.py b/copyparty/ico.py
index 58076c89..f403e4b5 100644
--- a/copyparty/ico.py
+++ b/copyparty/ico.py
@@ -1,28 +1,28 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
-import hashlib
+import argparse # typechk
import colorsys
+import hashlib
from .__init__ import PY2
class Ico(object):
- def __init__(self, args):
+ def __init__(self, args: argparse.Namespace) -> None:
self.args = args
- def get(self, ext, as_thumb):
+ def get(self, ext: str, as_thumb: bool) -> tuple[str, bytes]:
"""placeholder to make thumbnails not break"""
- h = hashlib.md5(ext.encode("utf-8")).digest()[:2]
+ zb = hashlib.md5(ext.encode("utf-8")).digest()[:2]
if PY2:
- h = [ord(x) for x in h]
+ zb = [ord(x) for x in zb]
- c1 = colorsys.hsv_to_rgb(h[0] / 256.0, 1, 0.3)
- c2 = colorsys.hsv_to_rgb(h[0] / 256.0, 1, 1)
- c = list(c1) + list(c2)
- c = [int(x * 255) for x in c]
- c = "".join(["{:02x}".format(x) for x in c])
+ c1 = colorsys.hsv_to_rgb(zb[0] / 256.0, 1, 0.3)
+ c2 = colorsys.hsv_to_rgb(zb[0] / 256.0, 1, 1)
+ ci = [int(x * 255) for x in list(c1) + list(c2)]
+ c = "".join(["{:02x}".format(x) for x in ci])
h = 30
if not self.args.th_no_crop and as_thumb:
@@ -37,6 +37,6 @@ class Ico(object):
fill="#{}" font-family="monospace" font-size="14px" style="letter-spacing:.5px">{}
"""
- svg = svg.format(h, c[:6], c[6:], ext).encode("utf-8")
+ svg = svg.format(h, c[:6], c[6:], ext)
- return ["image/svg+xml", svg]
+ return "image/svg+xml", svg.encode("utf-8")
diff --git a/copyparty/mtag.py b/copyparty/mtag.py
index 60a7584f..dd1e2ec3 100644
--- a/copyparty/mtag.py
+++ b/copyparty/mtag.py
@@ -1,18 +1,26 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
-import os
-import sys
+import argparse
import json
+import os
import shutil
import subprocess as sp
+import sys
from .__init__ import PY2, WINDOWS, unicode
-from .util import fsenc, uncyg, runcmd, retchk, REKOBO_LKEY
from .bos import bos
+from .util import REKOBO_LKEY, fsenc, retchk, runcmd, uncyg
+
+try:
+ from typing import Any, Union
+
+ from .util import RootLogger
+except:
+ pass
-def have_ff(cmd):
+def have_ff(cmd: str) -> bool:
if PY2:
print("# checking {}".format(cmd))
cmd = (cmd + " -version").encode("ascii").split(b" ")
@@ -30,7 +38,7 @@ HAVE_FFPROBE = have_ff("ffprobe")
class MParser(object):
- def __init__(self, cmdline):
+ def __init__(self, cmdline: str) -> None:
self.tag, args = cmdline.split("=", 1)
self.tags = self.tag.split(",")
@@ -73,7 +81,9 @@ class MParser(object):
raise Exception()
-def ffprobe(abspath, timeout=10):
+def ffprobe(
+ abspath: str, timeout: int = 10
+) -> tuple[dict[str, tuple[int, Any]], dict[str, list[Any]]]:
cmd = [
b"ffprobe",
b"-hide_banner",
@@ -87,15 +97,15 @@ def ffprobe(abspath, timeout=10):
return parse_ffprobe(so)
-def parse_ffprobe(txt):
+def parse_ffprobe(txt: str) -> tuple[dict[str, tuple[int, Any]], dict[str, list[Any]]]:
"""ffprobe -show_format -show_streams"""
streams = []
fmt = {}
g = {}
for ln in [x.rstrip("\r") for x in txt.split("\n")]:
try:
- k, v = ln.split("=", 1)
- g[k] = v
+ sk, sv = ln.split("=", 1)
+ g[sk] = sv
continue
except:
pass
@@ -109,8 +119,8 @@ def parse_ffprobe(txt):
fmt = g
streams = [fmt] + streams
- ret = {} # processed
- md = {} # raw tags
+ ret: dict[str, Any] = {} # processed
+ md: dict[str, list[Any]] = {} # raw tags
is_audio = fmt.get("format_name") in ["mp3", "ogg", "flac", "wav"]
if fmt.get("filename", "").split(".")[-1].lower() in ["m4a", "aac"]:
@@ -161,43 +171,43 @@ def parse_ffprobe(txt):
kvm = [["duration", ".dur"], ["bit_rate", ".q"]]
for sk, rk in kvm:
- v = strm.get(sk)
- if v is None:
+ v1 = strm.get(sk)
+ if v1 is None:
continue
if rk.startswith("."):
try:
- v = float(v)
+ zf = float(v1)
v2 = ret.get(rk)
- if v2 is None or v > v2:
- ret[rk] = v
+ if v2 is None or zf > v2:
+ ret[rk] = zf
except:
# sqlite doesnt care but the code below does
- if v not in ["N/A"]:
- ret[rk] = v
+ if v1 not in ["N/A"]:
+ ret[rk] = v1
else:
- ret[rk] = v
+ ret[rk] = v1
if ret.get("vc") == "ansi": # shellscript
return {}, {}
for strm in streams:
- for k, v in strm.items():
- if not k.startswith("TAG:"):
+ for sk, sv in strm.items():
+ if not sk.startswith("TAG:"):
continue
- k = k[4:].strip()
- v = v.strip()
- if k and v and k not in md:
- md[k] = [v]
+ sk = sk[4:].strip()
+ sv = sv.strip()
+ if sk and sv and sk not in md:
+ md[sk] = [sv]
- for k in [".q", ".vq", ".aq"]:
- if k in ret:
- ret[k] /= 1000 # bit_rate=320000
+ for sk in [".q", ".vq", ".aq"]:
+ if sk in ret:
+ ret[sk] /= 1000 # bit_rate=320000
- for k in [".q", ".vq", ".aq", ".resw", ".resh"]:
- if k in ret:
- ret[k] = int(ret[k])
+ for sk in [".q", ".vq", ".aq", ".resw", ".resh"]:
+ if sk in ret:
+ ret[sk] = int(ret[sk])
if ".fps" in ret:
fps = ret[".fps"]
@@ -219,13 +229,13 @@ def parse_ffprobe(txt):
if ".resw" in ret and ".resh" in ret:
ret["res"] = "{}x{}".format(ret[".resw"], ret[".resh"])
- ret = {k: [0, v] for k, v in ret.items()}
+ zd = {k: (0, v) for k, v in ret.items()}
- return ret, md
+ return zd, md
class MTag(object):
- def __init__(self, log_func, args):
+ def __init__(self, log_func: RootLogger, args: argparse.Namespace) -> None:
self.log_func = log_func
self.args = args
self.usable = True
@@ -242,7 +252,7 @@ class MTag(object):
if self.backend == "mutagen":
self.get = self.get_mutagen
try:
- import mutagen
+ import mutagen # noqa: F401 # pylint: disable=unused-import,import-outside-toplevel
except:
self.log("could not load Mutagen, trying FFprobe instead", c=3)
self.backend = "ffprobe"
@@ -339,31 +349,33 @@ class MTag(object):
}
# self.get = self.compare
- def log(self, msg, c=0):
+ def log(self, msg: str, c: Union[int, str] = 0) -> None:
self.log_func("mtag", msg, c)
- def normalize_tags(self, ret, md):
- for k, v in dict(md).items():
- if not v:
+ def normalize_tags(
+ self, parser_output: dict[str, tuple[int, Any]], md: dict[str, list[Any]]
+ ) -> dict[str, Union[str, float]]:
+ for sk, tv in dict(md).items():
+ if not tv:
continue
- k = k.lower().split("::")[0].strip()
- mk = self.rmap.get(k)
- if not mk:
+ sk = sk.lower().split("::")[0].strip()
+ key_mapping = self.rmap.get(sk)
+ if not key_mapping:
continue
- pref, mk = mk
- if mk not in ret or ret[mk][0] > pref:
- ret[mk] = [pref, v[0]]
+ priority, alias = key_mapping
+ if alias not in parser_output or parser_output[alias][0] > priority:
+ parser_output[alias] = (priority, tv[0])
- # take first value
- ret = {k: unicode(v[1]).strip() for k, v in ret.items()}
+ # take first value (lowest priority / most preferred)
+ ret = {sk: unicode(tv[1]).strip() for sk, tv in parser_output.items()}
# track 3/7 => track 3
- for k, v in ret.items():
- if k[0] == ".":
- v = v.split("/")[0].strip().lstrip("0")
- ret[k] = v or 0
+ for sk, tv in ret.items():
+ if sk[0] == ".":
+ sv = str(tv).split("/")[0].strip().lstrip("0")
+ ret[sk] = sv or 0
# normalize key notation to rkeobo
okey = ret.get("key")
@@ -373,7 +385,7 @@ class MTag(object):
return ret
- def compare(self, abspath):
+ def compare(self, abspath: str) -> dict[str, Union[str, float]]:
if abspath.endswith(".au"):
return {}
@@ -411,7 +423,7 @@ class MTag(object):
return r1
- def get_mutagen(self, abspath):
+ def get_mutagen(self, abspath: str) -> dict[str, Union[str, float]]:
if not bos.path.isfile(abspath):
return {}
@@ -425,7 +437,7 @@ class MTag(object):
return self.get_ffprobe(abspath) if self.can_ffprobe else {}
sz = bos.path.getsize(abspath)
- ret = {".q": [0, int((sz / md.info.length) / 128)]}
+ ret = {".q": (0, int((sz / md.info.length) / 128))}
for attr, k, norm in [
["codec", "ac", unicode],
@@ -456,24 +468,24 @@ class MTag(object):
if k == "ac" and v.startswith("mp4a.40."):
v = "aac"
- ret[k] = [0, norm(v)]
+ ret[k] = (0, norm(v))
return self.normalize_tags(ret, md)
- def get_ffprobe(self, abspath):
+ def get_ffprobe(self, abspath: str) -> dict[str, Union[str, float]]:
if not bos.path.isfile(abspath):
return {}
ret, md = ffprobe(abspath)
return self.normalize_tags(ret, md)
- def get_bin(self, parsers, abspath):
+ def get_bin(self, parsers: dict[str, MParser], abspath: str) -> dict[str, Any]:
if not bos.path.isfile(abspath):
return {}
pypath = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
- pypath = [str(pypath)] + [str(x) for x in sys.path if x]
- pypath = str(os.pathsep.join(pypath))
+ zsl = [str(pypath)] + [str(x) for x in sys.path if x]
+ pypath = str(os.pathsep.join(zsl))
env = os.environ.copy()
env["PYTHONPATH"] = pypath
@@ -491,9 +503,9 @@ class MTag(object):
else:
cmd = ["nice"] + cmd
- cmd = [fsenc(x) for x in cmd]
- rc, v, err = runcmd(cmd, **args)
- retchk(rc, cmd, err, self.log, 5, self.args.mtag_v)
+ bcmd = [fsenc(x) for x in cmd]
+ rc, v, err = runcmd(bcmd, **args) # type: ignore
+ retchk(rc, bcmd, err, self.log, 5, self.args.mtag_v)
v = v.strip()
if not v:
continue
@@ -501,10 +513,10 @@ class MTag(object):
if "," not in tagname:
ret[tagname] = v
else:
- v = json.loads(v)
+ zj = json.loads(v)
for tag in tagname.split(","):
- if tag and tag in v:
- ret[tag] = v[tag]
+ if tag and tag in zj:
+ ret[tag] = zj[tag]
except:
pass
diff --git a/copyparty/star.py b/copyparty/star.py
index 804ad724..21c2703c 100644
--- a/copyparty/star.py
+++ b/copyparty/star.py
@@ -4,20 +4,29 @@ from __future__ import print_function, unicode_literals
import tarfile
import threading
-from .sutil import errdesc
-from .util import Queue, fsenc, min_ex
+from queue import Queue
+
from .bos import bos
+from .sutil import StreamArc, errdesc
+from .util import fsenc, min_ex
+
+try:
+ from typing import Any, Generator, Optional
+
+ from .util import NamedLogger
+except:
+ pass
-class QFile(object):
+class QFile(object): # inherit io.StringIO for painful typing
"""file-like object which buffers writes into a queue"""
- def __init__(self):
- self.q = Queue(64)
- self.bq = []
+ def __init__(self) -> None:
+ self.q: Queue[Optional[bytes]] = Queue(64)
+ self.bq: list[bytes] = []
self.nq = 0
- def write(self, buf):
+ def write(self, buf: Optional[bytes]) -> None:
if buf is None or self.nq >= 240 * 1024:
self.q.put(b"".join(self.bq))
self.bq = []
@@ -30,27 +39,32 @@ class QFile(object):
self.nq += len(buf)
-class StreamTar(object):
+class StreamTar(StreamArc):
"""construct in-memory tar file from the given path"""
- def __init__(self, log, fgen, **kwargs):
+ def __init__(
+ self,
+ log: NamedLogger,
+ fgen: Generator[dict[str, Any], None, None],
+ **kwargs: Any
+ ):
+ super(StreamTar, self).__init__(log, fgen)
+
self.ci = 0
self.co = 0
self.qfile = QFile()
- self.log = log
- self.fgen = fgen
- self.errf = None
+ self.errf: dict[str, Any] = {}
# python 3.8 changed to PAX_FORMAT as default,
# waste of space and don't care about the new features
fmt = tarfile.GNU_FORMAT
- self.tar = tarfile.open(fileobj=self.qfile, mode="w|", format=fmt)
+ self.tar = tarfile.open(fileobj=self.qfile, mode="w|", format=fmt) # type: ignore
w = threading.Thread(target=self._gen, name="star-gen")
w.daemon = True
w.start()
- def gen(self):
+ def gen(self) -> Generator[Optional[bytes], None, None]:
while True:
buf = self.qfile.q.get()
if not buf:
@@ -63,7 +77,7 @@ class StreamTar(object):
if self.errf:
bos.unlink(self.errf["ap"])
- def ser(self, f):
+ def ser(self, f: dict[str, Any]) -> None:
name = f["vp"]
src = f["ap"]
fsi = f["st"]
@@ -76,21 +90,21 @@ class StreamTar(object):
inf.gid = 0
self.ci += inf.size
- with open(fsenc(src), "rb", 512 * 1024) as f:
- self.tar.addfile(inf, f)
+ with open(fsenc(src), "rb", 512 * 1024) as fo:
+ self.tar.addfile(inf, fo)
- def _gen(self):
+ def _gen(self) -> None:
errors = []
for f in self.fgen:
if "err" in f:
- errors.append([f["vp"], f["err"]])
+ errors.append((f["vp"], f["err"]))
continue
try:
self.ser(f)
except:
ex = min_ex(5, True).replace("\n", "\n-- ")
- errors.append([f["vp"], ex])
+ errors.append((f["vp"], ex))
if errors:
self.errf, txt = errdesc(errors)
diff --git a/copyparty/stolen/surrogateescape.py b/copyparty/stolen/surrogateescape.py
index 4b06ed28..b1ff8886 100644
--- a/copyparty/stolen/surrogateescape.py
+++ b/copyparty/stolen/surrogateescape.py
@@ -12,23 +12,28 @@ Original source: misc/python/surrogateescape.py in https://bitbucket.org/haypo/m
# This code is released under the Python license and the BSD 2-clause license
-import platform
import codecs
+import platform
import sys
PY3 = sys.version_info[0] > 2
WINDOWS = platform.system() == "Windows"
FS_ERRORS = "surrogateescape"
+try:
+ from typing import Any
+except:
+ pass
-def u(text):
+
+def u(text: Any) -> str:
if PY3:
return text
else:
return text.decode("unicode_escape")
-def b(data):
+def b(data: Any) -> bytes:
if PY3:
return data.encode("latin1")
else:
@@ -43,7 +48,7 @@ else:
bytes_chr = chr
-def surrogateescape_handler(exc):
+def surrogateescape_handler(exc: Any) -> tuple[str, int]:
"""
Pure Python implementation of the PEP 383: the "surrogateescape" error
handler of Python 3. Undecodable bytes will be replaced by a Unicode
@@ -74,7 +79,7 @@ class NotASurrogateError(Exception):
pass
-def replace_surrogate_encode(mystring):
+def replace_surrogate_encode(mystring: str) -> str:
"""
Returns a (unicode) string, not the more logical bytes, because the codecs
register_error functionality expects this.
@@ -100,7 +105,7 @@ def replace_surrogate_encode(mystring):
return str().join(decoded)
-def replace_surrogate_decode(mybytes):
+def replace_surrogate_decode(mybytes: bytes) -> str:
"""
Returns a (unicode) string
"""
@@ -121,7 +126,7 @@ def replace_surrogate_decode(mybytes):
return str().join(decoded)
-def encodefilename(fn):
+def encodefilename(fn: str) -> bytes:
if FS_ENCODING == "ascii":
# ASCII encoder of Python 2 expects that the error handler returns a
# Unicode string encodable to ASCII, whereas our surrogateescape error
@@ -161,7 +166,7 @@ def encodefilename(fn):
return fn.encode(FS_ENCODING, FS_ERRORS)
-def decodefilename(fn):
+def decodefilename(fn: bytes) -> str:
return fn.decode(FS_ENCODING, FS_ERRORS)
@@ -181,7 +186,7 @@ if WINDOWS and not PY3:
FS_ENCODING = codecs.lookup(FS_ENCODING).name
-def register_surrogateescape():
+def register_surrogateescape() -> None:
"""
Registers the surrogateescape error handler on Python 2 (only)
"""
diff --git a/copyparty/sutil.py b/copyparty/sutil.py
index 22de0066..506e389f 100644
--- a/copyparty/sutil.py
+++ b/copyparty/sutil.py
@@ -6,8 +6,29 @@ from datetime import datetime
from .bos import bos
+try:
+ from typing import Any, Generator, Optional
-def errdesc(errors):
+ from .util import NamedLogger
+except:
+ pass
+
+
+class StreamArc(object):
+ def __init__(
+ self,
+ log: NamedLogger,
+ fgen: Generator[dict[str, Any], None, None],
+ **kwargs: Any
+ ):
+ self.log = log
+ self.fgen = fgen
+
+ def gen(self) -> Generator[Optional[bytes], None, None]:
+ pass
+
+
+def errdesc(errors: list[tuple[str, str]]) -> tuple[dict[str, Any], list[str]]:
report = ["copyparty failed to add the following files to the archive:", ""]
for fn, err in errors:
diff --git a/copyparty/svchub.py b/copyparty/svchub.py
index 44680513..e9b5aadf 100644
--- a/copyparty/svchub.py
+++ b/copyparty/svchub.py
@@ -1,41 +1,51 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
+import argparse
+import calendar
import os
-import sys
-import time
import shlex
-import string
import signal
import socket
+import string
+import sys
import threading
+import time
from datetime import datetime, timedelta
-import calendar
-from .__init__ import E, PY2, WINDOWS, ANYWIN, MACOS, VT100, unicode
-from .util import mp, start_log_thrs, start_stackmon, min_ex, ansi_re
+try:
+ from types import FrameType
+
+ import typing
+ from typing import Optional, Union
+except:
+ pass
+
+from .__init__ import ANYWIN, MACOS, PY2, VT100, WINDOWS, E, unicode
from .authsrv import AuthSrv
-from .tcpsrv import TcpSrv
-from .up2k import Up2k
-from .th_srv import ThumbSrv, HAVE_PIL, HAVE_VIPS, HAVE_WEBP
from .mtag import HAVE_FFMPEG, HAVE_FFPROBE
+from .tcpsrv import TcpSrv
+from .th_srv import HAVE_PIL, HAVE_VIPS, HAVE_WEBP, ThumbSrv
+from .up2k import Up2k
+from .util import ansi_re, min_ex, mp, start_log_thrs, start_stackmon
class SvcHub(object):
"""
Hosts all services which cannot be parallelized due to reliance on monolithic resources.
Creates a Broker which does most of the heavy stuff; hosted services can use this to perform work:
- hub.broker.put(want_reply, destination, args_list).
+ hub.broker.(destination, args_list).
Either BrokerThr (plain threads) or BrokerMP (multiprocessing) is used depending on configuration.
Nothing is returned synchronously; if you want any value returned from the call,
put() can return a queue (if want_reply=True) which has a blocking get() with the response.
"""
- def __init__(self, args, argv, printed):
+ def __init__(self, args: argparse.Namespace, argv: list[str], printed: str) -> None:
self.args = args
self.argv = argv
- self.logf = None
+ self.logf: Optional[typing.TextIO] = None
+ self.logf_base_fn = ""
self.stop_req = False
self.reload_req = False
self.stopping = False
@@ -59,16 +69,16 @@ class SvcHub(object):
if not args.use_fpool and args.j != 1:
args.no_fpool = True
- m = "multithreading enabled with -j {}, so disabling fpool -- this can reduce upload performance on some filesystems"
- self.log("root", m.format(args.j))
+ t = "multithreading enabled with -j {}, so disabling fpool -- this can reduce upload performance on some filesystems"
+ self.log("root", t.format(args.j))
if not args.no_fpool and args.j != 1:
- m = "WARNING: --use-fpool combined with multithreading is untested and can probably cause undefined behavior"
+ t = "WARNING: --use-fpool combined with multithreading is untested and can probably cause undefined behavior"
if ANYWIN:
- m = 'windows cannot do multithreading without --no-fpool, so enabling that -- note that upload performance will suffer if you have microsoft defender "real-time protection" enabled, so you probably want to use -j 1 instead'
+ t = 'windows cannot do multithreading without --no-fpool, so enabling that -- note that upload performance will suffer if you have microsoft defender "real-time protection" enabled, so you probably want to use -j 1 instead'
args.no_fpool = True
- self.log("root", m, c=3)
+ self.log("root", t, c=3)
bri = "zy"[args.theme % 2 :][:1]
ch = "abcdefghijklmnopqrstuvwx"[int(args.theme / 2)]
@@ -96,8 +106,8 @@ class SvcHub(object):
self.args.th_dec = list(decs.keys())
self.thumbsrv = None
if not args.no_thumb:
- m = "decoder preference: {}".format(", ".join(self.args.th_dec))
- self.log("thumb", m)
+ t = "decoder preference: {}".format(", ".join(self.args.th_dec))
+ self.log("thumb", t)
if "pil" in self.args.th_dec and not HAVE_WEBP:
msg = "disabling webp thumbnails because either libwebp is not available or your Pillow is too old"
@@ -131,11 +141,11 @@ class SvcHub(object):
if self.check_mp_enable():
from .broker_mp import BrokerMp as Broker
else:
- from .broker_thr import BrokerThr as Broker
+ from .broker_thr import BrokerThr as Broker # type: ignore
self.broker = Broker(self)
- def thr_httpsrv_up(self):
+ def thr_httpsrv_up(self) -> None:
time.sleep(1 if self.args.ign_ebind_all else 5)
expected = self.broker.num_workers * self.tcpsrv.nsrv
failed = expected - self.httpsrv_up
@@ -145,20 +155,20 @@ class SvcHub(object):
if self.args.ign_ebind_all:
if not self.tcpsrv.srv:
for _ in range(self.broker.num_workers):
- self.broker.put(False, "cb_httpsrv_up")
+ self.broker.say("cb_httpsrv_up")
return
if self.args.ign_ebind and self.tcpsrv.srv:
return
- m = "{}/{} workers failed to start"
- m = m.format(failed, expected)
- self.log("root", m, 1)
+ t = "{}/{} workers failed to start"
+ t = t.format(failed, expected)
+ self.log("root", t, 1)
self.retcode = 1
os.kill(os.getpid(), signal.SIGTERM)
- def cb_httpsrv_up(self):
+ def cb_httpsrv_up(self) -> None:
self.httpsrv_up += 1
if self.httpsrv_up != self.broker.num_workers:
return
@@ -171,9 +181,9 @@ class SvcHub(object):
thr.daemon = True
thr.start()
- def _logname(self):
+ def _logname(self) -> str:
dt = datetime.utcnow()
- fn = self.args.lo
+ fn = str(self.args.lo)
for fs in "YmdHMS":
fs = "%" + fs
if fs in fn:
@@ -181,7 +191,7 @@ class SvcHub(object):
return fn
- def _setup_logfile(self, printed):
+ def _setup_logfile(self, printed: str) -> None:
base_fn = fn = sel_fn = self._logname()
if fn != self.args.lo:
ctr = 0
@@ -203,8 +213,6 @@ class SvcHub(object):
lh = codecs.open(fn, "w", encoding="utf-8", errors="replace")
- lh.base_fn = base_fn
-
argv = [sys.executable] + self.argv
if hasattr(shlex, "quote"):
argv = [shlex.quote(x) for x in argv]
@@ -215,9 +223,10 @@ class SvcHub(object):
printed += msg
lh.write("t0: {:.3f}\nargv: {}\n\n{}".format(E.t0, " ".join(argv), printed))
self.logf = lh
+ self.logf_base_fn = base_fn
print(msg, end="")
- def run(self):
+ def run(self) -> None:
self.tcpsrv.run()
thr = threading.Thread(target=self.thr_httpsrv_up)
@@ -252,7 +261,7 @@ class SvcHub(object):
else:
self.stop_thr()
- def reload(self):
+ def reload(self) -> str:
if self.reloading:
return "cannot reload; already in progress"
@@ -262,7 +271,7 @@ class SvcHub(object):
t.start()
return "reload initiated"
- def _reload(self):
+ def _reload(self) -> None:
self.log("root", "reload scheduled")
with self.up2k.mutex:
self.asrv.reload()
@@ -271,7 +280,7 @@ class SvcHub(object):
self.reloading = False
- def stop_thr(self):
+ def stop_thr(self) -> None:
while not self.stop_req:
with self.stop_cond:
self.stop_cond.wait(9001)
@@ -282,7 +291,7 @@ class SvcHub(object):
self.shutdown()
- def signal_handler(self, sig, frame):
+ def signal_handler(self, sig: int, frame: Optional[FrameType]) -> None:
if self.stopping:
return
@@ -294,7 +303,7 @@ class SvcHub(object):
with self.stop_cond:
self.stop_cond.notify_all()
- def shutdown(self):
+ def shutdown(self) -> None:
if self.stopping:
return
@@ -337,7 +346,7 @@ class SvcHub(object):
sys.exit(ret)
- def _log_disabled(self, src, msg, c=0):
+ def _log_disabled(self, src: str, msg: str, c: Union[int, str] = 0) -> None:
if not self.logf:
return
@@ -349,8 +358,8 @@ class SvcHub(object):
if now >= self.next_day:
self._set_next_day()
- def _set_next_day(self):
- if self.next_day and self.logf and self.logf.base_fn != self._logname():
+ def _set_next_day(self) -> None:
+ if self.next_day and self.logf and self.logf_base_fn != self._logname():
self.logf.close()
self._setup_logfile("")
@@ -364,7 +373,7 @@ class SvcHub(object):
dt = dt.replace(hour=0, minute=0, second=0)
self.next_day = calendar.timegm(dt.utctimetuple())
- def _log_enabled(self, src, msg, c=0):
+ def _log_enabled(self, src: str, msg: str, c: Union[int, str] = 0) -> None:
"""handles logging from all components"""
with self.log_mutex:
now = time.time()
@@ -401,7 +410,7 @@ class SvcHub(object):
if self.logf:
self.logf.write(msg)
- def check_mp_support(self):
+ def check_mp_support(self) -> str:
vmin = sys.version_info[1]
if WINDOWS:
msg = "need python 3.3 or newer for multiprocessing;"
@@ -415,16 +424,16 @@ class SvcHub(object):
return msg
try:
- x = mp.Queue(1)
- x.put(["foo", "bar"])
+ x: mp.Queue[tuple[str, str]] = mp.Queue(1)
+ x.put(("foo", "bar"))
if x.get()[0] != "foo":
raise Exception()
except:
return "multiprocessing is not supported on your platform;"
- return None
+ return ""
- def check_mp_enable(self):
+ def check_mp_enable(self) -> bool:
if self.args.j == 1:
return False
@@ -447,18 +456,18 @@ class SvcHub(object):
self.log("svchub", "cannot efficiently use multiple CPU cores")
return False
- def sd_notify(self):
+ def sd_notify(self) -> None:
try:
- addr = os.getenv("NOTIFY_SOCKET")
- if not addr:
+ zb = os.getenv("NOTIFY_SOCKET")
+ if not zb:
return
- addr = unicode(addr)
+ addr = unicode(zb)
if addr.startswith("@"):
addr = "\0" + addr[1:]
- m = "".join(x for x in addr if x in string.printable)
- self.log("sd_notify", m)
+ t = "".join(x for x in addr if x in string.printable)
+ self.log("sd_notify", t)
sck = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
sck.connect(addr)
diff --git a/copyparty/szip.py b/copyparty/szip.py
index 177ff3ea..178f8a61 100644
--- a/copyparty/szip.py
+++ b/copyparty/szip.py
@@ -1,16 +1,23 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
+import calendar
import time
import zlib
-import calendar
-from .sutil import errdesc
-from .util import yieldfile, sanitize_fn, spack, sunpack, min_ex
from .bos import bos
+from .sutil import StreamArc, errdesc
+from .util import min_ex, sanitize_fn, spack, sunpack, yieldfile
+
+try:
+ from typing import Any, Generator, Optional
+
+ from .util import NamedLogger
+except:
+ pass
-def dostime2unix(buf):
+def dostime2unix(buf: bytes) -> int:
t, d = sunpack(b" bytes:
tt = time.gmtime(ts + 1)
dy, dm, dd, th, tm, ts = list(tt)[:6]
@@ -41,14 +48,22 @@ def unixtime2dos(ts):
return b"\x00\x00\x21\x00"
-def gen_fdesc(sz, crc32, z64):
+def gen_fdesc(sz: int, crc32: int, z64: bool) -> bytes:
ret = b"\x50\x4b\x07\x08"
fmt = b" bytes:
"""
does regular file headers
and the central directory meme if h_pos is set
@@ -67,8 +82,8 @@ def gen_hdr(h_pos, fn, sz, lastmod, utf8, crc32, pre_crc):
# confusingly this doesn't bump if h_pos
req_ver = b"\x2d\x00" if z64 else b"\x0a\x00"
- if crc32:
- crc32 = spack(b" tuple[bytes, bool]:
"""
summary of all file headers,
usually the zipfile footer unless something clamps
@@ -154,10 +171,12 @@ def gen_ecdr(items, cdir_pos, cdir_end):
# 2b comment length
ret += b"\x00\x00"
- return [ret, need_64]
+ return ret, need_64
-def gen_ecdr64(items, cdir_pos, cdir_end):
+def gen_ecdr64(
+ items: list[tuple[str, int, int, int, int]], cdir_pos: int, cdir_end: int
+) -> bytes:
"""
z64 end of central directory
added when numfiles or a headerptr clamps
@@ -181,7 +200,7 @@ def gen_ecdr64(items, cdir_pos, cdir_end):
return ret
-def gen_ecdr64_loc(ecdr64_pos):
+def gen_ecdr64_loc(ecdr64_pos: int) -> bytes:
"""
z64 end of central directory locator
points to ecdr64
@@ -196,21 +215,27 @@ def gen_ecdr64_loc(ecdr64_pos):
return ret
-class StreamZip(object):
- def __init__(self, log, fgen, utf8=False, pre_crc=False):
- self.log = log
- self.fgen = fgen
+class StreamZip(StreamArc):
+ def __init__(
+ self,
+ log: NamedLogger,
+ fgen: Generator[dict[str, Any], None, None],
+ utf8: bool = False,
+ pre_crc: bool = False,
+ ) -> None:
+ super(StreamZip, self).__init__(log, fgen)
+
self.utf8 = utf8
self.pre_crc = pre_crc
self.pos = 0
- self.items = []
+ self.items: list[tuple[str, int, int, int, int]] = []
- def _ct(self, buf):
+ def _ct(self, buf: bytes) -> bytes:
self.pos += len(buf)
return buf
- def ser(self, f):
+ def ser(self, f: dict[str, Any]) -> Generator[bytes, None, None]:
name = f["vp"]
src = f["ap"]
st = f["st"]
@@ -218,9 +243,8 @@ class StreamZip(object):
sz = st.st_size
ts = st.st_mtime
- crc = None
+ crc = 0
if self.pre_crc:
- crc = 0
for buf in yieldfile(src):
crc = zlib.crc32(buf, crc)
@@ -230,7 +254,6 @@ class StreamZip(object):
buf = gen_hdr(None, name, sz, ts, self.utf8, crc, self.pre_crc)
yield self._ct(buf)
- crc = crc or 0
for buf in yieldfile(src):
if not self.pre_crc:
crc = zlib.crc32(buf, crc)
@@ -239,7 +262,7 @@ class StreamZip(object):
crc &= 0xFFFFFFFF
- self.items.append([name, sz, ts, crc, h_pos])
+ self.items.append((name, sz, ts, crc, h_pos))
z64 = sz >= 4 * 1024 * 1024 * 1024
@@ -247,11 +270,11 @@ class StreamZip(object):
buf = gen_fdesc(sz, crc, z64)
yield self._ct(buf)
- def gen(self):
+ def gen(self) -> Generator[bytes, None, None]:
errors = []
for f in self.fgen:
if "err" in f:
- errors.append([f["vp"], f["err"]])
+ errors.append((f["vp"], f["err"]))
continue
try:
@@ -259,7 +282,7 @@ class StreamZip(object):
yield x
except:
ex = min_ex(5, True).replace("\n", "\n-- ")
- errors.append([f["vp"], ex])
+ errors.append((f["vp"], ex))
if errors:
errf, txt = errdesc(errors)
diff --git a/copyparty/tcpsrv.py b/copyparty/tcpsrv.py
index 7d2cb3e4..ae4c5f86 100644
--- a/copyparty/tcpsrv.py
+++ b/copyparty/tcpsrv.py
@@ -2,12 +2,15 @@
from __future__ import print_function, unicode_literals
import re
-import sys
import socket
+import sys
-from .__init__ import MACOS, ANYWIN, unicode
+from .__init__ import ANYWIN, MACOS, TYPE_CHECKING, unicode
from .util import chkcmd
+if TYPE_CHECKING:
+ from .svchub import SvcHub
+
class TcpSrv(object):
"""
@@ -15,16 +18,16 @@ class TcpSrv(object):
which then uses the least busy HttpSrv to handle it
"""
- def __init__(self, hub):
+ def __init__(self, hub: "SvcHub"):
self.hub = hub
self.args = hub.args
self.log = hub.log
self.stopping = False
- self.srv = []
+ self.srv: list[socket.socket] = []
self.nsrv = 0
- ok = {}
+ ok: dict[str, list[int]] = {}
for ip in self.args.i:
ok[ip] = []
for port in self.args.p:
@@ -34,8 +37,8 @@ class TcpSrv(object):
ok[ip].append(port)
except Exception as ex:
if self.args.ign_ebind or self.args.ign_ebind_all:
- m = "could not listen on {}:{}: {}"
- self.log("tcpsrv", m.format(ip, port, ex), c=3)
+ t = "could not listen on {}:{}: {}"
+ self.log("tcpsrv", t.format(ip, port, ex), c=3)
else:
raise
@@ -55,9 +58,9 @@ class TcpSrv(object):
eps[x] = "external"
msgs = []
- title_tab = {}
+ title_tab: dict[str, dict[str, int]] = {}
title_vars = [x[1:] for x in self.args.wintitle.split(" ") if x.startswith("$")]
- m = "available @ {}://{}:{}/ (\033[33m{}\033[0m)"
+ t = "available @ {}://{}:{}/ (\033[33m{}\033[0m)"
for ip, desc in sorted(eps.items(), key=lambda x: x[1]):
for port in sorted(self.args.p):
if port not in ok.get(ip, ok.get("0.0.0.0", [])):
@@ -69,7 +72,7 @@ class TcpSrv(object):
elif self.args.https_only or port == 443:
proto = "https"
- msgs.append(m.format(proto, ip, port, desc))
+ msgs.append(t.format(proto, ip, port, desc))
if not self.args.wintitle:
continue
@@ -98,13 +101,13 @@ class TcpSrv(object):
if msgs:
msgs[-1] += "\n"
- for m in msgs:
- self.log("tcpsrv", m)
+ for t in msgs:
+ self.log("tcpsrv", t)
if self.args.wintitle:
self._set_wintitle(title_tab)
- def _listen(self, ip, port):
+ def _listen(self, ip: str, port: int) -> None:
srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
srv.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
@@ -120,7 +123,7 @@ class TcpSrv(object):
raise
raise Exception(e)
- def run(self):
+ def run(self) -> None:
for srv in self.srv:
srv.listen(self.args.nc)
ip, port = srv.getsockname()
@@ -130,9 +133,9 @@ class TcpSrv(object):
if self.args.q:
print(msg)
- self.hub.broker.put(False, "listen", srv)
+ self.hub.broker.say("listen", srv)
- def shutdown(self):
+ def shutdown(self) -> None:
self.stopping = True
try:
for srv in self.srv:
@@ -142,14 +145,14 @@ class TcpSrv(object):
self.log("tcpsrv", "ok bye")
- def ips_linux_ifconfig(self):
+ def ips_linux_ifconfig(self) -> dict[str, str]:
# for termux
try:
txt, _ = chkcmd(["ifconfig"])
except:
return {}
- eps = {}
+ eps: dict[str, str] = {}
dev = None
ip = None
up = None
@@ -171,7 +174,7 @@ class TcpSrv(object):
return eps
- def ips_linux(self):
+ def ips_linux(self) -> dict[str, str]:
try:
txt, _ = chkcmd(["ip", "addr"])
except:
@@ -180,21 +183,21 @@ class TcpSrv(object):
r = re.compile(r"^\s+inet ([^ ]+)/.* (.*)")
ri = re.compile(r"^\s*[0-9]+\s*:.*")
up = False
- eps = {}
+ eps: dict[str, str] = {}
for ln in txt.split("\n"):
if ri.match(ln):
up = "UP" in re.split("[>,< ]", ln)
try:
- ip, dev = r.match(ln.rstrip()).groups()
+ ip, dev = r.match(ln.rstrip()).groups() # type: ignore
eps[ip] = dev + ("" if up else ", \033[31mLINK-DOWN")
except:
pass
return eps
- def ips_macos(self):
- eps = {}
+ def ips_macos(self) -> dict[str, str]:
+ eps: dict[str, str] = {}
try:
txt, _ = chkcmd(["ifconfig"])
except:
@@ -202,7 +205,7 @@ class TcpSrv(object):
rdev = re.compile(r"^([^ ]+):")
rip = re.compile(r"^\tinet ([0-9\.]+) ")
- dev = None
+ dev = "UNKNOWN"
for ln in txt.split("\n"):
m = rdev.match(ln)
if m:
@@ -211,17 +214,17 @@ class TcpSrv(object):
m = rip.match(ln)
if m:
eps[m.group(1)] = dev
- dev = None
+ dev = "UNKNOWN"
return eps
- def ips_windows_ipconfig(self):
- eps = {}
- offs = {}
+ def ips_windows_ipconfig(self) -> tuple[dict[str, str], set[str]]:
+ eps: dict[str, str] = {}
+ offs: set[str] = set()
try:
txt, _ = chkcmd(["ipconfig"])
except:
- return eps
+ return eps, offs
rdev = re.compile(r"(^[^ ].*):$")
rip = re.compile(r"^ +IPv?4? [^:]+: *([0-9\.]{7,15})$")
@@ -231,12 +234,12 @@ class TcpSrv(object):
m = rdev.match(ln)
if m:
if dev and dev not in eps.values():
- offs[dev] = 1
+ offs.add(dev)
dev = m.group(1).split(" adapter ", 1)[-1]
if dev and roff.match(ln):
- offs[dev] = 1
+ offs.add(dev)
dev = None
m = rip.match(ln)
@@ -245,12 +248,12 @@ class TcpSrv(object):
dev = None
if dev and dev not in eps.values():
- offs[dev] = 1
+ offs.add(dev)
return eps, offs
- def ips_windows_netsh(self):
- eps = {}
+ def ips_windows_netsh(self) -> dict[str, str]:
+ eps: dict[str, str] = {}
try:
txt, _ = chkcmd("netsh interface ip show address".split())
except:
@@ -270,7 +273,7 @@ class TcpSrv(object):
return eps
- def detect_interfaces(self, listen_ips):
+ def detect_interfaces(self, listen_ips: list[str]) -> dict[str, str]:
if MACOS:
eps = self.ips_macos()
elif ANYWIN:
@@ -317,7 +320,7 @@ class TcpSrv(object):
return eps
- def _set_wintitle(self, vs):
+ def _set_wintitle(self, vs: dict[str, dict[str, int]]) -> None:
vs["all"] = vs.get("all", {"Local-Only": 1})
vs["pub"] = vs.get("pub", vs["all"])
diff --git a/copyparty/th_cli.py b/copyparty/th_cli.py
index e8a18e0e..9eb49a4f 100644
--- a/copyparty/th_cli.py
+++ b/copyparty/th_cli.py
@@ -3,13 +3,23 @@ from __future__ import print_function, unicode_literals
import os
-from .util import Cooldown
-from .th_srv import thumb_path, HAVE_WEBP
+from .__init__ import TYPE_CHECKING
+from .authsrv import VFS
from .bos import bos
+from .th_srv import HAVE_WEBP, thumb_path
+from .util import Cooldown
+
+try:
+ from typing import Optional, Union
+except:
+ pass
+
+if TYPE_CHECKING:
+ from .httpsrv import HttpSrv
class ThumbCli(object):
- def __init__(self, hsrv):
+ def __init__(self, hsrv: "HttpSrv") -> None:
self.broker = hsrv.broker
self.log_func = hsrv.log
self.args = hsrv.args
@@ -34,10 +44,10 @@ class ThumbCli(object):
d = next((x for x in self.args.th_dec if x in ("vips", "pil")), None)
self.can_webp = HAVE_WEBP or d == "vips"
- def log(self, msg, c=0):
+ def log(self, msg: str, c: Union[int, str] = 0) -> None:
self.log_func("thumbcli", msg, c)
- def get(self, dbv, rem, mtime, fmt):
+ def get(self, dbv: VFS, rem: str, mtime: float, fmt: str) -> Optional[str]:
ptop = dbv.realpath
ext = rem.rsplit(".")[-1].lower()
if ext not in self.thumbable or "dthumb" in dbv.flags:
@@ -106,17 +116,17 @@ class ThumbCli(object):
if ret:
tdir = os.path.dirname(tpath)
if self.cooldown.poke(tdir):
- self.broker.put(False, "thumbsrv.poke", tdir)
+ self.broker.say("thumbsrv.poke", tdir)
if want_opus:
# audio files expire individually
if self.cooldown.poke(tpath):
- self.broker.put(False, "thumbsrv.poke", tpath)
+ self.broker.say("thumbsrv.poke", tpath)
return ret
if abort:
return None
- x = self.broker.put(True, "thumbsrv.get", ptop, rem, mtime, fmt)
- return x.get()
+ x = self.broker.ask("thumbsrv.get", ptop, rem, mtime, fmt)
+ return x.get() # type: ignore
diff --git a/copyparty/th_srv.py b/copyparty/th_srv.py
index dc3965c7..2a9e5f02 100644
--- a/copyparty/th_srv.py
+++ b/copyparty/th_srv.py
@@ -1,18 +1,28 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
-import os
-import time
-import shutil
import base64
import hashlib
-import threading
+import os
+import shutil
import subprocess as sp
+import threading
+import time
-from .util import fsenc, vsplit, statdir, runcmd, Queue, Cooldown, BytesIO, min_ex
+from queue import Queue
+
+from .__init__ import TYPE_CHECKING
from .bos import bos
from .mtag import HAVE_FFMPEG, HAVE_FFPROBE, ffprobe
+from .util import BytesIO, Cooldown, fsenc, min_ex, runcmd, statdir, vsplit
+try:
+ from typing import Optional, Union
+except:
+ pass
+
+if TYPE_CHECKING:
+ from .svchub import SvcHub
HAVE_PIL = False
HAVE_HEIF = False
@@ -20,7 +30,7 @@ HAVE_AVIF = False
HAVE_WEBP = False
try:
- from PIL import Image, ImageOps, ExifTags
+ from PIL import ExifTags, Image, ImageOps
HAVE_PIL = True
try:
@@ -47,14 +57,13 @@ except:
pass
try:
- import pyvips
-
HAVE_VIPS = True
+ import pyvips
except:
HAVE_VIPS = False
-def thumb_path(histpath, rem, mtime, fmt):
+def thumb_path(histpath: str, rem: str, mtime: float, fmt: str) -> str:
# base16 = 16 = 256
# b64-lc = 38 = 1444
# base64 = 64 = 4096
@@ -80,7 +89,7 @@ def thumb_path(histpath, rem, mtime, fmt):
class ThumbSrv(object):
- def __init__(self, hub):
+ def __init__(self, hub: "SvcHub") -> None:
self.hub = hub
self.asrv = hub.asrv
self.args = hub.args
@@ -91,17 +100,17 @@ class ThumbSrv(object):
self.poke_cd = Cooldown(self.args.th_poke)
self.mutex = threading.Lock()
- self.busy = {}
+ self.busy: dict[str, list[threading.Condition]] = {}
self.stopping = False
self.nthr = max(1, self.args.th_mt)
- self.q = Queue(self.nthr * 4)
+ self.q: Queue[Optional[tuple[str, str]]] = Queue(self.nthr * 4)
for n in range(self.nthr):
- t = threading.Thread(
+ thr = threading.Thread(
target=self.worker, name="thumb-{}-{}".format(n, self.nthr)
)
- t.daemon = True
- t.start()
+ thr.daemon = True
+ thr.start()
want_ff = not self.args.no_vthumb or not self.args.no_athumb
if want_ff and (not HAVE_FFMPEG or not HAVE_FFPROBE):
@@ -122,7 +131,7 @@ class ThumbSrv(object):
t.start()
self.fmt_pil, self.fmt_vips, self.fmt_ffi, self.fmt_ffv, self.fmt_ffa = [
- {x: True for x in y.split(",")}
+ set(y.split(","))
for y in [
self.args.th_r_pil,
self.args.th_r_vips,
@@ -134,37 +143,37 @@ class ThumbSrv(object):
if not HAVE_HEIF:
for f in "heif heifs heic heics".split(" "):
- self.fmt_pil.pop(f, None)
+ self.fmt_pil.discard(f)
if not HAVE_AVIF:
for f in "avif avifs".split(" "):
- self.fmt_pil.pop(f, None)
+ self.fmt_pil.discard(f)
- self.thumbable = {}
+ self.thumbable: set[str] = set()
if "pil" in self.args.th_dec:
- self.thumbable.update(self.fmt_pil)
+ self.thumbable |= self.fmt_pil
if "vips" in self.args.th_dec:
- self.thumbable.update(self.fmt_vips)
+ self.thumbable |= self.fmt_vips
if "ff" in self.args.th_dec:
- for t in [self.fmt_ffi, self.fmt_ffv, self.fmt_ffa]:
- self.thumbable.update(t)
+ for zss in [self.fmt_ffi, self.fmt_ffv, self.fmt_ffa]:
+ self.thumbable |= zss
- def log(self, msg, c=0):
+ def log(self, msg: str, c: Union[int, str] = 0) -> None:
self.log_func("thumb", msg, c)
- def shutdown(self):
+ def shutdown(self) -> None:
self.stopping = True
for _ in range(self.nthr):
self.q.put(None)
- def stopped(self):
+ def stopped(self) -> bool:
with self.mutex:
return not self.nthr
- def get(self, ptop, rem, mtime, fmt):
+ def get(self, ptop: str, rem: str, mtime: float, fmt: str) -> Optional[str]:
histpath = self.asrv.vfs.histtab.get(ptop)
if not histpath:
self.log("no histpath for [{}]".format(ptop))
@@ -191,7 +200,7 @@ class ThumbSrv(object):
do_conv = True
if do_conv:
- self.q.put([abspath, tpath])
+ self.q.put((abspath, tpath))
self.log("conv {} \033[0m{}".format(tpath, abspath), c=6)
while not self.stopping:
@@ -212,7 +221,7 @@ class ThumbSrv(object):
return None
- def getcfg(self):
+ def getcfg(self) -> dict[str, set[str]]:
return {
"thumbable": self.thumbable,
"pil": self.fmt_pil,
@@ -222,7 +231,7 @@ class ThumbSrv(object):
"ffa": self.fmt_ffa,
}
- def worker(self):
+ def worker(self) -> None:
while not self.stopping:
task = self.q.get()
if not task:
@@ -253,7 +262,7 @@ class ThumbSrv(object):
except:
msg = "{} could not create thumbnail of {}\n{}"
msg = msg.format(fun.__name__, abspath, min_ex())
- c = 1 if " "Image.Image":
# exif_transpose is expensive (loads full image + unconditional copy)
r = max(*self.res) * 2
im.thumbnail((r, r), resample=Image.LANCZOS)
@@ -295,7 +304,7 @@ class ThumbSrv(object):
return im
- def conv_pil(self, abspath, tpath):
+ def conv_pil(self, abspath: str, tpath: str) -> None:
with Image.open(fsenc(abspath)) as im:
try:
im = self.fancy_pillow(im)
@@ -324,7 +333,7 @@ class ThumbSrv(object):
im.save(tpath, **args)
- def conv_vips(self, abspath, tpath):
+ def conv_vips(self, abspath: str, tpath: str) -> None:
crops = ["centre", "none"]
if self.args.th_no_crop:
crops = ["none"]
@@ -342,18 +351,17 @@ class ThumbSrv(object):
img.write_to_file(tpath, Q=40)
- def conv_ffmpeg(self, abspath, tpath):
+ def conv_ffmpeg(self, abspath: str, tpath: str) -> None:
ret, _ = ffprobe(abspath)
if not ret:
return
ext = abspath.rsplit(".")[-1].lower()
if ext in ["h264", "h265"] or ext in self.fmt_ffi:
- seek = []
+ seek: list[bytes] = []
else:
dur = ret[".dur"][1] if ".dur" in ret else 4
- seek = "{:.0f}".format(dur / 3)
- seek = [b"-ss", seek.encode("utf-8")]
+ seek = [b"-ss", "{:.0f}".format(dur / 3).encode("utf-8")]
scale = "scale={0}:{1}:force_original_aspect_ratio="
if self.args.th_no_crop:
@@ -361,7 +369,7 @@ class ThumbSrv(object):
else:
scale += "increase,crop={0}:{1},setsar=1:1"
- scale = scale.format(*list(self.res)).encode("utf-8")
+ bscale = scale.format(*list(self.res)).encode("utf-8")
# fmt: off
cmd = [
b"ffmpeg",
@@ -373,7 +381,7 @@ class ThumbSrv(object):
cmd += [
b"-i", fsenc(abspath),
b"-map", b"0:v:0",
- b"-vf", scale,
+ b"-vf", bscale,
b"-frames:v", b"1",
b"-metadata:s:v:0", b"rotate=0",
]
@@ -395,14 +403,14 @@ class ThumbSrv(object):
cmd += [fsenc(tpath)]
self._run_ff(cmd)
- def _run_ff(self, cmd):
+ def _run_ff(self, cmd: list[bytes]) -> None:
# self.log((b" ".join(cmd)).decode("utf-8"))
ret, _, serr = runcmd(cmd, timeout=self.args.th_convt)
if not ret:
return
- c = "1;30"
- m = "FFmpeg failed (probably a corrupt video file):\n"
+ c: Union[str, int] = "1;30"
+ t = "FFmpeg failed (probably a corrupt video file):\n"
if cmd[-1].lower().endswith(b".webp") and (
"Error selecting an encoder" in serr
or "Automatic encoder selection failed" in serr
@@ -410,14 +418,14 @@ class ThumbSrv(object):
or "Please choose an encoder manually" in serr
):
self.args.th_ff_jpg = True
- m = "FFmpeg failed because it was compiled without libwebp; enabling --th-ff-jpg to force jpeg output:\n"
+ t = "FFmpeg failed because it was compiled without libwebp; enabling --th-ff-jpg to force jpeg output:\n"
c = 1
if (
"Requested resampling engine is unavailable" in serr
or "output pad on Parsed_aresample_" in serr
):
- m = "FFmpeg failed because it was compiled without libsox; you must set --th-ff-swr to force swr resampling:\n"
+ t = "FFmpeg failed because it was compiled without libsox; you must set --th-ff-swr to force swr resampling:\n"
c = 1
lines = serr.strip("\n").split("\n")
@@ -428,10 +436,10 @@ class ThumbSrv(object):
if len(txt) > 5000:
txt = txt[:2500] + "...\nff: [...]\nff: ..." + txt[-2500:]
- self.log(m + txt, c=c)
+ self.log(t + txt, c=c)
raise sp.CalledProcessError(ret, (cmd[0], b"...", cmd[-1]))
- def conv_spec(self, abspath, tpath):
+ def conv_spec(self, abspath: str, tpath: str) -> None:
ret, _ = ffprobe(abspath)
if "ac" not in ret:
raise Exception("not audio")
@@ -473,7 +481,7 @@ class ThumbSrv(object):
cmd += [fsenc(tpath)]
self._run_ff(cmd)
- def conv_opus(self, abspath, tpath):
+ def conv_opus(self, abspath: str, tpath: str) -> None:
if self.args.no_acode:
raise Exception("disabled in server config")
@@ -521,7 +529,7 @@ class ThumbSrv(object):
# fmt: on
self._run_ff(cmd)
- def poke(self, tdir):
+ def poke(self, tdir: str) -> None:
if not self.poke_cd.poke(tdir):
return
@@ -533,7 +541,7 @@ class ThumbSrv(object):
except:
pass
- def cleaner(self):
+ def cleaner(self) -> None:
interval = self.args.th_clean
while True:
time.sleep(interval)
@@ -548,14 +556,14 @@ class ThumbSrv(object):
self.log("\033[Jcln ok; rm {} dirs".format(ndirs))
- def clean(self, histpath):
+ def clean(self, histpath: str) -> int:
ret = 0
for cat in ["th", "ac"]:
- ret += self._clean(histpath, cat, None)
+ ret += self._clean(histpath, cat, "")
return ret
- def _clean(self, histpath, cat, thumbpath):
+ def _clean(self, histpath: str, cat: str, thumbpath: str) -> int:
if not thumbpath:
thumbpath = os.path.join(histpath, cat)
@@ -564,10 +572,10 @@ class ThumbSrv(object):
maxage = getattr(self.args, cat + "_maxage")
now = time.time()
prev_b64 = None
- prev_fp = None
+ prev_fp = ""
try:
- ents = statdir(self.log, not self.args.no_scandir, False, thumbpath)
- ents = sorted(list(ents))
+ t1 = statdir(self.log_func, not self.args.no_scandir, False, thumbpath)
+ ents = sorted(list(t1))
except:
return 0
diff --git a/copyparty/u2idx.py b/copyparty/u2idx.py
index 073ed3d1..8a342042 100644
--- a/copyparty/u2idx.py
+++ b/copyparty/u2idx.py
@@ -1,34 +1,37 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
-import re
-import os
-import time
import calendar
+import os
+import re
import threading
+import time
from operator import itemgetter
-from .__init__ import ANYWIN, unicode
-from .util import absreal, s3dec, Pebkac, min_ex, gen_filekey, quotep
+from .__init__ import ANYWIN, TYPE_CHECKING, unicode
from .bos import bos
from .up2k import up2k_wark_from_hashlist
+from .util import HAVE_SQLITE3, Pebkac, absreal, gen_filekey, min_ex, quotep, s3dec
-
-try:
- HAVE_SQLITE3 = True
+if HAVE_SQLITE3:
import sqlite3
-except:
- HAVE_SQLITE3 = False
-
try:
from pathlib import Path
except:
pass
+try:
+ from typing import Any, Optional, Union
+except:
+ pass
+
+if TYPE_CHECKING:
+ from .httpconn import HttpConn
+
class U2idx(object):
- def __init__(self, conn):
+ def __init__(self, conn: "HttpConn") -> None:
self.log_func = conn.log_func
self.asrv = conn.asrv
self.args = conn.args
@@ -38,19 +41,21 @@ class U2idx(object):
self.log("your python does not have sqlite3; searching will be disabled")
return
- self.active_id = None
- self.active_cur = None
- self.cur = {}
- self.mem_cur = sqlite3.connect(":memory:")
+ self.active_id = ""
+ self.active_cur: Optional["sqlite3.Cursor"] = None
+ self.cur: dict[str, "sqlite3.Cursor"] = {}
+ self.mem_cur = sqlite3.connect(":memory:").cursor()
self.mem_cur.execute(r"create table a (b text)")
- self.p_end = None
- self.p_dur = 0
+ self.p_end = 0.0
+ self.p_dur = 0.0
- def log(self, msg, c=0):
+ def log(self, msg: str, c: Union[int, str] = 0) -> None:
self.log_func("u2idx", msg, c)
- def fsearch(self, vols, body):
+ def fsearch(
+ self, vols: list[tuple[str, str, dict[str, Any]]], body: dict[str, Any]
+ ) -> list[dict[str, Any]]:
"""search by up2k hashlist"""
if not HAVE_SQLITE3:
return []
@@ -60,14 +65,14 @@ class U2idx(object):
wark = up2k_wark_from_hashlist(self.args.salt, fsize, fhash)
uq = "substr(w,1,16) = ? and w = ?"
- uv = [wark[:16], wark]
+ uv: list[Union[str, int]] = [wark[:16], wark]
try:
return self.run_query(vols, uq, uv, True, False, 99999)[0]
except:
raise Pebkac(500, min_ex())
- def get_cur(self, ptop):
+ def get_cur(self, ptop: str) -> Optional["sqlite3.Cursor"]:
if not HAVE_SQLITE3:
return None
@@ -103,13 +108,16 @@ class U2idx(object):
self.cur[ptop] = cur
return cur
- def search(self, vols, uq, lim):
+ def search(
+ self, vols: list[tuple[str, str, dict[str, Any]]], uq: str, lim: int
+ ) -> tuple[list[dict[str, Any]], list[str]]:
"""search by query params"""
if not HAVE_SQLITE3:
- return []
+ return [], []
q = ""
- va = []
+ v: Union[str, int] = ""
+ va: list[Union[str, int]] = []
have_up = False # query has up.* operands
have_mt = False
is_key = True
@@ -202,7 +210,7 @@ class U2idx(object):
"%Y",
]:
try:
- v = calendar.timegm(time.strptime(v, fmt))
+ v = calendar.timegm(time.strptime(str(v), fmt))
break
except:
pass
@@ -230,11 +238,12 @@ class U2idx(object):
# lowercase tag searches
m = ptn_lc.search(q)
- if not m or not ptn_lcv.search(unicode(v)):
+ zs = unicode(v)
+ if not m or not ptn_lcv.search(zs):
continue
va.pop()
- va.append(v.lower())
+ va.append(zs.lower())
q = q[: m.start()]
field, oper = m.groups()
@@ -248,8 +257,16 @@ class U2idx(object):
except Exception as ex:
raise Pebkac(500, repr(ex))
- def run_query(self, vols, uq, uv, have_up, have_mt, lim):
- done_flag = []
+ def run_query(
+ self,
+ vols: list[tuple[str, str, dict[str, Any]]],
+ uq: str,
+ uv: list[Union[str, int]],
+ have_up: bool,
+ have_mt: bool,
+ lim: int,
+ ) -> tuple[list[dict[str, Any]], list[str]]:
+ done_flag: list[bool] = []
self.active_id = "{:.6f}_{}".format(
time.time(), threading.current_thread().ident
)
@@ -266,13 +283,11 @@ class U2idx(object):
if not uq or not uv:
uq = "select * from up"
- uv = ()
+ uv = []
elif have_mt:
uq = "select up.*, substr(up.w,1,16) mtw from up where " + uq
- uv = tuple(uv)
else:
uq = "select up.* from up where " + uq
- uv = tuple(uv)
self.log("qs: {!r} {!r}".format(uq, uv))
@@ -292,11 +307,10 @@ class U2idx(object):
v = vtop + "/"
vuv.append(v)
- vuv = tuple(vuv)
sret = []
fk = flags.get("fk")
- c = cur.execute(uq, vuv)
+ c = cur.execute(uq, tuple(vuv))
for hit in c:
w, ts, sz, rd, fn, ip, at = hit[:7]
lim -= 1
@@ -340,7 +354,7 @@ class U2idx(object):
# print("[{}] {}".format(ptop, sret))
done_flag.append(True)
- self.active_id = None
+ self.active_id = ""
# undupe hits from multiple metadata keys
if len(ret) > 1:
@@ -354,11 +368,12 @@ class U2idx(object):
return ret, list(taglist.keys())
- def terminator(self, identifier, done_flag):
+ def terminator(self, identifier: str, done_flag: list[bool]) -> None:
for _ in range(self.timeout):
time.sleep(1)
if done_flag:
return
if identifier == self.active_id:
+ assert self.active_cur
self.active_cur.connection.interrupt()
diff --git a/copyparty/up2k.py b/copyparty/up2k.py
index 3f4e02e4..aff68917 100644
--- a/copyparty/up2k.py
+++ b/copyparty/up2k.py
@@ -1,60 +1,83 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
-import re
-import os
-import time
-import math
-import json
-import gzip
-import stat
-import shutil
import base64
+import gzip
import hashlib
-import threading
-import traceback
+import json
+import math
+import os
+import re
+import shutil
+import stat
import subprocess as sp
+import threading
+import time
+import traceback
from copy import deepcopy
-from .__init__ import WINDOWS, ANYWIN, PY2
-from .util import (
- Pebkac,
- Queue,
- ProgressPrinter,
- SYMTIME,
- fsenc,
- absreal,
- sanitize_fn,
- ren_open,
- atomic_move,
- quotep,
- vsplit,
- w8b64enc,
- w8b64dec,
- s3enc,
- s3dec,
- rmdirs,
- statdir,
- s2hms,
- min_ex,
-)
-from .bos import bos
-from .authsrv import AuthSrv, LEELOO_DALLAS
-from .mtag import MTag, MParser
+from queue import Queue
-try:
- HAVE_SQLITE3 = True
+from .__init__ import ANYWIN, PY2, TYPE_CHECKING, WINDOWS
+from .authsrv import LEELOO_DALLAS, VFS, AuthSrv
+from .bos import bos
+from .mtag import MParser, MTag
+from .util import (
+ HAVE_SQLITE3,
+ SYMTIME,
+ Pebkac,
+ ProgressPrinter,
+ absreal,
+ atomic_move,
+ fsenc,
+ min_ex,
+ quotep,
+ ren_open,
+ rmdirs,
+ s2hms,
+ s3dec,
+ s3enc,
+ sanitize_fn,
+ statdir,
+ vsplit,
+ w8b64dec,
+ w8b64enc,
+)
+
+if HAVE_SQLITE3:
import sqlite3
-except:
- HAVE_SQLITE3 = False
DB_VER = 5
+try:
+ from typing import Any, Optional, Pattern, Union
+except:
+ pass
+
+if TYPE_CHECKING:
+ from .svchub import SvcHub
+
+
+class Dbw(object):
+ def __init__(self, c: "sqlite3.Cursor", n: int, t: float) -> None:
+ self.c = c
+ self.n = n
+ self.t = t
+
+
+class Mpqe(object):
+ def __init__(self, mtp: dict[str, MParser], entags: set[str], w: str, abspath: str):
+ # mtp empty = mtag
+ self.mtp = mtp
+ self.entags = entags
+ self.w = w
+ self.abspath = abspath
+
class Up2k(object):
- def __init__(self, hub):
+ def __init__(self, hub: "SvcHub") -> None:
self.hub = hub
- self.asrv = hub.asrv # type: AuthSrv
+ self.asrv: AuthSrv = hub.asrv
self.args = hub.args
self.log_func = hub.log
@@ -62,25 +85,32 @@ class Up2k(object):
self.salt = self.args.salt
# state
+ self.gid = 0
self.mutex = threading.Lock()
+ self.pp: Optional[ProgressPrinter] = None
self.rescan_cond = threading.Condition()
- self.hashq = Queue()
- self.tagq = Queue()
+ self.need_rescan: set[str] = set()
+
+ self.registry: dict[str, dict[str, dict[str, Any]]] = {}
+ self.flags: dict[str, dict[str, Any]] = {}
+ self.droppable: dict[str, list[str]] = {}
+ self.volstate: dict[str, str] = {}
+ self.dupesched: dict[str, list[tuple[str, str, float]]] = {}
+ self.snap_persist_interval = 300 # persist unfinished index every 5 min
+ self.snap_discard_interval = 21600 # drop unfinished after 6 hours inactivity
+ self.snap_prev: dict[str, Optional[tuple[int, float]]] = {}
+
+ self.mtag: Optional[MTag] = None
+ self.entags: dict[str, set[str]] = {}
+ self.mtp_parsers: dict[str, dict[str, MParser]] = {}
+ self.pending_tags: list[tuple[set[str], str, str, dict[str, Any]]] = []
+ self.hashq: Queue[tuple[str, str, str, str, float]] = Queue()
+ self.tagq: Queue[tuple[str, str, str, str]] = Queue()
self.n_hashq = 0
self.n_tagq = 0
- self.gid = 0
- self.volstate = {}
- self.need_rescan = {}
- self.dupesched = {}
- self.registry = {}
- self.droppable = {}
- self.entags = {}
- self.flags = {}
- self.cur = {}
- self.mtag = None
- self.pending_tags = None
- self.mtp_parsers = {}
+ self.mpool_used = False
+ self.cur: dict[str, "sqlite3.Cursor"] = {}
self.mem_cur = None
self.sqlite_ver = None
self.no_expr_idx = False
@@ -94,7 +124,7 @@ class Up2k(object):
if ANYWIN:
# usually fails to set lastmod too quickly
- self.lastmod_q = []
+ self.lastmod_q: list[tuple[str, int, tuple[int, int]]] = []
thr = threading.Thread(target=self._lastmodder, name="up2k-lastmod")
thr.daemon = True
thr.start()
@@ -108,7 +138,7 @@ class Up2k(object):
if self.args.no_fastboot:
self.deferred_init()
- def init_vols(self):
+ def init_vols(self) -> None:
if self.args.no_fastboot:
return
@@ -116,15 +146,15 @@ class Up2k(object):
t.daemon = True
t.start()
- def reload(self):
+ def reload(self) -> None:
self.gid += 1
self.log("reload #{} initiated".format(self.gid))
all_vols = self.asrv.vfs.all_vols
self.rescan(all_vols, list(all_vols.keys()), True)
- def deferred_init(self):
+ def deferred_init(self) -> None:
all_vols = self.asrv.vfs.all_vols
- have_e2d = self.init_indexes(all_vols)
+ have_e2d = self.init_indexes(all_vols, [])
thr = threading.Thread(target=self._snapshot, name="up2k-snapshot")
thr.daemon = True
@@ -150,11 +180,11 @@ class Up2k(object):
thr.daemon = True
thr.start()
- def log(self, msg, c=0):
+ def log(self, msg: str, c: Union[int, str] = 0) -> None:
self.log_func("up2k", msg + "\033[K", c)
- def get_state(self):
- mtpq = 0
+ def get_state(self) -> str:
+ mtpq: Union[int, str] = 0
q = "select count(w) from mt where k = 't:mtp'"
got_lock = False if PY2 else self.mutex.acquire(timeout=0.5)
if got_lock:
@@ -165,19 +195,19 @@ class Up2k(object):
pass
self.mutex.release()
else:
- mtpq = "?"
+ mtpq = "(?)"
ret = {
"volstate": self.volstate,
- "scanning": hasattr(self, "pp"),
+ "scanning": bool(self.pp),
"hashq": self.n_hashq,
"tagq": self.n_tagq,
"mtpq": mtpq,
}
return json.dumps(ret, indent=4)
- def rescan(self, all_vols, scan_vols, wait):
- if not wait and hasattr(self, "pp"):
+ def rescan(self, all_vols: dict[str, VFS], scan_vols: list[str], wait: bool) -> str:
+ if not wait and self.pp:
return "cannot initiate; scan is already in progress"
args = (all_vols, scan_vols)
@@ -188,11 +218,11 @@ class Up2k(object):
)
t.daemon = True
t.start()
- return None
+ return ""
- def _sched_rescan(self):
+ def _sched_rescan(self) -> None:
volage = {}
- cooldown = 0
+ cooldown = 0.0
timeout = time.time() + 3
while True:
timeout = max(timeout, cooldown)
@@ -204,7 +234,7 @@ class Up2k(object):
if now < cooldown:
continue
- if hasattr(self, "pp"):
+ if self.pp:
cooldown = now + 5
continue
@@ -220,19 +250,19 @@ class Up2k(object):
deadline = volage[vp] + maxage
if deadline <= now:
- self.need_rescan[vp] = 1
+ self.need_rescan.add(vp)
timeout = min(timeout, deadline)
- vols = list(sorted(self.need_rescan.keys()))
- self.need_rescan = {}
+ vols = list(sorted(self.need_rescan))
+ self.need_rescan.clear()
if vols:
cooldown = now + 10
err = self.rescan(self.asrv.vfs.all_vols, vols, False)
if err:
for v in vols:
- self.need_rescan[v] = True
+ self.need_rescan.add(v)
continue
@@ -272,7 +302,7 @@ class Up2k(object):
if vp:
fvp = "{}/{}".format(vp, fvp)
- self._handle_rm(LEELOO_DALLAS, None, fvp)
+ self._handle_rm(LEELOO_DALLAS, "", fvp)
nrm += 1
if nrm:
@@ -288,12 +318,12 @@ class Up2k(object):
if hits:
timeout = min(timeout, now + lifetime - (now - hits[0]))
- def _vis_job_progress(self, job):
+ def _vis_job_progress(self, job: dict[str, Any]) -> str:
perc = 100 - (len(job["need"]) * 100.0 / len(job["hash"]))
path = os.path.join(job["ptop"], job["prel"], job["name"])
return "{:5.1f}% {}".format(perc, path)
- def _vis_reg_progress(self, reg):
+ def _vis_reg_progress(self, reg: dict[str, dict[str, Any]]) -> list[str]:
ret = []
for _, job in reg.items():
if job["need"]:
@@ -301,7 +331,7 @@ class Up2k(object):
return ret
- def _expr_idx_filter(self, flags):
+ def _expr_idx_filter(self, flags: dict[str, Any]) -> tuple[bool, dict[str, Any]]:
if not self.no_expr_idx:
return False, flags
@@ -311,19 +341,19 @@ class Up2k(object):
return True, ret
- def init_indexes(self, all_vols, scan_vols=None):
+ def init_indexes(self, all_vols: dict[str, VFS], scan_vols: list[str]) -> bool:
gid = self.gid
- while hasattr(self, "pp") and gid == self.gid:
+ while self.pp and gid == self.gid:
time.sleep(0.1)
if gid != self.gid:
- return
+ return False
if gid:
self.log("reload #{} running".format(self.gid))
self.pp = ProgressPrinter()
- vols = all_vols.values()
+ vols = list(all_vols.values())
t0 = time.time()
have_e2d = False
@@ -377,9 +407,9 @@ class Up2k(object):
# e2ds(a) volumes first
for vol in vols:
- en = {}
+ en: set[str] = set()
if "mte" in vol.flags:
- en = {k: True for k in vol.flags["mte"].split(",")}
+ en = set(vol.flags["mte"].split(","))
self.entags[vol.realpath] = en
@@ -393,11 +423,11 @@ class Up2k(object):
need_vac[vol] = True
if "e2ts" not in vol.flags:
- m = "online, idle"
+ t = "online, idle"
else:
- m = "online (tags pending)"
+ t = "online (tags pending)"
- self.volstate[vol.vpath] = m
+ self.volstate[vol.vpath] = t
# open the rest + do any e2ts(a)
needed_mutagen = False
@@ -405,9 +435,9 @@ class Up2k(object):
if "e2ts" not in vol.flags:
continue
- m = "online (reading tags)"
- self.volstate[vol.vpath] = m
- self.log("{} [{}]".format(m, vol.realpath))
+ t = "online (reading tags)"
+ self.volstate[vol.vpath] = t
+ self.log("{} [{}]".format(t, vol.realpath))
nadd, nrm, success = self._build_tags_index(vol)
if not success:
@@ -419,7 +449,9 @@ class Up2k(object):
self.volstate[vol.vpath] = "online (mtp soon)"
for vol in need_vac:
- cur, _ = self.register_vpath(vol.realpath, vol.flags)
+ reg = self.register_vpath(vol.realpath, vol.flags)
+ assert reg
+ cur, _ = reg
with self.mutex:
cur.connection.commit()
cur.execute("vacuum")
@@ -435,23 +467,25 @@ class Up2k(object):
thr = None
if self.mtag:
- m = "online (running mtp)"
+ t = "online (running mtp)"
if scan_vols:
thr = threading.Thread(target=self._run_all_mtp, name="up2k-mtp-scan")
thr.daemon = True
else:
- del self.pp
- m = "online, idle"
+ self.pp = None
+ t = "online, idle"
for vol in vols:
- self.volstate[vol.vpath] = m
+ self.volstate[vol.vpath] = t
if thr:
thr.start()
return have_e2d
- def register_vpath(self, ptop, flags):
+ def register_vpath(
+ self, ptop: str, flags: dict[str, Any]
+ ) -> Optional[tuple["sqlite3.Cursor", str]]:
histpath = self.asrv.vfs.histtab.get(ptop)
if not histpath:
self.log("no histpath for [{}]".format(ptop))
@@ -460,7 +494,7 @@ class Up2k(object):
db_path = os.path.join(histpath, "up2k.db")
if ptop in self.registry:
try:
- return [self.cur[ptop], db_path]
+ return self.cur[ptop], db_path
except:
return None
@@ -514,14 +548,14 @@ class Up2k(object):
else:
drp = [x for x in drp if x in reg]
- m = "loaded snap {} |{}| ({})".format(path, len(reg.keys()), len(drp or []))
- m = [m] + self._vis_reg_progress(reg)
- self.log("\n".join(m))
+ t = "loaded snap {} |{}| ({})".format(path, len(reg.keys()), len(drp or []))
+ ta = [t] + self._vis_reg_progress(reg)
+ self.log("\n".join(ta))
self.flags[ptop] = flags
self.registry[ptop] = reg
self.droppable[ptop] = drp or []
- self.regdrop(ptop, None)
+ self.regdrop(ptop, "")
if not HAVE_SQLITE3 or "e2d" not in flags or "d2d" in flags:
return None
@@ -530,23 +564,25 @@ class Up2k(object):
try:
cur = self._open_db(db_path)
self.cur[ptop] = cur
- return [cur, db_path]
+ return cur, db_path
except:
msg = "cannot use database at [{}]:\n{}"
self.log(msg.format(ptop, traceback.format_exc()))
return None
- def _build_file_index(self, vol, all_vols):
+ def _build_file_index(self, vol: VFS, all_vols: list[VFS]) -> tuple[bool, bool]:
do_vac = False
top = vol.realpath
rei = vol.flags.get("noidx")
reh = vol.flags.get("nohash")
with self.mutex:
- cur, _ = self.register_vpath(top, vol.flags)
+ reg = self.register_vpath(top, vol.flags)
+ assert reg and self.pp
+ cur, _ = reg
- dbw = [cur, 0, time.time()]
- self.pp.n = next(dbw[0].execute("select count(w) from up"))[0]
+ db = Dbw(cur, 0, time.time())
+ self.pp.n = next(db.c.execute("select count(w) from up"))[0]
excl = [
vol.realpath + "/" + d.vpath[len(vol.vpath) :].lstrip("/")
@@ -558,37 +594,47 @@ class Up2k(object):
if WINDOWS:
excl = [x.replace("/", "\\") for x in excl]
- excl = set(excl)
rtop = absreal(top)
n_add = n_rm = 0
try:
- n_add = self._build_dir(dbw, top, excl, top, rtop, rei, reh, [])
- n_rm = self._drop_lost(dbw[0], top)
+ n_add = self._build_dir(db, top, set(excl), top, rtop, rei, reh, [])
+ n_rm = self._drop_lost(db.c, top)
except:
- m = "failed to index volume [{}]:\n{}"
- self.log(m.format(top, min_ex()), c=1)
+ t = "failed to index volume [{}]:\n{}"
+ self.log(t.format(top, min_ex()), c=1)
- if dbw[1]:
- self.log("commit {} new files".format(dbw[1]))
+ if db.n:
+ self.log("commit {} new files".format(db.n))
- dbw[0].connection.commit()
+ db.c.connection.commit()
- return True, n_add or n_rm or do_vac
+ return True, bool(n_add or n_rm or do_vac)
- def _build_dir(self, dbw, top, excl, cdir, rcdir, rei, reh, seen):
+ def _build_dir(
+ self,
+ db: Dbw,
+ top: str,
+ excl: set[str],
+ cdir: str,
+ rcdir: str,
+ rei: Optional[Pattern[str]],
+ reh: Optional[Pattern[str]],
+ seen: list[str],
+ ) -> int:
if rcdir in seen:
- m = "bailing from symlink loop,\n prev: {}\n curr: {}\n from: {}"
- self.log(m.format(seen[-1], rcdir, cdir), 3)
+ t = "bailing from symlink loop,\n prev: {}\n curr: {}\n from: {}"
+ self.log(t.format(seen[-1], rcdir, cdir), 3)
return 0
seen = seen + [rcdir]
+ assert self.pp and self.mem_cur
self.pp.msg = "a{} {}".format(self.pp.n, cdir)
ret = 0
seen_files = {} # != inames; files-only for dropcheck
g = statdir(self.log_func, not self.args.no_scandir, False, cdir)
- g = sorted(g)
- inames = {x[0]: 1 for x in g}
- for iname, inf in g:
+ gl = sorted(g)
+ inames = {x[0]: 1 for x in gl}
+ for iname, inf in gl:
abspath = os.path.join(cdir, iname)
if rei and rei.search(abspath):
continue
@@ -605,10 +651,10 @@ class Up2k(object):
continue
# self.log(" dir: {}".format(abspath))
try:
- ret += self._build_dir(dbw, top, excl, abspath, rap, rei, reh, seen)
+ ret += self._build_dir(db, top, excl, abspath, rap, rei, reh, seen)
except:
- m = "failed to index subdir [{}]:\n{}"
- self.log(m.format(abspath, min_ex()), c=1)
+ t = "failed to index subdir [{}]:\n{}"
+ self.log(t.format(abspath, min_ex()), c=1)
elif not stat.S_ISREG(inf.st_mode):
self.log("skip type-{:x} file [{}]".format(inf.st_mode, abspath))
else:
@@ -632,31 +678,31 @@ class Up2k(object):
rd, fn = rp.rsplit("/", 1) if "/" in rp else ["", rp]
sql = "select w, mt, sz from up where rd = ? and fn = ?"
try:
- c = dbw[0].execute(sql, (rd, fn))
+ c = db.c.execute(sql, (rd, fn))
except:
- c = dbw[0].execute(sql, s3enc(self.mem_cur, rd, fn))
+ c = db.c.execute(sql, s3enc(self.mem_cur, rd, fn))
in_db = list(c.fetchall())
if in_db:
self.pp.n -= 1
dw, dts, dsz = in_db[0]
if len(in_db) > 1:
- m = "WARN: multiple entries: [{}] => [{}] |{}|\n{}"
+ t = "WARN: multiple entries: [{}] => [{}] |{}|\n{}"
rep_db = "\n".join([repr(x) for x in in_db])
- self.log(m.format(top, rp, len(in_db), rep_db))
+ self.log(t.format(top, rp, len(in_db), rep_db))
dts = -1
if dts == lmod and dsz == sz and (nohash or dw[0] != "#"):
continue
- m = "reindex [{}] => [{}] ({}/{}) ({}/{})".format(
+ t = "reindex [{}] => [{}] ({}/{}) ({}/{})".format(
top, rp, dts, lmod, dsz, sz
)
- self.log(m)
- self.db_rm(dbw[0], rd, fn)
+ self.log(t)
+ self.db_rm(db.c, rd, fn)
ret += 1
- dbw[1] += 1
- in_db = None
+ db.n += 1
+ in_db = []
self.pp.msg = "a{} {}".format(self.pp.n, abspath)
@@ -674,15 +720,15 @@ class Up2k(object):
wark = up2k_wark_from_hashlist(self.salt, sz, hashes)
- self.db_add(dbw[0], wark, rd, fn, lmod, sz, "", 0)
- dbw[1] += 1
+ self.db_add(db.c, wark, rd, fn, lmod, sz, "", 0)
+ db.n += 1
ret += 1
- td = time.time() - dbw[2]
- if dbw[1] >= 4096 or td >= 60:
- self.log("commit {} new files".format(dbw[1]))
- dbw[0].connection.commit()
- dbw[1] = 0
- dbw[2] = time.time()
+ td = time.time() - db.t
+ if db.n >= 4096 or td >= 60:
+ self.log("commit {} new files".format(db.n))
+ db.c.connection.commit()
+ db.n = 0
+ db.t = time.time()
# drop missing files
rd = cdir[len(top) + 1 :].strip("/")
@@ -691,25 +737,26 @@ class Up2k(object):
q = "select fn from up where rd = ?"
try:
- c = dbw[0].execute(q, (rd,))
+ c = db.c.execute(q, (rd,))
except:
- c = dbw[0].execute(q, ("//" + w8b64enc(rd),))
+ c = db.c.execute(q, ("//" + w8b64enc(rd),))
hits = [w8b64dec(x[2:]) if x.startswith("//") else x for (x,) in c]
rm_files = [x for x in hits if x not in seen_files]
n_rm = len(rm_files)
for fn in rm_files:
- self.db_rm(dbw[0], rd, fn)
+ self.db_rm(db.c, rd, fn)
if n_rm:
self.log("forgot {} deleted files".format(n_rm))
return ret
- def _drop_lost(self, cur, top):
+ def _drop_lost(self, cur: "sqlite3.Cursor", top: str) -> int:
rm = []
n_rm = 0
nchecked = 0
+ assert self.pp
# `_build_dir` did all the files, now do dirs
ndirs = next(cur.execute("select count(distinct rd) from up"))[0]
c = cur.execute("select distinct rd from up order by rd desc")
@@ -743,13 +790,16 @@ class Up2k(object):
return n_rm
- def _build_tags_index(self, vol):
+ def _build_tags_index(self, vol: VFS) -> tuple[int, int, bool]:
ptop = vol.realpath
with self.mutex:
- _, db_path = self.register_vpath(ptop, vol.flags)
- entags = self.entags[ptop]
- flags = self.flags[ptop]
- cur = self.cur[ptop]
+ reg = self.register_vpath(ptop, vol.flags)
+
+ assert reg and self.pp and self.mtag
+ _, db_path = reg
+ entags = self.entags[ptop]
+ flags = self.flags[ptop]
+ cur = self.cur[ptop]
n_add = 0
n_rm = 0
@@ -794,7 +844,8 @@ class Up2k(object):
if not self.mtag:
return n_add, n_rm, False
- mpool = False
+ mpool: Optional[Queue[Mpqe]] = None
+
if self.mtag.prefer_mt and self.args.mtag_mt > 1:
mpool = self._start_mpool()
@@ -819,11 +870,10 @@ class Up2k(object):
abspath = os.path.join(ptop, rd, fn)
self.pp.msg = "c{} {}".format(n_left, abspath)
- args = [entags, w, abspath]
if not mpool:
- n_tags = self._tag_file(c3, *args)
+ n_tags = self._tag_file(c3, entags, w, abspath)
else:
- mpool.put(["mtag"] + args)
+ mpool.put(Mpqe({}, entags, w, abspath))
# not registry cursor; do not self.mutex:
n_tags = len(self._flush_mpool(c3))
@@ -850,7 +900,7 @@ class Up2k(object):
return n_add, n_rm, True
- def _flush_mpool(self, wcur):
+ def _flush_mpool(self, wcur: "sqlite3.Cursor") -> list[str]:
ret = []
for x in self.pending_tags:
self._tag_file(wcur, *x)
@@ -859,7 +909,7 @@ class Up2k(object):
self.pending_tags = []
return ret
- def _run_all_mtp(self):
+ def _run_all_mtp(self) -> None:
gid = self.gid
t0 = time.time()
for ptop, flags in self.flags.items():
@@ -870,12 +920,12 @@ class Up2k(object):
msg = "mtp finished in {:.2f} sec ({})"
self.log(msg.format(td, s2hms(td, True)))
- del self.pp
+ self.pp = None
for k in list(self.volstate.keys()):
if "OFFLINE" not in self.volstate[k]:
self.volstate[k] = "online, idle"
- def _run_one_mtp(self, ptop, gid):
+ def _run_one_mtp(self, ptop: str, gid: int) -> None:
if gid != self.gid:
return
@@ -915,8 +965,8 @@ class Up2k(object):
break
q = "select w from mt where k = 't:mtp' limit ?"
- warks = cur.execute(q, (batch_sz,)).fetchall()
- warks = [x[0] for x in warks]
+ zq = cur.execute(q, (batch_sz,)).fetchall()
+ warks = [str(x[0]) for x in zq]
jobs = []
for w in warks:
q = "select rd, fn from up where substr(w,1,16)=? limit 1"
@@ -925,8 +975,8 @@ class Up2k(object):
abspath = os.path.join(ptop, rd, fn)
q = "select k from mt where w = ?"
- have = cur.execute(q, (w,)).fetchall()
- have = [x[0] for x in have]
+ zq2 = cur.execute(q, (w,)).fetchall()
+ have: dict[str, Union[str, float]] = {x[0]: 1 for x in zq2}
parsers = self._get_parsers(ptop, have, abspath)
if not parsers:
@@ -937,7 +987,7 @@ class Up2k(object):
if w in in_progress:
continue
- jobs.append([parsers, None, w, abspath])
+ jobs.append(Mpqe(parsers, set(), w, abspath))
in_progress[w] = True
with self.mutex:
@@ -997,7 +1047,9 @@ class Up2k(object):
wcur.close()
cur.close()
- def _get_parsers(self, ptop, have, abspath):
+ def _get_parsers(
+ self, ptop: str, have: dict[str, Union[str, float]], abspath: str
+ ) -> dict[str, MParser]:
try:
all_parsers = self.mtp_parsers[ptop]
except:
@@ -1030,16 +1082,16 @@ class Up2k(object):
parsers = {k: v for k, v in parsers.items() if v.force or k not in have}
return parsers
- def _start_mpool(self):
+ def _start_mpool(self) -> Queue[Mpqe]:
# mp.pool.ThreadPool and concurrent.futures.ThreadPoolExecutor
# both do crazy runahead so lets reinvent another wheel
nw = max(1, self.args.mtag_mt)
-
- if self.pending_tags is None:
+ assert self.mtag
+ if not self.mpool_used:
+ self.mpool_used = True
self.log("using {}x {}".format(nw, self.mtag.backend))
- self.pending_tags = []
- mpool = Queue(nw)
+ mpool: Queue[Mpqe] = Queue(nw)
for _ in range(nw):
thr = threading.Thread(
target=self._tag_thr, args=(mpool,), name="up2k-mpool"
@@ -1049,50 +1101,55 @@ class Up2k(object):
return mpool
- def _stop_mpool(self, mpool):
+ def _stop_mpool(self, mpool: Queue[Mpqe]) -> None:
if not mpool:
return
for _ in range(mpool.maxsize):
- mpool.put(None)
+ mpool.put(Mpqe({}, set(), "", ""))
mpool.join()
- def _tag_thr(self, q):
+ def _tag_thr(self, q: Queue[Mpqe]) -> None:
+ assert self.mtag
while True:
- task = q.get()
- if not task:
+ qe = q.get()
+ if not qe.w:
q.task_done()
return
try:
- parser, entags, wark, abspath = task
- if parser == "mtag":
- tags = self.mtag.get(abspath)
+ if not qe.mtp:
+ tags = self.mtag.get(qe.abspath)
else:
- tags = self.mtag.get_bin(parser, abspath)
+ tags = self.mtag.get_bin(qe.mtp, qe.abspath)
vtags = [
"\033[36m{} \033[33m{}".format(k, v) for k, v in tags.items()
]
if vtags:
- self.log("{}\033[0m [{}]".format(" ".join(vtags), abspath))
+ self.log("{}\033[0m [{}]".format(" ".join(vtags), qe.abspath))
with self.mutex:
- self.pending_tags.append([entags, wark, abspath, tags])
+ self.pending_tags.append((qe.entags, qe.w, qe.abspath, tags))
except:
ex = traceback.format_exc()
- if parser == "mtag":
- parser = self.mtag.backend
-
- self._log_tag_err(parser, abspath, ex)
+ self._log_tag_err(qe.mtp or self.mtag.backend, qe.abspath, ex)
q.task_done()
- def _log_tag_err(self, parser, abspath, ex):
+ def _log_tag_err(self, parser: Any, abspath: str, ex: Any) -> None:
msg = "{} failed to read tags from {}:\n{}".format(parser, abspath, ex)
self.log(msg.lstrip(), c=1 if " int:
+ assert self.mtag
if tags is None:
try:
tags = self.mtag.get(abspath)
@@ -1127,12 +1184,12 @@ class Up2k(object):
return ret
- def _orz(self, db_path):
+ def _orz(self, db_path: str) -> "sqlite3.Cursor":
timeout = int(max(self.args.srch_time, 5) * 1.2)
return sqlite3.connect(db_path, timeout, check_same_thread=False).cursor()
# x.set_trace_callback(trace)
- def _open_db(self, db_path):
+ def _open_db(self, db_path: str) -> "sqlite3.Cursor":
existed = bos.path.exists(db_path)
cur = self._orz(db_path)
ver = self._read_ver(cur)
@@ -1141,8 +1198,8 @@ class Up2k(object):
if ver == 4:
try:
- m = "creating backup before upgrade: "
- cur = self._backup_db(db_path, cur, ver, m)
+ t = "creating backup before upgrade: "
+ cur = self._backup_db(db_path, cur, ver, t)
self._upgrade_v4(cur)
ver = 5
except:
@@ -1157,8 +1214,8 @@ class Up2k(object):
self.log("WARN: could not list files; DB corrupt?\n" + min_ex())
if (ver or 0) > DB_VER:
- m = "database is version {}, this copyparty only supports versions <= {}"
- raise Exception(m.format(ver, DB_VER))
+ t = "database is version {}, this copyparty only supports versions <= {}"
+ raise Exception(t.format(ver, DB_VER))
msg = "creating new DB (old is bad); backup: "
if ver:
@@ -1171,7 +1228,9 @@ class Up2k(object):
bos.unlink(db_path)
return self._create_db(db_path, None)
- def _backup_db(self, db_path, cur, ver, msg):
+ def _backup_db(
+ self, db_path: str, cur: "sqlite3.Cursor", ver: Optional[int], msg: str
+ ) -> "sqlite3.Cursor":
bak = "{}.bak.{:x}.v{}".format(db_path, int(time.time()), ver)
self.log(msg + bak)
try:
@@ -1180,8 +1239,8 @@ class Up2k(object):
cur.connection.backup(c2)
return cur
except:
- m = "native sqlite3 backup failed; using fallback method:\n"
- self.log(m + min_ex())
+ t = "native sqlite3 backup failed; using fallback method:\n"
+ self.log(t + min_ex())
finally:
c2.close()
@@ -1192,7 +1251,7 @@ class Up2k(object):
shutil.copy2(fsenc(db_path), fsenc(bak))
return self._orz(db_path)
- def _read_ver(self, cur):
+ def _read_ver(self, cur: "sqlite3.Cursor") -> Optional[int]:
for tab in ["ki", "kv"]:
try:
c = cur.execute(r"select v from {} where k = 'sver'".format(tab))
@@ -1202,8 +1261,11 @@ class Up2k(object):
rows = c.fetchall()
if rows:
return int(rows[0][0])
+ return None
- def _create_db(self, db_path, cur):
+ def _create_db(
+ self, db_path: str, cur: Optional["sqlite3.Cursor"]
+ ) -> "sqlite3.Cursor":
"""
collision in 2^(n/2) files where n = bits (6 bits/ch)
10*6/2 = 2^30 = 1'073'741'824, 24.1mb idx 1<<(3*10)
@@ -1236,7 +1298,7 @@ class Up2k(object):
self.log("created DB at {}".format(db_path))
return cur
- def _upgrade_v4(self, cur):
+ def _upgrade_v4(self, cur: "sqlite3.Cursor") -> None:
for cmd in [
r"alter table up add column ip text",
r"alter table up add column at int",
@@ -1247,7 +1309,7 @@ class Up2k(object):
cur.connection.commit()
- def handle_json(self, cj):
+ def handle_json(self, cj: dict[str, Any]) -> dict[str, Any]:
with self.mutex:
if not self.register_vpath(cj["ptop"], cj["vcfg"]):
if cj["ptop"] not in self.registry:
@@ -1269,13 +1331,13 @@ class Up2k(object):
if cur:
if self.no_expr_idx:
q = r"select * from up where w = ?"
- argv = (wark,)
+ argv = [wark]
else:
q = r"select * from up where substr(w,1,16) = ? and w = ?"
- argv = (wark[:16], wark)
+ argv = [wark[:16], wark]
- alts = []
- cur = cur.execute(q, argv)
+ alts: list[tuple[int, int, dict[str, Any]]] = []
+ cur = cur.execute(q, tuple(argv))
for _, dtime, dsize, dp_dir, dp_fn, ip, at in cur:
if dp_dir.startswith("//") or dp_fn.startswith("//"):
dp_dir, dp_fn = s3dec(dp_dir, dp_fn)
@@ -1307,7 +1369,7 @@ class Up2k(object):
+ (2 if dp_dir == cj["prel"] else 0)
+ (1 if dp_fn == cj["name"] else 0)
)
- alts.append([score, -len(alts), j])
+ alts.append((score, -len(alts), j))
job = sorted(alts, reverse=True)[0][2] if alts else None
if job and wark in reg:
@@ -1344,7 +1406,7 @@ class Up2k(object):
# registry is size-constrained + can only contain one unique wark;
# let want_recheck trigger symlink (if still in reg) or reupload
if cur:
- dupe = [cj["prel"], cj["name"], cj["lmod"]]
+ dupe = (cj["prel"], cj["name"], cj["lmod"])
try:
self.dupesched[src].append(dupe)
except:
@@ -1431,17 +1493,19 @@ class Up2k(object):
"wark": wark,
}
- def _untaken(self, fdir, fname, ts, ip):
+ def _untaken(self, fdir: str, fname: str, ts: float, ip: str) -> str:
if self.args.nw:
return fname
# TODO broker which avoid this race and
# provides a new filename if taken (same as bup)
suffix = "-{:.6f}-{}".format(ts, ip.replace(":", "."))
- with ren_open(fname, "wb", fdir=fdir, suffix=suffix) as f:
- return f["orz"][1]
+ with ren_open(fname, "wb", fdir=fdir, suffix=suffix) as zfw:
+ return zfw["orz"][1]
- def _symlink(self, src, dst, verbose=True, lmod=None):
+ def _symlink(
+ self, src: str, dst: str, verbose: bool = True, lmod: float = 0
+ ) -> None:
if verbose:
self.log("linking dupe:\n {0}\n {1}".format(src, dst))
@@ -1475,9 +1539,9 @@ class Up2k(object):
break
nc += 1
if nc > 1:
- lsrc = nsrc[nc:]
+ zsl = nsrc[nc:]
hops = len(ndst[nc:]) - 1
- lsrc = "../" * hops + "/".join(lsrc)
+ lsrc = "../" * hops + "/".join(zsl)
try:
if self.args.hardlink:
@@ -1498,11 +1562,13 @@ class Up2k(object):
if lmod and (not linked or SYMTIME):
times = (int(time.time()), int(lmod))
if ANYWIN:
- self.lastmod_q.append([dst, 0, times])
+ self.lastmod_q.append((dst, 0, times))
else:
bos.utime(dst, times, False)
- def handle_chunk(self, ptop, wark, chash):
+ def handle_chunk(
+ self, ptop: str, wark: str, chash: str
+ ) -> tuple[int, list[int], str, float]:
with self.mutex:
job = self.registry[ptop].get(wark)
if not job:
@@ -1523,8 +1589,8 @@ class Up2k(object):
if chash in job["busy"]:
nh = len(job["hash"])
idx = job["hash"].index(chash)
- m = "that chunk is already being written to:\n {}\n {} {}/{}\n {}"
- raise Pebkac(400, m.format(wark, chash, idx, nh, job["name"]))
+ t = "that chunk is already being written to:\n {}\n {} {}/{}\n {}"
+ raise Pebkac(400, t.format(wark, chash, idx, nh, job["name"]))
job["busy"][chash] = 1
@@ -1535,17 +1601,17 @@ class Up2k(object):
path = os.path.join(job["ptop"], job["prel"], job["tnam"])
- return [chunksize, ofs, path, job["lmod"]]
+ return chunksize, ofs, path, job["lmod"]
- def release_chunk(self, ptop, wark, chash):
+ def release_chunk(self, ptop: str, wark: str, chash: str) -> bool:
with self.mutex:
job = self.registry[ptop].get(wark)
if job:
job["busy"].pop(chash, None)
- return [True]
+ return True
- def confirm_chunk(self, ptop, wark, chash):
+ def confirm_chunk(self, ptop: str, wark: str, chash: str) -> tuple[int, str]:
with self.mutex:
try:
job = self.registry[ptop][wark]
@@ -1553,14 +1619,14 @@ class Up2k(object):
src = os.path.join(pdir, job["tnam"])
dst = os.path.join(pdir, job["name"])
except Exception as ex:
- return "confirm_chunk, wark, " + repr(ex)
+ return "confirm_chunk, wark, " + repr(ex) # type: ignore
job["busy"].pop(chash, None)
try:
job["need"].remove(chash)
except Exception as ex:
- return "confirm_chunk, chash, " + repr(ex)
+ return "confirm_chunk, chash, " + repr(ex) # type: ignore
ret = len(job["need"])
if ret > 0:
@@ -1576,35 +1642,35 @@ class Up2k(object):
return ret, dst
- def finish_upload(self, ptop, wark):
+ def finish_upload(self, ptop: str, wark: str) -> None:
with self.mutex:
self._finish_upload(ptop, wark)
- def _finish_upload(self, ptop, wark):
+ def _finish_upload(self, ptop: str, wark: str) -> None:
try:
job = self.registry[ptop][wark]
pdir = os.path.join(job["ptop"], job["prel"])
src = os.path.join(pdir, job["tnam"])
dst = os.path.join(pdir, job["name"])
except Exception as ex:
- return "finish_upload, wark, " + repr(ex)
+ raise Pebkac(500, "finish_upload, wark, " + repr(ex))
# self.log("--- " + wark + " " + dst + " finish_upload atomic " + dst, 4)
atomic_move(src, dst)
times = (int(time.time()), int(job["lmod"]))
if ANYWIN:
- a = [dst, job["size"], times]
- self.lastmod_q.append(a)
+ z1 = (dst, job["size"], times)
+ self.lastmod_q.append(z1)
elif not job["hash"]:
try:
bos.utime(dst, times)
except:
pass
- a = [job[x] for x in "ptop wark prel name lmod size addr".split()]
- a += [job.get("at") or time.time()]
- if self.idx_wark(*a):
+ z2 = [job[x] for x in "ptop wark prel name lmod size addr".split()]
+ z2 += [job.get("at") or time.time()]
+ if self.idx_wark(*z2):
del self.registry[ptop][wark]
else:
self.regdrop(ptop, wark)
@@ -1622,27 +1688,37 @@ class Up2k(object):
self._symlink(dst, d2, lmod=lmod)
if cur:
self.db_rm(cur, rd, fn)
- self.db_add(cur, wark, rd, fn, *a[-4:])
+ self.db_add(cur, wark, rd, fn, *z2[-4:])
if cur:
cur.connection.commit()
- def regdrop(self, ptop, wark):
- t = self.droppable[ptop]
+ def regdrop(self, ptop: str, wark: str) -> None:
+ olds = self.droppable[ptop]
if wark:
- t.append(wark)
+ olds.append(wark)
- if len(t) <= self.args.reg_cap:
+ if len(olds) <= self.args.reg_cap:
return
- n = len(t) - int(self.args.reg_cap / 2)
- m = "up2k-registry [{}] has {} droppables; discarding {}"
- self.log(m.format(ptop, len(t), n))
- for k in t[:n]:
+ n = len(olds) - int(self.args.reg_cap / 2)
+ t = "up2k-registry [{}] has {} droppables; discarding {}"
+ self.log(t.format(ptop, len(olds), n))
+ for k in olds[:n]:
self.registry[ptop].pop(k, None)
- self.droppable[ptop] = t[n:]
+ self.droppable[ptop] = olds[n:]
- def idx_wark(self, ptop, wark, rd, fn, lmod, sz, ip, at):
+ def idx_wark(
+ self,
+ ptop: str,
+ wark: str,
+ rd: str,
+ fn: str,
+ lmod: float,
+ sz: int,
+ ip: str,
+ at: float,
+ ) -> bool:
cur = self.cur.get(ptop)
if not cur:
return False
@@ -1652,29 +1728,41 @@ class Up2k(object):
cur.connection.commit()
if "e2t" in self.flags[ptop]:
- self.tagq.put([ptop, wark, rd, fn])
+ self.tagq.put((ptop, wark, rd, fn))
self.n_tagq += 1
return True
- def db_rm(self, db, rd, fn):
+ def db_rm(self, db: "sqlite3.Cursor", rd: str, fn: str) -> None:
sql = "delete from up where rd = ? and fn = ?"
try:
db.execute(sql, (rd, fn))
except:
+ assert self.mem_cur
db.execute(sql, s3enc(self.mem_cur, rd, fn))
- def db_add(self, db, wark, rd, fn, ts, sz, ip, at):
+ def db_add(
+ self,
+ db: "sqlite3.Cursor",
+ wark: str,
+ rd: str,
+ fn: str,
+ ts: float,
+ sz: int,
+ ip: str,
+ at: float,
+ ) -> None:
sql = "insert into up values (?,?,?,?,?,?,?)"
v = (wark, int(ts), sz, rd, fn, ip or "", int(at or 0))
try:
db.execute(sql, v)
except:
+ assert self.mem_cur
rd, fn = s3enc(self.mem_cur, rd, fn)
v = (wark, int(ts), sz, rd, fn, ip or "", int(at or 0))
db.execute(sql, v)
- def handle_rm(self, uname, ip, vpaths):
+ def handle_rm(self, uname: str, ip: str, vpaths: list[str]) -> str:
n_files = 0
ok = {}
ng = {}
@@ -1687,12 +1775,14 @@ class Up2k(object):
ng[k] = 1
ng = {k: 1 for k in ng if k not in ok}
- ok = len(ok)
- ng = len(ng)
+ iok = len(ok)
+ ing = len(ng)
- return "deleted {} files (and {}/{} folders)".format(n_files, ok, ok + ng)
+ return "deleted {} files (and {}/{} folders)".format(n_files, iok, iok + ing)
- def _handle_rm(self, uname, ip, vpath):
+ def _handle_rm(
+ self, uname: str, ip: str, vpath: str
+ ) -> tuple[int, list[str], list[str]]:
try:
permsets = [[True, False, False, True]]
vn, rem = self.asrv.vfs.get(vpath, uname, *permsets[0])
@@ -1709,18 +1799,18 @@ class Up2k(object):
vn, rem = vn.get_dbv(rem)
_, _, _, _, dip, dat = self._find_from_vpath(vn.realpath, rem)
- m = "you cannot delete this: "
+ t = "you cannot delete this: "
if not dip:
- m += "file not found"
+ t += "file not found"
elif dip != ip:
- m += "not uploaded by (You)"
+ t += "not uploaded by (You)"
elif dat < time.time() - self.args.unpost:
- m += "uploaded too long ago"
+ t += "uploaded too long ago"
else:
- m = None
+ t = ""
- if m:
- raise Pebkac(400, m)
+ if t:
+ raise Pebkac(400, t)
ptop = vn.realpath
atop = vn.canonical(rem, False)
@@ -1731,16 +1821,19 @@ class Up2k(object):
raise Pebkac(400, "file not found on disk (already deleted?)")
scandir = not self.args.no_scandir
- if stat.S_ISLNK(st.st_mode) or stat.S_ISREG(st.st_mode):
+ if stat.S_ISDIR(st.st_mode):
+ g = vn.walk("", rem, [], uname, permsets, True, scandir, True)
+ if unpost:
+ raise Pebkac(400, "cannot unpost folders")
+ elif stat.S_ISLNK(st.st_mode) or stat.S_ISREG(st.st_mode):
dbv, vrem = self.asrv.vfs.get(vpath, uname, *permsets[0])
dbv, vrem = dbv.get_dbv(vrem)
voldir = vsplit(vrem)[0]
vpath_dir = vsplit(vpath)[0]
- g = [[dbv, voldir, vpath_dir, adir, [[fn, 0]], [], []]]
+ g = [(dbv, voldir, vpath_dir, adir, [(fn, 0)], [], {})] # type: ignore
else:
- g = vn.walk("", rem, [], uname, permsets, True, scandir, True)
- if unpost:
- raise Pebkac(400, "cannot unpost folders")
+ self.log("rm: skip type-{:x} file [{}]".format(st.st_mode, atop))
+ return 0, [], []
n_files = 0
for dbv, vrem, _, adir, files, rd, vd in g:
@@ -1766,7 +1859,7 @@ class Up2k(object):
rm = rmdirs(self.log_func, scandir, True, atop, 1)
return n_files, rm[0], rm[1]
- def handle_mv(self, uname, svp, dvp):
+ def handle_mv(self, uname: str, svp: str, dvp: str) -> str:
svn, srem = self.asrv.vfs.get(svp, uname, True, False, True)
svn, srem = svn.get_dbv(srem)
sabs = svn.canonical(srem, False)
@@ -1808,7 +1901,7 @@ class Up2k(object):
rmdirs(self.log_func, scandir, True, sabs, 1)
return "k"
- def _mv_file(self, uname, svp, dvp):
+ def _mv_file(self, uname: str, svp: str, dvp: str) -> str:
svn, srem = self.asrv.vfs.get(svp, uname, True, False, True)
svn, srem = svn.get_dbv(srem)
@@ -1834,29 +1927,33 @@ class Up2k(object):
if bos.path.islink(sabs):
dlabs = absreal(sabs)
- m = "moving symlink from [{}] to [{}], target [{}]"
- self.log(m.format(sabs, dabs, dlabs))
+ t = "moving symlink from [{}] to [{}], target [{}]"
+ self.log(t.format(sabs, dabs, dlabs))
mt = bos.path.getmtime(sabs, False)
bos.unlink(sabs)
self._symlink(dlabs, dabs, False, lmod=mt)
# folders are too scary, schedule rescan of both vols
- self.need_rescan[svn.vpath] = 1
- self.need_rescan[dvn.vpath] = 1
+ self.need_rescan.add(svn.vpath)
+ self.need_rescan.add(dvn.vpath)
with self.rescan_cond:
self.rescan_cond.notify_all()
return "k"
- c1, w, ftime, fsize, ip, at = self._find_from_vpath(svn.realpath, srem)
+ c1, w, ftime_, fsize_, ip, at = self._find_from_vpath(svn.realpath, srem)
c2 = self.cur.get(dvn.realpath)
- if ftime is None:
+ if ftime_ is None:
st = bos.stat(sabs)
ftime = st.st_mtime
fsize = st.st_size
+ else:
+ ftime = ftime_
+ fsize = fsize_ or 0
if w:
+ assert c1
if c2 and c2 != c1:
self._copy_tags(c1, c2, w)
@@ -1865,7 +1962,7 @@ class Up2k(object):
c1.connection.commit()
if c2:
- self.db_add(c2, w, drd, dfn, ftime, fsize, ip, at)
+ self.db_add(c2, w, drd, dfn, ftime, fsize, ip or "", at or 0)
c2.connection.commit()
else:
self.log("not found in src db: [{}]".format(svp))
@@ -1873,7 +1970,9 @@ class Up2k(object):
bos.rename(sabs, dabs)
return "k"
- def _copy_tags(self, csrc, cdst, wark):
+ def _copy_tags(
+ self, csrc: "sqlite3.Cursor", cdst: "sqlite3.Cursor", wark: str
+ ) -> None:
"""copy all tags for wark from src-db to dst-db"""
w = wark[:16]
@@ -1883,16 +1982,26 @@ class Up2k(object):
for _, k, v in csrc.execute("select * from mt where w=?", (w,)):
cdst.execute("insert into mt values(?,?,?)", (w, k, v))
- def _find_from_vpath(self, ptop, vrem):
+ def _find_from_vpath(
+ self, ptop: str, vrem: str
+ ) -> tuple[
+ Optional["sqlite3.Cursor"],
+ Optional[str],
+ Optional[int],
+ Optional[int],
+ Optional[str],
+ Optional[int],
+ ]:
cur = self.cur.get(ptop)
if not cur:
- return [None] * 6
+ return None, None, None, None, None, None
rd, fn = vsplit(vrem)
q = "select w, mt, sz, ip, at from up where rd=? and fn=? limit 1"
try:
c = cur.execute(q, (rd, fn))
except:
+ assert self.mem_cur
c = cur.execute(q, s3enc(self.mem_cur, rd, fn))
hit = c.fetchone()
@@ -1901,14 +2010,21 @@ class Up2k(object):
return cur, wark, ftime, fsize, ip, at
return cur, None, None, None, None, None
- def _forget_file(self, ptop, vrem, cur, wark, drop_tags):
+ def _forget_file(
+ self,
+ ptop: str,
+ vrem: str,
+ cur: Optional["sqlite3.Cursor"],
+ wark: Optional[str],
+ drop_tags: bool,
+ ) -> None:
"""forgets file in db, fixes symlinks, does not delete"""
srd, sfn = vsplit(vrem)
self.log("forgetting {}".format(vrem))
- if wark:
+ if wark and cur:
self.log("found {} in db".format(wark))
if drop_tags:
- if self._relink(wark, ptop, vrem, None):
+ if self._relink(wark, ptop, vrem, ""):
drop_tags = False
if drop_tags:
@@ -1919,20 +2035,24 @@ class Up2k(object):
reg = self.registry.get(ptop)
if reg:
- if not wark:
- wark = [
+ vdir = vsplit(vrem)[0]
+ wark = wark or next(
+ (
x
for x, y in reg.items()
- if sfn in [y["name"], y.get("tnam")] and y["prel"] == vrem
- ]
-
- if wark and wark in reg:
- m = "forgetting partial upload {} ({})"
- p = self._vis_job_progress(wark)
- self.log(m.format(wark, p))
+ if sfn in [y["name"], y.get("tnam")] and y["prel"] == vdir
+ ),
+ "",
+ )
+ job = reg.get(wark) if wark else None
+ if job:
+ t = "forgetting partial upload {} ({})"
+ p = self._vis_job_progress(job)
+ self.log(t.format(wark, p))
+ assert wark
del reg[wark]
- def _relink(self, wark, sptop, srem, dabs):
+ def _relink(self, wark: str, sptop: str, srem: str, dabs: str) -> int:
"""
update symlinks from file at svn/srem to dabs (rename),
or to first remaining full if no dabs (delete)
@@ -1953,13 +2073,13 @@ class Up2k(object):
if not dupes:
return 0
- full = {}
- links = {}
+ full: dict[str, tuple[str, str]] = {}
+ links: dict[str, tuple[str, str]] = {}
for ptop, vp in dupes:
ap = os.path.join(ptop, vp)
try:
d = links if bos.path.islink(ap) else full
- d[ap] = [ptop, vp]
+ d[ap] = (ptop, vp)
except:
self.log("relink: not found: [{}]".format(ap))
@@ -1973,13 +2093,13 @@ class Up2k(object):
bos.rename(sabs, slabs)
bos.utime(slabs, (int(time.time()), int(mt)), False)
self._symlink(slabs, sabs, False)
- full[slabs] = [ptop, rem]
+ full[slabs] = (ptop, rem)
sabs = slabs
if not dabs:
dabs = list(sorted(full.keys()))[0]
- for alink in links.keys():
+ for alink in links:
lmod = None
try:
if alink != sabs and absreal(alink) != sabs:
@@ -1991,11 +2111,11 @@ class Up2k(object):
except:
pass
- self._symlink(dabs, alink, False, lmod=lmod)
+ self._symlink(dabs, alink, False, lmod=lmod or 0)
return len(full) + len(links)
- def _get_wark(self, cj):
+ def _get_wark(self, cj: dict[str, Any]) -> str:
if len(cj["name"]) > 1024 or len(cj["hash"]) > 512 * 1024: # 16TiB
raise Pebkac(400, "name or numchunks not according to spec")
@@ -2020,15 +2140,14 @@ class Up2k(object):
return wark
- def _hashlist_from_file(self, path):
- pp = self.pp if hasattr(self, "pp") else None
+ def _hashlist_from_file(self, path: str) -> list[str]:
fsz = bos.path.getsize(path)
csz = up2k_chunksize(fsz)
ret = []
with open(fsenc(path), "rb", 512 * 1024) as f:
while fsz > 0:
- if pp:
- pp.msg = "{} MB, {}".format(int(fsz / 1024 / 1024), path)
+ if self.pp:
+ self.pp.msg = "{} MB, {}".format(int(fsz / 1024 / 1024), path)
hashobj = hashlib.sha512()
rem = min(csz, fsz)
@@ -2047,7 +2166,7 @@ class Up2k(object):
return ret
- def _new_upload(self, job):
+ def _new_upload(self, job: dict[str, Any]) -> None:
self.registry[job["ptop"]][job["wark"]] = job
pdir = os.path.join(job["ptop"], job["prel"])
job["name"] = self._untaken(pdir, job["name"], job["t0"], job["addr"])
@@ -2066,8 +2185,8 @@ class Up2k(object):
dip = job["addr"].replace(":", ".")
suffix = "-{:.6f}-{}".format(job["t0"], dip)
- with ren_open(tnam, "wb", fdir=pdir, suffix=suffix) as f:
- f, job["tnam"] = f["orz"]
+ with ren_open(tnam, "wb", fdir=pdir, suffix=suffix) as zfw:
+ f, job["tnam"] = zfw["orz"]
if (
ANYWIN
and self.args.sparse
@@ -2086,7 +2205,7 @@ class Up2k(object):
if not job["hash"]:
self._finish_upload(job["ptop"], job["wark"])
- def _lastmodder(self):
+ def _lastmodder(self) -> None:
while True:
ready = self.lastmod_q
self.lastmod_q = []
@@ -2098,8 +2217,8 @@ class Up2k(object):
try:
bos.utime(path, times, False)
except:
- m = "lmod: failed to utime ({}, {}):\n{}"
- self.log(m.format(path, times, min_ex()))
+ t = "lmod: failed to utime ({}, {}):\n{}"
+ self.log(t.format(path, times, min_ex()))
if self.args.sparse and self.args.sparse * 1024 * 1024 <= sz:
try:
@@ -2107,21 +2226,22 @@ class Up2k(object):
except:
self.log("could not unsparse [{}]".format(path), 3)
- def _snapshot(self):
- self.snap_persist_interval = 300 # persist unfinished index every 5 min
- self.snap_discard_interval = 21600 # drop unfinished after 6 hours inactivity
- self.snap_prev = {}
+ def _snapshot(self) -> None:
+ slp = self.snap_persist_interval
while True:
- time.sleep(self.snap_persist_interval)
- if not hasattr(self, "pp"):
+ time.sleep(slp)
+ if self.pp:
+ slp = 5
+ else:
+ slp = self.snap_persist_interval
self.do_snapshot()
- def do_snapshot(self):
+ def do_snapshot(self) -> None:
with self.mutex:
for k, reg in self.registry.items():
self._snap_reg(k, reg)
- def _snap_reg(self, ptop, reg):
+ def _snap_reg(self, ptop: str, reg: dict[str, dict[str, Any]]) -> None:
now = time.time()
histpath = self.asrv.vfs.histtab.get(ptop)
if not histpath:
@@ -2133,9 +2253,9 @@ class Up2k(object):
if x["need"] and now - x["poke"] > self.snap_discard_interval
]
if rm:
- m = "dropping {} abandoned uploads in {}".format(len(rm), ptop)
+ t = "dropping {} abandoned uploads in {}".format(len(rm), ptop)
vis = [self._vis_job_progress(x) for x in rm]
- self.log("\n".join([m] + vis))
+ self.log("\n".join([t] + vis))
for job in rm:
del reg[job["wark"]]
try:
@@ -2159,8 +2279,8 @@ class Up2k(object):
bos.unlink(path)
return
- newest = max(x["poke"] for _, x in reg.items()) if reg else 0
- etag = [len(reg), newest]
+ newest = float(max(x["poke"] for _, x in reg.items()) if reg else 0)
+ etag = (len(reg), newest)
if etag == self.snap_prev.get(ptop):
return
@@ -2177,10 +2297,11 @@ class Up2k(object):
self.log("snap: {} |{}|".format(path, len(reg.keys())))
self.snap_prev[ptop] = etag
- def _tagger(self):
+ def _tagger(self) -> None:
with self.mutex:
self.n_tagq += 1
+ assert self.mtag
while True:
with self.mutex:
self.n_tagq -= 1
@@ -2218,7 +2339,7 @@ class Up2k(object):
self.log("tagged {} ({}+{})".format(abspath, ntags1, len(tags) - ntags1))
- def _hasher(self):
+ def _hasher(self) -> None:
with self.mutex:
self.n_hashq += 1
@@ -2240,20 +2361,21 @@ class Up2k(object):
with self.mutex:
self.idx_wark(ptop, wark, rd, fn, inf.st_mtime, inf.st_size, ip, at)
- def hash_file(self, ptop, flags, rd, fn, ip, at):
+ def hash_file(
+ self, ptop: str, flags: dict[str, Any], rd: str, fn: str, ip: str, at: float
+ ) -> None:
with self.mutex:
self.register_vpath(ptop, flags)
- self.hashq.put([ptop, rd, fn, ip, at])
+ self.hashq.put((ptop, rd, fn, ip, at))
self.n_hashq += 1
# self.log("hashq {} push {}/{}/{}".format(self.n_hashq, ptop, rd, fn))
- def shutdown(self):
- if hasattr(self, "snap_prev"):
- self.log("writing snapshot")
- self.do_snapshot()
+ def shutdown(self) -> None:
+ self.log("writing snapshot")
+ self.do_snapshot()
-def up2k_chunksize(filesize):
+def up2k_chunksize(filesize: int) -> int:
chunksize = 1024 * 1024
stepsize = 512 * 1024
while True:
@@ -2266,18 +2388,17 @@ def up2k_chunksize(filesize):
stepsize *= mul
-def up2k_wark_from_hashlist(salt, filesize, hashes):
+def up2k_wark_from_hashlist(salt: str, filesize: int, hashes: list[str]) -> str:
"""server-reproducible file identifier, independent of name or location"""
- ident = [salt, str(filesize)]
- ident.extend(hashes)
- ident = "\n".join(ident)
+ values = [salt, str(filesize)] + hashes
+ vstr = "\n".join(values)
- wark = hashlib.sha512(ident.encode("utf-8")).digest()[:33]
+ wark = hashlib.sha512(vstr.encode("utf-8")).digest()[:33]
wark = base64.urlsafe_b64encode(wark)
return wark.decode("ascii")
-def up2k_wark_from_metadata(salt, sz, lastmod, rd, fn):
+def up2k_wark_from_metadata(salt: str, sz: int, lastmod: int, rd: str, fn: str) -> str:
ret = fsenc("{}\n{}\n{}\n{}\n{}".format(salt, lastmod, sz, rd, fn))
ret = base64.urlsafe_b64encode(hashlib.sha512(ret).digest())
return "#{}".format(ret.decode("ascii"))[:44]
diff --git a/copyparty/util.py b/copyparty/util.py
index a01fdfdb..2ca710f5 100644
--- a/copyparty/util.py
+++ b/copyparty/util.py
@@ -1,51 +1,79 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
-import re
-import os
-import sys
-import stat
-import time
import base64
+import contextlib
+import hashlib
+import mimetypes
+import os
+import platform
+import re
import select
-import struct
import signal
import socket
-import hashlib
-import platform
-import traceback
-import threading
-import mimetypes
-import contextlib
+import stat
+import struct
import subprocess as sp # nosec
-from datetime import datetime
+import sys
+import threading
+import time
+import traceback
from collections import Counter
+from datetime import datetime
-from .__init__ import PY2, WINDOWS, ANYWIN, VT100
+from .__init__ import ANYWIN, PY2, TYPE_CHECKING, VT100, WINDOWS
from .stolen import surrogateescape
+try:
+ HAVE_SQLITE3 = True
+ import sqlite3 # pylint: disable=unused-import # typechk
+except:
+ HAVE_SQLITE3 = False
+
+try:
+ import types
+ from collections.abc import Callable, Iterable
+
+ import typing
+ from typing import Any, Generator, Optional, Protocol, Union
+
+ class RootLogger(Protocol):
+ def __call__(self, src: str, msg: str, c: Union[int, str] = 0) -> None:
+ return None
+
+ class NamedLogger(Protocol):
+ def __call__(self, msg: str, c: Union[int, str] = 0) -> None:
+ return None
+
+except:
+ pass
+
+if TYPE_CHECKING:
+ from .authsrv import VFS
+
+
FAKE_MP = False
try:
- if FAKE_MP:
- import multiprocessing.dummy as mp # noqa: F401 # pylint: disable=unused-import
+ if not FAKE_MP:
+ import multiprocessing as mp
else:
- import multiprocessing as mp # noqa: F401 # pylint: disable=unused-import
+ import multiprocessing.dummy as mp # type: ignore
except ImportError:
# support jython
- mp = None
+ mp = None # type: ignore
if not PY2:
- from urllib.parse import unquote_to_bytes as unquote
+ from io import BytesIO
from urllib.parse import quote_from_bytes as quote
- from queue import Queue # pylint: disable=unused-import
- from io import BytesIO # pylint: disable=unused-import
+ from urllib.parse import unquote_to_bytes as unquote
else:
- from urllib import unquote # pylint: disable=no-name-in-module
+ from StringIO import StringIO as BytesIO
from urllib import quote # pylint: disable=no-name-in-module
- from Queue import Queue # pylint: disable=import-error,no-name-in-module
- from StringIO import StringIO as BytesIO # pylint: disable=unused-import
+ from urllib import unquote # pylint: disable=no-name-in-module
+_: Any = (mp, BytesIO, quote, unquote)
+__all__ = ["mp", "BytesIO", "quote", "unquote"]
try:
struct.unpack(b">i", b"idgi")
@@ -53,20 +81,21 @@ try:
sunpack = struct.unpack
except:
- def spack(f, *a, **ka):
- return struct.pack(f.decode("ascii"), *a, **ka)
+ def spack(fmt: bytes, *a: Any) -> bytes:
+ return struct.pack(fmt.decode("ascii"), *a)
- def sunpack(f, *a, **ka):
- return struct.unpack(f.decode("ascii"), *a, **ka)
+ def sunpack(fmt: bytes, a: bytes) -> tuple[Any, ...]:
+ return struct.unpack(fmt.decode("ascii"), a)
ansi_re = re.compile("\033\\[[^mK]*[mK]")
surrogateescape.register_surrogateescape()
-FS_ENCODING = sys.getfilesystemencoding()
if WINDOWS and PY2:
FS_ENCODING = "utf-8"
+else:
+ FS_ENCODING = sys.getfilesystemencoding()
SYMTIME = sys.version_info >= (3, 6) and os.utime in os.supports_follow_symlinks
@@ -116,7 +145,7 @@ MIMES = {
}
-def _add_mimes():
+def _add_mimes() -> None:
for ln in """text css html csv
application json wasm xml pdf rtf zip
image webp jpeg png gif bmp
@@ -170,18 +199,18 @@ REKOBO_LKEY = {k.lower(): v for k, v in REKOBO_KEY.items()}
class Cooldown(object):
- def __init__(self, maxage):
+ def __init__(self, maxage: float) -> None:
self.maxage = maxage
self.mutex = threading.Lock()
- self.hist = {}
- self.oldest = 0
+ self.hist: dict[str, float] = {}
+ self.oldest = 0.0
- def poke(self, key):
+ def poke(self, key: str) -> bool:
with self.mutex:
now = time.time()
ret = False
- pv = self.hist.get(key, 0)
+ pv: float = self.hist.get(key, 0)
if now - pv > self.maxage:
self.hist[key] = now
ret = True
@@ -204,12 +233,12 @@ class _Unrecv(object):
undo any number of socket recv ops
"""
- def __init__(self, s, log):
- self.s = s # type: socket.socket
+ def __init__(self, s: socket.socket, log: Optional[NamedLogger]) -> None:
+ self.s = s
self.log = log
- self.buf = b""
+ self.buf: bytes = b""
- def recv(self, nbytes):
+ def recv(self, nbytes: int) -> bytes:
if self.buf:
ret = self.buf[:nbytes]
self.buf = self.buf[nbytes:]
@@ -221,25 +250,25 @@ class _Unrecv(object):
return ret
- def recv_ex(self, nbytes, raise_on_trunc=True):
+ def recv_ex(self, nbytes: int, raise_on_trunc: bool = True) -> bytes:
"""read an exact number of bytes"""
ret = b""
try:
while nbytes > len(ret):
ret += self.recv(nbytes - len(ret))
except OSError:
- m = "client only sent {} of {} expected bytes".format(len(ret), nbytes)
+ t = "client only sent {} of {} expected bytes".format(len(ret), nbytes)
if len(ret) <= 16:
- m += "; got {!r}".format(ret)
+ t += "; got {!r}".format(ret)
if raise_on_trunc:
- raise UnrecvEOF(5, m)
+ raise UnrecvEOF(5, t)
elif self.log:
- self.log(m, 3)
+ self.log(t, 3)
return ret
- def unrecv(self, buf):
+ def unrecv(self, buf: bytes) -> None:
self.buf = buf + self.buf
@@ -248,28 +277,28 @@ class _LUnrecv(object):
with expensive debug logging
"""
- def __init__(self, s, log):
+ def __init__(self, s: socket.socket, log: Optional[NamedLogger]) -> None:
self.s = s
self.log = log
self.buf = b""
- def recv(self, nbytes):
+ def recv(self, nbytes: int) -> bytes:
if self.buf:
ret = self.buf[:nbytes]
self.buf = self.buf[nbytes:]
- m = "\033[0;7mur:pop:\033[0;1;32m {}\n\033[0;7mur:rem:\033[0;1;35m {}\033[0m"
- self.log(m.format(ret, self.buf))
+ t = "\033[0;7mur:pop:\033[0;1;32m {}\n\033[0;7mur:rem:\033[0;1;35m {}\033[0m"
+ print(t.format(ret, self.buf))
return ret
ret = self.s.recv(nbytes)
- m = "\033[0;7mur:recv\033[0;1;33m {}\033[0m"
- self.log(m.format(ret))
+ t = "\033[0;7mur:recv\033[0;1;33m {}\033[0m"
+ print(t.format(ret))
if not ret:
raise UnrecvEOF("client stopped sending data")
return ret
- def recv_ex(self, nbytes, raise_on_trunc=True):
+ def recv_ex(self, nbytes: int, raise_on_trunc: bool = True) -> bytes:
"""read an exact number of bytes"""
try:
ret = self.recv(nbytes)
@@ -285,18 +314,18 @@ class _LUnrecv(object):
err = True
if err:
- m = "client only sent {} of {} expected bytes".format(len(ret), nbytes)
+ t = "client only sent {} of {} expected bytes".format(len(ret), nbytes)
if raise_on_trunc:
- raise UnrecvEOF(m)
+ raise UnrecvEOF(t)
elif self.log:
- self.log(m, 3)
+ self.log(t, 3)
return ret
- def unrecv(self, buf):
+ def unrecv(self, buf: bytes) -> None:
self.buf = buf + self.buf
- m = "\033[0;7mur:push\033[0;1;31m {}\n\033[0;7mur:rem:\033[0;1;35m {}\033[0m"
- self.log(m.format(buf, self.buf))
+ t = "\033[0;7mur:push\033[0;1;31m {}\n\033[0;7mur:rem:\033[0;1;35m {}\033[0m"
+ print(t.format(buf, self.buf))
Unrecv = _Unrecv
@@ -304,14 +333,14 @@ Unrecv = _Unrecv
class FHC(object):
class CE(object):
- def __init__(self, fh):
- self.ts = 0
+ def __init__(self, fh: typing.BinaryIO) -> None:
+ self.ts: float = 0
self.fhs = [fh]
- def __init__(self):
- self.cache = {}
+ def __init__(self) -> None:
+ self.cache: dict[str, FHC.CE] = {}
- def close(self, path):
+ def close(self, path: str) -> None:
try:
ce = self.cache[path]
except:
@@ -322,7 +351,7 @@ class FHC(object):
del self.cache[path]
- def clean(self):
+ def clean(self) -> None:
if not self.cache:
return
@@ -337,10 +366,10 @@ class FHC(object):
self.cache = keep
- def pop(self, path):
+ def pop(self, path: str) -> typing.BinaryIO:
return self.cache[path].fhs.pop()
- def put(self, path, fh):
+ def put(self, path: str, fh: typing.BinaryIO) -> None:
try:
ce = self.cache[path]
ce.fhs.append(fh)
@@ -356,14 +385,15 @@ class ProgressPrinter(threading.Thread):
periodically print progress info without linefeeds
"""
- def __init__(self):
+ def __init__(self) -> None:
threading.Thread.__init__(self, name="pp")
self.daemon = True
- self.msg = None
+ self.msg = ""
self.end = False
+ self.n = -1
self.start()
- def run(self):
+ def run(self) -> None:
msg = None
fmt = " {}\033[K\r" if VT100 else " {} $\r"
while not self.end:
@@ -384,7 +414,7 @@ class ProgressPrinter(threading.Thread):
sys.stdout.flush() # necessary on win10 even w/ stderr btw
-def uprint(msg):
+def uprint(msg: str) -> None:
try:
print(msg, end="")
except UnicodeEncodeError:
@@ -394,17 +424,17 @@ def uprint(msg):
print(msg.encode("ascii", "replace").decode(), end="")
-def nuprint(msg):
+def nuprint(msg: str) -> None:
uprint("{}\n".format(msg))
-def rice_tid():
+def rice_tid() -> str:
tid = threading.current_thread().ident
c = sunpack(b"B" * 5, spack(b">Q", tid)[-5:])
return "".join("\033[1;37;48;5;{0}m{0:02x}".format(x) for x in c) + "\033[0m"
-def trace(*args, **kwargs):
+def trace(*args: Any, **kwargs: Any) -> None:
t = time.time()
stack = "".join(
"\033[36m{}\033[33m{}".format(x[0].split(os.sep)[-1][:-3], x[1])
@@ -423,15 +453,15 @@ def trace(*args, **kwargs):
nuprint(msg)
-def alltrace():
- threads = {}
+def alltrace() -> str:
+ threads: dict[str, types.FrameType] = {}
names = dict([(t.ident, t.name) for t in threading.enumerate()])
for tid, stack in sys._current_frames().items():
name = "{} ({:x})".format(names.get(tid), tid)
threads[name] = stack
- rret = []
- bret = []
+ rret: list[str] = []
+ bret: list[str] = []
for name, stack in sorted(threads.items()):
ret = ["\n\n# {}".format(name)]
pad = None
@@ -451,20 +481,20 @@ def alltrace():
return "\n".join(rret + bret)
-def start_stackmon(arg_str, nid):
+def start_stackmon(arg_str: str, nid: int) -> None:
suffix = "-{}".format(nid) if nid else ""
fp, f = arg_str.rsplit(",", 1)
- f = int(f)
+ zi = int(f)
t = threading.Thread(
target=stackmon,
- args=(fp, f, suffix),
+ args=(fp, zi, suffix),
name="stackmon" + suffix,
)
t.daemon = True
t.start()
-def stackmon(fp, ival, suffix):
+def stackmon(fp: str, ival: float, suffix: str) -> None:
ctr = 0
while True:
ctr += 1
@@ -474,7 +504,9 @@ def stackmon(fp, ival, suffix):
f.write(st.encode("utf-8", "replace"))
-def start_log_thrs(logger, ival, nid):
+def start_log_thrs(
+ logger: Callable[[str, str, int], None], ival: float, nid: int
+) -> None:
ival = float(ival)
tname = lname = "log-thrs"
if nid:
@@ -490,7 +522,7 @@ def start_log_thrs(logger, ival, nid):
t.start()
-def log_thrs(log, ival, name):
+def log_thrs(log: Callable[[str, str, int], None], ival: float, name: str) -> None:
while True:
time.sleep(ival)
tv = [x.name for x in threading.enumerate()]
@@ -507,7 +539,7 @@ def log_thrs(log, ival, name):
log(name, "\033[0m \033[33m".join(tv), 3)
-def vol_san(vols, txt):
+def vol_san(vols: list["VFS"], txt: bytes) -> bytes:
for vol in vols:
txt = txt.replace(vol.realpath.encode("utf-8"), vol.vpath.encode("utf-8"))
txt = txt.replace(
@@ -518,24 +550,26 @@ def vol_san(vols, txt):
return txt
-def min_ex(max_lines=8, reverse=False):
+def min_ex(max_lines: int = 8, reverse: bool = False) -> str:
et, ev, tb = sys.exc_info()
- tb = traceback.extract_tb(tb)
+ stb = traceback.extract_tb(tb)
fmt = "{} @ {} <{}>: {}"
- ex = [fmt.format(fp.split(os.sep)[-1], ln, fun, txt) for fp, ln, fun, txt in tb]
- ex.append("[{}] {}".format(et.__name__, ev))
+ ex = [fmt.format(fp.split(os.sep)[-1], ln, fun, txt) for fp, ln, fun, txt in stb]
+ ex.append("[{}] {}".format(et.__name__ if et else "(anonymous)", ev))
return "\n".join(ex[-max_lines:][:: -1 if reverse else 1])
@contextlib.contextmanager
-def ren_open(fname, *args, **kwargs):
+def ren_open(
+ fname: str, *args: Any, **kwargs: Any
+) -> Generator[dict[str, tuple[typing.IO[Any], str]], None, None]:
fun = kwargs.pop("fun", open)
fdir = kwargs.pop("fdir", None)
suffix = kwargs.pop("suffix", None)
if fname == os.devnull:
with fun(fname, *args, **kwargs) as f:
- yield {"orz": [f, fname]}
+ yield {"orz": (f, fname)}
return
if suffix:
@@ -575,7 +609,7 @@ def ren_open(fname, *args, **kwargs):
with open(fsenc(fp2), "wb") as f2:
f2.write(orig_name.encode("utf-8"))
- yield {"orz": [f, fname]}
+ yield {"orz": (f, fname)}
return
except OSError as ex_:
@@ -584,9 +618,9 @@ def ren_open(fname, *args, **kwargs):
raise
if not b64:
- b64 = (bname + ext).encode("utf-8", "replace")
- b64 = hashlib.sha512(b64).digest()[:12]
- b64 = base64.urlsafe_b64encode(b64).decode("utf-8")
+ zs = (bname + ext).encode("utf-8", "replace")
+ zs = hashlib.sha512(zs).digest()[:12]
+ b64 = base64.urlsafe_b64encode(zs).decode("utf-8")
badlen = len(fname)
while len(fname) >= badlen:
@@ -608,8 +642,8 @@ def ren_open(fname, *args, **kwargs):
class MultipartParser(object):
- def __init__(self, log_func, sr, http_headers):
- self.sr = sr # type: Unrecv
+ def __init__(self, log_func: NamedLogger, sr: Unrecv, http_headers: dict[str, str]):
+ self.sr = sr
self.log = log_func
self.headers = http_headers
@@ -622,10 +656,14 @@ class MultipartParser(object):
r'^content-disposition:(?: *|.*; *)filename="(.*)"', re.IGNORECASE
)
- self.boundary = None
- self.gen = None
+ self.boundary = b""
+ self.gen: Optional[
+ Generator[
+ tuple[str, Optional[str], Generator[bytes, None, None]], None, None
+ ]
+ ] = None
- def _read_header(self):
+ def _read_header(self) -> tuple[str, Optional[str]]:
"""
returns [fieldname, filename] after eating a block of multipart headers
while doing a decent job at dealing with the absolute mess that is
@@ -641,7 +679,8 @@ class MultipartParser(object):
# rfc-7578 overrides rfc-2388 so this is not-impl
# (opera >=9 <11.10 is the only thing i've ever seen use it)
raise Pebkac(
- "you can't use that browser to upload multiple files at once"
+ 400,
+ "you can't use that browser to upload multiple files at once",
)
continue
@@ -655,12 +694,12 @@ class MultipartParser(object):
raise Pebkac(400, "not form-data: {}".format(ln))
try:
- field = self.re_cdisp_field.match(ln).group(1)
+ field = self.re_cdisp_field.match(ln).group(1) # type: ignore
except:
raise Pebkac(400, "missing field name: {}".format(ln))
try:
- fn = self.re_cdisp_file.match(ln).group(1)
+ fn = self.re_cdisp_file.match(ln).group(1) # type: ignore
except:
# this is not a file upload, we're done
return field, None
@@ -687,11 +726,10 @@ class MultipartParser(object):
esc = False
for ch in fn:
if esc:
- if ch in ['"', "\\"]:
- ret += '"'
- else:
- ret += esc + ch
esc = False
+ if ch not in ['"', "\\"]:
+ ret += "\\"
+ ret += ch
elif ch == "\\":
esc = True
elif ch == '"':
@@ -699,9 +737,11 @@ class MultipartParser(object):
else:
ret += ch
- return [field, ret]
+ return field, ret
- def _read_data(self):
+ raise Pebkac(400, "server expected a multipart header but you never sent one")
+
+ def _read_data(self) -> Generator[bytes, None, None]:
blen = len(self.boundary)
bufsz = 32 * 1024
while True:
@@ -748,7 +788,9 @@ class MultipartParser(object):
yield buf
- def _run_gen(self):
+ def _run_gen(
+ self,
+ ) -> Generator[tuple[str, Optional[str], Generator[bytes, None, None]], None, None]:
"""
yields [fieldname, unsanitized_filename, fieldvalue]
where fieldvalue yields chunks of data
@@ -756,7 +798,7 @@ class MultipartParser(object):
run = True
while run:
fieldname, filename = self._read_header()
- yield [fieldname, filename, self._read_data()]
+ yield (fieldname, filename, self._read_data())
tail = self.sr.recv_ex(2, False)
@@ -766,19 +808,19 @@ class MultipartParser(object):
run = False
if tail != b"\r\n":
- m = "protocol error after field value: want b'\\r\\n', got {!r}"
- raise Pebkac(400, m.format(tail))
+ t = "protocol error after field value: want b'\\r\\n', got {!r}"
+ raise Pebkac(400, t.format(tail))
- def _read_value(self, iterator, max_len):
+ def _read_value(self, iterable: Iterable[bytes], max_len: int) -> bytes:
ret = b""
- for buf in iterator:
+ for buf in iterable:
ret += buf
if len(ret) > max_len:
raise Pebkac(400, "field length is too long")
return ret
- def parse(self):
+ def parse(self) -> None:
# spec says there might be junk before the first boundary,
# can't have the leading \r\n if that's not the case
self.boundary = b"--" + get_boundary(self.headers).encode("utf-8")
@@ -793,11 +835,12 @@ class MultipartParser(object):
self.boundary = b"\r\n" + self.boundary
self.gen = self._run_gen()
- def require(self, field_name, max_len):
+ def require(self, field_name: str, max_len: int) -> str:
"""
returns the value of the next field in the multipart body,
raises if the field name is not as expected
"""
+ assert self.gen
p_field, _, p_data = next(self.gen)
if p_field != field_name:
raise Pebkac(
@@ -806,14 +849,15 @@ class MultipartParser(object):
return self._read_value(p_data, max_len).decode("utf-8", "surrogateescape")
- def drop(self):
+ def drop(self) -> None:
"""discards the remaining multipart body"""
+ assert self.gen
for _, _, data in self.gen:
for _ in data:
pass
-def get_boundary(headers):
+def get_boundary(headers: dict[str, str]) -> str:
# boundaries contain a-z A-Z 0-9 ' ( ) + _ , - . / : = ?
# (whitespace allowed except as the last char)
ptn = r"^multipart/form-data *; *(.*; *)?boundary=([^;]+)"
@@ -825,14 +869,14 @@ def get_boundary(headers):
return m.group(2)
-def read_header(sr):
+def read_header(sr: Unrecv) -> list[str]:
ret = b""
while True:
try:
ret += sr.recv(1024)
except:
if not ret:
- return None
+ return []
raise Pebkac(
400,
@@ -853,7 +897,7 @@ def read_header(sr):
return ret[:ofs].decode("utf-8", "surrogateescape").lstrip("\r\n").split("\r\n")
-def gen_filekey(salt, fspath, fsize, inode):
+def gen_filekey(salt: str, fspath: str, fsize: int, inode: int) -> str:
return base64.urlsafe_b64encode(
hashlib.sha512(
"{} {} {} {}".format(salt, fspath, fsize, inode).encode("utf-8", "replace")
@@ -861,7 +905,7 @@ def gen_filekey(salt, fspath, fsize, inode):
).decode("ascii")
-def gencookie(k, v, dur):
+def gencookie(k: str, v: str, dur: Optional[int]) -> str:
v = v.replace(";", "")
if dur:
dt = datetime.utcfromtimestamp(time.time() + dur)
@@ -872,7 +916,7 @@ def gencookie(k, v, dur):
return "{}={}; Path=/; Expires={}; SameSite=Lax".format(k, v, exp)
-def humansize(sz, terse=False):
+def humansize(sz: float, terse: bool = False) -> str:
for unit in ["B", "KiB", "MiB", "GiB", "TiB"]:
if sz < 1024:
break
@@ -887,18 +931,18 @@ def humansize(sz, terse=False):
return ret.replace("iB", "").replace(" ", "")
-def unhumanize(sz):
+def unhumanize(sz: str) -> int:
try:
- return float(sz)
+ return int(sz)
except:
pass
- mul = sz[-1:].lower()
- mul = {"k": 1024, "m": 1024 * 1024, "g": 1024 * 1024 * 1024}.get(mul, 1)
- return float(sz[:-1]) * mul
+ mc = sz[-1:].lower()
+ mi = {"k": 1024, "m": 1024 * 1024, "g": 1024 * 1024 * 1024}.get(mc, 1)
+ return int(float(sz[:-1]) * mi)
-def get_spd(nbyte, t0, t=None):
+def get_spd(nbyte: int, t0: float, t: Optional[float] = None) -> str:
if t is None:
t = time.time()
@@ -908,7 +952,7 @@ def get_spd(nbyte, t0, t=None):
return "{} \033[0m{}/s\033[0m".format(s1, s2)
-def s2hms(s, optional_h=False):
+def s2hms(s: float, optional_h: bool = False) -> str:
s = int(s)
h, s = divmod(s, 3600)
m, s = divmod(s, 60)
@@ -918,7 +962,7 @@ def s2hms(s, optional_h=False):
return "{}:{:02}:{:02}".format(h, m, s)
-def uncyg(path):
+def uncyg(path: str) -> str:
if len(path) < 2 or not path.startswith("/"):
return path
@@ -928,8 +972,8 @@ def uncyg(path):
return "{}:\\{}".format(path[1], path[3:])
-def undot(path):
- ret = []
+def undot(path: str) -> str:
+ ret: list[str] = []
for node in path.split("/"):
if node in ["", "."]:
continue
@@ -944,7 +988,7 @@ def undot(path):
return "/".join(ret)
-def sanitize_fn(fn, ok, bad):
+def sanitize_fn(fn: str, ok: str, bad: list[str]) -> str:
if "/" not in ok:
fn = fn.replace("\\", "/").split("/")[-1]
@@ -976,7 +1020,7 @@ def sanitize_fn(fn, ok, bad):
return fn.strip()
-def relchk(rp):
+def relchk(rp: str) -> str:
if ANYWIN:
if "\n" in rp or "\r" in rp:
return "x\nx"
@@ -985,8 +1029,10 @@ def relchk(rp):
if p != rp:
return "[{}]".format(p)
+ return ""
-def absreal(fpath):
+
+def absreal(fpath: str) -> str:
try:
return fsdec(os.path.abspath(os.path.realpath(fsenc(fpath))))
except:
@@ -999,26 +1045,26 @@ def absreal(fpath):
return os.path.abspath(os.path.realpath(fpath))
-def u8safe(txt):
+def u8safe(txt: str) -> str:
try:
return txt.encode("utf-8", "xmlcharrefreplace").decode("utf-8", "replace")
except:
return txt.encode("utf-8", "replace").decode("utf-8", "replace")
-def exclude_dotfiles(filepaths):
+def exclude_dotfiles(filepaths: list[str]) -> list[str]:
return [x for x in filepaths if not x.split("/")[-1].startswith(".")]
-def http_ts(ts):
+def http_ts(ts: int) -> str:
file_dt = datetime.utcfromtimestamp(ts)
return file_dt.strftime(HTTP_TS_FMT)
-def html_escape(s, quote=False, crlf=False):
+def html_escape(s: str, quot: bool = False, crlf: bool = False) -> str:
"""html.escape but also newlines"""
s = s.replace("&", "&").replace("<", "<").replace(">", ">")
- if quote:
+ if quot:
s = s.replace('"', """).replace("'", "'")
if crlf:
s = s.replace("\r", "
").replace("\n", "
")
@@ -1026,10 +1072,10 @@ def html_escape(s, quote=False, crlf=False):
return s
-def html_bescape(s, quote=False, crlf=False):
+def html_bescape(s: bytes, quot: bool = False, crlf: bool = False) -> bytes:
"""html.escape but bytestrings"""
s = s.replace(b"&", b"&").replace(b"<", b"<").replace(b">", b">")
- if quote:
+ if quot:
s = s.replace(b'"', b""").replace(b"'", b"'")
if crlf:
s = s.replace(b"\r", b"
").replace(b"\n", b"
")
@@ -1037,18 +1083,20 @@ def html_bescape(s, quote=False, crlf=False):
return s
-def quotep(txt):
+def quotep(txt: str) -> str:
"""url quoter which deals with bytes correctly"""
btxt = w8enc(txt)
quot1 = quote(btxt, safe=b"/")
if not PY2:
- quot1 = quot1.encode("ascii")
+ quot2 = quot1.encode("ascii")
+ else:
+ quot2 = quot1
- quot2 = quot1.replace(b" ", b"+")
- return w8dec(quot2)
+ quot3 = quot2.replace(b" ", b"+")
+ return w8dec(quot3)
-def unquotep(txt):
+def unquotep(txt: str) -> str:
"""url unquoter which deals with bytes correctly"""
btxt = w8enc(txt)
# btxt = btxt.replace(b"+", b" ")
@@ -1056,14 +1104,14 @@ def unquotep(txt):
return w8dec(unq2)
-def vsplit(vpath):
+def vsplit(vpath: str) -> tuple[str, str]:
if "/" not in vpath:
return "", vpath
- return vpath.rsplit("/", 1)
+ return vpath.rsplit("/", 1) # type: ignore
-def w8dec(txt):
+def w8dec(txt: bytes) -> str:
"""decodes filesystem-bytes to wtf8"""
if PY2:
return surrogateescape.decodefilename(txt)
@@ -1071,7 +1119,7 @@ def w8dec(txt):
return txt.decode(FS_ENCODING, "surrogateescape")
-def w8enc(txt):
+def w8enc(txt: str) -> bytes:
"""encodes wtf8 to filesystem-bytes"""
if PY2:
return surrogateescape.encodefilename(txt)
@@ -1079,12 +1127,12 @@ def w8enc(txt):
return txt.encode(FS_ENCODING, "surrogateescape")
-def w8b64dec(txt):
+def w8b64dec(txt: str) -> str:
"""decodes base64(filesystem-bytes) to wtf8"""
return w8dec(base64.urlsafe_b64decode(txt.encode("ascii")))
-def w8b64enc(txt):
+def w8b64enc(txt: str) -> str:
"""encodes wtf8 to base64(filesystem-bytes)"""
return base64.urlsafe_b64encode(w8enc(txt)).decode("ascii")
@@ -1102,8 +1150,8 @@ else:
fsdec = w8dec
-def s3enc(mem_cur, rd, fn):
- ret = []
+def s3enc(mem_cur: "sqlite3.Cursor", rd: str, fn: str) -> tuple[str, str]:
+ ret: list[str] = []
for v in [rd, fn]:
try:
mem_cur.execute("select * from a where b = ?", (v,))
@@ -1112,10 +1160,10 @@ def s3enc(mem_cur, rd, fn):
ret.append("//" + w8b64enc(v))
# self.log("mojien [{}] {}".format(v, ret[-1][2:]))
- return tuple(ret)
+ return ret[0], ret[1]
-def s3dec(rd, fn):
+def s3dec(rd: str, fn: str) -> tuple[str, str]:
ret = []
for v in [rd, fn]:
if v.startswith("//"):
@@ -1124,12 +1172,12 @@ def s3dec(rd, fn):
else:
ret.append(v)
- return tuple(ret)
+ return ret[0], ret[1]
-def atomic_move(src, dst):
- src = fsenc(src)
- dst = fsenc(dst)
+def atomic_move(usrc: str, udst: str) -> None:
+ src = fsenc(usrc)
+ dst = fsenc(udst)
if not PY2:
os.replace(src, dst)
else:
@@ -1139,7 +1187,7 @@ def atomic_move(src, dst):
os.rename(src, dst)
-def read_socket(sr, total_size):
+def read_socket(sr: Unrecv, total_size: int) -> Generator[bytes, None, None]:
remains = total_size
while remains > 0:
bufsz = 32 * 1024
@@ -1149,14 +1197,14 @@ def read_socket(sr, total_size):
try:
buf = sr.recv(bufsz)
except OSError:
- m = "client d/c during binary post after {} bytes, {} bytes remaining"
- raise Pebkac(400, m.format(total_size - remains, remains))
+ t = "client d/c during binary post after {} bytes, {} bytes remaining"
+ raise Pebkac(400, t.format(total_size - remains, remains))
remains -= len(buf)
yield buf
-def read_socket_unbounded(sr):
+def read_socket_unbounded(sr: Unrecv) -> Generator[bytes, None, None]:
try:
while True:
yield sr.recv(32 * 1024)
@@ -1164,7 +1212,9 @@ def read_socket_unbounded(sr):
return
-def read_socket_chunked(sr, log=None):
+def read_socket_chunked(
+ sr: Unrecv, log: Optional[NamedLogger] = None
+) -> Generator[bytes, None, None]:
err = "upload aborted: expected chunk length, got [{}] |{}| instead"
while True:
buf = b""
@@ -1191,8 +1241,8 @@ def read_socket_chunked(sr, log=None):
if x == b"\r\n":
return
- m = "protocol error after final chunk: want b'\\r\\n', got {!r}"
- raise Pebkac(400, m.format(x))
+ t = "protocol error after final chunk: want b'\\r\\n', got {!r}"
+ raise Pebkac(400, t.format(x))
if log:
log("receiving {} byte chunk".format(chunklen))
@@ -1202,11 +1252,11 @@ def read_socket_chunked(sr, log=None):
x = sr.recv_ex(2, False)
if x != b"\r\n":
- m = "protocol error in chunk separator: want b'\\r\\n', got {!r}"
- raise Pebkac(400, m.format(x))
+ t = "protocol error in chunk separator: want b'\\r\\n', got {!r}"
+ raise Pebkac(400, t.format(x))
-def yieldfile(fn):
+def yieldfile(fn: str) -> Generator[bytes, None, None]:
with open(fsenc(fn), "rb", 512 * 1024) as f:
while True:
buf = f.read(64 * 1024)
@@ -1216,7 +1266,11 @@ def yieldfile(fn):
yield buf
-def hashcopy(fin, fout, slp=0):
+def hashcopy(
+ fin: Union[typing.BinaryIO, Generator[bytes, None, None]],
+ fout: Union[typing.BinaryIO, typing.IO[Any]],
+ slp: int = 0,
+) -> tuple[int, str, str]:
hashobj = hashlib.sha512()
tlen = 0
for buf in fin:
@@ -1232,7 +1286,15 @@ def hashcopy(fin, fout, slp=0):
return tlen, hashobj.hexdigest(), digest_b64
-def sendfile_py(log, lower, upper, f, s, bufsz, slp):
+def sendfile_py(
+ log: NamedLogger,
+ lower: int,
+ upper: int,
+ f: typing.BinaryIO,
+ s: socket.socket,
+ bufsz: int,
+ slp: int,
+) -> int:
remains = upper - lower
f.seek(lower)
while remains > 0:
@@ -1252,25 +1314,37 @@ def sendfile_py(log, lower, upper, f, s, bufsz, slp):
return 0
-def sendfile_kern(log, lower, upper, f, s, bufsz, slp):
+def sendfile_kern(
+ log: NamedLogger,
+ lower: int,
+ upper: int,
+ f: typing.BinaryIO,
+ s: socket.socket,
+ bufsz: int,
+ slp: int,
+) -> int:
out_fd = s.fileno()
in_fd = f.fileno()
ofs = lower
- stuck = None
+ stuck = 0.0
while ofs < upper:
stuck = stuck or time.time()
try:
req = min(2 ** 30, upper - ofs)
select.select([], [out_fd], [], 10)
n = os.sendfile(out_fd, in_fd, ofs, req)
- stuck = None
- except Exception as ex:
+ stuck = 0
+ except OSError as ex:
d = time.time() - stuck
log("sendfile stuck for {:.3f} sec: {!r}".format(d, ex))
if d < 3600 and ex.errno == 11: # eagain
continue
n = 0
+ except Exception as ex:
+ n = 0
+ d = time.time() - stuck
+ log("sendfile failed after {:.3f} sec: {!r}".format(d, ex))
if n <= 0:
return upper - ofs
@@ -1281,7 +1355,9 @@ def sendfile_kern(log, lower, upper, f, s, bufsz, slp):
return 0
-def statdir(logger, scandir, lstat, top):
+def statdir(
+ logger: Optional[RootLogger], scandir: bool, lstat: bool, top: str
+) -> Generator[tuple[str, os.stat_result], None, None]:
if lstat and ANYWIN:
lstat = False
@@ -1295,30 +1371,42 @@ def statdir(logger, scandir, lstat, top):
with os.scandir(btop) as dh:
for fh in dh:
try:
- yield [fsdec(fh.name), fh.stat(follow_symlinks=not lstat)]
+ yield (fsdec(fh.name), fh.stat(follow_symlinks=not lstat))
except Exception as ex:
+ if not logger:
+ continue
+
logger(src, "[s] {} @ {}".format(repr(ex), fsdec(fh.path)), 6)
else:
src = "listdir"
- fun = os.lstat if lstat else os.stat
+ fun: Any = os.lstat if lstat else os.stat
for name in os.listdir(btop):
abspath = os.path.join(btop, name)
try:
- yield [fsdec(name), fun(abspath)]
+ yield (fsdec(name), fun(abspath))
except Exception as ex:
+ if not logger:
+ continue
+
logger(src, "[s] {} @ {}".format(repr(ex), fsdec(abspath)), 6)
except Exception as ex:
- logger(src, "{} @ {}".format(repr(ex), top), 1)
+ t = "{} @ {}".format(repr(ex), top)
+ if logger:
+ logger(src, t, 1)
+ else:
+ print(t)
-def rmdirs(logger, scandir, lstat, top, depth):
+def rmdirs(
+ logger: RootLogger, scandir: bool, lstat: bool, top: str, depth: int
+) -> tuple[list[str], list[str]]:
if not os.path.exists(fsenc(top)) or not os.path.isdir(fsenc(top)):
top = os.path.dirname(top)
depth -= 1
- dirs = statdir(logger, scandir, lstat, top)
- dirs = [x[0] for x in dirs if stat.S_ISDIR(x[1].st_mode)]
+ stats = statdir(logger, scandir, lstat, top)
+ dirs = [x[0] for x in stats if stat.S_ISDIR(x[1].st_mode)]
dirs = [os.path.join(top, x) for x in dirs]
ok = []
ng = []
@@ -1337,7 +1425,7 @@ def rmdirs(logger, scandir, lstat, top, depth):
return ok, ng
-def unescape_cookie(orig):
+def unescape_cookie(orig: str) -> str:
# mw=idk; doot=qwe%2Crty%3Basd+fgh%2Bjkl%25zxc%26vbn # qwe,rty;asd fgh+jkl%zxc&vbn
ret = ""
esc = ""
@@ -1365,7 +1453,7 @@ def unescape_cookie(orig):
return ret
-def guess_mime(url, fallback="application/octet-stream"):
+def guess_mime(url: str, fallback: str = "application/octet-stream") -> str:
try:
_, ext = url.rsplit(".", 1)
except:
@@ -1387,7 +1475,9 @@ def guess_mime(url, fallback="application/octet-stream"):
return ret
-def runcmd(argv, timeout=None, **ka):
+def runcmd(
+ argv: Union[list[bytes], list[str]], timeout: Optional[int] = None, **ka: Any
+) -> tuple[int, str, str]:
p = sp.Popen(argv, stdout=sp.PIPE, stderr=sp.PIPE, **ka)
if not timeout or PY2:
stdout, stderr = p.communicate()
@@ -1400,10 +1490,10 @@ def runcmd(argv, timeout=None, **ka):
stdout = stdout.decode("utf-8", "replace")
stderr = stderr.decode("utf-8", "replace")
- return [p.returncode, stdout, stderr]
+ return p.returncode, stdout, stderr
-def chkcmd(argv, **ka):
+def chkcmd(argv: Union[list[bytes], list[str]], **ka: Any) -> tuple[str, str]:
ok, sout, serr = runcmd(argv, **ka)
if ok != 0:
retchk(ok, argv, serr)
@@ -1412,7 +1502,7 @@ def chkcmd(argv, **ka):
return sout, serr
-def mchkcmd(argv, timeout=10):
+def mchkcmd(argv: Union[list[bytes], list[str]], timeout: int = 10) -> None:
if PY2:
with open(os.devnull, "wb") as f:
rv = sp.call(argv, stdout=f, stderr=f)
@@ -1423,7 +1513,14 @@ def mchkcmd(argv, timeout=10):
raise sp.CalledProcessError(rv, (argv[0], b"...", argv[-1]))
-def retchk(rc, cmd, serr, logger=None, color=None, verbose=False):
+def retchk(
+ rc: int,
+ cmd: Union[list[bytes], list[str]],
+ serr: str,
+ logger: Optional[NamedLogger] = None,
+ color: Union[int, str] = 0,
+ verbose: bool = False,
+) -> None:
if rc < 0:
rc = 128 - rc
@@ -1446,33 +1543,33 @@ def retchk(rc, cmd, serr, logger=None, color=None, verbose=False):
s = "invalid retcode"
if s:
- m = "{} <{}>".format(rc, s)
+ t = "{} <{}>".format(rc, s)
else:
- m = str(rc)
+ t = str(rc)
try:
- c = " ".join([fsdec(x) for x in cmd])
+ c = " ".join([fsdec(x) for x in cmd]) # type: ignore
except:
c = str(cmd)
- m = "error {} from [{}]".format(m, c)
+ t = "error {} from [{}]".format(t, c)
if serr:
- m += "\n" + serr
+ t += "\n" + serr
if logger:
- logger(m, color)
+ logger(t, color)
else:
- raise Exception(m)
+ raise Exception(t)
-def gzip_orig_sz(fn):
+def gzip_orig_sz(fn: str) -> int:
with open(fsenc(fn), "rb") as f:
f.seek(-4, 2)
rv = f.read(4)
- return sunpack(b"I", rv)[0]
+ return sunpack(b"I", rv)[0] # type: ignore
-def py_desc():
+def py_desc() -> str:
interp = platform.python_implementation()
py_ver = ".".join([str(x) for x in sys.version_info])
ofs = py_ver.find(".final.")
@@ -1487,15 +1584,15 @@ def py_desc():
host_os = platform.system()
compiler = platform.python_compiler()
- os_ver = re.search(r"([0-9]+\.[0-9\.]+)", platform.version())
- os_ver = os_ver.group(1) if os_ver else ""
+ m = re.search(r"([0-9]+\.[0-9\.]+)", platform.version())
+ os_ver = m.group(1) if m else ""
return "{:>9} v{} on {}{} {} [{}]".format(
interp, py_ver, host_os, bitness, os_ver, compiler
)
-def align_tab(lines):
+def align_tab(lines: list[str]) -> list[str]:
rows = []
ncols = 0
for ln in lines:
@@ -1512,9 +1609,9 @@ def align_tab(lines):
class Pebkac(Exception):
- def __init__(self, code, msg=None):
+ def __init__(self, code: int, msg: Optional[str] = None) -> None:
super(Pebkac, self).__init__(msg or HTTPCODE[code])
self.code = code
- def __repr__(self):
+ def __repr__(self) -> str:
return "Pebkac({}, {})".format(self.code, repr(self.args))
diff --git a/scripts/make-pypi-release.sh b/scripts/make-pypi-release.sh
index 6a568051..7e723d4b 100755
--- a/scripts/make-pypi-release.sh
+++ b/scripts/make-pypi-release.sh
@@ -90,6 +90,15 @@ function have() {
have setuptools
have wheel
have twine
+
+# remove type hints to support python < 3.9
+rm -rf build/pypi
+mkdir -p build/pypi
+cp -pR setup.py README.md LICENSE copyparty tests bin scripts/strip_hints build/pypi/
+cd build/pypi
+tar --strip-components=2 -xf ../strip-hints-0.1.10.tar.gz strip-hints-0.1.10/src/strip_hints
+python3 -c 'from strip_hints.a import uh; uh("copyparty")'
+
./setup.py clean2
./setup.py sdist bdist_wheel --universal
diff --git a/scripts/make-sfx.sh b/scripts/make-sfx.sh
index 8f9856f3..7402c517 100755
--- a/scripts/make-sfx.sh
+++ b/scripts/make-sfx.sh
@@ -76,7 +76,7 @@ while [ ! -z "$1" ]; do
no-hl) no_hl=1 ; ;;
no-dd) no_dd=1 ; ;;
no-cm) no_cm=1 ; ;;
- fast) zopf=100 ; ;;
+ fast) zopf= ; ;;
lang) shift;langs="$1"; ;;
*) help ; ;;
esac
@@ -106,7 +106,7 @@ tmpdir="$(
[ $repack ] && {
old="$tmpdir/pe-copyparty"
echo "repack of files in $old"
- cp -pR "$old/"*{dep-j2,dep-ftp,copyparty} .
+ cp -pR "$old/"*{j2,ftp,copyparty} .
}
[ $repack ] || {
@@ -130,8 +130,8 @@ tmpdir="$(
mv MarkupSafe-*/src/markupsafe .
rm -rf MarkupSafe-* markupsafe/_speedups.c
- mkdir dep-j2/
- mv {markupsafe,jinja2} dep-j2/
+ mkdir j2/
+ mv {markupsafe,jinja2} j2/
echo collecting pyftpdlib
f="../build/pyftpdlib-1.5.6.tar.gz"
@@ -143,8 +143,8 @@ tmpdir="$(
mv pyftpdlib-release-*/pyftpdlib .
rm -rf pyftpdlib-release-* pyftpdlib/test
- mkdir dep-ftp/
- mv pyftpdlib dep-ftp/
+ mkdir ftp/
+ mv pyftpdlib ftp/
echo collecting asyncore, asynchat
for n in asyncore.py asynchat.py; do
@@ -154,6 +154,24 @@ tmpdir="$(
wget -O$f "$url" || curl -L "$url" >$f)
done
+ # enable this to dynamically remove type hints at startup,
+ # in case a future python version can use them for performance
+ true || (
+ echo collecting strip-hints
+ f=../build/strip-hints-0.1.10.tar.gz
+ [ -e $f ] ||
+ (url=https://files.pythonhosted.org/packages/9c/d4/312ddce71ee10f7e0ab762afc027e07a918f1c0e1be5b0069db5b0e7542d/strip-hints-0.1.10.tar.gz;
+ wget -O$f "$url" || curl -L "$url" >$f)
+
+ tar -zxf $f
+ mv strip-hints-0.1.10/src/strip_hints .
+ rm -rf strip-hints-* strip_hints/import_hooks*
+ sed -ri 's/[a-z].* as import_hooks$/"""a"""/' strip_hints/*.py
+
+ cp -pR ../scripts/strip_hints/ .
+ )
+ cp -pR ../scripts/py2/ .
+
# msys2 tar is bad, make the best of it
echo collecting source
[ $clean ] && {
@@ -170,6 +188,9 @@ tmpdir="$(
for n in asyncore.py asynchat.py; do
awk 'NR<4||NR>27;NR==4{print"# license: https://opensource.org/licenses/ISC\n"}' ../build/$n >copyparty/vend/$n
done
+
+ # remove type hints before build instead
+ (cd copyparty; python3 ../../scripts/strip_hints/a.py; rm uh)
}
ver=
@@ -274,17 +295,23 @@ rm have
tmv "$f"
done
-[ $repack ] ||
-find | grep -E '\.py$' |
- grep -vE '__version__' |
- tr '\n' '\0' |
- xargs -0 "$pybin" ../scripts/uncomment.py
+[ $repack ] || {
+ # uncomment
+ find | grep -E '\.py$' |
+ grep -vE '__version__' |
+ tr '\n' '\0' |
+ xargs -0 "$pybin" ../scripts/uncomment.py
-f=dep-j2/jinja2/constants.py
+ # py2-compat
+ #find | grep -E '\.py$' | while IFS= read -r x; do
+ # sed -ri '/: TypeAlias = /d' "$x"; done
+}
+
+f=j2/jinja2/constants.py
awk '/^LOREM_IPSUM_WORDS/{o=1;print "LOREM_IPSUM_WORDS = u\"a\"";next} !o; /"""/{o=0}' <$f >t
tmv "$f"
-grep -rLE '^#[^a-z]*coding: utf-8' dep-j2 |
+grep -rLE '^#[^a-z]*coding: utf-8' j2 |
while IFS= read -r f; do
(echo "# coding: utf-8"; cat "$f") >t
tmv "$f"
@@ -313,7 +340,7 @@ find | grep -E '\.(js|html)$' | while IFS= read -r f; do
done
gzres() {
- command -v pigz &&
+ command -v pigz && [ $zopf ] &&
pk="pigz -11 -I $zopf" ||
pk='gzip'
@@ -354,7 +381,8 @@ nf=$(ls -1 "$zdir"/arc.* | wc -l)
}
[ $use_zdir ] && {
arcs=("$zdir"/arc.*)
- arc="${arcs[$RANDOM % ${#arcs[@]} ] }"
+ n=$(( $RANDOM % ${#arcs[@]} ))
+ arc="${arcs[n]}"
echo "using $arc"
tar -xf "$arc"
for f in copyparty/web/*.gz; do
@@ -364,7 +392,7 @@ nf=$(ls -1 "$zdir"/arc.* | wc -l)
echo gen tarlist
-for d in copyparty dep-j2 dep-ftp; do find $d -type f; done |
+for d in copyparty j2 ftp py2; do find $d -type f; done | # strip_hints
sed -r 's/(.*)\.(.*)/\2 \1/' | LC_ALL=C sort |
sed -r 's/([^ ]*) (.*)/\2.\1/' | grep -vE '/list1?$' > list1
diff --git a/scripts/run-tests.sh b/scripts/run-tests.sh
index 48aaf075..1977a55a 100755
--- a/scripts/run-tests.sh
+++ b/scripts/run-tests.sh
@@ -1,13 +1,23 @@
#!/bin/bash
set -ex
+rm -rf unt
+mkdir -p unt/srv
+cp -pR copyparty tests unt/
+cd unt
+python3 ../scripts/strip_hints/a.py
+
pids=()
for py in python{2,3}; do
+ PYTHONPATH=
+ [ $py = python2 ] && PYTHONPATH=../scripts/py2
+ export PYTHONPATH
+
nice $py -m unittest discover -s tests >/dev/null &
pids+=($!)
done
-python3 scripts/test/smoketest.py &
+python3 ../scripts/test/smoketest.py &
pids+=($!)
for pid in ${pids[@]}; do
diff --git a/scripts/sfx.py b/scripts/sfx.py
index 4723a0da..2cb20c8a 100644
--- a/scripts/sfx.py
+++ b/scripts/sfx.py
@@ -379,9 +379,20 @@ def run(tmp, j2, ftp):
t.daemon = True
t.start()
- ld = (("", ""), (j2, "dep-j2"), (ftp, "dep-ftp"))
+ ld = (("", ""), (j2, "j2"), (ftp, "ftp"), (not PY2, "py2"))
ld = [os.path.join(tmp, b) for a, b in ld if not a]
+ # skip 1
+ # enable this to dynamically remove type hints at startup,
+ # in case a future python version can use them for performance
+ if sys.version_info < (3, 10) and False:
+ sys.path.insert(0, ld[0])
+
+ from strip_hints.a import uh
+
+ uh(tmp + "/copyparty")
+ # skip 0
+
if any([re.match(r"^-.*j[0-9]", x) for x in sys.argv]):
run_s(ld)
else:
diff --git a/scripts/sfx.sh b/scripts/sfx.sh
index 1f53c8db..1496f970 100644
--- a/scripts/sfx.sh
+++ b/scripts/sfx.sh
@@ -47,7 +47,7 @@ grep -E '/(python|pypy)[0-9\.-]*$' >$dir/pys || true
printf '\033[1;30mlooking for jinja2 in [%s]\033[0m\n' "$_py" >&2
$_py -c 'import jinja2' 2>/dev/null || continue
printf '%s\n' "$_py"
- mv $dir/{,x.}dep-j2
+ mv $dir/{,x.}j2
break
done)"
diff --git a/scripts/strip_hints/a.py b/scripts/strip_hints/a.py
new file mode 100644
index 00000000..06bf4f6b
--- /dev/null
+++ b/scripts/strip_hints/a.py
@@ -0,0 +1,57 @@
+# coding: utf-8
+from __future__ import print_function, unicode_literals
+
+import re
+import os
+import sys
+from strip_hints import strip_file_to_string
+
+
+# list unique types used in hints:
+# rm -rf unt && cp -pR copyparty unt && (cd unt && python3 ../scripts/strip_hints/a.py)
+# diff -wNarU1 copyparty unt | grep -E '^\-' | sed -r 's/[^][, ]+://g; s/[^][, ]+[[(]//g; s/[],()<>{} -]/\n/g' | grep -E .. | sort | uniq -c | sort -n
+
+
+def pr(m):
+ sys.stderr.write(m)
+ sys.stderr.flush()
+
+
+def uh(top):
+ if os.path.exists(top + "/uh"):
+ return
+
+ libs = "typing|types|collections\.abc"
+ ptn = re.compile(r"^(\s*)(from (?:{0}) import |import (?:{0})\b).*".format(libs))
+
+ # pr("building support for your python ver")
+ pr("unhinting")
+ for (dp, _, fns) in os.walk(top):
+ for fn in fns:
+ if not fn.endswith(".py"):
+ continue
+
+ pr(".")
+ fp = os.path.join(dp, fn)
+ cs = strip_file_to_string(fp, no_ast=True, to_empty=True)
+
+ # remove expensive imports too
+ lns = []
+ for ln in cs.split("\n"):
+ m = ptn.match(ln)
+ if m:
+ ln = m.group(1) + "raise Exception()"
+
+ lns.append(ln)
+
+ cs = "\n".join(lns)
+ with open(fp, "wb") as f:
+ f.write(cs.encode("utf-8"))
+
+ pr("k\n\n")
+ with open(top + "/uh", "wb") as f:
+ f.write(b"a")
+
+
+if __name__ == "__main__":
+ uh(".")
diff --git a/scripts/test/race.py b/scripts/test/race.py
index 09922e60..77ef0e46 100644
--- a/scripts/test/race.py
+++ b/scripts/test/race.py
@@ -58,13 +58,13 @@ class CState(threading.Thread):
remotes.append("?")
remotes_ok = False
- m = []
+ ta = []
for conn, remote in zip(self.cs, remotes):
stage = len(conn.st)
- m.append(f"\033[3{colors[stage]}m{remote}")
+ ta.append(f"\033[3{colors[stage]}m{remote}")
- m = " ".join(m)
- print(f"{m}\033[0m\n\033[A", end="")
+ t = " ".join(ta)
+ print(f"{t}\033[0m\n\033[A", end="")
def allget(cs, urls):
diff --git a/scripts/test/smoketest.py b/scripts/test/smoketest.py
index 8508f127..a37fc44a 100644
--- a/scripts/test/smoketest.py
+++ b/scripts/test/smoketest.py
@@ -72,6 +72,8 @@ def tc1(vflags):
for _ in range(10):
try:
os.mkdir(td)
+ if os.path.exists(td):
+ break
except:
time.sleep(0.1) # win10
diff --git a/tests/test_vfs.py b/tests/test_vfs.py
index cb7f08fe..4818d180 100644
--- a/tests/test_vfs.py
+++ b/tests/test_vfs.py
@@ -85,7 +85,7 @@ class TestVFS(unittest.TestCase):
pass
def assertAxs(self, dct, lst):
- t1 = list(sorted(dct.keys()))
+ t1 = list(sorted(dct))
t2 = list(sorted(lst))
self.assertEqual(t1, t2)
@@ -208,10 +208,10 @@ class TestVFS(unittest.TestCase):
self.assertEqual(n.realpath, os.path.join(td, "a"))
self.assertAxs(n.axs.uread, ["*"])
self.assertAxs(n.axs.uwrite, [])
- self.assertEqual(vfs.can_access("/", "*"), [False, False, False, False, False])
- self.assertEqual(vfs.can_access("/", "k"), [True, True, False, False, False])
- self.assertEqual(vfs.can_access("/a", "*"), [True, False, False, False, False])
- self.assertEqual(vfs.can_access("/a", "k"), [True, False, False, False, False])
+ self.assertEqual(vfs.can_access("/", "*"), (False, False, False, False, False))
+ self.assertEqual(vfs.can_access("/", "k"), (True, True, False, False, False))
+ self.assertEqual(vfs.can_access("/a", "*"), (True, False, False, False, False))
+ self.assertEqual(vfs.can_access("/a", "k"), (True, False, False, False, False))
# breadth-first construction
vfs = AuthSrv(
@@ -279,7 +279,7 @@ class TestVFS(unittest.TestCase):
n = au.vfs
# root was not defined, so PWD with no access to anyone
self.assertEqual(n.vpath, "")
- self.assertEqual(n.realpath, None)
+ self.assertEqual(n.realpath, "")
self.assertAxs(n.axs.uread, [])
self.assertAxs(n.axs.uwrite, [])
self.assertEqual(len(n.nodes), 1)
diff --git a/tests/util.py b/tests/util.py
index 55ec58ea..92ecaa88 100644
--- a/tests/util.py
+++ b/tests/util.py
@@ -90,7 +90,10 @@ def get_ramdisk():
class NullBroker(object):
- def put(*args):
+ def say(*args):
+ pass
+
+ def ask(*args):
pass