diff --git a/test/test.py b/test/test.py index ae5e390..fdabc11 100755 --- a/test/test.py +++ b/test/test.py @@ -195,7 +195,7 @@ def test_in_place_yaml(self): self.assertEqual(tf.read(), b"foo\n...\n") self.assertEqual(tf2.read(), b"foo\n...\n") - # Files do not get overwritten on error (DeferredOutputStream logic) + # Files do not get overwritten on error self.run_yq("", ["-i", "-y", tf.name, tf2.name], expect_exit_codes=[3]) tf.seek(0) tf2.seek(0) diff --git a/yq/__init__.py b/yq/__init__.py index 477dd78..0bd2a0a 100644 --- a/yq/__init__.py +++ b/yq/__init__.py @@ -50,30 +50,6 @@ def tq_cli(): cli(input_format="toml", program_name="tomlq") -class DeferredOutputStream: - def __init__(self, name, mode="w"): - self.name = name - self.mode = mode - self._fh = None - - @property - def fh(self): - if self._fh is None: - self._fh = open(self.name, self.mode) - return self._fh - - def flush(self): - if self._fh is not None: - return self.fh.flush() - - def close(self): - if self._fh is not None: - return self.fh.close() - - def __getattr__(self, a): - return getattr(self.fh, a) - - def cli(args=None, input_format="yaml", program_name="yq"): parser = get_parser(program_name, __doc__) argcomplete.autocomplete(parser) @@ -131,8 +107,8 @@ def cli(args=None, input_format="yaml", program_name="yq"): yq_args = dict(input_format=input_format, program_name=program_name, jq_args=jq_args, **vars(args)) if in_place: - if args.output_format not in {"yaml", "annotated_yaml", "toml"}: - sys.exit("{}: -i/--in-place can only be used with -y/-Y/-t".format(program_name)) + if args.output_format not in {"yaml", "annotated_yaml", "toml", "xml"}: + sys.exit("{}: -i/--in-place can only be used with -y/-Y/-t/-x".format(program_name)) input_streams = yq_args.pop("input_streams") if len(input_streams) == 1 and input_streams[0].name == "": msg = "{}: -i/--in-place can only be used with filename arguments, not on standard input" @@ -145,7 +121,11 @@ def exit_handler(arg=None): if i < len(input_streams): yq_args["exit_func"] = exit_handler - yq(input_streams=[input_stream], output_stream=DeferredOutputStream(input_stream.name), **yq_args) + + with io.StringIO() as out_fh: + yq(input_streams=[input_stream], output_stream=out_fh, **yq_args) + with open(input_stream.name, "w") as fh: + fh.write(out_fh.getvalue()) else: yq(**yq_args)