goutils/cmd/clustersh/main.go

348 lines
7.5 KiB
Go

package main
import (
"bufio"
"flag"
"fmt"
"io"
"log"
"net"
"os"
"strings"
"sync"
"git.wntrmute.dev/kyle/goutils/lib"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)
func usage(w io.Writer) {
fmt.Fprintf(w, `Usage: %s [-a addresses] [-c chunksize] [-h] [-u user] command args
Flags:
-a addresses The comma-separated list of servers to send
to. There must not be any spaces in this list.
-c chunk The size of chunks to transfer, in bytes. The
default is 16MB.
-u user The SSH username to use. It must be the
same for all hosts. Defaults to the value
of the USER environment variable.
Commands:
exec, run: run a command on the servers
The args list must be the command line to send to the hosts.
upload, up, push, send: upload a file to the servers
The first argument is the local file to upload, and the second
is the filename to store it as on the remote.
download, down, pull, fetch: download a file from the servers
The first argument is the filename to fetch.The second argument
is the base name for the local file. It will have a '-' and the
hostname appended.
`, lib.ProgName())
}
var chunkSize = 16777216
var modes = ssh.TerminalModes{
ssh.ECHO: 0,
ssh.TTY_OP_ISPEED: 14400,
ssh.TTY_OP_OSPEED: 14400,
}
func sshAgent() ssh.AuthMethod {
a, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK"))
if err == nil {
return ssh.PublicKeysCallback(agent.NewClient(a).Signers)
}
lib.Err(lib.ExitFailure, err, "failed to authenticate with SSH agent")
return nil
}
func sshConfig(user string) *ssh.ClientConfig {
return &ssh.ClientConfig{
User: user,
Auth: []ssh.AuthMethod{
sshAgent(),
},
}
}
func scanner(host string, in io.Reader, out io.Writer) {
scanner := bufio.NewScanner(in)
for scanner.Scan() {
line := scanner.Text()
fmt.Fprintf(out, "[%s] %s\n", host, line)
}
}
func logError(host string, err error, format string, args ...interface{}) {
msg := fmt.Sprintf(format, args...)
log.Printf("[%s] FAILED: %s: %v\n", host, msg, err)
}
func exec(wg *sync.WaitGroup, user, host string, commands []string) {
var shutdown []func() error
defer func() {
for i := len(shutdown) - 1; i >= 0; i-- {
err := shutdown[i]()
if err != nil && err != io.EOF {
logError(host, err, "shutting down")
}
}
}()
defer wg.Done()
conf := sshConfig(user)
conn, err := ssh.Dial("tcp", host+":22", conf)
if err != nil {
logError(host, err, "failed to connect")
return
}
shutdown = append(shutdown, conn.Close)
session, err := conn.NewSession()
if err != nil {
logError(host, err, "failed to establish session")
return
}
shutdown = append(shutdown, session.Close)
if err := session.RequestPty("xterm", 80, 40, modes); err != nil {
session.Close()
logError(host, err, "request for pty failed")
return
}
stdout, err := session.StdoutPipe()
if err != nil {
logError(host, err, "failed to setup standard output")
return
}
go scanner(host, stdout, os.Stdout)
stderr, err := session.StderrPipe()
if err != nil {
logError(host, err, "failed to setup standard error")
return
}
go scanner(host, stderr, os.Stderr)
for _, command := range commands {
err = session.Run(command)
if err != nil {
logError(host, err, "running command failed")
return
}
}
}
func upload(wg *sync.WaitGroup, user, host, local, remote string) {
var shutdown []func() error
defer func() {
for i := len(shutdown) - 1; i >= 0; i-- {
err := shutdown[i]()
if err != nil && err != io.EOF {
logError(host, err, "shutting down")
}
}
}()
defer wg.Done()
conf := sshConfig(user)
conn, err := ssh.Dial("tcp", host+":22", conf)
if err != nil {
logError(host, err, "failed to connect")
return
}
shutdown = append(shutdown, conn.Close)
sftp, err := sftp.NewClient(conn)
if err != nil {
logError(host, err, "setting up SFTP client")
return
}
shutdown = append(shutdown, sftp.Close)
remoteFile, err := sftp.Create(remote)
if err != nil {
logError(host, err, "creating file %s on remote", remote)
return
}
shutdown = append(shutdown, remoteFile.Close)
localFile, err := os.Open(local)
if err != nil {
logError(host, err, "opening local file")
return
}
shutdown = append(shutdown, localFile.Close)
var buf = make([]byte, chunkSize)
for {
var n int
n, err = localFile.Read(buf)
if n > 0 {
_, err = remoteFile.Write(buf[:n])
if err != nil {
logError(host, err, "writing chunk")
return
}
fmt.Printf("[%s] wrote %d-byte chunk\n", host, n)
}
if err == io.EOF {
break
} else if err != nil {
logError(host, err, "reading chunk")
return
}
}
fmt.Printf("[%s] %s uploaded to %s\n", host, remote, local)
}
func download(wg *sync.WaitGroup, user, host, local, remote string) {
var shutdown []func() error
defer func() {
for i := len(shutdown) - 1; i >= 0; i-- {
err := shutdown[i]()
if err != nil && err != io.EOF {
logError(host, err, "shutting down")
}
}
}()
defer wg.Done()
conf := sshConfig(user)
conn, err := ssh.Dial("tcp", host+":22", conf)
if err != nil {
logError(host, err, "failed to connect")
return
}
shutdown = append(shutdown, conn.Close)
sftp, err := sftp.NewClient(conn)
if err != nil {
logError(host, err, "setting up SFTP client")
return
}
shutdown = append(shutdown, sftp.Close)
remoteFile, err := sftp.Open(remote)
if err != nil {
logError(host, err, "opening file %s on remote", remote)
return
}
shutdown = append(shutdown, remoteFile.Close)
local = local + "-" + host
localFile, err := os.Create(local)
if err != nil {
logError(host, err, "opening local file")
return
}
shutdown = append(shutdown, localFile.Close)
var buf = make([]byte, chunkSize)
for {
var n int
n, err = remoteFile.Read(buf)
if n > 0 {
_, err = localFile.Write(buf[:n])
if err != nil {
logError(host, err, "writing chunk")
return
}
fmt.Printf("[%s] wrote %d-byte chunk\n", host, n)
}
if err == io.EOF {
break
} else if err != nil {
logError(host, err, "reading chunk")
return
}
}
fmt.Printf("[%s] %s downloaded to %s\n", host, remote, local)
}
func init() {
flag.Usage = func() {
usage(os.Stdout)
}
}
func main() {
var hostsFlag, user string
flag.StringVar(&hostsFlag, "a", "", "`hosts` to run on")
flag.IntVar(&chunkSize, "c", chunkSize, "`size` in bytes of transfer chunks")
flag.StringVar(&user, "u", os.Getenv("USER"), "`username` to run commands as")
flag.Parse()
hosts := strings.Split(hostsFlag, ",")
if len(hosts) == 0 {
os.Exit(0)
}
if flag.NArg() < 1 {
usage(os.Stderr)
os.Exit(1)
}
cmd := flag.Arg(0)
var wg = new(sync.WaitGroup)
switch cmd {
case "exec", "run":
if flag.NArg() < 2 {
usage(os.Stderr)
os.Exit(1)
}
commands := []string{strings.Join(flag.Args()[1:], " ")}
for _, host := range hosts {
wg.Add(1)
go exec(wg, user, host, commands)
}
case "upload", "up", "push", "send":
if flag.NArg() != 3 {
usage(os.Stderr)
os.Exit(1)
}
for _, host := range hosts {
wg.Add(1)
go upload(wg, user, host, flag.Arg(1), flag.Arg(2))
}
case "download", "down", "pull", "fetch":
if flag.NArg() != 3 {
usage(os.Stderr)
os.Exit(1)
}
for _, host := range hosts {
wg.Add(1)
go download(wg, user, host, flag.Arg(2), flag.Arg(1))
}
case "help":
usage(os.Stdout)
os.Exit(0)
default:
fmt.Fprintf(os.Stderr, "Unknown command %s.\n", cmd)
usage(os.Stderr)
os.Exit(1)
}
log.Printf("waiting for sessions to complete")
wg.Wait()
}