diff --git a/.idea/modules.xml b/.idea/modules.xml
deleted file mode 100644
index 053b796..0000000
--- a/.idea/modules.xml
+++ /dev/null
@@ -1,8 +0,0 @@
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/src/main/cpp/.idea/editor.xml b/src/main/cpp/.idea/editor.xml
new file mode 100644
index 0000000..226ca24
--- /dev/null
+++ b/src/main/cpp/.idea/editor.xml
@@ -0,0 +1,580 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/src/main/cpp/.idea/misc.xml b/src/main/cpp/.idea/misc.xml
index 79b3c94..0b76fe5 100644
--- a/src/main/cpp/.idea/misc.xml
+++ b/src/main/cpp/.idea/misc.xml
@@ -1,4 +1,7 @@
+
+
+
\ No newline at end of file
diff --git a/src/main/cpp/CMakeLists.txt b/src/main/cpp/CMakeLists.txt
index e5e7f5e..6862d06 100644
--- a/src/main/cpp/CMakeLists.txt
+++ b/src/main/cpp/CMakeLists.txt
@@ -15,8 +15,7 @@ set(JNI_INCLUDE_DIR2 "${JAVA_HOME}/include/win32")
include_directories(${JNI_INCLUDE_DIR} ${JNI_INCLUDE_DIR2})
# Add the source file that includes your JNI C wrapper code
-add_library(wireguard_wrapper SHARED wireguard_wrapper.c
- example.c)
+add_library(wireguard_wrapper SHARED wireguard_wrapper.c)
#target_link_libraries(wireguard_wrapper PRIVATE Rpcrt4)
target_link_libraries(wireguard_wrapper ws2_32 rpcrt4)
diff --git a/src/main/go/.idea/.gitignore b/src/main/go/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/src/main/go/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/src/main/go/.idea/vcs.xml b/src/main/go/.idea/vcs.xml
new file mode 100644
index 0000000..d843f34
--- /dev/null
+++ b/src/main/go/.idea/vcs.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/src/main/go/go.mod b/src/main/go/go.mod
new file mode 100644
index 0000000..0d63826
--- /dev/null
+++ b/src/main/go/go.mod
@@ -0,0 +1,13 @@
+module git.nevets.tech/Steven/TVPN
+
+go 1.23.2
+
+require (
+ golang.org/x/sys v0.26.0
+ golang.zx2c4.com/wireguard/windows v0.5.3
+)
+
+require (
+ golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3 // indirect
+ golang.org/x/text v0.3.8-0.20211105212822-18b340fc7af2 // indirect
+)
diff --git a/src/main/go/go.sum b/src/main/go/go.sum
new file mode 100644
index 0000000..e3a7754
--- /dev/null
+++ b/src/main/go/go.sum
@@ -0,0 +1,8 @@
+golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3 h1:0es+/5331RGQPcXlMfP+WrnIIS6dNnNRe0WB02W0F4M=
+golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
+golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
+golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
+golang.org/x/text v0.3.8-0.20211105212822-18b340fc7af2 h1:GLw7MR8AfAG2GmGcmVgObFOHXYypgGjnGno25RDwn3Y=
+golang.org/x/text v0.3.8-0.20211105212822-18b340fc7af2/go.mod h1:EFNZuWvGYxIRUEX+K8UmCFwYmZjqcrnq15ZuVldZkZ0=
+golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
+golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
diff --git a/src/main/go/main.go b/src/main/go/main.go
new file mode 100644
index 0000000..28d0975
--- /dev/null
+++ b/src/main/go/main.go
@@ -0,0 +1,45 @@
+package main
+
+import (
+ "crypto/rand"
+ "golang.org/x/crypto/curve25519"
+ "golang.org/x/sys/windows"
+ "golang.zx2c4.com/wireguard/windows/conf"
+ "golang.zx2c4.com/wireguard/windows/tunnel"
+ "log"
+ "path/filepath"
+ "unsafe"
+)
+
+//export InstallTunnelService
+func InstallTunnelService(confPath string) {
+
+}
+
+//export CreateTunnelService
+func CreateTunnelService(confFile16 *uint16) bool {
+ confFile := windows.UTF16PtrToString(confFile16)
+ conf.PresetRootDirectory(filepath.Dir(confFile))
+ tunnel.UseFixedGUIDInsteadOfDeterministic = true
+ err := tunnel.Run(confFile)
+ if err != nil {
+ log.Printf("Service run error: %v", err)
+ }
+ return err == nil
+}
+
+//export GenerateKeyPair
+func GenerateKeypair(publicKey, privateKey *byte) {
+ publicKeyArray := (*[32]byte)(unsafe.Pointer(publicKey))
+ privateKeyArray := (*[32]byte)(unsafe.Pointer(privateKey))
+ n, err := rand.Read(privateKeyArray[:])
+ if err != nil || n != len(privateKeyArray) {
+ panic("Unable to generate random bytes")
+ }
+ privateKeyArray[0] &= 248
+ privateKeyArray[31] = (privateKeyArray[31] & 127) | 64
+
+ curve25519.ScalarBaseMult(publicKeyArray, privateKeyArray)
+}
+
+func main() {}
diff --git a/src/main/go/service/install.go b/src/main/go/service/install.go
new file mode 100644
index 0000000..d16bbec
--- /dev/null
+++ b/src/main/go/service/install.go
@@ -0,0 +1,117 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
+ *
+ * Modified by Steven Tracey
+ */
+
+package service
+
+import (
+ "errors"
+ "golang.org/x/sys/windows"
+ "golang.org/x/sys/windows/svc"
+ "golang.org/x/sys/windows/svc/mgr"
+ "golang.zx2c4.com/wireguard/windows/conf"
+ "os"
+ "time"
+)
+
+var cachedServiceManager *mgr.Mgr
+
+func serviceManager() (*mgr.Mgr, error) {
+ if cachedServiceManager != nil {
+ return cachedServiceManager, nil
+ }
+ m, err := mgr.Connect()
+ if err != nil {
+ return nil, err
+ }
+ cachedServiceManager = m
+ return cachedServiceManager, nil
+}
+
+func InstallTunnel(configPath string) error {
+ m, err := serviceManager()
+ if err != nil {
+ return err
+ }
+ path, err := os.Executable()
+ if err != nil {
+ return nil
+ }
+
+ name, err := conf.NameFromPath(configPath)
+ if err != nil {
+ return err
+ }
+
+ serviceName, err := conf.ServiceNameOfTunnel(name)
+ if err != nil {
+ return err
+ }
+ service, err := m.OpenService(serviceName)
+ if err == nil {
+ status, err := service.Query()
+ if err != nil && err != windows.ERROR_SERVICE_MARKED_FOR_DELETE {
+ service.Close()
+ return err
+ }
+ if status.State != svc.Stopped && err != windows.ERROR_SERVICE_MARKED_FOR_DELETE {
+ service.Close()
+ return errors.New("Tunnel already installed and running")
+ }
+ err = service.Delete()
+ service.Close()
+ if err != nil && err != windows.ERROR_SERVICE_MARKED_FOR_DELETE {
+ return err
+ }
+ for {
+ service, err = m.OpenService(serviceName)
+ if err != nil && err != windows.ERROR_SERVICE_MARKED_FOR_DELETE {
+ break
+ }
+ service.Close()
+ time.Sleep(time.Second / 3)
+ }
+ }
+
+ config := mgr.Config{
+ ServiceType: windows.SERVICE_WIN32_OWN_PROCESS,
+ StartType: mgr.StartAutomatic,
+ ErrorControl: mgr.ErrorNormal,
+ Dependencies: []string{"Nsi", "TcpIp"},
+ DisplayName: "WireGuard Tunnel: " + name,
+ SidType: windows.SERVICE_SID_TYPE_UNRESTRICTED,
+ }
+ service, err = m.CreateService(serviceName, path, config, "/tunnelservice", configPath)
+ if err != nil {
+ return err
+ }
+
+ err = service.Start()
+ go trackTunnelService(name, service) // Pass off reference to handle.
+ return err
+}
+
+func UninstallTunnel(name string) error {
+ m, err := serviceManager()
+ if err != nil {
+ return err
+ }
+ serviceName, err := conf.ServiceNameOfTunnel(name)
+ if err != nil {
+ return err
+ }
+ service, err := m.OpenService(serviceName)
+ if err != nil {
+ return err
+ }
+ service.Control(svc.Stop)
+ err = service.Delete()
+ err2 := service.Close()
+ if err != nil && err != windows.ERROR_SERVICE_MARKED_FOR_DELETE {
+ return err
+ }
+ return err2
+}
diff --git a/src/main/go/service/ipc_client.go b/src/main/go/service/ipc_client.go
new file mode 100644
index 0000000..261cbad
--- /dev/null
+++ b/src/main/go/service/ipc_client.go
@@ -0,0 +1,398 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
+ *
+ * Modified by Steven Tracey
+ */
+
+package service
+
+import (
+ "encoding/gob"
+ "errors"
+ "os"
+ "sync"
+
+ "golang.zx2c4.com/wireguard/windows/conf"
+)
+
+type Tunnel struct {
+ Name string
+}
+
+type TunnelState int
+
+const (
+ TunnelUnknown TunnelState = iota
+ TunnelStarted
+ TunnelStopped
+ TunnelStarting
+ TunnelStopping
+)
+
+type NotificationType int
+
+const (
+ TunnelChangeNotificationType NotificationType = iota
+ TunnelsChangeNotificationType
+ ManagerStoppingNotificationType
+ UpdateFoundNotificationType
+ UpdateProgressNotificationType
+)
+
+type MethodType int
+
+const (
+ StoredConfigMethodType MethodType = iota
+ RuntimeConfigMethodType
+ StartMethodType
+ StopMethodType
+ WaitForStopMethodType
+ DeleteMethodType
+ StateMethodType
+ GlobalStateMethodType
+ CreateMethodType
+ TunnelsMethodType
+ QuitMethodType
+ UpdateStateMethodType
+ UpdateMethodType
+)
+
+var (
+ rpcEncoder *gob.Encoder
+ rpcDecoder *gob.Decoder
+ rpcMutex sync.Mutex
+)
+
+type TunnelChangeCallback struct {
+ cb func(tunnel *Tunnel, state, globalState TunnelState, err error)
+}
+
+var tunnelChangeCallbacks = make(map[*TunnelChangeCallback]bool)
+
+type TunnelsChangeCallback struct {
+ cb func()
+}
+
+var tunnelsChangeCallbacks = make(map[*TunnelsChangeCallback]bool)
+
+type ManagerStoppingCallback struct {
+ cb func()
+}
+
+var managerStoppingCallbacks = make(map[*ManagerStoppingCallback]bool)
+
+func InitializeIPCClient(reader, writer, events *os.File) {
+ rpcDecoder = gob.NewDecoder(reader)
+ rpcEncoder = gob.NewEncoder(writer)
+ go func() {
+ decoder := gob.NewDecoder(events)
+ for {
+ var notificationType NotificationType
+ err := decoder.Decode(¬ificationType)
+ if err != nil {
+ return
+ }
+ switch notificationType {
+ case TunnelChangeNotificationType:
+ var tunnel string
+ err := decoder.Decode(&tunnel)
+ if err != nil || len(tunnel) == 0 {
+ continue
+ }
+ var state TunnelState
+ err = decoder.Decode(&state)
+ if err != nil {
+ continue
+ }
+ var globalState TunnelState
+ err = decoder.Decode(&globalState)
+ if err != nil {
+ continue
+ }
+ var errStr string
+ err = decoder.Decode(&errStr)
+ if err != nil {
+ continue
+ }
+ var retErr error
+ if len(errStr) > 0 {
+ retErr = errors.New(errStr)
+ }
+ if state == TunnelUnknown {
+ continue
+ }
+ t := &Tunnel{tunnel}
+ for cb := range tunnelChangeCallbacks {
+ cb.cb(t, state, globalState, retErr)
+ }
+ case TunnelsChangeNotificationType:
+ for cb := range tunnelsChangeCallbacks {
+ cb.cb()
+ }
+ case ManagerStoppingNotificationType:
+ for cb := range managerStoppingCallbacks {
+ cb.cb()
+ }
+ }
+ }
+ }()
+}
+
+func rpcDecodeError() error {
+ var str string
+ err := rpcDecoder.Decode(&str)
+ if err != nil {
+ return err
+ }
+ if len(str) == 0 {
+ return nil
+ }
+ return errors.New(str)
+}
+
+func (t *Tunnel) StoredConfig() (c conf.Config, err error) {
+ rpcMutex.Lock()
+ defer rpcMutex.Unlock()
+
+ err = rpcEncoder.Encode(StoredConfigMethodType)
+ if err != nil {
+ return
+ }
+ err = rpcEncoder.Encode(t.Name)
+ if err != nil {
+ return
+ }
+ err = rpcDecoder.Decode(&c)
+ if err != nil {
+ return
+ }
+ err = rpcDecodeError()
+ return
+}
+
+func (t *Tunnel) RuntimeConfig() (c conf.Config, err error) {
+ rpcMutex.Lock()
+ defer rpcMutex.Unlock()
+
+ err = rpcEncoder.Encode(RuntimeConfigMethodType)
+ if err != nil {
+ return
+ }
+ err = rpcEncoder.Encode(t.Name)
+ if err != nil {
+ return
+ }
+ err = rpcDecoder.Decode(&c)
+ if err != nil {
+ return
+ }
+ err = rpcDecodeError()
+ return
+}
+
+func (t *Tunnel) Start() (err error) {
+ rpcMutex.Lock()
+ defer rpcMutex.Unlock()
+
+ err = rpcEncoder.Encode(StartMethodType)
+ if err != nil {
+ return
+ }
+ err = rpcEncoder.Encode(t.Name)
+ if err != nil {
+ return
+ }
+ err = rpcDecodeError()
+ return
+}
+
+func (t *Tunnel) Stop() (err error) {
+ rpcMutex.Lock()
+ defer rpcMutex.Unlock()
+
+ err = rpcEncoder.Encode(StopMethodType)
+ if err != nil {
+ return
+ }
+ err = rpcEncoder.Encode(t.Name)
+ if err != nil {
+ return
+ }
+ err = rpcDecodeError()
+ return
+}
+
+func (t *Tunnel) Toggle() (oldState TunnelState, err error) {
+ oldState, err = t.State()
+ if err != nil {
+ oldState = TunnelUnknown
+ return
+ }
+ if oldState == TunnelStarted {
+ err = t.Stop()
+ } else if oldState == TunnelStopped {
+ err = t.Start()
+ }
+ return
+}
+
+func (t *Tunnel) WaitForStop() (err error) {
+ rpcMutex.Lock()
+ defer rpcMutex.Unlock()
+
+ err = rpcEncoder.Encode(WaitForStopMethodType)
+ if err != nil {
+ return
+ }
+ err = rpcEncoder.Encode(t.Name)
+ if err != nil {
+ return
+ }
+ err = rpcDecodeError()
+ return
+}
+
+func (t *Tunnel) Delete() (err error) {
+ rpcMutex.Lock()
+ defer rpcMutex.Unlock()
+
+ err = rpcEncoder.Encode(DeleteMethodType)
+ if err != nil {
+ return
+ }
+ err = rpcEncoder.Encode(t.Name)
+ if err != nil {
+ return
+ }
+ err = rpcDecodeError()
+ return
+}
+
+func (t *Tunnel) State() (tunnelState TunnelState, err error) {
+ rpcMutex.Lock()
+ defer rpcMutex.Unlock()
+
+ err = rpcEncoder.Encode(StateMethodType)
+ if err != nil {
+ return
+ }
+ err = rpcEncoder.Encode(t.Name)
+ if err != nil {
+ return
+ }
+ err = rpcDecoder.Decode(&tunnelState)
+ if err != nil {
+ return
+ }
+ err = rpcDecodeError()
+ return
+}
+
+func IPCClientGlobalState() (tunnelState TunnelState, err error) {
+ rpcMutex.Lock()
+ defer rpcMutex.Unlock()
+
+ err = rpcEncoder.Encode(GlobalStateMethodType)
+ if err != nil {
+ return
+ }
+ err = rpcDecoder.Decode(&tunnelState)
+ if err != nil {
+ return
+ }
+ return
+}
+
+func IPCClientNewTunnel(conf *conf.Config) (tunnel Tunnel, err error) {
+ rpcMutex.Lock()
+ defer rpcMutex.Unlock()
+
+ err = rpcEncoder.Encode(CreateMethodType)
+ if err != nil {
+ return
+ }
+ err = rpcEncoder.Encode(*conf)
+ if err != nil {
+ return
+ }
+ err = rpcDecoder.Decode(&tunnel)
+ if err != nil {
+ return
+ }
+ err = rpcDecodeError()
+ return
+}
+
+func IPCClientTunnels() (tunnels []Tunnel, err error) {
+ rpcMutex.Lock()
+ defer rpcMutex.Unlock()
+
+ err = rpcEncoder.Encode(TunnelsMethodType)
+ if err != nil {
+ return
+ }
+ err = rpcDecoder.Decode(&tunnels)
+ if err != nil {
+ return
+ }
+ err = rpcDecodeError()
+ return
+}
+
+func IPCClientQuit(stopTunnelsOnQuit bool) (alreadyQuit bool, err error) {
+ rpcMutex.Lock()
+ defer rpcMutex.Unlock()
+
+ err = rpcEncoder.Encode(QuitMethodType)
+ if err != nil {
+ return
+ }
+ err = rpcEncoder.Encode(stopTunnelsOnQuit)
+ if err != nil {
+ return
+ }
+ err = rpcDecoder.Decode(&alreadyQuit)
+ if err != nil {
+ return
+ }
+ err = rpcDecodeError()
+ return
+}
+
+func IPCClientUpdate() error {
+ rpcMutex.Lock()
+ defer rpcMutex.Unlock()
+
+ return rpcEncoder.Encode(UpdateMethodType)
+}
+
+func IPCClientRegisterTunnelChange(cb func(tunnel *Tunnel, state, globalState TunnelState, err error)) *TunnelChangeCallback {
+ s := &TunnelChangeCallback{cb}
+ tunnelChangeCallbacks[s] = true
+ return s
+}
+
+func (cb *TunnelChangeCallback) Unregister() {
+ delete(tunnelChangeCallbacks, cb)
+}
+
+func IPCClientRegisterTunnelsChange(cb func()) *TunnelsChangeCallback {
+ s := &TunnelsChangeCallback{cb}
+ tunnelsChangeCallbacks[s] = true
+ return s
+}
+
+func (cb *TunnelsChangeCallback) Unregister() {
+ delete(tunnelsChangeCallbacks, cb)
+}
+
+func IPCClientRegisterManagerStopping(cb func()) *ManagerStoppingCallback {
+ s := &ManagerStoppingCallback{cb}
+ managerStoppingCallbacks[s] = true
+ return s
+}
+
+func (cb *ManagerStoppingCallback) Unregister() {
+ delete(managerStoppingCallbacks, cb)
+}
diff --git a/src/main/go/service/ipc_driver.go b/src/main/go/service/ipc_driver.go
new file mode 100644
index 0000000..baa0c00
--- /dev/null
+++ b/src/main/go/service/ipc_driver.go
@@ -0,0 +1,63 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
+ *
+ * Modified by Steven Tracey
+ */
+
+package service
+
+import (
+ "sync"
+
+ "golang.zx2c4.com/wireguard/windows/driver"
+)
+
+type lockedDriverAdapter struct {
+ *driver.Adapter
+ sync.Mutex
+}
+
+var (
+ driverAdapters = make(map[string]*lockedDriverAdapter)
+ driverAdaptersLock sync.RWMutex
+)
+
+func findDriverAdapter(tunnelName string) (*lockedDriverAdapter, error) {
+ driverAdaptersLock.RLock()
+ driverAdapter, ok := driverAdapters[tunnelName]
+ if ok {
+ driverAdapter.Lock()
+ driverAdaptersLock.RUnlock()
+ return driverAdapter, nil
+ }
+ driverAdaptersLock.RUnlock()
+ driverAdaptersLock.Lock()
+ defer driverAdaptersLock.Unlock()
+ driverAdapter, ok = driverAdapters[tunnelName]
+ if ok {
+ driverAdapter.Lock()
+ return driverAdapter, nil
+ }
+ driverAdapter = &lockedDriverAdapter{}
+ var err error
+ driverAdapter.Adapter, err = driver.OpenAdapter(tunnelName)
+ if err != nil {
+ return nil, err
+ }
+ driverAdapters[tunnelName] = driverAdapter
+ driverAdapter.Lock()
+ return driverAdapter, nil
+}
+
+func releaseDriverAdapter(tunnelName string) {
+ driverAdaptersLock.Lock()
+ defer driverAdaptersLock.Unlock()
+ driverAdapter, ok := driverAdapters[tunnelName]
+ if !ok {
+ return
+ }
+ driverAdapter.Lock()
+ delete(driverAdapters, tunnelName)
+ driverAdapter.Unlock()
+}
diff --git a/src/main/go/service/ipc_server.go b/src/main/go/service/ipc_server.go
new file mode 100644
index 0000000..445218c
--- /dev/null
+++ b/src/main/go/service/ipc_server.go
@@ -0,0 +1,490 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
+ *
+ * Modified by Steven Tracey
+ */
+
+package service
+
+import (
+ "bytes"
+ "encoding/gob"
+ "fmt"
+ "io"
+ "log"
+ "os"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "golang.org/x/sys/windows"
+ "golang.org/x/sys/windows/svc"
+
+ "golang.zx2c4.com/wireguard/windows/conf"
+)
+
+var (
+ managerServices = make(map[*ManagerService]bool)
+ managerServicesLock sync.RWMutex
+ haveQuit uint32
+ quitManagersChan = make(chan struct{}, 1)
+)
+
+type ManagerService struct {
+ events *os.File
+ eventLock sync.Mutex
+ elevatedToken windows.Token
+}
+
+func (s *ManagerService) StoredConfig(tunnelName string) (*conf.Config, error) {
+ conf, err := conf.LoadFromName(tunnelName)
+ if err != nil {
+ return nil, err
+ }
+ if s.elevatedToken == 0 {
+ conf.Redact()
+ }
+ return conf, nil
+}
+
+func (s *ManagerService) RuntimeConfig(tunnelName string) (*conf.Config, error) {
+ storedConfig, err := conf.LoadFromName(tunnelName)
+ if err != nil {
+ return nil, err
+ }
+ driverAdapter, err := findDriverAdapter(tunnelName)
+ if err != nil {
+ return nil, err
+ }
+ runtimeConfig, err := driverAdapter.Configuration()
+ if err != nil {
+ driverAdapter.Unlock()
+ releaseDriverAdapter(tunnelName)
+ return nil, err
+ }
+ conf := conf.FromDriverConfiguration(runtimeConfig, storedConfig)
+ driverAdapter.Unlock()
+ if s.elevatedToken == 0 {
+ conf.Redact()
+ }
+ return conf, nil
+}
+
+func (s *ManagerService) Start(tunnelName string) error {
+ c, err := conf.LoadFromName(tunnelName)
+ if err != nil {
+ return err
+ }
+
+ // Figure out which tunnels have intersecting addresses/routes and stop those.
+ trackedTunnelsLock.Lock()
+ tt := make([]string, 0, len(trackedTunnels))
+ var inTransition string
+ for t, state := range trackedTunnels {
+ c2, err := conf.LoadFromName(t)
+ if err != nil || !c.IntersectsWith(c2) {
+ // If we can't get the config, assume it doesn't intersect.
+ continue
+ }
+ tt = append(tt, t)
+ if len(t) > 0 && (state == TunnelStarting || state == TunnelUnknown) {
+ inTransition = t
+ break
+ }
+ }
+ trackedTunnelsLock.Unlock()
+ if len(inTransition) != 0 {
+ return fmt.Errorf("Please allow the tunnel ā%sā to finish activating", inTransition)
+ }
+
+ // Stop those intersecting tunnels asynchronously.
+ go func() {
+ for _, t := range tt {
+ s.Stop(t)
+ }
+ for _, t := range tt {
+ state, err := s.State(t)
+ if err == nil && (state == TunnelStarted || state == TunnelStarting) {
+ log.Printf("[%s] Trying again to stop zombie tunnel", t)
+ s.Stop(t)
+ time.Sleep(time.Millisecond * 100)
+ }
+ }
+ }()
+ // After the stop process has begun, but before it's finished, we install the new one.
+ path, err := c.Path()
+ if err != nil {
+ return err
+ }
+ return InstallTunnel(path)
+}
+
+func (s *ManagerService) Stop(tunnelName string) error {
+ err := UninstallTunnel(tunnelName)
+ if err == windows.ERROR_SERVICE_DOES_NOT_EXIST {
+ _, notExistsError := conf.LoadFromName(tunnelName)
+ if notExistsError == nil {
+ return nil
+ }
+ }
+ return err
+}
+
+func (s *ManagerService) WaitForStop(tunnelName string) error {
+ serviceName, err := conf.ServiceNameOfTunnel(tunnelName)
+ if err != nil {
+ return err
+ }
+ m, err := serviceManager()
+ if err != nil {
+ return err
+ }
+ for {
+ service, err := m.OpenService(serviceName)
+ if err == nil || err == windows.ERROR_SERVICE_MARKED_FOR_DELETE {
+ service.Close()
+ time.Sleep(time.Second / 3)
+ } else {
+ return nil
+ }
+ }
+}
+
+func (s *ManagerService) Delete(tunnelName string) error {
+ if s.elevatedToken == 0 {
+ return windows.ERROR_ACCESS_DENIED
+ }
+ err := s.Stop(tunnelName)
+ if err != nil {
+ return err
+ }
+ return conf.DeleteName(tunnelName)
+}
+
+func (s *ManagerService) State(tunnelName string) (TunnelState, error) {
+ serviceName, err := conf.ServiceNameOfTunnel(tunnelName)
+ if err != nil {
+ return 0, err
+ }
+ m, err := serviceManager()
+ if err != nil {
+ return 0, err
+ }
+ service, err := m.OpenService(serviceName)
+ if err != nil {
+ return TunnelStopped, nil
+ }
+ defer service.Close()
+ status, err := service.Query()
+ if err != nil {
+ return TunnelUnknown, nil
+ }
+ switch status.State {
+ case svc.Stopped:
+ return TunnelStopped, nil
+ case svc.StopPending:
+ return TunnelStopping, nil
+ case svc.Running:
+ return TunnelStarted, nil
+ case svc.StartPending:
+ return TunnelStarting, nil
+ default:
+ return TunnelUnknown, nil
+ }
+}
+
+func (s *ManagerService) GlobalState() TunnelState {
+ return trackedTunnelsGlobalState()
+}
+
+func (s *ManagerService) Create(tunnelConfig *conf.Config) (*Tunnel, error) {
+ if s.elevatedToken == 0 {
+ return nil, windows.ERROR_ACCESS_DENIED
+ }
+ err := tunnelConfig.Save(true)
+ if err != nil {
+ return nil, err
+ }
+ return &Tunnel{tunnelConfig.Name}, nil
+ // TODO: handle already existing situation
+ // TODO: handle already running and existing situation
+}
+
+func (s *ManagerService) Tunnels() ([]Tunnel, error) {
+ names, err := conf.ListConfigNames()
+ if err != nil {
+ return nil, err
+ }
+ tunnels := make([]Tunnel, len(names))
+ for i := 0; i < len(tunnels); i++ {
+ tunnels[i].Name = names[i]
+ }
+ return tunnels, nil
+ // TODO: account for running ones that aren't in the configuration store somehow
+}
+
+func (s *ManagerService) Quit(stopTunnelsOnQuit bool) (alreadyQuit bool, err error) {
+ if s.elevatedToken == 0 {
+ return false, windows.ERROR_ACCESS_DENIED
+ }
+ if !atomic.CompareAndSwapUint32(&haveQuit, 0, 1) {
+ return true, nil
+ }
+
+ // Work around potential race condition of delivering messages to the wrong process by removing from notifications.
+ managerServicesLock.Lock()
+ s.eventLock.Lock()
+ s.events = nil
+ s.eventLock.Unlock()
+ delete(managerServices, s)
+ managerServicesLock.Unlock()
+
+ if stopTunnelsOnQuit {
+ names, err := conf.ListConfigNames()
+ if err != nil {
+ return false, err
+ }
+ for _, name := range names {
+ UninstallTunnel(name)
+ }
+ }
+
+ quitManagersChan <- struct{}{}
+ return false, nil
+}
+
+func (s *ManagerService) ServeConn(reader io.Reader, writer io.Writer) {
+ decoder := gob.NewDecoder(reader)
+ encoder := gob.NewEncoder(writer)
+ for {
+ var methodType MethodType
+ err := decoder.Decode(&methodType)
+ if err != nil {
+ return
+ }
+ switch methodType {
+ case StoredConfigMethodType:
+ var tunnelName string
+ err := decoder.Decode(&tunnelName)
+ if err != nil {
+ return
+ }
+ config, retErr := s.StoredConfig(tunnelName)
+ if config == nil {
+ config = &conf.Config{}
+ }
+ err = encoder.Encode(*config)
+ if err != nil {
+ return
+ }
+ err = encoder.Encode(errToString(retErr))
+ if err != nil {
+ return
+ }
+ case RuntimeConfigMethodType:
+ var tunnelName string
+ err := decoder.Decode(&tunnelName)
+ if err != nil {
+ return
+ }
+ config, retErr := s.RuntimeConfig(tunnelName)
+ if config == nil {
+ config = &conf.Config{}
+ }
+ err = encoder.Encode(*config)
+ if err != nil {
+ return
+ }
+ err = encoder.Encode(errToString(retErr))
+ if err != nil {
+ return
+ }
+ case StartMethodType:
+ var tunnelName string
+ err := decoder.Decode(&tunnelName)
+ if err != nil {
+ return
+ }
+ retErr := s.Start(tunnelName)
+ err = encoder.Encode(errToString(retErr))
+ if err != nil {
+ return
+ }
+ case StopMethodType:
+ var tunnelName string
+ err := decoder.Decode(&tunnelName)
+ if err != nil {
+ return
+ }
+ retErr := s.Stop(tunnelName)
+ err = encoder.Encode(errToString(retErr))
+ if err != nil {
+ return
+ }
+ case WaitForStopMethodType:
+ var tunnelName string
+ err := decoder.Decode(&tunnelName)
+ if err != nil {
+ return
+ }
+ retErr := s.WaitForStop(tunnelName)
+ err = encoder.Encode(errToString(retErr))
+ if err != nil {
+ return
+ }
+ case DeleteMethodType:
+ var tunnelName string
+ err := decoder.Decode(&tunnelName)
+ if err != nil {
+ return
+ }
+ retErr := s.Delete(tunnelName)
+ err = encoder.Encode(errToString(retErr))
+ if err != nil {
+ return
+ }
+ case StateMethodType:
+ var tunnelName string
+ err := decoder.Decode(&tunnelName)
+ if err != nil {
+ return
+ }
+ state, retErr := s.State(tunnelName)
+ err = encoder.Encode(state)
+ if err != nil {
+ return
+ }
+ err = encoder.Encode(errToString(retErr))
+ if err != nil {
+ return
+ }
+ case GlobalStateMethodType:
+ state := s.GlobalState()
+ err = encoder.Encode(state)
+ if err != nil {
+ return
+ }
+ case CreateMethodType:
+ var config conf.Config
+ err := decoder.Decode(&config)
+ if err != nil {
+ return
+ }
+ tunnel, retErr := s.Create(&config)
+ if tunnel == nil {
+ tunnel = &Tunnel{}
+ }
+ err = encoder.Encode(tunnel)
+ if err != nil {
+ return
+ }
+ err = encoder.Encode(errToString(retErr))
+ if err != nil {
+ return
+ }
+ case TunnelsMethodType:
+ tunnels, retErr := s.Tunnels()
+ err = encoder.Encode(tunnels)
+ if err != nil {
+ return
+ }
+ err = encoder.Encode(errToString(retErr))
+ if err != nil {
+ return
+ }
+ case QuitMethodType:
+ var stopTunnelsOnQuit bool
+ err := decoder.Decode(&stopTunnelsOnQuit)
+ if err != nil {
+ return
+ }
+ alreadyQuit, retErr := s.Quit(stopTunnelsOnQuit)
+ err = encoder.Encode(alreadyQuit)
+ if err != nil {
+ return
+ }
+ err = encoder.Encode(errToString(retErr))
+ if err != nil {
+ return
+ }
+ default:
+ return
+ }
+ }
+}
+
+func IPCServerListen(reader, writer, events *os.File, elevatedToken windows.Token) {
+ service := &ManagerService{
+ events: events,
+ elevatedToken: elevatedToken,
+ }
+
+ go func() {
+ managerServicesLock.Lock()
+ managerServices[service] = true
+ managerServicesLock.Unlock()
+ service.ServeConn(reader, writer)
+ managerServicesLock.Lock()
+ service.eventLock.Lock()
+ service.events = nil
+ service.eventLock.Unlock()
+ delete(managerServices, service)
+ managerServicesLock.Unlock()
+ }()
+}
+
+func notifyAll(notificationType NotificationType, adminOnly bool, ifaces ...any) {
+ if len(managerServices) == 0 {
+ return
+ }
+
+ var buf bytes.Buffer
+ encoder := gob.NewEncoder(&buf)
+ err := encoder.Encode(notificationType)
+ if err != nil {
+ return
+ }
+ for _, iface := range ifaces {
+ err = encoder.Encode(iface)
+ if err != nil {
+ return
+ }
+ }
+
+ managerServicesLock.RLock()
+ for m := range managerServices {
+ if m.elevatedToken == 0 && adminOnly {
+ continue
+ }
+ go func(m *ManagerService) {
+ m.eventLock.Lock()
+ defer m.eventLock.Unlock()
+ if m.events != nil {
+ m.events.SetWriteDeadline(time.Now().Add(time.Second))
+ m.events.Write(buf.Bytes())
+ }
+ }(m)
+ }
+ managerServicesLock.RUnlock()
+}
+
+func errToString(err error) string {
+ if err == nil {
+ return ""
+ }
+ return err.Error()
+}
+
+func IPCServerNotifyTunnelChange(name string, state TunnelState, err error) {
+ notifyAll(TunnelChangeNotificationType, false, name, state, trackedTunnelsGlobalState(), errToString(err))
+}
+
+func IPCServerNotifyTunnelsChange() {
+ notifyAll(TunnelsChangeNotificationType, false)
+}
+
+func IPCServerNotifyManagerStopping() {
+ notifyAll(ManagerStoppingNotificationType, false)
+ time.Sleep(time.Millisecond * 200)
+}
diff --git a/src/main/go/service/tunneltracker.go b/src/main/go/service/tunneltracker.go
new file mode 100644
index 0000000..0c5d51e
--- /dev/null
+++ b/src/main/go/service/tunneltracker.go
@@ -0,0 +1,345 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
+ *
+ * Modified by Steven Tracey
+ */
+
+package service
+
+import (
+ "errors"
+ "fmt"
+ "golang.zx2c4.com/wireguard/windows/services"
+ "log"
+ "runtime"
+ "sync"
+ "sync/atomic"
+ "syscall"
+ "time"
+ "unsafe"
+
+ "golang.org/x/sys/windows"
+ "golang.org/x/sys/windows/svc"
+ "golang.org/x/sys/windows/svc/mgr"
+
+ "golang.zx2c4.com/wireguard/windows/conf"
+)
+
+var (
+ trackedTunnels = make(map[string]TunnelState)
+ trackedTunnelsLock = sync.Mutex{}
+)
+
+func trackedTunnelsGlobalState() (state TunnelState) {
+ state = TunnelStopped
+ trackedTunnelsLock.Lock()
+ defer trackedTunnelsLock.Unlock()
+ for _, s := range trackedTunnels {
+ if s == TunnelStarting {
+ return TunnelStarting
+ } else if s == TunnelStopping {
+ return TunnelStopping
+ } else if s == TunnelStarted || s == TunnelUnknown {
+ state = TunnelStarted
+ }
+ }
+ return
+}
+
+var serviceTrackerCallbackPtr = windows.NewCallback(func(notifier *windows.SERVICE_NOTIFY) uintptr {
+ return 0
+})
+
+type serviceSubscriptionState struct {
+ service *mgr.Service
+ cb func(status uint32) bool
+ done sync.WaitGroup
+ once uint32
+}
+
+var serviceSubscriptionCallbackPtr = windows.NewCallback(func(notification uint32, context uintptr) uintptr {
+ state := (*serviceSubscriptionState)(unsafe.Pointer(context))
+ if atomic.LoadUint32(&state.once) != 0 {
+ return 0
+ }
+ if notification == 0 {
+ status, err := state.service.Query()
+ if err == nil {
+ notification = svcStateToNotifyState(uint32(status.State))
+ }
+ }
+ if state.cb(notification) && atomic.CompareAndSwapUint32(&state.once, 0, 1) {
+ state.done.Done()
+ }
+ return 0
+})
+
+func svcStateToNotifyState(s uint32) uint32 {
+ switch s {
+ case windows.SERVICE_STOPPED:
+ return windows.SERVICE_NOTIFY_STOPPED
+ case windows.SERVICE_START_PENDING:
+ return windows.SERVICE_NOTIFY_START_PENDING
+ case windows.SERVICE_STOP_PENDING:
+ return windows.SERVICE_NOTIFY_STOP_PENDING
+ case windows.SERVICE_RUNNING:
+ return windows.SERVICE_NOTIFY_RUNNING
+ case windows.SERVICE_CONTINUE_PENDING:
+ return windows.SERVICE_NOTIFY_CONTINUE_PENDING
+ case windows.SERVICE_PAUSE_PENDING:
+ return windows.SERVICE_NOTIFY_PAUSE_PENDING
+ case windows.SERVICE_PAUSED:
+ return windows.SERVICE_NOTIFY_PAUSED
+ case windows.SERVICE_NO_CHANGE:
+ return 0
+ default:
+ return 0
+ }
+}
+
+func notifyStateToTunState(s uint32) TunnelState {
+ if s&(windows.SERVICE_NOTIFY_STOPPED|windows.SERVICE_NOTIFY_DELETED) != 0 {
+ return TunnelStopped
+ } else if s&(windows.SERVICE_NOTIFY_DELETE_PENDING|windows.SERVICE_NOTIFY_STOP_PENDING) != 0 {
+ return TunnelStopping
+ } else if s&windows.SERVICE_NOTIFY_RUNNING != 0 {
+ return TunnelStarted
+ } else if s&windows.SERVICE_NOTIFY_START_PENDING != 0 {
+ return TunnelStarting
+ } else {
+ return TunnelUnknown
+ }
+}
+
+func trackService(service *mgr.Service, callback func(status uint32) bool) error {
+ var subscription uintptr
+ state := &serviceSubscriptionState{service: service, cb: callback}
+ state.done.Add(1)
+ err := windows.SubscribeServiceChangeNotifications(service.Handle, windows.SC_EVENT_STATUS_CHANGE, serviceSubscriptionCallbackPtr, uintptr(unsafe.Pointer(state)), &subscription)
+ if err == nil {
+ defer windows.UnsubscribeServiceChangeNotifications(subscription)
+ status, err := service.Query()
+ if err == nil {
+ if callback(svcStateToNotifyState(uint32(status.State))) {
+ return nil
+ }
+ }
+ state.done.Wait()
+ runtime.KeepAlive(state.cb)
+ return nil
+ }
+ if !errors.Is(err, windows.ERROR_PROC_NOT_FOUND) {
+ return err
+ }
+
+ // TODO: Below this line is Windows 7 compatibility code, which hopefully we can delete at some point.
+
+ runtime.LockOSThread()
+ // This line would be fitting but is intentionally commented out:
+ //
+ // defer runtime.UnlockOSThread()
+ //
+ // The reason is that NotifyServiceStatusChange used queued APC, which winds up messing
+ // with the thread local context, which in turn appears to corrupt Go's own usage of TLS,
+ // leading to crashes sometime later (usually in runtime_unlock()) when the thread is recycled.
+
+ const serviceNotifications = windows.SERVICE_NOTIFY_RUNNING | windows.SERVICE_NOTIFY_START_PENDING | windows.SERVICE_NOTIFY_STOP_PENDING | windows.SERVICE_NOTIFY_STOPPED | windows.SERVICE_NOTIFY_DELETE_PENDING
+ notifier := &windows.SERVICE_NOTIFY{
+ Version: windows.SERVICE_NOTIFY_STATUS_CHANGE,
+ NotifyCallback: serviceTrackerCallbackPtr,
+ }
+ for {
+ err := windows.NotifyServiceStatusChange(service.Handle, serviceNotifications, notifier)
+ switch err {
+ case nil:
+ for {
+ if windows.SleepEx(uint32(time.Second*3/time.Millisecond), true) == windows.WAIT_IO_COMPLETION {
+ break
+ } else if callback(0) {
+ return nil
+ }
+ }
+ case windows.ERROR_SERVICE_MARKED_FOR_DELETE:
+ // Should be SERVICE_NOTIFY_DELETE_PENDING, but actually, we must release the handle and return here; otherwise it never deletes.
+ if callback(windows.SERVICE_NOTIFY_DELETED) {
+ return nil
+ }
+ case windows.ERROR_SERVICE_NOTIFY_CLIENT_LAGGING:
+ continue
+ default:
+ return err
+ }
+ if callback(svcStateToNotifyState(notifier.ServiceStatus.CurrentState)) {
+ return nil
+ }
+ }
+}
+
+func trackTunnelService(tunnelName string, service *mgr.Service) {
+ trackedTunnelsLock.Lock()
+ if _, found := trackedTunnels[tunnelName]; found {
+ trackedTunnelsLock.Unlock()
+ service.Close()
+ return
+ }
+
+ defer func() {
+ service.Close()
+ log.Printf("[%s] Tunnel service tracker finished", tunnelName)
+ }()
+ trackedTunnels[tunnelName] = TunnelUnknown
+ trackedTunnelsLock.Unlock()
+ defer func() {
+ trackedTunnelsLock.Lock()
+ delete(trackedTunnels, tunnelName)
+ trackedTunnelsLock.Unlock()
+ }()
+
+ for i := 0; i < 20; i++ {
+ if i > 0 {
+ time.Sleep(time.Second / 5)
+ }
+ if status, err := service.Query(); err != nil || status.State != svc.Stopped {
+ break
+ }
+ }
+
+ checkForDisabled := func() (shouldReturn bool) {
+ config, err := service.Config()
+ if err == windows.ERROR_SERVICE_MARKED_FOR_DELETE || (err != nil && config.StartType == windows.SERVICE_DISABLED) {
+ log.Printf("[%s] Found disabled service via timeout, so deleting", tunnelName)
+ service.Delete()
+ trackedTunnelsLock.Lock()
+ trackedTunnels[tunnelName] = TunnelStopped
+ trackedTunnelsLock.Unlock()
+ IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, nil)
+ return true
+ }
+ return false
+ }
+ if checkForDisabled() {
+ return
+ }
+ lastState := TunnelUnknown
+ err := trackService(service, func(status uint32) bool {
+ state := notifyStateToTunState(status)
+ var tunnelError error
+ if state == TunnelStopped {
+ serviceStatus, err := service.Query()
+ if err == nil {
+ if serviceStatus.Win32ExitCode == uint32(windows.ERROR_SERVICE_SPECIFIC_ERROR) {
+ maybeErr := services.Error(serviceStatus.ServiceSpecificExitCode)
+ if maybeErr != services.ErrorSuccess {
+ tunnelError = maybeErr
+ }
+ } else {
+ switch serviceStatus.Win32ExitCode {
+ case uint32(windows.NO_ERROR), uint32(windows.ERROR_SERVICE_NEVER_STARTED):
+ default:
+ tunnelError = syscall.Errno(serviceStatus.Win32ExitCode)
+ }
+ }
+ }
+ if tunnelError != nil {
+ service.Delete()
+ }
+ }
+ if state != lastState {
+ trackedTunnelsLock.Lock()
+ trackedTunnels[tunnelName] = state
+ trackedTunnelsLock.Unlock()
+ IPCServerNotifyTunnelChange(tunnelName, state, tunnelError)
+ lastState = state
+ }
+ if state == TunnelUnknown && checkForDisabled() {
+ return true
+ }
+ return state == TunnelStopped
+ })
+ if err != nil && !checkForDisabled() {
+ trackedTunnelsLock.Lock()
+ trackedTunnels[tunnelName] = TunnelStopped
+ trackedTunnelsLock.Unlock()
+ IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, fmt.Errorf("Unable to continue monitoring service, so stopping: %w", err))
+ service.Control(svc.Stop)
+ }
+}
+
+func trackExistingTunnels() error {
+ m, err := serviceManager()
+ if err != nil {
+ return err
+ }
+ names, err := conf.ListConfigNames()
+ if err != nil {
+ return err
+ }
+ for _, name := range names {
+ trackedTunnelsLock.Lock()
+ if _, found := trackedTunnels[name]; found {
+ trackedTunnelsLock.Unlock()
+ continue
+ }
+ trackedTunnelsLock.Unlock()
+ serviceName, err := conf.ServiceNameOfTunnel(name)
+ if err != nil {
+ continue
+ }
+ service, err := m.OpenService(serviceName)
+ if err != nil {
+ continue
+ }
+ go trackTunnelService(name, service)
+ }
+ return nil
+}
+
+var servicesSubscriptionWatcherCallbackPtr = windows.NewCallback(func(notification uint32, context uintptr) uintptr {
+ trackExistingTunnels()
+ return 0
+})
+
+func watchNewTunnelServices() error {
+ m, err := serviceManager()
+ if err != nil {
+ return err
+ }
+ var subscription uintptr
+ err = windows.SubscribeServiceChangeNotifications(m.Handle, windows.SC_EVENT_DATABASE_CHANGE, servicesSubscriptionWatcherCallbackPtr, 0, &subscription)
+ if err == nil {
+ // We probably could do:
+ // defer windows.UnsubscribeServiceChangeNotifications(subscription)
+ // and then terminate after some point, but instead we just let this go forever; it's process-lived.
+ return trackExistingTunnels()
+ }
+ if !errors.Is(err, windows.ERROR_PROC_NOT_FOUND) {
+ return err
+ }
+
+ // TODO: Below this line is Windows 7 compatibility code, which hopefully we can delete at some point.
+ go func() {
+ runtime.LockOSThread()
+ notifier := &windows.SERVICE_NOTIFY{
+ Version: windows.SERVICE_NOTIFY_STATUS_CHANGE,
+ NotifyCallback: serviceTrackerCallbackPtr,
+ }
+ for {
+ err := windows.NotifyServiceStatusChange(m.Handle, windows.SERVICE_NOTIFY_CREATED, notifier)
+ if err == nil {
+ windows.SleepEx(windows.INFINITE, true)
+ if notifier.ServiceNames != nil {
+ windows.LocalFree(windows.Handle(unsafe.Pointer(notifier.ServiceNames)))
+ notifier.ServiceNames = nil
+ }
+ trackExistingTunnels()
+ } else if err == windows.ERROR_SERVICE_NOTIFY_CLIENT_LAGGING {
+ continue
+ } else {
+ time.Sleep(time.Second * 3)
+ trackExistingTunnels()
+ }
+ }
+ }()
+ return trackExistingTunnels()
+}
diff --git a/src/main/java/tech/nevets/tvpn/wg/WireGuardJNI.java b/src/main/java/tech/nevets/tvpn/wg/WireGuardJNI.java
index dc11f5e..0dfdae4 100644
--- a/src/main/java/tech/nevets/tvpn/wg/WireGuardJNI.java
+++ b/src/main/java/tech/nevets/tvpn/wg/WireGuardJNI.java
@@ -34,7 +34,7 @@ public class WireGuardJNI {
public native int initializeWireGuard(String configFilePath);
public native void cleanup();
- public native String getWGVersion();
+ public native long getWGVersion();
public native long parseConfig(WGConf conf);
// Other methods for WireGuard interaction can be added here if needed