diff --git a/ext/crow/TinySHA1.hpp b/ext/crow/TinySHA1.hpp new file mode 100755 index 0000000..70af046 --- /dev/null +++ b/ext/crow/TinySHA1.hpp @@ -0,0 +1,196 @@ +/* + * + * TinySHA1 - a header only implementation of the SHA1 algorithm in C++. Based + * on the implementation in boost::uuid::details. + * + * SHA1 Wikipedia Page: http://en.wikipedia.org/wiki/SHA-1 + * + * Copyright (c) 2012-22 SAURAV MOHAPATRA + * + * Permission to use, copy, modify, and distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ +#ifndef _TINY_SHA1_HPP_ +#define _TINY_SHA1_HPP_ +#include +#include +#include +#include +namespace sha1 +{ + class SHA1 + { + public: + typedef uint32_t digest32_t[5]; + typedef uint8_t digest8_t[20]; + inline static uint32_t LeftRotate(uint32_t value, size_t count) { + return (value << count) ^ (value >> (32-count)); + } + SHA1(){ reset(); } + virtual ~SHA1() {} + SHA1(const SHA1& s) { *this = s; } + const SHA1& operator = (const SHA1& s) { + memcpy(m_digest, s.m_digest, 5 * sizeof(uint32_t)); + memcpy(m_block, s.m_block, 64); + m_blockByteIndex = s.m_blockByteIndex; + m_byteCount = s.m_byteCount; + return *this; + } + SHA1& reset() { + m_digest[0] = 0x67452301; + m_digest[1] = 0xEFCDAB89; + m_digest[2] = 0x98BADCFE; + m_digest[3] = 0x10325476; + m_digest[4] = 0xC3D2E1F0; + m_blockByteIndex = 0; + m_byteCount = 0; + return *this; + } + SHA1& processByte(uint8_t octet) { + this->m_block[this->m_blockByteIndex++] = octet; + ++this->m_byteCount; + if(m_blockByteIndex == 64) { + this->m_blockByteIndex = 0; + processBlock(); + } + return *this; + } + SHA1& processBlock(const void* const start, const void* const end) { + const uint8_t* begin = static_cast(start); + const uint8_t* finish = static_cast(end); + while(begin != finish) { + processByte(*begin); + begin++; + } + return *this; + } + SHA1& processBytes(const void* const data, size_t len) { + const uint8_t* block = static_cast(data); + processBlock(block, block + len); + return *this; + } + const uint32_t* getDigest(digest32_t digest) { + size_t bitCount = this->m_byteCount * 8; + processByte(0x80); + if (this->m_blockByteIndex > 56) { + while (m_blockByteIndex != 0) { + processByte(0); + } + while (m_blockByteIndex < 56) { + processByte(0); + } + } else { + while (m_blockByteIndex < 56) { + processByte(0); + } + } + processByte(0); + processByte(0); + processByte(0); + processByte(0); + processByte( static_cast((bitCount>>24) & 0xFF)); + processByte( static_cast((bitCount>>16) & 0xFF)); + processByte( static_cast((bitCount>>8 ) & 0xFF)); + processByte( static_cast((bitCount) & 0xFF)); + + memcpy(digest, m_digest, 5 * sizeof(uint32_t)); + return digest; + } + const uint8_t* getDigestBytes(digest8_t digest) { + digest32_t d32; + getDigest(d32); + size_t di = 0; + digest[di++] = ((d32[0] >> 24) & 0xFF); + digest[di++] = ((d32[0] >> 16) & 0xFF); + digest[di++] = ((d32[0] >> 8) & 0xFF); + digest[di++] = ((d32[0]) & 0xFF); + + digest[di++] = ((d32[1] >> 24) & 0xFF); + digest[di++] = ((d32[1] >> 16) & 0xFF); + digest[di++] = ((d32[1] >> 8) & 0xFF); + digest[di++] = ((d32[1]) & 0xFF); + + digest[di++] = ((d32[2] >> 24) & 0xFF); + digest[di++] = ((d32[2] >> 16) & 0xFF); + digest[di++] = ((d32[2] >> 8) & 0xFF); + digest[di++] = ((d32[2]) & 0xFF); + + digest[di++] = ((d32[3] >> 24) & 0xFF); + digest[di++] = ((d32[3] >> 16) & 0xFF); + digest[di++] = ((d32[3] >> 8) & 0xFF); + digest[di++] = ((d32[3]) & 0xFF); + + digest[di++] = ((d32[4] >> 24) & 0xFF); + digest[di++] = ((d32[4] >> 16) & 0xFF); + digest[di++] = ((d32[4] >> 8) & 0xFF); + digest[di++] = ((d32[4]) & 0xFF); + return digest; + } + + protected: + void processBlock() { + uint32_t w[80]; + for (size_t i = 0; i < 16; i++) { + w[i] = (m_block[i*4 + 0] << 24); + w[i] |= (m_block[i*4 + 1] << 16); + w[i] |= (m_block[i*4 + 2] << 8); + w[i] |= (m_block[i*4 + 3]); + } + for (size_t i = 16; i < 80; i++) { + w[i] = LeftRotate((w[i-3] ^ w[i-8] ^ w[i-14] ^ w[i-16]), 1); + } + + uint32_t a = m_digest[0]; + uint32_t b = m_digest[1]; + uint32_t c = m_digest[2]; + uint32_t d = m_digest[3]; + uint32_t e = m_digest[4]; + + for (std::size_t i=0; i<80; ++i) { + uint32_t f = 0; + uint32_t k = 0; + + if (i<20) { + f = (b & c) | (~b & d); + k = 0x5A827999; + } else if (i<40) { + f = b ^ c ^ d; + k = 0x6ED9EBA1; + } else if (i<60) { + f = (b & c) | (b & d) | (c & d); + k = 0x8F1BBCDC; + } else { + f = b ^ c ^ d; + k = 0xCA62C1D6; + } + uint32_t temp = LeftRotate(a, 5) + f + e + k + w[i]; + e = d; + d = c; + c = LeftRotate(b, 30); + b = a; + a = temp; + } + + m_digest[0] += a; + m_digest[1] += b; + m_digest[2] += c; + m_digest[3] += d; + m_digest[4] += e; + } + private: + digest32_t m_digest; + uint8_t m_block[64]; + size_t m_blockByteIndex; + size_t m_byteCount; + }; +} +#endif diff --git a/ext/crow/websocket.h b/ext/crow/websocket.h new file mode 100755 index 0000000..5299c1a --- /dev/null +++ b/ext/crow/websocket.h @@ -0,0 +1,482 @@ +#pragma once +#include "socket_adaptors.h" +#include "http_request.h" +#include "TinySHA1.hpp" + +namespace crow +{ + namespace websocket + { + enum class WebSocketReadState + { + MiniHeader, + Len16, + Len64, + Mask, + Payload, + }; + + struct connection + { + virtual void send_binary(const std::string& msg) = 0; + virtual void send_text(const std::string& msg) = 0; + virtual void close(const std::string& msg = "quit") = 0; + virtual ~connection(){} + }; + + template + class Connection : public connection + { + public: + Connection(const crow::request& req, Adaptor&& adaptor, + std::function open_handler, + std::function message_handler, + std::function close_handler, + std::function error_handler) + : adaptor_(std::move(adaptor)), open_handler_(std::move(open_handler)), message_handler_(std::move(message_handler)), close_handler_(std::move(close_handler)), error_handler_(std::move(error_handler)) + { + if (req.get_header_value("upgrade") != "websocket") + { + adaptor.close(); + delete this; + return; + } + // Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ== + // Sec-WebSocket-Version: 13 + std::string magic = req.get_header_value("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + sha1::SHA1 s; + s.processBytes(magic.data(), magic.size()); + uint8_t digest[20]; + s.getDigestBytes(digest); + start(crow::utility::base64encode((char*)digest, 20)); + } + + template + void dispatch(CompletionHandler handler) + { + adaptor_.get_io_service().dispatch(handler); + } + + template + void post(CompletionHandler handler) + { + adaptor_.get_io_service().post(handler); + } + + void send_pong(const std::string& msg) + { + dispatch([this, msg]{ + char buf[3] = "\x8A\x00"; + buf[1] += msg.size(); + write_buffers_.emplace_back(buf, buf+2); + write_buffers_.emplace_back(msg); + do_write(); + }); + } + + void send_binary(const std::string& msg) override + { + dispatch([this, msg]{ + auto header = build_header(2, msg.size()); + write_buffers_.emplace_back(std::move(header)); + write_buffers_.emplace_back(msg); + do_write(); + }); + } + + void send_text(const std::string& msg) override + { + dispatch([this, msg]{ + auto header = build_header(1, msg.size()); + write_buffers_.emplace_back(std::move(header)); + write_buffers_.emplace_back(msg); + do_write(); + }); + } + + void close(const std::string& msg) override + { + dispatch([this, msg]{ + has_sent_close_ = true; + if (has_recv_close_ && !is_close_handler_called_) + { + is_close_handler_called_ = true; + if (close_handler_) + close_handler_(*this, msg); + } + auto header = build_header(0x8, msg.size()); + write_buffers_.emplace_back(std::move(header)); + write_buffers_.emplace_back(msg); + do_write(); + }); + } + + protected: + + std::string build_header(int opcode, size_t size) + { + char buf[2+8] = "\x80\x00"; + buf[0] += opcode; + if (size < 126) + { + buf[1] += size; + return {buf, buf+2}; + } + else if (size < 0x10000) + { + buf[1] += 126; + *(uint16_t*)(buf+2) = (uint16_t)size; + return {buf, buf+4}; + } + else + { + buf[1] += 127; + *(uint64_t*)(buf+2) = (uint64_t)size; + return {buf, buf+10}; + } + } + + void start(std::string&& hello) + { + static std::string header = "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept: "; + static std::string crlf = "\r\n"; + write_buffers_.emplace_back(header); + write_buffers_.emplace_back(std::move(hello)); + write_buffers_.emplace_back(crlf); + write_buffers_.emplace_back(crlf); + do_write(); + if (open_handler_) + open_handler_(*this); + do_read(); + } + + void do_read() + { + is_reading = true; + switch(state_) + { + case WebSocketReadState::MiniHeader: + { + //boost::asio::async_read(adaptor_.socket(), boost::asio::buffer(&mini_header_, 1), + adaptor_.socket().async_read_some(boost::asio::buffer(&mini_header_, 2), + [this](const boost::system::error_code& ec, std::size_t bytes_transferred) + { + is_reading = false; + mini_header_ = htons(mini_header_); +#ifdef CROW_ENABLE_DEBUG + + if (!ec && bytes_transferred != 2) + { + throw std::runtime_error("WebSocket:MiniHeader:async_read fail:asio bug?"); + } +#endif + + if (!ec && ((mini_header_ & 0x80) == 0x80)) + { + if ((mini_header_ & 0x7f) == 127) + { + state_ = WebSocketReadState::Len64; + } + else if ((mini_header_ & 0x7f) == 126) + { + state_ = WebSocketReadState::Len16; + } + else + { + remaining_length_ = mini_header_ & 0x7f; + state_ = WebSocketReadState::Mask; + } + do_read(); + } + else + { + close_connection_ = true; + adaptor_.close(); + if (error_handler_) + error_handler_(*this); + check_destroy(); + } + }); + } + break; + case WebSocketReadState::Len16: + { + remaining_length_ = 0; + boost::asio::async_read(adaptor_.socket(), boost::asio::buffer(&remaining_length_, 2), + [this](const boost::system::error_code& ec, std::size_t bytes_transferred) + { + is_reading = false; + remaining_length_ = ntohs(*(uint16_t*)&remaining_length_); +#ifdef CROW_ENABLE_DEBUG + if (!ec && bytes_transferred != 2) + { + throw std::runtime_error("WebSocket:Len16:async_read fail:asio bug?"); + } +#endif + + if (!ec) + { + state_ = WebSocketReadState::Mask; + do_read(); + } + else + { + close_connection_ = true; + adaptor_.close(); + if (error_handler_) + error_handler_(*this); + check_destroy(); + } + }); + } + break; + case WebSocketReadState::Len64: + { + boost::asio::async_read(adaptor_.socket(), boost::asio::buffer(&remaining_length_, 8), + [this](const boost::system::error_code& ec, std::size_t bytes_transferred) + { + is_reading = false; + remaining_length_ = ((1==ntohl(1)) ? (remaining_length_) : ((uint64_t)ntohl((remaining_length_) & 0xFFFFFFFF) << 32) | ntohl((remaining_length_) >> 32)); +#ifdef CROW_ENABLE_DEBUG + if (!ec && bytes_transferred != 8) + { + throw std::runtime_error("WebSocket:Len16:async_read fail:asio bug?"); + } +#endif + + if (!ec) + { + state_ = WebSocketReadState::Mask; + do_read(); + } + else + { + close_connection_ = true; + adaptor_.close(); + if (error_handler_) + error_handler_(*this); + check_destroy(); + } + }); + } + break; + case WebSocketReadState::Mask: + boost::asio::async_read(adaptor_.socket(), boost::asio::buffer((char*)&mask_, 4), + [this](const boost::system::error_code& ec, std::size_t bytes_transferred) + { + is_reading = false; +#ifdef CROW_ENABLE_DEBUG + if (!ec && bytes_transferred != 4) + { + throw std::runtime_error("WebSocket:Mask:async_read fail:asio bug?"); + } +#endif + + if (!ec) + { + state_ = WebSocketReadState::Payload; + do_read(); + } + else + { + close_connection_ = true; + if (error_handler_) + error_handler_(*this); + adaptor_.close(); + } + }); + break; + case WebSocketReadState::Payload: + { + size_t to_read = buffer_.size(); + if (remaining_length_ < to_read) + to_read = remaining_length_; + adaptor_.socket().async_read_some( boost::asio::buffer(buffer_, to_read), + [this](const boost::system::error_code& ec, std::size_t bytes_transferred) + { + is_reading = false; + + if (!ec) + { + fragment_.insert(fragment_.end(), buffer_.begin(), buffer_.begin() + bytes_transferred); + remaining_length_ -= bytes_transferred; + if (remaining_length_ == 0) + { + handle_fragment(); + state_ = WebSocketReadState::MiniHeader; + do_read(); + } + } + else + { + close_connection_ = true; + if (error_handler_) + error_handler_(*this); + adaptor_.close(); + } + }); + } + break; + } + } + + bool is_FIN() + { + return mini_header_ & 0x8000; + } + + int opcode() + { + return (mini_header_ & 0x0f00) >> 8; + } + + void handle_fragment() + { + for(decltype(fragment_.length()) i = 0; i < fragment_.length(); i ++) + { + fragment_[i] ^= ((char*)&mask_)[i%4]; + } + switch(opcode()) + { + case 0: // Continuation + { + message_ += fragment_; + if (is_FIN()) + { + if (message_handler_) + message_handler_(*this, message_, is_binary_); + message_.clear(); + } + } + case 1: // Text + { + is_binary_ = false; + message_ += fragment_; + if (is_FIN()) + { + if (message_handler_) + message_handler_(*this, message_, is_binary_); + message_.clear(); + } + } + break; + case 2: // Binary + { + is_binary_ = true; + message_ += fragment_; + if (is_FIN()) + { + if (message_handler_) + message_handler_(*this, message_, is_binary_); + message_.clear(); + } + } + break; + case 0x8: // Close + { + has_recv_close_ = true; + if (!has_sent_close_) + { + close(fragment_); + } + else + { + adaptor_.close(); + close_connection_ = true; + if (!is_close_handler_called_) + { + if (close_handler_) + close_handler_(*this, fragment_); + is_close_handler_called_ = true; + } + check_destroy(); + } + } + break; + case 0x9: // Ping + { + send_pong(fragment_); + } + break; + case 0xA: // Pong + { + pong_received_ = true; + } + break; + } + + fragment_.clear(); + } + + void do_write() + { + if (sending_buffers_.empty()) + { + sending_buffers_.swap(write_buffers_); + std::vector buffers; + buffers.reserve(sending_buffers_.size()); + for(auto& s:sending_buffers_) + { + buffers.emplace_back(boost::asio::buffer(s)); + } + boost::asio::async_write(adaptor_.socket(), buffers, + [&](const boost::system::error_code& ec, std::size_t /*bytes_transferred*/) + { + sending_buffers_.clear(); + if (!ec && !close_connection_) + { + if (!write_buffers_.empty()) + do_write(); + if (has_sent_close_) + close_connection_ = true; + } + else + { + close_connection_ = true; + check_destroy(); + } + }); + } + } + + void check_destroy() + { + //if (has_sent_close_ && has_recv_close_) + if (!is_close_handler_called_) + if (close_handler_) + close_handler_(*this, "uncleanly"); + if (sending_buffers_.empty() && !is_reading) + delete this; + } + private: + Adaptor adaptor_; + + std::vector sending_buffers_; + std::vector write_buffers_; + + boost::array buffer_; + bool is_binary_; + std::string message_; + std::string fragment_; + WebSocketReadState state_{WebSocketReadState::MiniHeader}; + uint64_t remaining_length_{0}; + bool close_connection_{false}; + bool is_reading{false}; + uint32_t mask_; + uint16_t mini_header_; + bool has_sent_close_{false}; + bool has_recv_close_{false}; + bool error_occured_{false}; + bool pong_received_{false}; + bool is_close_handler_called_{false}; + + std::function open_handler_; + std::function message_handler_; + std::function close_handler_; + std::function error_handler_; + }; + } +}