1   
  2   
  3   
  4   
  5   
  6   
  7   
  8   
  9   
 10   
 11   
 12   
 13   
 14   
 15   
 16   
 17   
 18   
 19  """ 
 20  L{HostKeys} 
 21  """ 
 22   
 23  import base64 
 24  import binascii 
 25  from Crypto.Hash import SHA, HMAC 
 26  import UserDict 
 27   
 28  from paramiko.common import * 
 29  from paramiko.dsskey import DSSKey 
 30  from paramiko.rsakey import RSAKey 
 31  from paramiko.util import get_logger 
 32  from paramiko.ecdsakey import ECDSAKey 
 33   
 34   
 36   
 38          self.line = line 
 39          self.exc = exc 
 40          self.args = (line, exc) 
   41   
 42   
 44      """ 
 45      Representation of a line in an OpenSSH-style "known hosts" file. 
 46      """ 
 47   
 48 -    def __init__(self, hostnames=None, key=None): 
  49          self.valid = (hostnames is not None) and (key is not None) 
 50          self.hostnames = hostnames 
 51          self.key = key 
  52   
 53 -    def from_line(cls, line, lineno=None): 
  54          """ 
 55          Parses the given line of text to find the names for the host, 
 56          the type of key, and the key data. The line is expected to be in the 
 57          format used by the openssh known_hosts file. 
 58   
 59          Lines are expected to not have leading or trailing whitespace. 
 60          We don't bother to check for comments or empty lines.  All of 
 61          that should be taken care of before sending the line to us. 
 62   
 63          @param line: a line from an OpenSSH known_hosts file 
 64          @type line: str 
 65          """ 
 66          log = get_logger('paramiko.hostkeys') 
 67          fields = line.split(' ') 
 68          if len(fields) < 3: 
 69               
 70              log.info("Not enough fields found in known_hosts in line %s (%r)" % 
 71                       (lineno, line)) 
 72              return None 
 73          fields = fields[:3] 
 74   
 75          names, keytype, key = fields 
 76          names = names.split(',') 
 77   
 78           
 79           
 80          try: 
 81              if keytype == 'ssh-rsa': 
 82                  key = RSAKey(data=base64.decodestring(key)) 
 83              elif keytype == 'ssh-dss': 
 84                  key = DSSKey(data=base64.decodestring(key)) 
 85              elif keytype == 'ecdsa-sha2-nistp256': 
 86                  key = ECDSAKey(data=base64.decodestring(key)) 
 87              else: 
 88                  log.info("Unable to handle key of type %s" % (keytype,)) 
 89                  return None 
 90   
 91          except binascii.Error, e: 
 92              raise InvalidHostKey(line, e) 
 93   
 94          return cls(names, key) 
  95      from_line = classmethod(from_line) 
 96   
 98          """ 
 99          Returns a string in OpenSSH known_hosts file format, or None if 
100          the object is not in a valid state.  A trailing newline is 
101          included. 
102          """ 
103          if self.valid: 
104              return '%s %s %s\n' % (','.join(self.hostnames), self.key.get_name(), 
105                     self.key.get_base64()) 
106          return None 
 107   
108 -    def __repr__(self): 
 109          return '<HostKeyEntry %r: %r>' % (self.hostnames, self.key) 
  110   
111   
113      """ 
114      Representation of an openssh-style "known hosts" file.  Host keys can be 
115      read from one or more files, and then individual hosts can be looked up to 
116      verify server keys during SSH negotiation. 
117   
118      A HostKeys object can be treated like a dict; any dict lookup is equivalent 
119      to calling L{lookup}. 
120   
121      @since: 1.5.3 
122      """ 
123   
125          """ 
126          Create a new HostKeys object, optionally loading keys from an openssh 
127          style host-key file. 
128   
129          @param filename: filename to load host keys from, or C{None} 
130          @type filename: str 
131          """ 
132           
133          self._entries = [] 
134          if filename is not None: 
135              self.load(filename) 
 136   
137 -    def add(self, hostname, keytype, key): 
 138          """ 
139          Add a host key entry to the table.  Any existing entry for a 
140          C{(hostname, keytype)} pair will be replaced. 
141   
142          @param hostname: the hostname (or IP) to add 
143          @type hostname: str 
144          @param keytype: key type (C{"ssh-rsa"} or C{"ssh-dss"}) 
145          @type keytype: str 
146          @param key: the key to add 
147          @type key: L{PKey} 
148          """ 
149          for e in self._entries: 
150              if (hostname in e.hostnames) and (e.key.get_name() == keytype): 
151                  e.key = key 
152                  return 
153          self._entries.append(HostKeyEntry([hostname], key)) 
 154   
155 -    def load(self, filename): 
 156          """ 
157          Read a file of known SSH host keys, in the format used by openssh. 
158          This type of file unfortunately doesn't exist on Windows, but on 
159          posix, it will usually be stored in 
160          C{os.path.expanduser("~/.ssh/known_hosts")}. 
161   
162          If this method is called multiple times, the host keys are merged, 
163          not cleared.  So multiple calls to C{load} will just call L{add}, 
164          replacing any existing entries and adding new ones. 
165   
166          @param filename: name of the file to read host keys from 
167          @type filename: str 
168   
169          @raise IOError: if there was an error reading the file 
170          """ 
171          f = open(filename, 'r') 
172          for lineno, line in enumerate(f): 
173              line = line.strip() 
174              if (len(line) == 0) or (line[0] == '#'): 
175                  continue 
176              e = HostKeyEntry.from_line(line, lineno) 
177              if e is not None: 
178                  _hostnames = e.hostnames 
179                  for h in _hostnames: 
180                      if self.check(h, e.key): 
181                          e.hostnames.remove(h) 
182                  if len(e.hostnames): 
183                      self._entries.append(e) 
184          f.close() 
 185   
186 -    def save(self, filename): 
 187          """ 
188          Save host keys into a file, in the format used by openssh.  The order of 
189          keys in the file will be preserved when possible (if these keys were 
190          loaded from a file originally).  The single exception is that combined 
191          lines will be split into individual key lines, which is arguably a bug. 
192   
193          @param filename: name of the file to write 
194          @type filename: str 
195   
196          @raise IOError: if there was an error writing the file 
197   
198          @since: 1.6.1 
199          """ 
200          f = open(filename, 'w') 
201          for e in self._entries: 
202              line = e.to_line() 
203              if line: 
204                  f.write(line) 
205          f.close() 
 206   
208          """ 
209          Find a hostkey entry for a given hostname or IP.  If no entry is found, 
210          C{None} is returned.  Otherwise a dictionary of keytype to key is 
211          returned.  The keytype will be either C{"ssh-rsa"} or C{"ssh-dss"}. 
212   
213          @param hostname: the hostname (or IP) to lookup 
214          @type hostname: str 
215          @return: keys associated with this host (or C{None}) 
216          @rtype: dict(str, L{PKey}) 
217          """ 
218          class SubDict (UserDict.DictMixin): 
219              def __init__(self, hostname, entries, hostkeys): 
220                  self._hostname = hostname 
221                  self._entries = entries 
222                  self._hostkeys = hostkeys 
 223   
224              def __getitem__(self, key): 
225                  for e in self._entries: 
226                      if e.key.get_name() == key: 
227                          return e.key 
228                  raise KeyError(key) 
 229   
230              def __setitem__(self, key, val): 
231                  for e in self._entries: 
232                      if e.key is None: 
233                          continue 
234                      if e.key.get_name() == key: 
235                           
236                          e.key = val 
237                          break 
238                  else: 
239                       
240                      e = HostKeyEntry([hostname], val) 
241                      self._entries.append(e) 
242                      self._hostkeys._entries.append(e) 
243   
244              def keys(self): 
245                  return [e.key.get_name() for e in self._entries if e.key is not None] 
246   
247          entries = [] 
248          for e in self._entries: 
249              for h in e.hostnames: 
250                  if (h.startswith('|1|') and (self.hash_host(hostname, h) == h)) or (h == hostname): 
251                      entries.append(e) 
252          if len(entries) == 0: 
253              return None 
254          return SubDict(hostname, entries, self) 
255   
256 -    def check(self, hostname, key): 
 257          """ 
258          Return True if the given key is associated with the given hostname 
259          in this dictionary. 
260   
261          @param hostname: hostname (or IP) of the SSH server 
262          @type hostname: str 
263          @param key: the key to check 
264          @type key: L{PKey} 
265          @return: C{True} if the key is associated with the hostname; C{False} 
266              if not 
267          @rtype: bool 
268          """ 
269          k = self.lookup(hostname) 
270          if k is None: 
271              return False 
272          host_key = k.get(key.get_name(), None) 
273          if host_key is None: 
274              return False 
275          return str(host_key) == str(key) 
 276   
278          """ 
279          Remove all host keys from the dictionary. 
280          """ 
281          self._entries = [] 
 282   
284          ret = self.lookup(key) 
285          if ret is None: 
286              raise KeyError(key) 
287          return ret 
 288   
290           
291          if len(entry) == 0: 
292              self._entries.append(HostKeyEntry([hostname], None)) 
293              return 
294          for key_type in entry.keys(): 
295              found = False 
296              for e in self._entries: 
297                  if (hostname in e.hostnames) and (e.key.get_name() == key_type): 
298                       
299                      e.key = entry[key_type] 
300                      found = True 
301              if not found: 
302                  self._entries.append(HostKeyEntry([hostname], entry[key_type])) 
 303   
305           
306          ret = [] 
307          for e in self._entries: 
308              for h in e.hostnames: 
309                  if h not in ret: 
310                      ret.append(h) 
311          return ret 
 312   
314          ret = [] 
315          for k in self.keys(): 
316              ret.append(self.lookup(k)) 
317          return ret 
 318   
320          """ 
321          Return a "hashed" form of the hostname, as used by openssh when storing 
322          hashed hostnames in the known_hosts file. 
323   
324          @param hostname: the hostname to hash 
325          @type hostname: str 
326          @param salt: optional salt to use when hashing (must be 20 bytes long) 
327          @type salt: str 
328          @return: the hashed hostname 
329          @rtype: str 
330          """ 
331          if salt is None: 
332              salt = rng.read(SHA.digest_size) 
333          else: 
334              if salt.startswith('|1|'): 
335                  salt = salt.split('|')[2] 
336              salt = base64.decodestring(salt) 
337          assert len(salt) == SHA.digest_size 
338          hmac = HMAC.HMAC(salt, hostname, SHA).digest() 
339          hostkey = '|1|%s|%s' % (base64.encodestring(salt), base64.encodestring(hmac)) 
340          return hostkey.replace('\n', '') 
 341      hash_host = staticmethod(hash_host) 
342