Golang学习笔记(08-4-ssh服务)

653 阅读5分钟

1. SSH客户端使用

在运维开发中,有时会涉及到在目标机器上执行shell命令,或者需要通过ssh通道传输文件,此时ssh客户端的使用就必不可少。如果涉及频繁连接SSH客户端,建议使用连接池进行管理,降低SSH开销,同时避免过多的ssh连接导致被跳板机拒绝。使用golang实现ssh客户端场景较多,实现ssh服务端场景较少,此处仅实现ssh客户端作为示例。涉及到的包有:

  • golang.org/x/crypto/ssh
  • github.com/bramvdbogaerde/go-scp

1.1. 使用ssh client

在实际操作中,分两种情况,一种是直接操作目标主机,一种是通过跳板机操作目标机器,下面代码对两种情况进行了 ssh.Client 实现!

// 目标主机信息
type Host struct {
	IP         string
	SSHPort    int
	Username   string
	Password   string
	SSHKey     string
	JumpServer *JumpServer
}

// 跳板机
type JumpServer struct {
	IP       string
	SSHPort  int
	Username string
	Password string
	SSHKey   string
}

// 生成密钥信息, 分为两种:password 和 ssh key
func sshClientConfig(passwd, key string) (auth []ssh.AuthMethod, err error) {
	if passwd != "" {
		auth = append(auth, ssh.Password(passwd))
	}
	if key != "" {
		privateKey, err := ssh.ParsePrivateKey([]byte(key))
		if err != nil {
			return nil, err
		}
		auth = append(auth, ssh.PublicKeys(privateKey))
	}
	return
}

// 生成ssh 客户端信息
func opensshClient(host *Host) (client *ssh.Client, err error) {
	if host.JumpServer == nil {
		return clientWithoutJumpServer(host)
	}
	return clientWithJumpServer(host)
}

// 如果没有跳板机,则直接连接目标机器
func clientWithoutJumpServer(host *Host) (client *ssh.Client, err error) {
	auth, err := sshClientConfig(host.Password, host.SSHKey) // 生成密钥
	if err != nil {
		return nil, err
	}
	// 生成ssh client的配置
	config := &ssh.ClientConfig{
		User: host.Username,
		Auth: auth,
		HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
			return nil
		},
		Timeout: time.Second * 10,
	}
	// 拨号
	client, err = ssh.Dial("tcp", fmt.Sprintf("%s:%d", host.IP, host.SSHPort), config)
	if err != nil {
		return nil, err
	}
	return
}

// 如果存在跳板机,则进行中转
func clientWithJumpServer(host *Host) (client *ssh.Client, err error) {
	// 生成跳板机密钥
	jAuth, err := sshClientConfig(host.JumpServer.Password, host.JumpServer.SSHKey)
	if err != nil {
		return nil, err
	}
	// 生成跳板机的ssh client配置
	jConfig := &ssh.ClientConfig{
		User: host.JumpServer.Username,
		Auth: jAuth,
		HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
			return nil
		},
		Timeout: time.Second * 10,
	}
	// 对跳板机进行拨号
	jClient, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", host.JumpServer.IP, host.JumpServer.SSHPort), jConfig)
	if err != nil {
		return nil, err
	}
	// 生成目标机器的密钥
	auth, err := sshClientConfig(host.Password, host.SSHKey)
	if err != nil {
		return nil, err
	}
	// 生成目标机器的ssh client配置
	config := &ssh.ClientConfig{
		User: host.Username,
		Auth: auth,
		HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
			return nil
		},
		Timeout: time.Second * 10,
	}
	// 使用跳板机对目标机器进行拨号
	conn, err := jClient.Dial("tcp", fmt.Sprintf("%s:%d", host.IP, host.SSHPort))
	if err != nil {
		return nil, err
	}
	// 生成目标机器的 ssh client
	clientConn, channels, requests, err := ssh.NewClientConn(conn, fmt.Sprintf("%s:%d", host.IP, host.SSHPort), config)
	if err != nil {
		return nil, err
	}
	return ssh.NewClient(clientConn, channels, requests), nil
}

1.2. 执行命令

func execCommand(client *ssh.Client, cmd ...string) (stdout, stderr string, err error) {
	session, err := client.NewSession() // 开启新的ssh会话
	if err != nil {
		return "", "", err
	}
	defer func() { _ = session.Close() }()
	// 指定标准输出和标准错误
	var stdOut bytes.Buffer 
	var stdErr bytes.Buffer
	session.Stderr = &stdErr
	session.Stdout = &stdOut
	if err := session.Run(strings.Join(cmd, " && ")); err != nil {
		return stdOut.String(), stdErr.String(), err
	}
	return stdOut.String(), stdErr.String(), nil
}
// 之间连接对于主机,并执行命令
func main() {
	node100 := &Host{
		IP:       "10.4.7.100",
		SSHPort:  22,
		Username: "root",
		SSHKey: func() string {
			content, _ := ioutil.ReadFile("/root/.ssh/id_rsa")
			return string(content)
		}(),
	}
	client, err := opensshClient(node100)
	if err != nil {
		logger.Errorf("open ssh connect to %s failed, error:%s", node100.IP, err.Error())
		return
	}
	defer func() { _ = client.Close() }()
	stdout, stderr, err := execCommand(client, "echo $HOSTNAME", "df -h")
	if err != nil {
		logger.Errorf("%s run command echo $HOSTNAME failed, stdout:%s, stderr:%s, err:%s", node100.IP, stdout, stderr, err.Error())
		return
	}
	fmt.Print(stdout)
	fmt.Print(stderr)
}
[root@duduniao ssh]# go run command.go
jumpserver-100
Filesystem      Size  Used Avail Use% Mounted on
udev            1.9G     0  1.9G   0% /dev
tmpfs           393M  1.1M  392M   1% /run
/dev/sda2        20G  7.4G   12G  40% /
tmpfs           2.0G     0  2.0G   0% /dev/shm
tmpfs           5.0M     0  5.0M   0% /run/lock
tmpfs           2.0G     0  2.0G   0% /sys/fs/cgroup
tmpfs           393M     0  393M   0% /run/user/0
// 存在跳板机的情况
func main() {
	node101 := &Host{
		IP:       "10.4.7.101",
		SSHPort:  22,
		Username: "root",
		SSHKey: func() string {
			content, _ := ioutil.ReadFile("/root/.ssh/id_rsa")
			return string(content)
		}(),
		JumpServer: &JumpServer{
			IP:       "10.4.7.100",
			SSHPort:  22,
			Username: "root",
			SSHKey: func() string {
				content, _ := ioutil.ReadFile("/root/.ssh/id_rsa")
				return string(content)
			}(),
		},
	}
	client, err := opensshClient(node101)
	if err != nil {
		logger.Errorf("open ssh connect to %s failed, error:%s", node101.IP, err.Error())
		return
	}
	defer func() { _ = client.Close() }()
	stdout, stderr, err := execCommand(client, "echo $HOSTNAME", "df -h")
	if err != nil {
		logger.Errorf("%s run command echo $HOSTNAME failed, stdout:%s, stderr:%s, err:%s", node101.IP, stdout, stderr, err.Error())
		return
	}
	fmt.Print(stdout)
	fmt.Print(stderr)
}
[root@duduniao ssh]# go run command.go
worker-101
Filesystem            Size  Used Avail Use% Mounted on
udev                  1.9G     0  1.9G   0% /dev
tmpfs                 393M  1.2M  392M   1% /run
/dev/sda2              20G  7.4G   12G  40% /
tmpfs                 2.0G     0  2.0G   0% /dev/shm
tmpfs                 5.0M     0  5.0M   0% /run/lock
tmpfs                 2.0G     0  2.0G   0% /sys/fs/cgroup
10.4.7.100:/data/nfs   20G  7.4G   12G  40% /data/nfs
tmpfs                 393M     0  393M   0% /run/user/0

1.3. 转发文件

// 发送文件
func sendFiles(client *ssh.Client, remoteDir string, localFile string) error {
	// github.com/bramvdbogaerde/go-scp
	// 调用scp模块,创建scp的client 
	scpClient, err := scp.NewClientBySSH(client)
	if err != nil {
		return err
	}
	defer func() { _ = scpClient.Close }()
	if err := scpClient.Connect(); err != nil {
		return err
	}
	file, err := os.Open(localFile)
	if err != nil {
		return err
	}
	// 这里存在几个问题:
	// 1. 文件权限必须要是字符串的数字格式。无法使用 file.stat 中的mode
	// 2. 同一个 scpClient 只能发送一次文件,发多个文件会出现: ssh: StdinPipe after process started
	// 3. 应该有其它的模块能改进避免上述的俩个文件
	if err := scpClient.CopyFile(file, path.Join(remoteDir, path.Base(localFile)), "0755"); err != nil {
		return err
	}
	_ = file.Close()
	return nil
}
func main() {
	node101 := &Host{
		IP:       "10.4.7.101",
		SSHPort:  22,
		Username: "root",
		SSHKey: func() string {
			content, _ := ioutil.ReadFile("/root/.ssh/id_rsa")
			return string(content)
		}(),
		JumpServer: &JumpServer{
			IP:       "10.4.7.100",
			SSHPort:  22,
			Username: "root",
			SSHKey: func() string {
				content, _ := ioutil.ReadFile("/root/.ssh/id_rsa")
				return string(content)
			}(),
		},
	}
	client, err := opensshClient(node101)
	if err != nil {
		logger.Errorf("open ssh connect to %s failed, error:%s", node101.IP, err.Error())
		return
	}
	defer func() { _ = client.Close() }()
	err = sendFiles(client, "/tmp", "/root/bin/scan_host.sh")
	if err != nil {
		logger.Errorf("%s send file failed ,err:%s", node101.IP, err.Error())
	}
}

2. SSH作为代理转发

ssh 服务可以作为隧道进行请求的转发,比如目标机器上存在一个nginx服务器,对外暴露80端口,但是当前服务器与目标机器网络不通,此时可以通过中间的跳板机上ssh通道进行HTTP的请求转发!

// 这里面涉及的channel让gc进行回收,手动关闭容易出现panic
// 端口转发
func forward(client *ssh.Client, protocol, localAddr, remoteAddr string, stop chan bool, errMsg chan error) {
	// 打开本地端口
	listener, err := net.Listen(protocol, localAddr)
	if err != nil {
		errMsg <- err
		return
	}
	defer func() { _ = listener.Close() }()
	// 定义异常退出机制,因为设置了 stop chan,为了避免在stop chan阻塞,引入err chan,两者满足其一就能退出
	var errChan = make(chan error)

	// 循环接收本地端口的请求
	go func() {
		for {
			localConn, err := listener.Accept()
			if localConn == nil {
				errMsg <- err
				return
			}
			if err != nil {
				_ = localConn.Close()
				errMsg <- err
				return
			}
			go establishLocal(client, protocol, remoteAddr, localConn, errChan)
		}
	}()
	select {
	case <-stop:
		errMsg <- nil
	case err := <-errChan:
		errMsg <- err
	}
}

// 处理本地端口的请求
func establishLocal(client *ssh.Client, protocol, remoteAddr string, local net.Conn, errChan chan error) {
	// 打开远程的端口, 每次接收一个新的TCP连接,都得开一次远程转发
	remote, err := client.Dial(protocol, remoteAddr)
	if err != nil {
		errChan <- err
		return
	}
	defer func() { _ = remote.Close() }()
	errCh := make(chan error, 1)
	go exchangeData(local, remote, errCh)
	go exchangeData(remote, local, errCh)
	<-errCh
	<-errCh
}

type closeWriter interface {
	CloseWrite() error
}

// 数据交换
func exchangeData(r io.Reader, w io.Writer, errCh chan error) {
	_, err := io.Copy(w, r)
	if tcpConn, ok := w.(closeWriter); ok {
		_ = tcpConn.CloseWrite() // 必须要关闭,否则内存泄露
	}
	errCh <- err
}
func main() {
	node101 := &Host{
		IP:       "10.4.7.101",
		SSHPort:  22,
		Username: "root",
		SSHKey: func() string {
			content, _ := ioutil.ReadFile("/root/.ssh/id_rsa")
			return string(content)
		}(),
		JumpServer: &JumpServer{
			IP:       "10.4.7.100",
			SSHPort:  22,
			Username: "root",
			SSHKey: func() string {
				content, _ := ioutil.ReadFile("/root/.ssh/id_rsa")
				return string(content)
			}(),
		},
	}
	client, err := opensshClient(node101)
	if err != nil {
		logger.Errorf("open ssh connect to %s failed, error:%s", node101.IP, err.Error())
		return
	}
	defer func() { _ = client.Close() }()

	stop := make(chan bool, 1)
	errMsg := make(chan error, 1)
	go forward(client, "tcp", "127.0.0.1:10080", "172.17.0.2:80", stop, errMsg)

	// 测试网络隧道是否就绪, 因为使用goroutine打开隧道,本地通道可能还能没有就绪
	for i := 0; i < 10; i++ {
		_, err = net.DialTimeout("tcp", "127.0.0.1:10080", time.Millisecond*100)
		if err == nil {
			break
		}
	}
	if err != nil {
		stop <- true
		return
	}
	// 测试
	for i := 0; i < 100000; i++ {
		httpClient := http.Client{Timeout: time.Second}
		resp, err := httpClient.Get("http://127.0.0.1:10080/info")
		if err != nil {
			logger.Errorf("send request failed, err:%s", err.Error())
			break
		}
		content, _ := ioutil.ReadAll(resp.Body)
		_ = resp.Body.Close()
		fmt.Print(string(content))
		time.Sleep(time.Millisecond)
	}
	stop <- true
}