Question

I wrote the following python module to handle ssh connections in my program:

#!/usr/bin/env python
from vxpty import VX_PTY

class SSHError(Exception):
  def __init__(self, msg):
    self.msg = msg
  def __str__(self):
    return repr(self.msg)

class SSHShell:
  def __init__(self, host, port, user, password):
    self.host = host
    self.port = port
    self.user = user
    self.password = password
    self.authenticated = False
  def authenticate(self):
    self.tty = VX_PTY(['/usr/bin/ssh', 'ssh', '-p'+str(self.port), self.user+'@'+self.host])
    resp = self.tty.read()
    if "authenticity of host" in resp:
      self.tty.println('yes')
      while 1:
        resp = self.tty.read()
        if "added" in resp:
          break
      resp = self.tty.read()
    if "assword:" in resp:
      self.tty.println(self.password)
      tmp_resp = self.tty.read()
      tmp_resp += self.tty.read()
      if "denied" in tmp_resp or "assword:" in tmp_resp:
        raise(SSHError("Authentication failed"))
      else:
        self.authenticated = True
        self.tty.println("PS1=''")
    return self.authenticated
  def execute(self, os_cmd):
    self.tty.println(os_cmd)
    resp_buf = self.tty.read().replace(os_cmd+'\r\n', '')
    return resp_buf

Which uses a pty module I wrote earlier:

#!/usr/bin/env python
import os,pty

class PTYError(Exception):
  def __init__(self, msg):
    self.msg = msg
  def __str__(self):
    return repr(self.msg)

class VX_PTY:
  def __init__(self, execlp_args):
    self.execlp_args = execlp_args
    self.pty_execlp(execlp_args)
  def pty_execlp(self, execlp_args):
    (self.pid, self.f) = pty.fork()
    if self.pid==0:
      os.execlp(*execlp_args)
    elif self.pid<0:
      raise(PTYError("Failed to fork pty"))
  def read(self):
    data = None
    try:
      data = os.read(self.f, 1024)
    except Exception:
      raise(PTYError("Read failed"))
    return data
  def write(self, data):
    try:
      os.write(self.f, data)
    except Exception:
      raise(PTYError("Write failed"))
  def fsync(self):
    os.fsync(self.f)
  def seek_end(self):
    os.lseek(self.f, os.SEEK_END, os.SEEK_CUR)
  def println(self, ln):
    self.write(ln+'\n')

However, whenever I call the execute() method, I end up reading the output from the first line:

>>> import SSH;shell=SSH.SSHShell('localhost',22,'735tesla','notmypassword');shell.authenticate()
True
>>> shell.execute('whoami')
"\x1b[?1034hLaptop:~ 735Tesla$ PS1=''\r\n"
>>>

Then the second time I call read() I get the output:

>>> shell.tty.read()
'whoami\r\n735Tesla\r\n'
>>> 

Removing whoami\r\n from the output is not problem but is there any way to clear the output so I don't have to call read twice with the first command?

Was it helpful?

Solution

I think your problem is deeper than you realize. Luckily, it's also easier to solve than you realize.

What you seem to want is for os.read to return the entirety of what the shell has to send to you in one call. That's not something you can ask for. Depending on several factors, including, but not limited to, the shell's implementation, network bandwidth and latency, and the behavior of the PTYs (yours and the remote host's), the amount of data you'll get back in each call to read can be as much as, well, everything, and as little as a single character.

If you want to receive just the output of your command, you should bracket it with unique markers, and don't worry about messing with PS1. What I mean is that you need to make the shell output a unique string before your command executes and another one after your command executes. Your tty.read method should then return all the text it finds in between these two marker strings. The easiest way to make the shell output these unique strings is just to use the echo command.

For multiline commands, you have to wrap the command in a shell function, and echo the markers before and after executing the function.

A simple implementation is as follows:

def execute(self, cmd):
    if '\n' in cmd:
        self.pty.println(
            '__cmd_func__(){\n%s\n' % cmd +
            '}; echo __"cmd_start"__; __cmd_func__; echo __"cmd_end"__; unset -f __cmd_func__'
        )
    else:
        self.pty.println('echo __"cmd_start"__; %s; echo __"cmd_end"__' % cmd)

    resp = ''
    while not '__cmd_start__\r\n' in resp:
        resp += self.pty.read()

    resp = resp[resp.find('__cmd_start__\r\n') + 15:] # 15 == len('__cmd_start__\r\n')

    while not '_cmd_end__' in resp:
        resp += self.pty.read()

    return resp[:resp.find('__cmd_end__')]
Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top