Browse Source

Add two-way handshake between ZeroTier peers

pull/3853/head
staphen 4 years ago committed by Anders Jenbo
parent
commit
dcfd057200
  1. 77
      Source/dvlnet/base.cpp
  2. 7
      Source/dvlnet/base.h
  3. 162
      Source/dvlnet/base_protocol.h
  4. 11
      Source/dvlnet/packet.cpp
  5. 33
      Source/dvlnet/packet.h
  6. 5
      Source/dvlnet/protocol_zt.cpp
  7. 1
      Source/dvlnet/protocol_zt.h
  8. 2
      Source/dvlnet/tcp_client.h
  9. 20
      Source/dvlnet/tcp_server.cpp
  10. 1
      Source/dvlnet/tcp_server.h

77
Source/dvlnet/base.cpp

@ -34,6 +34,18 @@ void base::DisconnectNet(plr_t plr)
{
}
void base::SendEchoRequest(plr_t player)
{
if (plr_self == PLR_BROADCAST)
return;
if (player == plr_self)
return;
timestamp_t now = SDL_GetTicks();
auto echo = pktfty->make_packet<PT_ECHO_REQUEST>(plr_self, player, now);
send(*echo);
}
void base::HandleAccept(packet &pkt)
{
if (plr_self != PLR_BROADCAST) {
@ -68,8 +80,6 @@ void base::HandleTurn(packet &pkt)
{
plr_t src = pkt.Source();
PlayerState &playerState = playerStateTable_[src];
playerState.waitForTurns = true;
std::deque<turn_t> &turnQueue = playerState.turnQueue;
const turn_t &turn = pkt.Turn();
turnQueue.push_back(turn);
@ -92,7 +102,6 @@ void base::HandleDisconnect(packet &pkt)
ClearMsg(newPlayer);
PlayerState &playerState = playerStateTable_[newPlayer];
playerState.isConnected = false;
playerState.waitForTurns = false;
playerState.turnQueue.clear();
}
} else {
@ -100,6 +109,20 @@ void base::HandleDisconnect(packet &pkt)
}
}
void base::HandleEchoRequest(packet &pkt)
{
auto reply = pktfty->make_packet<PT_ECHO_REPLY>(plr_self, pkt.Source(), pkt.Time());
send(*reply);
}
void base::HandleEchoReply(packet &pkt)
{
uint32_t now = SDL_GetTicks();
plr_t src = pkt.Source();
PlayerState &playerState = playerStateTable_[src];
playerState.roundTripLatency = now - pkt.Time();
}
void base::ClearMsg(plr_t plr)
{
message_queue.erase(std::remove_if(message_queue.begin(),
@ -113,7 +136,11 @@ void base::ClearMsg(plr_t plr)
void base::Connect(plr_t player)
{
PlayerState &playerState = playerStateTable_[player];
bool wasConnected = playerState.isConnected;
playerState.isConnected = true;
if (!wasConnected)
SendFirstTurnIfReady(player);
}
bool base::IsConnected(plr_t player) const
@ -143,6 +170,12 @@ void base::RecvLocal(packet &pkt)
case PT_DISCONNECT:
HandleDisconnect(pkt);
break;
case PT_ECHO_REQUEST:
HandleEchoRequest(pkt);
break;
case PT_ECHO_REPLY:
HandleEchoReply(pkt);
break;
default:
break;
// otherwise drop
@ -187,7 +220,7 @@ bool base::AllTurnsArrived()
{
for (auto i = 0; i < MAX_PLRS; ++i) {
PlayerState &playerState = playerStateTable_[i];
if (!playerState.waitForTurns)
if (!playerState.isConnected)
continue;
std::deque<turn_t> &turnQueue = playerState.turnQueue;
@ -208,7 +241,7 @@ bool base::SNetReceiveTurns(char **data, size_t *size, uint32_t *status)
status[i] = 0;
PlayerState &playerState = playerStateTable_[i];
if (!playerState.waitForTurns)
if (!playerState.isConnected)
continue;
status[i] |= PS_CONNECTED;
@ -226,7 +259,7 @@ bool base::SNetReceiveTurns(char **data, size_t *size, uint32_t *status)
if (AllTurnsArrived()) {
for (auto i = 0; i < MAX_PLRS; ++i) {
PlayerState &playerState = playerStateTable_[i];
if (!playerState.waitForTurns)
if (!playerState.isConnected)
continue;
std::deque<turn_t> &turnQueue = playerState.turnQueue;
@ -253,7 +286,7 @@ bool base::SNetReceiveTurns(char **data, size_t *size, uint32_t *status)
for (auto i = 0; i < MAX_PLRS; ++i) {
PlayerState &playerState = playerStateTable_[i];
if (!playerState.waitForTurns)
if (!playerState.isConnected)
continue;
std::deque<turn_t> &turnQueue = playerState.turnQueue;
@ -284,27 +317,39 @@ bool base::SNetSendTurn(char *data, unsigned int size)
void base::SendTurnIfReady(turn_t turn)
{
PlayerState &playerState = playerStateTable_[plr_self];
bool &ready = playerState.waitForTurns;
if (!ready)
ready = IsGameHost();
if (awaitingSequenceNumber_)
awaitingSequenceNumber_ = !IsGameHost();
if (ready) {
if (!awaitingSequenceNumber_) {
auto pkt = pktfty->make_packet<PT_TURN>(plr_self, PLR_BROADCAST, turn);
send(*pkt);
}
}
void base::MakeReady(seq_t sequenceNumber)
void base::SendFirstTurnIfReady(plr_t player)
{
if (awaitingSequenceNumber_)
return;
PlayerState &playerState = playerStateTable_[plr_self];
if (playerState.waitForTurns)
std::deque<turn_t> &turnQueue = playerState.turnQueue;
if (turnQueue.empty())
return;
turn_t turn = turnQueue.back();
auto pkt = pktfty->make_packet<PT_TURN>(plr_self, player, turn);
send(*pkt);
}
void base::MakeReady(seq_t sequenceNumber)
{
if (!awaitingSequenceNumber_)
return;
next_turn = sequenceNumber;
playerState.waitForTurns = true;
awaitingSequenceNumber_ = false;
PlayerState &playerState = playerStateTable_[plr_self];
std::deque<turn_t> &turnQueue = playerState.turnQueue;
if (!turnQueue.empty()) {
turn_t &turn = turnQueue.front();

7
Source/dvlnet/base.h

@ -64,9 +64,9 @@ protected:
struct PlayerState {
bool isConnected = {};
bool waitForTurns = {};
std::deque<turn_t> turnQueue;
int32_t lastTurnValue = {};
uint32_t roundTripLatency = {};
};
seq_t next_turn = 0;
@ -81,23 +81,28 @@ protected:
void Connect(plr_t player);
void RecvLocal(packet &pkt);
void RunEventHandler(_SNETEVENT &ev);
void SendEchoRequest(plr_t player);
[[nodiscard]] bool IsConnected(plr_t player) const;
virtual bool IsGameHost() = 0;
private:
std::array<PlayerState, MAX_PLRS> playerStateTable_;
bool awaitingSequenceNumber_ = true;
plr_t GetOwner();
bool AllTurnsArrived();
void MakeReady(seq_t sequenceNumber);
void SendTurnIfReady(turn_t turn);
void SendFirstTurnIfReady(plr_t player);
void ClearMsg(plr_t plr);
void HandleAccept(packet &pkt);
void HandleConnect(packet &pkt);
void HandleTurn(packet &pkt);
void HandleDisconnect(packet &pkt);
void HandleEchoRequest(packet &pkt);
void HandleEchoReply(packet &pkt);
};
} // namespace net

162
Source/dvlnet/base_protocol.h

@ -32,24 +32,32 @@ public:
virtual ~base_protocol() = default;
protected:
virtual bool IsGameHost();
bool IsGameHost() override;
private:
P proto;
typedef typename P::endpoint endpoint;
typedef typename P::endpoint endpoint_t;
endpoint firstpeer;
struct Peer {
endpoint_t endpoint;
std::unique_ptr<std::deque<packet>> sendQueue;
};
endpoint_t firstpeer;
std::string gamename;
std::map<std::string, std::tuple<GameData, std::vector<std::string>, endpoint>> game_list;
std::array<endpoint, MAX_PLRS> peers;
std::map<std::string, std::tuple<GameData, std::vector<std::string>, endpoint_t>> game_list;
std::array<Peer, MAX_PLRS> peers;
bool isGameHost_;
plr_t get_master();
void InitiateHandshake(plr_t player);
void SendTo(plr_t player, packet &pkt);
void DrainSendQueue(plr_t player);
void recv();
void handle_join_request(packet &pkt, endpoint sender);
void recv_decrypted(packet &pkt, endpoint sender);
void recv_ingame(packet &pkt, endpoint sender);
bool is_recognized(endpoint sender);
void handle_join_request(packet &pkt, endpoint_t sender);
void recv_decrypted(packet &pkt, endpoint_t sender);
void recv_ingame(packet &pkt, endpoint_t sender);
bool is_recognized(endpoint_t sender);
bool wait_network();
bool wait_firstpeer();
@ -61,7 +69,7 @@ plr_t base_protocol<P>::get_master()
{
plr_t ret = plr_self;
for (plr_t i = 0; i < MAX_PLRS; ++i)
if (peers[i])
if (peers[i].endpoint)
ret = std::min(ret, i);
return ret;
}
@ -81,8 +89,9 @@ bool base_protocol<P>::wait_network()
template <class P>
void base_protocol<P>::DisconnectNet(plr_t plr)
{
proto.disconnect(peers[plr]);
peers[plr] = endpoint();
Peer &peer = peers[plr];
proto.disconnect(peer.endpoint);
peer = {};
}
template <class P>
@ -164,31 +173,58 @@ void base_protocol<P>::poll()
recv();
}
template <class P>
void base_protocol<P>::InitiateHandshake(plr_t player)
{
Peer &peer = peers[player];
// The first packet sent will initiate the TCP connection over the ZeroTier network.
// It will cause problems if both peers attempt to initiate the handshake simultaneously.
// If the connection is already open, it should be safe to initiate from either end.
// If not, only the player with the smaller player number should initiate the handshake.
if (plr_self < player || proto.is_peer_connected(peer.endpoint))
SendEchoRequest(player);
}
template <class P>
void base_protocol<P>::send(packet &pkt)
{
if (pkt.Destination() < MAX_PLRS) {
if (pkt.Destination() == MyPlayerId)
plr_t destination = pkt.Destination();
if (destination < MAX_PLRS) {
if (destination == MyPlayerId)
return;
if (peers[pkt.Destination()])
proto.send(peers[pkt.Destination()], pkt.Data());
} else if (pkt.Destination() == PLR_BROADCAST) {
for (auto &peer : peers)
if (peer)
proto.send(peer, pkt.Data());
} else if (pkt.Destination() == PLR_MASTER) {
SendTo(destination, pkt);
} else if (destination == PLR_BROADCAST) {
for (plr_t player = 0; player < MAX_PLRS; player++)
SendTo(player, pkt);
} else if (destination == PLR_MASTER) {
throw dvlnet_exception();
} else {
throw dvlnet_exception();
}
}
template <class P>
void base_protocol<P>::SendTo(plr_t player, packet &pkt)
{
Peer &peer = peers[player];
if (!peer.endpoint)
return;
// The handshake uses echo packets so clients know
// when they can safely drain their send queues
if (peer.sendQueue && !IsAnyOf(pkt.Type(), PT_ECHO_REQUEST, PT_ECHO_REPLY))
peer.sendQueue->push_back(pkt);
else
proto.send(peer.endpoint, pkt.Data());
}
template <class P>
void base_protocol<P>::recv()
{
try {
buffer_t pkt_buf;
endpoint sender;
endpoint_t sender;
while (proto.recv(sender, pkt_buf)) { // read until kernel buffer is empty?
try {
auto pkt = pktfty->make_packet(pkt_buf);
@ -201,7 +237,7 @@ void base_protocol<P>::recv()
}
while (proto.get_disconnected(sender)) {
for (plr_t i = 0; i < MAX_PLRS; ++i) {
if (peers[i] == sender) {
if (peers[i].endpoint == sender) {
DisconnectNet(i);
break;
}
@ -214,13 +250,15 @@ void base_protocol<P>::recv()
}
template <class P>
void base_protocol<P>::handle_join_request(packet &pkt, endpoint sender)
void base_protocol<P>::handle_join_request(packet &pkt, endpoint_t sender)
{
plr_t i;
for (i = 0; i < MAX_PLRS; ++i) {
if (i != plr_self && !peers[i]) {
Peer &peer = peers[i];
if (i != plr_self && !peer.endpoint) {
peer.endpoint = sender;
peer.sendQueue = std::make_unique<std::deque<packet>>();
Connect(i);
peers[i] = sender;
break;
}
}
@ -229,14 +267,9 @@ void base_protocol<P>::handle_join_request(packet &pkt, endpoint sender)
return;
}
auto reply = pktfty->make_packet<PT_JOIN_ACCEPT>(plr_self, PLR_BROADCAST,
pkt.Cookie(), i,
game_init_info);
proto.send(sender, reply->Data());
auto senderinfo = sender.serialize();
for (plr_t j = 0; j < MAX_PLRS; ++j) {
endpoint peer = peers[j];
endpoint_t peer = peers[j].endpoint;
if ((j != plr_self) && (j != i) && peer) {
auto peerpkt = pktfty->make_packet<PT_CONNECT>(PLR_MASTER, PLR_BROADCAST, i, senderinfo);
proto.send(peer, peerpkt->Data());
@ -245,10 +278,18 @@ void base_protocol<P>::handle_join_request(packet &pkt, endpoint sender)
proto.send(sender, infopkt->Data());
}
}
// PT_JOIN_ACCEPT must be sent after all PT_CONNECT packets so the new player does
// not resume game logic until after having been notified of all existing players
auto reply = pktfty->make_packet<PT_JOIN_ACCEPT>(plr_self, PLR_BROADCAST,
pkt.Cookie(), i,
game_init_info);
proto.send(sender, reply->Data());
DrainSendQueue(i);
}
template <class P>
void base_protocol<P>::recv_decrypted(packet &pkt, endpoint sender)
void base_protocol<P>::recv_decrypted(packet &pkt, endpoint_t sender)
{
if (pkt.Source() == PLR_BROADCAST && pkt.Destination() == PLR_MASTER && pkt.Type() == PT_INFO_REPLY) {
constexpr size_t sizePlayerName = (sizeof(char) * PLR_NAME_LEN);
@ -275,7 +316,7 @@ void base_protocol<P>::recv_decrypted(packet &pkt, endpoint sender)
}
template <class P>
void base_protocol<P>::recv_ingame(packet &pkt, endpoint sender)
void base_protocol<P>::recv_ingame(packet &pkt, endpoint_t sender)
{
if (pkt.Source() == PLR_BROADCAST && pkt.Destination() == PLR_MASTER) {
if (pkt.Type() == PT_JOIN_REQUEST) {
@ -308,28 +349,65 @@ void base_protocol<P>::recv_ingame(packet &pkt, endpoint sender)
}
// addrinfo packets
Connect(pkt.NewPlayer());
peers[pkt.NewPlayer()].unserialize(pkt.Info());
plr_t newPlayer = pkt.NewPlayer();
Peer &peer = peers[newPlayer];
peer.endpoint.unserialize(pkt.Info());
peer.sendQueue = std::make_unique<std::deque<packet>>();
Connect(newPlayer);
if (plr_self != PLR_BROADCAST)
InitiateHandshake(newPlayer);
return;
} else if (pkt.Source() >= MAX_PLRS) {
// normal packets
LogDebug("Invalid packet: packet source ({}) >= MAX_PLRS", pkt.Source());
return;
} else if (sender == firstpeer && pkt.Type() == PT_JOIN_ACCEPT) {
Connect(pkt.Source());
peers[pkt.Source()] = sender;
firstpeer = endpoint();
} else if (sender != peers[pkt.Source()]) {
plr_t src = pkt.Source();
peers[src].endpoint = sender;
Connect(src);
firstpeer = {};
} else if (sender != peers[pkt.Source()].endpoint) {
LogDebug("Invalid packet: packet source ({}) received from unrecognized endpoint", pkt.Source());
return;
}
if (pkt.Destination() != plr_self && pkt.Destination() != PLR_BROADCAST)
return; // packet not for us, drop
bool wasBroadcast = plr_self == PLR_BROADCAST;
RecvLocal(pkt);
if (plr_self != PLR_BROADCAST) {
if (wasBroadcast) {
// Send a handshake to everyone just after PT_JOIN_ACCEPT
for (plr_t player = 0; player < MAX_PLRS; player++)
InitiateHandshake(player);
}
DrainSendQueue(pkt.Source());
}
}
template <class P>
void base_protocol<P>::DrainSendQueue(plr_t player)
{
Peer &srcPeer = peers[player];
if (!srcPeer.sendQueue)
return;
std::deque<packet> &sendQueue = *srcPeer.sendQueue;
while (!sendQueue.empty()) {
packet &pkt = sendQueue.front();
proto.send(srcPeer.endpoint, pkt.Data());
sendQueue.pop_front();
}
srcPeer.sendQueue = nullptr;
}
template <class P>
bool base_protocol<P>::is_recognized(endpoint sender)
bool base_protocol<P>::is_recognized(endpoint_t sender)
{
if (!sender)
return false;
@ -338,7 +416,7 @@ bool base_protocol<P>::is_recognized(endpoint sender)
return true;
for (auto player = 0; player <= MAX_PLRS; player++) {
if (sender == peers[player])
if (sender == peers[player].endpoint)
return true;
}

11
Source/dvlnet/packet.cpp

@ -70,6 +70,10 @@ const char *packet_type_to_string(uint8_t packetType)
return "PT_INFO_REQUEST";
case PT_INFO_REPLY:
return "PT_INFO_REPLY";
case PT_ECHO_REQUEST:
return "PT_ECHO_REQUEST";
case PT_ECHO_REPLY:
return "PT_ECHO_REQUEST";
default:
return nullptr;
}
@ -162,6 +166,13 @@ plr_t packet::NewPlayer()
return m_newplr;
}
timestamp_t packet::Time()
{
assert(have_decrypted);
CheckPacketTypeOneOf({ PT_ECHO_REQUEST, PT_ECHO_REPLY }, m_type);
return m_time;
}
const buffer_t &packet::Info()
{
assert(have_decrypted);

33
Source/dvlnet/packet.h

@ -26,6 +26,8 @@ enum packet_type : uint8_t {
PT_DISCONNECT = 0x14,
PT_INFO_REQUEST = 0x21,
PT_INFO_REPLY = 0x22,
PT_ECHO_REQUEST = 0x31,
PT_ECHO_REPLY = 0x32,
// clang-format on
};
@ -35,6 +37,7 @@ const char *packet_type_to_string(uint8_t packetType);
typedef uint8_t plr_t;
typedef uint8_t seq_t;
typedef uint32_t cookie_t;
typedef uint32_t timestamp_t;
typedef int leaveinfo_t; // also change later
#ifdef PACKET_ENCRYPTION
typedef std::array<unsigned char, crypto_secretbox_KEYBYTES> key_t;
@ -81,6 +84,7 @@ protected:
turn_t m_turn;
cookie_t m_cookie;
plr_t m_newplr;
timestamp_t m_time;
buffer_t m_info;
leaveinfo_t m_leaveinfo;
@ -103,6 +107,7 @@ public:
turn_t Turn();
cookie_t Cookie();
plr_t NewPlayer();
timestamp_t Time();
const buffer_t &Info();
leaveinfo_t LeaveInfo();
};
@ -179,6 +184,10 @@ void packet_proc<P>::process_data()
break;
case PT_INFO_REQUEST:
break;
case PT_ECHO_REQUEST:
case PT_ECHO_REPLY:
self.process_element(m_time);
break;
}
}
@ -313,6 +322,30 @@ inline void packet_out::create<PT_DISCONNECT>(plr_t s, plr_t d, plr_t n,
m_leaveinfo = l;
}
template <>
inline void packet_out::create<PT_ECHO_REQUEST>(plr_t s, plr_t d, timestamp_t t)
{
if (have_encrypted || have_decrypted)
ABORT();
have_decrypted = true;
m_type = PT_ECHO_REQUEST;
m_src = s;
m_dest = d;
m_time = t;
}
template <>
inline void packet_out::create<PT_ECHO_REPLY>(plr_t s, plr_t d, timestamp_t t)
{
if (have_encrypted || have_decrypted)
ABORT();
have_decrypted = true;
m_type = PT_ECHO_REPLY;
m_src = s;
m_dest = d;
m_time = t;
}
inline void packet_out::process_element(buffer_t &x)
{
decrypted_buffer.insert(decrypted_buffer.end(), x.begin(), x.end());

5
Source/dvlnet/protocol_zt.cpp

@ -307,6 +307,11 @@ uint64_t protocol_zt::current_ms()
return 0;
}
bool protocol_zt::is_peer_connected(endpoint &peer)
{
return peer_list.count(peer) != 0 && peer_list[peer].fd != -1;
}
std::string protocol_zt::make_default_gamename()
{
std::string ret;

1
Source/dvlnet/protocol_zt.h

@ -73,6 +73,7 @@ public:
bool recv(endpoint &peer, buffer_t &data);
bool get_disconnected(endpoint &peer);
bool network_online();
bool is_peer_connected(endpoint &peer);
static std::string make_default_gamename();
private:

2
Source/dvlnet/tcp_client.h

@ -31,7 +31,7 @@ public:
virtual std::string make_default_gamename();
protected:
virtual bool IsGameHost();
bool IsGameHost() override;
private:
frame_queue recv_queue;

20
Source/dvlnet/tcp_server.cpp

@ -99,20 +99,25 @@ void tcp_server::HandleReceive(const scc &con, const asio::error_code &ec,
StartReceive(con);
}
void tcp_server::SendConnect(const scc &con)
{
auto pkt = pktfty.make_packet<PT_CONNECT>(PLR_MASTER, PLR_BROADCAST,
con->plr);
SendPacket(*pkt);
}
void tcp_server::HandleReceiveNewPlayer(const scc &con, packet &pkt)
{
auto newplr = NextFree();
if (newplr == PLR_BROADCAST)
throw server_exception();
if (Empty())
game_init_info = pkt.Info();
for (plr_t player = 0; player < MAX_PLRS; player++) {
if (connections[player]) {
auto playerPacket = pktfty.make_packet<PT_CONNECT>(PLR_MASTER, PLR_BROADCAST, newplr);
StartSend(connections[player], *playerPacket);
auto newplrPacket = pktfty.make_packet<PT_CONNECT>(PLR_MASTER, PLR_BROADCAST, player);
StartSend(con, *newplrPacket);
}
}
auto reply = pktfty.make_packet<PT_JOIN_ACCEPT>(PLR_MASTER, PLR_BROADCAST,
pkt.Cookie(), newplr,
game_init_info);
@ -120,7 +125,6 @@ void tcp_server::HandleReceiveNewPlayer(const scc &con, packet &pkt)
con->plr = newplr;
connections[newplr] = con;
con->timeout = timeout_active;
SendConnect(con);
}
void tcp_server::HandleReceivePacket(packet &pkt)

1
Source/dvlnet/tcp_server.h

@ -68,7 +68,6 @@ private:
void HandleReceive(const scc &con, const asio::error_code &ec, size_t bytesRead);
void HandleReceiveNewPlayer(const scc &con, packet &pkt);
void HandleReceivePacket(packet &pkt);
void SendConnect(const scc &con);
void SendPacket(packet &pkt);
void StartSend(const scc &con, packet &pkt);
void HandleSend(const scc &con, const asio::error_code &ec, size_t bytesSent);

Loading…
Cancel
Save