Skip to content

Commit 73b4e7b

Browse files
pretty print cli output
1 parent 82d011c commit 73b4e7b

File tree

2 files changed

+28
-5
lines changed

2 files changed

+28
-5
lines changed

ads/cli.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import sys
99

1010
import fire
11+
from dataclasses import is_dataclass
1112
from ads.common import logger
1213

1314
try:
@@ -70,11 +71,32 @@ def _SeparateFlagArgs(args):
7071
fire.core.parser.SeparateFlagArgs = _SeparateFlagArgs
7172

7273

74+
def serialize(data):
75+
"""Serialize dataclass objects or lists of dataclass objects.
76+
Parameters:
77+
data: A dataclass object or a list of dataclass objects.
78+
Returns:
79+
None
80+
Prints:
81+
The string representation of each dataclass object, or the string representation of any other type of object.
82+
"""
83+
if isinstance(data, list):
84+
[print(str(item) if hasattr(item, "__str__") else repr(item)) for item in data]
85+
else:
86+
print(
87+
str(data)
88+
if (is_dataclass(data) and hasattr(data, "__str__"))
89+
else repr(data)
90+
)
91+
92+
7393
def cli():
7494
if len(sys.argv) > 1 and sys.argv[1] == "aqua":
7595
from ads.aqua.cli import AquaCommand
7696

77-
fire.Fire(AquaCommand, command=sys.argv[2:], name="ads aqua")
97+
fire.Fire(
98+
AquaCommand, command=sys.argv[2:], name="ads aqua", serialize=serialize
99+
)
78100
else:
79101
click_cli()
80102

ads/common/serializer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,10 @@ def to_json(
195195
`None` in case when `uri` provided.
196196
"""
197197
json_string = json.dumps(
198-
self.to_dict(**kwargs), cls=encoder, default=default or self.serialize
198+
self.to_dict(**kwargs),
199+
cls=encoder,
200+
default=default or self.serialize,
201+
indent=4,
199202
)
200203
if uri:
201204
self._write_to_file(s=json_string, uri=uri, **kwargs)
@@ -463,9 +466,7 @@ def from_dict(
463466
"These fields will be ignored."
464467
)
465468

466-
obj = cls(
467-
**{key: obj_dict.get(key) for key in allowed_fields}
468-
)
469+
obj = cls(**{key: obj_dict.get(key) for key in allowed_fields})
469470

470471
for key, value in obj_dict.items():
471472
if (

0 commit comments

Comments
 (0)