login, vol.tree, refactor

This commit is contained in:
ed
2019-06-04 22:12:28 +00:00
parent c53413d57c
commit 5d8be84d18
8 changed files with 174 additions and 167 deletions

View File

@@ -5,19 +5,24 @@ from __future__ import print_function
import time
import hashlib
import mimetypes
import jinja2
from .__init__ import *
from .util import *
if not PY2:
if PY2:
from cStringIO import StringIO as BytesIO
else:
unicode = str
from io import BytesIO as BytesIO
class HttpCli(object):
def __init__(self, sck, addr, args, log_func):
def __init__(self, sck, addr, args, auth, log_func):
self.s = sck
self.addr = addr
self.args = args
self.auth = auth
self.sr = Unrecv(sck)
self.bufsz = 1024 * 32
@@ -27,13 +32,21 @@ class HttpCli(object):
self.log_func = log_func
self.log_src = "{} \033[36m{}".format(addr[0], addr[1]).ljust(26)
with open(self.respath("splash.html"), "rb") as f:
self.tpl_mounts = jinja2.Template(f.read().decode("utf-8"))
def respath(self, res_name):
return os.path.join(E.mod, "web", res_name)
def log(self, msg):
self.log_func(self.log_src, msg)
def run(self):
while self.ok:
headerlines = self.read_header()
if not self.ok:
try:
headerlines = read_header(self.sr)
except:
self.ok = False
return
self.headers = {}
@@ -48,7 +61,27 @@ class HttpCli(object):
k, v = header_line.split(":", 1)
self.headers[k.lower()] = v.strip()
# self.bufsz = int(self.req.split('/')[-1]) * 1024
self.uname = "*"
if "cookie" in self.headers:
cookies = self.headers["cookie"].split(";")
for k, v in [x.split("=", 1) for x in cookies]:
if k != "cppwd":
continue
v = unescape_cookie(v)
if not v in self.auth.iuser:
msg = u'bad_cpwd "{}"'.format(v)
nuke = u"Set-Cookie: cppwd=x; path=/; expires=Thu, 01 Jan 1970 00:00:00 GMT"
self.loud_reply(msg, headers=[nuke])
return
self.uname = self.auth.iuser[v]
if self.uname:
self.rvol = self.auth.vfs.user_tree(self.uname, readable=True)
self.wvol = self.auth.vfs.user_tree(self.uname, writable=True)
print(self.rvol)
print(self.wvol)
if mode == "GET":
self.handle_get()
@@ -62,132 +95,96 @@ class HttpCli(object):
self.ok = False
self.s.close()
def read_header(self):
ret = b""
while True:
if ret.endswith(b"\r\n\r\n"):
break
elif ret.endswith(b"\r\n\r"):
n = 1
elif ret.endswith(b"\r\n"):
n = 2
elif ret.endswith(b"\r"):
n = 3
else:
n = 4
buf = self.sr.recv(n)
if not buf:
self.panic("headers")
break
ret += buf
return ret[:-4].decode("utf-8", "replace").split("\r\n")
def reply(self, body, status="200 OK", mime="text/html"):
header = "HTTP/1.1 {}\r\nConnection: Keep-Alive\r\nContent-Type: {}\r\nContent-Length: {}\r\n\r\n".format(
status, mime, len(body)
).encode(
"utf-8"
)
def reply(self, body, status="200 OK", mime="text/html", headers=[]):
# TODO something to reply with user-supplied values safely
response = [
u"HTTP/1.1 " + status,
u"Connection: Keep-Alive",
u"Content-Type: " + mime,
u"Content-Length: " + str(len(body)),
]
response.extend(headers)
response_str = u"\r\n".join(response).encode("utf-8")
if self.ok:
self.s.send(header + body)
self.s.send(response_str + b"\r\n\r\n" + body)
return body
def loud_reply(self, body, **kwargs):
def loud_reply(self, body, *args, **kwargs):
self.log(body.rstrip())
self.reply(b"<pre>" + body.encode("utf-8"), **kwargs)
def send_file(self, path):
sz = os.path.getsize(path)
mime = mimetypes.guess_type(path)[0]
header = "HTTP/1.1 200 OK\r\nConnection: Keep-Alive\r\nContent-Type: {}\r\nContent-Length: {}\r\n\r\n".format(
mime, sz
).encode(
"utf-8"
)
if self.ok:
self.s.send(header)
with open(path, "rb") as f:
while self.ok:
buf = f.read(4096)
if not buf:
break
self.s.send(buf)
self.reply(b"<pre>" + body.encode("utf-8"), *list(args), **kwargs)
def handle_get(self):
self.log("")
self.log("GET {0} {1}".format(self.addr[0], self.req))
static_path = os.path.join(E.mod, "web", self.req.split("?")[0][1:])
if self.req.startswith("/.cpr/"):
static_path = os.path.join(E.mod, "web", self.req.split("?")[0][6:])
if os.path.isfile(static_path):
return self.send_file(static_path)
if os.path.isfile(static_path):
return self.tx_file(static_path)
if self.req == "/":
return self.send_file(os.path.join(E.mod, "web/splash.html"))
return self.tx_mounts()
return self.loud_reply("404 not found", status="404 Not Found")
return self.loud_reply("404 not found", "404 Not Found")
def handle_post(self):
self.log("")
self.log("POST {0} {1}".format(self.addr[0], self.req))
nullwrite = self.args.nw
try:
if self.headers["expect"].lower() == "100-continue":
self.s.send(b"HTTP/1.1 100 Continue\r\n\r\n")
except:
pass
form_segm = self.read_header()
if not self.ok:
self.parser = MultipartParser(self.log, self.sr, self.headers)
self.parser.parse()
act = self.parser.require("act", 64)
if act == u"bput":
self.handle_plain_upload()
return
self.boundary = b"\r\n" + form_segm[0].encode("utf-8")
for ln in form_segm[1:]:
self.log(ln)
if act == u"login":
self.handle_login()
return
fn = os.devnull
fn0 = "inc.{0:.6f}".format(time.time())
raise Pebkac('invalid action "{}"'.format(act))
def handle_login(self):
pwd = self.parser.require("cppwd", 64)
if not pwd in self.auth.iuser:
h = [u"Set-Cookie: cppwd=x; path=/; expires=Thu, 01 Jan 1970 00:00:00 GMT"]
self.loud_reply(u'bad_ppwd "{}"'.format(pwd), headers=h)
else:
h = ["Set-Cookie: cppwd={}; Path=/".format(pwd)]
self.loud_reply(u"login_ok", headers=h)
def handle_plain_upload(self):
nullwrite = self.args.nw
files = []
t0 = time.time()
for nfile in range(99):
for nfile, (p_field, p_file, p_data) in enumerate(self.parser.gen):
fn = os.devnull
if not nullwrite:
fn = "{0}.{1}".format(fn0, nfile)
fn = sanitize_fn(p_file)
# TODO broker which avoid this race
# and provides a new filename if taken
if os.path.exists(fn):
fn += ".{:.6f}".format(time.time())
with open(fn, "wb") as f:
self.log("writing to {0}".format(fn))
sz, sha512 = self.handle_multipart(f)
sz, sha512 = hashcopy(self, p_data, f)
if sz == 0:
break
files.append([sz, sha512])
buf = self.sr.recv(2)
if buf == b"--":
# end of multipart
break
if buf != b"\r\n":
return self.loud_reply(u"protocol error")
header = self.read_header()
if not self.ok:
break
form_segm += header
for ln in header:
self.log(ln)
td = time.time() - t0
sz_total = sum(x[0] for x in files)
spd = (sz_total / td) / (1024 * 1024)
@@ -206,14 +203,15 @@ class HttpCli(object):
self.loud_reply(msg)
if not nullwrite:
with open(fn0 + ".txt", "wb") as f:
# TODO this is bad
log_fn = "up.{:.6f}.txt".format(t0)
with open(log_fn, "wb") as f:
f.write(
(
u"\n".join(
unicode(x)
for x in [
u":".join(unicode(x) for x in self.addr),
u"\n".join(form_segm),
msg.rstrip(),
]
)
@@ -221,77 +219,26 @@ class HttpCli(object):
).encode("utf-8")
)
try:
# TODO: check if actually part of multipart footer
buf = self.sr.recv(2)
if buf != b"\r\n":
raise Exception("oh")
except:
self.log("client is done")
self.s.close()
def tx_file(self, path):
sz = os.path.getsize(path)
mime = mimetypes.guess_type(path)[0]
header = "HTTP/1.1 200 OK\r\nConnection: Keep-Alive\r\nContent-Type: {}\r\nContent-Length: {}\r\n\r\n".format(
mime, sz
).encode(
"utf-8"
)
def handle_multipart(self, ofd):
tlen = 0
hashobj = hashlib.sha512()
for buf in self.extract_filedata():
tlen += len(buf)
hashobj.update(buf)
ofd.write(buf)
if self.ok:
self.s.send(header)
return tlen, hashobj.hexdigest()
def extract_filedata(self):
u32_lim = int((2 ** 31) * 0.9)
blen = len(self.boundary)
bufsz = self.bufsz
while True:
if self.workload > u32_lim:
# reset to prevent overflow
self.workload = 100
buf = self.sr.recv(bufsz)
self.workload += 1
if not buf:
# abort: client disconnected
self.panic("outer")
return
while True:
ofs = buf.find(self.boundary)
if ofs != -1:
self.sr.unrecv(buf[ofs + blen :])
yield buf[:ofs]
return
d = len(buf) - blen
if d > 0:
# buffer growing large; yield everything except
# the part at the end (maybe start of boundary)
yield buf[:d]
buf = buf[d:]
# look for boundary near the end of the buffer
for n in range(1, len(buf) + 1):
if not buf[-n:] in self.boundary:
n -= 1
break
if n == 0 or not self.boundary.startswith(buf[-n:]):
# no boundary contents near the buffer edge
with open(path, "rb") as f:
while self.ok:
buf = f.read(4096)
if not buf:
break
if blen == n:
# EOF: found boundary
yield buf[:-n]
return
self.s.send(buf)
buf2 = self.sr.recv(bufsz)
self.workload += 1
if not buf2:
# abort: client disconnected
self.panic("inner")
return
buf += buf2
yield buf
def tx_mounts(self):
html = self.tpl_mounts.render(this=self)
self.reply(html.encode("utf-8"))