Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 73 additions & 8 deletions crates/remote/src/transport/ssh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ impl RemoteConnection for SshRemoteConnection {
let mut command = util::command::new_smol_command("scp");
let output = self
.socket
.ssh_options(&mut command)
.ssh_options(&mut command, false)
.args(
self.socket
.connection_options
Expand Down Expand Up @@ -648,7 +648,7 @@ impl SshRemoteConnection {
let mut command = util::command::new_smol_command("scp");
let output = self
.socket
.ssh_options(&mut command)
.ssh_options(&mut command, false)
.args(
self.socket
.connection_options
Expand Down Expand Up @@ -729,7 +729,7 @@ impl SshSocket {
to_run.push_str(&shlex::try_quote(arg.as_ref()).unwrap());
}
let to_run = format!("cd; {to_run}");
self.ssh_options(&mut command)
self.ssh_options(&mut command, true)
.arg(self.connection_options.ssh_url())
.arg("-T")
.arg(to_run);
Expand All @@ -748,23 +748,43 @@ impl SshSocket {
}

#[cfg(not(target_os = "windows"))]
fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
fn ssh_options<'a>(
&self,
command: &'a mut process::Command,
include_port_forwards: bool,
) -> &'a mut process::Command {
let args = if include_port_forwards {
self.connection_options.additional_args()
} else {
self.connection_options.additional_args_for_scp()
};

command
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.args(self.connection_options.additional_args())
.args(args)
.args(["-o", "ControlMaster=no", "-o"])
.arg(format!("ControlPath={}", self.socket_path.display()))
}

#[cfg(target_os = "windows")]
fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
fn ssh_options<'a>(
&self,
command: &'a mut process::Command,
include_port_forwards: bool,
) -> &'a mut process::Command {
let args = if include_port_forwards {
self.connection_options.additional_args()
} else {
self.connection_options.additional_args_for_scp()
};

command
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.args(self.connection_options.additional_args())
.args(args)
.envs(self.envs.clone())
}

Expand Down Expand Up @@ -991,8 +1011,12 @@ impl SshConnectionOptions {
result
}

pub fn additional_args_for_scp(&self) -> Vec<String> {
self.args.iter().flatten().cloned().collect::<Vec<String>>()
}

pub fn additional_args(&self) -> Vec<String> {
let mut args = self.args.iter().flatten().cloned().collect::<Vec<String>>();
let mut args = self.additional_args_for_scp();

if let Some(forwards) = &self.port_forwards {
args.extend(forwards.iter().map(|pf| {
Expand Down Expand Up @@ -1169,4 +1193,45 @@ mod tests {

Ok(())
}

#[test]
fn scp_args_exclude_port_forward_flags() {
let options = SshConnectionOptions {
host: "example.com".into(),
args: Some(vec![
"-p".to_string(),
"2222".to_string(),
"-o".to_string(),
"StrictHostKeyChecking=no".to_string(),
]),
port_forwards: Some(vec![SshPortForwardOption {
local_host: Some("127.0.0.1".to_string()),
local_port: 8080,
remote_host: Some("127.0.0.1".to_string()),
remote_port: 80,
}]),
..Default::default()
};

let ssh_args = options.additional_args();
assert!(
ssh_args.iter().any(|arg| arg.starts_with("-L")),
"expected ssh args to include port-forward: {ssh_args:?}"
);

let scp_args = options.additional_args_for_scp();
assert_eq!(
scp_args,
vec![
"-p".to_string(),
"2222".to_string(),
"-o".to_string(),
"StrictHostKeyChecking=no".to_string()
]
);
assert!(
scp_args.iter().all(|arg| !arg.starts_with("-L")),
"scp args should not contain port forward flags: {scp_args:?}"
);
}
}
Loading