Skip to content

Commit fd4a0f5

Browse files
committed
Refactored cp (#56)
1 parent 42f2342 commit fd4a0f5

File tree

1 file changed

+102
-72
lines changed

1 file changed

+102
-72
lines changed

cp.go

Lines changed: 102 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package main
22

33
import (
4+
"fmt"
45
"io"
56
"os"
67
"os/exec"
@@ -39,35 +40,26 @@ func init() {
3940
// Flags
4041
var cpHelp bool // -h, --help flag
4142

42-
func runCp(cmd *Command, args []string) {
43-
if cpHelp {
44-
cmd.PrintUsage()
45-
}
46-
if len(args) != 2 {
47-
cmd.PrintShortUsage()
48-
}
49-
43+
func TarFromSource(api *ScalewayAPI, source string) (*io.ReadCloser, error) {
5044
var tarOutputStream io.ReadCloser
51-
var tarErrorStream io.ReadCloser
5245

53-
// source
54-
source := args[0]
55-
if strings.Index(source, ":") > -1 { // source server address
46+
// source is a server address + path (scp-like uri)
47+
if strings.Index(source, ":") > -1 {
5648
log.Debugf("Creating a tarball remotely and streaming it using SSH")
57-
serverParts := strings.Split(args[0], ":")
49+
serverParts := strings.Split(source, ":")
5850
if len(serverParts) != 2 {
59-
log.Fatalf("usage: scw %s", cmd.UsageLine)
51+
return nil, fmt.Errorf("invalid source uri, see 'scw cp -h' for usage")
6052
}
6153

62-
serverID := cmd.API.GetServerID(serverParts[0])
54+
serverID := api.GetServerID(serverParts[0])
6355

64-
server, err := cmd.API.GetServer(serverID)
56+
server, err := api.GetServer(serverID)
6557
if err != nil {
66-
log.Fatalf("Failed to get server information for %s: %v", serverID, err)
58+
return nil, err
6759
}
6860

6961
dir, base := PathToTARPathparts(serverParts[1])
70-
log.Debugf("Kind of equivalent of 'scp root@%s:%s/%s ...'", server.PublicAddress.IP, dir, base)
62+
log.Debugf("Equivalent to 'scp root@%s:%s/%s ...'", server.PublicAddress.IP, dir, base)
7163

7264
// remoteCommand is executed on the remote server
7365
// it streams a tarball raw content
@@ -86,61 +78,73 @@ func runCp(cmd *Command, args []string) {
8678

8779
tarOutputStream, err = spawnSrc.StdoutPipe()
8880
if err != nil {
89-
log.Fatal(err)
81+
return nil, err
9082
}
91-
tarErrorStream, err = spawnSrc.StderrPipe()
83+
defer tarOutputStream.Close()
84+
85+
tarErrorStream, err := spawnSrc.StderrPipe()
9286
if err != nil {
93-
log.Fatal(err)
87+
return nil, err
9488
}
89+
defer tarErrorStream.Close()
90+
io.Copy(os.Stderr, tarErrorStream)
9591

9692
err = spawnSrc.Start()
9793
if err != nil {
98-
log.Fatalf("Failed to start ssh command: %v", err)
94+
return nil, err
9995
}
100-
10196
defer spawnSrc.Wait()
10297

103-
io.Copy(os.Stderr, tarErrorStream)
104-
} else if source == "-" { // stdin
98+
return &tarOutputStream, nil
99+
}
100+
101+
// source is stdin
102+
if source == "-" {
105103
log.Debugf("Streaming tarball from stdin")
106104
tarOutputStream = os.Stdin
107-
} else { // source host path
108-
log.Debugf("Taring local path %s", source)
109-
path, err := filepath.Abs(source)
110-
if err != nil {
111-
log.Fatalf("Cannot tar local path: %v", err)
112-
}
113-
path, err = filepath.EvalSymlinks(path)
114-
if err != nil {
115-
log.Fatalf("Cannot tar local path: %v", err)
116-
}
117-
log.Debugf("Real local path is %s", path)
105+
defer tarOutputStream.Close()
106+
return &tarOutputStream, nil
107+
}
118108

119-
dir, base := PathToTARPathparts(path)
109+
// source is a path on localhost
110+
log.Debugf("Taring local path %s", source)
111+
path, err := filepath.Abs(source)
112+
if err != nil {
113+
return nil, err
114+
}
115+
path, err = filepath.EvalSymlinks(path)
116+
if err != nil {
117+
return nil, err
118+
}
119+
log.Debugf("Real local path is %s", path)
120120

121-
tarOutputStream, err = archive.TarWithOptions(dir, &archive.TarOptions{
122-
Compression: archive.Uncompressed,
123-
IncludeFiles: []string{base},
124-
})
125-
if err != nil {
126-
log.Fatalf("Cannot tar local path: %v", err)
127-
}
121+
dir, base := PathToTARPathparts(path)
122+
123+
tarOutputStream, err = archive.TarWithOptions(dir, &archive.TarOptions{
124+
Compression: archive.Uncompressed,
125+
IncludeFiles: []string{base},
126+
})
127+
if err != nil {
128+
return nil, err
128129
}
130+
defer tarOutputStream.Close()
131+
return &tarOutputStream, nil
132+
}
129133

130-
// destination
131-
destination := args[1]
132-
if strings.Index(destination, ":") > -1 { // destination server address
134+
func UntarToDest(api *ScalewayAPI, sourceStream *io.ReadCloser, destination string) error {
135+
// destination is a server address + path (scp-like uri)
136+
if strings.Index(destination, ":") > -1 {
133137
log.Debugf("Streaming using ssh and untaring remotely")
134138
serverParts := strings.Split(destination, ":")
135139
if len(serverParts) != 2 {
136-
log.Fatalf("usage: scw %s", cmd.UsageLine)
140+
return fmt.Errorf("invalid destination uri, see 'scw cp -h' for usage")
137141
}
138142

139-
serverID := cmd.API.GetServerID(serverParts[0])
143+
serverID := api.GetServerID(serverParts[0])
140144

141-
server, err := cmd.API.GetServer(serverID)
145+
server, err := api.GetServer(serverID)
142146
if err != nil {
143-
log.Fatalf("Failed to get server information for %s: %v", serverID, err)
147+
return err
144148
}
145149

146150
// remoteCommand is executed on the remote server
@@ -150,7 +154,7 @@ func runCp(cmd *Command, args []string) {
150154
if os.Getenv("DEBUG") == "1" {
151155
remoteCommand = append(remoteCommand, "-v")
152156
}
153-
remoteCommand = append(remoteCommand, "-tf", "-")
157+
remoteCommand = append(remoteCommand, "-xf", "-")
154158

155159
// execCmd contains the ssh connection + the remoteCommand
156160
execCmd := append(NewSSHExecCmd(server.PublicAddress.IP, false, remoteCommand))
@@ -159,40 +163,66 @@ func runCp(cmd *Command, args []string) {
159163

160164
untarInputStream, err := spawnDst.StdinPipe()
161165
if err != nil {
162-
log.Fatal(err)
166+
return err
163167
}
168+
defer untarInputStream.Close()
169+
164170
untarErrorStream, err := spawnDst.StderrPipe()
165171
if err != nil {
166-
log.Fatal(err)
172+
return err
167173
}
174+
defer untarErrorStream.Close()
175+
168176
untarOutputStream, err := spawnDst.StdoutPipe()
169177
if err != nil {
170-
log.Fatal(err)
178+
return err
171179
}
180+
defer untarOutputStream.Close()
172181

173182
err = spawnDst.Start()
174183
if err != nil {
175-
log.Fatalf("Failed to start ssh command: %v", err)
184+
return err
176185
}
177-
178186
defer spawnDst.Wait()
179187

180-
io.Copy(untarInputStream, tarOutputStream)
188+
io.Copy(untarInputStream, *sourceStream)
181189
io.Copy(os.Stderr, untarErrorStream)
182-
io.Copy(os.Stdout, untarOutputStream)
183-
} else if destination == "-" { // stdout
184-
log.Debugf("Writing tarOutputStream(%v) to os.Stdout(%v)", tarOutputStream, os.Stdout)
185-
written, err := io.Copy(os.Stdout, tarOutputStream)
186-
log.Debugf("%d bytes written", written)
187-
if err != nil {
188-
log.Fatal(err)
189-
}
190+
_, err = io.Copy(os.Stdout, untarOutputStream)
191+
return err
192+
}
190193

191-
} else { // destination host path
192-
log.Debugf("Untaring to local path: %s", destination)
193-
err := archive.Untar(tarOutputStream, destination, &archive.TarOptions{NoLchown: true})
194-
if err != nil {
195-
log.Fatalf("Failed to untar the remote archive: %v", err)
196-
}
194+
// destination is stdout
195+
if destination == "-" { // stdout
196+
log.Debugf("Writing sourceStream(%v) to os.Stdout(%v)", sourceStream, os.Stdout)
197+
_, err := io.Copy(os.Stdout, *sourceStream)
198+
return err
199+
}
200+
201+
// destination is a path on localhost
202+
log.Debugf("Untaring to local path: %s", destination)
203+
err := archive.Untar(*sourceStream, destination, &archive.TarOptions{NoLchown: true})
204+
return err
205+
}
206+
207+
func runCp(cmd *Command, args []string) {
208+
if cpHelp {
209+
cmd.PrintUsage()
210+
}
211+
if len(args) != 2 {
212+
cmd.PrintShortUsage()
213+
}
214+
215+
if strings.Count(args[0], ":") > 1 || strings.Count(args[1], ":") > 1 {
216+
log.Fatalf("usage: scw %s", cmd.UsageLine)
217+
}
218+
219+
sourceStream, err := TarFromSource(cmd.API, args[0])
220+
if err != nil {
221+
log.Fatalf("Cannot tar from source '%s': %v", args[0], err)
222+
}
223+
224+
err = UntarToDest(cmd.API, sourceStream, args[1])
225+
if err != nil {
226+
log.Fatalf("Cannot untar to destionation '%s': %v", args[1], err)
197227
}
198228
}

0 commit comments

Comments
 (0)