From 0ca48ba6949264eec050b9d7a09fa85e603e1dc2 Mon Sep 17 00:00:00 2001 From: Steven Tracey Date: Tue, 29 Oct 2024 08:25:53 -0400 Subject: [PATCH] Added GoLang :sob: --- .idea/modules.xml | 8 - src/main/cpp/.idea/editor.xml | 580 ++++++++++++++++++ src/main/cpp/.idea/misc.xml | 3 + src/main/cpp/CMakeLists.txt | 3 +- src/main/go/.idea/.gitignore | 8 + src/main/go/.idea/vcs.xml | 4 + src/main/go/go.mod | 13 + src/main/go/go.sum | 8 + src/main/go/main.go | 45 ++ src/main/go/service/install.go | 117 ++++ src/main/go/service/ipc_client.go | 398 ++++++++++++ src/main/go/service/ipc_driver.go | 63 ++ src/main/go/service/ipc_server.go | 490 +++++++++++++++ src/main/go/service/tunneltracker.go | 345 +++++++++++ .../tech/nevets/tvpn/wg/WireGuardJNI.java | 2 +- 15 files changed, 2076 insertions(+), 11 deletions(-) delete mode 100644 .idea/modules.xml create mode 100644 src/main/cpp/.idea/editor.xml create mode 100644 src/main/go/.idea/.gitignore create mode 100644 src/main/go/.idea/vcs.xml create mode 100644 src/main/go/go.mod create mode 100644 src/main/go/go.sum create mode 100644 src/main/go/main.go create mode 100644 src/main/go/service/install.go create mode 100644 src/main/go/service/ipc_client.go create mode 100644 src/main/go/service/ipc_driver.go create mode 100644 src/main/go/service/ipc_server.go create mode 100644 src/main/go/service/tunneltracker.go diff --git a/.idea/modules.xml b/.idea/modules.xml deleted file mode 100644 index 053b796..0000000 --- a/.idea/modules.xml +++ /dev/null @@ -1,8 +0,0 @@ - - - - - - - - \ No newline at end of file diff --git a/src/main/cpp/.idea/editor.xml b/src/main/cpp/.idea/editor.xml new file mode 100644 index 0000000..226ca24 --- /dev/null +++ b/src/main/cpp/.idea/editor.xml @@ -0,0 +1,580 @@ + + + + + \ No newline at end of file diff --git a/src/main/cpp/.idea/misc.xml b/src/main/cpp/.idea/misc.xml index 79b3c94..0b76fe5 100644 --- a/src/main/cpp/.idea/misc.xml +++ b/src/main/cpp/.idea/misc.xml @@ -1,4 +1,7 @@ + + \ No newline at end of file diff --git a/src/main/cpp/CMakeLists.txt b/src/main/cpp/CMakeLists.txt index e5e7f5e..6862d06 100644 --- a/src/main/cpp/CMakeLists.txt +++ b/src/main/cpp/CMakeLists.txt @@ -15,8 +15,7 @@ set(JNI_INCLUDE_DIR2 "${JAVA_HOME}/include/win32") include_directories(${JNI_INCLUDE_DIR} ${JNI_INCLUDE_DIR2}) # Add the source file that includes your JNI C wrapper code -add_library(wireguard_wrapper SHARED wireguard_wrapper.c - example.c) +add_library(wireguard_wrapper SHARED wireguard_wrapper.c) #target_link_libraries(wireguard_wrapper PRIVATE Rpcrt4) target_link_libraries(wireguard_wrapper ws2_32 rpcrt4) diff --git a/src/main/go/.idea/.gitignore b/src/main/go/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/src/main/go/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/src/main/go/.idea/vcs.xml b/src/main/go/.idea/vcs.xml new file mode 100644 index 0000000..d843f34 --- /dev/null +++ b/src/main/go/.idea/vcs.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/src/main/go/go.mod b/src/main/go/go.mod new file mode 100644 index 0000000..0d63826 --- /dev/null +++ b/src/main/go/go.mod @@ -0,0 +1,13 @@ +module git.nevets.tech/Steven/TVPN + +go 1.23.2 + +require ( + golang.org/x/sys v0.26.0 + golang.zx2c4.com/wireguard/windows v0.5.3 +) + +require ( + golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3 // indirect + golang.org/x/text v0.3.8-0.20211105212822-18b340fc7af2 // indirect +) diff --git a/src/main/go/go.sum b/src/main/go/go.sum new file mode 100644 index 0000000..e3a7754 --- /dev/null +++ b/src/main/go/go.sum @@ -0,0 +1,8 @@ +golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3 h1:0es+/5331RGQPcXlMfP+WrnIIS6dNnNRe0WB02W0F4M= +golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= +golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.3.8-0.20211105212822-18b340fc7af2 h1:GLw7MR8AfAG2GmGcmVgObFOHXYypgGjnGno25RDwn3Y= +golang.org/x/text v0.3.8-0.20211105212822-18b340fc7af2/go.mod h1:EFNZuWvGYxIRUEX+K8UmCFwYmZjqcrnq15ZuVldZkZ0= +golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE= +golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI= diff --git a/src/main/go/main.go b/src/main/go/main.go new file mode 100644 index 0000000..28d0975 --- /dev/null +++ b/src/main/go/main.go @@ -0,0 +1,45 @@ +package main + +import ( + "crypto/rand" + "golang.org/x/crypto/curve25519" + "golang.org/x/sys/windows" + "golang.zx2c4.com/wireguard/windows/conf" + "golang.zx2c4.com/wireguard/windows/tunnel" + "log" + "path/filepath" + "unsafe" +) + +//export InstallTunnelService +func InstallTunnelService(confPath string) { + +} + +//export CreateTunnelService +func CreateTunnelService(confFile16 *uint16) bool { + confFile := windows.UTF16PtrToString(confFile16) + conf.PresetRootDirectory(filepath.Dir(confFile)) + tunnel.UseFixedGUIDInsteadOfDeterministic = true + err := tunnel.Run(confFile) + if err != nil { + log.Printf("Service run error: %v", err) + } + return err == nil +} + +//export GenerateKeyPair +func GenerateKeypair(publicKey, privateKey *byte) { + publicKeyArray := (*[32]byte)(unsafe.Pointer(publicKey)) + privateKeyArray := (*[32]byte)(unsafe.Pointer(privateKey)) + n, err := rand.Read(privateKeyArray[:]) + if err != nil || n != len(privateKeyArray) { + panic("Unable to generate random bytes") + } + privateKeyArray[0] &= 248 + privateKeyArray[31] = (privateKeyArray[31] & 127) | 64 + + curve25519.ScalarBaseMult(publicKeyArray, privateKeyArray) +} + +func main() {} diff --git a/src/main/go/service/install.go b/src/main/go/service/install.go new file mode 100644 index 0000000..d16bbec --- /dev/null +++ b/src/main/go/service/install.go @@ -0,0 +1,117 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. + * + * Modified by Steven Tracey + */ + +package service + +import ( + "errors" + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/svc" + "golang.org/x/sys/windows/svc/mgr" + "golang.zx2c4.com/wireguard/windows/conf" + "os" + "time" +) + +var cachedServiceManager *mgr.Mgr + +func serviceManager() (*mgr.Mgr, error) { + if cachedServiceManager != nil { + return cachedServiceManager, nil + } + m, err := mgr.Connect() + if err != nil { + return nil, err + } + cachedServiceManager = m + return cachedServiceManager, nil +} + +func InstallTunnel(configPath string) error { + m, err := serviceManager() + if err != nil { + return err + } + path, err := os.Executable() + if err != nil { + return nil + } + + name, err := conf.NameFromPath(configPath) + if err != nil { + return err + } + + serviceName, err := conf.ServiceNameOfTunnel(name) + if err != nil { + return err + } + service, err := m.OpenService(serviceName) + if err == nil { + status, err := service.Query() + if err != nil && err != windows.ERROR_SERVICE_MARKED_FOR_DELETE { + service.Close() + return err + } + if status.State != svc.Stopped && err != windows.ERROR_SERVICE_MARKED_FOR_DELETE { + service.Close() + return errors.New("Tunnel already installed and running") + } + err = service.Delete() + service.Close() + if err != nil && err != windows.ERROR_SERVICE_MARKED_FOR_DELETE { + return err + } + for { + service, err = m.OpenService(serviceName) + if err != nil && err != windows.ERROR_SERVICE_MARKED_FOR_DELETE { + break + } + service.Close() + time.Sleep(time.Second / 3) + } + } + + config := mgr.Config{ + ServiceType: windows.SERVICE_WIN32_OWN_PROCESS, + StartType: mgr.StartAutomatic, + ErrorControl: mgr.ErrorNormal, + Dependencies: []string{"Nsi", "TcpIp"}, + DisplayName: "WireGuard Tunnel: " + name, + SidType: windows.SERVICE_SID_TYPE_UNRESTRICTED, + } + service, err = m.CreateService(serviceName, path, config, "/tunnelservice", configPath) + if err != nil { + return err + } + + err = service.Start() + go trackTunnelService(name, service) // Pass off reference to handle. + return err +} + +func UninstallTunnel(name string) error { + m, err := serviceManager() + if err != nil { + return err + } + serviceName, err := conf.ServiceNameOfTunnel(name) + if err != nil { + return err + } + service, err := m.OpenService(serviceName) + if err != nil { + return err + } + service.Control(svc.Stop) + err = service.Delete() + err2 := service.Close() + if err != nil && err != windows.ERROR_SERVICE_MARKED_FOR_DELETE { + return err + } + return err2 +} diff --git a/src/main/go/service/ipc_client.go b/src/main/go/service/ipc_client.go new file mode 100644 index 0000000..261cbad --- /dev/null +++ b/src/main/go/service/ipc_client.go @@ -0,0 +1,398 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. + * + * Modified by Steven Tracey + */ + +package service + +import ( + "encoding/gob" + "errors" + "os" + "sync" + + "golang.zx2c4.com/wireguard/windows/conf" +) + +type Tunnel struct { + Name string +} + +type TunnelState int + +const ( + TunnelUnknown TunnelState = iota + TunnelStarted + TunnelStopped + TunnelStarting + TunnelStopping +) + +type NotificationType int + +const ( + TunnelChangeNotificationType NotificationType = iota + TunnelsChangeNotificationType + ManagerStoppingNotificationType + UpdateFoundNotificationType + UpdateProgressNotificationType +) + +type MethodType int + +const ( + StoredConfigMethodType MethodType = iota + RuntimeConfigMethodType + StartMethodType + StopMethodType + WaitForStopMethodType + DeleteMethodType + StateMethodType + GlobalStateMethodType + CreateMethodType + TunnelsMethodType + QuitMethodType + UpdateStateMethodType + UpdateMethodType +) + +var ( + rpcEncoder *gob.Encoder + rpcDecoder *gob.Decoder + rpcMutex sync.Mutex +) + +type TunnelChangeCallback struct { + cb func(tunnel *Tunnel, state, globalState TunnelState, err error) +} + +var tunnelChangeCallbacks = make(map[*TunnelChangeCallback]bool) + +type TunnelsChangeCallback struct { + cb func() +} + +var tunnelsChangeCallbacks = make(map[*TunnelsChangeCallback]bool) + +type ManagerStoppingCallback struct { + cb func() +} + +var managerStoppingCallbacks = make(map[*ManagerStoppingCallback]bool) + +func InitializeIPCClient(reader, writer, events *os.File) { + rpcDecoder = gob.NewDecoder(reader) + rpcEncoder = gob.NewEncoder(writer) + go func() { + decoder := gob.NewDecoder(events) + for { + var notificationType NotificationType + err := decoder.Decode(¬ificationType) + if err != nil { + return + } + switch notificationType { + case TunnelChangeNotificationType: + var tunnel string + err := decoder.Decode(&tunnel) + if err != nil || len(tunnel) == 0 { + continue + } + var state TunnelState + err = decoder.Decode(&state) + if err != nil { + continue + } + var globalState TunnelState + err = decoder.Decode(&globalState) + if err != nil { + continue + } + var errStr string + err = decoder.Decode(&errStr) + if err != nil { + continue + } + var retErr error + if len(errStr) > 0 { + retErr = errors.New(errStr) + } + if state == TunnelUnknown { + continue + } + t := &Tunnel{tunnel} + for cb := range tunnelChangeCallbacks { + cb.cb(t, state, globalState, retErr) + } + case TunnelsChangeNotificationType: + for cb := range tunnelsChangeCallbacks { + cb.cb() + } + case ManagerStoppingNotificationType: + for cb := range managerStoppingCallbacks { + cb.cb() + } + } + } + }() +} + +func rpcDecodeError() error { + var str string + err := rpcDecoder.Decode(&str) + if err != nil { + return err + } + if len(str) == 0 { + return nil + } + return errors.New(str) +} + +func (t *Tunnel) StoredConfig() (c conf.Config, err error) { + rpcMutex.Lock() + defer rpcMutex.Unlock() + + err = rpcEncoder.Encode(StoredConfigMethodType) + if err != nil { + return + } + err = rpcEncoder.Encode(t.Name) + if err != nil { + return + } + err = rpcDecoder.Decode(&c) + if err != nil { + return + } + err = rpcDecodeError() + return +} + +func (t *Tunnel) RuntimeConfig() (c conf.Config, err error) { + rpcMutex.Lock() + defer rpcMutex.Unlock() + + err = rpcEncoder.Encode(RuntimeConfigMethodType) + if err != nil { + return + } + err = rpcEncoder.Encode(t.Name) + if err != nil { + return + } + err = rpcDecoder.Decode(&c) + if err != nil { + return + } + err = rpcDecodeError() + return +} + +func (t *Tunnel) Start() (err error) { + rpcMutex.Lock() + defer rpcMutex.Unlock() + + err = rpcEncoder.Encode(StartMethodType) + if err != nil { + return + } + err = rpcEncoder.Encode(t.Name) + if err != nil { + return + } + err = rpcDecodeError() + return +} + +func (t *Tunnel) Stop() (err error) { + rpcMutex.Lock() + defer rpcMutex.Unlock() + + err = rpcEncoder.Encode(StopMethodType) + if err != nil { + return + } + err = rpcEncoder.Encode(t.Name) + if err != nil { + return + } + err = rpcDecodeError() + return +} + +func (t *Tunnel) Toggle() (oldState TunnelState, err error) { + oldState, err = t.State() + if err != nil { + oldState = TunnelUnknown + return + } + if oldState == TunnelStarted { + err = t.Stop() + } else if oldState == TunnelStopped { + err = t.Start() + } + return +} + +func (t *Tunnel) WaitForStop() (err error) { + rpcMutex.Lock() + defer rpcMutex.Unlock() + + err = rpcEncoder.Encode(WaitForStopMethodType) + if err != nil { + return + } + err = rpcEncoder.Encode(t.Name) + if err != nil { + return + } + err = rpcDecodeError() + return +} + +func (t *Tunnel) Delete() (err error) { + rpcMutex.Lock() + defer rpcMutex.Unlock() + + err = rpcEncoder.Encode(DeleteMethodType) + if err != nil { + return + } + err = rpcEncoder.Encode(t.Name) + if err != nil { + return + } + err = rpcDecodeError() + return +} + +func (t *Tunnel) State() (tunnelState TunnelState, err error) { + rpcMutex.Lock() + defer rpcMutex.Unlock() + + err = rpcEncoder.Encode(StateMethodType) + if err != nil { + return + } + err = rpcEncoder.Encode(t.Name) + if err != nil { + return + } + err = rpcDecoder.Decode(&tunnelState) + if err != nil { + return + } + err = rpcDecodeError() + return +} + +func IPCClientGlobalState() (tunnelState TunnelState, err error) { + rpcMutex.Lock() + defer rpcMutex.Unlock() + + err = rpcEncoder.Encode(GlobalStateMethodType) + if err != nil { + return + } + err = rpcDecoder.Decode(&tunnelState) + if err != nil { + return + } + return +} + +func IPCClientNewTunnel(conf *conf.Config) (tunnel Tunnel, err error) { + rpcMutex.Lock() + defer rpcMutex.Unlock() + + err = rpcEncoder.Encode(CreateMethodType) + if err != nil { + return + } + err = rpcEncoder.Encode(*conf) + if err != nil { + return + } + err = rpcDecoder.Decode(&tunnel) + if err != nil { + return + } + err = rpcDecodeError() + return +} + +func IPCClientTunnels() (tunnels []Tunnel, err error) { + rpcMutex.Lock() + defer rpcMutex.Unlock() + + err = rpcEncoder.Encode(TunnelsMethodType) + if err != nil { + return + } + err = rpcDecoder.Decode(&tunnels) + if err != nil { + return + } + err = rpcDecodeError() + return +} + +func IPCClientQuit(stopTunnelsOnQuit bool) (alreadyQuit bool, err error) { + rpcMutex.Lock() + defer rpcMutex.Unlock() + + err = rpcEncoder.Encode(QuitMethodType) + if err != nil { + return + } + err = rpcEncoder.Encode(stopTunnelsOnQuit) + if err != nil { + return + } + err = rpcDecoder.Decode(&alreadyQuit) + if err != nil { + return + } + err = rpcDecodeError() + return +} + +func IPCClientUpdate() error { + rpcMutex.Lock() + defer rpcMutex.Unlock() + + return rpcEncoder.Encode(UpdateMethodType) +} + +func IPCClientRegisterTunnelChange(cb func(tunnel *Tunnel, state, globalState TunnelState, err error)) *TunnelChangeCallback { + s := &TunnelChangeCallback{cb} + tunnelChangeCallbacks[s] = true + return s +} + +func (cb *TunnelChangeCallback) Unregister() { + delete(tunnelChangeCallbacks, cb) +} + +func IPCClientRegisterTunnelsChange(cb func()) *TunnelsChangeCallback { + s := &TunnelsChangeCallback{cb} + tunnelsChangeCallbacks[s] = true + return s +} + +func (cb *TunnelsChangeCallback) Unregister() { + delete(tunnelsChangeCallbacks, cb) +} + +func IPCClientRegisterManagerStopping(cb func()) *ManagerStoppingCallback { + s := &ManagerStoppingCallback{cb} + managerStoppingCallbacks[s] = true + return s +} + +func (cb *ManagerStoppingCallback) Unregister() { + delete(managerStoppingCallbacks, cb) +} diff --git a/src/main/go/service/ipc_driver.go b/src/main/go/service/ipc_driver.go new file mode 100644 index 0000000..baa0c00 --- /dev/null +++ b/src/main/go/service/ipc_driver.go @@ -0,0 +1,63 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. + * + * Modified by Steven Tracey + */ + +package service + +import ( + "sync" + + "golang.zx2c4.com/wireguard/windows/driver" +) + +type lockedDriverAdapter struct { + *driver.Adapter + sync.Mutex +} + +var ( + driverAdapters = make(map[string]*lockedDriverAdapter) + driverAdaptersLock sync.RWMutex +) + +func findDriverAdapter(tunnelName string) (*lockedDriverAdapter, error) { + driverAdaptersLock.RLock() + driverAdapter, ok := driverAdapters[tunnelName] + if ok { + driverAdapter.Lock() + driverAdaptersLock.RUnlock() + return driverAdapter, nil + } + driverAdaptersLock.RUnlock() + driverAdaptersLock.Lock() + defer driverAdaptersLock.Unlock() + driverAdapter, ok = driverAdapters[tunnelName] + if ok { + driverAdapter.Lock() + return driverAdapter, nil + } + driverAdapter = &lockedDriverAdapter{} + var err error + driverAdapter.Adapter, err = driver.OpenAdapter(tunnelName) + if err != nil { + return nil, err + } + driverAdapters[tunnelName] = driverAdapter + driverAdapter.Lock() + return driverAdapter, nil +} + +func releaseDriverAdapter(tunnelName string) { + driverAdaptersLock.Lock() + defer driverAdaptersLock.Unlock() + driverAdapter, ok := driverAdapters[tunnelName] + if !ok { + return + } + driverAdapter.Lock() + delete(driverAdapters, tunnelName) + driverAdapter.Unlock() +} diff --git a/src/main/go/service/ipc_server.go b/src/main/go/service/ipc_server.go new file mode 100644 index 0000000..445218c --- /dev/null +++ b/src/main/go/service/ipc_server.go @@ -0,0 +1,490 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. + * + * Modified by Steven Tracey + */ + +package service + +import ( + "bytes" + "encoding/gob" + "fmt" + "io" + "log" + "os" + "sync" + "sync/atomic" + "time" + + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/svc" + + "golang.zx2c4.com/wireguard/windows/conf" +) + +var ( + managerServices = make(map[*ManagerService]bool) + managerServicesLock sync.RWMutex + haveQuit uint32 + quitManagersChan = make(chan struct{}, 1) +) + +type ManagerService struct { + events *os.File + eventLock sync.Mutex + elevatedToken windows.Token +} + +func (s *ManagerService) StoredConfig(tunnelName string) (*conf.Config, error) { + conf, err := conf.LoadFromName(tunnelName) + if err != nil { + return nil, err + } + if s.elevatedToken == 0 { + conf.Redact() + } + return conf, nil +} + +func (s *ManagerService) RuntimeConfig(tunnelName string) (*conf.Config, error) { + storedConfig, err := conf.LoadFromName(tunnelName) + if err != nil { + return nil, err + } + driverAdapter, err := findDriverAdapter(tunnelName) + if err != nil { + return nil, err + } + runtimeConfig, err := driverAdapter.Configuration() + if err != nil { + driverAdapter.Unlock() + releaseDriverAdapter(tunnelName) + return nil, err + } + conf := conf.FromDriverConfiguration(runtimeConfig, storedConfig) + driverAdapter.Unlock() + if s.elevatedToken == 0 { + conf.Redact() + } + return conf, nil +} + +func (s *ManagerService) Start(tunnelName string) error { + c, err := conf.LoadFromName(tunnelName) + if err != nil { + return err + } + + // Figure out which tunnels have intersecting addresses/routes and stop those. + trackedTunnelsLock.Lock() + tt := make([]string, 0, len(trackedTunnels)) + var inTransition string + for t, state := range trackedTunnels { + c2, err := conf.LoadFromName(t) + if err != nil || !c.IntersectsWith(c2) { + // If we can't get the config, assume it doesn't intersect. + continue + } + tt = append(tt, t) + if len(t) > 0 && (state == TunnelStarting || state == TunnelUnknown) { + inTransition = t + break + } + } + trackedTunnelsLock.Unlock() + if len(inTransition) != 0 { + return fmt.Errorf("Please allow the tunnel ā€˜%sā€™ to finish activating", inTransition) + } + + // Stop those intersecting tunnels asynchronously. + go func() { + for _, t := range tt { + s.Stop(t) + } + for _, t := range tt { + state, err := s.State(t) + if err == nil && (state == TunnelStarted || state == TunnelStarting) { + log.Printf("[%s] Trying again to stop zombie tunnel", t) + s.Stop(t) + time.Sleep(time.Millisecond * 100) + } + } + }() + // After the stop process has begun, but before it's finished, we install the new one. + path, err := c.Path() + if err != nil { + return err + } + return InstallTunnel(path) +} + +func (s *ManagerService) Stop(tunnelName string) error { + err := UninstallTunnel(tunnelName) + if err == windows.ERROR_SERVICE_DOES_NOT_EXIST { + _, notExistsError := conf.LoadFromName(tunnelName) + if notExistsError == nil { + return nil + } + } + return err +} + +func (s *ManagerService) WaitForStop(tunnelName string) error { + serviceName, err := conf.ServiceNameOfTunnel(tunnelName) + if err != nil { + return err + } + m, err := serviceManager() + if err != nil { + return err + } + for { + service, err := m.OpenService(serviceName) + if err == nil || err == windows.ERROR_SERVICE_MARKED_FOR_DELETE { + service.Close() + time.Sleep(time.Second / 3) + } else { + return nil + } + } +} + +func (s *ManagerService) Delete(tunnelName string) error { + if s.elevatedToken == 0 { + return windows.ERROR_ACCESS_DENIED + } + err := s.Stop(tunnelName) + if err != nil { + return err + } + return conf.DeleteName(tunnelName) +} + +func (s *ManagerService) State(tunnelName string) (TunnelState, error) { + serviceName, err := conf.ServiceNameOfTunnel(tunnelName) + if err != nil { + return 0, err + } + m, err := serviceManager() + if err != nil { + return 0, err + } + service, err := m.OpenService(serviceName) + if err != nil { + return TunnelStopped, nil + } + defer service.Close() + status, err := service.Query() + if err != nil { + return TunnelUnknown, nil + } + switch status.State { + case svc.Stopped: + return TunnelStopped, nil + case svc.StopPending: + return TunnelStopping, nil + case svc.Running: + return TunnelStarted, nil + case svc.StartPending: + return TunnelStarting, nil + default: + return TunnelUnknown, nil + } +} + +func (s *ManagerService) GlobalState() TunnelState { + return trackedTunnelsGlobalState() +} + +func (s *ManagerService) Create(tunnelConfig *conf.Config) (*Tunnel, error) { + if s.elevatedToken == 0 { + return nil, windows.ERROR_ACCESS_DENIED + } + err := tunnelConfig.Save(true) + if err != nil { + return nil, err + } + return &Tunnel{tunnelConfig.Name}, nil + // TODO: handle already existing situation + // TODO: handle already running and existing situation +} + +func (s *ManagerService) Tunnels() ([]Tunnel, error) { + names, err := conf.ListConfigNames() + if err != nil { + return nil, err + } + tunnels := make([]Tunnel, len(names)) + for i := 0; i < len(tunnels); i++ { + tunnels[i].Name = names[i] + } + return tunnels, nil + // TODO: account for running ones that aren't in the configuration store somehow +} + +func (s *ManagerService) Quit(stopTunnelsOnQuit bool) (alreadyQuit bool, err error) { + if s.elevatedToken == 0 { + return false, windows.ERROR_ACCESS_DENIED + } + if !atomic.CompareAndSwapUint32(&haveQuit, 0, 1) { + return true, nil + } + + // Work around potential race condition of delivering messages to the wrong process by removing from notifications. + managerServicesLock.Lock() + s.eventLock.Lock() + s.events = nil + s.eventLock.Unlock() + delete(managerServices, s) + managerServicesLock.Unlock() + + if stopTunnelsOnQuit { + names, err := conf.ListConfigNames() + if err != nil { + return false, err + } + for _, name := range names { + UninstallTunnel(name) + } + } + + quitManagersChan <- struct{}{} + return false, nil +} + +func (s *ManagerService) ServeConn(reader io.Reader, writer io.Writer) { + decoder := gob.NewDecoder(reader) + encoder := gob.NewEncoder(writer) + for { + var methodType MethodType + err := decoder.Decode(&methodType) + if err != nil { + return + } + switch methodType { + case StoredConfigMethodType: + var tunnelName string + err := decoder.Decode(&tunnelName) + if err != nil { + return + } + config, retErr := s.StoredConfig(tunnelName) + if config == nil { + config = &conf.Config{} + } + err = encoder.Encode(*config) + if err != nil { + return + } + err = encoder.Encode(errToString(retErr)) + if err != nil { + return + } + case RuntimeConfigMethodType: + var tunnelName string + err := decoder.Decode(&tunnelName) + if err != nil { + return + } + config, retErr := s.RuntimeConfig(tunnelName) + if config == nil { + config = &conf.Config{} + } + err = encoder.Encode(*config) + if err != nil { + return + } + err = encoder.Encode(errToString(retErr)) + if err != nil { + return + } + case StartMethodType: + var tunnelName string + err := decoder.Decode(&tunnelName) + if err != nil { + return + } + retErr := s.Start(tunnelName) + err = encoder.Encode(errToString(retErr)) + if err != nil { + return + } + case StopMethodType: + var tunnelName string + err := decoder.Decode(&tunnelName) + if err != nil { + return + } + retErr := s.Stop(tunnelName) + err = encoder.Encode(errToString(retErr)) + if err != nil { + return + } + case WaitForStopMethodType: + var tunnelName string + err := decoder.Decode(&tunnelName) + if err != nil { + return + } + retErr := s.WaitForStop(tunnelName) + err = encoder.Encode(errToString(retErr)) + if err != nil { + return + } + case DeleteMethodType: + var tunnelName string + err := decoder.Decode(&tunnelName) + if err != nil { + return + } + retErr := s.Delete(tunnelName) + err = encoder.Encode(errToString(retErr)) + if err != nil { + return + } + case StateMethodType: + var tunnelName string + err := decoder.Decode(&tunnelName) + if err != nil { + return + } + state, retErr := s.State(tunnelName) + err = encoder.Encode(state) + if err != nil { + return + } + err = encoder.Encode(errToString(retErr)) + if err != nil { + return + } + case GlobalStateMethodType: + state := s.GlobalState() + err = encoder.Encode(state) + if err != nil { + return + } + case CreateMethodType: + var config conf.Config + err := decoder.Decode(&config) + if err != nil { + return + } + tunnel, retErr := s.Create(&config) + if tunnel == nil { + tunnel = &Tunnel{} + } + err = encoder.Encode(tunnel) + if err != nil { + return + } + err = encoder.Encode(errToString(retErr)) + if err != nil { + return + } + case TunnelsMethodType: + tunnels, retErr := s.Tunnels() + err = encoder.Encode(tunnels) + if err != nil { + return + } + err = encoder.Encode(errToString(retErr)) + if err != nil { + return + } + case QuitMethodType: + var stopTunnelsOnQuit bool + err := decoder.Decode(&stopTunnelsOnQuit) + if err != nil { + return + } + alreadyQuit, retErr := s.Quit(stopTunnelsOnQuit) + err = encoder.Encode(alreadyQuit) + if err != nil { + return + } + err = encoder.Encode(errToString(retErr)) + if err != nil { + return + } + default: + return + } + } +} + +func IPCServerListen(reader, writer, events *os.File, elevatedToken windows.Token) { + service := &ManagerService{ + events: events, + elevatedToken: elevatedToken, + } + + go func() { + managerServicesLock.Lock() + managerServices[service] = true + managerServicesLock.Unlock() + service.ServeConn(reader, writer) + managerServicesLock.Lock() + service.eventLock.Lock() + service.events = nil + service.eventLock.Unlock() + delete(managerServices, service) + managerServicesLock.Unlock() + }() +} + +func notifyAll(notificationType NotificationType, adminOnly bool, ifaces ...any) { + if len(managerServices) == 0 { + return + } + + var buf bytes.Buffer + encoder := gob.NewEncoder(&buf) + err := encoder.Encode(notificationType) + if err != nil { + return + } + for _, iface := range ifaces { + err = encoder.Encode(iface) + if err != nil { + return + } + } + + managerServicesLock.RLock() + for m := range managerServices { + if m.elevatedToken == 0 && adminOnly { + continue + } + go func(m *ManagerService) { + m.eventLock.Lock() + defer m.eventLock.Unlock() + if m.events != nil { + m.events.SetWriteDeadline(time.Now().Add(time.Second)) + m.events.Write(buf.Bytes()) + } + }(m) + } + managerServicesLock.RUnlock() +} + +func errToString(err error) string { + if err == nil { + return "" + } + return err.Error() +} + +func IPCServerNotifyTunnelChange(name string, state TunnelState, err error) { + notifyAll(TunnelChangeNotificationType, false, name, state, trackedTunnelsGlobalState(), errToString(err)) +} + +func IPCServerNotifyTunnelsChange() { + notifyAll(TunnelsChangeNotificationType, false) +} + +func IPCServerNotifyManagerStopping() { + notifyAll(ManagerStoppingNotificationType, false) + time.Sleep(time.Millisecond * 200) +} diff --git a/src/main/go/service/tunneltracker.go b/src/main/go/service/tunneltracker.go new file mode 100644 index 0000000..0c5d51e --- /dev/null +++ b/src/main/go/service/tunneltracker.go @@ -0,0 +1,345 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. + * + * Modified by Steven Tracey + */ + +package service + +import ( + "errors" + "fmt" + "golang.zx2c4.com/wireguard/windows/services" + "log" + "runtime" + "sync" + "sync/atomic" + "syscall" + "time" + "unsafe" + + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/svc" + "golang.org/x/sys/windows/svc/mgr" + + "golang.zx2c4.com/wireguard/windows/conf" +) + +var ( + trackedTunnels = make(map[string]TunnelState) + trackedTunnelsLock = sync.Mutex{} +) + +func trackedTunnelsGlobalState() (state TunnelState) { + state = TunnelStopped + trackedTunnelsLock.Lock() + defer trackedTunnelsLock.Unlock() + for _, s := range trackedTunnels { + if s == TunnelStarting { + return TunnelStarting + } else if s == TunnelStopping { + return TunnelStopping + } else if s == TunnelStarted || s == TunnelUnknown { + state = TunnelStarted + } + } + return +} + +var serviceTrackerCallbackPtr = windows.NewCallback(func(notifier *windows.SERVICE_NOTIFY) uintptr { + return 0 +}) + +type serviceSubscriptionState struct { + service *mgr.Service + cb func(status uint32) bool + done sync.WaitGroup + once uint32 +} + +var serviceSubscriptionCallbackPtr = windows.NewCallback(func(notification uint32, context uintptr) uintptr { + state := (*serviceSubscriptionState)(unsafe.Pointer(context)) + if atomic.LoadUint32(&state.once) != 0 { + return 0 + } + if notification == 0 { + status, err := state.service.Query() + if err == nil { + notification = svcStateToNotifyState(uint32(status.State)) + } + } + if state.cb(notification) && atomic.CompareAndSwapUint32(&state.once, 0, 1) { + state.done.Done() + } + return 0 +}) + +func svcStateToNotifyState(s uint32) uint32 { + switch s { + case windows.SERVICE_STOPPED: + return windows.SERVICE_NOTIFY_STOPPED + case windows.SERVICE_START_PENDING: + return windows.SERVICE_NOTIFY_START_PENDING + case windows.SERVICE_STOP_PENDING: + return windows.SERVICE_NOTIFY_STOP_PENDING + case windows.SERVICE_RUNNING: + return windows.SERVICE_NOTIFY_RUNNING + case windows.SERVICE_CONTINUE_PENDING: + return windows.SERVICE_NOTIFY_CONTINUE_PENDING + case windows.SERVICE_PAUSE_PENDING: + return windows.SERVICE_NOTIFY_PAUSE_PENDING + case windows.SERVICE_PAUSED: + return windows.SERVICE_NOTIFY_PAUSED + case windows.SERVICE_NO_CHANGE: + return 0 + default: + return 0 + } +} + +func notifyStateToTunState(s uint32) TunnelState { + if s&(windows.SERVICE_NOTIFY_STOPPED|windows.SERVICE_NOTIFY_DELETED) != 0 { + return TunnelStopped + } else if s&(windows.SERVICE_NOTIFY_DELETE_PENDING|windows.SERVICE_NOTIFY_STOP_PENDING) != 0 { + return TunnelStopping + } else if s&windows.SERVICE_NOTIFY_RUNNING != 0 { + return TunnelStarted + } else if s&windows.SERVICE_NOTIFY_START_PENDING != 0 { + return TunnelStarting + } else { + return TunnelUnknown + } +} + +func trackService(service *mgr.Service, callback func(status uint32) bool) error { + var subscription uintptr + state := &serviceSubscriptionState{service: service, cb: callback} + state.done.Add(1) + err := windows.SubscribeServiceChangeNotifications(service.Handle, windows.SC_EVENT_STATUS_CHANGE, serviceSubscriptionCallbackPtr, uintptr(unsafe.Pointer(state)), &subscription) + if err == nil { + defer windows.UnsubscribeServiceChangeNotifications(subscription) + status, err := service.Query() + if err == nil { + if callback(svcStateToNotifyState(uint32(status.State))) { + return nil + } + } + state.done.Wait() + runtime.KeepAlive(state.cb) + return nil + } + if !errors.Is(err, windows.ERROR_PROC_NOT_FOUND) { + return err + } + + // TODO: Below this line is Windows 7 compatibility code, which hopefully we can delete at some point. + + runtime.LockOSThread() + // This line would be fitting but is intentionally commented out: + // + // defer runtime.UnlockOSThread() + // + // The reason is that NotifyServiceStatusChange used queued APC, which winds up messing + // with the thread local context, which in turn appears to corrupt Go's own usage of TLS, + // leading to crashes sometime later (usually in runtime_unlock()) when the thread is recycled. + + const serviceNotifications = windows.SERVICE_NOTIFY_RUNNING | windows.SERVICE_NOTIFY_START_PENDING | windows.SERVICE_NOTIFY_STOP_PENDING | windows.SERVICE_NOTIFY_STOPPED | windows.SERVICE_NOTIFY_DELETE_PENDING + notifier := &windows.SERVICE_NOTIFY{ + Version: windows.SERVICE_NOTIFY_STATUS_CHANGE, + NotifyCallback: serviceTrackerCallbackPtr, + } + for { + err := windows.NotifyServiceStatusChange(service.Handle, serviceNotifications, notifier) + switch err { + case nil: + for { + if windows.SleepEx(uint32(time.Second*3/time.Millisecond), true) == windows.WAIT_IO_COMPLETION { + break + } else if callback(0) { + return nil + } + } + case windows.ERROR_SERVICE_MARKED_FOR_DELETE: + // Should be SERVICE_NOTIFY_DELETE_PENDING, but actually, we must release the handle and return here; otherwise it never deletes. + if callback(windows.SERVICE_NOTIFY_DELETED) { + return nil + } + case windows.ERROR_SERVICE_NOTIFY_CLIENT_LAGGING: + continue + default: + return err + } + if callback(svcStateToNotifyState(notifier.ServiceStatus.CurrentState)) { + return nil + } + } +} + +func trackTunnelService(tunnelName string, service *mgr.Service) { + trackedTunnelsLock.Lock() + if _, found := trackedTunnels[tunnelName]; found { + trackedTunnelsLock.Unlock() + service.Close() + return + } + + defer func() { + service.Close() + log.Printf("[%s] Tunnel service tracker finished", tunnelName) + }() + trackedTunnels[tunnelName] = TunnelUnknown + trackedTunnelsLock.Unlock() + defer func() { + trackedTunnelsLock.Lock() + delete(trackedTunnels, tunnelName) + trackedTunnelsLock.Unlock() + }() + + for i := 0; i < 20; i++ { + if i > 0 { + time.Sleep(time.Second / 5) + } + if status, err := service.Query(); err != nil || status.State != svc.Stopped { + break + } + } + + checkForDisabled := func() (shouldReturn bool) { + config, err := service.Config() + if err == windows.ERROR_SERVICE_MARKED_FOR_DELETE || (err != nil && config.StartType == windows.SERVICE_DISABLED) { + log.Printf("[%s] Found disabled service via timeout, so deleting", tunnelName) + service.Delete() + trackedTunnelsLock.Lock() + trackedTunnels[tunnelName] = TunnelStopped + trackedTunnelsLock.Unlock() + IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, nil) + return true + } + return false + } + if checkForDisabled() { + return + } + lastState := TunnelUnknown + err := trackService(service, func(status uint32) bool { + state := notifyStateToTunState(status) + var tunnelError error + if state == TunnelStopped { + serviceStatus, err := service.Query() + if err == nil { + if serviceStatus.Win32ExitCode == uint32(windows.ERROR_SERVICE_SPECIFIC_ERROR) { + maybeErr := services.Error(serviceStatus.ServiceSpecificExitCode) + if maybeErr != services.ErrorSuccess { + tunnelError = maybeErr + } + } else { + switch serviceStatus.Win32ExitCode { + case uint32(windows.NO_ERROR), uint32(windows.ERROR_SERVICE_NEVER_STARTED): + default: + tunnelError = syscall.Errno(serviceStatus.Win32ExitCode) + } + } + } + if tunnelError != nil { + service.Delete() + } + } + if state != lastState { + trackedTunnelsLock.Lock() + trackedTunnels[tunnelName] = state + trackedTunnelsLock.Unlock() + IPCServerNotifyTunnelChange(tunnelName, state, tunnelError) + lastState = state + } + if state == TunnelUnknown && checkForDisabled() { + return true + } + return state == TunnelStopped + }) + if err != nil && !checkForDisabled() { + trackedTunnelsLock.Lock() + trackedTunnels[tunnelName] = TunnelStopped + trackedTunnelsLock.Unlock() + IPCServerNotifyTunnelChange(tunnelName, TunnelStopped, fmt.Errorf("Unable to continue monitoring service, so stopping: %w", err)) + service.Control(svc.Stop) + } +} + +func trackExistingTunnels() error { + m, err := serviceManager() + if err != nil { + return err + } + names, err := conf.ListConfigNames() + if err != nil { + return err + } + for _, name := range names { + trackedTunnelsLock.Lock() + if _, found := trackedTunnels[name]; found { + trackedTunnelsLock.Unlock() + continue + } + trackedTunnelsLock.Unlock() + serviceName, err := conf.ServiceNameOfTunnel(name) + if err != nil { + continue + } + service, err := m.OpenService(serviceName) + if err != nil { + continue + } + go trackTunnelService(name, service) + } + return nil +} + +var servicesSubscriptionWatcherCallbackPtr = windows.NewCallback(func(notification uint32, context uintptr) uintptr { + trackExistingTunnels() + return 0 +}) + +func watchNewTunnelServices() error { + m, err := serviceManager() + if err != nil { + return err + } + var subscription uintptr + err = windows.SubscribeServiceChangeNotifications(m.Handle, windows.SC_EVENT_DATABASE_CHANGE, servicesSubscriptionWatcherCallbackPtr, 0, &subscription) + if err == nil { + // We probably could do: + // defer windows.UnsubscribeServiceChangeNotifications(subscription) + // and then terminate after some point, but instead we just let this go forever; it's process-lived. + return trackExistingTunnels() + } + if !errors.Is(err, windows.ERROR_PROC_NOT_FOUND) { + return err + } + + // TODO: Below this line is Windows 7 compatibility code, which hopefully we can delete at some point. + go func() { + runtime.LockOSThread() + notifier := &windows.SERVICE_NOTIFY{ + Version: windows.SERVICE_NOTIFY_STATUS_CHANGE, + NotifyCallback: serviceTrackerCallbackPtr, + } + for { + err := windows.NotifyServiceStatusChange(m.Handle, windows.SERVICE_NOTIFY_CREATED, notifier) + if err == nil { + windows.SleepEx(windows.INFINITE, true) + if notifier.ServiceNames != nil { + windows.LocalFree(windows.Handle(unsafe.Pointer(notifier.ServiceNames))) + notifier.ServiceNames = nil + } + trackExistingTunnels() + } else if err == windows.ERROR_SERVICE_NOTIFY_CLIENT_LAGGING { + continue + } else { + time.Sleep(time.Second * 3) + trackExistingTunnels() + } + } + }() + return trackExistingTunnels() +} diff --git a/src/main/java/tech/nevets/tvpn/wg/WireGuardJNI.java b/src/main/java/tech/nevets/tvpn/wg/WireGuardJNI.java index dc11f5e..0dfdae4 100644 --- a/src/main/java/tech/nevets/tvpn/wg/WireGuardJNI.java +++ b/src/main/java/tech/nevets/tvpn/wg/WireGuardJNI.java @@ -34,7 +34,7 @@ public class WireGuardJNI { public native int initializeWireGuard(String configFilePath); public native void cleanup(); - public native String getWGVersion(); + public native long getWGVersion(); public native long parseConfig(WGConf conf); // Other methods for WireGuard interaction can be added here if needed