diff --git a/cmd/launcher/gui/gui_windows.go b/cmd/launcher/gui/gui_windows.go index 62508f8..3f0be53 100644 --- a/cmd/launcher/gui/gui_windows.go +++ b/cmd/launcher/gui/gui_windows.go @@ -1,7 +1,6 @@ package gui import ( - "unicode/utf16" "unsafe" "github.com/setlog/trivrost/pkg/system" @@ -145,7 +144,7 @@ func applyWindowStyle(handle uintptr) { } func loadIcons() { - binaryPath := goStringToConstantUTF16WinApiString(system.GetBinaryPath()) + binaryPath := C.LPCWSTR(system.StringToUTF16UnmanagedString(system.GetBinaryPath())) extractedIconCount := C.loadIcons(binaryPath) didLoadIcons = true C.free(unsafe.Pointer(binaryPath)) @@ -160,17 +159,6 @@ func loadIcons() { } } -func goStringToConstantUTF16WinApiString(s string) C.LPCWSTR { - utf16String := utf16.Encode([]rune(s)) - utf16StringPointer := (*uint16)(C.calloc(C.size_t(len(utf16String)+1), C.size_t(unsafe.Sizeof(uint16(0))))) - currentCharPointer := utf16StringPointer - for _, c := range utf16String { - *currentCharPointer = c - currentCharPointer = (*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(currentCharPointer)) + unsafe.Sizeof(uint16(0)))) - } - return (C.LPCWSTR)(unsafe.Pointer(utf16StringPointer)) -} - func setProgressState(s progressState) { C.setProgressBarState(C.ULONG_PTR(panelDownloadStatus.barTotalProgress.Handle()), C.int(s)) } diff --git a/cmd/launcher/launcher/install.go b/cmd/launcher/launcher/install.go index 365364c..ed1f079 100644 --- a/cmd/launcher/launcher/install.go +++ b/cmd/launcher/launcher/install.go @@ -50,7 +50,7 @@ func IsInstanceInstalledInSystemMode() bool { // IsInstanceInstalledForCurrentUser returns true iff the launcher's desired path under user files is occupied by the program running this code. func IsInstanceInstalledForCurrentUser() bool { - return system.GetProgramPath() == getTargetProgramPath() + return system.FilepathsEquivalent(system.GetProgramPath(), getTargetProgramPath()) } // IsInstallationOutdated returns true if the time the installed launcher binary was built diff --git a/pkg/system/api_unix.go b/pkg/system/api_unix.go index 752ebbd..1fae6ed 100644 --- a/pkg/system/api_unix.go +++ b/pkg/system/api_unix.go @@ -47,3 +47,7 @@ func showLocalFileInFileManager(path string) error { func isProcessRunning(p *os.Process) bool { return p.Signal(unix.Signal(0)) == nil } + +func universalPathName(p string) (string, error) { + return p, nil +} diff --git a/pkg/system/api_windows.go b/pkg/system/api_windows.go index 48be9ed..dfd2f63 100644 --- a/pkg/system/api_windows.go +++ b/pkg/system/api_windows.go @@ -6,11 +6,15 @@ import ( "os/exec" "runtime" "strings" + "unicode/utf16" + "unsafe" "golang.org/x/sys/windows" ) +// #cgo LDFLAGS: -lMpr //#include +//#include import "C" func mustDetectArchitecture() { @@ -59,3 +63,130 @@ func isProcessRunning(p *os.Process) bool { result := C.GetExitCodeProcess(handle, &lpExitCode) return (result != 0) && (lpExitCode == C.STILL_ACTIVE) } + +func universalPathName(p string) (string, error) { + s, lpBufferSize, err := universalPathNameWithBufferSize(p, 1000) + if err != nil && err.(*universalNameRetrievalError).ErrorType() == errorMoreData { + s, _, err = universalPathNameWithBufferSize(s, lpBufferSize) + } + if err != nil { + return p, err + } + return s, err +} + +func universalPathNameWithBufferSize(p string, lpBufferSizeUse C.DWORD) (universalPath string, lpBufferSize C.DWORD, err error) { + cp := C.LPCWSTR(StringToUTF16UnmanagedString(p)) + defer C.free(unsafe.Pointer(cp)) + + // The possible data written to infoStruct (we request a UNIVERSAL_NAME_INFO below) not only consists of the struct, but also of the data (strings) + // pointed to by pointer-members within the struct. That's why this allocation needs to be much larger than just large enough to hold the struct itself. + infoStruct := C.LPVOID(C.calloc(C.size_t(lpBufferSizeUse), 1)) + defer C.free(unsafe.Pointer(infoStruct)) + + lpBufferSize = lpBufferSizeUse + errorCode := C.WNetGetUniversalNameW(cp, C.UNIVERSAL_NAME_INFO_LEVEL, infoStruct, &lpBufferSize) + err = getErrorOfWNetGetUniversalNameW(errorCode) + if err == nil { + lpUniversalName := unsafe.Pointer(*(*C.LPWSTR)(infoStruct)) + universalPath = UTF16StringToString(lpUniversalName) + } + return universalPath, lpBufferSize, err +} + +func getErrorOfWNetGetUniversalNameW(returnCode C.DWORD) error { + if returnCode == C.NO_ERROR { + return nil + } + if returnCode == C.ERROR_BAD_DEVICE { + return &universalNameRetrievalError{errorType: errorBadDevice, + message: `the string pointed to by the lpLocalPath parameter is invalid`} + } + if returnCode == C.ERROR_CONNECTION_UNAVAIL { + return &universalNameRetrievalError{errorType: errorConnectionUnavailable, + message: `there is no current connection to the remote device, but there is a remembered (persistent) connection to it`} + } + if returnCode == C.ERROR_EXTENDED_ERROR { + errorMessage, providerName, err := getLastWNetError() + if err != nil { + return &universalNameRetrievalError{errorType: errorExtendedError, + message: `a network-specific error occurred; getting extended error information failed: ` + err.Error()} + } + return &universalNameRetrievalError{errorType: errorExtendedError, + message: `a network-specific error occurred; Network provider "` + providerName + `" reports: ` + errorMessage} + } + if returnCode == C.ERROR_MORE_DATA { + return &universalNameRetrievalError{errorType: errorMoreData, + message: `despite trying to query with the requested buffer size, the buffer pointed to by the lpBuffer parameter was too small`} + } + if returnCode == C.ERROR_NOT_SUPPORTED { + return &universalNameRetrievalError{errorType: errorNotSupported, + message: `the dwInfoLevel parameter is set to UNIVERSAL_NAME_INFO_LEVEL, but the network provider does not support UNC names. (None of the network providers support this function)`} + } + if returnCode == C.ERROR_NO_NET_OR_BAD_PATH { + return &universalNameRetrievalError{errorType: errorNoNetOrBadPath, + message: `none of the network providers recognize the local name as having a connection. However, the network is not available for at least one provider to whom the connection may belong`} + } + if returnCode == C.ERROR_NO_NETWORK { + return &universalNameRetrievalError{errorType: errorNoNetwork, + message: `the network is unavailable`} + } + if returnCode == C.ERROR_NOT_CONNECTED { + return &universalNameRetrievalError{errorType: errorNotConnected, + message: `the device specified by the path is not redirected`} + } + return &universalNameRetrievalError{errorType: errorUndocumented, + message: fmt.Sprintf(`undocumented error code %d`, returnCode)} +} + +func getLastWNetError() (errorMessage, providerName string, err error) { + var lpError C.DWORD + + const errorBufferSize = 5000 + const nErrorBufSize C.DWORD = errorBufferSize + lpErrorBuf := (C.LPWSTR)(C.calloc(C.size_t(errorBufferSize+1), C.size_t(unsafe.Sizeof(uint16(0))))) + defer C.free(unsafe.Pointer(lpErrorBuf)) + + const nameBufferSize = 1000 + const nNameBufSize C.DWORD = nameBufferSize + lpNameBuf := (C.LPWSTR)(C.calloc(C.size_t(nameBufferSize+1), C.size_t(unsafe.Sizeof(uint16(0))))) + defer C.free(unsafe.Pointer(lpNameBuf)) + + returnCode := C.WNetGetLastErrorW(&lpError, lpErrorBuf, nErrorBufSize, lpNameBuf, nNameBufSize) + if returnCode == C.NO_ERROR { + return UTF16StringToString(unsafe.Pointer(lpErrorBuf)), UTF16StringToString(unsafe.Pointer(lpNameBuf)), nil + } + if returnCode == C.ERROR_INVALID_ADDRESS { + return "", "", fmt.Errorf("could not get last WNet error: ERROR_INVALID_ADDRESS") + } + return "", "", fmt.Errorf("could not get last WNet error: undocumented extended error code %d", returnCode) +} + +// StringToUTF16UnmanagedString returns an unmanaged, null-terminated UTF16 string for given string. +// The caller is responsible for freeing the returned pointer. +func StringToUTF16UnmanagedString(s string) unsafe.Pointer { + utf16String := utf16.Encode([]rune(s)) + utf16StringPointer := (*uint16)(C.calloc(C.size_t(len(utf16String)+1), C.size_t(unsafe.Sizeof(uint16(0))))) + currentCharPointer := utf16StringPointer + for _, c := range utf16String { + *currentCharPointer = c + currentCharPointer = (*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(currentCharPointer)) + unsafe.Sizeof(uint16(0)))) + } + return unsafe.Pointer(utf16StringPointer) +} + +// UTF16StringToString returns a string for a given null-terminated UTF16 string. +// This function does not call free on the parameter. +func UTF16StringToString(lpwString unsafe.Pointer) string { + ptr := (*uint16)(lpwString) + data := make([]uint16, 0, 0) + for { + if *ptr == 0 { + break + } + data = append(data, *ptr) + ptr = (*uint16)(unsafe.Pointer(((uintptr)(unsafe.Pointer(ptr))) + unsafe.Sizeof(uint16(0)))) + } + s := utf16.Decode(data) + return string(s) +} diff --git a/pkg/system/file_system_funcs.go b/pkg/system/file_system_funcs.go index 02409be..0afa765 100644 --- a/pkg/system/file_system_funcs.go +++ b/pkg/system/file_system_funcs.go @@ -303,3 +303,25 @@ func CleanUpFileOperation(file *os.File, returnError *error) { } } } + +// FilepathsEquivalent returns true if the filepaths a and b are semantically equivalent (exceptions may exist). +func FilepathsEquivalent(a, b string) bool { + a = filepath.Clean(a) + b = filepath.Clean(b) + if a == b { + return true + } + aResolved, aErr := universalPathName(a) + if aErr != nil && aErr.(*universalNameRetrievalError).ErrorType() != errorNotConnected { + log.Warnf(`could not determine UNC path for filepath "%s": %v\n`, a, aErr) + } + bResolved, bErr := universalPathName(b) + if bErr != nil && bErr.(*universalNameRetrievalError).ErrorType() != errorNotConnected { + log.Warnf(`could not determine UNC path for filepath "%s": %v\n`, b, bErr) + } + aResolved = filepath.Clean(aResolved) + bResolved = filepath.Clean(bResolved) + return (a == bResolved && bErr == nil) || + (b == aResolved && aErr == nil) || + (aResolved == bResolved && aErr == nil && bErr == nil) +} diff --git a/pkg/system/universal_path_error.go b/pkg/system/universal_path_error.go new file mode 100644 index 0000000..1292f9e --- /dev/null +++ b/pkg/system/universal_path_error.go @@ -0,0 +1,30 @@ +package system + +type universalNameRetrievalErrorType int + +const errorBadDevice universalNameRetrievalErrorType = 1 +const errorConnectionUnavailable universalNameRetrievalErrorType = 2 +const errorExtendedError universalNameRetrievalErrorType = 3 +const errorMoreData universalNameRetrievalErrorType = 4 +const errorNotSupported universalNameRetrievalErrorType = 5 +const errorNoNetOrBadPath universalNameRetrievalErrorType = 6 +const errorNoNetwork universalNameRetrievalErrorType = 7 +const errorNotConnected universalNameRetrievalErrorType = 8 +const errorUndocumented universalNameRetrievalErrorType = 9 + +type universalNameRetrievalError struct { + message string + errorType universalNameRetrievalErrorType +} + +func (err *universalNameRetrievalError) Error() string { + if err == nil { + return "" + } + return err.message +} + +// ErrorType returns the corresponsing WINAPI error type of the WNetGetUniversalNameW function call which generated the error. +func (err *universalNameRetrievalError) ErrorType() universalNameRetrievalErrorType { + return err.errorType +}