diff mbox series

[RESEND,b4,4/5] Provide overloads for git_run_command

Message ID 20241108-better-type-annotations-v1-4-ebaf0129d9e1@gmail.com (mailing list archive)
State New
Headers show
Series Resolve some static typing errors | expand

Commit Message

Tamir Duberstein Nov. 8, 2024, 12:31 p.m. UTC
This requires explicit use of kwargs to allow decode to be a required
parameter in the negative case (which is necessary to disambiguate the
overloads).

Callers now know for certain whether they get str or bytes.

Signed-off-by: Tamir Duberstein <tamird@gmail.com>
---
 src/b4/__init__.py | 37 ++++++++++++++++++++++++++-----------
 1 file changed, 26 insertions(+), 11 deletions(-)
diff mbox series

Patch

diff --git a/src/b4/__init__.py b/src/b4/__init__.py
index d1f0fc5b56693d7d234bc4aced65166244928728..e71fbe7d8ea8d4ca19228690ce7c186cc91b81ad 100644
--- a/src/b4/__init__.py
+++ b/src/b4/__init__.py
@@ -36,7 +36,7 @@  import requests
 
 from pathlib import Path
 from contextlib import contextmanager
-from typing import Optional, Tuple, Set, List, BinaryIO, Union, Sequence, Literal, Iterator, Dict, overload
+from typing import Optional, Tuple, Set, List, BinaryIO, Union, Sequence, Literal, Iterator, Dict, TypeVar, overload
 
 from email import charset
 
@@ -2697,14 +2697,27 @@  def gpg_run_command(args: List[str], stdin: Optional[bytes] = None) -> Tuple[int
     return _run_command(cmdargs, stdin=stdin)
 
 
+@overload
+def git_run_command(gitdir: Optional[Union[str, Path]], args: List[str], stdin: Optional[bytes] = ...,
+                    *, logstderr: bool = ..., decode: Literal[False],
+                    rundir: Optional[str] = ...) -> Tuple[int, bytes]:
+    ...
+
+
+@overload
+def git_run_command(gitdir: Optional[Union[str, Path]], args: List[str], stdin: Optional[bytes] = ...,
+                    *, logstderr: bool = ..., decode: Literal[True] = ...,
+                    rundir: Optional[str] = ...) -> Tuple[int, str]:
+    ...
+
 def git_run_command(gitdir: Optional[Union[str, Path]], args: List[str], stdin: Optional[bytes] = None,
-                    logstderr: bool = False, decode: bool = True,
+                    *, logstderr: bool = False, decode: bool = True,
                     rundir: Optional[str] = None) -> Tuple[int, Union[str, bytes]]:
     cmdargs = ['git', '--no-pager']
     if gitdir:
         if os.path.exists(os.path.join(gitdir, '.git')):
             gitdir = os.path.join(gitdir, '.git')
-        cmdargs += ['--git-dir', gitdir]
+        cmdargs += ['--git-dir', str(gitdir)]
 
     # counteract some potential local settings
     if args[0] == 'log':
@@ -2714,16 +2727,18 @@  def git_run_command(gitdir: Optional[Union[str, Path]], args: List[str], stdin:
 
     ecode, out, err = _run_command(cmdargs, stdin=stdin, rundir=rundir)
 
-    if decode:
-        out = out.decode(errors='replace')
+    U = TypeVar('U', str, bytes)
+
+    def _handle(out: U, err: U) -> Tuple[int, Union[str, bytes]]:
+        if logstderr and len(err.strip()):
+            logger.debug('Stderr: %s', err)
+            out += err
 
-    if logstderr and len(err.strip()):
-        if decode:
-            err = err.decode(errors='replace')
-        logger.debug('Stderr: %s', err)
-        out += err
+        return ecode, out
 
-    return ecode, out
+    if decode:
+        return _handle(out.decode(errors='replace'), err.decode(errors='replace'))
+    return _handle(out, err)
 
 
 def git_credential_fill(gitdir: Optional[str], protocol: str, host: str, username: str) -> Optional[str]: