增加sql的waf功能

This commit is contained in:
琴心
2022-03-03 16:16:53 +08:00
parent 5910cb2983
commit fcace799df
54 changed files with 12617 additions and 73 deletions

353
resty/websocket/client.lua Normal file
View File

@@ -0,0 +1,353 @@
-- Copyright (C) Yichun Zhang (agentzh)
-- FIXME: this library is very rough and is currently just for testing
-- the websocket server.
local wbproto = require "resty.websocket.protocol"
local bit = require "bit"
local _recv_frame = wbproto.recv_frame
local _send_frame = wbproto.send_frame
local new_tab = wbproto.new_tab
local tcp = ngx.socket.tcp
local re_match = ngx.re.match
local encode_base64 = ngx.encode_base64
local concat = table.concat
local char = string.char
local str_find = string.find
local rand = math.random
local rshift = bit.rshift
local band = bit.band
local setmetatable = setmetatable
local type = type
local debug = ngx.config.debug
local ngx_log = ngx.log
local ngx_DEBUG = ngx.DEBUG
local ssl_support = true
if not ngx.config
or not ngx.config.ngx_lua_version
or ngx.config.ngx_lua_version < 9011
then
ssl_support = false
end
local _M = new_tab(0, 13)
_M._VERSION = '0.08'
local mt = { __index = _M }
function _M.new(self, opts)
local sock, err = tcp()
if not sock then
return nil, err
end
local max_payload_len, send_unmasked, timeout
if opts then
max_payload_len = opts.max_payload_len
send_unmasked = opts.send_unmasked
timeout = opts.timeout
if timeout then
sock:settimeout(timeout)
end
end
return setmetatable({
sock = sock,
max_payload_len = max_payload_len or 65535,
send_unmasked = send_unmasked,
}, mt)
end
function _M.connect(self, uri, opts)
local sock = self.sock
if not sock then
return nil, "not initialized"
end
local m, err = re_match(uri, [[^(wss?)://([^:/]+)(?::(\d+))?(.*)]], "jo")
if not m then
if err then
return nil, "failed to match the uri: " .. err
end
return nil, "bad websocket uri"
end
local scheme = m[1]
local host = m[2]
local port = m[3]
local path = m[4]
-- ngx.say("host: ", host)
-- ngx.say("port: ", port)
if not port then
port = 80
end
if path == "" then
path = "/"
end
local ssl_verify, headers, proto_header, origin_header, sock_opts = false
if opts then
local protos = opts.protocols
if protos then
if type(protos) == "table" then
proto_header = "\r\nSec-WebSocket-Protocol: "
.. concat(protos, ",")
else
proto_header = "\r\nSec-WebSocket-Protocol: " .. protos
end
end
local origin = opts.origin
if origin then
origin_header = "\r\nOrigin: " .. origin
end
local pool = opts.pool
if pool then
sock_opts = { pool = pool }
end
if opts.ssl_verify then
if not ssl_support then
return nil, "ngx_lua 0.9.11+ required for SSL sockets"
end
ssl_verify = true
end
if opts.headers then
headers = opts.headers
if type(headers) ~= "table" then
return nil, "custom headers must be a table"
end
end
end
local ok, err
if sock_opts then
ok, err = sock:connect(host, port, sock_opts)
else
ok, err = sock:connect(host, port)
end
if not ok then
return nil, "failed to connect: " .. err
end
if scheme == "wss" then
if not ssl_support then
return nil, "ngx_lua 0.9.11+ required for SSL sockets"
end
ok, err = sock:sslhandshake(false, host, ssl_verify)
if not ok then
return nil, "ssl handshake failed: " .. err
end
end
-- check for connections from pool:
local count, err = sock:getreusedtimes()
if not count then
return nil, "failed to get reused times: " .. err
end
if count > 0 then
-- being a reused connection (must have done handshake)
return 1
end
local custom_headers
if headers then
custom_headers = concat(headers, "\r\n")
custom_headers = "\r\n" .. custom_headers
end
-- do the websocket handshake:
local bytes = char(rand(256) - 1, rand(256) - 1, rand(256) - 1,
rand(256) - 1, rand(256) - 1, rand(256) - 1,
rand(256) - 1, rand(256) - 1, rand(256) - 1,
rand(256) - 1, rand(256) - 1, rand(256) - 1,
rand(256) - 1, rand(256) - 1, rand(256) - 1,
rand(256) - 1)
local key = encode_base64(bytes)
local req = "GET " .. path .. " HTTP/1.1\r\nUpgrade: websocket\r\nHost: "
.. host .. ":" .. port
.. "\r\nSec-WebSocket-Key: " .. key
.. (proto_header or "")
.. "\r\nSec-WebSocket-Version: 13"
.. (origin_header or "")
.. "\r\nConnection: Upgrade"
.. (custom_headers or "")
.. "\r\n\r\n"
local bytes, err = sock:send(req)
if not bytes then
return nil, "failed to send the handshake request: " .. err
end
local header_reader = sock:receiveuntil("\r\n\r\n")
-- FIXME: check for too big response headers
local header, err, partial = header_reader()
if not header then
return nil, "failed to receive response header: " .. err
end
-- error("header: " .. header)
-- FIXME: verify the response headers
m, err = re_match(header, [[^\s*HTTP/1\.1\s+]], "jo")
if not m then
return nil, "bad HTTP response status line: " .. header
end
return 1
end
function _M.set_timeout(self, time)
local sock = self.sock
if not sock then
return nil, nil, "not initialized yet"
end
return sock:settimeout(time)
end
function _M.recv_frame(self)
if self.fatal then
return nil, nil, "fatal error already happened"
end
local sock = self.sock
if not sock then
return nil, nil, "not initialized yet"
end
local data, typ, err = _recv_frame(sock, self.max_payload_len, false)
if not data and not str_find(err, ": timeout", 1, true) then
self.fatal = true
end
return data, typ, err
end
local function send_frame(self, fin, opcode, payload)
if self.fatal then
return nil, "fatal error already happened"
end
if self.closed then
return nil, "already closed"
end
local sock = self.sock
if not sock then
return nil, "not initialized yet"
end
local bytes, err = _send_frame(sock, fin, opcode, payload,
self.max_payload_len,
not self.send_unmasked)
if not bytes then
self.fatal = true
end
return bytes, err
end
_M.send_frame = send_frame
function _M.send_text(self, data)
return send_frame(self, true, 0x1, data)
end
function _M.send_binary(self, data)
return send_frame(self, true, 0x2, data)
end
local function send_close(self, code, msg)
local payload
if code then
if type(code) ~= "number" or code > 0x7fff then
return nil, "bad status code"
end
payload = char(band(rshift(code, 8), 0xff), band(code, 0xff))
.. (msg or "")
end
if debug then
ngx_log(ngx_DEBUG, "sending the close frame")
end
local bytes, err = send_frame(self, true, 0x8, payload)
if not bytes then
self.fatal = true
end
self.closed = true
return bytes, err
end
_M.send_close = send_close
function _M.send_ping(self, data)
return send_frame(self, true, 0x9, data)
end
function _M.send_pong(self, data)
return send_frame(self, true, 0xa, data)
end
function _M.close(self)
if self.fatal then
return nil, "fatal error already happened"
end
local sock = self.sock
if not sock then
return nil, "not initialized"
end
if not self.closed then
local bytes, err = send_close(self)
if not bytes then
return nil, "failed to send close frame: " .. err
end
end
return sock:close()
end
function _M.set_keepalive(self, ...)
local sock = self.sock
if not sock then
return nil, "not initialized"
end
return sock:setkeepalive(...)
end
return _M

View File

@@ -0,0 +1,345 @@
-- Copyright (C) Yichun Zhang (agentzh)
local bit = require "bit"
local ffi = require "ffi"
local byte = string.byte
local char = string.char
local sub = string.sub
local band = bit.band
local bor = bit.bor
local bxor = bit.bxor
local lshift = bit.lshift
local rshift = bit.rshift
--local tohex = bit.tohex
local tostring = tostring
local concat = table.concat
local rand = math.random
local type = type
local debug = ngx.config.debug
local ngx_log = ngx.log
local ngx_DEBUG = ngx.DEBUG
local ffi_new = ffi.new
local ffi_string = ffi.string
local ok, new_tab = pcall(require, "table.new")
if not ok then
new_tab = function (narr, nrec) return {} end
end
local _M = new_tab(0, 5)
_M.new_tab = new_tab
_M._VERSION = '0.08'
local types = {
[0x0] = "continuation",
[0x1] = "text",
[0x2] = "binary",
[0x8] = "close",
[0x9] = "ping",
[0xa] = "pong",
}
local str_buf_size = 4096
local str_buf
local c_buf_type = ffi.typeof("char[?]")
local function get_string_buf(size)
if size > str_buf_size then
return ffi_new(c_buf_type, size)
end
if not str_buf then
str_buf = ffi_new(c_buf_type, str_buf_size)
end
return str_buf
end
function _M.recv_frame(sock, max_payload_len, force_masking)
local data, err = sock:receive(2)
if not data then
return nil, nil, "failed to receive the first 2 bytes: " .. err
end
local fst, snd = byte(data, 1, 2)
local fin = band(fst, 0x80) ~= 0
-- print("fin: ", fin)
if band(fst, 0x70) ~= 0 then
return nil, nil, "bad RSV1, RSV2, or RSV3 bits"
end
local opcode = band(fst, 0x0f)
-- print("opcode: ", tohex(opcode))
if opcode >= 0x3 and opcode <= 0x7 then
return nil, nil, "reserved non-control frames"
end
if opcode >= 0xb and opcode <= 0xf then
return nil, nil, "reserved control frames"
end
local mask = band(snd, 0x80) ~= 0
if debug then
ngx_log(ngx_DEBUG, "recv_frame: mask bit: ", mask and 1 or 0)
end
if force_masking and not mask then
return nil, nil, "frame unmasked"
end
local payload_len = band(snd, 0x7f)
-- print("payload len: ", payload_len)
if payload_len == 126 then
local data, err = sock:receive(2)
if not data then
return nil, nil, "failed to receive the 2 byte payload length: "
.. (err or "unknown")
end
payload_len = bor(lshift(byte(data, 1), 8), byte(data, 2))
elseif payload_len == 127 then
local data, err = sock:receive(8)
if not data then
return nil, nil, "failed to receive the 8 byte payload length: "
.. (err or "unknown")
end
if byte(data, 1) ~= 0
or byte(data, 2) ~= 0
or byte(data, 3) ~= 0
or byte(data, 4) ~= 0
then
return nil, nil, "payload len too large"
end
local fifth = byte(data, 5)
if band(fifth, 0x80) ~= 0 then
return nil, nil, "payload len too large"
end
payload_len = bor(lshift(fifth, 24),
lshift(byte(data, 6), 16),
lshift(byte(data, 7), 8),
byte(data, 8))
end
if band(opcode, 0x8) ~= 0 then
-- being a control frame
if payload_len > 125 then
return nil, nil, "too long payload for control frame"
end
if not fin then
return nil, nil, "fragmented control frame"
end
end
-- print("payload len: ", payload_len, ", max payload len: ",
-- max_payload_len)
if payload_len > max_payload_len then
return nil, nil, "exceeding max payload len"
end
local rest
if mask then
rest = payload_len + 4
else
rest = payload_len
end
-- print("rest: ", rest)
local data
if rest > 0 then
data, err = sock:receive(rest)
if not data then
return nil, nil, "failed to read masking-len and payload: "
.. (err or "unknown")
end
else
data = ""
end
-- print("received rest")
if opcode == 0x8 then
-- being a close frame
if payload_len > 0 then
if payload_len < 2 then
return nil, nil, "close frame with a body must carry a 2-byte"
.. " status code"
end
local msg, code
if mask then
local fst = bxor(byte(data, 4 + 1), byte(data, 1))
local snd = bxor(byte(data, 4 + 2), byte(data, 2))
code = bor(lshift(fst, 8), snd)
if payload_len > 2 then
-- TODO string.buffer optimizations
local bytes = get_string_buf(payload_len - 2)
for i = 3, payload_len do
bytes[i - 3] = bxor(byte(data, 4 + i),
byte(data, (i - 1) % 4 + 1))
end
msg = ffi_string(bytes, payload_len - 2)
else
msg = ""
end
else
local fst = byte(data, 1)
local snd = byte(data, 2)
code = bor(lshift(fst, 8), snd)
-- print("parsing unmasked close frame payload: ", payload_len)
if payload_len > 2 then
msg = sub(data, 3)
else
msg = ""
end
end
return msg, "close", code
end
return "", "close", nil
end
local msg
if mask then
-- TODO string.buffer optimizations
local bytes = get_string_buf(payload_len)
for i = 1, payload_len do
bytes[i - 1] = bxor(byte(data, 4 + i),
byte(data, (i - 1) % 4 + 1))
end
msg = ffi_string(bytes, payload_len)
else
msg = data
end
return msg, types[opcode], not fin and "again" or nil
end
local function build_frame(fin, opcode, payload_len, payload, masking)
-- XXX optimize this when we have string.buffer in LuaJIT 2.1
local fst
if fin then
fst = bor(0x80, opcode)
else
fst = opcode
end
local snd, extra_len_bytes
if payload_len <= 125 then
snd = payload_len
extra_len_bytes = ""
elseif payload_len <= 65535 then
snd = 126
extra_len_bytes = char(band(rshift(payload_len, 8), 0xff),
band(payload_len, 0xff))
else
if band(payload_len, 0x7fffffff) < payload_len then
return nil, "payload too big"
end
snd = 127
-- XXX we only support 31-bit length here
extra_len_bytes = char(0, 0, 0, 0, band(rshift(payload_len, 24), 0xff),
band(rshift(payload_len, 16), 0xff),
band(rshift(payload_len, 8), 0xff),
band(payload_len, 0xff))
end
local masking_key
if masking then
-- set the mask bit
snd = bor(snd, 0x80)
local key = rand(0xffffffff)
masking_key = char(band(rshift(key, 24), 0xff),
band(rshift(key, 16), 0xff),
band(rshift(key, 8), 0xff),
band(key, 0xff))
-- TODO string.buffer optimizations
local bytes = get_string_buf(payload_len)
for i = 1, payload_len do
bytes[i - 1] = bxor(byte(payload, i),
byte(masking_key, (i - 1) % 4 + 1))
end
payload = ffi_string(bytes, payload_len)
else
masking_key = ""
end
return char(fst, snd) .. extra_len_bytes .. masking_key .. payload
end
_M.build_frame = build_frame
function _M.send_frame(sock, fin, opcode, payload, max_payload_len, masking)
-- ngx.log(ngx.WARN, ngx.var.uri, ": masking: ", masking)
if not payload then
payload = ""
elseif type(payload) ~= "string" then
payload = tostring(payload)
end
local payload_len = #payload
if payload_len > max_payload_len then
return nil, "payload too big"
end
if band(opcode, 0x8) ~= 0 then
-- being a control frame
if payload_len > 125 then
return nil, "too much payload for control frame"
end
if not fin then
return nil, "fragmented control frame"
end
end
local frame, err = build_frame(fin, opcode, payload_len, payload,
masking)
if not frame then
return nil, "failed to build frame: " .. err
end
local bytes, err = sock:send(frame)
if not bytes then
return nil, "failed to send frame: " .. err
end
return bytes
end
return _M

210
resty/websocket/server.lua Normal file
View File

@@ -0,0 +1,210 @@
-- Copyright (C) Yichun Zhang (agentzh)
local bit = require "bit"
local wbproto = require "resty.websocket.protocol"
local new_tab = wbproto.new_tab
local _recv_frame = wbproto.recv_frame
local _send_frame = wbproto.send_frame
local http_ver = ngx.req.http_version
local req_sock = ngx.req.socket
local ngx_header = ngx.header
local req_headers = ngx.req.get_headers
local str_lower = string.lower
local char = string.char
local str_find = string.find
local sha1_bin = ngx.sha1_bin
local base64 = ngx.encode_base64
local ngx = ngx
local read_body = ngx.req.read_body
local band = bit.band
local rshift = bit.rshift
local type = type
local setmetatable = setmetatable
local tostring = tostring
-- local print = print
local _M = new_tab(0, 10)
_M._VERSION = '0.08'
local mt = { __index = _M }
function _M.new(self, opts)
if ngx.headers_sent then
return nil, "response header already sent"
end
read_body()
if http_ver() ~= 1.1 then
return nil, "bad http version"
end
local headers = req_headers()
local val = headers.upgrade
if type(val) == "table" then
val = val[1]
end
if not val or str_lower(val) ~= "websocket" then
return nil, "bad \"upgrade\" request header: " .. tostring(val)
end
val = headers.connection
if type(val) == "table" then
val = val[1]
end
if not val or not str_find(str_lower(val), "upgrade", 1, true) then
return nil, "bad \"connection\" request header"
end
local key = headers["sec-websocket-key"]
if type(key) == "table" then
key = key[1]
end
if not key then
return nil, "bad \"sec-websocket-key\" request header"
end
local ver = headers["sec-websocket-version"]
if type(ver) == "table" then
ver = ver[1]
end
if not ver or ver ~= "13" then
return nil, "bad \"sec-websocket-version\" request header"
end
local protocols = headers["sec-websocket-protocol"]
if type(protocols) == "table" then
protocols = protocols[1]
end
if protocols then
ngx_header["Sec-WebSocket-Protocol"] = protocols
end
ngx_header["Upgrade"] = "websocket"
local sha1 = sha1_bin(key .. "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
ngx_header["Sec-WebSocket-Accept"] = base64(sha1)
ngx_header["Content-Type"] = nil
ngx.status = 101
local ok, err = ngx.send_headers()
if not ok then
return nil, "failed to send response header: " .. (err or "unknonw")
end
ok, err = ngx.flush(true)
if not ok then
return nil, "failed to flush response header: " .. (err or "unknown")
end
local sock
sock, err = req_sock(true)
if not sock then
return nil, err
end
local max_payload_len, send_masked, timeout
if opts then
max_payload_len = opts.max_payload_len
send_masked = opts.send_masked
timeout = opts.timeout
if timeout then
sock:settimeout(timeout)
end
end
return setmetatable({
sock = sock,
max_payload_len = max_payload_len or 65535,
send_masked = send_masked,
}, mt)
end
function _M.set_timeout(self, time)
local sock = self.sock
if not sock then
return nil, nil, "not initialized yet"
end
return sock:settimeout(time)
end
function _M.recv_frame(self)
if self.fatal then
return nil, nil, "fatal error already happened"
end
local sock = self.sock
if not sock then
return nil, nil, "not initialized yet"
end
local data, typ, err = _recv_frame(sock, self.max_payload_len, true)
if not data and not str_find(err, ": timeout", 1, true) then
self.fatal = true
end
return data, typ, err
end
local function send_frame(self, fin, opcode, payload)
if self.fatal then
return nil, "fatal error already happened"
end
local sock = self.sock
if not sock then
return nil, "not initialized yet"
end
local bytes, err = _send_frame(sock, fin, opcode, payload,
self.max_payload_len, self.send_masked)
if not bytes then
self.fatal = true
end
return bytes, err
end
_M.send_frame = send_frame
function _M.send_text(self, data)
return send_frame(self, true, 0x1, data)
end
function _M.send_binary(self, data)
return send_frame(self, true, 0x2, data)
end
function _M.send_close(self, code, msg)
local payload
if code then
if type(code) ~= "number" or code > 0x7fff then
end
payload = char(band(rshift(code, 8), 0xff), band(code, 0xff))
.. (msg or "")
end
return send_frame(self, true, 0x8, payload)
end
function _M.send_ping(self, data)
return send_frame(self, true, 0x9, data)
end
function _M.send_pong(self, data)
return send_frame(self, true, 0xa, data)
end
return _M