VCarContainer/tools/WebSocket/websocket.h

765 lines
26 KiB
C++

/*
MIT License
Copyright (c) 2020 Meng Rao <raomeng1@gmail.com>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
*/
#pragma once
#include <unistd.h>
#include <fcntl.h>
#include <arpa/inet.h>
#include <netinet/tcp.h>
#include <string.h>
#include <limits>
#include <memory>
#include <stdio.h>
#include <errno.h>
namespace websocket {
template<uint32_t RecvBufSize>
class SocketTcpConnection
{
public:
~SocketTcpConnection() { close("destruct"); }
const char* getLastError() { return last_error_; };
bool isConnected() { return fd_ >= 0; }
bool connect(const char* server_ip, uint16_t server_port) {
int fd = socket(AF_INET, SOCK_STREAM, 0);
if (fd < 0) {
saveError("socket error", true);
return false;
}
struct sockaddr_in server_addr;
server_addr.sin_family = AF_INET;
inet_pton(AF_INET, server_ip, &(server_addr.sin_addr));
server_addr.sin_port = htons(server_port);
bzero(&(server_addr.sin_zero), 8);
if (::connect(fd, (struct sockaddr*)&server_addr, sizeof(server_addr)) < 0) {
saveError("connect error", true);
::close(fd);
return false;
}
return open(fd);
}
bool getPeername(struct sockaddr_in& addr) {
socklen_t addr_len = sizeof(addr);
return ::getpeername(fd_, (struct sockaddr*)&addr, &addr_len) == 0;
}
void close(const char* reason, bool check_errno = false) {
if (fd_ >= 0) {
saveError(reason, check_errno);
::close(fd_);
fd_ = -1;
}
}
bool write(const uint8_t* data, uint32_t size, bool more = false) {
int flags = MSG_NOSIGNAL;
if (more) flags |= MSG_MORE;
do {
int sent = ::send(fd_, data, size, flags);
if (sent < 0) {
if (errno != EAGAIN) {
close("send error", true);
return false;
}
continue;
}
data += sent;
size -= sent;
} while (size != 0);
return true;
}
template<typename Handler>
bool read(Handler handler) {
int ret = ::read(fd_, recvbuf_ + tail_, RecvBufSize - tail_);
if (ret <= 0) {
if (ret < 0 && errno == EAGAIN) return false;
if (ret < 0) {
close("read error", true);
}
else {
close("remote close");
}
return false;
}
tail_ += ret;
uint32_t remaining = handler(recvbuf_ + head_, tail_ - head_);
if (remaining == 0) {
head_ = tail_ = 0;
}
else {
head_ = tail_ - remaining;
if (head_ >= RecvBufSize / 2) {
memcpy(recvbuf_, recvbuf_ + head_, remaining);
head_ = 0;
tail_ = remaining;
}
else if (tail_ == RecvBufSize) {
close("recv buf full");
}
}
return true;
}
protected:
template<uint32_t>
friend class SocketTcpServer;
bool open(int fd) {
fd_ = fd;
head_ = tail_ = 0;
int flags = fcntl(fd_, F_GETFL, 0);
if (fcntl(fd_, F_SETFL, flags | O_NONBLOCK) < 0) {
close("fcntl O_NONBLOCK error", true);
return false;
}
int yes = 1;
if (setsockopt(fd_, IPPROTO_TCP, TCP_NODELAY, &yes, sizeof(yes)) < 0) {
close("setsockopt TCP_NODELAY error", true);
return false;
}
return true;
}
void saveError(const char* msg, bool check_errno) {
snprintf(last_error_, sizeof(last_error_), "%s %s", msg, check_errno ? (const char*)strerror(errno) : "");
}
int fd_ = -1;
uint32_t head_;
uint32_t tail_;
char recvbuf_[RecvBufSize];
char last_error_[64] = "";
};
template<uint32_t RecvBufSize = 4096>
class SocketTcpServer
{
public:
using TcpConnection = SocketTcpConnection<RecvBufSize>;
bool init(const char* interface, const char* server_ip, uint16_t server_port) {
listenfd_ = socket(AF_INET, SOCK_STREAM, 0);
if (listenfd_ < 0) {
saveError("socket error");
return false;
}
int flags = fcntl(listenfd_, F_GETFL, 0);
if (fcntl(listenfd_, F_SETFL, flags | O_NONBLOCK) < 0) {
close("fcntl O_NONBLOCK error");
return false;
}
int yes = 1;
if (setsockopt(listenfd_, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof(yes)) < 0) {
close("setsockopt SO_REUSEADDR error");
return false;
}
struct sockaddr_in local_addr;
local_addr.sin_family = AF_INET;
inet_pton(AF_INET, server_ip, &(local_addr.sin_addr));
local_addr.sin_port = htons(server_port);
bzero(&(local_addr.sin_zero), 8);
if (bind(listenfd_, (struct sockaddr*)&local_addr, sizeof(local_addr)) < 0) {
close("bind error");
return false;
}
if (listen(listenfd_, 5) < 0) {
close("listen error");
return false;
}
return true;
};
void close(const char* reason) {
if (listenfd_ >= 0) {
saveError(reason);
::close(listenfd_);
listenfd_ = -1;
}
}
const char* getLastError() { return last_error_; };
~SocketTcpServer() { close("destruct"); }
bool accept2(TcpConnection& conn) {
struct sockaddr_in clientaddr;
socklen_t addr_len = sizeof(clientaddr);
int fd = ::accept(listenfd_, (struct sockaddr*)&(clientaddr), &addr_len);
if (fd < 0) {
return false;
}
if (!conn.open(fd)) {
return false;
}
return true;
}
private:
void saveError(const char* msg) { snprintf(last_error_, sizeof(last_error_), "%s %s", msg, strerror(errno)); }
int listenfd_ = -1;
char last_error_[64] = "";
};
inline uint64_t getns() {
timespec ts;
::clock_gettime(CLOCK_REALTIME, &ts);
return ts.tv_sec * 1000000000 + ts.tv_nsec;
}
static const uint8_t OPCODE_CONT = 0;
static const uint8_t OPCODE_TEXT = 1;
static const uint8_t OPCODE_BINARY = 2;
static const uint8_t OPCODE_CLOSE = 8;
static const uint8_t OPCODE_PING = 9;
static const uint8_t OPCODE_PONG = 10;
template<typename EventHandler, typename ConnUserData, bool RecvSegment, uint32_t RecvBufSize, bool SendMask>
class WSConnection
{
public:
ConnUserData user_data;
// get remote network address
bool getPeername(struct sockaddr_in& addr) { return conn.getPeername(addr); }
bool isConnected() { return conn.isConnected(); }
// if sending a msg of multiple segments, only set fin to true for the last one
void send(uint8_t opcode, const uint8_t* payload, uint32_t pl_len, bool fin = true) {
uint8_t h[14];
uint32_t h_len = 2;
if (opcode >> 3) // if control
fin = true;
else {
if (!send_fin) opcode = OPCODE_CONT;
send_fin = fin;
}
h[0] = (opcode & 15) | ((uint8_t)fin << 7);
h[1] = (uint8_t)SendMask << 7;
if (pl_len < 126) {
h[1] |= (uint8_t)pl_len;
}
else if (pl_len < 65536) {
h[1] |= 126;
*(uint16_t*)(h + 2) = htobe16(pl_len);
h_len += 2;
}
else {
h[1] |= 127;
*(uint64_t*)(h + 2) = htobe64(pl_len);
h_len += 8;
}
if (SendMask) { // for efficency and simplicity masking-key is always set to 0
*(uint32_t*)(h + h_len) = 0;
h_len += 4;
}
conn.write(h, h_len, true);
conn.write(payload, pl_len, false);
}
// clean close the connection with optional status_code and reason
void close(uint16_t status_code = 1005, const char* reason = "") {
*(uint16_t*)close_reason = htobe16(status_code);
uint32_t reason_len = snprintf((char*)close_reason + 2, sizeof(close_reason) - 2, "%s", reason);
if (status_code != 1005) {
send(OPCODE_CLOSE, close_reason, 2 + reason_len);
}
else
send(OPCODE_CLOSE, nullptr, 0);
conn.close("clean close");
}
protected:
template<typename, typename, bool, uint32_t, uint32_t>
friend class WSServer;
void init(uint64_t expire) {
open = false;
send_fin = true;
*(uint16_t*)close_reason = htobe16(1006);
close_reason[2] = 0;
frame_size = 0;
expire_time = expire;
}
uint32_t handleWSMsg(EventHandler* handler, uint8_t* data, uint32_t size) {
// we might read a little more bytes beyond size, which is okey
const uint8_t* data_end = data + size;
uint8_t opcode = data[0] & 15;
bool beg = opcode != OPCODE_CONT, fin = data[0] >> 7; //, control = opcode >> 3;
bool mask = data[1] >> 7;
uint8_t mask_key[4];
uint64_t pl_len = data[1] & 127;
data += 2;
if (pl_len == 126) {
pl_len = be16toh(*(uint16_t*)data);
data += 2;
}
else if (pl_len == 127) {
pl_len = be64toh(*(uint64_t*)data) & ~(1ULL << 63);
data += 8;
}
if (mask) {
*(uint32_t*)mask_key = *(uint32_t*)data;
data += 4;
}
if (data_end - data < (int64_t)pl_len) {
if (size + (data + pl_len - data_end) > RecvBufSize) close(1009);
return size;
}
if (mask) {
for (uint64_t i = 0; i < pl_len; i++) data[i] ^= mask_key[i & 3];
}
if (RecvSegment || (beg && fin)) {
if (opcode == OPCODE_CLOSE) {
uint16_t status_code = 1005;
char reason[128] = {0};
if (pl_len >= 2) {
status_code = be16toh(*(uint16_t*)data);
uint64_t reason_len = std::min(sizeof(reason) - 1, pl_len - 2);
memcpy(reason, data + 2, reason_len);
reason[reason_len] = 0;
}
close(status_code, reason);
}
else {
#if __cplusplus >= 201703L
if constexpr (RecvSegment) {
#else
if (RecvSegment) {
#endif
if (beg) recv_opcode = opcode;
handler->onWSSegment(*this, recv_opcode, data, pl_len, frame_size, fin);
if (fin)
frame_size = 0;
else
frame_size += pl_len;
}
else
handler->onWSMsg(*this, opcode, data, pl_len);
}
}
#if __cplusplus >= 201703L
else if constexpr (!RecvSegment) {
#else
else {
#endif
if (frame_size + pl_len > RecvBufSize)
close(1009);
else {
memcpy(frame + frame_size, data, pl_len);
frame_size += pl_len;
if (beg) recv_opcode = opcode;
if (fin) {
handler->onWSMsg(*this, recv_opcode, frame, frame_size);
frame_size = 0;
}
}
}
return data_end - (data + pl_len);
}
void handleWSClose(EventHandler* handler) {
uint16_t status_code = be16toh(*(uint16_t*)close_reason);
const char* reason = (const char*)close_reason + 2;
if (status_code == 1006) reason = conn.getLastError();
handler->onWSClose(*this, status_code, reason);
}
bool open;
bool send_fin;
uint8_t recv_opcode;
uint32_t frame_size;
uint64_t expire_time;
uint8_t frame[RecvSegment ? 0 : RecvBufSize];
typename SocketTcpServer<RecvBufSize>::TcpConnection conn;
uint8_t close_reason[128]; // first 2 bytes are status_code(big endian)
};
template<typename EventHandler, typename ConnUserData = char, bool RecvSegment = false, uint32_t RecvBufSize = 4096,
typename ConnectionType = WSConnection<EventHandler, ConnUserData, RecvSegment, RecvBufSize, true>>
class WSClient : public ConnectionType
{
public:
using Connection = ConnectionType;
// using Connection = WSConnection<EventHandler, ConnUserData, RecvSegment, RecvBufSize, true>;
const char* getLastError() { return this->conn.getLastError(); }
// timeout: connect timeout in milliseconds, 0 means no limit
// if failed, call getLastError() for the reason
bool wsConnect(uint64_t timeout, const char* server_ip, uint16_t server_port, const char* request_uri,
const char* host, const char* origin = nullptr, const char* protocol = nullptr,
const char* extensions = nullptr, char* resp_protocol = nullptr, uint32_t resp_protocol_size = 0,
char* resp_extensions = nullptr, uint32_t resp_extensions_size = 0) {
uint64_t now = getns();
uint64_t expire = timeout > 0 ? now + timeout * 1000000 : std::numeric_limits<uint64_t>::max();
if (!this->conn.connect(server_ip, server_port)) return false;
if (getns() > expire) {
this->conn.close("timeout");
return false;
}
this->init(expire);
char req[2048];
uint32_t req_len =
snprintf(req, sizeof(req),
"GET %s HTTP/1.1\r\nHost: %s\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: "
"dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n",
request_uri, host);
if (origin) req_len += snprintf(req + req_len, sizeof(req) - req_len, "Origin: %s\r\n", origin);
if (protocol) req_len += snprintf(req + req_len, sizeof(req) - req_len, "Sec-WebSocket-Protocol: %s\r\n", protocol);
if (extensions)
req_len += snprintf(req + req_len, sizeof(req) - req_len, "Sec-WebSocket-Extensions: %s\r\n", extensions);
req_len += snprintf(req + req_len, sizeof(req) - req_len, "\r\n");
if (req_len >= sizeof(req) - 1) {
this->conn.close("request msg too long");
return false;
}
this->conn.write((uint8_t*)req, req_len);
while (!this->open && this->isConnected()) {
this->conn.read([&](const char* data, uint32_t size) -> uint32_t {
const char* data_end = data + size;
bool status_code_checked = false, upgrade_checked = false, connection_checked = false, accept_checked = false;
while (true) {
const char* ln = (char*)memchr(data, '\n', data_end - data);
if (!ln) return size;
if (*--ln != '\r') break;
if (!status_code_checked) { // first line
if (memcmp(data, "HTTP/", 5)) break;
const char* status_code = (char*)memchr(data, ' ', ln - data);
if (!status_code) break;
while (*status_code == ' ') status_code++;
if (memcmp(status_code, "101 ", 4)) break;
status_code_checked = true;
}
else {
const char* val_end = ln;
while (val_end[-1] == ' ') val_end--;
if (val_end == data) { // end of headers
if (!upgrade_checked || !connection_checked || !accept_checked) break;
this->open = true;
return data_end - ln - 2;
}
const char* colon = (char*)memchr(data, ':', ln - data);
if (!colon) break;
const char* val = colon + 1;
while (*val == ' ') val++;
uint32_t key_len = colon - data;
uint32_t val_len = val_end - val;
if (key_len == 7 && !memcmp(data, "Upgrade", 7)) {
if (memcmp(val, "websocket", 9)) break;
upgrade_checked = true;
}
else if (key_len == 10 && !memcmp(data, "Connection", 10)) {
if (!memcmp(val, "Upgrade", 7)) connection_checked = true;
}
else if (key_len == 20 && !memcmp(data, "Sec-WebSocket-Accept", 20)) {
if (val_len != 28 || memcmp(val, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", 28)) break;
accept_checked = true;
}
else if (key_len == 22 && !memcmp(data, "Sec-WebSocket-Protocol", 22) && resp_protocol_size > 0) {
uint32_t cp_len = std::min(resp_protocol_size - 1, val_len);
memcpy(resp_protocol, val, cp_len);
resp_protocol[cp_len] = 0;
}
else if (key_len == 24 && !memcmp(data, "Sec-WebSocket-Extensions", 24) && resp_extensions_size > 0) {
uint32_t cp_len = std::min(resp_extensions_size - 1, val_len);
memcpy(resp_extensions, val, cp_len);
resp_extensions[cp_len] = 0;
}
}
data = ln + 2; // skip \r\n
}
this->conn.close("request failed");
return size;
});
if (getns() > expire) this->conn.close("timeout");
}
return this->isConnected();
}
void poll(EventHandler* handler) {
this->conn.read([&](const char* data, uint32_t size) { return this->handleWSMsg(handler, (uint8_t*)data, size); });
if (!this->isConnected()) this->handleWSClose(handler);
}
};
template<typename EventHandler, typename ConnUserData = char, bool RecvSegment = false, uint32_t RecvBufSize = 4096,
uint32_t MaxConns = 10>
class WSServer
{
public:
using TcpServer = SocketTcpServer<RecvBufSize>;
using Connection = WSConnection<EventHandler, ConnUserData, RecvSegment, RecvBufSize, false>;
WSServer() {
for (int i = 0; i < MaxConns; i++) {
conns_[i] = conns_data_ + i;
}
}
const char* getLastError() { return server_.getLastError(); }
// newconn_timeout: new tcp connection max inactive time in milliseconds, 0 means no limit
// openconn_timeout: open ws connection max inactive time in milliseconds, 0 means no limit
// if failed, call getLastError() for the reason
bool init(const char* server_ip, uint16_t server_port, uint64_t newconn_timeout = 0, uint64_t openconn_timeout = 0) {
newconn_timeout_ = newconn_timeout * 1000000;
openconn_timeout_ = openconn_timeout * 1000000;
return server_.init("", server_ip, server_port);
}
void poll(EventHandler* handler) {
uint64_t now = getns();
uint64_t new_expire = newconn_timeout_ ? now + newconn_timeout_ : std::numeric_limits<uint64_t>::max();
uint64_t open_expire = openconn_timeout_ ? now + openconn_timeout_ : std::numeric_limits<uint64_t>::max();
if (conns_cnt_ < MaxConns) {
Connection& new_conn = *conns_[conns_cnt_];
if (server_.accept2(new_conn.conn)) {
new_conn.init(new_expire);
conns_cnt_++;
}
}
for (int i = 0; i < conns_cnt_;) {
Connection& conn = *conns_[i];
conn.conn.read([&](const char* data, uint32_t size) {
uint32_t remaining =
conn.open ? conn.handleWSMsg(handler, (uint8_t*)data, size) : handleHttpRequest(handler, conn, data, size);
if (remaining < size) conn.expire_time = conn.open ? open_expire : new_expire;
return remaining;
});
if (now > conn.expire_time) conn.conn.close("timeout");
if (conn.isConnected())
i++;
else {
if (conn.open) conn.handleWSClose(handler);
std::swap(conns_[i], conns_[--conns_cnt_]);
}
}
}
void sendMsg(const std::string &msg)
{
for (int i = 0; i < conns_cnt_; i++)
{
Connection &conn = *conns_[i];
if (conn.isConnected())
{
conn.send(websocket::OPCODE_TEXT, (const uint8_t *)msg.data(), msg.size());
}
else
{
conn.close();
}
}
}
private:
static uint32_t rol(uint32_t value, uint32_t bits) { return (value << bits) | (value >> (32 - bits)); }
// Be cautious that *in* will be modified and up to 64 bytes will be appended, so make sure in buffer is long enough
static uint32_t sha1base64(uint8_t* in, uint64_t in_len, char* out) {
uint32_t h0[5] = {0x67452301, 0xEFCDAB89, 0x98BADCFE, 0x10325476, 0xC3D2E1F0};
uint64_t total_len = in_len;
in[total_len++] = 0x80;
int padding_size = (64 - (total_len + 8) % 64) % 64;
while (padding_size--) in[total_len++] = 0;
for (uint64_t i = 0; i < total_len; i += 4) {
uint32_t& w = *(uint32_t*)(in + i);
w = be32toh(w);
}
*(uint32_t*)(in + total_len) = (uint32_t)(in_len >> 29);
*(uint32_t*)(in + total_len + 4) = (uint32_t)(in_len << 3);
for (uint8_t* in_end = in + total_len + 8; in < in_end; in += 64) {
uint32_t* w = (uint32_t*)in;
uint32_t h[5];
memcpy(h, h0, sizeof(h));
for (uint32_t i = 0, j = 0; i < 80; i++, j += 4) {
uint32_t &a = h[j % 5], &b = h[(j + 1) % 5], &c = h[(j + 2) % 5], &d = h[(j + 3) % 5], &e = h[(j + 4) % 5];
if (i >= 16) w[i & 15] = rol(w[(i + 13) & 15] ^ w[(i + 8) & 15] ^ w[(i + 2) & 15] ^ w[i & 15], 1);
if (i < 40) {
if (i < 20)
e += ((b & (c ^ d)) ^ d) + 0x5A827999;
else
e += (b ^ c ^ d) + 0x6ED9EBA1;
}
else {
if (i < 60)
e += (((b | c) & d) | (b & c)) + 0x8F1BBCDC;
else
e += (b ^ c ^ d) + 0xCA62C1D6;
}
e += w[i & 15] + rol(a, 5);
b = rol(b, 30);
}
for (int i = 0; i < 5; i++) h0[i] += h[i];
}
const char* base64tb = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
uint32_t triples[7] = {h0[0] >> 8,
(h0[0] << 16) | (h0[1] >> 16),
(h0[1] << 8) | (h0[2] >> 24),
h0[2],
h0[3] >> 8,
(h0[3] << 16) | (h0[4] >> 16),
h0[4] << 8};
for (uint32_t i = 0; i < 7; i++) {
out[i * 4] = base64tb[(triples[i] >> 18) & 63];
out[i * 4 + 1] = base64tb[(triples[i] >> 12) & 63];
out[i * 4 + 2] = base64tb[(triples[i] >> 6) & 63];
out[i * 4 + 3] = base64tb[triples[i] & 63];
}
out[27] = '=';
return 28;
}
uint32_t handleHttpRequest(EventHandler* handler, Connection& conn, const char* data, uint32_t size) {
const char* data_end = data + size;
const int ValueBufSize = 128;
char request_uri[1024] = {0};
char host[ValueBufSize] = {0};
char origin[ValueBufSize] = {0};
char wskey[ValueBufSize] = {0};
char wsprotocol[ValueBufSize] = {0};
char wsextensions[ValueBufSize] = {0};
bool upgrade_checked = false, connection_checked = false, wsversion_checked = false;
while (true) {
const char* ln = (char*)memchr(data, '\n', data_end - data);
if (!ln) return size;
if (*--ln != '\r') break;
if (request_uri[0] == 0) { // first line
if (memcmp(data, "GET ", 4)) break;
data += 4;
while (*data == ' ') data++;
const char* uri_end = (char*)memchr(data, ' ', ln - data);
uint32_t uri_len = uri_end - data;
if (!uri_end || uri_len >= sizeof(request_uri)) break;
memcpy(request_uri, data, uri_len);
request_uri[uri_len] = 0;
}
else {
const char* val_end = ln;
while (val_end[-1] == ' ') val_end--;
if (val_end == data) { // end of headers
if (!host[0] || !wskey[0] || !upgrade_checked || !connection_checked || !wsversion_checked) break;
char resp_wsprotocol[ValueBufSize] = {0};
char resp_wsextensions[ValueBufSize] = {0};
char resp[1024];
uint32_t resp_len = 0;
bool accept = handler->onWSConnect(
conn, request_uri, host, origin[0] ? origin : nullptr, wsprotocol[0] ? wsprotocol : nullptr,
wsextensions[0] ? wsextensions : nullptr, resp_wsprotocol, ValueBufSize, resp_wsextensions, ValueBufSize);
if (accept) {
conn.open = true;
memcpy(wskey + 24, "258EAFA5-E914-47DA-95CA-C5AB0DC85B11", 36);
char accept_str[32];
accept_str[sha1base64((uint8_t*)wskey, 24 + 36, accept_str)] = 0;
resp_len = sprintf(resp,
"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: "
"Upgrade\r\nSec-WebSocket-Accept: %s\r\n",
accept_str);
}
else {
resp_len = sprintf(resp, "HTTP/1.1 403 Forbidden\r\nSec-WebSocket-Version: 13\r\n");
}
if (resp_wsprotocol[0])
resp_len += sprintf(resp + resp_len, "Sec-WebSocket-Protocol: %s\r\n", resp_wsprotocol);
if (resp_wsextensions[0])
resp_len += sprintf(resp + resp_len, "Sec-WebSocket-Extensions: %s\r\n", resp_wsextensions);
resp_len += sprintf(resp + resp_len, "\r\n");
conn.conn.write((uint8_t*)resp, resp_len);
return data_end - ln - 2;
}
const char* colon = (char*)memchr(data, ':', ln - data);
if (!colon) break;
const char* val = colon + 1;
while (*val == ' ') val++;
uint32_t key_len = colon - data;
uint32_t val_len = val_end - val;
if (val_len < ValueBufSize) {
if (key_len == 4 && !memcmp(data, "Host", 4)) {
memcpy(host, val, val_len);
host[val_len] = 0;
}
else if (key_len == 6 && !memcmp(data, "Origin", 6)) {
memcpy(origin, val, val_len);
origin[val_len] = 0;
}
else if (key_len == 7 && !memcmp(data, "Upgrade", 7)) {
if (memcmp(val, "websocket", 9)) break;
upgrade_checked = true;
}
else if (key_len == 10 && !memcmp(data, "Connection", 10)) {
if (!memcmp(val, "Upgrade", 7)) connection_checked = true;
}
else if (key_len == 17 && !memcmp(data, "Sec-WebSocket-Key", 17)) {
if (val_len != 24) break;
memcpy(wskey, val, val_len);
}
else if (key_len == 21 && !memcmp(data, "Sec-WebSocket-Version", 21)) {
if (val_len != 2 || memcmp(val, "13", 2)) break;
wsversion_checked = true;
}
else if (key_len == 22 && !memcmp(data, "Sec-WebSocket-Protocol", 22)) {
memcpy(wsprotocol, val, val_len);
wsprotocol[val_len] = 0;
}
else if (key_len == 24 && !memcmp(data, "Sec-WebSocket-Extensions", 24)) {
memcpy(wsextensions, val, val_len);
wsextensions[val_len] = 0;
}
}
}
data = ln + 2; // skip \r\n
}
const char* resp400 = "HTTP/1.1 400 Bad Request\r\nSec-WebSocket-Version: 13\r\n\r\n";
conn.conn.write((uint8_t*)resp400, strlen(resp400));
conn.conn.close("bad request");
return size;
}
private:
uint64_t newconn_timeout_;
uint64_t openconn_timeout_;
TcpServer server_;
uint32_t conns_cnt_ = 0;
Connection* conns_[MaxConns];
Connection conns_data_[MaxConns];
};
} // namespace websocket