from argparse import ArgumentParser from argparse import FileType import os import sys import tpm2 from tpm2 import ProtocolError import unittest import logging import tss2 pwd1 = "wibble" pwd2 = "newpassword" class SessionTest(unittest.TestCase): def setUp(self): self.c = tss2.Client() def tearDown(self): self.c.close() # open handles until failure. Return the ones we got def open_until(self, func): ha = [] try: for i in range(0, 10): h = func() print "Handle is %08x" % h ha.append(h) except tss2.tpm_error, e: if (e.rc != tpm2.TPM2_RC_SESSION_MEMORY and e.rc != tpm2.TPM2_RC_OBJECT_MEMORY): raise e return ha def open_handles(self): def func(): return self.c.start_session(tpm2.TPM2_SE_HMAC, self.c.SRK) return self.open_until(func) def open_transients(self): def func(): k = self.c.create_rsa(self.c.SRK, None) return self.c.load(self.c.SRK, k.outPrivate, k.outPublic, None) return self.open_until(func) def test_handle_clearing(self): t1 = self.open_transients() h1 = self.open_handles() print "Opened {} transients and {} handles".format(len(t1), len(h1)) self.c.close() self.c = tss2.Client() h2 = self.open_handles() t2 = self.open_transients() print "Opened {} transients and {} handles".format(len(t2), len(h2)) self.assertEqual(len(h1), len(h2)) self.assertEqual(len(t1), len(t2)) def test_transients(self): k = self.open_transients() self.c.flush_context(k[0]) self.c.change_auth(self.c.SRK, k[1], None, pwd1) fail = False try: self.c.change_auth(self.c.SRK, k[0], None, pwd1) except tss2.tpm_error, e: print "Expected failure {}".format(e) fail = True self.assertTrue(fail) l = self.open_transients() self.assertEqual(len(l), 1) def test_handle_flush_on_space_close(self): i = self.open_handles() print "Ran out of handles at %d" %len(i) self.c.close() self.c = tss2.Client() # closing and reopening a space session should clear out our handles j = self.open_handles() print "Ran out of handles at %d" %len(j) self.assertNotEqual(len(i), 0) self.assertEqual(len(i), len(j)) def test_flush(self): i = self.open_handles() print "opened %d handles" % len(i) self.c.flush_context(i[0]) self.c.flush_context(i[1]) i = self.open_handles() self.assertEqual(len(i), 2); def test_session_consumption(self): self.c.read_public(self.c.SRK) # authorization hmac session hmac = self.c.start_session(tpm2.TPM2_SE_HMAC) # parameter encryption session enc = self.c.start_session(tpm2.TPM2_SE_HMAC, self.c.SRK) # fill all remaing handles i = self.open_handles() # create rsa key continuing both hmac and encryption sessions self.c.create_rsa(self.c.SRK, pwd1, hmac, 1, enc, 1) # should be no handles left i = self.open_handles() self.assertEqual(len(i),0) # now create rsa key continuing hmac and consuming encryption k = self.c.create_rsa(self.c.SRK, pwd1, hmac, 1, enc, 0) # now should be one handle remaining i = self.open_handles() self.assertEqual(len(i),1) self.c.flush_context(i[0]) # check the hmac continuation actually works k = self.c.load(self.c.SRK, k.outPrivate, k.outPublic, None) print "Loaded key at handle %x" %k # and finally verify with an authenticated encrypted operation # consuming both handles enc = self.c.start_session(tpm2.TPM2_SE_HMAC, k) self.c.change_auth(self.c.SRK, k, pwd1, pwd2, hmac, 0, enc, 0) i = self.open_handles() self.assertEqual(len(i), 2) def test_space_exhaustion(self): c = [] h = [] # usually 3 max unsaved and 64 max contexts, so 23*3 = 69 should # mean that the first 4 are evicted for i in range(0, 23): self.c = tss2.Client() c.append(self.c) h.append(self.open_handles()) #try to use handle by creating an RSA key this should fail # because the session was evicted self.c = c[0] failed = 0 try: # hmac only self.c.create_rsa(self.c.SRK, pwd1, h[0][0], 1) except tss2.tpm_error,e: print "Expected Session Failure: {}".format(e) failed = 1 self.assertEqual(failed, 1) # pick the latest session and handle and it should succeed self.c = c[22] self.c.create_rsa(self.c.SRK, pwd1, h[22][0], 1, h[22][1], 1) def test_disallow_save_context(self): h = self.open_handles(); s = self.c failure = 0 try: o = self.c.context_save(h[0]); except tss2.tpm_error,e: failure = 1 self.assertEqual(failure, 1) def test_gap_error_first(self): c = tss2.Client() h = c.start_session(tpm2.TPM2_SE_HMAC) print "Handle %08x" %h for i in range(0,256): j = self.c.start_session(tpm2.TPM2_SE_HMAC) print "Flush Handle %08x" %j self.c.flush_context(j) c.flush_context(h) def test_gap_error_last(self): c = tss2.Client() h = c.start_session(tpm2.TPM2_SE_HMAC) print "Handle %08x" %h t = self.open_handles(); self.c.flush_context(t[len(t)-1]) for i in range(0,256): j = self.c.start_session(tpm2.TPM2_SE_HMAC) print "Flush Handle %08x" %j self.c.flush_context(j) c.flush_context(h) if __name__ == '__main__': unittest.main()