diff --git a/README.md b/README.md index 38c60727..6493f0d2 100644 --- a/README.md +++ b/README.md @@ -1200,15 +1200,18 @@ journalctl -aS '48 hour ago' -u copyparty | grep -C10 FILENAME | tee bug.log ## dev env setup -mostly optional; if you need a working env for vscode or similar +you need python 3.9 or newer due to type hints + +the rest is mostly optional; if you need a working env for vscode or similar ```sh python3 -m venv .venv . .venv/bin/activate -pip install jinja2 # mandatory +pip install jinja2 strip_hints # MANDATORY pip install mutagen # audio metadata +pip install pyftpdlib # ftp server pip install Pillow pyheif-pillow-opener pillow-avif-plugin # thumbnails -pip install black==21.12b0 bandit pylint flake8 # vscode tooling +pip install black==21.12b0 click==8.0.2 bandit pylint flake8 isort mypy # vscode tooling ``` diff --git a/bin/mtag/image-noexif.py b/bin/mtag/image-noexif.py index d5392b00..bc009c17 100644 --- a/bin/mtag/image-noexif.py +++ b/bin/mtag/image-noexif.py @@ -43,7 +43,6 @@ PS: this requires e2ts to be functional, import os import sys -import time import filecmp import subprocess as sp diff --git a/bin/up2k.py b/bin/up2k.py index 8cf904c4..23d35ea3 100755 --- a/bin/up2k.py +++ b/bin/up2k.py @@ -77,15 +77,15 @@ class File(object): self.up_b = 0 # type: int self.up_c = 0 # type: int - # m = "size({}) lmod({}) top({}) rel({}) abs({}) name({})\n" - # eprint(m.format(self.size, self.lmod, self.top, self.rel, self.abs, self.name)) + # t = "size({}) lmod({}) top({}) rel({}) abs({}) name({})\n" + # eprint(t.format(self.size, self.lmod, self.top, self.rel, self.abs, self.name)) class FileSlice(object): """file-like object providing a fixed window into a file""" def __init__(self, file, cid): - # type: (File, str) -> FileSlice + # type: (File, str) -> None self.car, self.len = file.kchunks[cid] self.cdr = self.car + self.len @@ -216,8 +216,8 @@ class CTermsize(object): eprint("\033[s\033[r\033[u") else: self.g = 1 + self.h - margin - m = "{0}\033[{1}A".format("\n" * margin, margin) - eprint("{0}\033[s\033[1;{1}r\033[u".format(m, self.g - 1)) + t = "{0}\033[{1}A".format("\n" * margin, margin) + eprint("{0}\033[s\033[1;{1}r\033[u".format(t, self.g - 1)) ss = CTermsize() @@ -597,8 +597,8 @@ class Ctl(object): if "/" in name: name = "\033[36m{0}\033[0m/{1}".format(*name.rsplit("/", 1)) - m = "{0:6.1f}% {1} {2}\033[K" - txt += m.format(p, self.nfiles - f, name) + t = "{0:6.1f}% {1} {2}\033[K" + txt += t.format(p, self.nfiles - f, name) txt += "\033[{0}H ".format(ss.g + 2) else: @@ -618,8 +618,8 @@ class Ctl(object): nleft = self.nfiles - self.up_f tail = "\033[K\033[u" if VT100 else "\r" - m = "{0} eta @ {1}/s, {2}, {3}# left".format(eta, spd, sleft, nleft) - eprint(txt + "\033]0;{0}\033\\\r{0}{1}".format(m, tail)) + t = "{0} eta @ {1}/s, {2}, {3}# left".format(eta, spd, sleft, nleft) + eprint(txt + "\033]0;{0}\033\\\r{0}{1}".format(t, tail)) def cleanup_vt100(self): ss.scroll_region(None) @@ -721,8 +721,8 @@ class Ctl(object): if search: if hs: for hit in hs: - m = "found: {0}\n {1}{2}\n" - print(m.format(upath, burl, hit["rp"]), end="") + t = "found: {0}\n {1}{2}\n" + print(t.format(upath, burl, hit["rp"]), end="") else: print("NOT found: {0}\n".format(upath), end="") diff --git a/contrib/systemd/copyparty.service b/contrib/systemd/copyparty.service index e8e02f90..fd3f9efc 100644 --- a/contrib/systemd/copyparty.service +++ b/contrib/systemd/copyparty.service @@ -4,7 +4,7 @@ # installation: # cp -pv copyparty.service /etc/systemd/system # restorecon -vr /etc/systemd/system/copyparty.service -# firewall-cmd --permanent --add-port={80,443,3923}/tcp +# firewall-cmd --permanent --add-port={80,443,3923}/tcp # --zone=libvirt # firewall-cmd --reload # systemctl daemon-reload && systemctl enable --now copyparty # diff --git a/copyparty/__init__.py b/copyparty/__init__.py index 17513708..594363eb 100644 --- a/copyparty/__init__.py +++ b/copyparty/__init__.py @@ -1,21 +1,30 @@ # coding: utf-8 from __future__ import print_function, unicode_literals -import platform -import time -import sys import os +import platform +import sys +import time + +try: + from collections.abc import Callable + + from typing import TYPE_CHECKING, Any +except: + TYPE_CHECKING = False PY2 = sys.version_info[0] == 2 if PY2: sys.dont_write_bytecode = True - unicode = unicode + unicode = unicode # noqa: F821 # pylint: disable=undefined-variable,self-assigning-variable else: unicode = str -WINDOWS = False -if platform.system() == "Windows": - WINDOWS = [int(x) for x in platform.version().split(".")] +WINDOWS: Any = ( + [int(x) for x in platform.version().split(".")] + if platform.system() == "Windows" + else False +) VT100 = not WINDOWS or WINDOWS >= [10, 0, 14393] # introduced in anniversary update @@ -25,8 +34,8 @@ ANYWIN = WINDOWS or sys.platform in ["msys"] MACOS = platform.system() == "Darwin" -def get_unixdir(): - paths = [ +def get_unixdir() -> str: + paths: list[tuple[Callable[..., str], str]] = [ (os.environ.get, "XDG_CONFIG_HOME"), (os.path.expanduser, "~/.config"), (os.environ.get, "TMPDIR"), @@ -43,7 +52,7 @@ def get_unixdir(): continue p = os.path.normpath(p) - chk(p) + chk(p) # type: ignore p = os.path.join(p, "copyparty") if not os.path.isdir(p): os.mkdir(p) @@ -56,7 +65,7 @@ def get_unixdir(): class EnvParams(object): - def __init__(self): + def __init__(self) -> None: self.t0 = time.time() self.mod = os.path.dirname(os.path.realpath(__file__)) if self.mod.endswith("__init__"): diff --git a/copyparty/__main__.py b/copyparty/__main__.py index 87203f87..5412a6c2 100644 --- a/copyparty/__main__.py +++ b/copyparty/__main__.py @@ -8,35 +8,42 @@ __copyright__ = 2019 __license__ = "MIT" __url__ = "https://github.com/9001/copyparty/" -import re -import os -import sys -import time -import shutil +import argparse import filecmp import locale -import argparse +import os +import re +import shutil +import sys import threading +import time import traceback from textwrap import dedent -from .__init__ import E, WINDOWS, ANYWIN, VT100, PY2, unicode -from .__version__ import S_VERSION, S_BUILD_DT, CODENAME -from .svchub import SvcHub -from .util import py_desc, align_tab, IMPLICATIONS, ansi_re, min_ex +from .__init__ import ANYWIN, PY2, VT100, WINDOWS, E, unicode +from .__version__ import CODENAME, S_BUILD_DT, S_VERSION from .authsrv import re_vol +from .svchub import SvcHub +from .util import IMPLICATIONS, align_tab, ansi_re, min_ex, py_desc -HAVE_SSL = True try: + from types import FrameType + + from typing import Any, Optional +except: + pass + +try: + HAVE_SSL = True import ssl except: HAVE_SSL = False -printed = [] +printed: list[str] = [] class RiceFormatter(argparse.HelpFormatter): - def _get_help_string(self, action): + def _get_help_string(self, action: argparse.Action) -> str: """ same as ArgumentDefaultsHelpFormatter(HelpFormatter) except the help += [...] line now has colors @@ -45,27 +52,27 @@ class RiceFormatter(argparse.HelpFormatter): if not VT100: fmt = " (default: %(default)s)" - ret = action.help - if "%(default)" not in action.help: + ret = str(action.help) + if "%(default)" not in ret: if action.default is not argparse.SUPPRESS: defaulting_nargs = [argparse.OPTIONAL, argparse.ZERO_OR_MORE] if action.option_strings or action.nargs in defaulting_nargs: ret += fmt return ret - def _fill_text(self, text, width, indent): + def _fill_text(self, text: str, width: int, indent: str) -> str: """same as RawDescriptionHelpFormatter(HelpFormatter)""" return "".join(indent + line + "\n" for line in text.splitlines()) class Dodge11874(RiceFormatter): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: kwargs["width"] = 9003 super(Dodge11874, self).__init__(*args, **kwargs) -def lprint(*a, **ka): - txt = " ".join(unicode(x) for x in a) + ka.get("end", "\n") +def lprint(*a: Any, **ka: Any) -> None: + txt: str = " ".join(unicode(x) for x in a) + ka.get("end", "\n") printed.append(txt) if not VT100: txt = ansi_re.sub("", txt) @@ -73,11 +80,11 @@ def lprint(*a, **ka): print(txt, **ka) -def warn(msg): +def warn(msg: str) -> None: lprint("\033[1mwarning:\033[0;33m {}\033[0m\n".format(msg)) -def ensure_locale(): +def ensure_locale() -> None: for x in [ "en_US.UTF-8", "English_United States.UTF8", @@ -91,7 +98,7 @@ def ensure_locale(): continue -def ensure_cert(): +def ensure_cert() -> None: """ the default cert (and the entire TLS support) is only here to enable the crypto.subtle javascript API, which is necessary due to the webkit guys @@ -117,8 +124,8 @@ def ensure_cert(): # printf 'NO\n.\n.\n.\n.\ncopyparty-insecure\n.\n' | faketime '2000-01-01 00:00:00' openssl req -x509 -sha256 -newkey rsa:2048 -keyout insecure.pem -out insecure.pem -days $((($(printf %d 0x7fffffff)-$(date +%s --date=2000-01-01T00:00:00Z))/(60*60*24))) -nodes && ls -al insecure.pem && openssl x509 -in insecure.pem -text -noout -def configure_ssl_ver(al): - def terse_sslver(txt): +def configure_ssl_ver(al: argparse.Namespace) -> None: + def terse_sslver(txt: str) -> str: txt = txt.lower() for c in ["_", "v", "."]: txt = txt.replace(c, "") @@ -133,8 +140,8 @@ def configure_ssl_ver(al): flags = [k for k in ssl.__dict__ if ptn.match(k)] # SSLv2 SSLv3 TLSv1 TLSv1_1 TLSv1_2 TLSv1_3 if "help" in sslver: - avail = [terse_sslver(x[6:]) for x in flags] - avail = " ".join(sorted(avail) + ["all"]) + avail1 = [terse_sslver(x[6:]) for x in flags] + avail = " ".join(sorted(avail1) + ["all"]) lprint("\navailable ssl/tls versions:\n " + avail) sys.exit(0) @@ -160,7 +167,7 @@ def configure_ssl_ver(al): # think i need that beer now -def configure_ssl_ciphers(al): +def configure_ssl_ciphers(al: argparse.Namespace) -> None: ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) if al.ssl_ver: ctx.options &= ~al.ssl_flags_en @@ -184,8 +191,8 @@ def configure_ssl_ciphers(al): sys.exit(0) -def args_from_cfg(cfg_path): - ret = [] +def args_from_cfg(cfg_path: str) -> list[str]: + ret: list[str] = [] skip = False with open(cfg_path, "rb") as f: for ln in [x.decode("utf-8").strip() for x in f]: @@ -210,29 +217,30 @@ def args_from_cfg(cfg_path): return ret -def sighandler(sig=None, frame=None): +def sighandler(sig: Optional[int] = None, frame: Optional[FrameType] = None) -> None: msg = [""] * 5 for th in threading.enumerate(): + stk = sys._current_frames()[th.ident] # type: ignore msg.append(str(th)) - msg.extend(traceback.format_stack(sys._current_frames()[th.ident])) + msg.extend(traceback.format_stack(stk)) msg.append("\n") print("\n".join(msg)) -def disable_quickedit(): - import ctypes +def disable_quickedit() -> None: import atexit + import ctypes from ctypes import wintypes - def ecb(ok, fun, args): + def ecb(ok: bool, fun: Any, args: list[Any]) -> list[Any]: if not ok: - err = ctypes.get_last_error() + err: int = ctypes.get_last_error() # type: ignore if err: - raise ctypes.WinError(err) + raise ctypes.WinError(err) # type: ignore return args - k32 = ctypes.WinDLL("kernel32", use_last_error=True) + k32 = ctypes.WinDLL("kernel32", use_last_error=True) # type: ignore if PY2: wintypes.LPDWORD = ctypes.POINTER(wintypes.DWORD) @@ -242,14 +250,14 @@ def disable_quickedit(): k32.GetConsoleMode.argtypes = (wintypes.HANDLE, wintypes.LPDWORD) k32.SetConsoleMode.argtypes = (wintypes.HANDLE, wintypes.DWORD) - def cmode(out, mode=None): + def cmode(out: bool, mode: Optional[int] = None) -> int: h = k32.GetStdHandle(-11 if out else -10) if mode: - return k32.SetConsoleMode(h, mode) + return k32.SetConsoleMode(h, mode) # type: ignore - mode = wintypes.DWORD() - k32.GetConsoleMode(h, ctypes.byref(mode)) - return mode.value + cmode = wintypes.DWORD() + k32.GetConsoleMode(h, ctypes.byref(cmode)) + return cmode.value # disable quickedit mode = orig_in = cmode(False) @@ -268,7 +276,7 @@ def disable_quickedit(): cmode(True, mode | 4) -def run_argparse(argv, formatter): +def run_argparse(argv: list[str], formatter: Any) -> argparse.Namespace: ap = argparse.ArgumentParser( formatter_class=formatter, prog="copyparty", @@ -596,7 +604,7 @@ def run_argparse(argv, formatter): return ret -def main(argv=None): +def main(argv: Optional[list[str]] = None) -> None: time.strptime("19970815", "%Y%m%d") # python#7980 if WINDOWS: os.system("rem") # enables colors @@ -618,7 +626,7 @@ def main(argv=None): supp = args_from_cfg(v) argv.extend(supp) - deprecated = [] + deprecated: list[tuple[str, str]] = [] for dk, nk in deprecated: try: idx = argv.index(dk) @@ -650,7 +658,7 @@ def main(argv=None): if not VT100: al.wintitle = "" - nstrs = [] + nstrs: list[str] = [] anymod = False for ostr in al.v or []: m = re_vol.match(ostr) diff --git a/copyparty/authsrv.py b/copyparty/authsrv.py index 4ca406ee..92612529 100644 --- a/copyparty/authsrv.py +++ b/copyparty/authsrv.py @@ -1,44 +1,68 @@ # coding: utf-8 from __future__ import print_function, unicode_literals -import re -import os -import sys -import stat -import time +import argparse import base64 import hashlib +import os +import re +import stat +import sys import threading +import time from datetime import datetime -from .__init__ import ANYWIN, WINDOWS +from .__init__ import ANYWIN, TYPE_CHECKING, WINDOWS +from .bos import bos from .util import ( IMPLICATIONS, META_NOBOTS, + Pebkac, + absreal, + fsenc, + relchk, + statdir, uncyg, undot, - relchk, unhumanize, - absreal, - Pebkac, - fsenc, - statdir, ) -from .bos import bos + +try: + from collections.abc import Iterable + + import typing + from typing import Any, Generator, Optional, Union + + from .util import RootLogger +except: + pass + +if TYPE_CHECKING: + pass + # Vflags: TypeAlias = dict[str, str | bool | float | list[str]] + # Vflags: TypeAlias = dict[str, Any] + # Mflags: TypeAlias = dict[str, Vflags] LEELOO_DALLAS = "leeloo_dallas" class AXS(object): - def __init__(self, uread=None, uwrite=None, umove=None, udel=None, uget=None): - self.uread = {} if uread is None else {k: 1 for k in uread} - self.uwrite = {} if uwrite is None else {k: 1 for k in uwrite} - self.umove = {} if umove is None else {k: 1 for k in umove} - self.udel = {} if udel is None else {k: 1 for k in udel} - self.uget = {} if uget is None else {k: 1 for k in uget} + def __init__( + self, + uread: Optional[Union[list[str], set[str]]] = None, + uwrite: Optional[Union[list[str], set[str]]] = None, + umove: Optional[Union[list[str], set[str]]] = None, + udel: Optional[Union[list[str], set[str]]] = None, + uget: Optional[Union[list[str], set[str]]] = None, + ) -> None: + self.uread: set[str] = set(uread or []) + self.uwrite: set[str] = set(uwrite or []) + self.umove: set[str] = set(umove or []) + self.udel: set[str] = set(udel or []) + self.uget: set[str] = set(uget or []) - def __repr__(self): + def __repr__(self) -> str: return "AXS({})".format( ", ".join( "{}={!r}".format(k, self.__dict__[k]) @@ -48,33 +72,33 @@ class AXS(object): class Lim(object): - def __init__(self): - self.nups = {} # num tracker - self.bups = {} # byte tracker list - self.bupc = {} # byte tracker cache + def __init__(self) -> None: + self.nups: dict[str, list[float]] = {} # num tracker + self.bups: dict[str, list[tuple[float, int]]] = {} # byte tracker list + self.bupc: dict[str, int] = {} # byte tracker cache self.nosub = False # disallow subdirectories - self.smin = None # filesize min - self.smax = None # filesize max + self.smin = -1 # filesize min + self.smax = -1 # filesize max - self.bwin = None # bytes window - self.bmax = None # bytes max - self.nwin = None # num window - self.nmax = None # num max + self.bwin = 0 # bytes window + self.bmax = 0 # bytes max + self.nwin = 0 # num window + self.nmax = 0 # num max - self.rotn = None # rot num files - self.rotl = None # rot depth - self.rotf = None # rot datefmt - self.rot_re = None # rotf check + self.rotn = 0 # rot num files + self.rotl = 0 # rot depth + self.rotf = "" # rot datefmt + self.rot_re = re.compile("") # rotf check - def set_rotf(self, fmt): + def set_rotf(self, fmt: str) -> None: self.rotf = fmt r = re.escape(fmt).replace("%Y", "[0-9]{4}").replace("%j", "[0-9]{3}") r = re.sub("%[mdHMSWU]", "[0-9]{2}", r) self.rot_re = re.compile("(^|/)" + r + "$") - def all(self, ip, rem, sz, abspath): + def all(self, ip: str, rem: str, sz: float, abspath: str) -> tuple[str, str]: self.chk_nup(ip) self.chk_bup(ip) self.chk_rem(rem) @@ -87,18 +111,18 @@ class Lim(object): return ap2, ("{}/{}".format(rem, vp2) if rem else vp2) - def chk_sz(self, sz): - if self.smin is not None and sz < self.smin: + def chk_sz(self, sz: float) -> None: + if self.smin != -1 and sz < self.smin: raise Pebkac(400, "file too small") - if self.smax is not None and sz > self.smax: + if self.smax != -1 and sz > self.smax: raise Pebkac(400, "file too big") - def chk_rem(self, rem): + def chk_rem(self, rem: str) -> None: if self.nosub and rem: raise Pebkac(500, "no subdirectories allowed") - def rot(self, path): + def rot(self, path: str) -> tuple[str, str]: if not self.rotf and not self.rotn: return path, "" @@ -120,7 +144,7 @@ class Lim(object): d = ret[len(path) :].strip("/\\").replace("\\", "/") return ret, d - def dive(self, path, lvs): + def dive(self, path: str, lvs: int) -> Optional[str]: items = bos.listdir(path) if not lvs: @@ -155,14 +179,14 @@ class Lim(object): return os.path.join(sub, ret) - def nup(self, ip): + def nup(self, ip: str) -> None: try: self.nups[ip].append(time.time()) except: self.nups[ip] = [time.time()] - def bup(self, ip, nbytes): - v = [time.time(), nbytes] + def bup(self, ip: str, nbytes: int) -> None: + v = (time.time(), nbytes) try: self.bups[ip].append(v) self.bupc[ip] += nbytes @@ -170,7 +194,7 @@ class Lim(object): self.bups[ip] = [v] self.bupc[ip] = nbytes - def chk_nup(self, ip): + def chk_nup(self, ip: str) -> None: if not self.nmax or ip not in self.nups: return @@ -182,7 +206,7 @@ class Lim(object): if len(nups) >= self.nmax: raise Pebkac(429, "too many uploads") - def chk_bup(self, ip): + def chk_bup(self, ip: str) -> None: if not self.bmax or ip not in self.bups: return @@ -200,35 +224,37 @@ class Lim(object): class VFS(object): """single level in the virtual fs""" - def __init__(self, log, realpath, vpath, axs, flags): + def __init__( + self, + log: Optional[RootLogger], + realpath: str, + vpath: str, + axs: AXS, + flags: dict[str, Any], + ) -> None: self.log = log self.realpath = realpath # absolute path on host filesystem self.vpath = vpath # absolute path in the virtual filesystem - self.axs = axs # type: AXS + self.axs = axs self.flags = flags # config options - self.nodes = {} # child nodes - self.histtab = None # all realpath->histpath - self.dbv = None # closest full/non-jump parent - self.lim = None # type: Lim # upload limits; only set for dbv + self.nodes: dict[str, VFS] = {} # child nodes + self.histtab: dict[str, str] = {} # all realpath->histpath + self.dbv: Optional[VFS] = None # closest full/non-jump parent + self.lim: Optional[Lim] = None # upload limits; only set for dbv + self.aread: dict[str, list[str]] = {} + self.awrite: dict[str, list[str]] = {} + self.amove: dict[str, list[str]] = {} + self.adel: dict[str, list[str]] = {} + self.aget: dict[str, list[str]] = {} if realpath: self.histpath = os.path.join(realpath, ".hist") # db / thumbcache self.all_vols = {vpath: self} # flattened recursive - self.aread = {} - self.awrite = {} - self.amove = {} - self.adel = {} - self.aget = {} else: - self.histpath = None - self.all_vols = None - self.aread = None - self.awrite = None - self.amove = None - self.adel = None - self.aget = None + self.histpath = "" + self.all_vols = {} - def __repr__(self): + def __repr__(self) -> str: return "VFS({})".format( ", ".join( "{}={!r}".format(k, self.__dict__[k]) @@ -236,14 +262,14 @@ class VFS(object): ) ) - def get_all_vols(self, outdict): + def get_all_vols(self, outdict: dict[str, "VFS"]) -> None: if self.realpath: outdict[self.vpath] = self for v in self.nodes.values(): v.get_all_vols(outdict) - def add(self, src, dst): + def add(self, src: str, dst: str) -> "VFS": """get existing, or add new path to the vfs""" assert not src.endswith("/") # nosec assert not dst.endswith("/") # nosec @@ -257,7 +283,7 @@ class VFS(object): vn = VFS( self.log, - os.path.join(self.realpath, name) if self.realpath else None, + os.path.join(self.realpath, name) if self.realpath else "", "{}/{}".format(self.vpath, name).lstrip("/"), self.axs, self._copy_flags(name), @@ -277,7 +303,7 @@ class VFS(object): self.nodes[dst] = vn return vn - def _copy_flags(self, name): + def _copy_flags(self, name: str) -> dict[str, Any]: flags = {k: v for k, v in self.flags.items()} hist = flags.get("hist") if hist and hist != "-": @@ -285,20 +311,20 @@ class VFS(object): return flags - def bubble_flags(self): + def bubble_flags(self) -> None: if self.dbv: for k, v in self.dbv.flags.items(): if k not in ["hist"]: self.flags[k] = v - for v in self.nodes.values(): - v.bubble_flags() + for n in self.nodes.values(): + n.bubble_flags() - def _find(self, vpath): + def _find(self, vpath: str) -> tuple["VFS", str]: """return [vfs,remainder]""" vpath = undot(vpath) if vpath == "": - return [self, ""] + return self, "" if "/" in vpath: name, rem = vpath.split("/", 1) @@ -309,66 +335,64 @@ class VFS(object): if name in self.nodes: return self.nodes[name]._find(rem) - return [self, vpath] + return self, vpath - def can_access(self, vpath, uname): - # type: (str, str) -> tuple[bool, bool, bool, bool] + def can_access(self, vpath: str, uname: str) -> tuple[bool, bool, bool, bool, bool]: """can Read,Write,Move,Delete,Get""" vn, _ = self._find(vpath) c = vn.axs - return [ + return ( uname in c.uread or "*" in c.uread, uname in c.uwrite or "*" in c.uwrite, uname in c.umove or "*" in c.umove, uname in c.udel or "*" in c.udel, uname in c.uget or "*" in c.uget, - ] + ) def get( self, - vpath, - uname, - will_read, - will_write, - will_move=False, - will_del=False, - will_get=False, - ): - # type: (str, str, bool, bool, bool, bool, bool) -> tuple[VFS, str] + vpath: str, + uname: str, + will_read: bool, + will_write: bool, + will_move: bool = False, + will_del: bool = False, + will_get: bool = False, + ) -> tuple["VFS", str]: """returns [vfsnode,fs_remainder] if user has the requested permissions""" if ANYWIN: mod = relchk(vpath) if mod: - self.log("vfs", "invalid relpath [{}]".format(vpath)) + if self.log: + self.log("vfs", "invalid relpath [{}]".format(vpath)) raise Pebkac(404) vn, rem = self._find(vpath) - c = vn.axs + c: AXS = vn.axs for req, d, msg in [ - [will_read, c.uread, "read"], - [will_write, c.uwrite, "write"], - [will_move, c.umove, "move"], - [will_del, c.udel, "delete"], - [will_get, c.uget, "get"], + (will_read, c.uread, "read"), + (will_write, c.uwrite, "write"), + (will_move, c.umove, "move"), + (will_del, c.udel, "delete"), + (will_get, c.uget, "get"), ]: if req and (uname not in d and "*" not in d) and uname != LEELOO_DALLAS: - m = "you don't have {}-access for this location" - raise Pebkac(403, m.format(msg)) + t = "you don't have {}-access for this location" + raise Pebkac(403, t.format(msg)) return vn, rem - def get_dbv(self, vrem): - # type: (str) -> tuple[VFS, str] + def get_dbv(self, vrem: str) -> tuple["VFS", str]: dbv = self.dbv if not dbv: return self, vrem - vrem = [self.vpath[len(dbv.vpath) :].lstrip("/"), vrem] - vrem = "/".join([x for x in vrem if x]) + tv = [self.vpath[len(dbv.vpath) :].lstrip("/"), vrem] + vrem = "/".join([x for x in tv if x]) return dbv, vrem - def canonical(self, rem, resolve=True): + def canonical(self, rem: str, resolve: bool = True) -> str: """returns the canonical path (fully-resolved absolute fs path)""" rp = self.realpath if rem: @@ -376,8 +400,14 @@ class VFS(object): return absreal(rp) if resolve else rp - def ls(self, rem, uname, scandir, permsets, lstat=False): - # type: (str, str, bool, list[list[bool]], bool) -> tuple[str, str, dict[str, VFS]] + def ls( + self, + rem: str, + uname: str, + scandir: bool, + permsets: list[list[bool]], + lstat: bool = False, + ) -> tuple[str, list[tuple[str, os.stat_result]], dict[str, "VFS"]]: """return user-readable [fsdir,real,virt] items at vpath""" virt_vis = {} # nodes readable by user abspath = self.canonical(rem) @@ -389,8 +419,8 @@ class VFS(object): for name, vn2 in sorted(self.nodes.items()): ok = False - axs = vn2.axs - axs = [axs.uread, axs.uwrite, axs.umove, axs.udel, axs.uget] + zx = vn2.axs + axs = [zx.uread, zx.uwrite, zx.umove, zx.udel, zx.uget] for pset in permsets: ok = True for req, lst in zip(pset, axs): @@ -409,9 +439,32 @@ class VFS(object): elif "/.hist/th/" in p: real = [x for x in real if not x[0].endswith("dir.txt")] - return [abspath, real, virt_vis] + return abspath, real, virt_vis - def walk(self, rel, rem, seen, uname, permsets, dots, scandir, lstat, subvols=True): + def walk( + self, + rel: str, + rem: str, + seen: list[str], + uname: str, + permsets: list[list[bool]], + dots: bool, + scandir: bool, + lstat: bool, + subvols: bool = True, + ) -> Generator[ + tuple[ + "VFS", + str, + str, + str, + list[tuple[str, os.stat_result]], + list[tuple[str, os.stat_result]], + dict[str, "VFS"], + ], + None, + None, + ]: """ recursively yields from ./rem; rel is a unix-style user-defined vpath (not vfs-related) @@ -425,8 +478,9 @@ class VFS(object): and (not fsroot.startswith(seen[-1]) or fsroot == seen[-1]) and fsroot in seen ): - m = "bailing from symlink loop,\n prev: {}\n curr: {}\n from: {}/{}" - self.log("vfs.walk", m.format(seen[-1], fsroot, self.vpath, rem), 3) + if self.log: + t = "bailing from symlink loop,\n prev: {}\n curr: {}\n from: {}/{}" + self.log("vfs.walk", t.format(seen[-1], fsroot, self.vpath, rem), 3) return seen = seen[:] + [fsroot] @@ -460,9 +514,9 @@ class VFS(object): for x in vfs.walk(wrel, "", seen, uname, permsets, dots, scandir, lstat): yield x - def zipgen(self, vrem, flt, uname, dots, scandir): - if flt: - flt = {k: True for k in flt} + def zipgen( + self, vrem: str, flt: set[str], uname: str, dots: bool, scandir: bool + ) -> Generator[dict[str, Any], None, None]: # if multiselect: add all items to archive root # if single folder: the folder itself is the top-level item @@ -473,33 +527,33 @@ class VFS(object): if flt: files = [x for x in files if x[0] in flt] - rm = [x for x in rd if x[0] not in flt] - [rd.remove(x) for x in rm] + rm1 = [x for x in rd if x[0] not in flt] + _ = [rd.remove(x) for x in rm1] # type: ignore - rm = [x for x in vd.keys() if x not in flt] - [vd.pop(x) for x in rm] + rm2 = [x for x in vd.keys() if x not in flt] + _ = [vd.pop(x) for x in rm2] - flt = None + flt = set() # print(repr([vpath, apath, [x[0] for x in files]])) fnames = [n[0] for n in files] vpaths = [vpath + "/" + n for n in fnames] if vpath else fnames apaths = [os.path.join(apath, n) for n in fnames] - files = list(zip(vpaths, apaths, files)) + ret = list(zip(vpaths, apaths, files)) if not dots: # dotfile filtering based on vpath (intended visibility) - files = [x for x in files if "/." not in "/" + x[0]] + ret = [x for x in ret if "/." not in "/" + x[0]] - rm = [x for x in rd if x[0].startswith(".")] - for x in rm: - rd.remove(x) + zel = [ze for ze in rd if ze[0].startswith(".")] + for ze in zel: + rd.remove(ze) - rm = [k for k in vd.keys() if k.startswith(".")] - for x in rm: - del vd[x] + zsl = [zs for zs in vd.keys() if zs.startswith(".")] + for zs in zsl: + del vd[zs] - for f in [{"vp": v, "ap": a, "st": n[1]} for v, a, n in files]: + for f in [{"vp": v, "ap": a, "st": n[1]} for v, a, n in ret]: yield f @@ -512,7 +566,12 @@ else: class AuthSrv(object): """verifies users against given paths""" - def __init__(self, args, log_func, warn_anonwrite=True): + def __init__( + self, + args: argparse.Namespace, + log_func: Optional[RootLogger], + warn_anonwrite: bool = True, + ) -> None: self.args = args self.log_func = log_func self.warn_anonwrite = warn_anonwrite @@ -521,11 +580,11 @@ class AuthSrv(object): self.mutex = threading.Lock() self.reload() - def log(self, msg, c=0): + def log(self, msg: str, c: Union[int, str] = 0) -> None: if self.log_func: self.log_func("auth", msg, c) - def laggy_iter(self, iterable): + def laggy_iter(self, iterable: Iterable[Any]) -> Generator[Any, None, None]: """returns [value,isFinalValue]""" it = iter(iterable) prev = next(it) @@ -535,26 +594,39 @@ class AuthSrv(object): yield prev, True - def _map_volume(self, src, dst, mount, daxs, mflags): + def _map_volume( + self, + src: str, + dst: str, + mount: dict[str, str], + daxs: dict[str, AXS], + mflags: dict[str, dict[str, Any]], + ) -> None: if dst in mount: - m = "multiple filesystem-paths mounted at [/{}]:\n [{}]\n [{}]" - self.log(m.format(dst, mount[dst], src), c=1) + t = "multiple filesystem-paths mounted at [/{}]:\n [{}]\n [{}]" + self.log(t.format(dst, mount[dst], src), c=1) raise Exception("invalid config") if src in mount.values(): - m = "warning: filesystem-path [{}] mounted in multiple locations:" - m = m.format(src) + t = "warning: filesystem-path [{}] mounted in multiple locations:" + t = t.format(src) for v in [k for k, v in mount.items() if v == src] + [dst]: - m += "\n /{}".format(v) + t += "\n /{}".format(v) - self.log(m, c=3) + self.log(t, c=3) mount[dst] = src daxs[dst] = AXS() mflags[dst] = {} - def _parse_config_file(self, fd, acct, daxs, mflags, mount): - # type: (any, str, dict[str, AXS], any, str) -> None + def _parse_config_file( + self, + fd: typing.BinaryIO, + acct: dict[str, str], + daxs: dict[str, AXS], + mflags: dict[str, dict[str, Any]], + mount: dict[str, str], + ) -> None: skip = False vol_src = None vol_dst = None @@ -601,23 +673,25 @@ class AuthSrv(object): uname = "*" if lvl == "a": - m = "WARNING (config-file): permission flag 'a' is deprecated; please use 'rw' instead" - self.log(m, 1) + t = "WARNING (config-file): permission flag 'a' is deprecated; please use 'rw' instead" + self.log(t, 1) self._read_vol_str(lvl, uname, daxs[vol_dst], mflags[vol_dst]) - def _read_vol_str(self, lvl, uname, axs, flags): - # type: (str, str, AXS, any) -> None + def _read_vol_str( + self, lvl: str, uname: str, axs: AXS, flags: dict[str, Any] + ) -> None: if lvl.strip("crwmdg"): raise Exception("invalid volume flag: {},{}".format(lvl, uname)) if lvl == "c": + cval: Union[bool, str] = True try: # volume flag with arguments, possibly with a preceding list of bools uname, cval = uname.split("=", 1) except: # just one or more bools - cval = True + pass while "," in uname: # one or more bools before the final flag; eat them @@ -631,34 +705,38 @@ class AuthSrv(object): uname = "*" for un in uname.replace(",", " ").strip().split(): - if "r" in lvl: - axs.uread[un] = 1 + for ch, al in [ + ("r", axs.uread), + ("w", axs.uwrite), + ("m", axs.umove), + ("d", axs.udel), + ("g", axs.uget), + ]: + if ch in lvl: + al.add(un) - if "w" in lvl: - axs.uwrite[un] = 1 - - if "m" in lvl: - axs.umove[un] = 1 - - if "d" in lvl: - axs.udel[un] = 1 - - if "g" in lvl: - axs.uget[un] = 1 - - def _read_volflag(self, flags, name, value, is_list): + def _read_volflag( + self, + flags: dict[str, Any], + name: str, + value: Union[str, bool, list[str]], + is_list: bool, + ) -> None: if name not in ["mtp"]: flags[name] = value return - if not is_list: - value = [value] - elif not value: + vals = flags.get(name, []) + if not value: return + elif is_list: + vals += value + else: + vals += [value] - flags[name] = flags.get(name, []) + value + flags[name] = vals - def reload(self): + def reload(self) -> None: """ construct a flat list of mountpoints and usernames first from the commandline arguments @@ -666,10 +744,10 @@ class AuthSrv(object): before finally building the VFS """ - acct = {} # username:password - daxs = {} # type: dict[str, AXS] - mflags = {} # mountpoint:[flag] - mount = {} # dst:src (mountpoint:realpath) + acct: dict[str, str] = {} # username:password + daxs: dict[str, AXS] = {} + mflags: dict[str, dict[str, Any]] = {} # moutpoint:flags + mount: dict[str, str] = {} # dst:src (mountpoint:realpath) if self.args.a: # list of username:password @@ -678,8 +756,8 @@ class AuthSrv(object): u, p = x.split(":", 1) acct[u] = p except: - m = '\n invalid value "{}" for argument -a, must be username:password' - raise Exception(m.format(x)) + t = '\n invalid value "{}" for argument -a, must be username:password' + raise Exception(t.format(x)) if self.args.v: # list of src:dst:permset:permset:... @@ -708,8 +786,8 @@ class AuthSrv(object): try: self._parse_config_file(f, acct, daxs, mflags, mount) except: - m = "\n\033[1;31m\nerror in config file {} on line {}:\n\033[0m" - self.log(m.format(cfg_fn, self.line_ctr), 1) + t = "\n\033[1;31m\nerror in config file {} on line {}:\n\033[0m" + self.log(t.format(cfg_fn, self.line_ctr), 1) raise # case-insensitive; normalize @@ -726,7 +804,7 @@ class AuthSrv(object): vfs = VFS(self.log_func, bos.path.abspath("."), "", axs, {}) elif "" not in mount: # there's volumes but no root; make root inaccessible - vfs = VFS(self.log_func, None, "", AXS(), {}) + vfs = VFS(self.log_func, "", "", AXS(), {}) vfs.flags["d2d"] = True maxdepth = 0 @@ -740,10 +818,10 @@ class AuthSrv(object): vfs = VFS(self.log_func, mount[dst], dst, daxs[dst], mflags[dst]) continue - v = vfs.add(mount[dst], dst) - v.axs = daxs[dst] - v.flags = mflags[dst] - v.dbv = None + zv = vfs.add(mount[dst], dst) + zv.axs = daxs[dst] + zv.flags = mflags[dst] + zv.dbv = None vfs.all_vols = {} vfs.get_all_vols(vfs.all_vols) @@ -751,11 +829,11 @@ class AuthSrv(object): for perm in "read write move del get".split(): axs_key = "u" + perm unames = ["*"] + list(acct.keys()) - umap = {x: [] for x in unames} + umap: dict[str, list[str]] = {x: [] for x in unames} for usr in unames: for vp, vol in vfs.all_vols.items(): - axs = getattr(vol.axs, axs_key) - if usr in axs or "*" in axs: + zx = getattr(vol.axs, axs_key) + if usr in zx or "*" in zx: umap[usr].append(vp) umap[usr].sort() setattr(vfs, "a" + perm, umap) @@ -764,7 +842,7 @@ class AuthSrv(object): missing_users = {} for axs in daxs.values(): for d in [axs.uread, axs.uwrite, axs.umove, axs.udel, axs.uget]: - for usr in d.keys(): + for usr in d: all_users[usr] = 1 if usr != "*" and usr not in acct: missing_users[usr] = 1 @@ -783,8 +861,8 @@ class AuthSrv(object): promote = [] demote = [] for vol in vfs.all_vols.values(): - hid = hashlib.sha512(fsenc(vol.realpath)).digest() - hid = base64.b32encode(hid).decode("ascii").lower() + zb = hashlib.sha512(fsenc(vol.realpath)).digest() + hid = base64.b32encode(zb).decode("ascii").lower() vflag = vol.flags.get("hist") if vflag == "-": pass @@ -822,21 +900,21 @@ class AuthSrv(object): demote.append(vol) # discard jump-vols - for v in demote: - vfs.all_vols.pop(v.vpath) + for zv in demote: + vfs.all_vols.pop(zv.vpath) if promote: - msg = [ + ta = [ "\n the following jump-volumes were generated to assist the vfs.\n As they contain a database (probably from v0.11.11 or older),\n they are promoted to full volumes:" ] for vol in promote: - msg.append( + ta.append( " /{} ({}) ({})".format(vol.vpath, vol.realpath, vol.histpath) ) - self.log("\n\n".join(msg) + "\n", c=3) + self.log("\n\n".join(ta) + "\n", c=3) - vfs.histtab = {v.realpath: v.histpath for v in vfs.all_vols.values()} + vfs.histtab = {zv.realpath: zv.histpath for zv in vfs.all_vols.values()} for vol in vfs.all_vols.values(): lim = Lim() @@ -846,30 +924,30 @@ class AuthSrv(object): use = True lim.nosub = True - v = vol.flags.get("sz") - if v: + zs = vol.flags.get("sz") + if zs: use = True - lim.smin, lim.smax = [unhumanize(x) for x in v.split("-")] + lim.smin, lim.smax = [unhumanize(x) for x in zs.split("-")] - v = vol.flags.get("rotn") - if v: + zs = vol.flags.get("rotn") + if zs: use = True - lim.rotn, lim.rotl = [int(x) for x in v.split(",")] + lim.rotn, lim.rotl = [int(x) for x in zs.split(",")] - v = vol.flags.get("rotf") - if v: + zs = vol.flags.get("rotf") + if zs: use = True - lim.set_rotf(v) + lim.set_rotf(zs) - v = vol.flags.get("maxn") - if v: + zs = vol.flags.get("maxn") + if zs: use = True - lim.nmax, lim.nwin = [int(x) for x in v.split(",")] + lim.nmax, lim.nwin = [int(x) for x in zs.split(",")] - v = vol.flags.get("maxb") - if v: + zs = vol.flags.get("maxb") + if zs: use = True - lim.bmax, lim.bwin = [unhumanize(x) for x in v.split(",")] + lim.bmax, lim.bwin = [unhumanize(x) for x in zs.split(",")] if use: vol.lim = lim @@ -1005,8 +1083,8 @@ class AuthSrv(object): for mtp in local_only_mtp: if mtp not in local_mte: - m = 'volume "/{}" defines metadata tag "{}", but doesnt use it in "-mte" (or with "cmte" in its volume-flags)' - self.log(m.format(vol.vpath, mtp), 1) + t = 'volume "/{}" defines metadata tag "{}", but doesnt use it in "-mte" (or with "cmte" in its volume-flags)' + self.log(t.format(vol.vpath, mtp), 1) errors = True tags = self.args.mtp or [] @@ -1014,8 +1092,8 @@ class AuthSrv(object): tags = [y for x in tags for y in x.split(",")] for mtp in tags: if mtp not in all_mte: - m = 'metadata tag "{}" is defined by "-mtm" or "-mtp", but is not used by "-mte" (or by any "cmte" volume-flag)' - self.log(m.format(mtp), 1) + t = 'metadata tag "{}" is defined by "-mtm" or "-mtp", but is not used by "-mte" (or by any "cmte" volume-flag)' + self.log(t.format(mtp), 1) errors = True if errors: @@ -1023,12 +1101,12 @@ class AuthSrv(object): vfs.bubble_flags() - m = "volumes and permissions:\n" - for v in vfs.all_vols.values(): + t = "volumes and permissions:\n" + for zv in vfs.all_vols.values(): if not self.warn_anonwrite: break - m += '\n\033[36m"/{}" \033[33m{}\033[0m'.format(v.vpath, v.realpath) + t += '\n\033[36m"/{}" \033[33m{}\033[0m'.format(zv.vpath, zv.realpath) for txt, attr in [ [" read", "uread"], [" write", "uwrite"], @@ -1036,21 +1114,21 @@ class AuthSrv(object): ["delete", "udel"], [" get", "uget"], ]: - u = list(sorted(getattr(v.axs, attr).keys())) + u = list(sorted(getattr(zv.axs, attr))) u = ", ".join("\033[35meverybody\033[0m" if x == "*" else x for x in u) u = u if u else "\033[36m--none--\033[0m" - m += "\n| {}: {}".format(txt, u) - m += "\n" + t += "\n| {}: {}".format(txt, u) + t += "\n" if self.warn_anonwrite and not self.args.no_voldump: - self.log(m) + self.log(t) try: - v, _ = vfs.get("/", "*", False, True) - if self.warn_anonwrite and os.getcwd() == v.realpath: + zv, _ = vfs.get("/", "*", False, True) + if self.warn_anonwrite and os.getcwd() == zv.realpath: self.warn_anonwrite = False - msg = "anyone can read/write the current directory: {}\n" - self.log(msg.format(v.realpath), c=1) + t = "anyone can read/write the current directory: {}\n" + self.log(t.format(zv.realpath), c=1) except Pebkac: self.warn_anonwrite = True @@ -1064,19 +1142,19 @@ class AuthSrv(object): if pwds: self.re_pwd = re.compile("=(" + "|".join(pwds) + ")([]&; ]|$)") - def dbg_ls(self): + def dbg_ls(self) -> None: users = self.args.ls - vols = "*" - flags = [] + vol = "*" + flags: list[str] = [] try: - users, vols = users.split(",", 1) + users, vol = users.split(",", 1) except: pass try: - vols, flags = vols.split(",", 1) - flags = flags.split(",") + vol, zf = vol.split(",", 1) + flags = zf.split(",") except: pass @@ -1089,23 +1167,23 @@ class AuthSrv(object): if u not in self.acct and u != "*": raise Exception("user not found: " + u) - if vols == "*": + if vol == "*": vols = ["/" + x for x in self.vfs.all_vols] else: - vols = [vols] + vols = [vol] - for v in vols: - if not v.startswith("/"): + for zs in vols: + if not zs.startswith("/"): raise Exception("volumes must start with /") - if v[1:] not in self.vfs.all_vols: - raise Exception("volume not found: " + v) + if zs[1:] not in self.vfs.all_vols: + raise Exception("volume not found: " + zs) - self.log({"users": users, "vols": vols, "flags": flags}) - m = "/{}: read({}) write({}) move({}) del({}) get({})" - for k, v in self.vfs.all_vols.items(): - vc = v.axs - self.log(m.format(k, vc.uread, vc.uwrite, vc.umove, vc.udel, vc.uget)) + self.log(str({"users": users, "vols": vols, "flags": flags})) + t = "/{}: read({}) write({}) move({}) del({}) get({})" + for k, zv in self.vfs.all_vols.items(): + vc = zv.axs + self.log(t.format(k, vc.uread, vc.uwrite, vc.umove, vc.udel, vc.uget)) flag_v = "v" in flags flag_ln = "ln" in flags @@ -1136,12 +1214,14 @@ class AuthSrv(object): False, False, ) - for _, _, vpath, apath, files, dirs, _ in g: - fnames = [n[0] for n in files] - vpaths = [vpath + "/" + n for n in fnames] if vpath else fnames - vpaths = [vtop + x for x in vpaths] + for _, _, vpath, apath, files1, dirs, _ in g: + fnames = [n[0] for n in files1] + zsl = [vpath + "/" + n for n in fnames] if vpath else fnames + vpaths = [vtop + x for x in zsl] apaths = [os.path.join(apath, n) for n in fnames] - files = [[vpath + "/", apath + os.sep]] + list(zip(vpaths, apaths)) + files = [(vpath + "/", apath + os.sep)] + list( + [(zs1, zs2) for zs1, zs2 in zip(vpaths, apaths)] + ) if flag_ln: files = [x for x in files if not x[1].startswith(safeabs)] @@ -1152,21 +1232,23 @@ class AuthSrv(object): if not files: continue elif flag_v: - msg = [""] + [ + ta = [""] + [ '# user "{}", vpath "{}"\n{}'.format(u, vp, ap) for vp, ap in files ] else: - msg = ["user {}, vol {}: {} =>".format(u, vtop, files[0][0])] - msg += [x[1] for x in files] + ta = ["user {}, vol {}: {} =>".format(u, vtop, files[0][0])] + ta += [x[1] for x in files] - self.log("\n".join(msg)) + self.log("\n".join(ta)) if bads: self.log("\n ".join(["found symlinks leaving volume:"] + bads)) if bads and flag_p: - raise Exception("found symlink leaving volume, and strict is set") + raise Exception( + "\033[31m\n [--ls] found a safety issue and prevented startup:\n found symlinks leaving volume, and strict is set\n\033[0m" + ) if not flag_r: sys.exit(0) diff --git a/copyparty/bos/bos.py b/copyparty/bos/bos.py index d5e003cf..617545af 100644 --- a/copyparty/bos/bos.py +++ b/copyparty/bos/bos.py @@ -2,23 +2,30 @@ from __future__ import print_function, unicode_literals import os -from ..util import fsenc, fsdec, SYMTIME + +from ..util import SYMTIME, fsdec, fsenc from . import path +try: + from typing import Optional +except: + pass + +_ = (path,) # grep -hRiE '(^|[^a-zA-Z_\.-])os\.' . | gsed -r 's/ /\n/g;s/\(/(\n/g' | grep -hRiE '(^|[^a-zA-Z_\.-])os\.' | sort | uniq -c # printf 'os\.(%s)' "$(grep ^def bos/__init__.py | gsed -r 's/^def //;s/\(.*//' | tr '\n' '|' | gsed -r 's/.$//')" -def chmod(p, mode): +def chmod(p: str, mode: int) -> None: return os.chmod(fsenc(p), mode) -def listdir(p="."): +def listdir(p: str = ".") -> list[str]: return [fsdec(x) for x in os.listdir(fsenc(p))] -def makedirs(name, mode=0o755, exist_ok=True): +def makedirs(name: str, mode: int = 0o755, exist_ok: bool = True) -> None: bname = fsenc(name) try: os.makedirs(bname, mode) @@ -27,31 +34,33 @@ def makedirs(name, mode=0o755, exist_ok=True): raise -def mkdir(p, mode=0o755): +def mkdir(p: str, mode: int = 0o755) -> None: return os.mkdir(fsenc(p), mode) -def rename(src, dst): +def rename(src: str, dst: str) -> None: return os.rename(fsenc(src), fsenc(dst)) -def replace(src, dst): +def replace(src: str, dst: str) -> None: return os.replace(fsenc(src), fsenc(dst)) -def rmdir(p): +def rmdir(p: str) -> None: return os.rmdir(fsenc(p)) -def stat(p): +def stat(p: str) -> os.stat_result: return os.stat(fsenc(p)) -def unlink(p): +def unlink(p: str) -> None: return os.unlink(fsenc(p)) -def utime(p, times=None, follow_symlinks=True): +def utime( + p: str, times: Optional[tuple[float, float]] = None, follow_symlinks: bool = True +) -> None: if SYMTIME: return os.utime(fsenc(p), times, follow_symlinks=follow_symlinks) else: @@ -60,7 +69,7 @@ def utime(p, times=None, follow_symlinks=True): if hasattr(os, "lstat"): - def lstat(p): + def lstat(p: str) -> os.stat_result: return os.lstat(fsenc(p)) else: diff --git a/copyparty/bos/path.py b/copyparty/bos/path.py index 066453b0..c5769d84 100644 --- a/copyparty/bos/path.py +++ b/copyparty/bos/path.py @@ -2,43 +2,44 @@ from __future__ import print_function, unicode_literals import os -from ..util import fsenc, fsdec, SYMTIME + +from ..util import SYMTIME, fsdec, fsenc -def abspath(p): +def abspath(p: str) -> str: return fsdec(os.path.abspath(fsenc(p))) -def exists(p): +def exists(p: str) -> bool: return os.path.exists(fsenc(p)) -def getmtime(p, follow_symlinks=True): +def getmtime(p: str, follow_symlinks: bool = True) -> float: if not follow_symlinks and SYMTIME: return os.lstat(fsenc(p)).st_mtime else: return os.path.getmtime(fsenc(p)) -def getsize(p): +def getsize(p: str) -> int: return os.path.getsize(fsenc(p)) -def isfile(p): +def isfile(p: str) -> bool: return os.path.isfile(fsenc(p)) -def isdir(p): +def isdir(p: str) -> bool: return os.path.isdir(fsenc(p)) -def islink(p): +def islink(p: str) -> bool: return os.path.islink(fsenc(p)) -def lexists(p): +def lexists(p: str) -> bool: return os.path.lexists(fsenc(p)) -def realpath(p): +def realpath(p: str) -> str: return fsdec(os.path.realpath(fsenc(p))) diff --git a/copyparty/broker_mp.py b/copyparty/broker_mp.py index df73658d..c7bfed70 100644 --- a/copyparty/broker_mp.py +++ b/copyparty/broker_mp.py @@ -1,37 +1,56 @@ # coding: utf-8 from __future__ import print_function, unicode_literals -import time import threading +import time -from .broker_util import try_exec +import queue + +from .__init__ import TYPE_CHECKING from .broker_mpw import MpWorker +from .broker_util import try_exec from .util import mp +if TYPE_CHECKING: + from .svchub import SvcHub + +try: + from typing import Any +except: + pass + + +class MProcess(mp.Process): + def __init__( + self, + q_pend: queue.Queue[tuple[int, str, list[Any]]], + q_yield: queue.Queue[tuple[int, str, list[Any]]], + target: Any, + args: Any, + ) -> None: + super(MProcess, self).__init__(target=target, args=args) + self.q_pend = q_pend + self.q_yield = q_yield + class BrokerMp(object): """external api; manages MpWorkers""" - def __init__(self, hub): + def __init__(self, hub: "SvcHub") -> None: self.hub = hub self.log = hub.log self.args = hub.args self.procs = [] - self.retpend = {} - self.retpend_mutex = threading.Lock() self.mutex = threading.Lock() self.num_workers = self.args.j or mp.cpu_count() self.log("broker", "booting {} subprocesses".format(self.num_workers)) for n in range(1, self.num_workers + 1): - q_pend = mp.Queue(1) - q_yield = mp.Queue(64) + q_pend: queue.Queue[tuple[int, str, list[Any]]] = mp.Queue(1) + q_yield: queue.Queue[tuple[int, str, list[Any]]] = mp.Queue(64) - proc = mp.Process(target=MpWorker, args=(q_pend, q_yield, self.args, n)) - proc.q_pend = q_pend - proc.q_yield = q_yield - proc.clients = {} + proc = MProcess(q_pend, q_yield, MpWorker, (q_pend, q_yield, self.args, n)) thr = threading.Thread( target=self.collector, args=(proc,), name="mp-sink-{}".format(n) @@ -42,11 +61,11 @@ class BrokerMp(object): self.procs.append(proc) proc.start() - def shutdown(self): + def shutdown(self) -> None: self.log("broker", "shutting down") for n, proc in enumerate(self.procs): thr = threading.Thread( - target=proc.q_pend.put([0, "shutdown", []]), + target=proc.q_pend.put((0, "shutdown", [])), name="mp-shutdown-{}-{}".format(n, len(self.procs)), ) thr.start() @@ -62,12 +81,12 @@ class BrokerMp(object): procs.pop() - def reload(self): + def reload(self) -> None: self.log("broker", "reloading") for _, proc in enumerate(self.procs): - proc.q_pend.put([0, "reload", []]) + proc.q_pend.put((0, "reload", [])) - def collector(self, proc): + def collector(self, proc: MProcess) -> None: """receive message from hub in other process""" while True: msg = proc.q_yield.get() @@ -78,10 +97,7 @@ class BrokerMp(object): elif dest == "retq": # response from previous ipc call - with self.retpend_mutex: - retq = self.retpend.pop(retq_id) - - retq.put(args) + raise Exception("invalid broker_mp usage") else: # new ipc invoking managed service in hub @@ -93,9 +109,9 @@ class BrokerMp(object): rv = try_exec(retq_id, obj, *args) if retq_id: - proc.q_pend.put([retq_id, "retq", rv]) + proc.q_pend.put((retq_id, "retq", rv)) - def put(self, want_retval, dest, *args): + def say(self, dest: str, *args: Any) -> None: """ send message to non-hub component in other process, returns a Queue object which eventually contains the response if want_retval @@ -103,7 +119,7 @@ class BrokerMp(object): """ if dest == "listen": for p in self.procs: - p.q_pend.put([0, dest, [args[0], len(self.procs)]]) + p.q_pend.put((0, dest, [args[0], len(self.procs)])) elif dest == "cb_httpsrv_up": self.hub.cb_httpsrv_up() diff --git a/copyparty/broker_mpw.py b/copyparty/broker_mpw.py index c4a1054c..4dbec5b1 100644 --- a/copyparty/broker_mpw.py +++ b/copyparty/broker_mpw.py @@ -1,20 +1,38 @@ # coding: utf-8 from __future__ import print_function, unicode_literals -import sys +import argparse import signal +import sys import threading -from .broker_util import ExceptionalQueue +import queue + +from .authsrv import AuthSrv +from .broker_util import BrokerCli, ExceptionalQueue from .httpsrv import HttpSrv from .util import FAKE_MP -from .authsrv import AuthSrv + +try: + from types import FrameType + + from typing import Any, Optional, Union +except: + pass -class MpWorker(object): +class MpWorker(BrokerCli): """one single mp instance""" - def __init__(self, q_pend, q_yield, args, n): + def __init__( + self, + q_pend: queue.Queue[tuple[int, str, list[Any]]], + q_yield: queue.Queue[tuple[int, str, list[Any]]], + args: argparse.Namespace, + n: int, + ) -> None: + super(MpWorker, self).__init__() + self.q_pend = q_pend self.q_yield = q_yield self.args = args @@ -22,7 +40,7 @@ class MpWorker(object): self.log = self._log_disabled if args.q and not args.lo else self._log_enabled - self.retpend = {} + self.retpend: dict[int, Any] = {} self.retpend_mutex = threading.Lock() self.mutex = threading.Lock() @@ -45,20 +63,20 @@ class MpWorker(object): thr.start() thr.join() - def signal_handler(self, sig, frame): + def signal_handler(self, sig: Optional[int], frame: Optional[FrameType]) -> None: # print('k') pass - def _log_enabled(self, src, msg, c=0): - self.q_yield.put([0, "log", [src, msg, c]]) + def _log_enabled(self, src: str, msg: str, c: Union[int, str] = 0) -> None: + self.q_yield.put((0, "log", [src, msg, c])) - def _log_disabled(self, src, msg, c=0): + def _log_disabled(self, src: str, msg: str, c: Union[int, str] = 0) -> None: pass - def logw(self, msg, c=0): + def logw(self, msg: str, c: Union[int, str] = 0) -> None: self.log("mp{}".format(self.n), msg, c) - def main(self): + def main(self) -> None: while True: retq_id, dest, args = self.q_pend.get() @@ -87,15 +105,14 @@ class MpWorker(object): else: raise Exception("what is " + str(dest)) - def put(self, want_retval, dest, *args): - if want_retval: - retq = ExceptionalQueue(1) - retq_id = id(retq) - with self.retpend_mutex: - self.retpend[retq_id] = retq - else: - retq = None - retq_id = 0 + def ask(self, dest: str, *args: Any) -> ExceptionalQueue: + retq = ExceptionalQueue(1) + retq_id = id(retq) + with self.retpend_mutex: + self.retpend[retq_id] = retq - self.q_yield.put([retq_id, dest, args]) + self.q_yield.put((retq_id, dest, list(args))) return retq + + def say(self, dest: str, *args: Any) -> None: + self.q_yield.put((0, dest, list(args))) diff --git a/copyparty/broker_thr.py b/copyparty/broker_thr.py index 1c7a1abf..51c25d41 100644 --- a/copyparty/broker_thr.py +++ b/copyparty/broker_thr.py @@ -3,14 +3,25 @@ from __future__ import print_function, unicode_literals import threading +from .__init__ import TYPE_CHECKING +from .broker_util import BrokerCli, ExceptionalQueue, try_exec from .httpsrv import HttpSrv -from .broker_util import ExceptionalQueue, try_exec + +if TYPE_CHECKING: + from .svchub import SvcHub + +try: + from typing import Any +except: + pass -class BrokerThr(object): +class BrokerThr(BrokerCli): """external api; behaves like BrokerMP but using plain threads""" - def __init__(self, hub): + def __init__(self, hub: "SvcHub") -> None: + super(BrokerThr, self).__init__() + self.hub = hub self.log = hub.log self.args = hub.args @@ -23,29 +34,35 @@ class BrokerThr(object): self.httpsrv = HttpSrv(self, None) self.reload = self.noop - def shutdown(self): + def shutdown(self) -> None: # self.log("broker", "shutting down") self.httpsrv.shutdown() - def noop(self): + def noop(self) -> None: pass - def put(self, want_retval, dest, *args): + def ask(self, dest: str, *args: Any) -> ExceptionalQueue: + + # new ipc invoking managed service in hub + obj = self.hub + for node in dest.split("."): + obj = getattr(obj, node) + + rv = try_exec(True, obj, *args) + + # pretend we're broker_mp + retq = ExceptionalQueue(1) + retq.put(rv) + return retq + + def say(self, dest: str, *args: Any) -> None: if dest == "listen": self.httpsrv.listen(args[0], 1) + return - else: - # new ipc invoking managed service in hub - obj = self.hub - for node in dest.split("."): - obj = getattr(obj, node) + # new ipc invoking managed service in hub + obj = self.hub + for node in dest.split("."): + obj = getattr(obj, node) - # TODO will deadlock if dest performs another ipc - rv = try_exec(want_retval, obj, *args) - if not want_retval: - return - - # pretend we're broker_mp - retq = ExceptionalQueue(1) - retq.put(rv) - return retq + try_exec(False, obj, *args) diff --git a/copyparty/broker_util.py b/copyparty/broker_util.py index 896cba2c..b0d44575 100644 --- a/copyparty/broker_util.py +++ b/copyparty/broker_util.py @@ -1,14 +1,28 @@ # coding: utf-8 from __future__ import print_function, unicode_literals - +import argparse import traceback -from .util import Pebkac, Queue +from queue import Queue + +from .__init__ import TYPE_CHECKING +from .authsrv import AuthSrv +from .util import Pebkac + +try: + from typing import Any, Optional, Union + + from .util import RootLogger +except: + pass + +if TYPE_CHECKING: + from .httpsrv import HttpSrv class ExceptionalQueue(Queue, object): - def get(self, block=True, timeout=None): + def get(self, block: bool = True, timeout: Optional[float] = None) -> Any: rv = super(ExceptionalQueue, self).get(block, timeout) if isinstance(rv, list): @@ -21,7 +35,26 @@ class ExceptionalQueue(Queue, object): return rv -def try_exec(want_retval, func, *args): +class BrokerCli(object): + """ + helps mypy understand httpsrv.broker but still fails a few levels deeper, + for example resolving httpconn.* in httpcli -- see lines tagged #mypy404 + """ + + def __init__(self) -> None: + self.log: RootLogger = None + self.args: argparse.Namespace = None + self.asrv: AuthSrv = None + self.httpsrv: "HttpSrv" = None + + def ask(self, dest: str, *args: Any) -> ExceptionalQueue: + return ExceptionalQueue(1) + + def say(self, dest: str, *args: Any) -> None: + pass + + +def try_exec(want_retval: Union[bool, int], func: Any, *args: list[Any]) -> Any: try: return func(*args) diff --git a/copyparty/ftpd.py b/copyparty/ftpd.py index 828c0db6..a47586de 100644 --- a/copyparty/ftpd.py +++ b/copyparty/ftpd.py @@ -1,23 +1,23 @@ # coding: utf-8 from __future__ import print_function, unicode_literals -import os -import sys -import stat -import time -import logging import argparse +import logging +import os +import stat +import sys import threading +import time -from pyftpdlib.authorizers import DummyAuthorizer, AuthenticationFailed +from pyftpdlib.authorizers import AuthenticationFailed, DummyAuthorizer from pyftpdlib.filesystems import AbstractedFS, FilesystemError from pyftpdlib.handlers import FTPHandler -from pyftpdlib.servers import FTPServer from pyftpdlib.log import config_logging +from pyftpdlib.servers import FTPServer -from .__init__ import E, PY2 -from .util import Pebkac, fsenc, exclude_dotfiles +from .__init__ import PY2, TYPE_CHECKING, E from .bos import bos +from .util import Pebkac, exclude_dotfiles, fsenc try: from pyftpdlib.ioloop import IOLoop @@ -28,58 +28,63 @@ except ImportError: from pyftpdlib.ioloop import IOLoop -try: - from typing import TYPE_CHECKING +if TYPE_CHECKING: + from .svchub import SvcHub - if TYPE_CHECKING: - from .svchub import SvcHub -except ImportError: +try: + import typing + from typing import Any, Optional +except: pass class FtpAuth(DummyAuthorizer): - def __init__(self): + def __init__(self, hub: "SvcHub") -> None: super(FtpAuth, self).__init__() - self.hub = None # type: SvcHub + self.hub = hub - def validate_authentication(self, username, password, handler): + def validate_authentication( + self, username: str, password: str, handler: Any + ) -> None: asrv = self.hub.asrv if username == "anonymous": password = "" uname = "*" if password: - uname = asrv.iacct.get(password, None) + uname = asrv.iacct.get(password, "") handler.username = uname if password and not uname: raise AuthenticationFailed("Authentication failed.") - def get_home_dir(self, username): + def get_home_dir(self, username: str) -> str: return "/" - def has_user(self, username): + def has_user(self, username: str) -> bool: asrv = self.hub.asrv return username in asrv.acct - def has_perm(self, username, perm, path=None): + def has_perm(self, username: str, perm: int, path: Optional[str] = None) -> bool: return True # handled at filesystem layer - def get_perms(self, username): + def get_perms(self, username: str) -> str: return "elradfmwMT" - def get_msg_login(self, username): + def get_msg_login(self, username: str) -> str: return "sup {}".format(username) - def get_msg_quit(self, username): + def get_msg_quit(self, username: str) -> str: return "cya" class FtpFs(AbstractedFS): - def __init__(self, root, cmd_channel): + def __init__( + self, root: str, cmd_channel: Any + ) -> None: # pylint: disable=super-init-not-called self.h = self.cmd_channel = cmd_channel # type: FTPHandler - self.hub = cmd_channel.hub # type: SvcHub + self.hub: "SvcHub" = cmd_channel.hub self.args = cmd_channel.args self.uname = self.hub.asrv.iacct.get(cmd_channel.password, "*") @@ -90,7 +95,14 @@ class FtpFs(AbstractedFS): self.listdirinfo = self.listdir self.chdir(".") - def v2a(self, vpath, r=False, w=False, m=False, d=False): + def v2a( + self, + vpath: str, + r: bool = False, + w: bool = False, + m: bool = False, + d: bool = False, + ) -> str: try: vpath = vpath.replace("\\", "/").lstrip("/") vfs, rem = self.hub.asrv.vfs.get(vpath, self.uname, r, w, m, d) @@ -101,25 +113,32 @@ class FtpFs(AbstractedFS): except Pebkac as ex: raise FilesystemError(str(ex)) - def rv2a(self, vpath, r=False, w=False, m=False, d=False): + def rv2a( + self, + vpath: str, + r: bool = False, + w: bool = False, + m: bool = False, + d: bool = False, + ) -> str: return self.v2a(os.path.join(self.cwd, vpath), r, w, m, d) - def ftp2fs(self, ftppath): + def ftp2fs(self, ftppath: str) -> str: # return self.v2a(ftppath) return ftppath # self.cwd must be vpath - def fs2ftp(self, fspath): + def fs2ftp(self, fspath: str) -> str: # raise NotImplementedError() return fspath - def validpath(self, path): + def validpath(self, path: str) -> bool: if "/.hist/" in path: if "/up2k." in path or path.endswith("/dir.txt"): raise FilesystemError("access to this file is forbidden") return True - def open(self, filename, mode): + def open(self, filename: str, mode: str) -> typing.IO[Any]: r = "r" in mode w = "w" in mode or "a" in mode or "+" in mode @@ -130,24 +149,24 @@ class FtpFs(AbstractedFS): self.validpath(ap) return open(fsenc(ap), mode) - def chdir(self, path): + def chdir(self, path: str) -> None: self.cwd = join(self.cwd, path) x = self.hub.asrv.vfs.can_access(self.cwd.lstrip("/"), self.h.username) self.can_read, self.can_write, self.can_move, self.can_delete, self.can_get = x - def mkdir(self, path): + def mkdir(self, path: str) -> None: ap = self.rv2a(path, w=True) bos.mkdir(ap) - def listdir(self, path): + def listdir(self, path: str) -> list[str]: vpath = join(self.cwd, path).lstrip("/") try: vfs, rem = self.hub.asrv.vfs.get(vpath, self.uname, True, False) - fsroot, vfs_ls, vfs_virt = vfs.ls( + fsroot, vfs_ls1, vfs_virt = vfs.ls( rem, self.uname, not self.args.no_scandir, [[True], [False, True]] ) - vfs_ls = [x[0] for x in vfs_ls] + vfs_ls = [x[0] for x in vfs_ls1] vfs_ls.extend(vfs_virt.keys()) if not self.args.ed: @@ -164,11 +183,11 @@ class FtpFs(AbstractedFS): r = {x.split("/")[0]: 1 for x in self.hub.asrv.vfs.all_vols.keys()} return list(sorted(list(r.keys()))) - def rmdir(self, path): + def rmdir(self, path: str) -> None: ap = self.rv2a(path, d=True) bos.rmdir(ap) - def remove(self, path): + def remove(self, path: str) -> None: if self.args.no_del: raise FilesystemError("the delete feature is disabled in server config") @@ -178,13 +197,13 @@ class FtpFs(AbstractedFS): except Exception as ex: raise FilesystemError(str(ex)) - def rename(self, src, dst): + def rename(self, src: str, dst: str) -> None: if not self.can_move: raise FilesystemError("not allowed for user " + self.h.username) if self.args.no_mv: - m = "the rename/move feature is disabled in server config" - raise FilesystemError(m) + t = "the rename/move feature is disabled in server config" + raise FilesystemError(t) svp = join(self.cwd, src).lstrip("/") dvp = join(self.cwd, dst).lstrip("/") @@ -193,10 +212,10 @@ class FtpFs(AbstractedFS): except Exception as ex: raise FilesystemError(str(ex)) - def chmod(self, path, mode): + def chmod(self, path: str, mode: str) -> None: pass - def stat(self, path): + def stat(self, path: str) -> os.stat_result: try: ap = self.rv2a(path, r=True) return bos.stat(ap) @@ -208,59 +227,59 @@ class FtpFs(AbstractedFS): return st - def utime(self, path, timeval): + def utime(self, path: str, timeval: float) -> None: ap = self.rv2a(path, w=True) return bos.utime(ap, (timeval, timeval)) - def lstat(self, path): + def lstat(self, path: str) -> os.stat_result: ap = self.rv2a(path) return bos.lstat(ap) - def isfile(self, path): + def isfile(self, path: str) -> bool: st = self.stat(path) return stat.S_ISREG(st.st_mode) - def islink(self, path): + def islink(self, path: str) -> bool: ap = self.rv2a(path) return bos.path.islink(ap) - def isdir(self, path): + def isdir(self, path: str) -> bool: try: st = self.stat(path) return stat.S_ISDIR(st.st_mode) except: return True - def getsize(self, path): + def getsize(self, path: str) -> int: ap = self.rv2a(path) return bos.path.getsize(ap) - def getmtime(self, path): + def getmtime(self, path: str) -> float: ap = self.rv2a(path) return bos.path.getmtime(ap) - def realpath(self, path): + def realpath(self, path: str) -> str: return path - def lexists(self, path): + def lexists(self, path: str) -> bool: ap = self.rv2a(path) return bos.path.lexists(ap) - def get_user_by_uid(self, uid): + def get_user_by_uid(self, uid: int) -> str: return "root" - def get_group_by_uid(self, gid): + def get_group_by_uid(self, gid: int) -> str: return "root" class FtpHandler(FTPHandler): abstracted_fs = FtpFs - hub = None # type: SvcHub - args = None # type: argparse.Namespace + hub: "SvcHub" = None + args: argparse.Namespace = None - def __init__(self, conn, server, ioloop=None): - self.hub = FtpHandler.hub # type: SvcHub - self.args = FtpHandler.args # type: argparse.Namespace + def __init__(self, conn: Any, server: Any, ioloop: Any = None) -> None: + self.hub: "SvcHub" = FtpHandler.hub + self.args: argparse.Namespace = FtpHandler.args if PY2: FTPHandler.__init__(self, conn, server, ioloop) @@ -268,9 +287,10 @@ class FtpHandler(FTPHandler): super(FtpHandler, self).__init__(conn, server, ioloop) # abspath->vpath mapping to resolve log_transfer paths - self.vfs_map = {} + self.vfs_map: dict[str, str] = {} - def ftp_STOR(self, file, mode="w"): + def ftp_STOR(self, file: str, mode: str = "w") -> Any: + # Optional[str] vp = join(self.fs.cwd, file).lstrip("/") ap = self.fs.v2a(vp) self.vfs_map[ap] = vp @@ -279,7 +299,16 @@ class FtpHandler(FTPHandler): # print("ftp_STOR: {} {} OK".format(vp, mode)) return ret - def log_transfer(self, cmd, filename, receive, completed, elapsed, bytes): + def log_transfer( + self, + cmd: str, + filename: bytes, + receive: bool, + completed: bool, + elapsed: float, + bytes: int, + ) -> Any: + # None ap = filename.decode("utf-8", "replace") vp = self.vfs_map.pop(ap, None) # print("xfer_end: {} => {}".format(ap, vp)) @@ -312,7 +341,7 @@ except: class Ftpd(object): - def __init__(self, hub): + def __init__(self, hub: "SvcHub") -> None: self.hub = hub self.args = hub.args @@ -321,24 +350,23 @@ class Ftpd(object): hs.append([FtpHandler, self.args.ftp]) if self.args.ftps: try: - h = SftpHandler + h1 = SftpHandler except: - m = "\nftps requires pyopenssl;\nplease run the following:\n\n {} -m pip install --user pyopenssl\n" - print(m.format(sys.executable)) + t = "\nftps requires pyopenssl;\nplease run the following:\n\n {} -m pip install --user pyopenssl\n" + print(t.format(sys.executable)) sys.exit(1) - h.certfile = os.path.join(E.cfg, "cert.pem") - h.tls_control_required = True - h.tls_data_required = True + h1.certfile = os.path.join(E.cfg, "cert.pem") + h1.tls_control_required = True + h1.tls_data_required = True - hs.append([h, self.args.ftps]) + hs.append([h1, self.args.ftps]) - for h in hs: - h, lp = h - h.hub = hub - h.args = hub.args - h.authorizer = FtpAuth() - h.authorizer.hub = hub + for h_lp in hs: + h2, lp = h_lp + h2.hub = hub + h2.args = hub.args + h2.authorizer = FtpAuth(hub) if self.args.ftp_pr: p1, p2 = [int(x) for x in self.args.ftp_pr.split("-")] @@ -350,10 +378,10 @@ class Ftpd(object): else: p1 += d + 1 - h.passive_ports = list(range(p1, p2 + 1)) + h2.passive_ports = list(range(p1, p2 + 1)) if self.args.ftp_nat: - h.masquerade_address = self.args.ftp_nat + h2.masquerade_address = self.args.ftp_nat if self.args.ftp_dbg: config_logging(level=logging.DEBUG) @@ -363,11 +391,11 @@ class Ftpd(object): for h, lp in hs: FTPServer((ip, int(lp)), h, ioloop) - t = threading.Thread(target=ioloop.loop) - t.daemon = True - t.start() + thr = threading.Thread(target=ioloop.loop) + thr.daemon = True + thr.start() -def join(p1, p2): +def join(p1: str, p2: str) -> str: w = os.path.join(p1, p2.replace("\\", "/")) return os.path.normpath(w).replace("\\", "/") diff --git a/copyparty/httpcli.py b/copyparty/httpcli.py index 5368bf62..cfc2bbf7 100644 --- a/copyparty/httpcli.py +++ b/copyparty/httpcli.py @@ -1,18 +1,23 @@ # coding: utf-8 from __future__ import print_function, unicode_literals -import os -import stat -import gzip -import time -import copy -import json +import argparse # typechk import base64 -import string +import calendar +import copy +import gzip +import json +import os +import re import socket +import stat +import string +import threading # typechk +import time from datetime import datetime from operator import itemgetter -import calendar + +import jinja2 # typechk try: import lzma @@ -24,13 +29,62 @@ try: except: pass -from .__init__ import E, PY2, WINDOWS, ANYWIN, unicode -from .util import * # noqa # pylint: disable=unused-wildcard-import +from .__init__ import ANYWIN, PY2, TYPE_CHECKING, WINDOWS, E, unicode +from .authsrv import VFS # typechk from .bos import bos -from .authsrv import AuthSrv -from .szip import StreamZip from .star import StreamTar +from .sutil import StreamArc # typechk +from .szip import StreamZip +from .util import ( + HTTP_TS_FMT, + HTTPCODE, + META_NOBOTS, + MultipartParser, + Pebkac, + UnrecvEOF, + alltrace, + exclude_dotfiles, + fsenc, + gen_filekey, + gencookie, + get_spd, + guess_mime, + gzip_orig_sz, + hashcopy, + html_bescape, + html_escape, + http_ts, + humansize, + min_ex, + quotep, + read_header, + read_socket, + read_socket_chunked, + read_socket_unbounded, + relchk, + ren_open, + s3enc, + sanitize_fn, + sendfile_kern, + sendfile_py, + undot, + unescape_cookie, + unquote, + unquotep, + vol_san, + vsplit, + yieldfile, +) +try: + from typing import Any, Generator, Match, Optional, Pattern, Type, Union +except: + pass + +if TYPE_CHECKING: + from .httpconn import HttpConn + +_ = (argparse, threading) NO_CACHE = {"Cache-Control": "no-cache"} @@ -40,27 +94,60 @@ class HttpCli(object): Spawned by HttpConn to process one http transaction """ - def __init__(self, conn): + def __init__(self, conn: "HttpConn") -> None: + assert conn.sr + self.t0 = time.time() self.conn = conn - self.mutex = conn.mutex - self.s = conn.s # type: socket - self.sr = conn.sr # type: Unrecv + self.mutex = conn.mutex # mypy404 + self.s = conn.s + self.sr = conn.sr self.ip = conn.addr[0] - self.addr = conn.addr # type: tuple[str, int] - self.args = conn.args - self.asrv = conn.asrv # type: AuthSrv - self.ico = conn.ico - self.thumbcli = conn.thumbcli - self.u2fh = conn.u2fh - self.log_func = conn.log_func - self.log_src = conn.log_src - self.tls = hasattr(self.s, "cipher") + self.addr: tuple[str, int] = conn.addr + self.args = conn.args # mypy404 + self.asrv = conn.asrv # mypy404 + self.ico = conn.ico # mypy404 + self.thumbcli = conn.thumbcli # mypy404 + self.u2fh = conn.u2fh # mypy404 + self.log_func = conn.log_func # mypy404 + self.log_src = conn.log_src # mypy404 + self.tls: bool = hasattr(self.s, "cipher") + + # placeholders; assigned by run() + self.keepalive = False + self.is_https = False + self.headers: dict[str, str] = {} + self.mode = " " + self.req = " " + self.http_ver = " " + self.ua = " " + self.is_rclone = False + self.is_ancient = False + self.dip = " " + self.ouparam: dict[str, str] = {} + self.uparam: dict[str, str] = {} + self.cookies: dict[str, str] = {} + self.vpath = " " + self.uname = " " + self.rvol = [" "] + self.wvol = [" "] + self.mvol = [" "] + self.dvol = [" "] + self.gvol = [" "] + self.do_log = True + self.can_read = False + self.can_write = False + self.can_move = False + self.can_delete = False + self.can_get = False + # post + self.parser: Optional[MultipartParser] = None + # end placeholders self.bufsz = 1024 * 32 - self.hint = None + self.hint = "" self.trailing_slash = True - self.out_headerlist = [] + self.out_headerlist: list[tuple[str, str]] = [] self.out_headers = { "Access-Control-Allow-Origin": "*", "Cache-Control": "no-store; max-age=0", @@ -71,44 +158,44 @@ class HttpCli(object): self.out_headers["X-Robots-Tag"] = "noindex, nofollow" self.html_head = h - def log(self, msg, c=0): + def log(self, msg: str, c: Union[int, str] = 0) -> None: ptn = self.asrv.re_pwd if ptn and ptn.search(msg): msg = ptn.sub(self.unpwd, msg) self.log_func(self.log_src, msg, c) - def unpwd(self, m): + def unpwd(self, m: Match[str]) -> str: a, b = m.groups() return "=\033[7m {} \033[27m{}".format(self.asrv.iacct[a], b) - def _check_nonfatal(self, ex, post): + def _check_nonfatal(self, ex: Pebkac, post: bool) -> bool: if post: return ex.code < 300 return ex.code < 400 or ex.code in [404, 429] - def _assert_safe_rem(self, rem): + def _assert_safe_rem(self, rem: str) -> None: # sanity check to prevent any disasters if rem.startswith("/") or rem.startswith("../") or "/../" in rem: raise Exception("that was close") - def j2(self, name, **ka): + def j2s(self, name: str, **ka: Any) -> str: tpl = self.conn.hsrv.j2[name] - if ka: - ka["ts"] = self.conn.hsrv.cachebuster() - ka["svcname"] = self.args.doctitle - ka["html_head"] = self.html_head - return tpl.render(**ka) + ka["ts"] = self.conn.hsrv.cachebuster() + ka["svcname"] = self.args.doctitle + ka["html_head"] = self.html_head + return tpl.render(**ka) # type: ignore - return tpl + def j2j(self, name: str) -> jinja2.Template: + return self.conn.hsrv.j2[name] - def run(self): + def run(self) -> bool: """returns true if connection can be reused""" self.keepalive = False self.is_https = False self.headers = {} - self.hint = None + self.hint = "" try: headerlines = read_header(self.sr) if not headerlines: @@ -125,8 +212,8 @@ class HttpCli(object): # normalize incoming headers to lowercase; # outgoing headers however are Correct-Case for header_line in headerlines[1:]: - k, v = header_line.split(":", 1) - self.headers[k.lower()] = v.strip() + k, zs = header_line.split(":", 1) + self.headers[k.lower()] = zs.strip() except: msg = " ]\n#[ ".join(headerlines) raise Pebkac(400, "bad headers:\n#[ " + msg + " ]") @@ -147,23 +234,26 @@ class HttpCli(object): self.is_rclone = self.ua.startswith("rclone/") self.is_ancient = self.ua.startswith("Mozilla/4.") - v = self.headers.get("connection", "").lower() - self.keepalive = not v.startswith("close") and self.http_ver != "HTTP/1.0" - self.is_https = (self.headers.get("x-forwarded-proto", "").lower() == "https" or self.tls) + zs = self.headers.get("connection", "").lower() + self.keepalive = not zs.startswith("close") and self.http_ver != "HTTP/1.0" + self.is_https = ( + self.headers.get("x-forwarded-proto", "").lower() == "https" or self.tls + ) n = self.args.rproxy if n: - v = self.headers.get("x-forwarded-for") - if v and self.conn.addr[0] in ["127.0.0.1", "::1"]: + zso = self.headers.get("x-forwarded-for") + if zso and self.conn.addr[0] in ["127.0.0.1", "::1"]: if n > 0: n -= 1 - vs = v.split(",") + zsl = zso.split(",") try: - self.ip = vs[n].strip() + self.ip = zsl[n].strip() except: - self.ip = vs[0].strip() - self.log("rproxy={} oob x-fwd {}".format(self.args.rproxy, v), c=3) + self.ip = zsl[0].strip() + t = "rproxy={} oob x-fwd {}" + self.log(t.format(self.args.rproxy, zso), c=3) self.log_src = self.conn.set_rproxy(self.ip) @@ -175,9 +265,9 @@ class HttpCli(object): keys = list(sorted(self.headers.keys())) for k in keys: - v = self.headers.get(k) - if v is not None: - self.log("[H] {}: \033[33m[{}]".format(k, v), 6) + zso = self.headers.get(k) + if zso is not None: + self.log("[H] {}: \033[33m[{}]".format(k, zso), 6) if "&" in self.req and "?" not in self.req: self.hint = "did you mean '?' instead of '&'" @@ -193,20 +283,22 @@ class HttpCli(object): vpath = undot(vpath) for k in arglist.split("&"): if "=" in k: - k, v = k.split("=", 1) - uparam[k.lower()] = v.strip() + k, zs = k.split("=", 1) + uparam[k.lower()] = zs.strip() else: - uparam[k.lower()] = False + uparam[k.lower()] = "" - self.ouparam = {k: v for k, v in uparam.items()} + self.ouparam = {k: zs for k, zs in uparam.items()} - cookies = self.headers.get("cookie") or {} - if cookies: - cookies = [x.split("=", 1) for x in cookies.split(";") if "=" in x] - cookies = {k.strip(): unescape_cookie(v) for k, v in cookies} + zso = self.headers.get("cookie") + if zso: + zsll = [x.split("=", 1) for x in zso.split(";") if "=" in x] + cookies = {k.strip(): unescape_cookie(zs) for k, zs in zsll} for kc, ku in [["cppwd", "pw"], ["b", "b"]]: if kc in cookies and ku not in uparam: uparam[ku] = cookies[kc] + else: + cookies = {} if len(uparam) > 10 or len(cookies) > 50: raise Pebkac(400, "u wot m8") @@ -223,22 +315,22 @@ class HttpCli(object): self.log("invalid relpath [{}]".format(self.vpath)) return self.tx_404() and self.keepalive - pwd = None - ba = self.headers.get("authorization") - if ba: + pwd = "" + zso = self.headers.get("authorization") + if zso: try: - ba = ba.split(" ")[1].encode("ascii") - ba = base64.b64decode(ba).decode("utf-8") + zb = zso.split(" ")[1].encode("ascii") + zs = base64.b64decode(zb).decode("utf-8") # try "pwd", "x:pwd", "pwd:x" - for ba in [ba] + ba.split(":", 1)[::-1]: - if self.asrv.iacct.get(ba): - pwd = ba + for zs in [zs] + zs.split(":", 1)[::-1]: + if self.asrv.iacct.get(zs): + pwd = zs break except: pass pwd = uparam.get("pw") or pwd - self.uname = self.asrv.iacct.get(pwd, "*") + self.uname = self.asrv.iacct.get(pwd) or "*" self.rvol = self.asrv.vfs.aread[self.uname] self.wvol = self.asrv.vfs.awrite[self.uname] self.mvol = self.asrv.vfs.amove[self.uname] @@ -249,12 +341,13 @@ class HttpCli(object): self.out_headerlist.append(("Set-Cookie", self.get_pwd_cookie(pwd)[0])) if self.is_rclone: - uparam["raw"] = False - uparam["dots"] = False - uparam["b"] = False - cookies["b"] = False + uparam["raw"] = "" + uparam["dots"] = "" + uparam["b"] = "" + cookies["b"] = "" - self.do_log = not self.conn.lf_url or not self.conn.lf_url.search(self.req) + ptn: Optional[Pattern[str]] = self.conn.lf_url # mypy404 + self.do_log = not ptn or not ptn.search(self.req) x = self.asrv.vfs.can_access(self.vpath, self.uname) self.can_read, self.can_write, self.can_move, self.can_delete, self.can_get = x @@ -272,9 +365,10 @@ class HttpCli(object): raise Pebkac(400, 'invalid HTTP mode "{0}"'.format(self.mode)) except Exception as ex: - pex = ex if not hasattr(ex, "code"): pex = Pebkac(500) + else: + pex = ex # type: ignore try: post = self.mode in ["POST", "PUT"] or "content-length" in self.headers @@ -294,7 +388,7 @@ class HttpCli(object): except Pebkac: return False - def permit_caching(self): + def permit_caching(self) -> None: cache = self.uparam.get("cache") if cache is None: self.out_headers.update(NO_CACHE) @@ -303,11 +397,17 @@ class HttpCli(object): n = "604800" if cache == "i" else cache or "69" self.out_headers["Cache-Control"] = "max-age=" + n - def k304(self): + def k304(self) -> bool: k304 = self.cookies.get("k304") return k304 == "y" or ("; Trident/" in self.ua and not k304) - def send_headers(self, length, status=200, mime=None, headers=None): + def send_headers( + self, + length: Optional[int], + status: int = 200, + mime: Optional[str] = None, + headers: Optional[dict[str, str]] = None, + ) -> None: response = ["{} {} {}".format(self.http_ver, status, HTTPCODE[status])] if length is not None: @@ -327,10 +427,11 @@ class HttpCli(object): if not mime: mime = self.out_headers.get("Content-Type", "text/html; charset=utf-8") + assert mime self.out_headers["Content-Type"] = mime - for k, v in list(self.out_headers.items()) + self.out_headerlist: - response.append("{}: {}".format(k, v)) + for k, zs in list(self.out_headers.items()) + self.out_headerlist: + response.append("{}: {}".format(k, zs)) try: # best practice to separate headers and body into different packets @@ -338,11 +439,19 @@ class HttpCli(object): except: raise Pebkac(400, "client d/c while replying headers") - def reply(self, body, status=200, mime=None, headers=None, volsan=False): + def reply( + self, + body: bytes, + status: int = 200, + mime: Optional[str] = None, + headers: Optional[dict[str, str]] = None, + volsan: bool = False, + ) -> bytes: # TODO something to reply with user-supplied values safely if volsan: - body = vol_san(self.asrv.vfs.all_vols.values(), body) + vols = list(self.asrv.vfs.all_vols.values()) + body = vol_san(vols, body) self.send_headers(len(body), status, mime, headers) @@ -354,17 +463,19 @@ class HttpCli(object): return body - def loud_reply(self, body, *args, **kwargs): + def loud_reply(self, body: str, *args: Any, **kwargs: Any) -> None: if not kwargs.get("mime"): kwargs["mime"] = "text/plain; charset=utf-8" self.log(body.rstrip()) self.reply(body.encode("utf-8") + b"\r\n", *list(args), **kwargs) - def urlq(self, add, rm): + def urlq(self, add: dict[str, str], rm: list[str]) -> str: """ generates url query based on uparam (b, pw, all others) removing anything in rm, adding pairs in add + + also list faster than set until ~20 items """ if self.is_rclone: @@ -372,28 +483,28 @@ class HttpCli(object): cmap = {"pw": "cppwd"} kv = { - k: v - for k, v in self.uparam.items() - if k not in rm and self.cookies.get(cmap.get(k, k)) != v + k: zs + for k, zs in self.uparam.items() + if k not in rm and self.cookies.get(cmap.get(k, k)) != zs } kv.update(add) if not kv: return "" - r = ["{}={}".format(k, quotep(v)) if v else k for k, v in kv.items()] + r = ["{}={}".format(k, quotep(zs)) if zs else k for k, zs in kv.items()] return "?" + "&".join(r) def redirect( self, - vpath, - suf="", - msg="aight", - flavor="go to", - click=True, - status=200, - use302=False, - ): - html = self.j2( + vpath: str, + suf: str = "", + msg: str = "aight", + flavor: str = "go to", + click: bool = True, + status: int = 200, + use302: bool = False, + ) -> bool: + html = self.j2s( "msg", h2='{} /{}'.format( quotep(vpath) + suf, flavor, html_escape(vpath, crlf=True) + suf @@ -407,7 +518,9 @@ class HttpCli(object): else: self.reply(html, status=status) - def handle_get(self): + return True + + def handle_get(self) -> bool: if self.do_log: logmsg = "{:4} {}".format(self.mode, self.req) @@ -434,13 +547,13 @@ class HttpCli(object): self.log("inaccessible: [{}]".format(self.vpath)) return self.tx_404(True) - self.uparam["h"] = False + self.uparam["h"] = "" if "tree" in self.uparam: return self.tx_tree() if "delete" in self.uparam: - return self.handle_rm() + return self.handle_rm([]) if "move" in self.uparam: return self.handle_mv() @@ -486,7 +599,7 @@ class HttpCli(object): return self.tx_browser() - def handle_options(self): + def handle_options(self) -> bool: if self.do_log: self.log("OPTIONS " + self.req) @@ -501,7 +614,7 @@ class HttpCli(object): ) return True - def handle_put(self): + def handle_put(self) -> bool: self.log("PUT " + self.req) if self.headers.get("expect", "").lower() == "100-continue": @@ -512,7 +625,7 @@ class HttpCli(object): return self.handle_stash() - def handle_post(self): + def handle_post(self) -> bool: self.log("POST " + self.req) if self.headers.get("expect", "").lower() == "100-continue": @@ -549,16 +662,16 @@ class HttpCli(object): reader, _ = self.get_body_reader() for buf in reader: orig = buf.decode("utf-8", "replace") - m = "urlform_raw {} @ {}\n {}\n" - self.log(m.format(len(orig), self.vpath, orig)) + t = "urlform_raw {} @ {}\n {}\n" + self.log(t.format(len(orig), self.vpath, orig)) try: - plain = unquote(buf.replace(b"+", b" ")) - plain = plain.decode("utf-8", "replace") + zb = unquote(buf.replace(b"+", b" ")) + plain = zb.decode("utf-8", "replace") if buf.startswith(b"msg="): plain = plain[4:] - m = "urlform_dec {} @ {}\n {}\n" - self.log(m.format(len(plain), self.vpath, plain)) + t = "urlform_dec {} @ {}\n {}\n" + self.log(t.format(len(plain), self.vpath, plain)) except Exception as ex: self.log(repr(ex)) @@ -569,7 +682,7 @@ class HttpCli(object): raise Pebkac(405, "don't know how to handle POST({})".format(ctype)) - def get_body_reader(self): + def get_body_reader(self) -> tuple[Generator[bytes, None, None], int]: if "chunked" in self.headers.get("transfer-encoding", "").lower(): return read_socket_chunked(self.sr), -1 @@ -580,7 +693,8 @@ class HttpCli(object): else: return read_socket(self.sr, remains), remains - def dump_to_file(self): + def dump_to_file(self) -> tuple[int, str, str, int, str, str]: + # post_sz, sha_hex, sha_b64, remains, path, url reader, remains = self.get_body_reader() vfs, rem = self.asrv.vfs.get(self.vpath, self.uname, False, True) lim = vfs.get_dbv(rem)[0].lim @@ -595,7 +709,7 @@ class HttpCli(object): bos.makedirs(fdir) - open_ka = {"fun": open} + open_ka: dict[str, Any] = {"fun": open} open_a = ["wb", 512 * 1024] # user-request || config-force @@ -607,7 +721,7 @@ class HttpCli(object): ): fb = {"gz": 9, "xz": 0} # default/fallback level lv = {} # selected level - alg = None # selected algo (gz=preferred) + alg = "" # selected algo (gz=preferred) # user-prefs first if "gz" in self.uparam or "pk" in self.uparam: # def.pk @@ -615,8 +729,8 @@ class HttpCli(object): if "xz" in self.uparam: alg = "xz" if alg: - v = self.uparam.get(alg) - lv[alg] = fb[alg] if v is None else int(v) + zso = self.uparam.get(alg) + lv[alg] = fb[alg] if zso is None else int(zso) if alg not in vfs.flags: alg = "gz" if "gz" in vfs.flags else "xz" @@ -633,7 +747,7 @@ class HttpCli(object): except: pass - lv[alg] = lv.get(alg) or fb.get(alg) + lv[alg] = lv.get(alg) or fb.get(alg) or 0 self.log("compressing with {} level {}".format(alg, lv.get(alg))) if alg == "gz": @@ -656,8 +770,8 @@ class HttpCli(object): if not fn: fn = "put" + suffix - with ren_open(fn, *open_a, **params) as f: - f, fn = f["orz"] + with ren_open(fn, *open_a, **params) as zfw: + f, fn = zfw["orz"] path = os.path.join(fdir, fn) post_sz, sha_hex, sha_b64 = hashcopy(reader, f, self.args.s_wr_slp) @@ -674,8 +788,7 @@ class HttpCli(object): return post_sz, sha_hex, sha_b64, remains, path, "" vfs, rem = vfs.get_dbv(rem) - self.conn.hsrv.broker.put( - False, + self.conn.hsrv.broker.say( "up2k.hash_file", vfs.realpath, vfs.flags, @@ -705,16 +818,16 @@ class HttpCli(object): return post_sz, sha_hex, sha_b64, remains, path, url - def handle_stash(self): + def handle_stash(self) -> bool: post_sz, sha_hex, sha_b64, remains, path, url = self.dump_to_file() spd = self._spd(post_sz) - m = "{} wrote {}/{} bytes to {} # {}" - self.log(m.format(spd, post_sz, remains, path, sha_b64[:28])) # 21 - m = "{}\n{}\n{}\n{}\n".format(post_sz, sha_b64, sha_hex[:56], url) - self.reply(m.encode("utf-8")) + t = "{} wrote {}/{} bytes to {} # {}" + self.log(t.format(spd, post_sz, remains, path, sha_b64[:28])) # 21 + t = "{}\n{}\n{}\n{}\n".format(post_sz, sha_b64, sha_hex[:56], url) + self.reply(t.encode("utf-8")) return True - def _spd(self, nbytes, add=True): + def _spd(self, nbytes: int, add: bool = True) -> str: if add: self.conn.nbyte += nbytes @@ -722,7 +835,7 @@ class HttpCli(object): spd2 = get_spd(self.conn.nbyte, self.conn.t0) return "{} {} n{}".format(spd1, spd2, self.conn.nreq) - def handle_post_multipart(self): + def handle_post_multipart(self) -> bool: self.parser = MultipartParser(self.log, self.sr, self.headers) self.parser.parse() @@ -749,7 +862,8 @@ class HttpCli(object): raise Pebkac(422, 'invalid action "{}"'.format(act)) - def handle_zip_post(self): + def handle_zip_post(self) -> bool: + assert self.parser for k in ["zip", "tar"]: v = self.uparam.get(k) if v is not None: @@ -759,17 +873,17 @@ class HttpCli(object): raise Pebkac(422, "need zip or tar keyword") vn, rem = self.asrv.vfs.get(self.vpath, self.uname, True, False) - items = self.parser.require("files", 1024 * 1024) - if not items: + zs = self.parser.require("files", 1024 * 1024) + if not zs: raise Pebkac(422, "need files list") - items = items.replace("\r", "").split("\n") + items = zs.replace("\r", "").split("\n") items = [unquotep(x) for x in items if items] self.parser.drop() return self.tx_zip(k, v, vn, rem, items, self.args.ed) - def handle_post_json(self): + def handle_post_json(self) -> bool: try: remains = int(self.headers["content-length"]) except: @@ -836,14 +950,14 @@ class HttpCli(object): except: raise Pebkac(500, min_ex()) - x = self.conn.hsrv.broker.put(True, "up2k.handle_json", body) + x = self.conn.hsrv.broker.ask("up2k.handle_json", body) ret = x.get() ret = json.dumps(ret) self.log(ret) self.reply(ret.encode("utf-8"), mime="application/json") return True - def handle_search(self, body): + def handle_search(self, body: dict[str, Any]) -> bool: idx = self.conn.get_u2idx() if not hasattr(idx, "p_end"): raise Pebkac(500, "sqlite3 is not available on the server; cannot search") @@ -857,15 +971,15 @@ class HttpCli(object): continue seen[vfs] = True - vols.append([vfs.vpath, vfs.realpath, vfs.flags]) + vols.append((vfs.vpath, vfs.realpath, vfs.flags)) t0 = time.time() if idx.p_end: penalty = 0.7 t_idle = t0 - idx.p_end if idx.p_dur > 0.7 and t_idle < penalty: - m = "rate-limit {:.1f} sec, cost {:.2f}, idle {:.2f}" - raise Pebkac(429, m.format(penalty, idx.p_dur, t_idle)) + t = "rate-limit {:.1f} sec, cost {:.2f}, idle {:.2f}" + raise Pebkac(429, t.format(penalty, idx.p_dur, t_idle)) if "srch" in body: # search by up2k hashlist @@ -873,8 +987,8 @@ class HttpCli(object): vbody["hash"] = len(vbody["hash"]) self.log("qj: " + repr(vbody)) hits = idx.fsearch(vols, body) - msg = repr(hits) - taglist = {} + msg: Any = repr(hits) + taglist: list[str] = [] else: # search by query params q = body["q"] @@ -900,7 +1014,7 @@ class HttpCli(object): self.reply(r, mime="application/json") return True - def handle_post_binary(self): + def handle_post_binary(self) -> bool: try: remains = int(self.headers["content-length"]) except: @@ -915,7 +1029,7 @@ class HttpCli(object): vfs, _ = self.asrv.vfs.get(self.vpath, self.uname, False, True) ptop = (vfs.dbv or vfs).realpath - x = self.conn.hsrv.broker.put(True, "up2k.handle_chunk", ptop, wark, chash) + x = self.conn.hsrv.broker.ask("up2k.handle_chunk", ptop, wark, chash) response = x.get() chunksize, cstart, path, lastmod = response @@ -946,8 +1060,8 @@ class HttpCli(object): post_sz, _, sha_b64 = hashcopy(reader, f, self.args.s_wr_slp) if sha_b64 != chash: - m = "your chunk got corrupted somehow (received {} bytes); expected vs received hash:\n{}\n{}" - raise Pebkac(400, m.format(post_sz, chash, sha_b64)) + t = "your chunk got corrupted somehow (received {} bytes); expected vs received hash:\n{}\n{}" + raise Pebkac(400, t.format(post_sz, chash, sha_b64)) if len(cstart) > 1 and path != os.devnull: self.log( @@ -974,15 +1088,15 @@ class HttpCli(object): with self.mutex: self.u2fh.put(path, f) finally: - x = self.conn.hsrv.broker.put(True, "up2k.release_chunk", ptop, wark, chash) + x = self.conn.hsrv.broker.ask("up2k.release_chunk", ptop, wark, chash) x.get() # block client until released - x = self.conn.hsrv.broker.put(True, "up2k.confirm_chunk", ptop, wark, chash) - x = x.get() + x = self.conn.hsrv.broker.ask("up2k.confirm_chunk", ptop, wark, chash) + ztis = x.get() try: - num_left, fin_path = x + num_left, fin_path = ztis except: - self.loud_reply(x, status=500) + self.loud_reply(ztis, status=500) return False if not num_left and fpool: @@ -991,7 +1105,7 @@ class HttpCli(object): # windows cant rename open files if ANYWIN and path != fin_path and not self.args.nw: - self.conn.hsrv.broker.put(True, "up2k.finish_upload", ptop, wark).get() + self.conn.hsrv.broker.ask("up2k.finish_upload", ptop, wark).get() if not ANYWIN and not num_left: times = (int(time.time()), int(lastmod)) @@ -1006,7 +1120,8 @@ class HttpCli(object): self.reply(b"thank") return True - def handle_login(self): + def handle_login(self) -> bool: + assert self.parser pwd = self.parser.require("cppwd", 64) self.parser.drop() @@ -1021,11 +1136,11 @@ class HttpCli(object): dst += quotep(self.vpath) ck, msg = self.get_pwd_cookie(pwd) - html = self.j2("msg", h1=msg, h2='ack', redir=dst) + html = self.j2s("msg", h1=msg, h2='ack', redir=dst) self.reply(html.encode("utf-8"), headers={"Set-Cookie": ck}) return True - def get_pwd_cookie(self, pwd): + def get_pwd_cookie(self, pwd: str) -> tuple[str, str]: if pwd in self.asrv.iacct: msg = "login ok" dur = int(60 * 60 * self.args.logout) @@ -1038,9 +1153,10 @@ class HttpCli(object): if self.is_ancient: r = r.rsplit(" ", 1)[0] - return [r, msg] + return r, msg - def handle_mkdir(self): + def handle_mkdir(self) -> bool: + assert self.parser new_dir = self.parser.require("name", 512) self.parser.drop() @@ -1075,7 +1191,8 @@ class HttpCli(object): self.redirect(vpath) return True - def handle_new_md(self): + def handle_new_md(self) -> bool: + assert self.parser new_file = self.parser.require("name", 512) self.parser.drop() @@ -1102,7 +1219,8 @@ class HttpCli(object): self.redirect(vpath, "?edit") return True - def handle_plain_upload(self): + def handle_plain_upload(self) -> bool: + assert self.parser nullwrite = self.args.nw vfs, rem = self.asrv.vfs.get(self.vpath, self.uname, False, True) self._assert_safe_rem(rem) @@ -1116,17 +1234,21 @@ class HttpCli(object): if not nullwrite: bos.makedirs(fdir_base) - files = [] + files: list[tuple[int, str, str, str, str, str]] = [] + # sz, sha_hex, sha_b64, p_file, fname, abspath errmsg = "" t0 = time.time() try: + assert self.parser.gen for nfile, (p_field, p_file, p_data) in enumerate(self.parser.gen): if not p_file: self.log("discarding incoming file without filename") # fallthrough fdir = fdir_base - fname = sanitize_fn(p_file, "", [".prologue.html", ".epilogue.html"]) + fname = sanitize_fn( + p_file or "", "", [".prologue.html", ".epilogue.html"] + ) if p_file and not nullwrite: if not bos.path.isdir(fdir): raise Pebkac(404, "that folder does not exist") @@ -1143,8 +1265,8 @@ class HttpCli(object): lim.chk_nup(self.ip) try: - with ren_open(fname, "wb", 512 * 1024, **open_args) as f: - f, fname = f["orz"] + with ren_open(fname, "wb", 512 * 1024, **open_args) as zfw: + f, fname = zfw["orz"] abspath = os.path.join(fdir, fname) self.log("writing to {}".format(abspath)) sz, sha_hex, sha_b64 = hashcopy(p_data, f, self.args.s_wr_slp) @@ -1158,12 +1280,14 @@ class HttpCli(object): lim.chk_sz(sz) except: bos.unlink(abspath) + fname = os.devnull raise - files.append([sz, sha_hex, sha_b64, p_file, fname, abspath]) + files.append( + (sz, sha_hex, sha_b64, p_file or "(discarded)", fname, abspath) + ) dbv, vrem = vfs.get_dbv(rem) - self.conn.hsrv.broker.put( - False, + self.conn.hsrv.broker.say( "up2k.hash_file", dbv.realpath, dbv.flags, @@ -1192,7 +1316,7 @@ class HttpCli(object): except Pebkac as ex: errmsg = vol_san( - self.asrv.vfs.all_vols.values(), unicode(ex).encode("utf-8") + list(self.asrv.vfs.all_vols.values()), unicode(ex).encode("utf-8") ).decode("utf-8") td = max(0.1, time.time() - t0) @@ -1205,7 +1329,12 @@ class HttpCli(object): status = "ERROR" msg = "{} // {} bytes // {:.3f} MiB/s\n".format(status, sz_total, spd) - jmsg = {"status": status, "sz": sz_total, "mbps": round(spd, 3), "files": []} + jmsg: dict[str, Any] = { + "status": status, + "sz": sz_total, + "mbps": round(spd, 3), + "files": [], + } if errmsg: msg += errmsg + "\n" @@ -1260,17 +1389,17 @@ class HttpCli(object): ft = "{}\n{}\n{}\n".format(ft, msg.rstrip(), errmsg) f.write(ft.encode("utf-8")) - status = 400 if errmsg else 200 + sc = 400 if errmsg else 200 if "j" in self.uparam: jtxt = json.dumps(jmsg, indent=2, sort_keys=True).encode("utf-8", "replace") - self.reply(jtxt, mime="application/json", status=status) + self.reply(jtxt, mime="application/json", status=sc) else: self.redirect( self.vpath, msg=msg, flavor="return to", click=False, - status=status, + status=sc, ) if errmsg: @@ -1279,7 +1408,8 @@ class HttpCli(object): self.parser.drop() return True - def handle_text_upload(self): + def handle_text_upload(self) -> bool: + assert self.parser try: cli_lastmod3 = int(self.parser.require("lastmod", 16)) except: @@ -1314,7 +1444,8 @@ class HttpCli(object): self.reply(response.encode("utf-8")) return True - srv_lastmod = srv_lastmod3 = -1 + srv_lastmod = -1.0 + srv_lastmod3 = -1 try: st = bos.stat(fp) srv_lastmod = st.st_mtime @@ -1329,7 +1460,7 @@ class HttpCli(object): if not same_lastmod: # some filesystems/transports limit precision to 1sec, hopefully floored same_lastmod = ( - srv_lastmod == int(srv_lastmod) + srv_lastmod == int(cli_lastmod3 / 1000) and cli_lastmod3 > srv_lastmod3 and cli_lastmod3 - srv_lastmod3 < 1000 ) @@ -1360,6 +1491,7 @@ class HttpCli(object): pass bos.rename(fp, os.path.join(mdir, ".hist", mfile2)) + assert self.parser.gen p_field, _, p_data = next(self.parser.gen) if p_field != "body": raise Pebkac(400, "expected body, got {}".format(p_field)) @@ -1388,7 +1520,7 @@ class HttpCli(object): self.reply(response.encode("utf-8")) return True - def _chk_lastmod(self, file_ts): + def _chk_lastmod(self, file_ts: int) -> tuple[str, bool]: file_lastmod = http_ts(file_ts) cli_lastmod = self.headers.get("if-modified-since") if cli_lastmod: @@ -1408,7 +1540,7 @@ class HttpCli(object): return file_lastmod, True - def tx_file(self, req_path): + def tx_file(self, req_path: str) -> bool: status = 200 logmsg = "{:4} {} ".format("", self.req) logtail = "" @@ -1417,7 +1549,7 @@ class HttpCli(object): # if request is for foo.js, check if we have foo.js.{gz,br} file_ts = 0 - editions = {} + editions: dict[str, tuple[str, int]] = {} for ext in ["", ".gz", ".br"]: try: fs_path = req_path + ext @@ -1425,8 +1557,8 @@ class HttpCli(object): if stat.S_ISDIR(st.st_mode): continue - file_ts = max(file_ts, st.st_mtime) - editions[ext or "plain"] = [fs_path, st.st_size] + file_ts = max(file_ts, int(st.st_mtime)) + editions[ext or "plain"] = (fs_path, st.st_size) except: pass if not self.vpath.startswith(".cpr/"): @@ -1526,8 +1658,8 @@ class HttpCli(object): use_sendfile = False if decompress: - open_func = gzip.open - open_args = [fsenc(fs_path), "rb"] + open_func: Any = gzip.open + open_args: list[Any] = [fsenc(fs_path), "rb"] # Content-Length := original file size upper = gzip_orig_sz(fs_path) else: @@ -1551,7 +1683,7 @@ class HttpCli(object): if "txt" in self.uparam: mime = "text/plain; charset={}".format(self.uparam["txt"] or "utf-8") elif "mime" in self.uparam: - mime = self.uparam.get("mime") + mime = str(self.uparam.get("mime")) else: mime = guess_mime(req_path) @@ -1583,19 +1715,18 @@ class HttpCli(object): return ret - def tx_zip(self, fmt, uarg, vn, rem, items, dots): + def tx_zip( + self, fmt: str, uarg: str, vn: VFS, rem: str, items: list[str], dots: bool + ) -> bool: if self.args.no_zip: raise Pebkac(400, "not enabled") logmsg = "{:4} {} ".format("", self.req) self.keepalive = False - if not uarg: - uarg = "" - if fmt == "tar": mime = "application/x-tar" - packer = StreamTar + packer: Type[StreamArc] = StreamTar else: mime = "application/zip" packer = StreamZip @@ -1609,24 +1740,25 @@ class HttpCli(object): safe = (string.ascii_letters + string.digits).replace("%", "") afn = "".join([x if x in safe.replace('"', "") else "_" for x in fn]) bascii = unicode(safe).encode("utf-8") - ufn = fn.encode("utf-8", "xmlcharrefreplace") - if PY2: - ufn = [unicode(x) if x in bascii else "%{:02x}".format(ord(x)) for x in ufn] - else: - ufn = [ + zb = fn.encode("utf-8", "xmlcharrefreplace") + if not PY2: + zbl = [ chr(x).encode("utf-8") if x in bascii else "%{:02x}".format(x).encode("ascii") - for x in ufn + for x in zb ] - ufn = b"".join(ufn).decode("ascii") + else: + zbl = [unicode(x) if x in bascii else "%{:02x}".format(ord(x)) for x in zb] + + ufn = b"".join(zbl).decode("ascii") cdis = "attachment; filename=\"{}.{}\"; filename*=UTF-8''{}.{}" cdis = cdis.format(afn, fmt, ufn, fmt) self.log(cdis) self.send_headers(None, mime=mime, headers={"Content-Disposition": cdis}) - fgen = vn.zipgen(rem, items, self.uname, dots, not self.args.no_scandir) + fgen = vn.zipgen(rem, set(items), self.uname, dots, not self.args.no_scandir) # for f in fgen: print(repr({k: f[k] for k in ["vp", "ap"]})) bgen = packer(self.log, fgen, utf8="utf" in uarg, pre_crc="crc" in uarg) bsent = 0 @@ -1645,7 +1777,7 @@ class HttpCli(object): self.log("{}, {}".format(logmsg, spd)) return True - def tx_ico(self, ext, exact=False): + def tx_ico(self, ext: str, exact: bool = False) -> bool: self.permit_caching() if ext.endswith("/"): ext = "folder" @@ -1674,7 +1806,7 @@ class HttpCli(object): self.reply(ico, mime=mime, headers={"Last-Modified": lm}) return True - def tx_md(self, fs_path): + def tx_md(self, fs_path: str) -> bool: logmsg = "{:4} {} ".format("", self.req) if not self.can_write: @@ -1683,7 +1815,7 @@ class HttpCli(object): tpl = "mde" if "edit2" in self.uparam else "md" html_path = os.path.join(E.mod, "web", "{}.html".format(tpl)) - template = self.j2(tpl) + template = self.j2j(tpl) st = bos.stat(fs_path) ts_md = st.st_mtime @@ -1694,7 +1826,7 @@ class HttpCli(object): sz_md = 0 for buf in yieldfile(fs_path): sz_md += len(buf) - for c, v in [[b"&", 4], [b"<", 3], [b">", 3]]: + for c, v in [(b"&", 4), (b"<", 3), (b">", 3)]: sz_md += (len(buf) - len(buf.replace(c, b""))) * v file_ts = max(ts_md, ts_html, E.t0) @@ -1720,8 +1852,8 @@ class HttpCli(object): "md": boundary, "arg_base": arg_base, } - html = template.render(**targs).encode("utf-8", "replace") - html = html.split(boundary.encode("utf-8")) + zs = template.render(**targs).encode("utf-8", "replace") + html = zs.split(boundary.encode("utf-8")) if len(html) != 2: raise Exception("boundary appears in " + html_path) @@ -1750,7 +1882,7 @@ class HttpCli(object): return True - def tx_mounts(self): + def tx_mounts(self) -> bool: suf = self.urlq({}, ["h"]) avol = [x for x in self.wvol if x in self.rvol] rvol, wvol, avol = [ @@ -1759,7 +1891,7 @@ class HttpCli(object): ] if avol and not self.args.no_rescan: - x = self.conn.hsrv.broker.put(True, "up2k.get_state") + x = self.conn.hsrv.broker.ask("up2k.get_state") vs = json.loads(x.get()) vstate = {("/" + k).rstrip("/") + "/": v for k, v in vs["volstate"].items()} else: @@ -1787,11 +1919,11 @@ class HttpCli(object): for v in wvol: txt += "\n " + v - txt = txt.encode("utf-8", "replace") + b"\n" - self.reply(txt, mime="text/plain; charset=utf-8") + zb = txt.encode("utf-8", "replace") + b"\n" + self.reply(zb, mime="text/plain; charset=utf-8") return True - html = self.j2( + html = self.j2s( "splash", this=self, qvpath=quotep(self.vpath), @@ -1809,38 +1941,41 @@ class HttpCli(object): self.reply(html.encode("utf-8")) return True - def set_k304(self): + def set_k304(self) -> bool: ck = gencookie("k304", self.uparam["k304"], 60 * 60 * 24 * 365) self.out_headerlist.append(("Set-Cookie", ck)) self.redirect("", "?h#cc") + return True - def set_am_js(self): + def set_am_js(self) -> bool: v = "n" if self.uparam["am_js"] == "n" else "y" ck = gencookie("js", v, 60 * 60 * 24 * 365) self.out_headerlist.append(("Set-Cookie", ck)) self.reply(b"promoted\n") + return True - def set_cfg_reset(self): + def set_cfg_reset(self) -> bool: for k in ("k304", "js", "cppwd"): self.out_headerlist.append(("Set-Cookie", gencookie(k, "x", None))) self.redirect("", "?h#cc") + return True - def tx_404(self, is_403=False): + def tx_404(self, is_403: bool = False) -> bool: rc = 404 if self.args.vague_403: - m = '
or maybe you don\'t have access -- try logging in or go home
' + t = 'or maybe you don\'t have access -- try logging in or go home
' elif is_403: - m = 'you\'ll have to log in or go home
' + t = 'you\'ll have to log in or go home
' rc = 403 else: - m = '{}\n{}".format(time.time(), html_escape(alltrace()))
self.reply(ret.encode("utf-8"))
+ return True
- def tx_tree(self):
+ def tx_tree(self) -> bool:
top = self.uparam["tree"] or ""
dst = self.vpath
if top in [".", ".."]:
@@ -1898,12 +2034,12 @@ class HttpCli(object):
dst = dst[len(top) + 1 :]
ret = self.gen_tree(top, dst)
- ret = json.dumps(ret)
- self.reply(ret.encode("utf-8"), mime="application/json")
+ zs = json.dumps(ret)
+ self.reply(zs.encode("utf-8"), mime="application/json")
return True
- def gen_tree(self, top, target):
- ret = {}
+ def gen_tree(self, top: str, target: str) -> dict[str, Any]:
+ ret: dict[str, Any] = {}
excl = None
if target:
excl, target = (target.split("/", 1) + [""])[:2]
@@ -1921,26 +2057,26 @@ class HttpCli(object):
for v in self.rvol:
d1, d2 = v.rsplit("/", 1) if "/" in v else ["", v]
if d1 == top:
- vfs_virt[d2] = 0
+ vfs_virt[d2] = vn # typechk, value never read
dirs = []
- vfs_ls = [x[0] for x in vfs_ls if stat.S_ISDIR(x[1].st_mode)]
+ dirnames = [x[0] for x in vfs_ls if stat.S_ISDIR(x[1].st_mode)]
if not self.args.ed or "dots" not in self.uparam:
- vfs_ls = exclude_dotfiles(vfs_ls)
+ dirnames = exclude_dotfiles(dirnames)
- for fn in [x for x in vfs_ls if x != excl]:
+ for fn in [x for x in dirnames if x != excl]:
dirs.append(quotep(fn))
- for x in vfs_virt.keys():
+ for x in vfs_virt:
if x != excl:
dirs.append(x)
ret["a"] = dirs
return ret
- def tx_ups(self):
+ def tx_ups(self) -> bool:
if not self.args.unpost:
raise Pebkac(400, "the unpost feature is disabled in server config")
@@ -1952,7 +2088,7 @@ class HttpCli(object):
lm = "ups [{}]".format(filt)
self.log(lm)
- ret = []
+ ret: list[dict[str, Any]] = []
t0 = time.time()
lim = time.time() - self.args.unpost
for vol in self.asrv.vfs.all_vols.values():
@@ -1968,17 +2104,18 @@ class HttpCli(object):
ret.append({"vp": quotep(vp), "sz": sz, "at": at})
if len(ret) > 3000:
- ret.sort(key=lambda x: x["at"], reverse=True)
+ ret.sort(key=lambda x: x["at"], reverse=True) # type: ignore
ret = ret[:2000]
- ret.sort(key=lambda x: x["at"], reverse=True)
+ ret.sort(key=lambda x: x["at"], reverse=True) # type: ignore
ret = ret[:2000]
jtxt = json.dumps(ret, indent=2, sort_keys=True).encode("utf-8", "replace")
self.log("{} #{} {:.2f}sec".format(lm, len(ret), time.time() - t0))
self.reply(jtxt, mime="application/json")
+ return True
- def handle_rm(self, req=None):
+ def handle_rm(self, req: list[str]) -> bool:
if not req and not self.can_delete:
raise Pebkac(403, "not allowed for user " + self.uname)
@@ -1988,10 +2125,11 @@ class HttpCli(object):
if not req:
req = [self.vpath]
- x = self.conn.hsrv.broker.put(True, "up2k.handle_rm", self.uname, self.ip, req)
+ x = self.conn.hsrv.broker.ask("up2k.handle_rm", self.uname, self.ip, req)
self.loud_reply(x.get())
+ return True
- def handle_mv(self):
+ def handle_mv(self) -> bool:
if not self.can_move:
raise Pebkac(403, "not allowed for user " + self.uname)
@@ -2006,12 +2144,11 @@ class HttpCli(object):
# x-www-form-urlencoded (url query part) uses
# either + or %20 for 0x20 so handle both
dst = unquotep(dst.replace("+", " "))
- x = self.conn.hsrv.broker.put(
- True, "up2k.handle_mv", self.uname, self.vpath, dst
- )
+ x = self.conn.hsrv.broker.ask("up2k.handle_mv", self.uname, self.vpath, dst)
self.loud_reply(x.get())
+ return True
- def tx_ls(self, ls):
+ def tx_ls(self, ls: dict[str, Any]) -> bool:
dirs = ls["dirs"]
files = ls["files"]
arg = self.uparam["ls"]
@@ -2055,17 +2192,17 @@ class HttpCli(object):
x["name"] = n
fmt = fmt.format(len(nfmt.format(biggest)))
- ret = [
+ retl = [
"# {}: {}".format(x, ls[x])
for x in ["acct", "perms", "srvinf"]
if x in ls
]
- ret += [
+ retl += [
fmt.format(x["dt"], x["sz"], x["name"])
for y in [dirs, files]
for x in y
]
- ret = "\n".join(ret)
+ ret = "\n".join(retl)
mime = "text/plain; charset=utf-8"
else:
[x.pop(k) for k in ["name", "dt"] for y in [dirs, files] for x in y]
@@ -2076,7 +2213,7 @@ class HttpCli(object):
self.reply(ret.encode("utf-8", "replace") + b"\n", mime=mime)
return True
- def tx_browser(self):
+ def tx_browser(self) -> bool:
vpath = ""
vpnodes = [["", "/"]]
if self.vpath:
@@ -2164,7 +2301,7 @@ class HttpCli(object):
if WINDOWS:
try:
bfree = ctypes.c_ulonglong(0)
- ctypes.windll.kernel32.GetDiskFreeSpaceExW(
+ ctypes.windll.kernel32.GetDiskFreeSpaceExW( # type: ignore
ctypes.c_wchar_p(abspath), None, None, ctypes.pointer(bfree)
)
srv_info.append(humansize(bfree.value) + " free")
@@ -2179,7 +2316,7 @@ class HttpCli(object):
except:
pass
- srv_info = " // ".join(srv_info)
+ srv_infot = " // ".join(srv_info)
perms = []
if self.can_read:
@@ -2223,7 +2360,7 @@ class HttpCli(object):
"dirs": [],
"files": [],
"taglist": [],
- "srvinf": srv_info,
+ "srvinf": srv_infot,
"acct": self.uname,
"idx": ("e2d" in vn.flags),
"perms": perms,
@@ -2251,7 +2388,7 @@ class HttpCli(object):
"logues": logues,
"readme": readme,
"title": html_escape(self.vpath, crlf=True),
- "srv_info": srv_info,
+ "srv_info": srv_infot,
"lang": self.args.lang,
"dtheme": self.args.theme,
"themes": self.args.themes,
@@ -2267,7 +2404,7 @@ class HttpCli(object):
if "zip" in self.uparam or "tar" in self.uparam:
raise Pebkac(403)
- html = self.j2(tpl, **j2a)
+ html = self.j2s(tpl, **j2a)
self.reply(html.encode("utf-8", "replace"))
return True
@@ -2280,11 +2417,12 @@ class HttpCli(object):
rem, self.uname, not self.args.no_scandir, [[True], [False, True]]
)
stats = {k: v for k, v in vfs_ls}
- vfs_ls = [x[0] for x in vfs_ls]
- vfs_ls.extend(vfs_virt.keys())
+ ls_names = [x[0] for x in vfs_ls]
+ ls_names.extend(list(vfs_virt.keys()))
# check for old versions of files,
- hist = {} # [num-backups, most-recent, hist-path]
+ # [num-backups, most-recent, hist-path]
+ hist: dict[str, tuple[int, float, str]] = {}
histdir = os.path.join(fsroot, ".hist")
ptn = re.compile(r"(.*)\.([0-9]+\.[0-9]{3})(\.[^\.]+)$")
try:
@@ -2294,14 +2432,14 @@ class HttpCli(object):
continue
fn = m.group(1) + m.group(3)
- n, ts, _ = hist.get(fn, [0, 0, ""])
- hist[fn] = [n + 1, max(ts, float(m.group(2))), hfn]
+ n, ts, _ = hist.get(fn, (0, 0, ""))
+ hist[fn] = (n + 1, max(ts, float(m.group(2))), hfn)
except:
pass
# show dotfiles if permitted and requested
if not self.args.ed or "dots" not in self.uparam:
- vfs_ls = exclude_dotfiles(vfs_ls)
+ ls_names = exclude_dotfiles(ls_names)
icur = None
if "e2t" in vn.flags:
@@ -2312,7 +2450,7 @@ class HttpCli(object):
dirs = []
files = []
- for fn in vfs_ls:
+ for fn in ls_names:
base = ""
href = fn
if not is_ls and not is_js and not self.trailing_slash and vpath:
@@ -2339,14 +2477,14 @@ class HttpCli(object):
margin = 'zip'.format(quotep(href))
elif fn in hist:
margin = '#{}'.format(
- base, html_escape(hist[fn][2], quote=True, crlf=True), hist[fn][0]
+ base, html_escape(hist[fn][2], quot=True, crlf=True), hist[fn][0]
)
else:
margin = "-"
sz = inf.st_size
- dt = datetime.utcfromtimestamp(inf.st_mtime)
- dt = dt.strftime("%Y-%m-%d %H:%M:%S")
+ zd = datetime.utcfromtimestamp(inf.st_mtime)
+ dt = zd.strftime("%Y-%m-%d %H:%M:%S")
try:
ext = "---" if is_dir else fn.rsplit(".", 1)[1]
@@ -2380,11 +2518,11 @@ class HttpCli(object):
files.append(item)
item["rd"] = rem
- taglist = {}
- for f in files:
- fn = f["name"]
- rd = f["rd"]
- del f["rd"]
+ tagset: set[str] = set()
+ for fe in files:
+ fn = fe["name"]
+ rd = fe["rd"]
+ del fe["rd"]
if not icur:
break
@@ -2403,12 +2541,12 @@ class HttpCli(object):
args = s3enc(idx.mem_cur, rd, fn)
r = icur.execute(q, args).fetchone()
except:
- m = "tag list error, {}/{}\n{}"
- self.log(m.format(rd, fn, min_ex()))
+ t = "tag list error, {}/{}\n{}"
+ self.log(t.format(rd, fn, min_ex()))
break
- tags = {}
- f["tags"] = tags
+ tags: dict[str, Any] = {}
+ fe["tags"] = tags
if not r:
continue
@@ -2417,17 +2555,19 @@ class HttpCli(object):
q = "select k, v from mt where w = ? and +k != 'x'"
try:
for k, v in icur.execute(q, (w,)):
- taglist[k] = True
+ tagset.add(k)
tags[k] = v
except:
- m = "tag read error, {}/{} [{}]:\n{}"
- self.log(m.format(rd, fn, w, min_ex()))
+ t = "tag read error, {}/{} [{}]:\n{}"
+ self.log(t.format(rd, fn, w, min_ex()))
break
if icur:
- taglist = [k for k in vn.flags.get("mte", "").split(",") if k in taglist]
- for f in dirs:
- f["tags"] = {}
+ taglist = [k for k in vn.flags.get("mte", "").split(",") if k in tagset]
+ for fe in dirs:
+ fe["tags"] = {}
+ else:
+ taglist = list(tagset)
if is_ls:
ls_ret["dirs"] = dirs
@@ -2480,6 +2620,6 @@ class HttpCli(object):
if self.args.css_browser:
j2a["css"] = self.args.css_browser
- html = self.j2(tpl, **j2a)
+ html = self.j2s(tpl, **j2a)
self.reply(html.encode("utf-8", "replace"))
return True
diff --git a/copyparty/httpconn.py b/copyparty/httpconn.py
index 85f5f701..067e2d31 100644
--- a/copyparty/httpconn.py
+++ b/copyparty/httpconn.py
@@ -1,25 +1,36 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
-import re
+import argparse # typechk
import os
-import time
+import re
import socket
+import threading # typechk
+import time
-HAVE_SSL = True
try:
+ HAVE_SSL = True
import ssl
except:
HAVE_SSL = False
-from .__init__ import E
-from .util import Unrecv
+from . import util as Util
+from .__init__ import TYPE_CHECKING, E
+from .authsrv import AuthSrv # typechk
from .httpcli import HttpCli
-from .u2idx import U2idx
+from .ico import Ico
+from .mtag import HAVE_FFMPEG
from .th_cli import ThumbCli
from .th_srv import HAVE_PIL, HAVE_VIPS
-from .mtag import HAVE_FFMPEG
-from .ico import Ico
+from .u2idx import U2idx
+
+try:
+ from typing import Optional, Pattern, Union
+except:
+ pass
+
+if TYPE_CHECKING:
+ from .httpsrv import HttpSrv
class HttpConn(object):
@@ -28,32 +39,37 @@ class HttpConn(object):
creates an HttpCli for each request (Connection: Keep-Alive)
"""
- def __init__(self, sck, addr, hsrv):
+ def __init__(
+ self, sck: socket.socket, addr: tuple[str, int], hsrv: "HttpSrv"
+ ) -> None:
self.s = sck
- self.sr = None # Type: Unrecv
+ self.sr: Optional[Util._Unrecv] = None
self.addr = addr
self.hsrv = hsrv
- self.mutex = hsrv.mutex
- self.args = hsrv.args
- self.asrv = hsrv.asrv
+ self.mutex: threading.Lock = hsrv.mutex # mypy404
+ self.args: argparse.Namespace = hsrv.args # mypy404
+ self.asrv: AuthSrv = hsrv.asrv # mypy404
self.cert_path = hsrv.cert_path
- self.u2fh = hsrv.u2fh
+ self.u2fh: Util.FHC = hsrv.u2fh # mypy404
enth = (HAVE_PIL or HAVE_VIPS or HAVE_FFMPEG) and not self.args.no_thumb
- self.thumbcli = ThumbCli(hsrv) if enth else None
- self.ico = Ico(self.args)
+ self.thumbcli: Optional[ThumbCli] = ThumbCli(hsrv) if enth else None # mypy404
+ self.ico: Ico = Ico(self.args) # mypy404
- self.t0 = time.time()
+ self.t0: float = time.time() # mypy404
self.stopping = False
- self.nreq = 0
- self.nbyte = 0
- self.u2idx = None
- self.log_func = hsrv.log
- self.lf_url = re.compile(self.args.lf_url) if self.args.lf_url else None
+ self.nreq: int = 0 # mypy404
+ self.nbyte: int = 0 # mypy404
+ self.u2idx: Optional[U2idx] = None
+ self.log_func: Util.RootLogger = hsrv.log # mypy404
+ self.log_src: str = "httpconn" # mypy404
+ self.lf_url: Optional[Pattern[str]] = (
+ re.compile(self.args.lf_url) if self.args.lf_url else None
+ ) # mypy404
self.set_rproxy()
- def shutdown(self):
+ def shutdown(self) -> None:
self.stopping = True
try:
self.s.shutdown(socket.SHUT_RDWR)
@@ -61,7 +77,7 @@ class HttpConn(object):
except:
pass
- def set_rproxy(self, ip=None):
+ def set_rproxy(self, ip: Optional[str] = None) -> str:
if ip is None:
color = 36
ip = self.addr[0]
@@ -74,35 +90,35 @@ class HttpConn(object):
self.log_src = "{} \033[{}m{}".format(ip, color, self.addr[1]).ljust(26)
return self.log_src
- def respath(self, res_name):
+ def respath(self, res_name: str) -> str:
return os.path.join(E.mod, "web", res_name)
- def log(self, msg, c=0):
+ def log(self, msg: str, c: Union[int, str] = 0) -> None:
self.log_func(self.log_src, msg, c)
- def get_u2idx(self):
+ def get_u2idx(self) -> U2idx:
if not self.u2idx:
self.u2idx = U2idx(self)
return self.u2idx
- def _detect_https(self):
+ def _detect_https(self) -> bool:
method = None
if self.cert_path:
try:
method = self.s.recv(4, socket.MSG_PEEK)
except socket.timeout:
- return
+ return False
except AttributeError:
# jython does not support msg_peek; forget about https
method = self.s.recv(4)
- self.sr = Unrecv(self.s, self.log)
+ self.sr = Util.Unrecv(self.s, self.log)
self.sr.buf = method
# jython used to do this, they stopped since it's broken
# but reimplementing sendall is out of scope for now
if not getattr(self.s, "sendall", None):
- self.s.sendall = self.s.send
+ self.s.sendall = self.s.send # type: ignore
if len(method) != 4:
err = "need at least 4 bytes in the first packet; got {}".format(
@@ -112,17 +128,18 @@ class HttpConn(object):
self.log(err)
self.s.send(b"HTTP/1.1 400 Bad Request\r\n\r\n" + err.encode("utf-8"))
- return
+ return False
return method not in [None, b"GET ", b"HEAD", b"POST", b"PUT ", b"OPTI"]
- def run(self):
+ def run(self) -> None:
self.sr = None
if self.args.https_only:
is_https = True
elif self.args.http_only or not HAVE_SSL:
is_https = False
else:
+ # raise Exception("asdf")
is_https = self._detect_https()
if is_https:
@@ -151,14 +168,15 @@ class HttpConn(object):
self.s = ctx.wrap_socket(self.s, server_side=True)
msg = [
"\033[1;3{:d}m{}".format(c, s)
- for c, s in zip([0, 5, 0], self.s.cipher())
+ for c, s in zip([0, 5, 0], self.s.cipher()) # type: ignore
]
self.log(" ".join(msg) + "\033[0m")
if self.args.ssl_dbg and hasattr(self.s, "shared_ciphers"):
- overlap = [y[::-1] for y in self.s.shared_ciphers()]
- lines = [str(x) for x in (["TLS cipher overlap:"] + overlap)]
- self.log("\n".join(lines))
+ ciphers = self.s.shared_ciphers()
+ assert ciphers
+ overlap = [str(y[::-1]) for y in ciphers]
+ self.log("TLS cipher overlap:" + "\n".join(overlap))
for k, v in [
["compression", self.s.compression()],
["ALPN proto", self.s.selected_alpn_protocol()],
@@ -183,7 +201,7 @@ class HttpConn(object):
return
if not self.sr:
- self.sr = Unrecv(self.s, self.log)
+ self.sr = Util.Unrecv(self.s, self.log)
while not self.stopping:
self.nreq += 1
diff --git a/copyparty/httpsrv.py b/copyparty/httpsrv.py
index 04d2146c..fd449b5f 100644
--- a/copyparty/httpsrv.py
+++ b/copyparty/httpsrv.py
@@ -1,13 +1,15 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
-import os
-import sys
-import time
-import math
import base64
+import math
+import os
import socket
+import sys
import threading
+import time
+
+import queue
try:
import jinja2
@@ -26,15 +28,18 @@ except ImportError:
)
sys.exit(1)
-from .__init__ import E, PY2, MACOS
-from .util import FHC, spack, min_ex, start_stackmon, start_log_thrs
+from .__init__ import MACOS, TYPE_CHECKING, E
from .bos import bos
from .httpconn import HttpConn
+from .util import FHC, min_ex, spack, start_log_thrs, start_stackmon
-if PY2:
- import Queue as queue
-else:
- import queue
+if TYPE_CHECKING:
+ from .broker_util import BrokerCli
+
+try:
+ from typing import Any, Optional
+except:
+ pass
class HttpSrv(object):
@@ -43,7 +48,7 @@ class HttpSrv(object):
relying on MpSrv for performance (HttpSrv is just plain threads)
"""
- def __init__(self, broker, nid):
+ def __init__(self, broker: "BrokerCli", nid: Optional[int]) -> None:
self.broker = broker
self.nid = nid
self.args = broker.args
@@ -58,17 +63,19 @@ class HttpSrv(object):
self.tp_nthr = 0 # actual
self.tp_ncli = 0 # fading
- self.tp_time = None # latest worker collect
- self.tp_q = None if self.args.no_htp else queue.LifoQueue()
- self.t_periodic = None
+ self.tp_time = 0.0 # latest worker collect
+ self.tp_q: Optional[queue.LifoQueue[Any]] = (
+ None if self.args.no_htp else queue.LifoQueue()
+ )
+ self.t_periodic: Optional[threading.Thread] = None
self.u2fh = FHC()
- self.srvs = []
+ self.srvs: list[socket.socket] = []
self.ncli = 0 # exact
- self.clients = {} # laggy
+ self.clients: set[HttpConn] = set() # laggy
self.nclimax = 0
- self.cb_ts = 0
- self.cb_v = 0
+ self.cb_ts = 0.0
+ self.cb_v = ""
env = jinja2.Environment()
env.loader = jinja2.FileSystemLoader(os.path.join(E.mod, "web"))
@@ -82,7 +89,7 @@ class HttpSrv(object):
if bos.path.exists(cert_path):
self.cert_path = cert_path
else:
- self.cert_path = None
+ self.cert_path = ""
if self.tp_q:
self.start_threads(4)
@@ -94,19 +101,19 @@ class HttpSrv(object):
if self.args.log_thrs:
start_log_thrs(self.log, self.args.log_thrs, nid)
- self.th_cfg = {} # type: dict[str, Any]
+ self.th_cfg: dict[str, Any] = {}
t = threading.Thread(target=self.post_init)
t.daemon = True
t.start()
- def post_init(self):
+ def post_init(self) -> None:
try:
- x = self.broker.put(True, "thumbsrv.getcfg")
+ x = self.broker.ask("thumbsrv.getcfg")
self.th_cfg = x.get()
except:
pass
- def start_threads(self, n):
+ def start_threads(self, n: int) -> None:
self.tp_nthr += n
if self.args.log_htp:
self.log(self.name, "workers += {} = {}".format(n, self.tp_nthr), 6)
@@ -119,15 +126,16 @@ class HttpSrv(object):
thr.daemon = True
thr.start()
- def stop_threads(self, n):
+ def stop_threads(self, n: int) -> None:
self.tp_nthr -= n
if self.args.log_htp:
self.log(self.name, "workers -= {} = {}".format(n, self.tp_nthr), 6)
+ assert self.tp_q
for _ in range(n):
self.tp_q.put(None)
- def periodic(self):
+ def periodic(self) -> None:
while True:
time.sleep(2 if self.tp_ncli or self.ncli else 10)
with self.mutex:
@@ -141,7 +149,7 @@ class HttpSrv(object):
self.t_periodic = None
return
- def listen(self, sck, nlisteners):
+ def listen(self, sck: socket.socket, nlisteners: int) -> None:
ip, port = sck.getsockname()
self.srvs.append(sck)
self.nclimax = math.ceil(self.args.nc * 1.0 / nlisteners)
@@ -153,15 +161,15 @@ class HttpSrv(object):
t.daemon = True
t.start()
- def thr_listen(self, srv_sck):
+ def thr_listen(self, srv_sck: socket.socket) -> None:
"""listens on a shared tcp server"""
ip, port = srv_sck.getsockname()
fno = srv_sck.fileno()
msg = "subscribed @ {}:{} f{}".format(ip, port, fno)
self.log(self.name, msg)
- def fun():
- self.broker.put(False, "cb_httpsrv_up")
+ def fun() -> None:
+ self.broker.say("cb_httpsrv_up")
threading.Thread(target=fun).start()
@@ -185,21 +193,21 @@ class HttpSrv(object):
continue
if self.args.log_conn:
- m = "|{}C-acc2 \033[0;36m{} \033[3{}m{}".format(
+ t = "|{}C-acc2 \033[0;36m{} \033[3{}m{}".format(
"-" * 3, ip, port % 8, port
)
- self.log("%s %s" % addr, m, c="1;30")
+ self.log("%s %s" % addr, t, c="1;30")
self.accept(sck, addr)
- def accept(self, sck, addr):
+ def accept(self, sck: socket.socket, addr: tuple[str, int]) -> None:
"""takes an incoming tcp connection and creates a thread to handle it"""
now = time.time()
if now - (self.tp_time or now) > 300:
- m = "httpserver threadpool died: tpt {:.2f}, now {:.2f}, nthr {}, ncli {}"
- self.log(self.name, m.format(self.tp_time, now, self.tp_nthr, self.ncli), 1)
- self.tp_time = None
+ t = "httpserver threadpool died: tpt {:.2f}, now {:.2f}, nthr {}, ncli {}"
+ self.log(self.name, t.format(self.tp_time, now, self.tp_nthr, self.ncli), 1)
+ self.tp_time = 0
self.tp_q = None
with self.mutex:
@@ -209,10 +217,10 @@ class HttpSrv(object):
if self.nid:
name += "-{}".format(self.nid)
- t = threading.Thread(target=self.periodic, name=name)
- self.t_periodic = t
- t.daemon = True
- t.start()
+ thr = threading.Thread(target=self.periodic, name=name)
+ self.t_periodic = thr
+ thr.daemon = True
+ thr.start()
if self.tp_q:
self.tp_time = self.tp_time or now
@@ -224,8 +232,8 @@ class HttpSrv(object):
return
if not self.args.no_htp:
- m = "looks like the httpserver threadpool died; please make an issue on github and tell me the story of how you pulled that off, thanks and dog bless\n"
- self.log(self.name, m, 1)
+ t = "looks like the httpserver threadpool died; please make an issue on github and tell me the story of how you pulled that off, thanks and dog bless\n"
+ self.log(self.name, t, 1)
thr = threading.Thread(
target=self.thr_client,
@@ -235,14 +243,15 @@ class HttpSrv(object):
thr.daemon = True
thr.start()
- def thr_poolw(self):
+ def thr_poolw(self) -> None:
+ assert self.tp_q
while True:
task = self.tp_q.get()
if not task:
break
with self.mutex:
- self.tp_time = None
+ self.tp_time = 0
try:
sck, addr = task
@@ -255,7 +264,7 @@ class HttpSrv(object):
except:
self.log(self.name, "thr_client: " + min_ex(), 3)
- def shutdown(self):
+ def shutdown(self) -> None:
self.stopping = True
for srv in self.srvs:
try:
@@ -263,7 +272,7 @@ class HttpSrv(object):
except:
pass
- clients = list(self.clients.keys())
+ clients = list(self.clients)
for cli in clients:
try:
cli.shutdown()
@@ -279,13 +288,13 @@ class HttpSrv(object):
self.log(self.name, "ok bye")
- def thr_client(self, sck, addr):
+ def thr_client(self, sck: socket.socket, addr: tuple[str, int]) -> None:
"""thread managing one tcp client"""
sck.settimeout(120)
cli = HttpConn(sck, addr, self)
with self.mutex:
- self.clients[cli] = 0
+ self.clients.add(cli)
fno = sck.fileno()
try:
@@ -328,10 +337,10 @@ class HttpSrv(object):
raise
finally:
with self.mutex:
- del self.clients[cli]
+ self.clients.remove(cli)
self.ncli -= 1
- def cachebuster(self):
+ def cachebuster(self) -> str:
if time.time() - self.cb_ts < 1:
return self.cb_v
diff --git a/copyparty/ico.py b/copyparty/ico.py
index 58076c89..f403e4b5 100644
--- a/copyparty/ico.py
+++ b/copyparty/ico.py
@@ -1,28 +1,28 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
-import hashlib
+import argparse # typechk
import colorsys
+import hashlib
from .__init__ import PY2
class Ico(object):
- def __init__(self, args):
+ def __init__(self, args: argparse.Namespace) -> None:
self.args = args
- def get(self, ext, as_thumb):
+ def get(self, ext: str, as_thumb: bool) -> tuple[str, bytes]:
"""placeholder to make thumbnails not break"""
- h = hashlib.md5(ext.encode("utf-8")).digest()[:2]
+ zb = hashlib.md5(ext.encode("utf-8")).digest()[:2]
if PY2:
- h = [ord(x) for x in h]
+ zb = [ord(x) for x in zb]
- c1 = colorsys.hsv_to_rgb(h[0] / 256.0, 1, 0.3)
- c2 = colorsys.hsv_to_rgb(h[0] / 256.0, 1, 1)
- c = list(c1) + list(c2)
- c = [int(x * 255) for x in c]
- c = "".join(["{:02x}".format(x) for x in c])
+ c1 = colorsys.hsv_to_rgb(zb[0] / 256.0, 1, 0.3)
+ c2 = colorsys.hsv_to_rgb(zb[0] / 256.0, 1, 1)
+ ci = [int(x * 255) for x in list(c1) + list(c2)]
+ c = "".join(["{:02x}".format(x) for x in ci])
h = 30
if not self.args.th_no_crop and as_thumb:
@@ -37,6 +37,6 @@ class Ico(object):
fill="#{}" font-family="monospace" font-size="14px" style="letter-spacing:.5px">{}
"""
- svg = svg.format(h, c[:6], c[6:], ext).encode("utf-8")
+ svg = svg.format(h, c[:6], c[6:], ext)
- return ["image/svg+xml", svg]
+ return "image/svg+xml", svg.encode("utf-8")
diff --git a/copyparty/mtag.py b/copyparty/mtag.py
index 60a7584f..dd1e2ec3 100644
--- a/copyparty/mtag.py
+++ b/copyparty/mtag.py
@@ -1,18 +1,26 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
-import os
-import sys
+import argparse
import json
+import os
import shutil
import subprocess as sp
+import sys
from .__init__ import PY2, WINDOWS, unicode
-from .util import fsenc, uncyg, runcmd, retchk, REKOBO_LKEY
from .bos import bos
+from .util import REKOBO_LKEY, fsenc, retchk, runcmd, uncyg
+
+try:
+ from typing import Any, Union
+
+ from .util import RootLogger
+except:
+ pass
-def have_ff(cmd):
+def have_ff(cmd: str) -> bool:
if PY2:
print("# checking {}".format(cmd))
cmd = (cmd + " -version").encode("ascii").split(b" ")
@@ -30,7 +38,7 @@ HAVE_FFPROBE = have_ff("ffprobe")
class MParser(object):
- def __init__(self, cmdline):
+ def __init__(self, cmdline: str) -> None:
self.tag, args = cmdline.split("=", 1)
self.tags = self.tag.split(",")
@@ -73,7 +81,9 @@ class MParser(object):
raise Exception()
-def ffprobe(abspath, timeout=10):
+def ffprobe(
+ abspath: str, timeout: int = 10
+) -> tuple[dict[str, tuple[int, Any]], dict[str, list[Any]]]:
cmd = [
b"ffprobe",
b"-hide_banner",
@@ -87,15 +97,15 @@ def ffprobe(abspath, timeout=10):
return parse_ffprobe(so)
-def parse_ffprobe(txt):
+def parse_ffprobe(txt: str) -> tuple[dict[str, tuple[int, Any]], dict[str, list[Any]]]:
"""ffprobe -show_format -show_streams"""
streams = []
fmt = {}
g = {}
for ln in [x.rstrip("\r") for x in txt.split("\n")]:
try:
- k, v = ln.split("=", 1)
- g[k] = v
+ sk, sv = ln.split("=", 1)
+ g[sk] = sv
continue
except:
pass
@@ -109,8 +119,8 @@ def parse_ffprobe(txt):
fmt = g
streams = [fmt] + streams
- ret = {} # processed
- md = {} # raw tags
+ ret: dict[str, Any] = {} # processed
+ md: dict[str, list[Any]] = {} # raw tags
is_audio = fmt.get("format_name") in ["mp3", "ogg", "flac", "wav"]
if fmt.get("filename", "").split(".")[-1].lower() in ["m4a", "aac"]:
@@ -161,43 +171,43 @@ def parse_ffprobe(txt):
kvm = [["duration", ".dur"], ["bit_rate", ".q"]]
for sk, rk in kvm:
- v = strm.get(sk)
- if v is None:
+ v1 = strm.get(sk)
+ if v1 is None:
continue
if rk.startswith("."):
try:
- v = float(v)
+ zf = float(v1)
v2 = ret.get(rk)
- if v2 is None or v > v2:
- ret[rk] = v
+ if v2 is None or zf > v2:
+ ret[rk] = zf
except:
# sqlite doesnt care but the code below does
- if v not in ["N/A"]:
- ret[rk] = v
+ if v1 not in ["N/A"]:
+ ret[rk] = v1
else:
- ret[rk] = v
+ ret[rk] = v1
if ret.get("vc") == "ansi": # shellscript
return {}, {}
for strm in streams:
- for k, v in strm.items():
- if not k.startswith("TAG:"):
+ for sk, sv in strm.items():
+ if not sk.startswith("TAG:"):
continue
- k = k[4:].strip()
- v = v.strip()
- if k and v and k not in md:
- md[k] = [v]
+ sk = sk[4:].strip()
+ sv = sv.strip()
+ if sk and sv and sk not in md:
+ md[sk] = [sv]
- for k in [".q", ".vq", ".aq"]:
- if k in ret:
- ret[k] /= 1000 # bit_rate=320000
+ for sk in [".q", ".vq", ".aq"]:
+ if sk in ret:
+ ret[sk] /= 1000 # bit_rate=320000
- for k in [".q", ".vq", ".aq", ".resw", ".resh"]:
- if k in ret:
- ret[k] = int(ret[k])
+ for sk in [".q", ".vq", ".aq", ".resw", ".resh"]:
+ if sk in ret:
+ ret[sk] = int(ret[sk])
if ".fps" in ret:
fps = ret[".fps"]
@@ -219,13 +229,13 @@ def parse_ffprobe(txt):
if ".resw" in ret and ".resh" in ret:
ret["res"] = "{}x{}".format(ret[".resw"], ret[".resh"])
- ret = {k: [0, v] for k, v in ret.items()}
+ zd = {k: (0, v) for k, v in ret.items()}
- return ret, md
+ return zd, md
class MTag(object):
- def __init__(self, log_func, args):
+ def __init__(self, log_func: RootLogger, args: argparse.Namespace) -> None:
self.log_func = log_func
self.args = args
self.usable = True
@@ -242,7 +252,7 @@ class MTag(object):
if self.backend == "mutagen":
self.get = self.get_mutagen
try:
- import mutagen
+ import mutagen # noqa: F401 # pylint: disable=unused-import,import-outside-toplevel
except:
self.log("could not load Mutagen, trying FFprobe instead", c=3)
self.backend = "ffprobe"
@@ -339,31 +349,33 @@ class MTag(object):
}
# self.get = self.compare
- def log(self, msg, c=0):
+ def log(self, msg: str, c: Union[int, str] = 0) -> None:
self.log_func("mtag", msg, c)
- def normalize_tags(self, ret, md):
- for k, v in dict(md).items():
- if not v:
+ def normalize_tags(
+ self, parser_output: dict[str, tuple[int, Any]], md: dict[str, list[Any]]
+ ) -> dict[str, Union[str, float]]:
+ for sk, tv in dict(md).items():
+ if not tv:
continue
- k = k.lower().split("::")[0].strip()
- mk = self.rmap.get(k)
- if not mk:
+ sk = sk.lower().split("::")[0].strip()
+ key_mapping = self.rmap.get(sk)
+ if not key_mapping:
continue
- pref, mk = mk
- if mk not in ret or ret[mk][0] > pref:
- ret[mk] = [pref, v[0]]
+ priority, alias = key_mapping
+ if alias not in parser_output or parser_output[alias][0] > priority:
+ parser_output[alias] = (priority, tv[0])
- # take first value
- ret = {k: unicode(v[1]).strip() for k, v in ret.items()}
+ # take first value (lowest priority / most preferred)
+ ret = {sk: unicode(tv[1]).strip() for sk, tv in parser_output.items()}
# track 3/7 => track 3
- for k, v in ret.items():
- if k[0] == ".":
- v = v.split("/")[0].strip().lstrip("0")
- ret[k] = v or 0
+ for sk, tv in ret.items():
+ if sk[0] == ".":
+ sv = str(tv).split("/")[0].strip().lstrip("0")
+ ret[sk] = sv or 0
# normalize key notation to rkeobo
okey = ret.get("key")
@@ -373,7 +385,7 @@ class MTag(object):
return ret
- def compare(self, abspath):
+ def compare(self, abspath: str) -> dict[str, Union[str, float]]:
if abspath.endswith(".au"):
return {}
@@ -411,7 +423,7 @@ class MTag(object):
return r1
- def get_mutagen(self, abspath):
+ def get_mutagen(self, abspath: str) -> dict[str, Union[str, float]]:
if not bos.path.isfile(abspath):
return {}
@@ -425,7 +437,7 @@ class MTag(object):
return self.get_ffprobe(abspath) if self.can_ffprobe else {}
sz = bos.path.getsize(abspath)
- ret = {".q": [0, int((sz / md.info.length) / 128)]}
+ ret = {".q": (0, int((sz / md.info.length) / 128))}
for attr, k, norm in [
["codec", "ac", unicode],
@@ -456,24 +468,24 @@ class MTag(object):
if k == "ac" and v.startswith("mp4a.40."):
v = "aac"
- ret[k] = [0, norm(v)]
+ ret[k] = (0, norm(v))
return self.normalize_tags(ret, md)
- def get_ffprobe(self, abspath):
+ def get_ffprobe(self, abspath: str) -> dict[str, Union[str, float]]:
if not bos.path.isfile(abspath):
return {}
ret, md = ffprobe(abspath)
return self.normalize_tags(ret, md)
- def get_bin(self, parsers, abspath):
+ def get_bin(self, parsers: dict[str, MParser], abspath: str) -> dict[str, Any]:
if not bos.path.isfile(abspath):
return {}
pypath = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
- pypath = [str(pypath)] + [str(x) for x in sys.path if x]
- pypath = str(os.pathsep.join(pypath))
+ zsl = [str(pypath)] + [str(x) for x in sys.path if x]
+ pypath = str(os.pathsep.join(zsl))
env = os.environ.copy()
env["PYTHONPATH"] = pypath
@@ -491,9 +503,9 @@ class MTag(object):
else:
cmd = ["nice"] + cmd
- cmd = [fsenc(x) for x in cmd]
- rc, v, err = runcmd(cmd, **args)
- retchk(rc, cmd, err, self.log, 5, self.args.mtag_v)
+ bcmd = [fsenc(x) for x in cmd]
+ rc, v, err = runcmd(bcmd, **args) # type: ignore
+ retchk(rc, bcmd, err, self.log, 5, self.args.mtag_v)
v = v.strip()
if not v:
continue
@@ -501,10 +513,10 @@ class MTag(object):
if "," not in tagname:
ret[tagname] = v
else:
- v = json.loads(v)
+ zj = json.loads(v)
for tag in tagname.split(","):
- if tag and tag in v:
- ret[tag] = v[tag]
+ if tag and tag in zj:
+ ret[tag] = zj[tag]
except:
pass
diff --git a/copyparty/star.py b/copyparty/star.py
index 804ad724..21c2703c 100644
--- a/copyparty/star.py
+++ b/copyparty/star.py
@@ -4,20 +4,29 @@ from __future__ import print_function, unicode_literals
import tarfile
import threading
-from .sutil import errdesc
-from .util import Queue, fsenc, min_ex
+from queue import Queue
+
from .bos import bos
+from .sutil import StreamArc, errdesc
+from .util import fsenc, min_ex
+
+try:
+ from typing import Any, Generator, Optional
+
+ from .util import NamedLogger
+except:
+ pass
-class QFile(object):
+class QFile(object): # inherit io.StringIO for painful typing
"""file-like object which buffers writes into a queue"""
- def __init__(self):
- self.q = Queue(64)
- self.bq = []
+ def __init__(self) -> None:
+ self.q: Queue[Optional[bytes]] = Queue(64)
+ self.bq: list[bytes] = []
self.nq = 0
- def write(self, buf):
+ def write(self, buf: Optional[bytes]) -> None:
if buf is None or self.nq >= 240 * 1024:
self.q.put(b"".join(self.bq))
self.bq = []
@@ -30,27 +39,32 @@ class QFile(object):
self.nq += len(buf)
-class StreamTar(object):
+class StreamTar(StreamArc):
"""construct in-memory tar file from the given path"""
- def __init__(self, log, fgen, **kwargs):
+ def __init__(
+ self,
+ log: NamedLogger,
+ fgen: Generator[dict[str, Any], None, None],
+ **kwargs: Any
+ ):
+ super(StreamTar, self).__init__(log, fgen)
+
self.ci = 0
self.co = 0
self.qfile = QFile()
- self.log = log
- self.fgen = fgen
- self.errf = None
+ self.errf: dict[str, Any] = {}
# python 3.8 changed to PAX_FORMAT as default,
# waste of space and don't care about the new features
fmt = tarfile.GNU_FORMAT
- self.tar = tarfile.open(fileobj=self.qfile, mode="w|", format=fmt)
+ self.tar = tarfile.open(fileobj=self.qfile, mode="w|", format=fmt) # type: ignore
w = threading.Thread(target=self._gen, name="star-gen")
w.daemon = True
w.start()
- def gen(self):
+ def gen(self) -> Generator[Optional[bytes], None, None]:
while True:
buf = self.qfile.q.get()
if not buf:
@@ -63,7 +77,7 @@ class StreamTar(object):
if self.errf:
bos.unlink(self.errf["ap"])
- def ser(self, f):
+ def ser(self, f: dict[str, Any]) -> None:
name = f["vp"]
src = f["ap"]
fsi = f["st"]
@@ -76,21 +90,21 @@ class StreamTar(object):
inf.gid = 0
self.ci += inf.size
- with open(fsenc(src), "rb", 512 * 1024) as f:
- self.tar.addfile(inf, f)
+ with open(fsenc(src), "rb", 512 * 1024) as fo:
+ self.tar.addfile(inf, fo)
- def _gen(self):
+ def _gen(self) -> None:
errors = []
for f in self.fgen:
if "err" in f:
- errors.append([f["vp"], f["err"]])
+ errors.append((f["vp"], f["err"]))
continue
try:
self.ser(f)
except:
ex = min_ex(5, True).replace("\n", "\n-- ")
- errors.append([f["vp"], ex])
+ errors.append((f["vp"], ex))
if errors:
self.errf, txt = errdesc(errors)
diff --git a/copyparty/stolen/surrogateescape.py b/copyparty/stolen/surrogateescape.py
index 4b06ed28..b1ff8886 100644
--- a/copyparty/stolen/surrogateescape.py
+++ b/copyparty/stolen/surrogateescape.py
@@ -12,23 +12,28 @@ Original source: misc/python/surrogateescape.py in https://bitbucket.org/haypo/m
# This code is released under the Python license and the BSD 2-clause license
-import platform
import codecs
+import platform
import sys
PY3 = sys.version_info[0] > 2
WINDOWS = platform.system() == "Windows"
FS_ERRORS = "surrogateescape"
+try:
+ from typing import Any
+except:
+ pass
-def u(text):
+
+def u(text: Any) -> str:
if PY3:
return text
else:
return text.decode("unicode_escape")
-def b(data):
+def b(data: Any) -> bytes:
if PY3:
return data.encode("latin1")
else:
@@ -43,7 +48,7 @@ else:
bytes_chr = chr
-def surrogateescape_handler(exc):
+def surrogateescape_handler(exc: Any) -> tuple[str, int]:
"""
Pure Python implementation of the PEP 383: the "surrogateescape" error
handler of Python 3. Undecodable bytes will be replaced by a Unicode
@@ -74,7 +79,7 @@ class NotASurrogateError(Exception):
pass
-def replace_surrogate_encode(mystring):
+def replace_surrogate_encode(mystring: str) -> str:
"""
Returns a (unicode) string, not the more logical bytes, because the codecs
register_error functionality expects this.
@@ -100,7 +105,7 @@ def replace_surrogate_encode(mystring):
return str().join(decoded)
-def replace_surrogate_decode(mybytes):
+def replace_surrogate_decode(mybytes: bytes) -> str:
"""
Returns a (unicode) string
"""
@@ -121,7 +126,7 @@ def replace_surrogate_decode(mybytes):
return str().join(decoded)
-def encodefilename(fn):
+def encodefilename(fn: str) -> bytes:
if FS_ENCODING == "ascii":
# ASCII encoder of Python 2 expects that the error handler returns a
# Unicode string encodable to ASCII, whereas our surrogateescape error
@@ -161,7 +166,7 @@ def encodefilename(fn):
return fn.encode(FS_ENCODING, FS_ERRORS)
-def decodefilename(fn):
+def decodefilename(fn: bytes) -> str:
return fn.decode(FS_ENCODING, FS_ERRORS)
@@ -181,7 +186,7 @@ if WINDOWS and not PY3:
FS_ENCODING = codecs.lookup(FS_ENCODING).name
-def register_surrogateescape():
+def register_surrogateescape() -> None:
"""
Registers the surrogateescape error handler on Python 2 (only)
"""
diff --git a/copyparty/sutil.py b/copyparty/sutil.py
index 22de0066..506e389f 100644
--- a/copyparty/sutil.py
+++ b/copyparty/sutil.py
@@ -6,8 +6,29 @@ from datetime import datetime
from .bos import bos
+try:
+ from typing import Any, Generator, Optional
-def errdesc(errors):
+ from .util import NamedLogger
+except:
+ pass
+
+
+class StreamArc(object):
+ def __init__(
+ self,
+ log: NamedLogger,
+ fgen: Generator[dict[str, Any], None, None],
+ **kwargs: Any
+ ):
+ self.log = log
+ self.fgen = fgen
+
+ def gen(self) -> Generator[Optional[bytes], None, None]:
+ pass
+
+
+def errdesc(errors: list[tuple[str, str]]) -> tuple[dict[str, Any], list[str]]:
report = ["copyparty failed to add the following files to the archive:", ""]
for fn, err in errors:
diff --git a/copyparty/svchub.py b/copyparty/svchub.py
index 44680513..e9b5aadf 100644
--- a/copyparty/svchub.py
+++ b/copyparty/svchub.py
@@ -1,41 +1,51 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
+import argparse
+import calendar
import os
-import sys
-import time
import shlex
-import string
import signal
import socket
+import string
+import sys
import threading
+import time
from datetime import datetime, timedelta
-import calendar
-from .__init__ import E, PY2, WINDOWS, ANYWIN, MACOS, VT100, unicode
-from .util import mp, start_log_thrs, start_stackmon, min_ex, ansi_re
+try:
+ from types import FrameType
+
+ import typing
+ from typing import Optional, Union
+except:
+ pass
+
+from .__init__ import ANYWIN, MACOS, PY2, VT100, WINDOWS, E, unicode
from .authsrv import AuthSrv
-from .tcpsrv import TcpSrv
-from .up2k import Up2k
-from .th_srv import ThumbSrv, HAVE_PIL, HAVE_VIPS, HAVE_WEBP
from .mtag import HAVE_FFMPEG, HAVE_FFPROBE
+from .tcpsrv import TcpSrv
+from .th_srv import HAVE_PIL, HAVE_VIPS, HAVE_WEBP, ThumbSrv
+from .up2k import Up2k
+from .util import ansi_re, min_ex, mp, start_log_thrs, start_stackmon
class SvcHub(object):
"""
Hosts all services which cannot be parallelized due to reliance on monolithic resources.
Creates a Broker which does most of the heavy stuff; hosted services can use this to perform work:
- hub.broker.put(want_reply, destination, args_list).
+ hub.broker.(destination, args_list).
Either BrokerThr (plain threads) or BrokerMP (multiprocessing) is used depending on configuration.
Nothing is returned synchronously; if you want any value returned from the call,
put() can return a queue (if want_reply=True) which has a blocking get() with the response.
"""
- def __init__(self, args, argv, printed):
+ def __init__(self, args: argparse.Namespace, argv: list[str], printed: str) -> None:
self.args = args
self.argv = argv
- self.logf = None
+ self.logf: Optional[typing.TextIO] = None
+ self.logf_base_fn = ""
self.stop_req = False
self.reload_req = False
self.stopping = False
@@ -59,16 +69,16 @@ class SvcHub(object):
if not args.use_fpool and args.j != 1:
args.no_fpool = True
- m = "multithreading enabled with -j {}, so disabling fpool -- this can reduce upload performance on some filesystems"
- self.log("root", m.format(args.j))
+ t = "multithreading enabled with -j {}, so disabling fpool -- this can reduce upload performance on some filesystems"
+ self.log("root", t.format(args.j))
if not args.no_fpool and args.j != 1:
- m = "WARNING: --use-fpool combined with multithreading is untested and can probably cause undefined behavior"
+ t = "WARNING: --use-fpool combined with multithreading is untested and can probably cause undefined behavior"
if ANYWIN:
- m = 'windows cannot do multithreading without --no-fpool, so enabling that -- note that upload performance will suffer if you have microsoft defender "real-time protection" enabled, so you probably want to use -j 1 instead'
+ t = 'windows cannot do multithreading without --no-fpool, so enabling that -- note that upload performance will suffer if you have microsoft defender "real-time protection" enabled, so you probably want to use -j 1 instead'
args.no_fpool = True
- self.log("root", m, c=3)
+ self.log("root", t, c=3)
bri = "zy"[args.theme % 2 :][:1]
ch = "abcdefghijklmnopqrstuvwx"[int(args.theme / 2)]
@@ -96,8 +106,8 @@ class SvcHub(object):
self.args.th_dec = list(decs.keys())
self.thumbsrv = None
if not args.no_thumb:
- m = "decoder preference: {}".format(", ".join(self.args.th_dec))
- self.log("thumb", m)
+ t = "decoder preference: {}".format(", ".join(self.args.th_dec))
+ self.log("thumb", t)
if "pil" in self.args.th_dec and not HAVE_WEBP:
msg = "disabling webp thumbnails because either libwebp is not available or your Pillow is too old"
@@ -131,11 +141,11 @@ class SvcHub(object):
if self.check_mp_enable():
from .broker_mp import BrokerMp as Broker
else:
- from .broker_thr import BrokerThr as Broker
+ from .broker_thr import BrokerThr as Broker # type: ignore
self.broker = Broker(self)
- def thr_httpsrv_up(self):
+ def thr_httpsrv_up(self) -> None:
time.sleep(1 if self.args.ign_ebind_all else 5)
expected = self.broker.num_workers * self.tcpsrv.nsrv
failed = expected - self.httpsrv_up
@@ -145,20 +155,20 @@ class SvcHub(object):
if self.args.ign_ebind_all:
if not self.tcpsrv.srv:
for _ in range(self.broker.num_workers):
- self.broker.put(False, "cb_httpsrv_up")
+ self.broker.say("cb_httpsrv_up")
return
if self.args.ign_ebind and self.tcpsrv.srv:
return
- m = "{}/{} workers failed to start"
- m = m.format(failed, expected)
- self.log("root", m, 1)
+ t = "{}/{} workers failed to start"
+ t = t.format(failed, expected)
+ self.log("root", t, 1)
self.retcode = 1
os.kill(os.getpid(), signal.SIGTERM)
- def cb_httpsrv_up(self):
+ def cb_httpsrv_up(self) -> None:
self.httpsrv_up += 1
if self.httpsrv_up != self.broker.num_workers:
return
@@ -171,9 +181,9 @@ class SvcHub(object):
thr.daemon = True
thr.start()
- def _logname(self):
+ def _logname(self) -> str:
dt = datetime.utcnow()
- fn = self.args.lo
+ fn = str(self.args.lo)
for fs in "YmdHMS":
fs = "%" + fs
if fs in fn:
@@ -181,7 +191,7 @@ class SvcHub(object):
return fn
- def _setup_logfile(self, printed):
+ def _setup_logfile(self, printed: str) -> None:
base_fn = fn = sel_fn = self._logname()
if fn != self.args.lo:
ctr = 0
@@ -203,8 +213,6 @@ class SvcHub(object):
lh = codecs.open(fn, "w", encoding="utf-8", errors="replace")
- lh.base_fn = base_fn
-
argv = [sys.executable] + self.argv
if hasattr(shlex, "quote"):
argv = [shlex.quote(x) for x in argv]
@@ -215,9 +223,10 @@ class SvcHub(object):
printed += msg
lh.write("t0: {:.3f}\nargv: {}\n\n{}".format(E.t0, " ".join(argv), printed))
self.logf = lh
+ self.logf_base_fn = base_fn
print(msg, end="")
- def run(self):
+ def run(self) -> None:
self.tcpsrv.run()
thr = threading.Thread(target=self.thr_httpsrv_up)
@@ -252,7 +261,7 @@ class SvcHub(object):
else:
self.stop_thr()
- def reload(self):
+ def reload(self) -> str:
if self.reloading:
return "cannot reload; already in progress"
@@ -262,7 +271,7 @@ class SvcHub(object):
t.start()
return "reload initiated"
- def _reload(self):
+ def _reload(self) -> None:
self.log("root", "reload scheduled")
with self.up2k.mutex:
self.asrv.reload()
@@ -271,7 +280,7 @@ class SvcHub(object):
self.reloading = False
- def stop_thr(self):
+ def stop_thr(self) -> None:
while not self.stop_req:
with self.stop_cond:
self.stop_cond.wait(9001)
@@ -282,7 +291,7 @@ class SvcHub(object):
self.shutdown()
- def signal_handler(self, sig, frame):
+ def signal_handler(self, sig: int, frame: Optional[FrameType]) -> None:
if self.stopping:
return
@@ -294,7 +303,7 @@ class SvcHub(object):
with self.stop_cond:
self.stop_cond.notify_all()
- def shutdown(self):
+ def shutdown(self) -> None:
if self.stopping:
return
@@ -337,7 +346,7 @@ class SvcHub(object):
sys.exit(ret)
- def _log_disabled(self, src, msg, c=0):
+ def _log_disabled(self, src: str, msg: str, c: Union[int, str] = 0) -> None:
if not self.logf:
return
@@ -349,8 +358,8 @@ class SvcHub(object):
if now >= self.next_day:
self._set_next_day()
- def _set_next_day(self):
- if self.next_day and self.logf and self.logf.base_fn != self._logname():
+ def _set_next_day(self) -> None:
+ if self.next_day and self.logf and self.logf_base_fn != self._logname():
self.logf.close()
self._setup_logfile("")
@@ -364,7 +373,7 @@ class SvcHub(object):
dt = dt.replace(hour=0, minute=0, second=0)
self.next_day = calendar.timegm(dt.utctimetuple())
- def _log_enabled(self, src, msg, c=0):
+ def _log_enabled(self, src: str, msg: str, c: Union[int, str] = 0) -> None:
"""handles logging from all components"""
with self.log_mutex:
now = time.time()
@@ -401,7 +410,7 @@ class SvcHub(object):
if self.logf:
self.logf.write(msg)
- def check_mp_support(self):
+ def check_mp_support(self) -> str:
vmin = sys.version_info[1]
if WINDOWS:
msg = "need python 3.3 or newer for multiprocessing;"
@@ -415,16 +424,16 @@ class SvcHub(object):
return msg
try:
- x = mp.Queue(1)
- x.put(["foo", "bar"])
+ x: mp.Queue[tuple[str, str]] = mp.Queue(1)
+ x.put(("foo", "bar"))
if x.get()[0] != "foo":
raise Exception()
except:
return "multiprocessing is not supported on your platform;"
- return None
+ return ""
- def check_mp_enable(self):
+ def check_mp_enable(self) -> bool:
if self.args.j == 1:
return False
@@ -447,18 +456,18 @@ class SvcHub(object):
self.log("svchub", "cannot efficiently use multiple CPU cores")
return False
- def sd_notify(self):
+ def sd_notify(self) -> None:
try:
- addr = os.getenv("NOTIFY_SOCKET")
- if not addr:
+ zb = os.getenv("NOTIFY_SOCKET")
+ if not zb:
return
- addr = unicode(addr)
+ addr = unicode(zb)
if addr.startswith("@"):
addr = "\0" + addr[1:]
- m = "".join(x for x in addr if x in string.printable)
- self.log("sd_notify", m)
+ t = "".join(x for x in addr if x in string.printable)
+ self.log("sd_notify", t)
sck = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
sck.connect(addr)
diff --git a/copyparty/szip.py b/copyparty/szip.py
index 177ff3ea..178f8a61 100644
--- a/copyparty/szip.py
+++ b/copyparty/szip.py
@@ -1,16 +1,23 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
+import calendar
import time
import zlib
-import calendar
-from .sutil import errdesc
-from .util import yieldfile, sanitize_fn, spack, sunpack, min_ex
from .bos import bos
+from .sutil import StreamArc, errdesc
+from .util import min_ex, sanitize_fn, spack, sunpack, yieldfile
+
+try:
+ from typing import Any, Generator, Optional
+
+ from .util import NamedLogger
+except:
+ pass
-def dostime2unix(buf):
+def dostime2unix(buf: bytes) -> int:
t, d = sunpack(b" bytes:
tt = time.gmtime(ts + 1)
dy, dm, dd, th, tm, ts = list(tt)[:6]
@@ -41,14 +48,22 @@ def unixtime2dos(ts):
return b"\x00\x00\x21\x00"
-def gen_fdesc(sz, crc32, z64):
+def gen_fdesc(sz: int, crc32: int, z64: bool) -> bytes:
ret = b"\x50\x4b\x07\x08"
fmt = b" bytes:
"""
does regular file headers
and the central directory meme if h_pos is set
@@ -67,8 +82,8 @@ def gen_hdr(h_pos, fn, sz, lastmod, utf8, crc32, pre_crc):
# confusingly this doesn't bump if h_pos
req_ver = b"\x2d\x00" if z64 else b"\x0a\x00"
- if crc32:
- crc32 = spack(b" tuple[bytes, bool]:
"""
summary of all file headers,
usually the zipfile footer unless something clamps
@@ -154,10 +171,12 @@ def gen_ecdr(items, cdir_pos, cdir_end):
# 2b comment length
ret += b"\x00\x00"
- return [ret, need_64]
+ return ret, need_64
-def gen_ecdr64(items, cdir_pos, cdir_end):
+def gen_ecdr64(
+ items: list[tuple[str, int, int, int, int]], cdir_pos: int, cdir_end: int
+) -> bytes:
"""
z64 end of central directory
added when numfiles or a headerptr clamps
@@ -181,7 +200,7 @@ def gen_ecdr64(items, cdir_pos, cdir_end):
return ret
-def gen_ecdr64_loc(ecdr64_pos):
+def gen_ecdr64_loc(ecdr64_pos: int) -> bytes:
"""
z64 end of central directory locator
points to ecdr64
@@ -196,21 +215,27 @@ def gen_ecdr64_loc(ecdr64_pos):
return ret
-class StreamZip(object):
- def __init__(self, log, fgen, utf8=False, pre_crc=False):
- self.log = log
- self.fgen = fgen
+class StreamZip(StreamArc):
+ def __init__(
+ self,
+ log: NamedLogger,
+ fgen: Generator[dict[str, Any], None, None],
+ utf8: bool = False,
+ pre_crc: bool = False,
+ ) -> None:
+ super(StreamZip, self).__init__(log, fgen)
+
self.utf8 = utf8
self.pre_crc = pre_crc
self.pos = 0
- self.items = []
+ self.items: list[tuple[str, int, int, int, int]] = []
- def _ct(self, buf):
+ def _ct(self, buf: bytes) -> bytes:
self.pos += len(buf)
return buf
- def ser(self, f):
+ def ser(self, f: dict[str, Any]) -> Generator[bytes, None, None]:
name = f["vp"]
src = f["ap"]
st = f["st"]
@@ -218,9 +243,8 @@ class StreamZip(object):
sz = st.st_size
ts = st.st_mtime
- crc = None
+ crc = 0
if self.pre_crc:
- crc = 0
for buf in yieldfile(src):
crc = zlib.crc32(buf, crc)
@@ -230,7 +254,6 @@ class StreamZip(object):
buf = gen_hdr(None, name, sz, ts, self.utf8, crc, self.pre_crc)
yield self._ct(buf)
- crc = crc or 0
for buf in yieldfile(src):
if not self.pre_crc:
crc = zlib.crc32(buf, crc)
@@ -239,7 +262,7 @@ class StreamZip(object):
crc &= 0xFFFFFFFF
- self.items.append([name, sz, ts, crc, h_pos])
+ self.items.append((name, sz, ts, crc, h_pos))
z64 = sz >= 4 * 1024 * 1024 * 1024
@@ -247,11 +270,11 @@ class StreamZip(object):
buf = gen_fdesc(sz, crc, z64)
yield self._ct(buf)
- def gen(self):
+ def gen(self) -> Generator[bytes, None, None]:
errors = []
for f in self.fgen:
if "err" in f:
- errors.append([f["vp"], f["err"]])
+ errors.append((f["vp"], f["err"]))
continue
try:
@@ -259,7 +282,7 @@ class StreamZip(object):
yield x
except:
ex = min_ex(5, True).replace("\n", "\n-- ")
- errors.append([f["vp"], ex])
+ errors.append((f["vp"], ex))
if errors:
errf, txt = errdesc(errors)
diff --git a/copyparty/tcpsrv.py b/copyparty/tcpsrv.py
index 7d2cb3e4..ae4c5f86 100644
--- a/copyparty/tcpsrv.py
+++ b/copyparty/tcpsrv.py
@@ -2,12 +2,15 @@
from __future__ import print_function, unicode_literals
import re
-import sys
import socket
+import sys
-from .__init__ import MACOS, ANYWIN, unicode
+from .__init__ import ANYWIN, MACOS, TYPE_CHECKING, unicode
from .util import chkcmd
+if TYPE_CHECKING:
+ from .svchub import SvcHub
+
class TcpSrv(object):
"""
@@ -15,16 +18,16 @@ class TcpSrv(object):
which then uses the least busy HttpSrv to handle it
"""
- def __init__(self, hub):
+ def __init__(self, hub: "SvcHub"):
self.hub = hub
self.args = hub.args
self.log = hub.log
self.stopping = False
- self.srv = []
+ self.srv: list[socket.socket] = []
self.nsrv = 0
- ok = {}
+ ok: dict[str, list[int]] = {}
for ip in self.args.i:
ok[ip] = []
for port in self.args.p:
@@ -34,8 +37,8 @@ class TcpSrv(object):
ok[ip].append(port)
except Exception as ex:
if self.args.ign_ebind or self.args.ign_ebind_all:
- m = "could not listen on {}:{}: {}"
- self.log("tcpsrv", m.format(ip, port, ex), c=3)
+ t = "could not listen on {}:{}: {}"
+ self.log("tcpsrv", t.format(ip, port, ex), c=3)
else:
raise
@@ -55,9 +58,9 @@ class TcpSrv(object):
eps[x] = "external"
msgs = []
- title_tab = {}
+ title_tab: dict[str, dict[str, int]] = {}
title_vars = [x[1:] for x in self.args.wintitle.split(" ") if x.startswith("$")]
- m = "available @ {}://{}:{}/ (\033[33m{}\033[0m)"
+ t = "available @ {}://{}:{}/ (\033[33m{}\033[0m)"
for ip, desc in sorted(eps.items(), key=lambda x: x[1]):
for port in sorted(self.args.p):
if port not in ok.get(ip, ok.get("0.0.0.0", [])):
@@ -69,7 +72,7 @@ class TcpSrv(object):
elif self.args.https_only or port == 443:
proto = "https"
- msgs.append(m.format(proto, ip, port, desc))
+ msgs.append(t.format(proto, ip, port, desc))
if not self.args.wintitle:
continue
@@ -98,13 +101,13 @@ class TcpSrv(object):
if msgs:
msgs[-1] += "\n"
- for m in msgs:
- self.log("tcpsrv", m)
+ for t in msgs:
+ self.log("tcpsrv", t)
if self.args.wintitle:
self._set_wintitle(title_tab)
- def _listen(self, ip, port):
+ def _listen(self, ip: str, port: int) -> None:
srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
srv.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
@@ -120,7 +123,7 @@ class TcpSrv(object):
raise
raise Exception(e)
- def run(self):
+ def run(self) -> None:
for srv in self.srv:
srv.listen(self.args.nc)
ip, port = srv.getsockname()
@@ -130,9 +133,9 @@ class TcpSrv(object):
if self.args.q:
print(msg)
- self.hub.broker.put(False, "listen", srv)
+ self.hub.broker.say("listen", srv)
- def shutdown(self):
+ def shutdown(self) -> None:
self.stopping = True
try:
for srv in self.srv:
@@ -142,14 +145,14 @@ class TcpSrv(object):
self.log("tcpsrv", "ok bye")
- def ips_linux_ifconfig(self):
+ def ips_linux_ifconfig(self) -> dict[str, str]:
# for termux
try:
txt, _ = chkcmd(["ifconfig"])
except:
return {}
- eps = {}
+ eps: dict[str, str] = {}
dev = None
ip = None
up = None
@@ -171,7 +174,7 @@ class TcpSrv(object):
return eps
- def ips_linux(self):
+ def ips_linux(self) -> dict[str, str]:
try:
txt, _ = chkcmd(["ip", "addr"])
except:
@@ -180,21 +183,21 @@ class TcpSrv(object):
r = re.compile(r"^\s+inet ([^ ]+)/.* (.*)")
ri = re.compile(r"^\s*[0-9]+\s*:.*")
up = False
- eps = {}
+ eps: dict[str, str] = {}
for ln in txt.split("\n"):
if ri.match(ln):
up = "UP" in re.split("[>,< ]", ln)
try:
- ip, dev = r.match(ln.rstrip()).groups()
+ ip, dev = r.match(ln.rstrip()).groups() # type: ignore
eps[ip] = dev + ("" if up else ", \033[31mLINK-DOWN")
except:
pass
return eps
- def ips_macos(self):
- eps = {}
+ def ips_macos(self) -> dict[str, str]:
+ eps: dict[str, str] = {}
try:
txt, _ = chkcmd(["ifconfig"])
except:
@@ -202,7 +205,7 @@ class TcpSrv(object):
rdev = re.compile(r"^([^ ]+):")
rip = re.compile(r"^\tinet ([0-9\.]+) ")
- dev = None
+ dev = "UNKNOWN"
for ln in txt.split("\n"):
m = rdev.match(ln)
if m:
@@ -211,17 +214,17 @@ class TcpSrv(object):
m = rip.match(ln)
if m:
eps[m.group(1)] = dev
- dev = None
+ dev = "UNKNOWN"
return eps
- def ips_windows_ipconfig(self):
- eps = {}
- offs = {}
+ def ips_windows_ipconfig(self) -> tuple[dict[str, str], set[str]]:
+ eps: dict[str, str] = {}
+ offs: set[str] = set()
try:
txt, _ = chkcmd(["ipconfig"])
except:
- return eps
+ return eps, offs
rdev = re.compile(r"(^[^ ].*):$")
rip = re.compile(r"^ +IPv?4? [^:]+: *([0-9\.]{7,15})$")
@@ -231,12 +234,12 @@ class TcpSrv(object):
m = rdev.match(ln)
if m:
if dev and dev not in eps.values():
- offs[dev] = 1
+ offs.add(dev)
dev = m.group(1).split(" adapter ", 1)[-1]
if dev and roff.match(ln):
- offs[dev] = 1
+ offs.add(dev)
dev = None
m = rip.match(ln)
@@ -245,12 +248,12 @@ class TcpSrv(object):
dev = None
if dev and dev not in eps.values():
- offs[dev] = 1
+ offs.add(dev)
return eps, offs
- def ips_windows_netsh(self):
- eps = {}
+ def ips_windows_netsh(self) -> dict[str, str]:
+ eps: dict[str, str] = {}
try:
txt, _ = chkcmd("netsh interface ip show address".split())
except:
@@ -270,7 +273,7 @@ class TcpSrv(object):
return eps
- def detect_interfaces(self, listen_ips):
+ def detect_interfaces(self, listen_ips: list[str]) -> dict[str, str]:
if MACOS:
eps = self.ips_macos()
elif ANYWIN:
@@ -317,7 +320,7 @@ class TcpSrv(object):
return eps
- def _set_wintitle(self, vs):
+ def _set_wintitle(self, vs: dict[str, dict[str, int]]) -> None:
vs["all"] = vs.get("all", {"Local-Only": 1})
vs["pub"] = vs.get("pub", vs["all"])
diff --git a/copyparty/th_cli.py b/copyparty/th_cli.py
index e8a18e0e..9eb49a4f 100644
--- a/copyparty/th_cli.py
+++ b/copyparty/th_cli.py
@@ -3,13 +3,23 @@ from __future__ import print_function, unicode_literals
import os
-from .util import Cooldown
-from .th_srv import thumb_path, HAVE_WEBP
+from .__init__ import TYPE_CHECKING
+from .authsrv import VFS
from .bos import bos
+from .th_srv import HAVE_WEBP, thumb_path
+from .util import Cooldown
+
+try:
+ from typing import Optional, Union
+except:
+ pass
+
+if TYPE_CHECKING:
+ from .httpsrv import HttpSrv
class ThumbCli(object):
- def __init__(self, hsrv):
+ def __init__(self, hsrv: "HttpSrv") -> None:
self.broker = hsrv.broker
self.log_func = hsrv.log
self.args = hsrv.args
@@ -34,10 +44,10 @@ class ThumbCli(object):
d = next((x for x in self.args.th_dec if x in ("vips", "pil")), None)
self.can_webp = HAVE_WEBP or d == "vips"
- def log(self, msg, c=0):
+ def log(self, msg: str, c: Union[int, str] = 0) -> None:
self.log_func("thumbcli", msg, c)
- def get(self, dbv, rem, mtime, fmt):
+ def get(self, dbv: VFS, rem: str, mtime: float, fmt: str) -> Optional[str]:
ptop = dbv.realpath
ext = rem.rsplit(".")[-1].lower()
if ext not in self.thumbable or "dthumb" in dbv.flags:
@@ -106,17 +116,17 @@ class ThumbCli(object):
if ret:
tdir = os.path.dirname(tpath)
if self.cooldown.poke(tdir):
- self.broker.put(False, "thumbsrv.poke", tdir)
+ self.broker.say("thumbsrv.poke", tdir)
if want_opus:
# audio files expire individually
if self.cooldown.poke(tpath):
- self.broker.put(False, "thumbsrv.poke", tpath)
+ self.broker.say("thumbsrv.poke", tpath)
return ret
if abort:
return None
- x = self.broker.put(True, "thumbsrv.get", ptop, rem, mtime, fmt)
- return x.get()
+ x = self.broker.ask("thumbsrv.get", ptop, rem, mtime, fmt)
+ return x.get() # type: ignore
diff --git a/copyparty/th_srv.py b/copyparty/th_srv.py
index dc3965c7..2a9e5f02 100644
--- a/copyparty/th_srv.py
+++ b/copyparty/th_srv.py
@@ -1,18 +1,28 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
-import os
-import time
-import shutil
import base64
import hashlib
-import threading
+import os
+import shutil
import subprocess as sp
+import threading
+import time
-from .util import fsenc, vsplit, statdir, runcmd, Queue, Cooldown, BytesIO, min_ex
+from queue import Queue
+
+from .__init__ import TYPE_CHECKING
from .bos import bos
from .mtag import HAVE_FFMPEG, HAVE_FFPROBE, ffprobe
+from .util import BytesIO, Cooldown, fsenc, min_ex, runcmd, statdir, vsplit
+try:
+ from typing import Optional, Union
+except:
+ pass
+
+if TYPE_CHECKING:
+ from .svchub import SvcHub
HAVE_PIL = False
HAVE_HEIF = False
@@ -20,7 +30,7 @@ HAVE_AVIF = False
HAVE_WEBP = False
try:
- from PIL import Image, ImageOps, ExifTags
+ from PIL import ExifTags, Image, ImageOps
HAVE_PIL = True
try:
@@ -47,14 +57,13 @@ except:
pass
try:
- import pyvips
-
HAVE_VIPS = True
+ import pyvips
except:
HAVE_VIPS = False
-def thumb_path(histpath, rem, mtime, fmt):
+def thumb_path(histpath: str, rem: str, mtime: float, fmt: str) -> str:
# base16 = 16 = 256
# b64-lc = 38 = 1444
# base64 = 64 = 4096
@@ -80,7 +89,7 @@ def thumb_path(histpath, rem, mtime, fmt):
class ThumbSrv(object):
- def __init__(self, hub):
+ def __init__(self, hub: "SvcHub") -> None:
self.hub = hub
self.asrv = hub.asrv
self.args = hub.args
@@ -91,17 +100,17 @@ class ThumbSrv(object):
self.poke_cd = Cooldown(self.args.th_poke)
self.mutex = threading.Lock()
- self.busy = {}
+ self.busy: dict[str, list[threading.Condition]] = {}
self.stopping = False
self.nthr = max(1, self.args.th_mt)
- self.q = Queue(self.nthr * 4)
+ self.q: Queue[Optional[tuple[str, str]]] = Queue(self.nthr * 4)
for n in range(self.nthr):
- t = threading.Thread(
+ thr = threading.Thread(
target=self.worker, name="thumb-{}-{}".format(n, self.nthr)
)
- t.daemon = True
- t.start()
+ thr.daemon = True
+ thr.start()
want_ff = not self.args.no_vthumb or not self.args.no_athumb
if want_ff and (not HAVE_FFMPEG or not HAVE_FFPROBE):
@@ -122,7 +131,7 @@ class ThumbSrv(object):
t.start()
self.fmt_pil, self.fmt_vips, self.fmt_ffi, self.fmt_ffv, self.fmt_ffa = [
- {x: True for x in y.split(",")}
+ set(y.split(","))
for y in [
self.args.th_r_pil,
self.args.th_r_vips,
@@ -134,37 +143,37 @@ class ThumbSrv(object):
if not HAVE_HEIF:
for f in "heif heifs heic heics".split(" "):
- self.fmt_pil.pop(f, None)
+ self.fmt_pil.discard(f)
if not HAVE_AVIF:
for f in "avif avifs".split(" "):
- self.fmt_pil.pop(f, None)
+ self.fmt_pil.discard(f)
- self.thumbable = {}
+ self.thumbable: set[str] = set()
if "pil" in self.args.th_dec:
- self.thumbable.update(self.fmt_pil)
+ self.thumbable |= self.fmt_pil
if "vips" in self.args.th_dec:
- self.thumbable.update(self.fmt_vips)
+ self.thumbable |= self.fmt_vips
if "ff" in self.args.th_dec:
- for t in [self.fmt_ffi, self.fmt_ffv, self.fmt_ffa]:
- self.thumbable.update(t)
+ for zss in [self.fmt_ffi, self.fmt_ffv, self.fmt_ffa]:
+ self.thumbable |= zss
- def log(self, msg, c=0):
+ def log(self, msg: str, c: Union[int, str] = 0) -> None:
self.log_func("thumb", msg, c)
- def shutdown(self):
+ def shutdown(self) -> None:
self.stopping = True
for _ in range(self.nthr):
self.q.put(None)
- def stopped(self):
+ def stopped(self) -> bool:
with self.mutex:
return not self.nthr
- def get(self, ptop, rem, mtime, fmt):
+ def get(self, ptop: str, rem: str, mtime: float, fmt: str) -> Optional[str]:
histpath = self.asrv.vfs.histtab.get(ptop)
if not histpath:
self.log("no histpath for [{}]".format(ptop))
@@ -191,7 +200,7 @@ class ThumbSrv(object):
do_conv = True
if do_conv:
- self.q.put([abspath, tpath])
+ self.q.put((abspath, tpath))
self.log("conv {} \033[0m{}".format(tpath, abspath), c=6)
while not self.stopping:
@@ -212,7 +221,7 @@ class ThumbSrv(object):
return None
- def getcfg(self):
+ def getcfg(self) -> dict[str, set[str]]:
return {
"thumbable": self.thumbable,
"pil": self.fmt_pil,
@@ -222,7 +231,7 @@ class ThumbSrv(object):
"ffa": self.fmt_ffa,
}
- def worker(self):
+ def worker(self) -> None:
while not self.stopping:
task = self.q.get()
if not task:
@@ -253,7 +262,7 @@ class ThumbSrv(object):
except:
msg = "{} could not create thumbnail of {}\n{}"
msg = msg.format(fun.__name__, abspath, min_ex())
- c = 1 if " "Image.Image":
# exif_transpose is expensive (loads full image + unconditional copy)
r = max(*self.res) * 2
im.thumbnail((r, r), resample=Image.LANCZOS)
@@ -295,7 +304,7 @@ class ThumbSrv(object):
return im
- def conv_pil(self, abspath, tpath):
+ def conv_pil(self, abspath: str, tpath: str) -> None:
with Image.open(fsenc(abspath)) as im:
try:
im = self.fancy_pillow(im)
@@ -324,7 +333,7 @@ class ThumbSrv(object):
im.save(tpath, **args)
- def conv_vips(self, abspath, tpath):
+ def conv_vips(self, abspath: str, tpath: str) -> None:
crops = ["centre", "none"]
if self.args.th_no_crop:
crops = ["none"]
@@ -342,18 +351,17 @@ class ThumbSrv(object):
img.write_to_file(tpath, Q=40)
- def conv_ffmpeg(self, abspath, tpath):
+ def conv_ffmpeg(self, abspath: str, tpath: str) -> None:
ret, _ = ffprobe(abspath)
if not ret:
return
ext = abspath.rsplit(".")[-1].lower()
if ext in ["h264", "h265"] or ext in self.fmt_ffi:
- seek = []
+ seek: list[bytes] = []
else:
dur = ret[".dur"][1] if ".dur" in ret else 4
- seek = "{:.0f}".format(dur / 3)
- seek = [b"-ss", seek.encode("utf-8")]
+ seek = [b"-ss", "{:.0f}".format(dur / 3).encode("utf-8")]
scale = "scale={0}:{1}:force_original_aspect_ratio="
if self.args.th_no_crop:
@@ -361,7 +369,7 @@ class ThumbSrv(object):
else:
scale += "increase,crop={0}:{1},setsar=1:1"
- scale = scale.format(*list(self.res)).encode("utf-8")
+ bscale = scale.format(*list(self.res)).encode("utf-8")
# fmt: off
cmd = [
b"ffmpeg",
@@ -373,7 +381,7 @@ class ThumbSrv(object):
cmd += [
b"-i", fsenc(abspath),
b"-map", b"0:v:0",
- b"-vf", scale,
+ b"-vf", bscale,
b"-frames:v", b"1",
b"-metadata:s:v:0", b"rotate=0",
]
@@ -395,14 +403,14 @@ class ThumbSrv(object):
cmd += [fsenc(tpath)]
self._run_ff(cmd)
- def _run_ff(self, cmd):
+ def _run_ff(self, cmd: list[bytes]) -> None:
# self.log((b" ".join(cmd)).decode("utf-8"))
ret, _, serr = runcmd(cmd, timeout=self.args.th_convt)
if not ret:
return
- c = "1;30"
- m = "FFmpeg failed (probably a corrupt video file):\n"
+ c: Union[str, int] = "1;30"
+ t = "FFmpeg failed (probably a corrupt video file):\n"
if cmd[-1].lower().endswith(b".webp") and (
"Error selecting an encoder" in serr
or "Automatic encoder selection failed" in serr
@@ -410,14 +418,14 @@ class ThumbSrv(object):
or "Please choose an encoder manually" in serr
):
self.args.th_ff_jpg = True
- m = "FFmpeg failed because it was compiled without libwebp; enabling --th-ff-jpg to force jpeg output:\n"
+ t = "FFmpeg failed because it was compiled without libwebp; enabling --th-ff-jpg to force jpeg output:\n"
c = 1
if (
"Requested resampling engine is unavailable" in serr
or "output pad on Parsed_aresample_" in serr
):
- m = "FFmpeg failed because it was compiled without libsox; you must set --th-ff-swr to force swr resampling:\n"
+ t = "FFmpeg failed because it was compiled without libsox; you must set --th-ff-swr to force swr resampling:\n"
c = 1
lines = serr.strip("\n").split("\n")
@@ -428,10 +436,10 @@ class ThumbSrv(object):
if len(txt) > 5000:
txt = txt[:2500] + "...\nff: [...]\nff: ..." + txt[-2500:]
- self.log(m + txt, c=c)
+ self.log(t + txt, c=c)
raise sp.CalledProcessError(ret, (cmd[0], b"...", cmd[-1]))
- def conv_spec(self, abspath, tpath):
+ def conv_spec(self, abspath: str, tpath: str) -> None:
ret, _ = ffprobe(abspath)
if "ac" not in ret:
raise Exception("not audio")
@@ -473,7 +481,7 @@ class ThumbSrv(object):
cmd += [fsenc(tpath)]
self._run_ff(cmd)
- def conv_opus(self, abspath, tpath):
+ def conv_opus(self, abspath: str, tpath: str) -> None:
if self.args.no_acode:
raise Exception("disabled in server config")
@@ -521,7 +529,7 @@ class ThumbSrv(object):
# fmt: on
self._run_ff(cmd)
- def poke(self, tdir):
+ def poke(self, tdir: str) -> None:
if not self.poke_cd.poke(tdir):
return
@@ -533,7 +541,7 @@ class ThumbSrv(object):
except:
pass
- def cleaner(self):
+ def cleaner(self) -> None:
interval = self.args.th_clean
while True:
time.sleep(interval)
@@ -548,14 +556,14 @@ class ThumbSrv(object):
self.log("\033[Jcln ok; rm {} dirs".format(ndirs))
- def clean(self, histpath):
+ def clean(self, histpath: str) -> int:
ret = 0
for cat in ["th", "ac"]:
- ret += self._clean(histpath, cat, None)
+ ret += self._clean(histpath, cat, "")
return ret
- def _clean(self, histpath, cat, thumbpath):
+ def _clean(self, histpath: str, cat: str, thumbpath: str) -> int:
if not thumbpath:
thumbpath = os.path.join(histpath, cat)
@@ -564,10 +572,10 @@ class ThumbSrv(object):
maxage = getattr(self.args, cat + "_maxage")
now = time.time()
prev_b64 = None
- prev_fp = None
+ prev_fp = ""
try:
- ents = statdir(self.log, not self.args.no_scandir, False, thumbpath)
- ents = sorted(list(ents))
+ t1 = statdir(self.log_func, not self.args.no_scandir, False, thumbpath)
+ ents = sorted(list(t1))
except:
return 0
diff --git a/copyparty/u2idx.py b/copyparty/u2idx.py
index 073ed3d1..8a342042 100644
--- a/copyparty/u2idx.py
+++ b/copyparty/u2idx.py
@@ -1,34 +1,37 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
-import re
-import os
-import time
import calendar
+import os
+import re
import threading
+import time
from operator import itemgetter
-from .__init__ import ANYWIN, unicode
-from .util import absreal, s3dec, Pebkac, min_ex, gen_filekey, quotep
+from .__init__ import ANYWIN, TYPE_CHECKING, unicode
from .bos import bos
from .up2k import up2k_wark_from_hashlist
+from .util import HAVE_SQLITE3, Pebkac, absreal, gen_filekey, min_ex, quotep, s3dec
-
-try:
- HAVE_SQLITE3 = True
+if HAVE_SQLITE3:
import sqlite3
-except:
- HAVE_SQLITE3 = False
-
try:
from pathlib import Path
except:
pass
+try:
+ from typing import Any, Optional, Union
+except:
+ pass
+
+if TYPE_CHECKING:
+ from .httpconn import HttpConn
+
class U2idx(object):
- def __init__(self, conn):
+ def __init__(self, conn: "HttpConn") -> None:
self.log_func = conn.log_func
self.asrv = conn.asrv
self.args = conn.args
@@ -38,19 +41,21 @@ class U2idx(object):
self.log("your python does not have sqlite3; searching will be disabled")
return
- self.active_id = None
- self.active_cur = None
- self.cur = {}
- self.mem_cur = sqlite3.connect(":memory:")
+ self.active_id = ""
+ self.active_cur: Optional["sqlite3.Cursor"] = None
+ self.cur: dict[str, "sqlite3.Cursor"] = {}
+ self.mem_cur = sqlite3.connect(":memory:").cursor()
self.mem_cur.execute(r"create table a (b text)")
- self.p_end = None
- self.p_dur = 0
+ self.p_end = 0.0
+ self.p_dur = 0.0
- def log(self, msg, c=0):
+ def log(self, msg: str, c: Union[int, str] = 0) -> None:
self.log_func("u2idx", msg, c)
- def fsearch(self, vols, body):
+ def fsearch(
+ self, vols: list[tuple[str, str, dict[str, Any]]], body: dict[str, Any]
+ ) -> list[dict[str, Any]]:
"""search by up2k hashlist"""
if not HAVE_SQLITE3:
return []
@@ -60,14 +65,14 @@ class U2idx(object):
wark = up2k_wark_from_hashlist(self.args.salt, fsize, fhash)
uq = "substr(w,1,16) = ? and w = ?"
- uv = [wark[:16], wark]
+ uv: list[Union[str, int]] = [wark[:16], wark]
try:
return self.run_query(vols, uq, uv, True, False, 99999)[0]
except:
raise Pebkac(500, min_ex())
- def get_cur(self, ptop):
+ def get_cur(self, ptop: str) -> Optional["sqlite3.Cursor"]:
if not HAVE_SQLITE3:
return None
@@ -103,13 +108,16 @@ class U2idx(object):
self.cur[ptop] = cur
return cur
- def search(self, vols, uq, lim):
+ def search(
+ self, vols: list[tuple[str, str, dict[str, Any]]], uq: str, lim: int
+ ) -> tuple[list[dict[str, Any]], list[str]]:
"""search by query params"""
if not HAVE_SQLITE3:
- return []
+ return [], []
q = ""
- va = []
+ v: Union[str, int] = ""
+ va: list[Union[str, int]] = []
have_up = False # query has up.* operands
have_mt = False
is_key = True
@@ -202,7 +210,7 @@ class U2idx(object):
"%Y",
]:
try:
- v = calendar.timegm(time.strptime(v, fmt))
+ v = calendar.timegm(time.strptime(str(v), fmt))
break
except:
pass
@@ -230,11 +238,12 @@ class U2idx(object):
# lowercase tag searches
m = ptn_lc.search(q)
- if not m or not ptn_lcv.search(unicode(v)):
+ zs = unicode(v)
+ if not m or not ptn_lcv.search(zs):
continue
va.pop()
- va.append(v.lower())
+ va.append(zs.lower())
q = q[: m.start()]
field, oper = m.groups()
@@ -248,8 +257,16 @@ class U2idx(object):
except Exception as ex:
raise Pebkac(500, repr(ex))
- def run_query(self, vols, uq, uv, have_up, have_mt, lim):
- done_flag = []
+ def run_query(
+ self,
+ vols: list[tuple[str, str, dict[str, Any]]],
+ uq: str,
+ uv: list[Union[str, int]],
+ have_up: bool,
+ have_mt: bool,
+ lim: int,
+ ) -> tuple[list[dict[str, Any]], list[str]]:
+ done_flag: list[bool] = []
self.active_id = "{:.6f}_{}".format(
time.time(), threading.current_thread().ident
)
@@ -266,13 +283,11 @@ class U2idx(object):
if not uq or not uv:
uq = "select * from up"
- uv = ()
+ uv = []
elif have_mt:
uq = "select up.*, substr(up.w,1,16) mtw from up where " + uq
- uv = tuple(uv)
else:
uq = "select up.* from up where " + uq
- uv = tuple(uv)
self.log("qs: {!r} {!r}".format(uq, uv))
@@ -292,11 +307,10 @@ class U2idx(object):
v = vtop + "/"
vuv.append(v)
- vuv = tuple(vuv)
sret = []
fk = flags.get("fk")
- c = cur.execute(uq, vuv)
+ c = cur.execute(uq, tuple(vuv))
for hit in c:
w, ts, sz, rd, fn, ip, at = hit[:7]
lim -= 1
@@ -340,7 +354,7 @@ class U2idx(object):
# print("[{}] {}".format(ptop, sret))
done_flag.append(True)
- self.active_id = None
+ self.active_id = ""
# undupe hits from multiple metadata keys
if len(ret) > 1:
@@ -354,11 +368,12 @@ class U2idx(object):
return ret, list(taglist.keys())
- def terminator(self, identifier, done_flag):
+ def terminator(self, identifier: str, done_flag: list[bool]) -> None:
for _ in range(self.timeout):
time.sleep(1)
if done_flag:
return
if identifier == self.active_id:
+ assert self.active_cur
self.active_cur.connection.interrupt()
diff --git a/copyparty/up2k.py b/copyparty/up2k.py
index 3f4e02e4..aff68917 100644
--- a/copyparty/up2k.py
+++ b/copyparty/up2k.py
@@ -1,60 +1,83 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
-import re
-import os
-import time
-import math
-import json
-import gzip
-import stat
-import shutil
import base64
+import gzip
import hashlib
-import threading
-import traceback
+import json
+import math
+import os
+import re
+import shutil
+import stat
import subprocess as sp
+import threading
+import time
+import traceback
from copy import deepcopy
-from .__init__ import WINDOWS, ANYWIN, PY2
-from .util import (
- Pebkac,
- Queue,
- ProgressPrinter,
- SYMTIME,
- fsenc,
- absreal,
- sanitize_fn,
- ren_open,
- atomic_move,
- quotep,
- vsplit,
- w8b64enc,
- w8b64dec,
- s3enc,
- s3dec,
- rmdirs,
- statdir,
- s2hms,
- min_ex,
-)
-from .bos import bos
-from .authsrv import AuthSrv, LEELOO_DALLAS
-from .mtag import MTag, MParser
+from queue import Queue
-try:
- HAVE_SQLITE3 = True
+from .__init__ import ANYWIN, PY2, TYPE_CHECKING, WINDOWS
+from .authsrv import LEELOO_DALLAS, VFS, AuthSrv
+from .bos import bos
+from .mtag import MParser, MTag
+from .util import (
+ HAVE_SQLITE3,
+ SYMTIME,
+ Pebkac,
+ ProgressPrinter,
+ absreal,
+ atomic_move,
+ fsenc,
+ min_ex,
+ quotep,
+ ren_open,
+ rmdirs,
+ s2hms,
+ s3dec,
+ s3enc,
+ sanitize_fn,
+ statdir,
+ vsplit,
+ w8b64dec,
+ w8b64enc,
+)
+
+if HAVE_SQLITE3:
import sqlite3
-except:
- HAVE_SQLITE3 = False
DB_VER = 5
+try:
+ from typing import Any, Optional, Pattern, Union
+except:
+ pass
+
+if TYPE_CHECKING:
+ from .svchub import SvcHub
+
+
+class Dbw(object):
+ def __init__(self, c: "sqlite3.Cursor", n: int, t: float) -> None:
+ self.c = c
+ self.n = n
+ self.t = t
+
+
+class Mpqe(object):
+ def __init__(self, mtp: dict[str, MParser], entags: set[str], w: str, abspath: str):
+ # mtp empty = mtag
+ self.mtp = mtp
+ self.entags = entags
+ self.w = w
+ self.abspath = abspath
+
class Up2k(object):
- def __init__(self, hub):
+ def __init__(self, hub: "SvcHub") -> None:
self.hub = hub
- self.asrv = hub.asrv # type: AuthSrv
+ self.asrv: AuthSrv = hub.asrv
self.args = hub.args
self.log_func = hub.log
@@ -62,25 +85,32 @@ class Up2k(object):
self.salt = self.args.salt
# state
+ self.gid = 0
self.mutex = threading.Lock()
+ self.pp: Optional[ProgressPrinter] = None
self.rescan_cond = threading.Condition()
- self.hashq = Queue()
- self.tagq = Queue()
+ self.need_rescan: set[str] = set()
+
+ self.registry: dict[str, dict[str, dict[str, Any]]] = {}
+ self.flags: dict[str, dict[str, Any]] = {}
+ self.droppable: dict[str, list[str]] = {}
+ self.volstate: dict[str, str] = {}
+ self.dupesched: dict[str, list[tuple[str, str, float]]] = {}
+ self.snap_persist_interval = 300 # persist unfinished index every 5 min
+ self.snap_discard_interval = 21600 # drop unfinished after 6 hours inactivity
+ self.snap_prev: dict[str, Optional[tuple[int, float]]] = {}
+
+ self.mtag: Optional[MTag] = None
+ self.entags: dict[str, set[str]] = {}
+ self.mtp_parsers: dict[str, dict[str, MParser]] = {}
+ self.pending_tags: list[tuple[set[str], str, str, dict[str, Any]]] = []
+ self.hashq: Queue[tuple[str, str, str, str, float]] = Queue()
+ self.tagq: Queue[tuple[str, str, str, str]] = Queue()
self.n_hashq = 0
self.n_tagq = 0
- self.gid = 0
- self.volstate = {}
- self.need_rescan = {}
- self.dupesched = {}
- self.registry = {}
- self.droppable = {}
- self.entags = {}
- self.flags = {}
- self.cur = {}
- self.mtag = None
- self.pending_tags = None
- self.mtp_parsers = {}
+ self.mpool_used = False
+ self.cur: dict[str, "sqlite3.Cursor"] = {}
self.mem_cur = None
self.sqlite_ver = None
self.no_expr_idx = False
@@ -94,7 +124,7 @@ class Up2k(object):
if ANYWIN:
# usually fails to set lastmod too quickly
- self.lastmod_q = []
+ self.lastmod_q: list[tuple[str, int, tuple[int, int]]] = []
thr = threading.Thread(target=self._lastmodder, name="up2k-lastmod")
thr.daemon = True
thr.start()
@@ -108,7 +138,7 @@ class Up2k(object):
if self.args.no_fastboot:
self.deferred_init()
- def init_vols(self):
+ def init_vols(self) -> None:
if self.args.no_fastboot:
return
@@ -116,15 +146,15 @@ class Up2k(object):
t.daemon = True
t.start()
- def reload(self):
+ def reload(self) -> None:
self.gid += 1
self.log("reload #{} initiated".format(self.gid))
all_vols = self.asrv.vfs.all_vols
self.rescan(all_vols, list(all_vols.keys()), True)
- def deferred_init(self):
+ def deferred_init(self) -> None:
all_vols = self.asrv.vfs.all_vols
- have_e2d = self.init_indexes(all_vols)
+ have_e2d = self.init_indexes(all_vols, [])
thr = threading.Thread(target=self._snapshot, name="up2k-snapshot")
thr.daemon = True
@@ -150,11 +180,11 @@ class Up2k(object):
thr.daemon = True
thr.start()
- def log(self, msg, c=0):
+ def log(self, msg: str, c: Union[int, str] = 0) -> None:
self.log_func("up2k", msg + "\033[K", c)
- def get_state(self):
- mtpq = 0
+ def get_state(self) -> str:
+ mtpq: Union[int, str] = 0
q = "select count(w) from mt where k = 't:mtp'"
got_lock = False if PY2 else self.mutex.acquire(timeout=0.5)
if got_lock:
@@ -165,19 +195,19 @@ class Up2k(object):
pass
self.mutex.release()
else:
- mtpq = "?"
+ mtpq = "(?)"
ret = {
"volstate": self.volstate,
- "scanning": hasattr(self, "pp"),
+ "scanning": bool(self.pp),
"hashq": self.n_hashq,
"tagq": self.n_tagq,
"mtpq": mtpq,
}
return json.dumps(ret, indent=4)
- def rescan(self, all_vols, scan_vols, wait):
- if not wait and hasattr(self, "pp"):
+ def rescan(self, all_vols: dict[str, VFS], scan_vols: list[str], wait: bool) -> str:
+ if not wait and self.pp:
return "cannot initiate; scan is already in progress"
args = (all_vols, scan_vols)
@@ -188,11 +218,11 @@ class Up2k(object):
)
t.daemon = True
t.start()
- return None
+ return ""
- def _sched_rescan(self):
+ def _sched_rescan(self) -> None:
volage = {}
- cooldown = 0
+ cooldown = 0.0
timeout = time.time() + 3
while True:
timeout = max(timeout, cooldown)
@@ -204,7 +234,7 @@ class Up2k(object):
if now < cooldown:
continue
- if hasattr(self, "pp"):
+ if self.pp:
cooldown = now + 5
continue
@@ -220,19 +250,19 @@ class Up2k(object):
deadline = volage[vp] + maxage
if deadline <= now:
- self.need_rescan[vp] = 1
+ self.need_rescan.add(vp)
timeout = min(timeout, deadline)
- vols = list(sorted(self.need_rescan.keys()))
- self.need_rescan = {}
+ vols = list(sorted(self.need_rescan))
+ self.need_rescan.clear()
if vols:
cooldown = now + 10
err = self.rescan(self.asrv.vfs.all_vols, vols, False)
if err:
for v in vols:
- self.need_rescan[v] = True
+ self.need_rescan.add(v)
continue
@@ -272,7 +302,7 @@ class Up2k(object):
if vp:
fvp = "{}/{}".format(vp, fvp)
- self._handle_rm(LEELOO_DALLAS, None, fvp)
+ self._handle_rm(LEELOO_DALLAS, "", fvp)
nrm += 1
if nrm:
@@ -288,12 +318,12 @@ class Up2k(object):
if hits:
timeout = min(timeout, now + lifetime - (now - hits[0]))
- def _vis_job_progress(self, job):
+ def _vis_job_progress(self, job: dict[str, Any]) -> str:
perc = 100 - (len(job["need"]) * 100.0 / len(job["hash"]))
path = os.path.join(job["ptop"], job["prel"], job["name"])
return "{:5.1f}% {}".format(perc, path)
- def _vis_reg_progress(self, reg):
+ def _vis_reg_progress(self, reg: dict[str, dict[str, Any]]) -> list[str]:
ret = []
for _, job in reg.items():
if job["need"]:
@@ -301,7 +331,7 @@ class Up2k(object):
return ret
- def _expr_idx_filter(self, flags):
+ def _expr_idx_filter(self, flags: dict[str, Any]) -> tuple[bool, dict[str, Any]]:
if not self.no_expr_idx:
return False, flags
@@ -311,19 +341,19 @@ class Up2k(object):
return True, ret
- def init_indexes(self, all_vols, scan_vols=None):
+ def init_indexes(self, all_vols: dict[str, VFS], scan_vols: list[str]) -> bool:
gid = self.gid
- while hasattr(self, "pp") and gid == self.gid:
+ while self.pp and gid == self.gid:
time.sleep(0.1)
if gid != self.gid:
- return
+ return False
if gid:
self.log("reload #{} running".format(self.gid))
self.pp = ProgressPrinter()
- vols = all_vols.values()
+ vols = list(all_vols.values())
t0 = time.time()
have_e2d = False
@@ -377,9 +407,9 @@ class Up2k(object):
# e2ds(a) volumes first
for vol in vols:
- en = {}
+ en: set[str] = set()
if "mte" in vol.flags:
- en = {k: True for k in vol.flags["mte"].split(",")}
+ en = set(vol.flags["mte"].split(","))
self.entags[vol.realpath] = en
@@ -393,11 +423,11 @@ class Up2k(object):
need_vac[vol] = True
if "e2ts" not in vol.flags:
- m = "online, idle"
+ t = "online, idle"
else:
- m = "online (tags pending)"
+ t = "online (tags pending)"
- self.volstate[vol.vpath] = m
+ self.volstate[vol.vpath] = t
# open the rest + do any e2ts(a)
needed_mutagen = False
@@ -405,9 +435,9 @@ class Up2k(object):
if "e2ts" not in vol.flags:
continue
- m = "online (reading tags)"
- self.volstate[vol.vpath] = m
- self.log("{} [{}]".format(m, vol.realpath))
+ t = "online (reading tags)"
+ self.volstate[vol.vpath] = t
+ self.log("{} [{}]".format(t, vol.realpath))
nadd, nrm, success = self._build_tags_index(vol)
if not success:
@@ -419,7 +449,9 @@ class Up2k(object):
self.volstate[vol.vpath] = "online (mtp soon)"
for vol in need_vac:
- cur, _ = self.register_vpath(vol.realpath, vol.flags)
+ reg = self.register_vpath(vol.realpath, vol.flags)
+ assert reg
+ cur, _ = reg
with self.mutex:
cur.connection.commit()
cur.execute("vacuum")
@@ -435,23 +467,25 @@ class Up2k(object):
thr = None
if self.mtag:
- m = "online (running mtp)"
+ t = "online (running mtp)"
if scan_vols:
thr = threading.Thread(target=self._run_all_mtp, name="up2k-mtp-scan")
thr.daemon = True
else:
- del self.pp
- m = "online, idle"
+ self.pp = None
+ t = "online, idle"
for vol in vols:
- self.volstate[vol.vpath] = m
+ self.volstate[vol.vpath] = t
if thr:
thr.start()
return have_e2d
- def register_vpath(self, ptop, flags):
+ def register_vpath(
+ self, ptop: str, flags: dict[str, Any]
+ ) -> Optional[tuple["sqlite3.Cursor", str]]:
histpath = self.asrv.vfs.histtab.get(ptop)
if not histpath:
self.log("no histpath for [{}]".format(ptop))
@@ -460,7 +494,7 @@ class Up2k(object):
db_path = os.path.join(histpath, "up2k.db")
if ptop in self.registry:
try:
- return [self.cur[ptop], db_path]
+ return self.cur[ptop], db_path
except:
return None
@@ -514,14 +548,14 @@ class Up2k(object):
else:
drp = [x for x in drp if x in reg]
- m = "loaded snap {} |{}| ({})".format(path, len(reg.keys()), len(drp or []))
- m = [m] + self._vis_reg_progress(reg)
- self.log("\n".join(m))
+ t = "loaded snap {} |{}| ({})".format(path, len(reg.keys()), len(drp or []))
+ ta = [t] + self._vis_reg_progress(reg)
+ self.log("\n".join(ta))
self.flags[ptop] = flags
self.registry[ptop] = reg
self.droppable[ptop] = drp or []
- self.regdrop(ptop, None)
+ self.regdrop(ptop, "")
if not HAVE_SQLITE3 or "e2d" not in flags or "d2d" in flags:
return None
@@ -530,23 +564,25 @@ class Up2k(object):
try:
cur = self._open_db(db_path)
self.cur[ptop] = cur
- return [cur, db_path]
+ return cur, db_path
except:
msg = "cannot use database at [{}]:\n{}"
self.log(msg.format(ptop, traceback.format_exc()))
return None
- def _build_file_index(self, vol, all_vols):
+ def _build_file_index(self, vol: VFS, all_vols: list[VFS]) -> tuple[bool, bool]:
do_vac = False
top = vol.realpath
rei = vol.flags.get("noidx")
reh = vol.flags.get("nohash")
with self.mutex:
- cur, _ = self.register_vpath(top, vol.flags)
+ reg = self.register_vpath(top, vol.flags)
+ assert reg and self.pp
+ cur, _ = reg
- dbw = [cur, 0, time.time()]
- self.pp.n = next(dbw[0].execute("select count(w) from up"))[0]
+ db = Dbw(cur, 0, time.time())
+ self.pp.n = next(db.c.execute("select count(w) from up"))[0]
excl = [
vol.realpath + "/" + d.vpath[len(vol.vpath) :].lstrip("/")
@@ -558,37 +594,47 @@ class Up2k(object):
if WINDOWS:
excl = [x.replace("/", "\\") for x in excl]
- excl = set(excl)
rtop = absreal(top)
n_add = n_rm = 0
try:
- n_add = self._build_dir(dbw, top, excl, top, rtop, rei, reh, [])
- n_rm = self._drop_lost(dbw[0], top)
+ n_add = self._build_dir(db, top, set(excl), top, rtop, rei, reh, [])
+ n_rm = self._drop_lost(db.c, top)
except:
- m = "failed to index volume [{}]:\n{}"
- self.log(m.format(top, min_ex()), c=1)
+ t = "failed to index volume [{}]:\n{}"
+ self.log(t.format(top, min_ex()), c=1)
- if dbw[1]:
- self.log("commit {} new files".format(dbw[1]))
+ if db.n:
+ self.log("commit {} new files".format(db.n))
- dbw[0].connection.commit()
+ db.c.connection.commit()
- return True, n_add or n_rm or do_vac
+ return True, bool(n_add or n_rm or do_vac)
- def _build_dir(self, dbw, top, excl, cdir, rcdir, rei, reh, seen):
+ def _build_dir(
+ self,
+ db: Dbw,
+ top: str,
+ excl: set[str],
+ cdir: str,
+ rcdir: str,
+ rei: Optional[Pattern[str]],
+ reh: Optional[Pattern[str]],
+ seen: list[str],
+ ) -> int:
if rcdir in seen:
- m = "bailing from symlink loop,\n prev: {}\n curr: {}\n from: {}"
- self.log(m.format(seen[-1], rcdir, cdir), 3)
+ t = "bailing from symlink loop,\n prev: {}\n curr: {}\n from: {}"
+ self.log(t.format(seen[-1], rcdir, cdir), 3)
return 0
seen = seen + [rcdir]
+ assert self.pp and self.mem_cur
self.pp.msg = "a{} {}".format(self.pp.n, cdir)
ret = 0
seen_files = {} # != inames; files-only for dropcheck
g = statdir(self.log_func, not self.args.no_scandir, False, cdir)
- g = sorted(g)
- inames = {x[0]: 1 for x in g}
- for iname, inf in g:
+ gl = sorted(g)
+ inames = {x[0]: 1 for x in gl}
+ for iname, inf in gl:
abspath = os.path.join(cdir, iname)
if rei and rei.search(abspath):
continue
@@ -605,10 +651,10 @@ class Up2k(object):
continue
# self.log(" dir: {}".format(abspath))
try:
- ret += self._build_dir(dbw, top, excl, abspath, rap, rei, reh, seen)
+ ret += self._build_dir(db, top, excl, abspath, rap, rei, reh, seen)
except:
- m = "failed to index subdir [{}]:\n{}"
- self.log(m.format(abspath, min_ex()), c=1)
+ t = "failed to index subdir [{}]:\n{}"
+ self.log(t.format(abspath, min_ex()), c=1)
elif not stat.S_ISREG(inf.st_mode):
self.log("skip type-{:x} file [{}]".format(inf.st_mode, abspath))
else:
@@ -632,31 +678,31 @@ class Up2k(object):
rd, fn = rp.rsplit("/", 1) if "/" in rp else ["", rp]
sql = "select w, mt, sz from up where rd = ? and fn = ?"
try:
- c = dbw[0].execute(sql, (rd, fn))
+ c = db.c.execute(sql, (rd, fn))
except:
- c = dbw[0].execute(sql, s3enc(self.mem_cur, rd, fn))
+ c = db.c.execute(sql, s3enc(self.mem_cur, rd, fn))
in_db = list(c.fetchall())
if in_db:
self.pp.n -= 1
dw, dts, dsz = in_db[0]
if len(in_db) > 1:
- m = "WARN: multiple entries: [{}] => [{}] |{}|\n{}"
+ t = "WARN: multiple entries: [{}] => [{}] |{}|\n{}"
rep_db = "\n".join([repr(x) for x in in_db])
- self.log(m.format(top, rp, len(in_db), rep_db))
+ self.log(t.format(top, rp, len(in_db), rep_db))
dts = -1
if dts == lmod and dsz == sz and (nohash or dw[0] != "#"):
continue
- m = "reindex [{}] => [{}] ({}/{}) ({}/{})".format(
+ t = "reindex [{}] => [{}] ({}/{}) ({}/{})".format(
top, rp, dts, lmod, dsz, sz
)
- self.log(m)
- self.db_rm(dbw[0], rd, fn)
+ self.log(t)
+ self.db_rm(db.c, rd, fn)
ret += 1
- dbw[1] += 1
- in_db = None
+ db.n += 1
+ in_db = []
self.pp.msg = "a{} {}".format(self.pp.n, abspath)
@@ -674,15 +720,15 @@ class Up2k(object):
wark = up2k_wark_from_hashlist(self.salt, sz, hashes)
- self.db_add(dbw[0], wark, rd, fn, lmod, sz, "", 0)
- dbw[1] += 1
+ self.db_add(db.c, wark, rd, fn, lmod, sz, "", 0)
+ db.n += 1
ret += 1
- td = time.time() - dbw[2]
- if dbw[1] >= 4096 or td >= 60:
- self.log("commit {} new files".format(dbw[1]))
- dbw[0].connection.commit()
- dbw[1] = 0
- dbw[2] = time.time()
+ td = time.time() - db.t
+ if db.n >= 4096 or td >= 60:
+ self.log("commit {} new files".format(db.n))
+ db.c.connection.commit()
+ db.n = 0
+ db.t = time.time()
# drop missing files
rd = cdir[len(top) + 1 :].strip("/")
@@ -691,25 +737,26 @@ class Up2k(object):
q = "select fn from up where rd = ?"
try:
- c = dbw[0].execute(q, (rd,))
+ c = db.c.execute(q, (rd,))
except:
- c = dbw[0].execute(q, ("//" + w8b64enc(rd),))
+ c = db.c.execute(q, ("//" + w8b64enc(rd),))
hits = [w8b64dec(x[2:]) if x.startswith("//") else x for (x,) in c]
rm_files = [x for x in hits if x not in seen_files]
n_rm = len(rm_files)
for fn in rm_files:
- self.db_rm(dbw[0], rd, fn)
+ self.db_rm(db.c, rd, fn)
if n_rm:
self.log("forgot {} deleted files".format(n_rm))
return ret
- def _drop_lost(self, cur, top):
+ def _drop_lost(self, cur: "sqlite3.Cursor", top: str) -> int:
rm = []
n_rm = 0
nchecked = 0
+ assert self.pp
# `_build_dir` did all the files, now do dirs
ndirs = next(cur.execute("select count(distinct rd) from up"))[0]
c = cur.execute("select distinct rd from up order by rd desc")
@@ -743,13 +790,16 @@ class Up2k(object):
return n_rm
- def _build_tags_index(self, vol):
+ def _build_tags_index(self, vol: VFS) -> tuple[int, int, bool]:
ptop = vol.realpath
with self.mutex:
- _, db_path = self.register_vpath(ptop, vol.flags)
- entags = self.entags[ptop]
- flags = self.flags[ptop]
- cur = self.cur[ptop]
+ reg = self.register_vpath(ptop, vol.flags)
+
+ assert reg and self.pp and self.mtag
+ _, db_path = reg
+ entags = self.entags[ptop]
+ flags = self.flags[ptop]
+ cur = self.cur[ptop]
n_add = 0
n_rm = 0
@@ -794,7 +844,8 @@ class Up2k(object):
if not self.mtag:
return n_add, n_rm, False
- mpool = False
+ mpool: Optional[Queue[Mpqe]] = None
+
if self.mtag.prefer_mt and self.args.mtag_mt > 1:
mpool = self._start_mpool()
@@ -819,11 +870,10 @@ class Up2k(object):
abspath = os.path.join(ptop, rd, fn)
self.pp.msg = "c{} {}".format(n_left, abspath)
- args = [entags, w, abspath]
if not mpool:
- n_tags = self._tag_file(c3, *args)
+ n_tags = self._tag_file(c3, entags, w, abspath)
else:
- mpool.put(["mtag"] + args)
+ mpool.put(Mpqe({}, entags, w, abspath))
# not registry cursor; do not self.mutex:
n_tags = len(self._flush_mpool(c3))
@@ -850,7 +900,7 @@ class Up2k(object):
return n_add, n_rm, True
- def _flush_mpool(self, wcur):
+ def _flush_mpool(self, wcur: "sqlite3.Cursor") -> list[str]:
ret = []
for x in self.pending_tags:
self._tag_file(wcur, *x)
@@ -859,7 +909,7 @@ class Up2k(object):
self.pending_tags = []
return ret
- def _run_all_mtp(self):
+ def _run_all_mtp(self) -> None:
gid = self.gid
t0 = time.time()
for ptop, flags in self.flags.items():
@@ -870,12 +920,12 @@ class Up2k(object):
msg = "mtp finished in {:.2f} sec ({})"
self.log(msg.format(td, s2hms(td, True)))
- del self.pp
+ self.pp = None
for k in list(self.volstate.keys()):
if "OFFLINE" not in self.volstate[k]:
self.volstate[k] = "online, idle"
- def _run_one_mtp(self, ptop, gid):
+ def _run_one_mtp(self, ptop: str, gid: int) -> None:
if gid != self.gid:
return
@@ -915,8 +965,8 @@ class Up2k(object):
break
q = "select w from mt where k = 't:mtp' limit ?"
- warks = cur.execute(q, (batch_sz,)).fetchall()
- warks = [x[0] for x in warks]
+ zq = cur.execute(q, (batch_sz,)).fetchall()
+ warks = [str(x[0]) for x in zq]
jobs = []
for w in warks:
q = "select rd, fn from up where substr(w,1,16)=? limit 1"
@@ -925,8 +975,8 @@ class Up2k(object):
abspath = os.path.join(ptop, rd, fn)
q = "select k from mt where w = ?"
- have = cur.execute(q, (w,)).fetchall()
- have = [x[0] for x in have]
+ zq2 = cur.execute(q, (w,)).fetchall()
+ have: dict[str, Union[str, float]] = {x[0]: 1 for x in zq2}
parsers = self._get_parsers(ptop, have, abspath)
if not parsers:
@@ -937,7 +987,7 @@ class Up2k(object):
if w in in_progress:
continue
- jobs.append([parsers, None, w, abspath])
+ jobs.append(Mpqe(parsers, set(), w, abspath))
in_progress[w] = True
with self.mutex:
@@ -997,7 +1047,9 @@ class Up2k(object):
wcur.close()
cur.close()
- def _get_parsers(self, ptop, have, abspath):
+ def _get_parsers(
+ self, ptop: str, have: dict[str, Union[str, float]], abspath: str
+ ) -> dict[str, MParser]:
try:
all_parsers = self.mtp_parsers[ptop]
except:
@@ -1030,16 +1082,16 @@ class Up2k(object):
parsers = {k: v for k, v in parsers.items() if v.force or k not in have}
return parsers
- def _start_mpool(self):
+ def _start_mpool(self) -> Queue[Mpqe]:
# mp.pool.ThreadPool and concurrent.futures.ThreadPoolExecutor
# both do crazy runahead so lets reinvent another wheel
nw = max(1, self.args.mtag_mt)
-
- if self.pending_tags is None:
+ assert self.mtag
+ if not self.mpool_used:
+ self.mpool_used = True
self.log("using {}x {}".format(nw, self.mtag.backend))
- self.pending_tags = []
- mpool = Queue(nw)
+ mpool: Queue[Mpqe] = Queue(nw)
for _ in range(nw):
thr = threading.Thread(
target=self._tag_thr, args=(mpool,), name="up2k-mpool"
@@ -1049,50 +1101,55 @@ class Up2k(object):
return mpool
- def _stop_mpool(self, mpool):
+ def _stop_mpool(self, mpool: Queue[Mpqe]) -> None:
if not mpool:
return
for _ in range(mpool.maxsize):
- mpool.put(None)
+ mpool.put(Mpqe({}, set(), "", ""))
mpool.join()
- def _tag_thr(self, q):
+ def _tag_thr(self, q: Queue[Mpqe]) -> None:
+ assert self.mtag
while True:
- task = q.get()
- if not task:
+ qe = q.get()
+ if not qe.w:
q.task_done()
return
try:
- parser, entags, wark, abspath = task
- if parser == "mtag":
- tags = self.mtag.get(abspath)
+ if not qe.mtp:
+ tags = self.mtag.get(qe.abspath)
else:
- tags = self.mtag.get_bin(parser, abspath)
+ tags = self.mtag.get_bin(qe.mtp, qe.abspath)
vtags = [
"\033[36m{} \033[33m{}".format(k, v) for k, v in tags.items()
]
if vtags:
- self.log("{}\033[0m [{}]".format(" ".join(vtags), abspath))
+ self.log("{}\033[0m [{}]".format(" ".join(vtags), qe.abspath))
with self.mutex:
- self.pending_tags.append([entags, wark, abspath, tags])
+ self.pending_tags.append((qe.entags, qe.w, qe.abspath, tags))
except:
ex = traceback.format_exc()
- if parser == "mtag":
- parser = self.mtag.backend
-
- self._log_tag_err(parser, abspath, ex)
+ self._log_tag_err(qe.mtp or self.mtag.backend, qe.abspath, ex)
q.task_done()
- def _log_tag_err(self, parser, abspath, ex):
+ def _log_tag_err(self, parser: Any, abspath: str, ex: Any) -> None:
msg = "{} failed to read tags from {}:\n{}".format(parser, abspath, ex)
self.log(msg.lstrip(), c=1 if " int:
+ assert self.mtag
if tags is None:
try:
tags = self.mtag.get(abspath)
@@ -1127,12 +1184,12 @@ class Up2k(object):
return ret
- def _orz(self, db_path):
+ def _orz(self, db_path: str) -> "sqlite3.Cursor":
timeout = int(max(self.args.srch_time, 5) * 1.2)
return sqlite3.connect(db_path, timeout, check_same_thread=False).cursor()
# x.set_trace_callback(trace)
- def _open_db(self, db_path):
+ def _open_db(self, db_path: str) -> "sqlite3.Cursor":
existed = bos.path.exists(db_path)
cur = self._orz(db_path)
ver = self._read_ver(cur)
@@ -1141,8 +1198,8 @@ class Up2k(object):
if ver == 4:
try:
- m = "creating backup before upgrade: "
- cur = self._backup_db(db_path, cur, ver, m)
+ t = "creating backup before upgrade: "
+ cur = self._backup_db(db_path, cur, ver, t)
self._upgrade_v4(cur)
ver = 5
except:
@@ -1157,8 +1214,8 @@ class Up2k(object):
self.log("WARN: could not list files; DB corrupt?\n" + min_ex())
if (ver or 0) > DB_VER:
- m = "database is version {}, this copyparty only supports versions <= {}"
- raise Exception(m.format(ver, DB_VER))
+ t = "database is version {}, this copyparty only supports versions <= {}"
+ raise Exception(t.format(ver, DB_VER))
msg = "creating new DB (old is bad); backup: "
if ver:
@@ -1171,7 +1228,9 @@ class Up2k(object):
bos.unlink(db_path)
return self._create_db(db_path, None)
- def _backup_db(self, db_path, cur, ver, msg):
+ def _backup_db(
+ self, db_path: str, cur: "sqlite3.Cursor", ver: Optional[int], msg: str
+ ) -> "sqlite3.Cursor":
bak = "{}.bak.{:x}.v{}".format(db_path, int(time.time()), ver)
self.log(msg + bak)
try:
@@ -1180,8 +1239,8 @@ class Up2k(object):
cur.connection.backup(c2)
return cur
except:
- m = "native sqlite3 backup failed; using fallback method:\n"
- self.log(m + min_ex())
+ t = "native sqlite3 backup failed; using fallback method:\n"
+ self.log(t + min_ex())
finally:
c2.close()
@@ -1192,7 +1251,7 @@ class Up2k(object):
shutil.copy2(fsenc(db_path), fsenc(bak))
return self._orz(db_path)
- def _read_ver(self, cur):
+ def _read_ver(self, cur: "sqlite3.Cursor") -> Optional[int]:
for tab in ["ki", "kv"]:
try:
c = cur.execute(r"select v from {} where k = 'sver'".format(tab))
@@ -1202,8 +1261,11 @@ class Up2k(object):
rows = c.fetchall()
if rows:
return int(rows[0][0])
+ return None
- def _create_db(self, db_path, cur):
+ def _create_db(
+ self, db_path: str, cur: Optional["sqlite3.Cursor"]
+ ) -> "sqlite3.Cursor":
"""
collision in 2^(n/2) files where n = bits (6 bits/ch)
10*6/2 = 2^30 = 1'073'741'824, 24.1mb idx 1<<(3*10)
@@ -1236,7 +1298,7 @@ class Up2k(object):
self.log("created DB at {}".format(db_path))
return cur
- def _upgrade_v4(self, cur):
+ def _upgrade_v4(self, cur: "sqlite3.Cursor") -> None:
for cmd in [
r"alter table up add column ip text",
r"alter table up add column at int",
@@ -1247,7 +1309,7 @@ class Up2k(object):
cur.connection.commit()
- def handle_json(self, cj):
+ def handle_json(self, cj: dict[str, Any]) -> dict[str, Any]:
with self.mutex:
if not self.register_vpath(cj["ptop"], cj["vcfg"]):
if cj["ptop"] not in self.registry:
@@ -1269,13 +1331,13 @@ class Up2k(object):
if cur:
if self.no_expr_idx:
q = r"select * from up where w = ?"
- argv = (wark,)
+ argv = [wark]
else:
q = r"select * from up where substr(w,1,16) = ? and w = ?"
- argv = (wark[:16], wark)
+ argv = [wark[:16], wark]
- alts = []
- cur = cur.execute(q, argv)
+ alts: list[tuple[int, int, dict[str, Any]]] = []
+ cur = cur.execute(q, tuple(argv))
for _, dtime, dsize, dp_dir, dp_fn, ip, at in cur:
if dp_dir.startswith("//") or dp_fn.startswith("//"):
dp_dir, dp_fn = s3dec(dp_dir, dp_fn)
@@ -1307,7 +1369,7 @@ class Up2k(object):
+ (2 if dp_dir == cj["prel"] else 0)
+ (1 if dp_fn == cj["name"] else 0)
)
- alts.append([score, -len(alts), j])
+ alts.append((score, -len(alts), j))
job = sorted(alts, reverse=True)[0][2] if alts else None
if job and wark in reg:
@@ -1344,7 +1406,7 @@ class Up2k(object):
# registry is size-constrained + can only contain one unique wark;
# let want_recheck trigger symlink (if still in reg) or reupload
if cur:
- dupe = [cj["prel"], cj["name"], cj["lmod"]]
+ dupe = (cj["prel"], cj["name"], cj["lmod"])
try:
self.dupesched[src].append(dupe)
except:
@@ -1431,17 +1493,19 @@ class Up2k(object):
"wark": wark,
}
- def _untaken(self, fdir, fname, ts, ip):
+ def _untaken(self, fdir: str, fname: str, ts: float, ip: str) -> str:
if self.args.nw:
return fname
# TODO broker which avoid this race and
# provides a new filename if taken (same as bup)
suffix = "-{:.6f}-{}".format(ts, ip.replace(":", "."))
- with ren_open(fname, "wb", fdir=fdir, suffix=suffix) as f:
- return f["orz"][1]
+ with ren_open(fname, "wb", fdir=fdir, suffix=suffix) as zfw:
+ return zfw["orz"][1]
- def _symlink(self, src, dst, verbose=True, lmod=None):
+ def _symlink(
+ self, src: str, dst: str, verbose: bool = True, lmod: float = 0
+ ) -> None:
if verbose:
self.log("linking dupe:\n {0}\n {1}".format(src, dst))
@@ -1475,9 +1539,9 @@ class Up2k(object):
break
nc += 1
if nc > 1:
- lsrc = nsrc[nc:]
+ zsl = nsrc[nc:]
hops = len(ndst[nc:]) - 1
- lsrc = "../" * hops + "/".join(lsrc)
+ lsrc = "../" * hops + "/".join(zsl)
try:
if self.args.hardlink:
@@ -1498,11 +1562,13 @@ class Up2k(object):
if lmod and (not linked or SYMTIME):
times = (int(time.time()), int(lmod))
if ANYWIN:
- self.lastmod_q.append([dst, 0, times])
+ self.lastmod_q.append((dst, 0, times))
else:
bos.utime(dst, times, False)
- def handle_chunk(self, ptop, wark, chash):
+ def handle_chunk(
+ self, ptop: str, wark: str, chash: str
+ ) -> tuple[int, list[int], str, float]:
with self.mutex:
job = self.registry[ptop].get(wark)
if not job:
@@ -1523,8 +1589,8 @@ class Up2k(object):
if chash in job["busy"]:
nh = len(job["hash"])
idx = job["hash"].index(chash)
- m = "that chunk is already being written to:\n {}\n {} {}/{}\n {}"
- raise Pebkac(400, m.format(wark, chash, idx, nh, job["name"]))
+ t = "that chunk is already being written to:\n {}\n {} {}/{}\n {}"
+ raise Pebkac(400, t.format(wark, chash, idx, nh, job["name"]))
job["busy"][chash] = 1
@@ -1535,17 +1601,17 @@ class Up2k(object):
path = os.path.join(job["ptop"], job["prel"], job["tnam"])
- return [chunksize, ofs, path, job["lmod"]]
+ return chunksize, ofs, path, job["lmod"]
- def release_chunk(self, ptop, wark, chash):
+ def release_chunk(self, ptop: str, wark: str, chash: str) -> bool:
with self.mutex:
job = self.registry[ptop].get(wark)
if job:
job["busy"].pop(chash, None)
- return [True]
+ return True
- def confirm_chunk(self, ptop, wark, chash):
+ def confirm_chunk(self, ptop: str, wark: str, chash: str) -> tuple[int, str]:
with self.mutex:
try:
job = self.registry[ptop][wark]
@@ -1553,14 +1619,14 @@ class Up2k(object):
src = os.path.join(pdir, job["tnam"])
dst = os.path.join(pdir, job["name"])
except Exception as ex:
- return "confirm_chunk, wark, " + repr(ex)
+ return "confirm_chunk, wark, " + repr(ex) # type: ignore
job["busy"].pop(chash, None)
try:
job["need"].remove(chash)
except Exception as ex:
- return "confirm_chunk, chash, " + repr(ex)
+ return "confirm_chunk, chash, " + repr(ex) # type: ignore
ret = len(job["need"])
if ret > 0:
@@ -1576,35 +1642,35 @@ class Up2k(object):
return ret, dst
- def finish_upload(self, ptop, wark):
+ def finish_upload(self, ptop: str, wark: str) -> None:
with self.mutex:
self._finish_upload(ptop, wark)
- def _finish_upload(self, ptop, wark):
+ def _finish_upload(self, ptop: str, wark: str) -> None:
try:
job = self.registry[ptop][wark]
pdir = os.path.join(job["ptop"], job["prel"])
src = os.path.join(pdir, job["tnam"])
dst = os.path.join(pdir, job["name"])
except Exception as ex:
- return "finish_upload, wark, " + repr(ex)
+ raise Pebkac(500, "finish_upload, wark, " + repr(ex))
# self.log("--- " + wark + " " + dst + " finish_upload atomic " + dst, 4)
atomic_move(src, dst)
times = (int(time.time()), int(job["lmod"]))
if ANYWIN:
- a = [dst, job["size"], times]
- self.lastmod_q.append(a)
+ z1 = (dst, job["size"], times)
+ self.lastmod_q.append(z1)
elif not job["hash"]:
try:
bos.utime(dst, times)
except:
pass
- a = [job[x] for x in "ptop wark prel name lmod size addr".split()]
- a += [job.get("at") or time.time()]
- if self.idx_wark(*a):
+ z2 = [job[x] for x in "ptop wark prel name lmod size addr".split()]
+ z2 += [job.get("at") or time.time()]
+ if self.idx_wark(*z2):
del self.registry[ptop][wark]
else:
self.regdrop(ptop, wark)
@@ -1622,27 +1688,37 @@ class Up2k(object):
self._symlink(dst, d2, lmod=lmod)
if cur:
self.db_rm(cur, rd, fn)
- self.db_add(cur, wark, rd, fn, *a[-4:])
+ self.db_add(cur, wark, rd, fn, *z2[-4:])
if cur:
cur.connection.commit()
- def regdrop(self, ptop, wark):
- t = self.droppable[ptop]
+ def regdrop(self, ptop: str, wark: str) -> None:
+ olds = self.droppable[ptop]
if wark:
- t.append(wark)
+ olds.append(wark)
- if len(t) <= self.args.reg_cap:
+ if len(olds) <= self.args.reg_cap:
return
- n = len(t) - int(self.args.reg_cap / 2)
- m = "up2k-registry [{}] has {} droppables; discarding {}"
- self.log(m.format(ptop, len(t), n))
- for k in t[:n]:
+ n = len(olds) - int(self.args.reg_cap / 2)
+ t = "up2k-registry [{}] has {} droppables; discarding {}"
+ self.log(t.format(ptop, len(olds), n))
+ for k in olds[:n]:
self.registry[ptop].pop(k, None)
- self.droppable[ptop] = t[n:]
+ self.droppable[ptop] = olds[n:]
- def idx_wark(self, ptop, wark, rd, fn, lmod, sz, ip, at):
+ def idx_wark(
+ self,
+ ptop: str,
+ wark: str,
+ rd: str,
+ fn: str,
+ lmod: float,
+ sz: int,
+ ip: str,
+ at: float,
+ ) -> bool:
cur = self.cur.get(ptop)
if not cur:
return False
@@ -1652,29 +1728,41 @@ class Up2k(object):
cur.connection.commit()
if "e2t" in self.flags[ptop]:
- self.tagq.put([ptop, wark, rd, fn])
+ self.tagq.put((ptop, wark, rd, fn))
self.n_tagq += 1
return True
- def db_rm(self, db, rd, fn):
+ def db_rm(self, db: "sqlite3.Cursor", rd: str, fn: str) -> None:
sql = "delete from up where rd = ? and fn = ?"
try:
db.execute(sql, (rd, fn))
except:
+ assert self.mem_cur
db.execute(sql, s3enc(self.mem_cur, rd, fn))
- def db_add(self, db, wark, rd, fn, ts, sz, ip, at):
+ def db_add(
+ self,
+ db: "sqlite3.Cursor",
+ wark: str,
+ rd: str,
+ fn: str,
+ ts: float,
+ sz: int,
+ ip: str,
+ at: float,
+ ) -> None:
sql = "insert into up values (?,?,?,?,?,?,?)"
v = (wark, int(ts), sz, rd, fn, ip or "", int(at or 0))
try:
db.execute(sql, v)
except:
+ assert self.mem_cur
rd, fn = s3enc(self.mem_cur, rd, fn)
v = (wark, int(ts), sz, rd, fn, ip or "", int(at or 0))
db.execute(sql, v)
- def handle_rm(self, uname, ip, vpaths):
+ def handle_rm(self, uname: str, ip: str, vpaths: list[str]) -> str:
n_files = 0
ok = {}
ng = {}
@@ -1687,12 +1775,14 @@ class Up2k(object):
ng[k] = 1
ng = {k: 1 for k in ng if k not in ok}
- ok = len(ok)
- ng = len(ng)
+ iok = len(ok)
+ ing = len(ng)
- return "deleted {} files (and {}/{} folders)".format(n_files, ok, ok + ng)
+ return "deleted {} files (and {}/{} folders)".format(n_files, iok, iok + ing)
- def _handle_rm(self, uname, ip, vpath):
+ def _handle_rm(
+ self, uname: str, ip: str, vpath: str
+ ) -> tuple[int, list[str], list[str]]:
try:
permsets = [[True, False, False, True]]
vn, rem = self.asrv.vfs.get(vpath, uname, *permsets[0])
@@ -1709,18 +1799,18 @@ class Up2k(object):
vn, rem = vn.get_dbv(rem)
_, _, _, _, dip, dat = self._find_from_vpath(vn.realpath, rem)
- m = "you cannot delete this: "
+ t = "you cannot delete this: "
if not dip:
- m += "file not found"
+ t += "file not found"
elif dip != ip:
- m += "not uploaded by (You)"
+ t += "not uploaded by (You)"
elif dat < time.time() - self.args.unpost:
- m += "uploaded too long ago"
+ t += "uploaded too long ago"
else:
- m = None
+ t = ""
- if m:
- raise Pebkac(400, m)
+ if t:
+ raise Pebkac(400, t)
ptop = vn.realpath
atop = vn.canonical(rem, False)
@@ -1731,16 +1821,19 @@ class Up2k(object):
raise Pebkac(400, "file not found on disk (already deleted?)")
scandir = not self.args.no_scandir
- if stat.S_ISLNK(st.st_mode) or stat.S_ISREG(st.st_mode):
+ if stat.S_ISDIR(st.st_mode):
+ g = vn.walk("", rem, [], uname, permsets, True, scandir, True)
+ if unpost:
+ raise Pebkac(400, "cannot unpost folders")
+ elif stat.S_ISLNK(st.st_mode) or stat.S_ISREG(st.st_mode):
dbv, vrem = self.asrv.vfs.get(vpath, uname, *permsets[0])
dbv, vrem = dbv.get_dbv(vrem)
voldir = vsplit(vrem)[0]
vpath_dir = vsplit(vpath)[0]
- g = [[dbv, voldir, vpath_dir, adir, [[fn, 0]], [], []]]
+ g = [(dbv, voldir, vpath_dir, adir, [(fn, 0)], [], {})] # type: ignore
else:
- g = vn.walk("", rem, [], uname, permsets, True, scandir, True)
- if unpost:
- raise Pebkac(400, "cannot unpost folders")
+ self.log("rm: skip type-{:x} file [{}]".format(st.st_mode, atop))
+ return 0, [], []
n_files = 0
for dbv, vrem, _, adir, files, rd, vd in g:
@@ -1766,7 +1859,7 @@ class Up2k(object):
rm = rmdirs(self.log_func, scandir, True, atop, 1)
return n_files, rm[0], rm[1]
- def handle_mv(self, uname, svp, dvp):
+ def handle_mv(self, uname: str, svp: str, dvp: str) -> str:
svn, srem = self.asrv.vfs.get(svp, uname, True, False, True)
svn, srem = svn.get_dbv(srem)
sabs = svn.canonical(srem, False)
@@ -1808,7 +1901,7 @@ class Up2k(object):
rmdirs(self.log_func, scandir, True, sabs, 1)
return "k"
- def _mv_file(self, uname, svp, dvp):
+ def _mv_file(self, uname: str, svp: str, dvp: str) -> str:
svn, srem = self.asrv.vfs.get(svp, uname, True, False, True)
svn, srem = svn.get_dbv(srem)
@@ -1834,29 +1927,33 @@ class Up2k(object):
if bos.path.islink(sabs):
dlabs = absreal(sabs)
- m = "moving symlink from [{}] to [{}], target [{}]"
- self.log(m.format(sabs, dabs, dlabs))
+ t = "moving symlink from [{}] to [{}], target [{}]"
+ self.log(t.format(sabs, dabs, dlabs))
mt = bos.path.getmtime(sabs, False)
bos.unlink(sabs)
self._symlink(dlabs, dabs, False, lmod=mt)
# folders are too scary, schedule rescan of both vols
- self.need_rescan[svn.vpath] = 1
- self.need_rescan[dvn.vpath] = 1
+ self.need_rescan.add(svn.vpath)
+ self.need_rescan.add(dvn.vpath)
with self.rescan_cond:
self.rescan_cond.notify_all()
return "k"
- c1, w, ftime, fsize, ip, at = self._find_from_vpath(svn.realpath, srem)
+ c1, w, ftime_, fsize_, ip, at = self._find_from_vpath(svn.realpath, srem)
c2 = self.cur.get(dvn.realpath)
- if ftime is None:
+ if ftime_ is None:
st = bos.stat(sabs)
ftime = st.st_mtime
fsize = st.st_size
+ else:
+ ftime = ftime_
+ fsize = fsize_ or 0
if w:
+ assert c1
if c2 and c2 != c1:
self._copy_tags(c1, c2, w)
@@ -1865,7 +1962,7 @@ class Up2k(object):
c1.connection.commit()
if c2:
- self.db_add(c2, w, drd, dfn, ftime, fsize, ip, at)
+ self.db_add(c2, w, drd, dfn, ftime, fsize, ip or "", at or 0)
c2.connection.commit()
else:
self.log("not found in src db: [{}]".format(svp))
@@ -1873,7 +1970,9 @@ class Up2k(object):
bos.rename(sabs, dabs)
return "k"
- def _copy_tags(self, csrc, cdst, wark):
+ def _copy_tags(
+ self, csrc: "sqlite3.Cursor", cdst: "sqlite3.Cursor", wark: str
+ ) -> None:
"""copy all tags for wark from src-db to dst-db"""
w = wark[:16]
@@ -1883,16 +1982,26 @@ class Up2k(object):
for _, k, v in csrc.execute("select * from mt where w=?", (w,)):
cdst.execute("insert into mt values(?,?,?)", (w, k, v))
- def _find_from_vpath(self, ptop, vrem):
+ def _find_from_vpath(
+ self, ptop: str, vrem: str
+ ) -> tuple[
+ Optional["sqlite3.Cursor"],
+ Optional[str],
+ Optional[int],
+ Optional[int],
+ Optional[str],
+ Optional[int],
+ ]:
cur = self.cur.get(ptop)
if not cur:
- return [None] * 6
+ return None, None, None, None, None, None
rd, fn = vsplit(vrem)
q = "select w, mt, sz, ip, at from up where rd=? and fn=? limit 1"
try:
c = cur.execute(q, (rd, fn))
except:
+ assert self.mem_cur
c = cur.execute(q, s3enc(self.mem_cur, rd, fn))
hit = c.fetchone()
@@ -1901,14 +2010,21 @@ class Up2k(object):
return cur, wark, ftime, fsize, ip, at
return cur, None, None, None, None, None
- def _forget_file(self, ptop, vrem, cur, wark, drop_tags):
+ def _forget_file(
+ self,
+ ptop: str,
+ vrem: str,
+ cur: Optional["sqlite3.Cursor"],
+ wark: Optional[str],
+ drop_tags: bool,
+ ) -> None:
"""forgets file in db, fixes symlinks, does not delete"""
srd, sfn = vsplit(vrem)
self.log("forgetting {}".format(vrem))
- if wark:
+ if wark and cur:
self.log("found {} in db".format(wark))
if drop_tags:
- if self._relink(wark, ptop, vrem, None):
+ if self._relink(wark, ptop, vrem, ""):
drop_tags = False
if drop_tags:
@@ -1919,20 +2035,24 @@ class Up2k(object):
reg = self.registry.get(ptop)
if reg:
- if not wark:
- wark = [
+ vdir = vsplit(vrem)[0]
+ wark = wark or next(
+ (
x
for x, y in reg.items()
- if sfn in [y["name"], y.get("tnam")] and y["prel"] == vrem
- ]
-
- if wark and wark in reg:
- m = "forgetting partial upload {} ({})"
- p = self._vis_job_progress(wark)
- self.log(m.format(wark, p))
+ if sfn in [y["name"], y.get("tnam")] and y["prel"] == vdir
+ ),
+ "",
+ )
+ job = reg.get(wark) if wark else None
+ if job:
+ t = "forgetting partial upload {} ({})"
+ p = self._vis_job_progress(job)
+ self.log(t.format(wark, p))
+ assert wark
del reg[wark]
- def _relink(self, wark, sptop, srem, dabs):
+ def _relink(self, wark: str, sptop: str, srem: str, dabs: str) -> int:
"""
update symlinks from file at svn/srem to dabs (rename),
or to first remaining full if no dabs (delete)
@@ -1953,13 +2073,13 @@ class Up2k(object):
if not dupes:
return 0
- full = {}
- links = {}
+ full: dict[str, tuple[str, str]] = {}
+ links: dict[str, tuple[str, str]] = {}
for ptop, vp in dupes:
ap = os.path.join(ptop, vp)
try:
d = links if bos.path.islink(ap) else full
- d[ap] = [ptop, vp]
+ d[ap] = (ptop, vp)
except:
self.log("relink: not found: [{}]".format(ap))
@@ -1973,13 +2093,13 @@ class Up2k(object):
bos.rename(sabs, slabs)
bos.utime(slabs, (int(time.time()), int(mt)), False)
self._symlink(slabs, sabs, False)
- full[slabs] = [ptop, rem]
+ full[slabs] = (ptop, rem)
sabs = slabs
if not dabs:
dabs = list(sorted(full.keys()))[0]
- for alink in links.keys():
+ for alink in links:
lmod = None
try:
if alink != sabs and absreal(alink) != sabs:
@@ -1991,11 +2111,11 @@ class Up2k(object):
except:
pass
- self._symlink(dabs, alink, False, lmod=lmod)
+ self._symlink(dabs, alink, False, lmod=lmod or 0)
return len(full) + len(links)
- def _get_wark(self, cj):
+ def _get_wark(self, cj: dict[str, Any]) -> str:
if len(cj["name"]) > 1024 or len(cj["hash"]) > 512 * 1024: # 16TiB
raise Pebkac(400, "name or numchunks not according to spec")
@@ -2020,15 +2140,14 @@ class Up2k(object):
return wark
- def _hashlist_from_file(self, path):
- pp = self.pp if hasattr(self, "pp") else None
+ def _hashlist_from_file(self, path: str) -> list[str]:
fsz = bos.path.getsize(path)
csz = up2k_chunksize(fsz)
ret = []
with open(fsenc(path), "rb", 512 * 1024) as f:
while fsz > 0:
- if pp:
- pp.msg = "{} MB, {}".format(int(fsz / 1024 / 1024), path)
+ if self.pp:
+ self.pp.msg = "{} MB, {}".format(int(fsz / 1024 / 1024), path)
hashobj = hashlib.sha512()
rem = min(csz, fsz)
@@ -2047,7 +2166,7 @@ class Up2k(object):
return ret
- def _new_upload(self, job):
+ def _new_upload(self, job: dict[str, Any]) -> None:
self.registry[job["ptop"]][job["wark"]] = job
pdir = os.path.join(job["ptop"], job["prel"])
job["name"] = self._untaken(pdir, job["name"], job["t0"], job["addr"])
@@ -2066,8 +2185,8 @@ class Up2k(object):
dip = job["addr"].replace(":", ".")
suffix = "-{:.6f}-{}".format(job["t0"], dip)
- with ren_open(tnam, "wb", fdir=pdir, suffix=suffix) as f:
- f, job["tnam"] = f["orz"]
+ with ren_open(tnam, "wb", fdir=pdir, suffix=suffix) as zfw:
+ f, job["tnam"] = zfw["orz"]
if (
ANYWIN
and self.args.sparse
@@ -2086,7 +2205,7 @@ class Up2k(object):
if not job["hash"]:
self._finish_upload(job["ptop"], job["wark"])
- def _lastmodder(self):
+ def _lastmodder(self) -> None:
while True:
ready = self.lastmod_q
self.lastmod_q = []
@@ -2098,8 +2217,8 @@ class Up2k(object):
try:
bos.utime(path, times, False)
except:
- m = "lmod: failed to utime ({}, {}):\n{}"
- self.log(m.format(path, times, min_ex()))
+ t = "lmod: failed to utime ({}, {}):\n{}"
+ self.log(t.format(path, times, min_ex()))
if self.args.sparse and self.args.sparse * 1024 * 1024 <= sz:
try:
@@ -2107,21 +2226,22 @@ class Up2k(object):
except:
self.log("could not unsparse [{}]".format(path), 3)
- def _snapshot(self):
- self.snap_persist_interval = 300 # persist unfinished index every 5 min
- self.snap_discard_interval = 21600 # drop unfinished after 6 hours inactivity
- self.snap_prev = {}
+ def _snapshot(self) -> None:
+ slp = self.snap_persist_interval
while True:
- time.sleep(self.snap_persist_interval)
- if not hasattr(self, "pp"):
+ time.sleep(slp)
+ if self.pp:
+ slp = 5
+ else:
+ slp = self.snap_persist_interval
self.do_snapshot()
- def do_snapshot(self):
+ def do_snapshot(self) -> None:
with self.mutex:
for k, reg in self.registry.items():
self._snap_reg(k, reg)
- def _snap_reg(self, ptop, reg):
+ def _snap_reg(self, ptop: str, reg: dict[str, dict[str, Any]]) -> None:
now = time.time()
histpath = self.asrv.vfs.histtab.get(ptop)
if not histpath:
@@ -2133,9 +2253,9 @@ class Up2k(object):
if x["need"] and now - x["poke"] > self.snap_discard_interval
]
if rm:
- m = "dropping {} abandoned uploads in {}".format(len(rm), ptop)
+ t = "dropping {} abandoned uploads in {}".format(len(rm), ptop)
vis = [self._vis_job_progress(x) for x in rm]
- self.log("\n".join([m] + vis))
+ self.log("\n".join([t] + vis))
for job in rm:
del reg[job["wark"]]
try:
@@ -2159,8 +2279,8 @@ class Up2k(object):
bos.unlink(path)
return
- newest = max(x["poke"] for _, x in reg.items()) if reg else 0
- etag = [len(reg), newest]
+ newest = float(max(x["poke"] for _, x in reg.items()) if reg else 0)
+ etag = (len(reg), newest)
if etag == self.snap_prev.get(ptop):
return
@@ -2177,10 +2297,11 @@ class Up2k(object):
self.log("snap: {} |{}|".format(path, len(reg.keys())))
self.snap_prev[ptop] = etag
- def _tagger(self):
+ def _tagger(self) -> None:
with self.mutex:
self.n_tagq += 1
+ assert self.mtag
while True:
with self.mutex:
self.n_tagq -= 1
@@ -2218,7 +2339,7 @@ class Up2k(object):
self.log("tagged {} ({}+{})".format(abspath, ntags1, len(tags) - ntags1))
- def _hasher(self):
+ def _hasher(self) -> None:
with self.mutex:
self.n_hashq += 1
@@ -2240,20 +2361,21 @@ class Up2k(object):
with self.mutex:
self.idx_wark(ptop, wark, rd, fn, inf.st_mtime, inf.st_size, ip, at)
- def hash_file(self, ptop, flags, rd, fn, ip, at):
+ def hash_file(
+ self, ptop: str, flags: dict[str, Any], rd: str, fn: str, ip: str, at: float
+ ) -> None:
with self.mutex:
self.register_vpath(ptop, flags)
- self.hashq.put([ptop, rd, fn, ip, at])
+ self.hashq.put((ptop, rd, fn, ip, at))
self.n_hashq += 1
# self.log("hashq {} push {}/{}/{}".format(self.n_hashq, ptop, rd, fn))
- def shutdown(self):
- if hasattr(self, "snap_prev"):
- self.log("writing snapshot")
- self.do_snapshot()
+ def shutdown(self) -> None:
+ self.log("writing snapshot")
+ self.do_snapshot()
-def up2k_chunksize(filesize):
+def up2k_chunksize(filesize: int) -> int:
chunksize = 1024 * 1024
stepsize = 512 * 1024
while True:
@@ -2266,18 +2388,17 @@ def up2k_chunksize(filesize):
stepsize *= mul
-def up2k_wark_from_hashlist(salt, filesize, hashes):
+def up2k_wark_from_hashlist(salt: str, filesize: int, hashes: list[str]) -> str:
"""server-reproducible file identifier, independent of name or location"""
- ident = [salt, str(filesize)]
- ident.extend(hashes)
- ident = "\n".join(ident)
+ values = [salt, str(filesize)] + hashes
+ vstr = "\n".join(values)
- wark = hashlib.sha512(ident.encode("utf-8")).digest()[:33]
+ wark = hashlib.sha512(vstr.encode("utf-8")).digest()[:33]
wark = base64.urlsafe_b64encode(wark)
return wark.decode("ascii")
-def up2k_wark_from_metadata(salt, sz, lastmod, rd, fn):
+def up2k_wark_from_metadata(salt: str, sz: int, lastmod: int, rd: str, fn: str) -> str:
ret = fsenc("{}\n{}\n{}\n{}\n{}".format(salt, lastmod, sz, rd, fn))
ret = base64.urlsafe_b64encode(hashlib.sha512(ret).digest())
return "#{}".format(ret.decode("ascii"))[:44]
diff --git a/copyparty/util.py b/copyparty/util.py
index a01fdfdb..2ca710f5 100644
--- a/copyparty/util.py
+++ b/copyparty/util.py
@@ -1,51 +1,79 @@
# coding: utf-8
from __future__ import print_function, unicode_literals
-import re
-import os
-import sys
-import stat
-import time
import base64
+import contextlib
+import hashlib
+import mimetypes
+import os
+import platform
+import re
import select
-import struct
import signal
import socket
-import hashlib
-import platform
-import traceback
-import threading
-import mimetypes
-import contextlib
+import stat
+import struct
import subprocess as sp # nosec
-from datetime import datetime
+import sys
+import threading
+import time
+import traceback
from collections import Counter
+from datetime import datetime
-from .__init__ import PY2, WINDOWS, ANYWIN, VT100
+from .__init__ import ANYWIN, PY2, TYPE_CHECKING, VT100, WINDOWS
from .stolen import surrogateescape
+try:
+ HAVE_SQLITE3 = True
+ import sqlite3 # pylint: disable=unused-import # typechk
+except:
+ HAVE_SQLITE3 = False
+
+try:
+ import types
+ from collections.abc import Callable, Iterable
+
+ import typing
+ from typing import Any, Generator, Optional, Protocol, Union
+
+ class RootLogger(Protocol):
+ def __call__(self, src: str, msg: str, c: Union[int, str] = 0) -> None:
+ return None
+
+ class NamedLogger(Protocol):
+ def __call__(self, msg: str, c: Union[int, str] = 0) -> None:
+ return None
+
+except:
+ pass
+
+if TYPE_CHECKING:
+ from .authsrv import VFS
+
+
FAKE_MP = False
try:
- if FAKE_MP:
- import multiprocessing.dummy as mp # noqa: F401 # pylint: disable=unused-import
+ if not FAKE_MP:
+ import multiprocessing as mp
else:
- import multiprocessing as mp # noqa: F401 # pylint: disable=unused-import
+ import multiprocessing.dummy as mp # type: ignore
except ImportError:
# support jython
- mp = None
+ mp = None # type: ignore
if not PY2:
- from urllib.parse import unquote_to_bytes as unquote
+ from io import BytesIO
from urllib.parse import quote_from_bytes as quote
- from queue import Queue # pylint: disable=unused-import
- from io import BytesIO # pylint: disable=unused-import
+ from urllib.parse import unquote_to_bytes as unquote
else:
- from urllib import unquote # pylint: disable=no-name-in-module
+ from StringIO import StringIO as BytesIO
from urllib import quote # pylint: disable=no-name-in-module
- from Queue import Queue # pylint: disable=import-error,no-name-in-module
- from StringIO import StringIO as BytesIO # pylint: disable=unused-import
+ from urllib import unquote # pylint: disable=no-name-in-module
+_: Any = (mp, BytesIO, quote, unquote)
+__all__ = ["mp", "BytesIO", "quote", "unquote"]
try:
struct.unpack(b">i", b"idgi")
@@ -53,20 +81,21 @@ try:
sunpack = struct.unpack
except:
- def spack(f, *a, **ka):
- return struct.pack(f.decode("ascii"), *a, **ka)
+ def spack(fmt: bytes, *a: Any) -> bytes:
+ return struct.pack(fmt.decode("ascii"), *a)
- def sunpack(f, *a, **ka):
- return struct.unpack(f.decode("ascii"), *a, **ka)
+ def sunpack(fmt: bytes, a: bytes) -> tuple[Any, ...]:
+ return struct.unpack(fmt.decode("ascii"), a)
ansi_re = re.compile("\033\\[[^mK]*[mK]")
surrogateescape.register_surrogateescape()
-FS_ENCODING = sys.getfilesystemencoding()
if WINDOWS and PY2:
FS_ENCODING = "utf-8"
+else:
+ FS_ENCODING = sys.getfilesystemencoding()
SYMTIME = sys.version_info >= (3, 6) and os.utime in os.supports_follow_symlinks
@@ -116,7 +145,7 @@ MIMES = {
}
-def _add_mimes():
+def _add_mimes() -> None:
for ln in """text css html csv
application json wasm xml pdf rtf zip
image webp jpeg png gif bmp
@@ -170,18 +199,18 @@ REKOBO_LKEY = {k.lower(): v for k, v in REKOBO_KEY.items()}
class Cooldown(object):
- def __init__(self, maxage):
+ def __init__(self, maxage: float) -> None:
self.maxage = maxage
self.mutex = threading.Lock()
- self.hist = {}
- self.oldest = 0
+ self.hist: dict[str, float] = {}
+ self.oldest = 0.0
- def poke(self, key):
+ def poke(self, key: str) -> bool:
with self.mutex:
now = time.time()
ret = False
- pv = self.hist.get(key, 0)
+ pv: float = self.hist.get(key, 0)
if now - pv > self.maxage:
self.hist[key] = now
ret = True
@@ -204,12 +233,12 @@ class _Unrecv(object):
undo any number of socket recv ops
"""
- def __init__(self, s, log):
- self.s = s # type: socket.socket
+ def __init__(self, s: socket.socket, log: Optional[NamedLogger]) -> None:
+ self.s = s
self.log = log
- self.buf = b""
+ self.buf: bytes = b""
- def recv(self, nbytes):
+ def recv(self, nbytes: int) -> bytes:
if self.buf:
ret = self.buf[:nbytes]
self.buf = self.buf[nbytes:]
@@ -221,25 +250,25 @@ class _Unrecv(object):
return ret
- def recv_ex(self, nbytes, raise_on_trunc=True):
+ def recv_ex(self, nbytes: int, raise_on_trunc: bool = True) -> bytes:
"""read an exact number of bytes"""
ret = b""
try:
while nbytes > len(ret):
ret += self.recv(nbytes - len(ret))
except OSError:
- m = "client only sent {} of {} expected bytes".format(len(ret), nbytes)
+ t = "client only sent {} of {} expected bytes".format(len(ret), nbytes)
if len(ret) <= 16:
- m += "; got {!r}".format(ret)
+ t += "; got {!r}".format(ret)
if raise_on_trunc:
- raise UnrecvEOF(5, m)
+ raise UnrecvEOF(5, t)
elif self.log:
- self.log(m, 3)
+ self.log(t, 3)
return ret
- def unrecv(self, buf):
+ def unrecv(self, buf: bytes) -> None:
self.buf = buf + self.buf
@@ -248,28 +277,28 @@ class _LUnrecv(object):
with expensive debug logging
"""
- def __init__(self, s, log):
+ def __init__(self, s: socket.socket, log: Optional[NamedLogger]) -> None:
self.s = s
self.log = log
self.buf = b""
- def recv(self, nbytes):
+ def recv(self, nbytes: int) -> bytes:
if self.buf:
ret = self.buf[:nbytes]
self.buf = self.buf[nbytes:]
- m = "\033[0;7mur:pop:\033[0;1;32m {}\n\033[0;7mur:rem:\033[0;1;35m {}\033[0m"
- self.log(m.format(ret, self.buf))
+ t = "\033[0;7mur:pop:\033[0;1;32m {}\n\033[0;7mur:rem:\033[0;1;35m {}\033[0m"
+ print(t.format(ret, self.buf))
return ret
ret = self.s.recv(nbytes)
- m = "\033[0;7mur:recv\033[0;1;33m {}\033[0m"
- self.log(m.format(ret))
+ t = "\033[0;7mur:recv\033[0;1;33m {}\033[0m"
+ print(t.format(ret))
if not ret:
raise UnrecvEOF("client stopped sending data")
return ret
- def recv_ex(self, nbytes, raise_on_trunc=True):
+ def recv_ex(self, nbytes: int, raise_on_trunc: bool = True) -> bytes:
"""read an exact number of bytes"""
try:
ret = self.recv(nbytes)
@@ -285,18 +314,18 @@ class _LUnrecv(object):
err = True
if err:
- m = "client only sent {} of {} expected bytes".format(len(ret), nbytes)
+ t = "client only sent {} of {} expected bytes".format(len(ret), nbytes)
if raise_on_trunc:
- raise UnrecvEOF(m)
+ raise UnrecvEOF(t)
elif self.log:
- self.log(m, 3)
+ self.log(t, 3)
return ret
- def unrecv(self, buf):
+ def unrecv(self, buf: bytes) -> None:
self.buf = buf + self.buf
- m = "\033[0;7mur:push\033[0;1;31m {}\n\033[0;7mur:rem:\033[0;1;35m {}\033[0m"
- self.log(m.format(buf, self.buf))
+ t = "\033[0;7mur:push\033[0;1;31m {}\n\033[0;7mur:rem:\033[0;1;35m {}\033[0m"
+ print(t.format(buf, self.buf))
Unrecv = _Unrecv
@@ -304,14 +333,14 @@ Unrecv = _Unrecv
class FHC(object):
class CE(object):
- def __init__(self, fh):
- self.ts = 0
+ def __init__(self, fh: typing.BinaryIO) -> None:
+ self.ts: float = 0
self.fhs = [fh]
- def __init__(self):
- self.cache = {}
+ def __init__(self) -> None:
+ self.cache: dict[str, FHC.CE] = {}
- def close(self, path):
+ def close(self, path: str) -> None:
try:
ce = self.cache[path]
except:
@@ -322,7 +351,7 @@ class FHC(object):
del self.cache[path]
- def clean(self):
+ def clean(self) -> None:
if not self.cache:
return
@@ -337,10 +366,10 @@ class FHC(object):
self.cache = keep
- def pop(self, path):
+ def pop(self, path: str) -> typing.BinaryIO:
return self.cache[path].fhs.pop()
- def put(self, path, fh):
+ def put(self, path: str, fh: typing.BinaryIO) -> None:
try:
ce = self.cache[path]
ce.fhs.append(fh)
@@ -356,14 +385,15 @@ class ProgressPrinter(threading.Thread):
periodically print progress info without linefeeds
"""
- def __init__(self):
+ def __init__(self) -> None:
threading.Thread.__init__(self, name="pp")
self.daemon = True
- self.msg = None
+ self.msg = ""
self.end = False
+ self.n = -1
self.start()
- def run(self):
+ def run(self) -> None:
msg = None
fmt = " {}\033[K\r" if VT100 else " {} $\r"
while not self.end:
@@ -384,7 +414,7 @@ class ProgressPrinter(threading.Thread):
sys.stdout.flush() # necessary on win10 even w/ stderr btw
-def uprint(msg):
+def uprint(msg: str) -> None:
try:
print(msg, end="")
except UnicodeEncodeError:
@@ -394,17 +424,17 @@ def uprint(msg):
print(msg.encode("ascii", "replace").decode(), end="")
-def nuprint(msg):
+def nuprint(msg: str) -> None:
uprint("{}\n".format(msg))
-def rice_tid():
+def rice_tid() -> str:
tid = threading.current_thread().ident
c = sunpack(b"B" * 5, spack(b">Q", tid)[-5:])
return "".join("\033[1;37;48;5;{0}m{0:02x}".format(x) for x in c) + "\033[0m"
-def trace(*args, **kwargs):
+def trace(*args: Any, **kwargs: Any) -> None:
t = time.time()
stack = "".join(
"\033[36m{}\033[33m{}".format(x[0].split(os.sep)[-1][:-3], x[1])
@@ -423,15 +453,15 @@ def trace(*args, **kwargs):
nuprint(msg)
-def alltrace():
- threads = {}
+def alltrace() -> str:
+ threads: dict[str, types.FrameType] = {}
names = dict([(t.ident, t.name) for t in threading.enumerate()])
for tid, stack in sys._current_frames().items():
name = "{} ({:x})".format(names.get(tid), tid)
threads[name] = stack
- rret = []
- bret = []
+ rret: list[str] = []
+ bret: list[str] = []
for name, stack in sorted(threads.items()):
ret = ["\n\n# {}".format(name)]
pad = None
@@ -451,20 +481,20 @@ def alltrace():
return "\n".join(rret + bret)
-def start_stackmon(arg_str, nid):
+def start_stackmon(arg_str: str, nid: int) -> None:
suffix = "-{}".format(nid) if nid else ""
fp, f = arg_str.rsplit(",", 1)
- f = int(f)
+ zi = int(f)
t = threading.Thread(
target=stackmon,
- args=(fp, f, suffix),
+ args=(fp, zi, suffix),
name="stackmon" + suffix,
)
t.daemon = True
t.start()
-def stackmon(fp, ival, suffix):
+def stackmon(fp: str, ival: float, suffix: str) -> None:
ctr = 0
while True:
ctr += 1
@@ -474,7 +504,9 @@ def stackmon(fp, ival, suffix):
f.write(st.encode("utf-8", "replace"))
-def start_log_thrs(logger, ival, nid):
+def start_log_thrs(
+ logger: Callable[[str, str, int], None], ival: float, nid: int
+) -> None:
ival = float(ival)
tname = lname = "log-thrs"
if nid:
@@ -490,7 +522,7 @@ def start_log_thrs(logger, ival, nid):
t.start()
-def log_thrs(log, ival, name):
+def log_thrs(log: Callable[[str, str, int], None], ival: float, name: str) -> None:
while True:
time.sleep(ival)
tv = [x.name for x in threading.enumerate()]
@@ -507,7 +539,7 @@ def log_thrs(log, ival, name):
log(name, "\033[0m \033[33m".join(tv), 3)
-def vol_san(vols, txt):
+def vol_san(vols: list["VFS"], txt: bytes) -> bytes:
for vol in vols:
txt = txt.replace(vol.realpath.encode("utf-8"), vol.vpath.encode("utf-8"))
txt = txt.replace(
@@ -518,24 +550,26 @@ def vol_san(vols, txt):
return txt
-def min_ex(max_lines=8, reverse=False):
+def min_ex(max_lines: int = 8, reverse: bool = False) -> str:
et, ev, tb = sys.exc_info()
- tb = traceback.extract_tb(tb)
+ stb = traceback.extract_tb(tb)
fmt = "{} @ {} <{}>: {}"
- ex = [fmt.format(fp.split(os.sep)[-1], ln, fun, txt) for fp, ln, fun, txt in tb]
- ex.append("[{}] {}".format(et.__name__, ev))
+ ex = [fmt.format(fp.split(os.sep)[-1], ln, fun, txt) for fp, ln, fun, txt in stb]
+ ex.append("[{}] {}".format(et.__name__ if et else "(anonymous)", ev))
return "\n".join(ex[-max_lines:][:: -1 if reverse else 1])
@contextlib.contextmanager
-def ren_open(fname, *args, **kwargs):
+def ren_open(
+ fname: str, *args: Any, **kwargs: Any
+) -> Generator[dict[str, tuple[typing.IO[Any], str]], None, None]:
fun = kwargs.pop("fun", open)
fdir = kwargs.pop("fdir", None)
suffix = kwargs.pop("suffix", None)
if fname == os.devnull:
with fun(fname, *args, **kwargs) as f:
- yield {"orz": [f, fname]}
+ yield {"orz": (f, fname)}
return
if suffix:
@@ -575,7 +609,7 @@ def ren_open(fname, *args, **kwargs):
with open(fsenc(fp2), "wb") as f2:
f2.write(orig_name.encode("utf-8"))
- yield {"orz": [f, fname]}
+ yield {"orz": (f, fname)}
return
except OSError as ex_:
@@ -584,9 +618,9 @@ def ren_open(fname, *args, **kwargs):
raise
if not b64:
- b64 = (bname + ext).encode("utf-8", "replace")
- b64 = hashlib.sha512(b64).digest()[:12]
- b64 = base64.urlsafe_b64encode(b64).decode("utf-8")
+ zs = (bname + ext).encode("utf-8", "replace")
+ zs = hashlib.sha512(zs).digest()[:12]
+ b64 = base64.urlsafe_b64encode(zs).decode("utf-8")
badlen = len(fname)
while len(fname) >= badlen:
@@ -608,8 +642,8 @@ def ren_open(fname, *args, **kwargs):
class MultipartParser(object):
- def __init__(self, log_func, sr, http_headers):
- self.sr = sr # type: Unrecv
+ def __init__(self, log_func: NamedLogger, sr: Unrecv, http_headers: dict[str, str]):
+ self.sr = sr
self.log = log_func
self.headers = http_headers
@@ -622,10 +656,14 @@ class MultipartParser(object):
r'^content-disposition:(?: *|.*; *)filename="(.*)"', re.IGNORECASE
)
- self.boundary = None
- self.gen = None
+ self.boundary = b""
+ self.gen: Optional[
+ Generator[
+ tuple[str, Optional[str], Generator[bytes, None, None]], None, None
+ ]
+ ] = None
- def _read_header(self):
+ def _read_header(self) -> tuple[str, Optional[str]]:
"""
returns [fieldname, filename] after eating a block of multipart headers
while doing a decent job at dealing with the absolute mess that is
@@ -641,7 +679,8 @@ class MultipartParser(object):
# rfc-7578 overrides rfc-2388 so this is not-impl
# (opera >=9 <11.10 is the only thing i've ever seen use it)
raise Pebkac(
- "you can't use that browser to upload multiple files at once"
+ 400,
+ "you can't use that browser to upload multiple files at once",
)
continue
@@ -655,12 +694,12 @@ class MultipartParser(object):
raise Pebkac(400, "not form-data: {}".format(ln))
try:
- field = self.re_cdisp_field.match(ln).group(1)
+ field = self.re_cdisp_field.match(ln).group(1) # type: ignore
except:
raise Pebkac(400, "missing field name: {}".format(ln))
try:
- fn = self.re_cdisp_file.match(ln).group(1)
+ fn = self.re_cdisp_file.match(ln).group(1) # type: ignore
except:
# this is not a file upload, we're done
return field, None
@@ -687,11 +726,10 @@ class MultipartParser(object):
esc = False
for ch in fn:
if esc:
- if ch in ['"', "\\"]:
- ret += '"'
- else:
- ret += esc + ch
esc = False
+ if ch not in ['"', "\\"]:
+ ret += "\\"
+ ret += ch
elif ch == "\\":
esc = True
elif ch == '"':
@@ -699,9 +737,11 @@ class MultipartParser(object):
else:
ret += ch
- return [field, ret]
+ return field, ret
- def _read_data(self):
+ raise Pebkac(400, "server expected a multipart header but you never sent one")
+
+ def _read_data(self) -> Generator[bytes, None, None]:
blen = len(self.boundary)
bufsz = 32 * 1024
while True:
@@ -748,7 +788,9 @@ class MultipartParser(object):
yield buf
- def _run_gen(self):
+ def _run_gen(
+ self,
+ ) -> Generator[tuple[str, Optional[str], Generator[bytes, None, None]], None, None]:
"""
yields [fieldname, unsanitized_filename, fieldvalue]
where fieldvalue yields chunks of data
@@ -756,7 +798,7 @@ class MultipartParser(object):
run = True
while run:
fieldname, filename = self._read_header()
- yield [fieldname, filename, self._read_data()]
+ yield (fieldname, filename, self._read_data())
tail = self.sr.recv_ex(2, False)
@@ -766,19 +808,19 @@ class MultipartParser(object):
run = False
if tail != b"\r\n":
- m = "protocol error after field value: want b'\\r\\n', got {!r}"
- raise Pebkac(400, m.format(tail))
+ t = "protocol error after field value: want b'\\r\\n', got {!r}"
+ raise Pebkac(400, t.format(tail))
- def _read_value(self, iterator, max_len):
+ def _read_value(self, iterable: Iterable[bytes], max_len: int) -> bytes:
ret = b""
- for buf in iterator:
+ for buf in iterable:
ret += buf
if len(ret) > max_len:
raise Pebkac(400, "field length is too long")
return ret
- def parse(self):
+ def parse(self) -> None:
# spec says there might be junk before the first boundary,
# can't have the leading \r\n if that's not the case
self.boundary = b"--" + get_boundary(self.headers).encode("utf-8")
@@ -793,11 +835,12 @@ class MultipartParser(object):
self.boundary = b"\r\n" + self.boundary
self.gen = self._run_gen()
- def require(self, field_name, max_len):
+ def require(self, field_name: str, max_len: int) -> str:
"""
returns the value of the next field in the multipart body,
raises if the field name is not as expected
"""
+ assert self.gen
p_field, _, p_data = next(self.gen)
if p_field != field_name:
raise Pebkac(
@@ -806,14 +849,15 @@ class MultipartParser(object):
return self._read_value(p_data, max_len).decode("utf-8", "surrogateescape")
- def drop(self):
+ def drop(self) -> None:
"""discards the remaining multipart body"""
+ assert self.gen
for _, _, data in self.gen:
for _ in data:
pass
-def get_boundary(headers):
+def get_boundary(headers: dict[str, str]) -> str:
# boundaries contain a-z A-Z 0-9 ' ( ) + _ , - . / : = ?
# (whitespace allowed except as the last char)
ptn = r"^multipart/form-data *; *(.*; *)?boundary=([^;]+)"
@@ -825,14 +869,14 @@ def get_boundary(headers):
return m.group(2)
-def read_header(sr):
+def read_header(sr: Unrecv) -> list[str]:
ret = b""
while True:
try:
ret += sr.recv(1024)
except:
if not ret:
- return None
+ return []
raise Pebkac(
400,
@@ -853,7 +897,7 @@ def read_header(sr):
return ret[:ofs].decode("utf-8", "surrogateescape").lstrip("\r\n").split("\r\n")
-def gen_filekey(salt, fspath, fsize, inode):
+def gen_filekey(salt: str, fspath: str, fsize: int, inode: int) -> str:
return base64.urlsafe_b64encode(
hashlib.sha512(
"{} {} {} {}".format(salt, fspath, fsize, inode).encode("utf-8", "replace")
@@ -861,7 +905,7 @@ def gen_filekey(salt, fspath, fsize, inode):
).decode("ascii")
-def gencookie(k, v, dur):
+def gencookie(k: str, v: str, dur: Optional[int]) -> str:
v = v.replace(";", "")
if dur:
dt = datetime.utcfromtimestamp(time.time() + dur)
@@ -872,7 +916,7 @@ def gencookie(k, v, dur):
return "{}={}; Path=/; Expires={}; SameSite=Lax".format(k, v, exp)
-def humansize(sz, terse=False):
+def humansize(sz: float, terse: bool = False) -> str:
for unit in ["B", "KiB", "MiB", "GiB", "TiB"]:
if sz < 1024:
break
@@ -887,18 +931,18 @@ def humansize(sz, terse=False):
return ret.replace("iB", "").replace(" ", "")
-def unhumanize(sz):
+def unhumanize(sz: str) -> int:
try:
- return float(sz)
+ return int(sz)
except:
pass
- mul = sz[-1:].lower()
- mul = {"k": 1024, "m": 1024 * 1024, "g": 1024 * 1024 * 1024}.get(mul, 1)
- return float(sz[:-1]) * mul
+ mc = sz[-1:].lower()
+ mi = {"k": 1024, "m": 1024 * 1024, "g": 1024 * 1024 * 1024}.get(mc, 1)
+ return int(float(sz[:-1]) * mi)
-def get_spd(nbyte, t0, t=None):
+def get_spd(nbyte: int, t0: float, t: Optional[float] = None) -> str:
if t is None:
t = time.time()
@@ -908,7 +952,7 @@ def get_spd(nbyte, t0, t=None):
return "{} \033[0m{}/s\033[0m".format(s1, s2)
-def s2hms(s, optional_h=False):
+def s2hms(s: float, optional_h: bool = False) -> str:
s = int(s)
h, s = divmod(s, 3600)
m, s = divmod(s, 60)
@@ -918,7 +962,7 @@ def s2hms(s, optional_h=False):
return "{}:{:02}:{:02}".format(h, m, s)
-def uncyg(path):
+def uncyg(path: str) -> str:
if len(path) < 2 or not path.startswith("/"):
return path
@@ -928,8 +972,8 @@ def uncyg(path):
return "{}:\\{}".format(path[1], path[3:])
-def undot(path):
- ret = []
+def undot(path: str) -> str:
+ ret: list[str] = []
for node in path.split("/"):
if node in ["", "."]:
continue
@@ -944,7 +988,7 @@ def undot(path):
return "/".join(ret)
-def sanitize_fn(fn, ok, bad):
+def sanitize_fn(fn: str, ok: str, bad: list[str]) -> str:
if "/" not in ok:
fn = fn.replace("\\", "/").split("/")[-1]
@@ -976,7 +1020,7 @@ def sanitize_fn(fn, ok, bad):
return fn.strip()
-def relchk(rp):
+def relchk(rp: str) -> str:
if ANYWIN:
if "\n" in rp or "\r" in rp:
return "x\nx"
@@ -985,8 +1029,10 @@ def relchk(rp):
if p != rp:
return "[{}]".format(p)
+ return ""
-def absreal(fpath):
+
+def absreal(fpath: str) -> str:
try:
return fsdec(os.path.abspath(os.path.realpath(fsenc(fpath))))
except:
@@ -999,26 +1045,26 @@ def absreal(fpath):
return os.path.abspath(os.path.realpath(fpath))
-def u8safe(txt):
+def u8safe(txt: str) -> str:
try:
return txt.encode("utf-8", "xmlcharrefreplace").decode("utf-8", "replace")
except:
return txt.encode("utf-8", "replace").decode("utf-8", "replace")
-def exclude_dotfiles(filepaths):
+def exclude_dotfiles(filepaths: list[str]) -> list[str]:
return [x for x in filepaths if not x.split("/")[-1].startswith(".")]
-def http_ts(ts):
+def http_ts(ts: int) -> str:
file_dt = datetime.utcfromtimestamp(ts)
return file_dt.strftime(HTTP_TS_FMT)
-def html_escape(s, quote=False, crlf=False):
+def html_escape(s: str, quot: bool = False, crlf: bool = False) -> str:
"""html.escape but also newlines"""
s = s.replace("&", "&").replace("<", "<").replace(">", ">")
- if quote:
+ if quot:
s = s.replace('"', """).replace("'", "'")
if crlf:
s = s.replace("\r", "
").replace("\n", "
")
@@ -1026,10 +1072,10 @@ def html_escape(s, quote=False, crlf=False):
return s
-def html_bescape(s, quote=False, crlf=False):
+def html_bescape(s: bytes, quot: bool = False, crlf: bool = False) -> bytes:
"""html.escape but bytestrings"""
s = s.replace(b"&", b"&").replace(b"<", b"<").replace(b">", b">")
- if quote:
+ if quot:
s = s.replace(b'"', b""").replace(b"'", b"'")
if crlf:
s = s.replace(b"\r", b"
").replace(b"\n", b"
")
@@ -1037,18 +1083,20 @@ def html_bescape(s, quote=False, crlf=False):
return s
-def quotep(txt):
+def quotep(txt: str) -> str:
"""url quoter which deals with bytes correctly"""
btxt = w8enc(txt)
quot1 = quote(btxt, safe=b"/")
if not PY2:
- quot1 = quot1.encode("ascii")
+ quot2 = quot1.encode("ascii")
+ else:
+ quot2 = quot1
- quot2 = quot1.replace(b" ", b"+")
- return w8dec(quot2)
+ quot3 = quot2.replace(b" ", b"+")
+ return w8dec(quot3)
-def unquotep(txt):
+def unquotep(txt: str) -> str:
"""url unquoter which deals with bytes correctly"""
btxt = w8enc(txt)
# btxt = btxt.replace(b"+", b" ")
@@ -1056,14 +1104,14 @@ def unquotep(txt):
return w8dec(unq2)
-def vsplit(vpath):
+def vsplit(vpath: str) -> tuple[str, str]:
if "/" not in vpath:
return "", vpath
- return vpath.rsplit("/", 1)
+ return vpath.rsplit("/", 1) # type: ignore
-def w8dec(txt):
+def w8dec(txt: bytes) -> str:
"""decodes filesystem-bytes to wtf8"""
if PY2:
return surrogateescape.decodefilename(txt)
@@ -1071,7 +1119,7 @@ def w8dec(txt):
return txt.decode(FS_ENCODING, "surrogateescape")
-def w8enc(txt):
+def w8enc(txt: str) -> bytes:
"""encodes wtf8 to filesystem-bytes"""
if PY2:
return surrogateescape.encodefilename(txt)
@@ -1079,12 +1127,12 @@ def w8enc(txt):
return txt.encode(FS_ENCODING, "surrogateescape")
-def w8b64dec(txt):
+def w8b64dec(txt: str) -> str:
"""decodes base64(filesystem-bytes) to wtf8"""
return w8dec(base64.urlsafe_b64decode(txt.encode("ascii")))
-def w8b64enc(txt):
+def w8b64enc(txt: str) -> str:
"""encodes wtf8 to base64(filesystem-bytes)"""
return base64.urlsafe_b64encode(w8enc(txt)).decode("ascii")
@@ -1102,8 +1150,8 @@ else:
fsdec = w8dec
-def s3enc(mem_cur, rd, fn):
- ret = []
+def s3enc(mem_cur: "sqlite3.Cursor", rd: str, fn: str) -> tuple[str, str]:
+ ret: list[str] = []
for v in [rd, fn]:
try:
mem_cur.execute("select * from a where b = ?", (v,))
@@ -1112,10 +1160,10 @@ def s3enc(mem_cur, rd, fn):
ret.append("//" + w8b64enc(v))
# self.log("mojien [{}] {}".format(v, ret[-1][2:]))
- return tuple(ret)
+ return ret[0], ret[1]
-def s3dec(rd, fn):
+def s3dec(rd: str, fn: str) -> tuple[str, str]:
ret = []
for v in [rd, fn]:
if v.startswith("//"):
@@ -1124,12 +1172,12 @@ def s3dec(rd, fn):
else:
ret.append(v)
- return tuple(ret)
+ return ret[0], ret[1]
-def atomic_move(src, dst):
- src = fsenc(src)
- dst = fsenc(dst)
+def atomic_move(usrc: str, udst: str) -> None:
+ src = fsenc(usrc)
+ dst = fsenc(udst)
if not PY2:
os.replace(src, dst)
else:
@@ -1139,7 +1187,7 @@ def atomic_move(src, dst):
os.rename(src, dst)
-def read_socket(sr, total_size):
+def read_socket(sr: Unrecv, total_size: int) -> Generator[bytes, None, None]:
remains = total_size
while remains > 0:
bufsz = 32 * 1024
@@ -1149,14 +1197,14 @@ def read_socket(sr, total_size):
try:
buf = sr.recv(bufsz)
except OSError:
- m = "client d/c during binary post after {} bytes, {} bytes remaining"
- raise Pebkac(400, m.format(total_size - remains, remains))
+ t = "client d/c during binary post after {} bytes, {} bytes remaining"
+ raise Pebkac(400, t.format(total_size - remains, remains))
remains -= len(buf)
yield buf
-def read_socket_unbounded(sr):
+def read_socket_unbounded(sr: Unrecv) -> Generator[bytes, None, None]:
try:
while True:
yield sr.recv(32 * 1024)
@@ -1164,7 +1212,9 @@ def read_socket_unbounded(sr):
return
-def read_socket_chunked(sr, log=None):
+def read_socket_chunked(
+ sr: Unrecv, log: Optional[NamedLogger] = None
+) -> Generator[bytes, None, None]:
err = "upload aborted: expected chunk length, got [{}] |{}| instead"
while True:
buf = b""
@@ -1191,8 +1241,8 @@ def read_socket_chunked(sr, log=None):
if x == b"\r\n":
return
- m = "protocol error after final chunk: want b'\\r\\n', got {!r}"
- raise Pebkac(400, m.format(x))
+ t = "protocol error after final chunk: want b'\\r\\n', got {!r}"
+ raise Pebkac(400, t.format(x))
if log:
log("receiving {} byte chunk".format(chunklen))
@@ -1202,11 +1252,11 @@ def read_socket_chunked(sr, log=None):
x = sr.recv_ex(2, False)
if x != b"\r\n":
- m = "protocol error in chunk separator: want b'\\r\\n', got {!r}"
- raise Pebkac(400, m.format(x))
+ t = "protocol error in chunk separator: want b'\\r\\n', got {!r}"
+ raise Pebkac(400, t.format(x))
-def yieldfile(fn):
+def yieldfile(fn: str) -> Generator[bytes, None, None]:
with open(fsenc(fn), "rb", 512 * 1024) as f:
while True:
buf = f.read(64 * 1024)
@@ -1216,7 +1266,11 @@ def yieldfile(fn):
yield buf
-def hashcopy(fin, fout, slp=0):
+def hashcopy(
+ fin: Union[typing.BinaryIO, Generator[bytes, None, None]],
+ fout: Union[typing.BinaryIO, typing.IO[Any]],
+ slp: int = 0,
+) -> tuple[int, str, str]:
hashobj = hashlib.sha512()
tlen = 0
for buf in fin:
@@ -1232,7 +1286,15 @@ def hashcopy(fin, fout, slp=0):
return tlen, hashobj.hexdigest(), digest_b64
-def sendfile_py(log, lower, upper, f, s, bufsz, slp):
+def sendfile_py(
+ log: NamedLogger,
+ lower: int,
+ upper: int,
+ f: typing.BinaryIO,
+ s: socket.socket,
+ bufsz: int,
+ slp: int,
+) -> int:
remains = upper - lower
f.seek(lower)
while remains > 0:
@@ -1252,25 +1314,37 @@ def sendfile_py(log, lower, upper, f, s, bufsz, slp):
return 0
-def sendfile_kern(log, lower, upper, f, s, bufsz, slp):
+def sendfile_kern(
+ log: NamedLogger,
+ lower: int,
+ upper: int,
+ f: typing.BinaryIO,
+ s: socket.socket,
+ bufsz: int,
+ slp: int,
+) -> int:
out_fd = s.fileno()
in_fd = f.fileno()
ofs = lower
- stuck = None
+ stuck = 0.0
while ofs < upper:
stuck = stuck or time.time()
try:
req = min(2 ** 30, upper - ofs)
select.select([], [out_fd], [], 10)
n = os.sendfile(out_fd, in_fd, ofs, req)
- stuck = None
- except Exception as ex:
+ stuck = 0
+ except OSError as ex:
d = time.time() - stuck
log("sendfile stuck for {:.3f} sec: {!r}".format(d, ex))
if d < 3600 and ex.errno == 11: # eagain
continue
n = 0
+ except Exception as ex:
+ n = 0
+ d = time.time() - stuck
+ log("sendfile failed after {:.3f} sec: {!r}".format(d, ex))
if n <= 0:
return upper - ofs
@@ -1281,7 +1355,9 @@ def sendfile_kern(log, lower, upper, f, s, bufsz, slp):
return 0
-def statdir(logger, scandir, lstat, top):
+def statdir(
+ logger: Optional[RootLogger], scandir: bool, lstat: bool, top: str
+) -> Generator[tuple[str, os.stat_result], None, None]:
if lstat and ANYWIN:
lstat = False
@@ -1295,30 +1371,42 @@ def statdir(logger, scandir, lstat, top):
with os.scandir(btop) as dh:
for fh in dh:
try:
- yield [fsdec(fh.name), fh.stat(follow_symlinks=not lstat)]
+ yield (fsdec(fh.name), fh.stat(follow_symlinks=not lstat))
except Exception as ex:
+ if not logger:
+ continue
+
logger(src, "[s] {} @ {}".format(repr(ex), fsdec(fh.path)), 6)
else:
src = "listdir"
- fun = os.lstat if lstat else os.stat
+ fun: Any = os.lstat if lstat else os.stat
for name in os.listdir(btop):
abspath = os.path.join(btop, name)
try:
- yield [fsdec(name), fun(abspath)]
+ yield (fsdec(name), fun(abspath))
except Exception as ex:
+ if not logger:
+ continue
+
logger(src, "[s] {} @ {}".format(repr(ex), fsdec(abspath)), 6)
except Exception as ex:
- logger(src, "{} @ {}".format(repr(ex), top), 1)
+ t = "{} @ {}".format(repr(ex), top)
+ if logger:
+ logger(src, t, 1)
+ else:
+ print(t)
-def rmdirs(logger, scandir, lstat, top, depth):
+def rmdirs(
+ logger: RootLogger, scandir: bool, lstat: bool, top: str, depth: int
+) -> tuple[list[str], list[str]]:
if not os.path.exists(fsenc(top)) or not os.path.isdir(fsenc(top)):
top = os.path.dirname(top)
depth -= 1
- dirs = statdir(logger, scandir, lstat, top)
- dirs = [x[0] for x in dirs if stat.S_ISDIR(x[1].st_mode)]
+ stats = statdir(logger, scandir, lstat, top)
+ dirs = [x[0] for x in stats if stat.S_ISDIR(x[1].st_mode)]
dirs = [os.path.join(top, x) for x in dirs]
ok = []
ng = []
@@ -1337,7 +1425,7 @@ def rmdirs(logger, scandir, lstat, top, depth):
return ok, ng
-def unescape_cookie(orig):
+def unescape_cookie(orig: str) -> str:
# mw=idk; doot=qwe%2Crty%3Basd+fgh%2Bjkl%25zxc%26vbn # qwe,rty;asd fgh+jkl%zxc&vbn
ret = ""
esc = ""
@@ -1365,7 +1453,7 @@ def unescape_cookie(orig):
return ret
-def guess_mime(url, fallback="application/octet-stream"):
+def guess_mime(url: str, fallback: str = "application/octet-stream") -> str:
try:
_, ext = url.rsplit(".", 1)
except:
@@ -1387,7 +1475,9 @@ def guess_mime(url, fallback="application/octet-stream"):
return ret
-def runcmd(argv, timeout=None, **ka):
+def runcmd(
+ argv: Union[list[bytes], list[str]], timeout: Optional[int] = None, **ka: Any
+) -> tuple[int, str, str]:
p = sp.Popen(argv, stdout=sp.PIPE, stderr=sp.PIPE, **ka)
if not timeout or PY2:
stdout, stderr = p.communicate()
@@ -1400,10 +1490,10 @@ def runcmd(argv, timeout=None, **ka):
stdout = stdout.decode("utf-8", "replace")
stderr = stderr.decode("utf-8", "replace")
- return [p.returncode, stdout, stderr]
+ return p.returncode, stdout, stderr
-def chkcmd(argv, **ka):
+def chkcmd(argv: Union[list[bytes], list[str]], **ka: Any) -> tuple[str, str]:
ok, sout, serr = runcmd(argv, **ka)
if ok != 0:
retchk(ok, argv, serr)
@@ -1412,7 +1502,7 @@ def chkcmd(argv, **ka):
return sout, serr
-def mchkcmd(argv, timeout=10):
+def mchkcmd(argv: Union[list[bytes], list[str]], timeout: int = 10) -> None:
if PY2:
with open(os.devnull, "wb") as f:
rv = sp.call(argv, stdout=f, stderr=f)
@@ -1423,7 +1513,14 @@ def mchkcmd(argv, timeout=10):
raise sp.CalledProcessError(rv, (argv[0], b"...", argv[-1]))
-def retchk(rc, cmd, serr, logger=None, color=None, verbose=False):
+def retchk(
+ rc: int,
+ cmd: Union[list[bytes], list[str]],
+ serr: str,
+ logger: Optional[NamedLogger] = None,
+ color: Union[int, str] = 0,
+ verbose: bool = False,
+) -> None:
if rc < 0:
rc = 128 - rc
@@ -1446,33 +1543,33 @@ def retchk(rc, cmd, serr, logger=None, color=None, verbose=False):
s = "invalid retcode"
if s:
- m = "{} <{}>".format(rc, s)
+ t = "{} <{}>".format(rc, s)
else:
- m = str(rc)
+ t = str(rc)
try:
- c = " ".join([fsdec(x) for x in cmd])
+ c = " ".join([fsdec(x) for x in cmd]) # type: ignore
except:
c = str(cmd)
- m = "error {} from [{}]".format(m, c)
+ t = "error {} from [{}]".format(t, c)
if serr:
- m += "\n" + serr
+ t += "\n" + serr
if logger:
- logger(m, color)
+ logger(t, color)
else:
- raise Exception(m)
+ raise Exception(t)
-def gzip_orig_sz(fn):
+def gzip_orig_sz(fn: str) -> int:
with open(fsenc(fn), "rb") as f:
f.seek(-4, 2)
rv = f.read(4)
- return sunpack(b"I", rv)[0]
+ return sunpack(b"I", rv)[0] # type: ignore
-def py_desc():
+def py_desc() -> str:
interp = platform.python_implementation()
py_ver = ".".join([str(x) for x in sys.version_info])
ofs = py_ver.find(".final.")
@@ -1487,15 +1584,15 @@ def py_desc():
host_os = platform.system()
compiler = platform.python_compiler()
- os_ver = re.search(r"([0-9]+\.[0-9\.]+)", platform.version())
- os_ver = os_ver.group(1) if os_ver else ""
+ m = re.search(r"([0-9]+\.[0-9\.]+)", platform.version())
+ os_ver = m.group(1) if m else ""
return "{:>9} v{} on {}{} {} [{}]".format(
interp, py_ver, host_os, bitness, os_ver, compiler
)
-def align_tab(lines):
+def align_tab(lines: list[str]) -> list[str]:
rows = []
ncols = 0
for ln in lines:
@@ -1512,9 +1609,9 @@ def align_tab(lines):
class Pebkac(Exception):
- def __init__(self, code, msg=None):
+ def __init__(self, code: int, msg: Optional[str] = None) -> None:
super(Pebkac, self).__init__(msg or HTTPCODE[code])
self.code = code
- def __repr__(self):
+ def __repr__(self) -> str:
return "Pebkac({}, {})".format(self.code, repr(self.args))
diff --git a/scripts/make-pypi-release.sh b/scripts/make-pypi-release.sh
index 6a568051..7e723d4b 100755
--- a/scripts/make-pypi-release.sh
+++ b/scripts/make-pypi-release.sh
@@ -90,6 +90,15 @@ function have() {
have setuptools
have wheel
have twine
+
+# remove type hints to support python < 3.9
+rm -rf build/pypi
+mkdir -p build/pypi
+cp -pR setup.py README.md LICENSE copyparty tests bin scripts/strip_hints build/pypi/
+cd build/pypi
+tar --strip-components=2 -xf ../strip-hints-0.1.10.tar.gz strip-hints-0.1.10/src/strip_hints
+python3 -c 'from strip_hints.a import uh; uh("copyparty")'
+
./setup.py clean2
./setup.py sdist bdist_wheel --universal
diff --git a/scripts/make-sfx.sh b/scripts/make-sfx.sh
index 8f9856f3..7402c517 100755
--- a/scripts/make-sfx.sh
+++ b/scripts/make-sfx.sh
@@ -76,7 +76,7 @@ while [ ! -z "$1" ]; do
no-hl) no_hl=1 ; ;;
no-dd) no_dd=1 ; ;;
no-cm) no_cm=1 ; ;;
- fast) zopf=100 ; ;;
+ fast) zopf= ; ;;
lang) shift;langs="$1"; ;;
*) help ; ;;
esac
@@ -106,7 +106,7 @@ tmpdir="$(
[ $repack ] && {
old="$tmpdir/pe-copyparty"
echo "repack of files in $old"
- cp -pR "$old/"*{dep-j2,dep-ftp,copyparty} .
+ cp -pR "$old/"*{j2,ftp,copyparty} .
}
[ $repack ] || {
@@ -130,8 +130,8 @@ tmpdir="$(
mv MarkupSafe-*/src/markupsafe .
rm -rf MarkupSafe-* markupsafe/_speedups.c
- mkdir dep-j2/
- mv {markupsafe,jinja2} dep-j2/
+ mkdir j2/
+ mv {markupsafe,jinja2} j2/
echo collecting pyftpdlib
f="../build/pyftpdlib-1.5.6.tar.gz"
@@ -143,8 +143,8 @@ tmpdir="$(
mv pyftpdlib-release-*/pyftpdlib .
rm -rf pyftpdlib-release-* pyftpdlib/test
- mkdir dep-ftp/
- mv pyftpdlib dep-ftp/
+ mkdir ftp/
+ mv pyftpdlib ftp/
echo collecting asyncore, asynchat
for n in asyncore.py asynchat.py; do
@@ -154,6 +154,24 @@ tmpdir="$(
wget -O$f "$url" || curl -L "$url" >$f)
done
+ # enable this to dynamically remove type hints at startup,
+ # in case a future python version can use them for performance
+ true || (
+ echo collecting strip-hints
+ f=../build/strip-hints-0.1.10.tar.gz
+ [ -e $f ] ||
+ (url=https://files.pythonhosted.org/packages/9c/d4/312ddce71ee10f7e0ab762afc027e07a918f1c0e1be5b0069db5b0e7542d/strip-hints-0.1.10.tar.gz;
+ wget -O$f "$url" || curl -L "$url" >$f)
+
+ tar -zxf $f
+ mv strip-hints-0.1.10/src/strip_hints .
+ rm -rf strip-hints-* strip_hints/import_hooks*
+ sed -ri 's/[a-z].* as import_hooks$/"""a"""/' strip_hints/*.py
+
+ cp -pR ../scripts/strip_hints/ .
+ )
+ cp -pR ../scripts/py2/ .
+
# msys2 tar is bad, make the best of it
echo collecting source
[ $clean ] && {
@@ -170,6 +188,9 @@ tmpdir="$(
for n in asyncore.py asynchat.py; do
awk 'NR<4||NR>27;NR==4{print"# license: https://opensource.org/licenses/ISC\n"}' ../build/$n >copyparty/vend/$n
done
+
+ # remove type hints before build instead
+ (cd copyparty; python3 ../../scripts/strip_hints/a.py; rm uh)
}
ver=
@@ -274,17 +295,23 @@ rm have
tmv "$f"
done
-[ $repack ] ||
-find | grep -E '\.py$' |
- grep -vE '__version__' |
- tr '\n' '\0' |
- xargs -0 "$pybin" ../scripts/uncomment.py
+[ $repack ] || {
+ # uncomment
+ find | grep -E '\.py$' |
+ grep -vE '__version__' |
+ tr '\n' '\0' |
+ xargs -0 "$pybin" ../scripts/uncomment.py
-f=dep-j2/jinja2/constants.py
+ # py2-compat
+ #find | grep -E '\.py$' | while IFS= read -r x; do
+ # sed -ri '/: TypeAlias = /d' "$x"; done
+}
+
+f=j2/jinja2/constants.py
awk '/^LOREM_IPSUM_WORDS/{o=1;print "LOREM_IPSUM_WORDS = u\"a\"";next} !o; /"""/{o=0}' <$f >t
tmv "$f"
-grep -rLE '^#[^a-z]*coding: utf-8' dep-j2 |
+grep -rLE '^#[^a-z]*coding: utf-8' j2 |
while IFS= read -r f; do
(echo "# coding: utf-8"; cat "$f") >t
tmv "$f"
@@ -313,7 +340,7 @@ find | grep -E '\.(js|html)$' | while IFS= read -r f; do
done
gzres() {
- command -v pigz &&
+ command -v pigz && [ $zopf ] &&
pk="pigz -11 -I $zopf" ||
pk='gzip'
@@ -354,7 +381,8 @@ nf=$(ls -1 "$zdir"/arc.* | wc -l)
}
[ $use_zdir ] && {
arcs=("$zdir"/arc.*)
- arc="${arcs[$RANDOM % ${#arcs[@]} ] }"
+ n=$(( $RANDOM % ${#arcs[@]} ))
+ arc="${arcs[n]}"
echo "using $arc"
tar -xf "$arc"
for f in copyparty/web/*.gz; do
@@ -364,7 +392,7 @@ nf=$(ls -1 "$zdir"/arc.* | wc -l)
echo gen tarlist
-for d in copyparty dep-j2 dep-ftp; do find $d -type f; done |
+for d in copyparty j2 ftp py2; do find $d -type f; done | # strip_hints
sed -r 's/(.*)\.(.*)/\2 \1/' | LC_ALL=C sort |
sed -r 's/([^ ]*) (.*)/\2.\1/' | grep -vE '/list1?$' > list1
diff --git a/scripts/run-tests.sh b/scripts/run-tests.sh
index 48aaf075..1977a55a 100755
--- a/scripts/run-tests.sh
+++ b/scripts/run-tests.sh
@@ -1,13 +1,23 @@
#!/bin/bash
set -ex
+rm -rf unt
+mkdir -p unt/srv
+cp -pR copyparty tests unt/
+cd unt
+python3 ../scripts/strip_hints/a.py
+
pids=()
for py in python{2,3}; do
+ PYTHONPATH=
+ [ $py = python2 ] && PYTHONPATH=../scripts/py2
+ export PYTHONPATH
+
nice $py -m unittest discover -s tests >/dev/null &
pids+=($!)
done
-python3 scripts/test/smoketest.py &
+python3 ../scripts/test/smoketest.py &
pids+=($!)
for pid in ${pids[@]}; do
diff --git a/scripts/sfx.py b/scripts/sfx.py
index 4723a0da..2cb20c8a 100644
--- a/scripts/sfx.py
+++ b/scripts/sfx.py
@@ -379,9 +379,20 @@ def run(tmp, j2, ftp):
t.daemon = True
t.start()
- ld = (("", ""), (j2, "dep-j2"), (ftp, "dep-ftp"))
+ ld = (("", ""), (j2, "j2"), (ftp, "ftp"), (not PY2, "py2"))
ld = [os.path.join(tmp, b) for a, b in ld if not a]
+ # skip 1
+ # enable this to dynamically remove type hints at startup,
+ # in case a future python version can use them for performance
+ if sys.version_info < (3, 10) and False:
+ sys.path.insert(0, ld[0])
+
+ from strip_hints.a import uh
+
+ uh(tmp + "/copyparty")
+ # skip 0
+
if any([re.match(r"^-.*j[0-9]", x) for x in sys.argv]):
run_s(ld)
else:
diff --git a/scripts/sfx.sh b/scripts/sfx.sh
index 1f53c8db..1496f970 100644
--- a/scripts/sfx.sh
+++ b/scripts/sfx.sh
@@ -47,7 +47,7 @@ grep -E '/(python|pypy)[0-9\.-]*$' >$dir/pys || true
printf '\033[1;30mlooking for jinja2 in [%s]\033[0m\n' "$_py" >&2
$_py -c 'import jinja2' 2>/dev/null || continue
printf '%s\n' "$_py"
- mv $dir/{,x.}dep-j2
+ mv $dir/{,x.}j2
break
done)"
diff --git a/scripts/strip_hints/a.py b/scripts/strip_hints/a.py
new file mode 100644
index 00000000..06bf4f6b
--- /dev/null
+++ b/scripts/strip_hints/a.py
@@ -0,0 +1,57 @@
+# coding: utf-8
+from __future__ import print_function, unicode_literals
+
+import re
+import os
+import sys
+from strip_hints import strip_file_to_string
+
+
+# list unique types used in hints:
+# rm -rf unt && cp -pR copyparty unt && (cd unt && python3 ../scripts/strip_hints/a.py)
+# diff -wNarU1 copyparty unt | grep -E '^\-' | sed -r 's/[^][, ]+://g; s/[^][, ]+[[(]//g; s/[],()<>{} -]/\n/g' | grep -E .. | sort | uniq -c | sort -n
+
+
+def pr(m):
+ sys.stderr.write(m)
+ sys.stderr.flush()
+
+
+def uh(top):
+ if os.path.exists(top + "/uh"):
+ return
+
+ libs = "typing|types|collections\.abc"
+ ptn = re.compile(r"^(\s*)(from (?:{0}) import |import (?:{0})\b).*".format(libs))
+
+ # pr("building support for your python ver")
+ pr("unhinting")
+ for (dp, _, fns) in os.walk(top):
+ for fn in fns:
+ if not fn.endswith(".py"):
+ continue
+
+ pr(".")
+ fp = os.path.join(dp, fn)
+ cs = strip_file_to_string(fp, no_ast=True, to_empty=True)
+
+ # remove expensive imports too
+ lns = []
+ for ln in cs.split("\n"):
+ m = ptn.match(ln)
+ if m:
+ ln = m.group(1) + "raise Exception()"
+
+ lns.append(ln)
+
+ cs = "\n".join(lns)
+ with open(fp, "wb") as f:
+ f.write(cs.encode("utf-8"))
+
+ pr("k\n\n")
+ with open(top + "/uh", "wb") as f:
+ f.write(b"a")
+
+
+if __name__ == "__main__":
+ uh(".")
diff --git a/scripts/test/race.py b/scripts/test/race.py
index 09922e60..77ef0e46 100644
--- a/scripts/test/race.py
+++ b/scripts/test/race.py
@@ -58,13 +58,13 @@ class CState(threading.Thread):
remotes.append("?")
remotes_ok = False
- m = []
+ ta = []
for conn, remote in zip(self.cs, remotes):
stage = len(conn.st)
- m.append(f"\033[3{colors[stage]}m{remote}")
+ ta.append(f"\033[3{colors[stage]}m{remote}")
- m = " ".join(m)
- print(f"{m}\033[0m\n\033[A", end="")
+ t = " ".join(ta)
+ print(f"{t}\033[0m\n\033[A", end="")
def allget(cs, urls):
diff --git a/scripts/test/smoketest.py b/scripts/test/smoketest.py
index 8508f127..a37fc44a 100644
--- a/scripts/test/smoketest.py
+++ b/scripts/test/smoketest.py
@@ -72,6 +72,8 @@ def tc1(vflags):
for _ in range(10):
try:
os.mkdir(td)
+ if os.path.exists(td):
+ break
except:
time.sleep(0.1) # win10
diff --git a/tests/test_vfs.py b/tests/test_vfs.py
index cb7f08fe..4818d180 100644
--- a/tests/test_vfs.py
+++ b/tests/test_vfs.py
@@ -85,7 +85,7 @@ class TestVFS(unittest.TestCase):
pass
def assertAxs(self, dct, lst):
- t1 = list(sorted(dct.keys()))
+ t1 = list(sorted(dct))
t2 = list(sorted(lst))
self.assertEqual(t1, t2)
@@ -208,10 +208,10 @@ class TestVFS(unittest.TestCase):
self.assertEqual(n.realpath, os.path.join(td, "a"))
self.assertAxs(n.axs.uread, ["*"])
self.assertAxs(n.axs.uwrite, [])
- self.assertEqual(vfs.can_access("/", "*"), [False, False, False, False, False])
- self.assertEqual(vfs.can_access("/", "k"), [True, True, False, False, False])
- self.assertEqual(vfs.can_access("/a", "*"), [True, False, False, False, False])
- self.assertEqual(vfs.can_access("/a", "k"), [True, False, False, False, False])
+ self.assertEqual(vfs.can_access("/", "*"), (False, False, False, False, False))
+ self.assertEqual(vfs.can_access("/", "k"), (True, True, False, False, False))
+ self.assertEqual(vfs.can_access("/a", "*"), (True, False, False, False, False))
+ self.assertEqual(vfs.can_access("/a", "k"), (True, False, False, False, False))
# breadth-first construction
vfs = AuthSrv(
@@ -279,7 +279,7 @@ class TestVFS(unittest.TestCase):
n = au.vfs
# root was not defined, so PWD with no access to anyone
self.assertEqual(n.vpath, "")
- self.assertEqual(n.realpath, None)
+ self.assertEqual(n.realpath, "")
self.assertAxs(n.axs.uread, [])
self.assertAxs(n.axs.uwrite, [])
self.assertEqual(len(n.nodes), 1)
diff --git a/tests/util.py b/tests/util.py
index 55ec58ea..92ecaa88 100644
--- a/tests/util.py
+++ b/tests/util.py
@@ -90,7 +90,10 @@ def get_ramdisk():
class NullBroker(object):
- def put(*args):
+ def say(*args):
+ pass
+
+ def ask(*args):
pass