diff mbox series

[PULL,13/25] python/qemu: fix socket.makefile() typing

Message ID 20200531163846.25363-14-philmd@redhat.com (mailing list archive)
State New, archived
Headers show
Series [PULL,01/25] scripts/qemugdb: Remove shebang header | expand

Commit Message

Philippe Mathieu-Daudé May 31, 2020, 4:38 p.m. UTC
From: John Snow <jsnow@redhat.com>

Note:

A bug in typeshed (https://github.com/python/typeshed/issues/3977)
misinterprets the type of makefile(). Work around this by explicitly
stating that we are opening a text-mode file.

Signed-off-by: John Snow <jsnow@redhat.com>
Reviewed-by: Philippe Mathieu-Daudé <philmd@redhat.com>
Message-Id: <20200514055403.18902-13-jsnow@redhat.com>
Signed-off-by: Philippe Mathieu-Daudé <philmd@redhat.com>
---
 python/qemu/qmp.py   | 10 +++++++---
 python/qemu/qtest.py | 12 ++++++++----
 2 files changed, 15 insertions(+), 7 deletions(-)
diff mbox series

Patch

diff --git a/python/qemu/qmp.py b/python/qemu/qmp.py
index 6ae7693965..73d49050ed 100644
--- a/python/qemu/qmp.py
+++ b/python/qemu/qmp.py
@@ -11,6 +11,10 @@ 
 import errno
 import socket
 import logging
+from typing import (
+    Optional,
+    TextIO,
+)
 
 
 class QMPError(Exception):
@@ -61,7 +65,7 @@  def __init__(self, address, server=False, nickname=None):
         self.__events = []
         self.__address = address
         self.__sock = self.__get_sock()
-        self.__sockfile = None
+        self.__sockfile: Optional[TextIO] = None
         self._nickname = nickname
         if self._nickname:
             self.logger = logging.getLogger('QMP').getChild(self._nickname)
@@ -157,7 +161,7 @@  def connect(self, negotiate=True):
         @raise QMPCapabilitiesError if fails to negotiate capabilities
         """
         self.__sock.connect(self.__address)
-        self.__sockfile = self.__sock.makefile()
+        self.__sockfile = self.__sock.makefile(mode='r')
         if negotiate:
             return self.__negotiate_capabilities()
         return None
@@ -180,7 +184,7 @@  def accept(self, timeout=15.0):
         """
         self.__sock.settimeout(timeout)
         self.__sock, _ = self.__sock.accept()
-        self.__sockfile = self.__sock.makefile()
+        self.__sockfile = self.__sock.makefile(mode='r')
         return self.__negotiate_capabilities()
 
     def cmd_obj(self, qmp_cmd):
diff --git a/python/qemu/qtest.py b/python/qemu/qtest.py
index 7943487c2b..4c88590eb0 100644
--- a/python/qemu/qtest.py
+++ b/python/qemu/qtest.py
@@ -19,6 +19,7 @@ 
 
 import socket
 import os
+from typing import Optional, TextIO
 
 from .machine import QEMUMachine
 
@@ -40,7 +41,7 @@  class QEMUQtestProtocol:
     def __init__(self, address, server=False):
         self._address = address
         self._sock = self._get_sock()
-        self._sockfile = None
+        self._sockfile: Optional[TextIO] = None
         if server:
             self._sock.bind(self._address)
             self._sock.listen(1)
@@ -59,7 +60,7 @@  def connect(self):
         @raise socket.error on socket connection errors
         """
         self._sock.connect(self._address)
-        self._sockfile = self._sock.makefile()
+        self._sockfile = self._sock.makefile(mode='r')
 
     def accept(self):
         """
@@ -68,7 +69,7 @@  def accept(self):
         @raise socket.error on socket connection errors
         """
         self._sock, _ = self._sock.accept()
-        self._sockfile = self._sock.makefile()
+        self._sockfile = self._sock.makefile(mode='r')
 
     def cmd(self, qtest_cmd):
         """
@@ -76,6 +77,7 @@  def cmd(self, qtest_cmd):
 
         @param qtest_cmd: qtest command text to be sent
         """
+        assert self._sockfile is not None
         self._sock.sendall((qtest_cmd + "\n").encode('utf-8'))
         resp = self._sockfile.readline()
         return resp
@@ -83,7 +85,9 @@  def cmd(self, qtest_cmd):
     def close(self):
         """Close this socket."""
         self._sock.close()
-        self._sockfile.close()
+        if self._sockfile:
+            self._sockfile.close()
+            self._sockfile = None
 
     def settimeout(self, timeout):
         """Set a timeout, in seconds."""