Skip to content

Commit 076d692

Browse files
committed
Adding a new VPN server
1 parent 4d76730 commit 076d692

24 files changed

+8645
-9
lines changed

app/src/main/java/com/pcapplusplus/toyvpn/ToyVpnService.kt

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import kotlinx.coroutines.Dispatchers
1818
import kotlinx.coroutines.SupervisorJob
1919
import kotlinx.coroutines.async
2020
import kotlinx.coroutines.launch
21+
import kotlinx.coroutines.runBlocking
2122
import java.io.FileInputStream
2223
import java.io.FileOutputStream
2324
import java.io.IOException
@@ -29,7 +30,9 @@ import java.util.concurrent.atomic.AtomicBoolean
2930

3031
class ToyVpnService : VpnService() {
3132
private var vpnInterface: ParcelFileDescriptor? = null
32-
private var vpnConnected: AtomicBoolean = AtomicBoolean(false)
33+
private lateinit var vpnTunnel: DatagramChannel
34+
private var vpnConnected = AtomicBoolean(false)
35+
private var isForwardingTraffic = AtomicBoolean(false)
3336
private val packetProcessor = PacketProcessor()
3437
private val packetDataList: MutableList<PacketData> = mutableListOf()
3538
private var lastPacketDataSentTimestamp: Long = 0
@@ -46,6 +49,7 @@ class ToyVpnService : VpnService() {
4649
const val MAX_HANDSHAKE_ATTEMPTS = 50
4750
const val MAX_PACKET_SIZE = 32767
4851
const val MAX_SECRET_LENGTH = 1024
52+
const val DISCONNECT_MESSAGE = "DISCONNECT"
4953
}
5054

5155
@SuppressLint("SyntheticAccessor")
@@ -71,7 +75,7 @@ class ToyVpnService : VpnService() {
7175
registerReceiver(
7276
broadcastReceiver,
7377
IntentFilter(BroadcastActions.VPN_SERVICE_STOP),
74-
Context.RECEIVER_EXPORTED,
78+
RECEIVER_EXPORTED,
7579
)
7680
} else {
7781
registerReceiver(broadcastReceiver, IntentFilter(BroadcastActions.VPN_SERVICE_STOP))
@@ -99,14 +103,17 @@ class ToyVpnService : VpnService() {
99103
}
100104
}
101105

102-
val vpnTunnel = establishVpnResult.await() ?: return@launch
106+
vpnTunnel = establishVpnResult.await() ?: return@launch
103107

104108
try {
109+
isForwardingTraffic.set(true)
105110
val inputSteam = FileInputStream(vpnInterface?.fileDescriptor)
106111
val outputStream = FileOutputStream(vpnInterface?.fileDescriptor)
107112
forwardTraffic(inputSteam, outputStream, vpnTunnel)
108113
} catch (e: Exception) {
109114
stopSelfOnError("Error while forwarding traffic to the VPN server", e)
115+
} finally {
116+
isForwardingTraffic.set(false)
110117
}
111118
}
112119
} catch (e: IllegalArgumentException) {
@@ -181,6 +188,12 @@ class ToyVpnService : VpnService() {
181188
processPacket(packet.array())
182189
outputStream.write(packet.array(), 0, length)
183190
}
191+
else if (length > 1) {
192+
val controlMessage = String(packet.array(), 1, length - 1, US_ASCII)
193+
if (controlMessage == DISCONNECT_MESSAGE) {
194+
stopSelfOnError("Server disconnected")
195+
}
196+
}
184197

185198
idle = false
186199
}
@@ -289,18 +302,39 @@ class ToyVpnService : VpnService() {
289302
private fun disconnectVpn() {
290303
try {
291304
vpnConnected.set(false)
292-
vpnInterface?.close()
293-
vpnInterface = null
305+
306+
runBlocking {
307+
vpnServiceScope.async {
308+
while (isForwardingTraffic.get()) {
309+
Thread.sleep(100)
310+
}
311+
312+
if (::vpnTunnel.isInitialized) {
313+
val packet = ByteBuffer.allocate(DISCONNECT_MESSAGE.length + 1)
314+
val disconnectAsByteArray = byteArrayOf(0) + DISCONNECT_MESSAGE.encodeToByteArray()
315+
packet.put(disconnectAsByteArray).flip()
316+
317+
if (vpnTunnel.write(packet) != disconnectAsByteArray.size) {
318+
Log.e(LOG_TAG, "Couldn't send disconnect message to the server")
319+
}
320+
vpnTunnel.close()
321+
vpnInterface?.close()
322+
vpnInterface = null
323+
}
324+
}.await()
325+
}
294326
} catch (e: IOException) {
295327
e.printStackTrace()
296328
}
297329
}
298330

299-
private fun stopSelfOnError(
331+
private fun stopSelfOnError(
300332
errorMessage: String,
301-
exception: java.lang.Exception,
333+
exception: java.lang.Exception? = null,
302334
) {
303-
Log.e(LOG_TAG, errorMessage, exception)
335+
if (exception != null) {
336+
Log.e(LOG_TAG, errorMessage, exception)
337+
}
304338

305339
val intent =
306340
Intent(BroadcastActions.VPN_SERVICE_ERROR).apply {

server/linux/.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
**/*.pcapng
2+
libs/pcapplusplus/*
3+
build/*

server/linux/CMakeLists.txt

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
cmake_minimum_required(VERSION 3.20)
2+
project(ToyVpnServer)
3+
4+
set(CMAKE_CXX_STANDARD 17)
5+
6+
set(PCAPPLUSPLUS_INCLUDE_DIR "${CMAKE_SOURCE_DIR}/libs/PcapPlusPlus/include")
7+
set(PCAPPLUSPLUS_LIB_DIR "${CMAKE_SOURCE_DIR}/libs/PcapPlusPlus/lib")
8+
9+
# Create the executable target first
10+
add_executable(ToyVpnServer
11+
ToyVpnServer.h
12+
main.cpp)
13+
14+
# Add the argparse directory to the target's include path
15+
target_include_directories(ToyVpnServer PRIVATE argparse)
16+
17+
# Add PcapPlusPlus headers to the include search path
18+
target_include_directories(ToyVpnServer PRIVATE ${PCAPPLUSPLUS_INCLUDE_DIR})
19+
20+
# Link the necessary PcapPlusPlus libraries
21+
target_link_libraries(ToyVpnServer PRIVATE
22+
${PCAPPLUSPLUS_LIB_DIR}/libPcap++.a
23+
${PCAPPLUSPLUS_LIB_DIR}/libPacket++.a
24+
${PCAPPLUSPLUS_LIB_DIR}/libCommon++.a
25+
pcap)

server/linux/ClientHandler.h

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
#pragma once
2+
3+
#include <chrono>
4+
#include <netinet/in.h>
5+
#include "libs/pcapplusplus/include/pcapplusplus/IpAddress.h"
6+
#include "TunInterfaceWrapper.h"
7+
#include "ServerSocketWrapper.h"
8+
#include "VpnSettings.h"
9+
#include "PacketHandler.h"
10+
#include "Log.h"
11+
12+
class ClientHandler {
13+
public:
14+
ClientHandler(const ServerSocketWrapper& serverSocket, const sockaddr_in6& clientExternalAddress, const TunInterfaceWrapper& tunInterface, const VpnSettings& vpnSettings, std::optional<PacketHandler>& packetHandler)
15+
: m_ServerSocket(serverSocket), m_ClientExternalAddress(clientExternalAddress), m_TunInterface(tunInterface),
16+
m_VpnSettings(vpnSettings), m_PacketHandler(packetHandler)
17+
{}
18+
19+
template<std::size_t BUFFER_SIZE>
20+
void handleDataFromClient(const std::array<uint8_t, BUFFER_SIZE>& buffer, size_t dataSize) {
21+
m_LastMessageTimestamp = std::chrono::steady_clock::now();
22+
switch (m_State) {
23+
case State::START: {
24+
if (dataSize < 2 || buffer[0] != 0) {
25+
break;
26+
}
27+
28+
std::string secret(buffer.begin() + 1, buffer.begin() + dataSize);
29+
if (secret != m_VpnSettings.secret) {
30+
TOYVPN_LOG_ERROR("Got the wrong secret: '" << secret << "'");
31+
m_State = State::ERROR;
32+
break;
33+
}
34+
35+
auto params = m_VpnSettings.toParamString();
36+
TOYVPN_LOG_DEBUG("Sending params to client: '" << params << "'");
37+
std::vector<uint8_t> paramsMessage;
38+
paramsMessage.push_back(0);
39+
paramsMessage.insert(paramsMessage.end(), params.begin(), params.end());
40+
auto bytesSent = m_ServerSocket.send(paramsMessage, m_ClientExternalAddress);
41+
if (bytesSent == -1) {
42+
m_State = State::ERROR;
43+
} else {
44+
m_State = State::CONNECTED;
45+
46+
std::array<uint8_t, 16> ipv6AddressBytes;
47+
std::copy(std::begin(m_ClientExternalAddress.sin6_addr.s6_addr),
48+
std::end(m_ClientExternalAddress.sin6_addr.s6_addr), ipv6AddressBytes.begin());
49+
50+
TOYVPN_LOG_INFO("New client connected! External address: ("
51+
<< pcpp::IPv6Address(ipv6AddressBytes).toString()
52+
<< ","
53+
<< m_ClientExternalAddress.sin6_port
54+
<< ")"
55+
<< ", Internal address: "
56+
<< m_VpnSettings.clientAddress);
57+
}
58+
break;
59+
}
60+
case State::CONNECTED: {
61+
// Got a control packet
62+
if (dataSize == 1 && buffer[0] == 0) {
63+
break;
64+
}
65+
66+
if (dataSize == 11 && buffer[0] == 0) {
67+
std::string message(buffer.begin() + 1, buffer.begin() + dataSize);
68+
if (message == m_DisconnectMessage) {
69+
m_State = State::DISCONNECTED;
70+
71+
if (m_PacketHandler.has_value()) {
72+
m_PacketHandler->clientDisconnected(m_VpnSettings.clientAddress);
73+
}
74+
75+
TOYVPN_LOG_INFO("Client disconnected: " << m_VpnSettings.clientAddress);
76+
break;
77+
}
78+
}
79+
80+
m_TunInterface.send(buffer, dataSize);
81+
82+
if (m_PacketHandler.has_value()) {
83+
m_PacketHandler->handlePacket(m_VpnSettings.clientAddress, buffer, dataSize);
84+
}
85+
86+
break;
87+
}
88+
default: {
89+
}
90+
}
91+
}
92+
93+
template<std::size_t BUFFER_SIZE>
94+
void handleDataFromTun(const std::array<uint8_t, BUFFER_SIZE>& buffer, size_t dataSize) {
95+
m_ServerSocket.send(buffer, dataSize, m_ClientExternalAddress);
96+
97+
if (m_PacketHandler.has_value()) {
98+
m_PacketHandler->handlePacket(m_VpnSettings.clientAddress, buffer, dataSize);
99+
}
100+
}
101+
102+
void disconnect() {
103+
if (m_State != State::CONNECTED) {
104+
return;
105+
}
106+
107+
std::vector<uint8_t> disconnectMessage;
108+
disconnectMessage.push_back(0);
109+
disconnectMessage.insert(disconnectMessage.end(), m_DisconnectMessage.begin(), m_DisconnectMessage.end());
110+
for (int i = 0; i < 3; i++) {
111+
m_ServerSocket.send(disconnectMessage, m_ClientExternalAddress);
112+
}
113+
114+
m_State = State::DISCONNECTED;
115+
116+
if (m_PacketHandler.has_value()) {
117+
m_PacketHandler->clientDisconnected(m_VpnSettings.clientAddress);
118+
}
119+
120+
TOYVPN_LOG_INFO("Client disconnected: " << m_VpnSettings.clientAddress);
121+
}
122+
123+
pcpp::IPv4Address getClientVpnAddress() const { return m_VpnSettings.clientAddress; }
124+
125+
bool isIdle(const std::chrono::steady_clock::time_point& now) {
126+
return m_State == State::DISCONNECTED || now - m_LastMessageTimestamp > m_ClientIdleTimeoutSec;
127+
}
128+
129+
private:
130+
enum class State {
131+
START,
132+
CONNECTED,
133+
DISCONNECTED,
134+
ERROR
135+
};
136+
137+
constexpr static std::string_view m_DisconnectMessage = "DISCONNECT";
138+
constexpr static std::chrono::duration m_ClientIdleTimeoutSec = std::chrono::seconds(60);
139+
140+
const ServerSocketWrapper& m_ServerSocket;
141+
const TunInterfaceWrapper& m_TunInterface;
142+
std::optional<PacketHandler>& m_PacketHandler;
143+
int m_BufferSize;
144+
State m_State = State::START;
145+
sockaddr_in6 m_ClientExternalAddress;
146+
VpnSettings m_VpnSettings;
147+
std::chrono::steady_clock::time_point m_LastMessageTimestamp;
148+
};

server/linux/EpollWrapper.h

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#pragma once
2+
3+
#include <functional>
4+
#include <unistd.h>
5+
#include <sys/epoll.h>
6+
#include <unordered_map>
7+
8+
class EPollWrapper {
9+
public:
10+
using EPollCallback = std::function<void(int)>;
11+
12+
virtual ~EPollWrapper() {
13+
if (m_EPollFd != -1) {
14+
close(m_EPollFd);
15+
}
16+
}
17+
18+
void init(int maxEvents) {
19+
m_EPollFd = epoll_create1(0);
20+
if (m_EPollFd == -1) {
21+
throw std::runtime_error("Error creating epoll instance!");
22+
}
23+
24+
m_MaxEvents = maxEvents;
25+
}
26+
27+
void add(int fd, const EPollCallback& callback) {
28+
if (m_EPollFd == -1) {
29+
throw std::runtime_error("Instance not initialized, please call init()!");
30+
}
31+
32+
struct epoll_event event;
33+
event.events = EPOLLIN;
34+
event.data.fd = fd;
35+
if (epoll_ctl(m_EPollFd, EPOLL_CTL_ADD, fd, &event) == -1) {
36+
throw std::runtime_error("Error adding fd to epoll!");
37+
}
38+
39+
m_FdToCallbackMap[event.data.fd] = callback;
40+
}
41+
42+
void remove(int fd) {
43+
m_FdToCallbackMap.erase(fd);
44+
}
45+
46+
void startPolling() {
47+
if (m_IsPolling) {
48+
throw std::runtime_error("Already polling!");
49+
}
50+
51+
m_IsPolling = true;
52+
53+
epoll_event events[m_MaxEvents];
54+
55+
while (m_IsPolling) {
56+
int numEvents = epoll_wait(m_EPollFd, events, m_MaxEvents, -1);
57+
if (numEvents < 0) {
58+
if (!m_IsPolling) {
59+
return;
60+
}
61+
m_IsPolling = false;
62+
throw std::runtime_error("Error with epoll_wait!");
63+
}
64+
65+
for (int i = 0; i < numEvents; ++i) {
66+
m_FdToCallbackMap[events[i].data.fd](events[i].data.fd);
67+
}
68+
}
69+
}
70+
71+
void stopPolling() {
72+
m_IsPolling = false;
73+
close(m_EPollFd);
74+
}
75+
76+
private:
77+
int m_EPollFd = -1;
78+
int m_MaxEvents = -1;
79+
std::unordered_map<int, EPollCallback> m_FdToCallbackMap;
80+
bool m_IsPolling = false;
81+
};
82+

0 commit comments

Comments
 (0)