diff --git a/jaydebeapi/__init__.py b/jaydebeapi/__init__.py index e3248c0..a28461a 100644 --- a/jaydebeapi/__init__.py +++ b/jaydebeapi/__init__.py @@ -163,7 +163,7 @@ def _handle_sql_exception_jpype(): exc_type = InterfaceError reraise(exc_type, exc_info[1], exc_info[2]) - + def _jdbc_connect_jpype(jclassname, url, driver_args, jars, libs): import jpype if not jpype.isJVMStarted(): @@ -451,6 +451,12 @@ class Connection(object): def cursor(self): return Cursor(self, self._converters) + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + # DB-API 2.0 Cursor Object class Cursor(object): @@ -604,6 +610,12 @@ class Cursor(object): def setoutputsize(self, size, column=None): pass + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + def _unknownSqlTypeConverter(rs, col): return rs.getObject(col) diff --git a/test/test_integration.py b/test/test_integration.py index 8818975..a99c69d 100644 --- a/test/test_integration.py +++ b/test/test_integration.py @@ -207,6 +207,17 @@ class IntegrationTestBase(object): cursor.execute("select * from ACCOUNT") self.assertEqual(cursor.rowcount, -1) + def test_connection_with_statement(self): + with self.connect() as conn: + self.assertEqual(conn._closed, False) + self.assertEqual(conn._closed, True) + + def test_cursor_with_statement(self): + with self.conn.cursor() as cursor: + cursor.execute("select 1 from ACCOUNT") + self.assertIsNotNone(cursor._connection) + self.assertIsNone(cursor._connection) + class SqliteTestBase(IntegrationTestBase): def setUpSql(self):