Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add directory transfer support for SFTPOperator #44126

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions providers/src/airflow/providers/sftp/hooks/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import warnings
from collections.abc import Sequence
from fnmatch import fnmatch
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable

import asyncssh
Expand Down Expand Up @@ -276,6 +277,49 @@ def delete_file(self, path: str) -> None:
conn = self.get_conn()
conn.remove(path)

def retrieve_directory(self, remote_full_path: str, local_full_path: str, prefetch: bool = True) -> None:
"""
Transfer the remote directory to a local location.

If local_full_path is a string path, the directory will be put
at that location.

:param remote_full_path: full path to the remote directory
:param local_full_path: full path to the local directory
:param prefetch: controls whether prefetch is performed (default: True)
"""
if Path(local_full_path).is_dir():
raise AirflowException(f"{local_full_path} already exists")
Path(local_full_path).mkdir(parents=True)
files, dirs, _ = self.get_tree_map(remote_full_path)
for dir_path in dirs:
new_local_path = os.path.join(local_full_path, os.path.relpath(dir_path, remote_full_path))
Path(new_local_path).mkdir(parents=True, exist_ok=True)
for file_path in files:
new_local_path = os.path.join(local_full_path, os.path.relpath(file_path, remote_full_path))
self.retrieve_file(file_path, new_local_path, prefetch)

def store_directory(self, remote_full_path: str, local_full_path: str, confirm: bool = True) -> None:
"""
Transfer a local directory to the remote location.

If local_full_path is a string path, the directory will be read
from that location.

:param remote_full_path: full path to the remote directory
:param local_full_path: full path to the local directory
"""
self.create_directory(remote_full_path)
Dawnpool marked this conversation as resolved.
Show resolved Hide resolved
for root, dirs, files in os.walk(local_full_path):
for dir_name in dirs:
dir_path = os.path.join(root, dir_name)
new_remote_path = os.path.join(remote_full_path, os.path.relpath(dir_path, local_full_path))
self.create_directory(new_remote_path)
for file_name in files:
file_path = os.path.join(root, file_name)
new_remote_path = os.path.join(remote_full_path, os.path.relpath(file_path, local_full_path))
self.store_file(new_remote_path, file_path, confirm)

def get_mod_time(self, path: str) -> str:
"""
Get an entry's modification time.
Expand Down
12 changes: 10 additions & 2 deletions providers/src/airflow/providers/sftp/operators/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,14 +180,22 @@ def execute(self, context: Any) -> str | list[str] | None:
Path(local_folder).mkdir(parents=True, exist_ok=True)
file_msg = f"from {_remote_filepath} to {_local_filepath}"
self.log.info("Starting to transfer %s", file_msg)
self.sftp_hook.retrieve_file(_remote_filepath, _local_filepath)
if self.sftp_hook.isdir(_remote_filepath):
self.sftp_hook.retrieve_directory(_remote_filepath, _local_filepath)
else:
self.sftp_hook.retrieve_file(_remote_filepath, _local_filepath)
else:
remote_folder = os.path.dirname(_remote_filepath)
if self.create_intermediate_dirs:
self.sftp_hook.create_directory(remote_folder)
file_msg = f"from {_local_filepath} to {_remote_filepath}"
self.log.info("Starting to transfer file %s", file_msg)
self.sftp_hook.store_file(_remote_filepath, _local_filepath, confirm=self.confirm)
if os.path.isdir(_local_filepath):
self.sftp_hook.store_directory(
_remote_filepath, _local_filepath, confirm=self.confirm
)
else:
self.sftp_hook.store_file(_remote_filepath, _local_filepath, confirm=self.confirm)

except Exception as e:
raise AirflowException(f"Error while transferring {file_msg}, error: {e}")
Expand Down
17 changes: 17 additions & 0 deletions providers/tests/sftp/hooks/test_sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,23 @@ def test_get_matched_files_with_different_pattern(self):
output = self.hook.get_files_by_pattern(self.temp_dir, "*_file_*.txt")
assert output == [ANOTHER_FILE_FOR_TESTS]

def test_store_and_retrieve_directory(self):
stored_dir_name = "stored_dir"
self.hook.store_directory(
remote_full_path=os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, stored_dir_name),
local_full_path=os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, SUB_DIR),
)
output = self.hook.list_directory(
path=os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, stored_dir_name)
)
assert output == [TMP_FILE_FOR_TESTS]
retrieved_dir_name = "retrieved_dir"
self.hook.retrieve_directory(
remote_full_path=os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, stored_dir_name),
local_full_path=os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, retrieved_dir_name),
)
assert retrieved_dir_name in os.listdir(os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS))


class MockSFTPClient:
def __init__(self):
Expand Down
30 changes: 30 additions & 0 deletions providers/tests/sftp/operators/test_sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,21 @@ def test_multiple_paths_get(self, mock_get):
assert args0 == (remote_filepath[0], local_filepath[0])
assert args1 == (remote_filepath[1], local_filepath[1])

@mock.patch("airflow.providers.sftp.operators.sftp.SFTPHook.retrieve_directory")
def test_str_dirpaths_get(self, mock_get):
local_dirpath = "/tmp_local"
remote_dirpath = "/tmp"
SFTPOperator(
task_id="test_str_to_list",
sftp_hook=self.sftp_hook,
local_filepath=local_dirpath,
remote_filepath=remote_dirpath,
operation=SFTPOperation.GET,
).execute(None)
assert mock_get.call_count == 1
args, _ = mock_get.call_args_list[0]
assert args == (remote_dirpath, local_dirpath)

@mock.patch("airflow.providers.sftp.operators.sftp.SFTPHook.store_file")
def test_str_filepaths_put(self, mock_get):
local_filepath = "/tmp/test"
Expand Down Expand Up @@ -505,6 +520,21 @@ def test_multiple_paths_put(self, mock_put):
assert args0 == (remote_filepath[0], local_filepath[0])
assert args1 == (remote_filepath[1], local_filepath[1])

@mock.patch("airflow.providers.sftp.operators.sftp.SFTPHook.store_directory")
def test_str_dirpaths_put(self, mock_get):
local_dirpath = "/tmp"
remote_dirpath = "/tmp_remote"
SFTPOperator(
task_id="test_str_dirpaths_put",
sftp_hook=self.sftp_hook,
local_filepath=local_dirpath,
remote_filepath=remote_dirpath,
operation=SFTPOperation.PUT,
).execute(None)
assert mock_get.call_count == 1
args, _ = mock_get.call_args_list[0]
assert args == (remote_dirpath, local_dirpath)

@mock.patch("airflow.providers.sftp.operators.sftp.SFTPHook.retrieve_file")
def test_return_str_when_local_filepath_was_str(self, mock_get):
local_filepath = "/tmp/ltest1"
Expand Down