增加sql的waf功能
This commit is contained in:
353
resty/websocket/client.lua
Normal file
353
resty/websocket/client.lua
Normal file
@@ -0,0 +1,353 @@
|
||||
-- Copyright (C) Yichun Zhang (agentzh)
|
||||
|
||||
|
||||
-- FIXME: this library is very rough and is currently just for testing
|
||||
-- the websocket server.
|
||||
|
||||
|
||||
local wbproto = require "resty.websocket.protocol"
|
||||
local bit = require "bit"
|
||||
|
||||
|
||||
local _recv_frame = wbproto.recv_frame
|
||||
local _send_frame = wbproto.send_frame
|
||||
local new_tab = wbproto.new_tab
|
||||
local tcp = ngx.socket.tcp
|
||||
local re_match = ngx.re.match
|
||||
local encode_base64 = ngx.encode_base64
|
||||
local concat = table.concat
|
||||
local char = string.char
|
||||
local str_find = string.find
|
||||
local rand = math.random
|
||||
local rshift = bit.rshift
|
||||
local band = bit.band
|
||||
local setmetatable = setmetatable
|
||||
local type = type
|
||||
local debug = ngx.config.debug
|
||||
local ngx_log = ngx.log
|
||||
local ngx_DEBUG = ngx.DEBUG
|
||||
local ssl_support = true
|
||||
|
||||
if not ngx.config
|
||||
or not ngx.config.ngx_lua_version
|
||||
or ngx.config.ngx_lua_version < 9011
|
||||
then
|
||||
ssl_support = false
|
||||
end
|
||||
|
||||
local _M = new_tab(0, 13)
|
||||
_M._VERSION = '0.08'
|
||||
|
||||
|
||||
local mt = { __index = _M }
|
||||
|
||||
|
||||
function _M.new(self, opts)
|
||||
local sock, err = tcp()
|
||||
if not sock then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
local max_payload_len, send_unmasked, timeout
|
||||
if opts then
|
||||
max_payload_len = opts.max_payload_len
|
||||
send_unmasked = opts.send_unmasked
|
||||
timeout = opts.timeout
|
||||
|
||||
if timeout then
|
||||
sock:settimeout(timeout)
|
||||
end
|
||||
end
|
||||
|
||||
return setmetatable({
|
||||
sock = sock,
|
||||
max_payload_len = max_payload_len or 65535,
|
||||
send_unmasked = send_unmasked,
|
||||
}, mt)
|
||||
end
|
||||
|
||||
|
||||
function _M.connect(self, uri, opts)
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
local m, err = re_match(uri, [[^(wss?)://([^:/]+)(?::(\d+))?(.*)]], "jo")
|
||||
if not m then
|
||||
if err then
|
||||
return nil, "failed to match the uri: " .. err
|
||||
end
|
||||
|
||||
return nil, "bad websocket uri"
|
||||
end
|
||||
|
||||
local scheme = m[1]
|
||||
local host = m[2]
|
||||
local port = m[3]
|
||||
local path = m[4]
|
||||
|
||||
-- ngx.say("host: ", host)
|
||||
-- ngx.say("port: ", port)
|
||||
|
||||
if not port then
|
||||
port = 80
|
||||
end
|
||||
|
||||
if path == "" then
|
||||
path = "/"
|
||||
end
|
||||
|
||||
local ssl_verify, headers, proto_header, origin_header, sock_opts = false
|
||||
|
||||
if opts then
|
||||
local protos = opts.protocols
|
||||
if protos then
|
||||
if type(protos) == "table" then
|
||||
proto_header = "\r\nSec-WebSocket-Protocol: "
|
||||
.. concat(protos, ",")
|
||||
|
||||
else
|
||||
proto_header = "\r\nSec-WebSocket-Protocol: " .. protos
|
||||
end
|
||||
end
|
||||
|
||||
local origin = opts.origin
|
||||
if origin then
|
||||
origin_header = "\r\nOrigin: " .. origin
|
||||
end
|
||||
|
||||
local pool = opts.pool
|
||||
if pool then
|
||||
sock_opts = { pool = pool }
|
||||
end
|
||||
|
||||
if opts.ssl_verify then
|
||||
if not ssl_support then
|
||||
return nil, "ngx_lua 0.9.11+ required for SSL sockets"
|
||||
end
|
||||
ssl_verify = true
|
||||
end
|
||||
|
||||
if opts.headers then
|
||||
headers = opts.headers
|
||||
if type(headers) ~= "table" then
|
||||
return nil, "custom headers must be a table"
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local ok, err
|
||||
if sock_opts then
|
||||
ok, err = sock:connect(host, port, sock_opts)
|
||||
else
|
||||
ok, err = sock:connect(host, port)
|
||||
end
|
||||
if not ok then
|
||||
return nil, "failed to connect: " .. err
|
||||
end
|
||||
|
||||
if scheme == "wss" then
|
||||
if not ssl_support then
|
||||
return nil, "ngx_lua 0.9.11+ required for SSL sockets"
|
||||
end
|
||||
ok, err = sock:sslhandshake(false, host, ssl_verify)
|
||||
if not ok then
|
||||
return nil, "ssl handshake failed: " .. err
|
||||
end
|
||||
end
|
||||
|
||||
-- check for connections from pool:
|
||||
|
||||
local count, err = sock:getreusedtimes()
|
||||
if not count then
|
||||
return nil, "failed to get reused times: " .. err
|
||||
end
|
||||
if count > 0 then
|
||||
-- being a reused connection (must have done handshake)
|
||||
return 1
|
||||
end
|
||||
|
||||
local custom_headers
|
||||
if headers then
|
||||
custom_headers = concat(headers, "\r\n")
|
||||
custom_headers = "\r\n" .. custom_headers
|
||||
end
|
||||
|
||||
-- do the websocket handshake:
|
||||
|
||||
local bytes = char(rand(256) - 1, rand(256) - 1, rand(256) - 1,
|
||||
rand(256) - 1, rand(256) - 1, rand(256) - 1,
|
||||
rand(256) - 1, rand(256) - 1, rand(256) - 1,
|
||||
rand(256) - 1, rand(256) - 1, rand(256) - 1,
|
||||
rand(256) - 1, rand(256) - 1, rand(256) - 1,
|
||||
rand(256) - 1)
|
||||
|
||||
local key = encode_base64(bytes)
|
||||
local req = "GET " .. path .. " HTTP/1.1\r\nUpgrade: websocket\r\nHost: "
|
||||
.. host .. ":" .. port
|
||||
.. "\r\nSec-WebSocket-Key: " .. key
|
||||
.. (proto_header or "")
|
||||
.. "\r\nSec-WebSocket-Version: 13"
|
||||
.. (origin_header or "")
|
||||
.. "\r\nConnection: Upgrade"
|
||||
.. (custom_headers or "")
|
||||
.. "\r\n\r\n"
|
||||
|
||||
local bytes, err = sock:send(req)
|
||||
if not bytes then
|
||||
return nil, "failed to send the handshake request: " .. err
|
||||
end
|
||||
|
||||
local header_reader = sock:receiveuntil("\r\n\r\n")
|
||||
-- FIXME: check for too big response headers
|
||||
local header, err, partial = header_reader()
|
||||
if not header then
|
||||
return nil, "failed to receive response header: " .. err
|
||||
end
|
||||
|
||||
-- error("header: " .. header)
|
||||
|
||||
-- FIXME: verify the response headers
|
||||
|
||||
m, err = re_match(header, [[^\s*HTTP/1\.1\s+]], "jo")
|
||||
if not m then
|
||||
return nil, "bad HTTP response status line: " .. header
|
||||
end
|
||||
|
||||
return 1
|
||||
end
|
||||
|
||||
|
||||
function _M.set_timeout(self, time)
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, nil, "not initialized yet"
|
||||
end
|
||||
|
||||
return sock:settimeout(time)
|
||||
end
|
||||
|
||||
|
||||
function _M.recv_frame(self)
|
||||
if self.fatal then
|
||||
return nil, nil, "fatal error already happened"
|
||||
end
|
||||
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, nil, "not initialized yet"
|
||||
end
|
||||
|
||||
local data, typ, err = _recv_frame(sock, self.max_payload_len, false)
|
||||
if not data and not str_find(err, ": timeout", 1, true) then
|
||||
self.fatal = true
|
||||
end
|
||||
return data, typ, err
|
||||
end
|
||||
|
||||
|
||||
local function send_frame(self, fin, opcode, payload)
|
||||
if self.fatal then
|
||||
return nil, "fatal error already happened"
|
||||
end
|
||||
|
||||
if self.closed then
|
||||
return nil, "already closed"
|
||||
end
|
||||
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized yet"
|
||||
end
|
||||
|
||||
local bytes, err = _send_frame(sock, fin, opcode, payload,
|
||||
self.max_payload_len,
|
||||
not self.send_unmasked)
|
||||
if not bytes then
|
||||
self.fatal = true
|
||||
end
|
||||
return bytes, err
|
||||
end
|
||||
_M.send_frame = send_frame
|
||||
|
||||
|
||||
function _M.send_text(self, data)
|
||||
return send_frame(self, true, 0x1, data)
|
||||
end
|
||||
|
||||
|
||||
function _M.send_binary(self, data)
|
||||
return send_frame(self, true, 0x2, data)
|
||||
end
|
||||
|
||||
|
||||
local function send_close(self, code, msg)
|
||||
local payload
|
||||
if code then
|
||||
if type(code) ~= "number" or code > 0x7fff then
|
||||
return nil, "bad status code"
|
||||
end
|
||||
payload = char(band(rshift(code, 8), 0xff), band(code, 0xff))
|
||||
.. (msg or "")
|
||||
end
|
||||
|
||||
if debug then
|
||||
ngx_log(ngx_DEBUG, "sending the close frame")
|
||||
end
|
||||
|
||||
local bytes, err = send_frame(self, true, 0x8, payload)
|
||||
|
||||
if not bytes then
|
||||
self.fatal = true
|
||||
end
|
||||
|
||||
self.closed = true
|
||||
|
||||
return bytes, err
|
||||
end
|
||||
_M.send_close = send_close
|
||||
|
||||
|
||||
function _M.send_ping(self, data)
|
||||
return send_frame(self, true, 0x9, data)
|
||||
end
|
||||
|
||||
|
||||
function _M.send_pong(self, data)
|
||||
return send_frame(self, true, 0xa, data)
|
||||
end
|
||||
|
||||
|
||||
function _M.close(self)
|
||||
if self.fatal then
|
||||
return nil, "fatal error already happened"
|
||||
end
|
||||
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
if not self.closed then
|
||||
local bytes, err = send_close(self)
|
||||
if not bytes then
|
||||
return nil, "failed to send close frame: " .. err
|
||||
end
|
||||
end
|
||||
|
||||
return sock:close()
|
||||
end
|
||||
|
||||
|
||||
function _M.set_keepalive(self, ...)
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
return sock:setkeepalive(...)
|
||||
end
|
||||
|
||||
|
||||
return _M
|
||||
345
resty/websocket/protocol.lua
Normal file
345
resty/websocket/protocol.lua
Normal file
@@ -0,0 +1,345 @@
|
||||
-- Copyright (C) Yichun Zhang (agentzh)
|
||||
|
||||
|
||||
local bit = require "bit"
|
||||
local ffi = require "ffi"
|
||||
|
||||
|
||||
local byte = string.byte
|
||||
local char = string.char
|
||||
local sub = string.sub
|
||||
local band = bit.band
|
||||
local bor = bit.bor
|
||||
local bxor = bit.bxor
|
||||
local lshift = bit.lshift
|
||||
local rshift = bit.rshift
|
||||
--local tohex = bit.tohex
|
||||
local tostring = tostring
|
||||
local concat = table.concat
|
||||
local rand = math.random
|
||||
local type = type
|
||||
local debug = ngx.config.debug
|
||||
local ngx_log = ngx.log
|
||||
local ngx_DEBUG = ngx.DEBUG
|
||||
local ffi_new = ffi.new
|
||||
local ffi_string = ffi.string
|
||||
|
||||
|
||||
local ok, new_tab = pcall(require, "table.new")
|
||||
if not ok then
|
||||
new_tab = function (narr, nrec) return {} end
|
||||
end
|
||||
|
||||
|
||||
local _M = new_tab(0, 5)
|
||||
|
||||
_M.new_tab = new_tab
|
||||
_M._VERSION = '0.08'
|
||||
|
||||
|
||||
local types = {
|
||||
[0x0] = "continuation",
|
||||
[0x1] = "text",
|
||||
[0x2] = "binary",
|
||||
[0x8] = "close",
|
||||
[0x9] = "ping",
|
||||
[0xa] = "pong",
|
||||
}
|
||||
|
||||
local str_buf_size = 4096
|
||||
local str_buf
|
||||
local c_buf_type = ffi.typeof("char[?]")
|
||||
|
||||
|
||||
local function get_string_buf(size)
|
||||
if size > str_buf_size then
|
||||
return ffi_new(c_buf_type, size)
|
||||
end
|
||||
if not str_buf then
|
||||
str_buf = ffi_new(c_buf_type, str_buf_size)
|
||||
end
|
||||
|
||||
return str_buf
|
||||
end
|
||||
|
||||
|
||||
function _M.recv_frame(sock, max_payload_len, force_masking)
|
||||
local data, err = sock:receive(2)
|
||||
if not data then
|
||||
return nil, nil, "failed to receive the first 2 bytes: " .. err
|
||||
end
|
||||
|
||||
local fst, snd = byte(data, 1, 2)
|
||||
|
||||
local fin = band(fst, 0x80) ~= 0
|
||||
-- print("fin: ", fin)
|
||||
|
||||
if band(fst, 0x70) ~= 0 then
|
||||
return nil, nil, "bad RSV1, RSV2, or RSV3 bits"
|
||||
end
|
||||
|
||||
local opcode = band(fst, 0x0f)
|
||||
-- print("opcode: ", tohex(opcode))
|
||||
|
||||
if opcode >= 0x3 and opcode <= 0x7 then
|
||||
return nil, nil, "reserved non-control frames"
|
||||
end
|
||||
|
||||
if opcode >= 0xb and opcode <= 0xf then
|
||||
return nil, nil, "reserved control frames"
|
||||
end
|
||||
|
||||
local mask = band(snd, 0x80) ~= 0
|
||||
|
||||
if debug then
|
||||
ngx_log(ngx_DEBUG, "recv_frame: mask bit: ", mask and 1 or 0)
|
||||
end
|
||||
|
||||
if force_masking and not mask then
|
||||
return nil, nil, "frame unmasked"
|
||||
end
|
||||
|
||||
local payload_len = band(snd, 0x7f)
|
||||
-- print("payload len: ", payload_len)
|
||||
|
||||
if payload_len == 126 then
|
||||
local data, err = sock:receive(2)
|
||||
if not data then
|
||||
return nil, nil, "failed to receive the 2 byte payload length: "
|
||||
.. (err or "unknown")
|
||||
end
|
||||
|
||||
payload_len = bor(lshift(byte(data, 1), 8), byte(data, 2))
|
||||
|
||||
elseif payload_len == 127 then
|
||||
local data, err = sock:receive(8)
|
||||
if not data then
|
||||
return nil, nil, "failed to receive the 8 byte payload length: "
|
||||
.. (err or "unknown")
|
||||
end
|
||||
|
||||
if byte(data, 1) ~= 0
|
||||
or byte(data, 2) ~= 0
|
||||
or byte(data, 3) ~= 0
|
||||
or byte(data, 4) ~= 0
|
||||
then
|
||||
return nil, nil, "payload len too large"
|
||||
end
|
||||
|
||||
local fifth = byte(data, 5)
|
||||
if band(fifth, 0x80) ~= 0 then
|
||||
return nil, nil, "payload len too large"
|
||||
end
|
||||
|
||||
payload_len = bor(lshift(fifth, 24),
|
||||
lshift(byte(data, 6), 16),
|
||||
lshift(byte(data, 7), 8),
|
||||
byte(data, 8))
|
||||
end
|
||||
|
||||
if band(opcode, 0x8) ~= 0 then
|
||||
-- being a control frame
|
||||
if payload_len > 125 then
|
||||
return nil, nil, "too long payload for control frame"
|
||||
end
|
||||
|
||||
if not fin then
|
||||
return nil, nil, "fragmented control frame"
|
||||
end
|
||||
end
|
||||
|
||||
-- print("payload len: ", payload_len, ", max payload len: ",
|
||||
-- max_payload_len)
|
||||
|
||||
if payload_len > max_payload_len then
|
||||
return nil, nil, "exceeding max payload len"
|
||||
end
|
||||
|
||||
local rest
|
||||
if mask then
|
||||
rest = payload_len + 4
|
||||
|
||||
else
|
||||
rest = payload_len
|
||||
end
|
||||
-- print("rest: ", rest)
|
||||
|
||||
local data
|
||||
if rest > 0 then
|
||||
data, err = sock:receive(rest)
|
||||
if not data then
|
||||
return nil, nil, "failed to read masking-len and payload: "
|
||||
.. (err or "unknown")
|
||||
end
|
||||
else
|
||||
data = ""
|
||||
end
|
||||
|
||||
-- print("received rest")
|
||||
|
||||
if opcode == 0x8 then
|
||||
-- being a close frame
|
||||
if payload_len > 0 then
|
||||
if payload_len < 2 then
|
||||
return nil, nil, "close frame with a body must carry a 2-byte"
|
||||
.. " status code"
|
||||
end
|
||||
|
||||
local msg, code
|
||||
if mask then
|
||||
local fst = bxor(byte(data, 4 + 1), byte(data, 1))
|
||||
local snd = bxor(byte(data, 4 + 2), byte(data, 2))
|
||||
code = bor(lshift(fst, 8), snd)
|
||||
|
||||
if payload_len > 2 then
|
||||
-- TODO string.buffer optimizations
|
||||
local bytes = get_string_buf(payload_len - 2)
|
||||
for i = 3, payload_len do
|
||||
bytes[i - 3] = bxor(byte(data, 4 + i),
|
||||
byte(data, (i - 1) % 4 + 1))
|
||||
end
|
||||
msg = ffi_string(bytes, payload_len - 2)
|
||||
|
||||
else
|
||||
msg = ""
|
||||
end
|
||||
|
||||
else
|
||||
local fst = byte(data, 1)
|
||||
local snd = byte(data, 2)
|
||||
code = bor(lshift(fst, 8), snd)
|
||||
|
||||
-- print("parsing unmasked close frame payload: ", payload_len)
|
||||
|
||||
if payload_len > 2 then
|
||||
msg = sub(data, 3)
|
||||
|
||||
else
|
||||
msg = ""
|
||||
end
|
||||
end
|
||||
|
||||
return msg, "close", code
|
||||
end
|
||||
|
||||
return "", "close", nil
|
||||
end
|
||||
|
||||
local msg
|
||||
if mask then
|
||||
-- TODO string.buffer optimizations
|
||||
local bytes = get_string_buf(payload_len)
|
||||
for i = 1, payload_len do
|
||||
bytes[i - 1] = bxor(byte(data, 4 + i),
|
||||
byte(data, (i - 1) % 4 + 1))
|
||||
end
|
||||
msg = ffi_string(bytes, payload_len)
|
||||
|
||||
else
|
||||
msg = data
|
||||
end
|
||||
|
||||
return msg, types[opcode], not fin and "again" or nil
|
||||
end
|
||||
|
||||
|
||||
local function build_frame(fin, opcode, payload_len, payload, masking)
|
||||
-- XXX optimize this when we have string.buffer in LuaJIT 2.1
|
||||
local fst
|
||||
if fin then
|
||||
fst = bor(0x80, opcode)
|
||||
else
|
||||
fst = opcode
|
||||
end
|
||||
|
||||
local snd, extra_len_bytes
|
||||
if payload_len <= 125 then
|
||||
snd = payload_len
|
||||
extra_len_bytes = ""
|
||||
|
||||
elseif payload_len <= 65535 then
|
||||
snd = 126
|
||||
extra_len_bytes = char(band(rshift(payload_len, 8), 0xff),
|
||||
band(payload_len, 0xff))
|
||||
|
||||
else
|
||||
if band(payload_len, 0x7fffffff) < payload_len then
|
||||
return nil, "payload too big"
|
||||
end
|
||||
|
||||
snd = 127
|
||||
-- XXX we only support 31-bit length here
|
||||
extra_len_bytes = char(0, 0, 0, 0, band(rshift(payload_len, 24), 0xff),
|
||||
band(rshift(payload_len, 16), 0xff),
|
||||
band(rshift(payload_len, 8), 0xff),
|
||||
band(payload_len, 0xff))
|
||||
end
|
||||
|
||||
local masking_key
|
||||
if masking then
|
||||
-- set the mask bit
|
||||
snd = bor(snd, 0x80)
|
||||
local key = rand(0xffffffff)
|
||||
masking_key = char(band(rshift(key, 24), 0xff),
|
||||
band(rshift(key, 16), 0xff),
|
||||
band(rshift(key, 8), 0xff),
|
||||
band(key, 0xff))
|
||||
|
||||
-- TODO string.buffer optimizations
|
||||
local bytes = get_string_buf(payload_len)
|
||||
for i = 1, payload_len do
|
||||
bytes[i - 1] = bxor(byte(payload, i),
|
||||
byte(masking_key, (i - 1) % 4 + 1))
|
||||
end
|
||||
payload = ffi_string(bytes, payload_len)
|
||||
|
||||
else
|
||||
masking_key = ""
|
||||
end
|
||||
|
||||
return char(fst, snd) .. extra_len_bytes .. masking_key .. payload
|
||||
end
|
||||
_M.build_frame = build_frame
|
||||
|
||||
|
||||
function _M.send_frame(sock, fin, opcode, payload, max_payload_len, masking)
|
||||
-- ngx.log(ngx.WARN, ngx.var.uri, ": masking: ", masking)
|
||||
|
||||
if not payload then
|
||||
payload = ""
|
||||
|
||||
elseif type(payload) ~= "string" then
|
||||
payload = tostring(payload)
|
||||
end
|
||||
|
||||
local payload_len = #payload
|
||||
|
||||
if payload_len > max_payload_len then
|
||||
return nil, "payload too big"
|
||||
end
|
||||
|
||||
if band(opcode, 0x8) ~= 0 then
|
||||
-- being a control frame
|
||||
if payload_len > 125 then
|
||||
return nil, "too much payload for control frame"
|
||||
end
|
||||
if not fin then
|
||||
return nil, "fragmented control frame"
|
||||
end
|
||||
end
|
||||
|
||||
local frame, err = build_frame(fin, opcode, payload_len, payload,
|
||||
masking)
|
||||
if not frame then
|
||||
return nil, "failed to build frame: " .. err
|
||||
end
|
||||
|
||||
local bytes, err = sock:send(frame)
|
||||
if not bytes then
|
||||
return nil, "failed to send frame: " .. err
|
||||
end
|
||||
return bytes
|
||||
end
|
||||
|
||||
|
||||
return _M
|
||||
210
resty/websocket/server.lua
Normal file
210
resty/websocket/server.lua
Normal file
@@ -0,0 +1,210 @@
|
||||
-- Copyright (C) Yichun Zhang (agentzh)
|
||||
|
||||
|
||||
local bit = require "bit"
|
||||
local wbproto = require "resty.websocket.protocol"
|
||||
|
||||
local new_tab = wbproto.new_tab
|
||||
local _recv_frame = wbproto.recv_frame
|
||||
local _send_frame = wbproto.send_frame
|
||||
local http_ver = ngx.req.http_version
|
||||
local req_sock = ngx.req.socket
|
||||
local ngx_header = ngx.header
|
||||
local req_headers = ngx.req.get_headers
|
||||
local str_lower = string.lower
|
||||
local char = string.char
|
||||
local str_find = string.find
|
||||
local sha1_bin = ngx.sha1_bin
|
||||
local base64 = ngx.encode_base64
|
||||
local ngx = ngx
|
||||
local read_body = ngx.req.read_body
|
||||
local band = bit.band
|
||||
local rshift = bit.rshift
|
||||
local type = type
|
||||
local setmetatable = setmetatable
|
||||
local tostring = tostring
|
||||
-- local print = print
|
||||
|
||||
|
||||
local _M = new_tab(0, 10)
|
||||
_M._VERSION = '0.08'
|
||||
|
||||
local mt = { __index = _M }
|
||||
|
||||
|
||||
function _M.new(self, opts)
|
||||
if ngx.headers_sent then
|
||||
return nil, "response header already sent"
|
||||
end
|
||||
|
||||
read_body()
|
||||
|
||||
if http_ver() ~= 1.1 then
|
||||
return nil, "bad http version"
|
||||
end
|
||||
|
||||
local headers = req_headers()
|
||||
|
||||
local val = headers.upgrade
|
||||
if type(val) == "table" then
|
||||
val = val[1]
|
||||
end
|
||||
if not val or str_lower(val) ~= "websocket" then
|
||||
return nil, "bad \"upgrade\" request header: " .. tostring(val)
|
||||
end
|
||||
|
||||
val = headers.connection
|
||||
if type(val) == "table" then
|
||||
val = val[1]
|
||||
end
|
||||
if not val or not str_find(str_lower(val), "upgrade", 1, true) then
|
||||
return nil, "bad \"connection\" request header"
|
||||
end
|
||||
|
||||
local key = headers["sec-websocket-key"]
|
||||
if type(key) == "table" then
|
||||
key = key[1]
|
||||
end
|
||||
if not key then
|
||||
return nil, "bad \"sec-websocket-key\" request header"
|
||||
end
|
||||
|
||||
local ver = headers["sec-websocket-version"]
|
||||
if type(ver) == "table" then
|
||||
ver = ver[1]
|
||||
end
|
||||
if not ver or ver ~= "13" then
|
||||
return nil, "bad \"sec-websocket-version\" request header"
|
||||
end
|
||||
|
||||
local protocols = headers["sec-websocket-protocol"]
|
||||
if type(protocols) == "table" then
|
||||
protocols = protocols[1]
|
||||
end
|
||||
|
||||
if protocols then
|
||||
ngx_header["Sec-WebSocket-Protocol"] = protocols
|
||||
end
|
||||
ngx_header["Upgrade"] = "websocket"
|
||||
|
||||
local sha1 = sha1_bin(key .. "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
|
||||
ngx_header["Sec-WebSocket-Accept"] = base64(sha1)
|
||||
|
||||
ngx_header["Content-Type"] = nil
|
||||
|
||||
ngx.status = 101
|
||||
local ok, err = ngx.send_headers()
|
||||
if not ok then
|
||||
return nil, "failed to send response header: " .. (err or "unknonw")
|
||||
end
|
||||
ok, err = ngx.flush(true)
|
||||
if not ok then
|
||||
return nil, "failed to flush response header: " .. (err or "unknown")
|
||||
end
|
||||
|
||||
local sock
|
||||
sock, err = req_sock(true)
|
||||
if not sock then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
local max_payload_len, send_masked, timeout
|
||||
if opts then
|
||||
max_payload_len = opts.max_payload_len
|
||||
send_masked = opts.send_masked
|
||||
timeout = opts.timeout
|
||||
|
||||
if timeout then
|
||||
sock:settimeout(timeout)
|
||||
end
|
||||
end
|
||||
|
||||
return setmetatable({
|
||||
sock = sock,
|
||||
max_payload_len = max_payload_len or 65535,
|
||||
send_masked = send_masked,
|
||||
}, mt)
|
||||
end
|
||||
|
||||
|
||||
function _M.set_timeout(self, time)
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, nil, "not initialized yet"
|
||||
end
|
||||
|
||||
return sock:settimeout(time)
|
||||
end
|
||||
|
||||
|
||||
function _M.recv_frame(self)
|
||||
if self.fatal then
|
||||
return nil, nil, "fatal error already happened"
|
||||
end
|
||||
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, nil, "not initialized yet"
|
||||
end
|
||||
|
||||
local data, typ, err = _recv_frame(sock, self.max_payload_len, true)
|
||||
if not data and not str_find(err, ": timeout", 1, true) then
|
||||
self.fatal = true
|
||||
end
|
||||
return data, typ, err
|
||||
end
|
||||
|
||||
|
||||
local function send_frame(self, fin, opcode, payload)
|
||||
if self.fatal then
|
||||
return nil, "fatal error already happened"
|
||||
end
|
||||
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized yet"
|
||||
end
|
||||
|
||||
local bytes, err = _send_frame(sock, fin, opcode, payload,
|
||||
self.max_payload_len, self.send_masked)
|
||||
if not bytes then
|
||||
self.fatal = true
|
||||
end
|
||||
return bytes, err
|
||||
end
|
||||
_M.send_frame = send_frame
|
||||
|
||||
|
||||
function _M.send_text(self, data)
|
||||
return send_frame(self, true, 0x1, data)
|
||||
end
|
||||
|
||||
|
||||
function _M.send_binary(self, data)
|
||||
return send_frame(self, true, 0x2, data)
|
||||
end
|
||||
|
||||
|
||||
function _M.send_close(self, code, msg)
|
||||
local payload
|
||||
if code then
|
||||
if type(code) ~= "number" or code > 0x7fff then
|
||||
end
|
||||
payload = char(band(rshift(code, 8), 0xff), band(code, 0xff))
|
||||
.. (msg or "")
|
||||
end
|
||||
return send_frame(self, true, 0x8, payload)
|
||||
end
|
||||
|
||||
|
||||
function _M.send_ping(self, data)
|
||||
return send_frame(self, true, 0x9, data)
|
||||
end
|
||||
|
||||
|
||||
function _M.send_pong(self, data)
|
||||
return send_frame(self, true, 0xa, data)
|
||||
end
|
||||
|
||||
|
||||
return _M
|
||||
Reference in New Issue
Block a user