mirror of
https://github.com/EthanShoeDev/fressh.git
synced 2026-01-11 14:22:51 +00:00
rust changes
This commit is contained in:
@@ -15,22 +15,19 @@ use thiserror::Error;
|
||||
use tokio::sync::{broadcast, Mutex as AsyncMutex};
|
||||
|
||||
use russh::{self, client, ChannelMsg, Disconnect};
|
||||
use russh::client::{Config as ClientConfig, Handle as ClientHandle};
|
||||
use russh_keys::{Algorithm as KeyAlgorithm, EcdsaCurve, PrivateKey as RusshKeysPrivateKey};
|
||||
use russh::keys::{PrivateKey as RusshPrivateKey, PrivateKeyWithHashAlg};
|
||||
use russh::client::{Config, Handle as ClientHandle};
|
||||
use russh_keys::{Algorithm, EcdsaCurve};
|
||||
use russh::keys::{PrivateKey, PrivateKeyWithHashAlg};
|
||||
use russh_keys::ssh_key::{self, LineEnding};
|
||||
use bytes::Bytes;
|
||||
|
||||
uniffi::setup_scaffolding!();
|
||||
|
||||
// No global registries; handles are the only access points.
|
||||
|
||||
/// ---------- Types ----------
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, uniffi::Enum)]
|
||||
pub enum Security {
|
||||
Password { password: String },
|
||||
Key { key_id: String }, // (key-based auth can be wired later)
|
||||
Key { private_key_content: String }, // (key-based auth can be wired later)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, uniffi::Record)]
|
||||
@@ -41,30 +38,35 @@ pub struct ConnectionDetails {
|
||||
pub security: Security,
|
||||
}
|
||||
|
||||
/// Options for establishing a TCP connection and authenticating.
|
||||
/// Listener is embedded here so TS has a single arg.
|
||||
#[derive(Clone, uniffi::Record)]
|
||||
pub struct ConnectOptions {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub username: String,
|
||||
pub security: Security,
|
||||
pub on_status_change: Option<Arc<dyn StatusListener>>,
|
||||
pub connection_details: ConnectionDetails,
|
||||
pub on_connection_progress_callback: Option<Arc<dyn ConnectProgressCallback>>,
|
||||
pub on_disconnected_callback: Option<Arc<dyn ConnectionDisconnectedCallback>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, uniffi::Enum)]
|
||||
pub enum SSHConnectionStatus {
|
||||
TcpConnecting,
|
||||
pub enum SshConnectionProgressEvent {
|
||||
// Before any progress events, assume: TcpConnecting
|
||||
TcpConnected,
|
||||
TcpDisconnected,
|
||||
ShellConnecting,
|
||||
ShellConnected,
|
||||
ShellDisconnected,
|
||||
SshHandshake,
|
||||
// If promise has not resolved, assume: Authenticating
|
||||
// After promise resolves, assume: Connected
|
||||
}
|
||||
|
||||
/// PTY types similar to the old TS lib (plus xterm-256color, which is common).
|
||||
#[uniffi::export(with_foreign)]
|
||||
pub trait ConnectProgressCallback: Send + Sync {
|
||||
fn on_change(&self, status: SshConnectionProgressEvent);
|
||||
}
|
||||
|
||||
#[uniffi::export(with_foreign)]
|
||||
pub trait ConnectionDisconnectedCallback: Send + Sync {
|
||||
fn on_change(&self, connection_id: String);
|
||||
}
|
||||
|
||||
// Note: russh accepts an untyped string for the terminal type
|
||||
#[derive(Debug, Clone, Copy, PartialEq, uniffi::Enum)]
|
||||
pub enum PtyType {
|
||||
pub enum TerminalType {
|
||||
Vanilla,
|
||||
Vt100,
|
||||
Vt102,
|
||||
@@ -73,52 +75,20 @@ pub enum PtyType {
|
||||
Xterm,
|
||||
Xterm256,
|
||||
}
|
||||
impl PtyType {
|
||||
impl TerminalType {
|
||||
fn as_ssh_name(self) -> &'static str {
|
||||
match self {
|
||||
PtyType::Vanilla => "vanilla",
|
||||
PtyType::Vt100 => "vt100",
|
||||
PtyType::Vt102 => "vt102",
|
||||
PtyType::Vt220 => "vt220",
|
||||
PtyType::Ansi => "ansi",
|
||||
PtyType::Xterm => "xterm",
|
||||
PtyType::Xterm256 => "xterm-256color",
|
||||
TerminalType::Vanilla => "vanilla",
|
||||
TerminalType::Vt100 => "vt100",
|
||||
TerminalType::Vt102 => "vt102",
|
||||
TerminalType::Vt220 => "vt220",
|
||||
TerminalType::Ansi => "ansi",
|
||||
TerminalType::Xterm => "xterm",
|
||||
TerminalType::Xterm256 => "xterm-256color",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error, uniffi::Error)]
|
||||
pub enum SshError {
|
||||
#[error("Disconnected")]
|
||||
Disconnected,
|
||||
#[error("Unsupported key type")]
|
||||
UnsupportedKeyType,
|
||||
#[error("Auth failed: {0}")]
|
||||
Auth(String),
|
||||
#[error("Shell already running")]
|
||||
ShellAlreadyRunning,
|
||||
#[error("russh error: {0}")]
|
||||
Russh(String),
|
||||
#[error("russh-keys error: {0}")]
|
||||
RusshKeys(String),
|
||||
}
|
||||
impl From<russh::Error> for SshError {
|
||||
fn from(e: russh::Error) -> Self { SshError::Russh(e.to_string()) }
|
||||
}
|
||||
impl From<russh_keys::Error> for SshError {
|
||||
fn from(e: russh_keys::Error) -> Self { SshError::RusshKeys(e.to_string()) }
|
||||
}
|
||||
impl From<ssh_key::Error> for SshError {
|
||||
fn from(e: ssh_key::Error) -> Self { SshError::RusshKeys(e.to_string()) }
|
||||
}
|
||||
|
||||
/// Status callback (used separately by connect and by start_shell)
|
||||
#[uniffi::export(with_foreign)]
|
||||
pub trait StatusListener: Send + Sync {
|
||||
fn on_change(&self, status: SSHConnectionStatus);
|
||||
}
|
||||
|
||||
// Stream kind for terminal output
|
||||
#[derive(Debug, Clone, Copy, PartialEq, uniffi::Enum)]
|
||||
pub enum StreamKind { Stdout, Stderr }
|
||||
|
||||
@@ -153,43 +123,43 @@ pub enum KeyType {
|
||||
Ed448,
|
||||
}
|
||||
|
||||
/// Options for starting a shell.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, uniffi::Record)]
|
||||
pub struct TerminalMode {
|
||||
pub opcode: u8, // PTY opcode (matches russh::Pty discriminants)
|
||||
pub value: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, uniffi::Record)]
|
||||
pub struct TerminalSize {
|
||||
pub row_height: Option<u32>,
|
||||
pub col_width: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, uniffi::Record)]
|
||||
pub struct TerminalPixelSize {
|
||||
pub pixel_width: Option<u32>,
|
||||
pub pixel_height: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Clone, uniffi::Record)]
|
||||
pub struct StartShellOptions {
|
||||
pub pty: PtyType,
|
||||
pub on_status_change: Option<Arc<dyn StatusListener>>,
|
||||
pub term: TerminalType,
|
||||
pub terminal_mode: Option<Vec<TerminalMode>>,
|
||||
pub terminal_size: Option<TerminalSize>,
|
||||
pub terminal_pixel_size: Option<TerminalPixelSize>,
|
||||
pub on_closed_callback: Option<Arc<dyn ShellClosedCallback>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, uniffi::Enum)]
|
||||
pub enum Cursor {
|
||||
Head,
|
||||
TailBytes { bytes: u64 },
|
||||
Seq { seq: u64 },
|
||||
TimeMs { t_ms: f64 },
|
||||
Live,
|
||||
#[uniffi::export(with_foreign)]
|
||||
pub trait ShellClosedCallback: Send + Sync {
|
||||
fn on_change(&self, channel_id: u32);
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, uniffi::Record)]
|
||||
pub struct ListenerOptions {
|
||||
pub cursor: Cursor,
|
||||
pub coalesce_ms: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, uniffi::Record)]
|
||||
pub struct BufferReadResult {
|
||||
pub chunks: Vec<TerminalChunk>,
|
||||
pub next_seq: u64,
|
||||
pub dropped: Option<DroppedRange>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, uniffi::Record)]
|
||||
pub struct BufferStats {
|
||||
pub ring_bytes: u64,
|
||||
pub used_bytes: u64,
|
||||
pub chunks: u64,
|
||||
pub head_seq: u64,
|
||||
pub tail_seq: u64,
|
||||
pub dropped_bytes_total: u64,
|
||||
pub struct SshConnectionInfoProgressTimings {
|
||||
// TODO: We should have a field for each SshConnectionProgressEvent. Would be great if this were enforced by the compiler.
|
||||
pub tcp_established_at_ms: f64,
|
||||
pub ssh_handshake_at_ms: f64,
|
||||
}
|
||||
|
||||
/// Snapshot of current connection info for property-like access in TS.
|
||||
@@ -198,7 +168,8 @@ pub struct SshConnectionInfo {
|
||||
pub connection_id: String,
|
||||
pub connection_details: ConnectionDetails,
|
||||
pub created_at_ms: f64,
|
||||
pub tcp_established_at_ms: f64,
|
||||
pub connected_at_ms: f64,
|
||||
pub progress_timings: SshConnectionInfoProgressTimings,
|
||||
}
|
||||
|
||||
/// Snapshot of shell session info for property-like access in TS.
|
||||
@@ -206,41 +177,35 @@ pub struct SshConnectionInfo {
|
||||
pub struct ShellSessionInfo {
|
||||
pub channel_id: u32,
|
||||
pub created_at_ms: f64,
|
||||
pub pty: PtyType,
|
||||
pub connected_at_ms: f64,
|
||||
pub term: TerminalType,
|
||||
pub connection_id: String,
|
||||
}
|
||||
|
||||
/// ---------- Connection object (no shell until start_shell) ----------
|
||||
|
||||
#[derive(uniffi::Object)]
|
||||
pub struct SSHConnection {
|
||||
connection_id: String,
|
||||
connection_details: ConnectionDetails,
|
||||
created_at_ms: f64,
|
||||
tcp_established_at_ms: f64,
|
||||
pub struct SshConnection {
|
||||
info: SshConnectionInfo,
|
||||
client_handle: AsyncMutex<ClientHandle<NoopHandler>>,
|
||||
|
||||
handle: AsyncMutex<ClientHandle<NoopHandler>>,
|
||||
|
||||
// Shell state (one active shell per connection by design).
|
||||
shell: AsyncMutex<Option<Arc<ShellSession>>>,
|
||||
shells: AsyncMutex<HashMap<u32, Arc<ShellSession>>>,
|
||||
|
||||
// Weak self for child sessions to refer back without cycles.
|
||||
self_weak: AsyncMutex<Weak<SSHConnection>>,
|
||||
self_weak: AsyncMutex<Weak<SshConnection>>,
|
||||
}
|
||||
|
||||
#[derive(uniffi::Object)]
|
||||
pub struct ShellSession {
|
||||
info: ShellSessionInfo,
|
||||
on_closed_callback: Option<Arc<dyn ShellClosedCallback>>,
|
||||
|
||||
// Weak backref; avoid retain cycle.
|
||||
parent: std::sync::Weak<SSHConnection>,
|
||||
channel_id: u32,
|
||||
parent: std::sync::Weak<SshConnection>,
|
||||
|
||||
writer: AsyncMutex<russh::ChannelWriteHalf<client::Msg>>,
|
||||
// We keep the reader task to allow cancellation on close.
|
||||
reader_task: tokio::task::JoinHandle<()>,
|
||||
// Only used for Shell* statuses.
|
||||
shell_status_listener: Option<Arc<dyn StatusListener>>,
|
||||
created_at_ms: f64,
|
||||
pty: PtyType,
|
||||
|
||||
|
||||
// Ring buffer
|
||||
ring: Arc<Mutex<std::collections::VecDeque<Arc<Chunk>>>>,
|
||||
ring_bytes_capacity: Arc<AtomicUsize>,
|
||||
@@ -259,19 +224,55 @@ pub struct ShellSession {
|
||||
rt_handle: tokio::runtime::Handle,
|
||||
}
|
||||
|
||||
impl fmt::Debug for SSHConnection {
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, uniffi::Enum)]
|
||||
pub enum Cursor {
|
||||
Head, // start from the beginning
|
||||
TailBytes { bytes: u64 }, // start from the end of the last N bytes
|
||||
Seq { seq: u64 }, // start from the given sequence number
|
||||
TimeMs { t_ms: f64 }, // start from the given time in milliseconds
|
||||
Live, // start from the live stream
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, uniffi::Record)]
|
||||
pub struct ListenerOptions {
|
||||
pub cursor: Cursor,
|
||||
pub coalesce_ms: Option<u32>, // coalesce chunks into this many milliseconds
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, uniffi::Record)]
|
||||
pub struct BufferReadResult {
|
||||
pub chunks: Vec<TerminalChunk>,
|
||||
pub next_seq: u64,
|
||||
pub dropped: Option<DroppedRange>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, uniffi::Record)]
|
||||
pub struct BufferStats {
|
||||
pub ring_bytes_count: u64,
|
||||
pub used_bytes: u64,
|
||||
pub head_seq: u64,
|
||||
pub tail_seq: u64,
|
||||
pub dropped_bytes_total: u64,
|
||||
|
||||
pub chunks_count: u64,
|
||||
}
|
||||
|
||||
|
||||
|
||||
impl fmt::Debug for SshConnection {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("SSHConnection")
|
||||
.field("connection_details", &self.connection_details)
|
||||
.field("created_at_ms", &self.created_at_ms)
|
||||
.field("tcp_established_at_ms", &self.tcp_established_at_ms)
|
||||
f.debug_struct("SshConnectionHandle")
|
||||
.field("info.connection_details", &self.info.connection_details)
|
||||
.field("info.created_at_ms", &self.info.created_at_ms)
|
||||
.field("info.connected_at_ms", &self.info.connected_at_ms)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
// Internal chunk type kept in ring/broadcast
|
||||
#[derive(Debug)]
|
||||
struct Chunk {
|
||||
struct Chunk { // TODO: This is very similar to TerminalChunk. The only difference is the bytes field
|
||||
seq: u64,
|
||||
t_ms: f64,
|
||||
stream: StreamKind,
|
||||
@@ -293,68 +294,96 @@ impl client::Handler for NoopHandler {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// ---------- Methods ----------
|
||||
static DEFAULT_TERMINAL_MODES: &[(russh::Pty, u32)] = &[
|
||||
(russh::Pty::ECHO, 1), // This will cause the terminal to echo the characters back to the client.
|
||||
(russh::Pty::ECHOK, 1), // After the line-kill character (often Ctrl+U), echo a newline.
|
||||
(russh::Pty::ECHOE, 1), // Visually erase on backspace (erase using BS-SP-BS sequence).
|
||||
(russh::Pty::ICANON, 1), // Canonical (cooked) mode: line editing; input delivered line-by-line.
|
||||
(russh::Pty::ISIG, 1), // Generate signals on special chars (e.g., Ctrl+C -> SIGINT, Ctrl+Z -> SIGTSTP).
|
||||
(russh::Pty::ICRNL, 1), // Convert carriage return (CR, \r) to newline (NL, \n) on input.
|
||||
(russh::Pty::ONLCR, 1), // Convert newline (NL) to CR+NL on output (LF -> CRLF).
|
||||
(russh::Pty::TTY_OP_ISPEED, 38400), // Set input baud rate (here 38400). The baud rate is the number of characters per second.
|
||||
(russh::Pty::TTY_OP_OSPEED, 38400), // Set output baud rate (here 38400). The baud rate is the number of characters per second.
|
||||
];
|
||||
|
||||
static DEFAULT_TERM_ROW_HEIGHT: u32 = 24;
|
||||
static DEFAULT_TERM_COL_WIDTH: u32 = 80;
|
||||
static DEFAULT_TERM_PIXEL_WIDTH: u32 = 0;
|
||||
static DEFAULT_TERM_PIXEL_HEIGHT: u32 = 0;
|
||||
static DEFAULT_TERM_COALESCE_MS: u64 = 16;
|
||||
|
||||
// Number of recent live chunks retained by the broadcast channel for each
|
||||
// subscriber. If a subscriber falls behind this many messages, they will get a
|
||||
// Lagged error and skip to the latest. Tune to: peak_chunks_per_sec × max_pause_sec.
|
||||
static DEFAULT_BROADCAST_CHUNK_CAPACITY: usize = 1024;
|
||||
|
||||
// Byte budget for the on-heap replay/history ring buffer. When the total bytes
|
||||
// of stored chunks exceed this, oldest chunks are evicted. Increase for a
|
||||
// longer replay window at the cost of memory.
|
||||
static DEFAULT_SHELL_RING_BUFFER_CAPACITY: usize = 2 * 1024 * 1024; // default 2MiB
|
||||
|
||||
// Upper bound for the size of a single appended/broadcast chunk. Incoming data
|
||||
// is split into slices no larger than this. Smaller values reduce latency and
|
||||
// loss impact; larger values reduce per-message overhead.
|
||||
static DEFAULT_MAX_CHUNK_SIZE: usize = 16 * 1024; // 16KB
|
||||
|
||||
static DEFAULT_READ_BUFFER_MAX_BYTES: u64 = 512 * 1024; // 512KB
|
||||
|
||||
#[uniffi::export(async_runtime = "tokio")]
|
||||
impl SSHConnection {
|
||||
impl SshConnection {
|
||||
/// Convenience snapshot for property-like access in TS.
|
||||
pub fn info(&self) -> SshConnectionInfo {
|
||||
SshConnectionInfo {
|
||||
connection_id: self.connection_id.clone(),
|
||||
connection_details: self.connection_details.clone(),
|
||||
created_at_ms: self.created_at_ms,
|
||||
tcp_established_at_ms: self.tcp_established_at_ms,
|
||||
}
|
||||
pub fn get_info(&self) -> SshConnectionInfo {
|
||||
self.info.clone()
|
||||
}
|
||||
|
||||
/// Start a shell with the given PTY. Emits only Shell* statuses via options.on_status_change.
|
||||
pub async fn start_shell(&self, opts: StartShellOptions) -> Result<Arc<ShellSession>, SshError> {
|
||||
// Prevent double-start (safe default).
|
||||
if self.shell.lock().await.is_some() {
|
||||
return Err(SshError::ShellAlreadyRunning);
|
||||
}
|
||||
|
||||
let pty = opts.pty;
|
||||
let shell_status_listener = opts.on_status_change.clone();
|
||||
if let Some(sl) = shell_status_listener.as_ref() {
|
||||
sl.on_change(SSHConnectionStatus::ShellConnecting);
|
||||
}
|
||||
let started_at_ms = now_ms();
|
||||
|
||||
// Open session channel.
|
||||
let handle = self.handle.lock().await;
|
||||
let ch = handle.channel_open_session().await?;
|
||||
let term = opts.term;
|
||||
let on_closed_callback = opts.on_closed_callback.clone();
|
||||
|
||||
let client_handle = self.client_handle.lock().await;
|
||||
|
||||
let ch = client_handle.channel_open_session().await?;
|
||||
let channel_id: u32 = ch.id().into();
|
||||
|
||||
// Request PTY & shell.
|
||||
// Request a PTY with basic sane defaults: enable ECHO and set speeds.
|
||||
// RFC4254 terminal mode opcodes: 53=ECHO, 128=TTY_OP_ISPEED, 129=TTY_OP_OSPEED
|
||||
let modes: &[(russh::Pty, u32)] = &[
|
||||
(russh::Pty::ECHO, 1),
|
||||
(russh::Pty::ECHOK, 1),
|
||||
(russh::Pty::ECHOE, 1),
|
||||
(russh::Pty::ICANON, 1),
|
||||
(russh::Pty::ISIG, 1),
|
||||
(russh::Pty::ICRNL, 1),
|
||||
(russh::Pty::ONLCR, 1),
|
||||
(russh::Pty::TTY_OP_ISPEED, 38400),
|
||||
(russh::Pty::TTY_OP_OSPEED, 38400),
|
||||
];
|
||||
ch.request_pty(true, pty.as_ssh_name(), 80, 24, 0, 0, modes).await?;
|
||||
let mut modes: Vec<(russh::Pty, u32)> = DEFAULT_TERMINAL_MODES.to_vec();
|
||||
if let Some(terminal_mode_params) = &opts.terminal_mode {
|
||||
for m in terminal_mode_params {
|
||||
if let Some(pty) = russh::Pty::from_u8(m.opcode) {
|
||||
if let Some(pos) = modes.iter().position(|(p, _)| *p as u8 == m.opcode) {
|
||||
modes[pos].1 = m.value; // override
|
||||
} else {
|
||||
modes.push((pty, m.value)); // add
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let row_height = opts.terminal_size.as_ref().and_then(|s| s.row_height).unwrap_or(DEFAULT_TERM_ROW_HEIGHT);
|
||||
let col_width = opts.terminal_size.as_ref().and_then(|s| s.col_width).unwrap_or(DEFAULT_TERM_COL_WIDTH);
|
||||
let pixel_width = opts.terminal_pixel_size.as_ref().and_then(|s| s.pixel_width).unwrap_or(DEFAULT_TERM_PIXEL_WIDTH);
|
||||
let pixel_height= opts.terminal_pixel_size.as_ref().and_then(|s| s.pixel_height).unwrap_or(DEFAULT_TERM_PIXEL_HEIGHT);
|
||||
|
||||
ch.request_pty(true, term.as_ssh_name(), col_width, row_height, pixel_width, pixel_height, &modes).await?;
|
||||
ch.request_shell(true).await?;
|
||||
|
||||
// Split for read/write; spawn reader.
|
||||
let (mut reader, writer) = ch.split();
|
||||
|
||||
// Setup ring + broadcast for this session
|
||||
let (tx, _rx) = broadcast::channel::<Arc<Chunk>>(1024);
|
||||
let (tx, _rx) = broadcast::channel::<Arc<Chunk>>(DEFAULT_BROADCAST_CHUNK_CAPACITY);
|
||||
let ring = Arc::new(Mutex::new(std::collections::VecDeque::<Arc<Chunk>>::new()));
|
||||
let used_bytes = Arc::new(Mutex::new(0usize));
|
||||
let next_seq = Arc::new(AtomicU64::new(1));
|
||||
let head_seq = Arc::new(AtomicU64::new(1));
|
||||
let tail_seq = Arc::new(AtomicU64::new(0));
|
||||
let dropped_bytes_total = Arc::new(AtomicU64::new(0));
|
||||
let ring_bytes_capacity = Arc::new(AtomicUsize::new(2 * 1024 * 1024)); // default 2MiB
|
||||
let default_coalesce_ms = AtomicU64::new(16); // default 16ms
|
||||
let ring_bytes_capacity = Arc::new(AtomicUsize::new(DEFAULT_SHELL_RING_BUFFER_CAPACITY));
|
||||
let default_coalesce_ms = AtomicU64::new(DEFAULT_TERM_COALESCE_MS);
|
||||
|
||||
let ring_clone = ring.clone();
|
||||
let used_bytes_clone = used_bytes.clone();
|
||||
@@ -364,9 +393,11 @@ impl SSHConnection {
|
||||
let head_seq_c = head_seq.clone();
|
||||
let tail_seq_c = tail_seq.clone();
|
||||
let next_seq_c = next_seq.clone();
|
||||
let shell_listener_for_task = shell_status_listener.clone();
|
||||
|
||||
let on_closed_callback_for_reader = on_closed_callback.clone();
|
||||
|
||||
let reader_task = tokio::spawn(async move {
|
||||
let max_chunk = 16 * 1024; // 16KB
|
||||
let max_chunk = DEFAULT_MAX_CHUNK_SIZE;
|
||||
loop {
|
||||
match reader.wait().await {
|
||||
Some(ChannelMsg::Data { data }) => {
|
||||
@@ -400,8 +431,8 @@ impl SSHConnection {
|
||||
);
|
||||
}
|
||||
Some(ChannelMsg::Close) | None => {
|
||||
if let Some(sl) = shell_listener_for_task.as_ref() {
|
||||
sl.on_change(SSHConnectionStatus::ShellDisconnected);
|
||||
if let Some(sl) = on_closed_callback_for_reader.as_ref() {
|
||||
sl.on_change(channel_id);
|
||||
}
|
||||
break;
|
||||
}
|
||||
@@ -411,19 +442,28 @@ impl SSHConnection {
|
||||
});
|
||||
|
||||
let session = Arc::new(ShellSession {
|
||||
info: ShellSessionInfo {
|
||||
channel_id,
|
||||
created_at_ms: started_at_ms,
|
||||
connected_at_ms: now_ms(),
|
||||
term,
|
||||
connection_id: self.info.connection_id.clone(),
|
||||
},
|
||||
on_closed_callback,
|
||||
parent: self.self_weak.lock().await.clone(),
|
||||
channel_id,
|
||||
|
||||
writer: AsyncMutex::new(writer),
|
||||
reader_task,
|
||||
shell_status_listener,
|
||||
created_at_ms: now_ms(),
|
||||
pty,
|
||||
|
||||
// Ring buffer
|
||||
ring,
|
||||
ring_bytes_capacity,
|
||||
used_bytes,
|
||||
dropped_bytes_total,
|
||||
head_seq,
|
||||
tail_seq,
|
||||
|
||||
// Listener tasks management
|
||||
sender: tx,
|
||||
listener_tasks: Arc::new(Mutex::new(HashMap::new())),
|
||||
next_listener_id: AtomicU64::new(1),
|
||||
@@ -431,28 +471,24 @@ impl SSHConnection {
|
||||
rt_handle: tokio::runtime::Handle::current(),
|
||||
});
|
||||
|
||||
*self.shell.lock().await = Some(session.clone());
|
||||
self.shells.lock().await.insert(channel_id, session.clone());
|
||||
|
||||
// Report ShellConnected.
|
||||
if let Some(sl) = session.shell_status_listener.as_ref() {
|
||||
sl.on_change(SSHConnectionStatus::ShellConnected);
|
||||
}
|
||||
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
// Note: send_data now lives on ShellSession
|
||||
|
||||
// No exported close_shell: shell closure is handled via ShellSession::close()
|
||||
|
||||
/// Disconnect TCP (also closes any active shell).
|
||||
pub async fn disconnect(&self) -> Result<(), SshError> {
|
||||
// Close shell first.
|
||||
if let Some(session) = self.shell.lock().await.take() {
|
||||
let _ = ShellSession::close_internal(&session).await;
|
||||
// TODO: Check if we need to close all these if we are about to disconnect?
|
||||
let sessions: Vec<Arc<ShellSession>> = {
|
||||
let map = self.shells.lock().await;
|
||||
map.values().cloned().collect()
|
||||
};
|
||||
for s in sessions {
|
||||
s.close().await?;
|
||||
}
|
||||
|
||||
let h = self.handle.lock().await;
|
||||
let h = self.client_handle.lock().await;
|
||||
h.disconnect(Disconnect::ByApplication, "bye", "").await?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -460,13 +496,8 @@ impl SSHConnection {
|
||||
|
||||
#[uniffi::export(async_runtime = "tokio")]
|
||||
impl ShellSession {
|
||||
pub fn info(&self) -> ShellSessionInfo {
|
||||
ShellSessionInfo {
|
||||
channel_id: self.channel_id,
|
||||
created_at_ms: self.created_at_ms,
|
||||
pty: self.pty,
|
||||
connection_id: self.parent.upgrade().map(|p| p.connection_id.clone()).unwrap_or_default(),
|
||||
}
|
||||
pub fn get_info(&self) -> ShellSessionInfo {
|
||||
self.info.clone()
|
||||
}
|
||||
|
||||
/// Send bytes to the active shell (stdin).
|
||||
@@ -479,20 +510,14 @@ impl ShellSession {
|
||||
/// Close the associated shell channel and stop its reader task.
|
||||
pub async fn close(&self) -> Result<(), SshError> { self.close_internal().await }
|
||||
|
||||
/// Configure ring buffer policy.
|
||||
pub async fn set_buffer_policy(&self, ring_bytes: Option<u64>, coalesce_ms: Option<u32>) {
|
||||
if let Some(rb) = ring_bytes { self.ring_bytes_capacity.store(rb as usize, Ordering::Relaxed); self.evict_if_needed(); }
|
||||
if let Some(cm) = coalesce_ms { self.default_coalesce_ms.store(cm as u64, Ordering::Relaxed); }
|
||||
}
|
||||
|
||||
/// Buffer statistics snapshot.
|
||||
pub fn buffer_stats(&self) -> BufferStats {
|
||||
let used = *self.used_bytes.lock().unwrap_or_else(|p| p.into_inner()) as u64;
|
||||
let chunks = match self.ring.lock() { Ok(q) => q.len() as u64, Err(p) => p.into_inner().len() as u64 };
|
||||
let chunks_count = match self.ring.lock() { Ok(q) => q.len() as u64, Err(p) => p.into_inner().len() as u64 };
|
||||
BufferStats {
|
||||
ring_bytes: self.ring_bytes_capacity.load(Ordering::Relaxed) as u64,
|
||||
ring_bytes_count: self.ring_bytes_capacity.load(Ordering::Relaxed) as u64,
|
||||
used_bytes: used,
|
||||
chunks,
|
||||
chunks_count,
|
||||
head_seq: self.head_seq.load(Ordering::Relaxed),
|
||||
tail_seq: self.tail_seq.load(Ordering::Relaxed),
|
||||
dropped_bytes_total: self.dropped_bytes_total.load(Ordering::Relaxed),
|
||||
@@ -504,7 +529,7 @@ impl ShellSession {
|
||||
|
||||
/// Read the ring buffer from a cursor.
|
||||
pub fn read_buffer(&self, cursor: Cursor, max_bytes: Option<u64>) -> BufferReadResult {
|
||||
let max_total = max_bytes.unwrap_or(512 * 1024) as usize; // default 512KB
|
||||
let max_total = max_bytes.unwrap_or(DEFAULT_READ_BUFFER_MAX_BYTES) as usize;
|
||||
let mut out_chunks: Vec<TerminalChunk> = Vec::new();
|
||||
let mut dropped: Option<DroppedRange> = None;
|
||||
let head_seq_now = self.head_seq.load(Ordering::Relaxed);
|
||||
@@ -561,7 +586,6 @@ impl ShellSession {
|
||||
let replay = self.read_buffer(opts.cursor.clone(), None);
|
||||
let mut rx = self.sender.subscribe();
|
||||
let id = self.next_listener_id.fetch_add(1, Ordering::Relaxed);
|
||||
eprintln!("ShellSession.add_listener -> id={id}");
|
||||
let default_coalesce_ms = self.default_coalesce_ms.load(Ordering::Relaxed) as u32;
|
||||
let coalesce_ms = opts.coalesce_ms.unwrap_or(default_coalesce_ms);
|
||||
|
||||
@@ -646,67 +670,86 @@ impl ShellSession {
|
||||
// Try to close channel gracefully; ignore error.
|
||||
self.writer.lock().await.close().await.ok();
|
||||
self.reader_task.abort();
|
||||
if let Some(sl) = self.shell_status_listener.as_ref() {
|
||||
sl.on_change(SSHConnectionStatus::ShellDisconnected);
|
||||
if let Some(sl) = self.on_closed_callback.as_ref() {
|
||||
sl.on_change(self.info.channel_id);
|
||||
}
|
||||
// Clear parent's notion of active shell if it matches us.
|
||||
if let Some(parent) = self.parent.upgrade() {
|
||||
let mut guard = parent.shell.lock().await;
|
||||
if let Some(current) = guard.as_ref() {
|
||||
if current.channel_id == self.channel_id { *guard = None; }
|
||||
}
|
||||
parent.shells.lock().await.remove(&self.info.channel_id);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn evict_if_needed(&self) {
|
||||
let cap = self.ring_bytes_capacity.load(Ordering::Relaxed);
|
||||
let mut ring = match self.ring.lock() { Ok(g) => g, Err(p) => p.into_inner() };
|
||||
let mut used = self.used_bytes.lock().unwrap_or_else(|p| p.into_inner());
|
||||
while *used > cap {
|
||||
if let Some(front) = ring.pop_front() {
|
||||
*used -= front.bytes.len();
|
||||
self.dropped_bytes_total.fetch_add(front.bytes.len() as u64, Ordering::Relaxed);
|
||||
self.head_seq.store(front.seq.saturating_add(1), Ordering::Relaxed);
|
||||
} else { break; }
|
||||
}
|
||||
}
|
||||
// /// This was on the public interface but I don't think we need it
|
||||
// pub async fn set_buffer_policy(&self, ring_bytes: Option<u64>, coalesce_ms: Option<u32>) {
|
||||
// if let Some(rb) = ring_bytes { self.ring_bytes_capacity.store(rb as usize, Ordering::Relaxed); self.evict_if_needed(); }
|
||||
// if let Some(cm) = coalesce_ms { self.default_coalesce_ms.store(cm as u64, Ordering::Relaxed); }
|
||||
// }
|
||||
|
||||
// fn evict_if_needed(&self) {
|
||||
// let cap = self.ring_bytes_capacity.load(Ordering::Relaxed);
|
||||
// let mut ring = match self.ring.lock() { Ok(g) => g, Err(p) => p.into_inner() };
|
||||
// let mut used = self.used_bytes.lock().unwrap_or_else(|p| p.into_inner());
|
||||
// while *used > cap {
|
||||
// if let Some(front) = ring.pop_front() {
|
||||
// *used -= front.bytes.len();
|
||||
// self.dropped_bytes_total.fetch_add(front.bytes.len() as u64, Ordering::Relaxed);
|
||||
// self.head_seq.store(front.seq.saturating_add(1), Ordering::Relaxed);
|
||||
// } else { break; }
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
/// ---------- Top-level API ----------
|
||||
|
||||
#[uniffi::export(async_runtime = "tokio")]
|
||||
pub async fn connect(options: ConnectOptions) -> Result<Arc<SSHConnection>, SshError> {
|
||||
pub async fn connect(options: ConnectOptions) -> Result<Arc<SshConnection>, SshError> {
|
||||
let started_at_ms = now_ms();
|
||||
|
||||
let details = ConnectionDetails {
|
||||
host: options.host.clone(),
|
||||
port: options.port,
|
||||
username: options.username.clone(),
|
||||
security: options.security.clone(),
|
||||
host: options.connection_details.host.clone(),
|
||||
port: options.connection_details.port,
|
||||
username: options.connection_details.username.clone(),
|
||||
security: options.connection_details.security.clone(),
|
||||
};
|
||||
if let Some(sl) = options.on_status_change.as_ref() {
|
||||
sl.on_change(SSHConnectionStatus::TcpConnecting);
|
||||
}
|
||||
|
||||
|
||||
// TCP
|
||||
let cfg = Arc::new(ClientConfig::default());
|
||||
let addr = format!("{}:{}", details.host, details.port);
|
||||
let mut handle: ClientHandle<NoopHandler> = client::connect(cfg, addr, NoopHandler).await?;
|
||||
|
||||
if let Some(sl) = options.on_status_change.as_ref() {
|
||||
sl.on_change(SSHConnectionStatus::TcpConnected);
|
||||
|
||||
|
||||
let socket = tokio::net::TcpStream::connect(&addr).await?;
|
||||
let local_port = socket.local_addr()?.port(); // ephemeral local port
|
||||
|
||||
|
||||
let tcp_established_at_ms = now_ms();
|
||||
if let Some(sl) = options.on_connection_progress_callback.as_ref() {
|
||||
sl.on_change(SshConnectionProgressEvent::TcpConnected);
|
||||
}
|
||||
|
||||
|
||||
let cfg = Arc::new(Config::default());
|
||||
let mut handle: ClientHandle<NoopHandler> =
|
||||
russh::client::connect_stream(cfg, socket, NoopHandler).await?;
|
||||
|
||||
|
||||
let ssh_handshake_at_ms = now_ms();
|
||||
if let Some(sl) = options.on_connection_progress_callback.as_ref() {
|
||||
sl.on_change(SshConnectionProgressEvent::SshHandshake);
|
||||
}
|
||||
|
||||
|
||||
// Auth
|
||||
let auth = match &details.security {
|
||||
let auth_result = match &details.security {
|
||||
Security::Password { password } => {
|
||||
handle
|
||||
.authenticate_password(details.username.clone(), password.clone())
|
||||
.await?
|
||||
}
|
||||
// Treat key_id as the OpenSSH PEM-encoded private key content
|
||||
Security::Key { key_id } => {
|
||||
Security::Key { private_key_content } => {
|
||||
// Parse OpenSSH private key text into a russh::keys::PrivateKey
|
||||
let parsed: RusshPrivateKey = RusshPrivateKey::from_openssh(key_id.as_str())
|
||||
let parsed: PrivateKey = PrivateKey::from_openssh(private_key_content.as_str())
|
||||
.map_err(|e| SshError::RusshKeys(e.to_string()))?;
|
||||
// Wrap; omit hash preference (server selects or default applies)
|
||||
let pk_with_hash = PrivateKeyWithHashAlg::new(Arc::new(parsed), None);
|
||||
@@ -715,20 +758,22 @@ pub async fn connect(options: ConnectOptions) -> Result<Arc<SSHConnection>, SshE
|
||||
.await?
|
||||
}
|
||||
};
|
||||
match auth {
|
||||
client::AuthResult::Success => {}
|
||||
other => return Err(SshError::Auth(format!("{other:?}"))),
|
||||
if !matches!(auth_result, russh::client::AuthResult::Success) {
|
||||
return Err(auth_result.into());
|
||||
}
|
||||
|
||||
let now = now_ms();
|
||||
let connection_id = format!("{}@{}:{}|{}", details.username, details.host, details.port, now as u64);
|
||||
let conn = Arc::new(SSHConnection {
|
||||
connection_id,
|
||||
connection_details: details,
|
||||
created_at_ms: now,
|
||||
tcp_established_at_ms: now,
|
||||
handle: AsyncMutex::new(handle),
|
||||
shell: AsyncMutex::new(None),
|
||||
|
||||
let connection_id = format!("{}@{}:{}:{}", details.username, details.host, details.port, local_port);
|
||||
let conn = Arc::new(SshConnection {
|
||||
info: SshConnectionInfo {
|
||||
connection_id,
|
||||
connection_details: details,
|
||||
created_at_ms: started_at_ms,
|
||||
connected_at_ms: now_ms(),
|
||||
progress_timings: SshConnectionInfoProgressTimings { tcp_established_at_ms, ssh_handshake_at_ms },
|
||||
},
|
||||
client_handle: AsyncMutex::new(handle),
|
||||
shells: AsyncMutex::new(HashMap::new()),
|
||||
self_weak: AsyncMutex::new(Weak::new()),
|
||||
});
|
||||
// Initialize weak self reference.
|
||||
@@ -736,20 +781,22 @@ pub async fn connect(options: ConnectOptions) -> Result<Arc<SSHConnection>, SshE
|
||||
Ok(conn)
|
||||
}
|
||||
|
||||
#[uniffi::export(async_runtime = "tokio")]
|
||||
pub async fn generate_key_pair(key_type: KeyType) -> Result<String, SshError> {
|
||||
#[uniffi::export]
|
||||
pub fn validate_private_key(private_key_content: String) -> Result<String, SshError> {
|
||||
let parsed: russh_keys::PrivateKey = russh_keys::PrivateKey::from_openssh(&private_key_content)?;
|
||||
Ok(parsed.to_openssh(LineEnding::LF)?.to_string())
|
||||
}
|
||||
|
||||
#[uniffi::export]
|
||||
pub fn generate_key_pair(key_type: KeyType) -> Result<String, SshError> {
|
||||
let mut rng = OsRng;
|
||||
let key = match key_type {
|
||||
KeyType::Rsa => RusshKeysPrivateKey::random(&mut rng, KeyAlgorithm::Rsa { hash: None })?,
|
||||
KeyType::Ecdsa => RusshKeysPrivateKey::random(
|
||||
&mut rng,
|
||||
KeyAlgorithm::Ecdsa { curve: EcdsaCurve::NistP256 },
|
||||
)?,
|
||||
KeyType::Ed25519 => RusshKeysPrivateKey::random(&mut rng, KeyAlgorithm::Ed25519)?,
|
||||
KeyType::Rsa => russh_keys::PrivateKey::random(&mut rng, Algorithm::Rsa { hash: None })?,
|
||||
KeyType::Ecdsa => russh_keys::PrivateKey::random(&mut rng, Algorithm::Ecdsa { curve: EcdsaCurve::NistP256 })?,
|
||||
KeyType::Ed25519 => russh_keys::PrivateKey::random(&mut rng, Algorithm::Ed25519)?,
|
||||
KeyType::Ed448 => return Err(SshError::UnsupportedKeyType),
|
||||
};
|
||||
let pem = key.to_openssh(LineEnding::LF)?; // Zeroizing<String>
|
||||
Ok(pem.to_string())
|
||||
Ok(key.to_openssh(LineEnding::LF)?.to_string())
|
||||
}
|
||||
|
||||
fn now_ms() -> f64 {
|
||||
@@ -808,3 +855,38 @@ fn append_and_broadcast(
|
||||
offset = end;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// TODO: Split this into different errors for each public function
|
||||
#[derive(Debug, Error, uniffi::Error)]
|
||||
pub enum SshError {
|
||||
#[error("Disconnected")]
|
||||
Disconnected,
|
||||
#[error("Unsupported key type")]
|
||||
UnsupportedKeyType,
|
||||
#[error("Auth failed: {0}")]
|
||||
Auth(String),
|
||||
#[error("Shell already running")]
|
||||
ShellAlreadyRunning,
|
||||
#[error("russh error: {0}")]
|
||||
Russh(String),
|
||||
#[error("russh-keys error: {0}")]
|
||||
RusshKeys(String),
|
||||
}
|
||||
impl From<russh::Error> for SshError {
|
||||
fn from(e: russh::Error) -> Self { SshError::Russh(e.to_string()) }
|
||||
}
|
||||
impl From<russh_keys::Error> for SshError {
|
||||
fn from(e: russh_keys::Error) -> Self { SshError::RusshKeys(e.to_string()) }
|
||||
}
|
||||
impl From<ssh_key::Error> for SshError {
|
||||
fn from(e: ssh_key::Error) -> Self { SshError::RusshKeys(e.to_string()) }
|
||||
}
|
||||
impl From<std::io::Error> for SshError {
|
||||
fn from(e: std::io::Error) -> Self { SshError::Russh(e.to_string()) }
|
||||
}
|
||||
impl From<russh::client::AuthResult> for SshError {
|
||||
fn from(a: russh::client::AuthResult) -> Self {
|
||||
SshError::Auth(format!("{a:?}"))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,6 +20,9 @@ import * as GeneratedRussh from './index';
|
||||
// Core types
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
export type PtyType =
|
||||
| 'Vanilla' | 'Vt100' | 'Vt102' | 'Vt220' | 'Ansi' | 'Xterm' | 'Xterm256';
|
||||
|
||||
export type ConnectionDetails = {
|
||||
host: string;
|
||||
port: number;
|
||||
@@ -29,25 +32,27 @@ export type ConnectionDetails = {
|
||||
| { type: 'key'; privateKey: string };
|
||||
};
|
||||
|
||||
export type SshConnectionStatus =
|
||||
| 'tcpConnecting'
|
||||
| 'tcpConnected'
|
||||
| 'tcpDisconnected'
|
||||
| 'shellConnecting'
|
||||
| 'shellConnected'
|
||||
| 'shellDisconnected';
|
||||
/**
|
||||
* This status is only to provide updates for discrete events
|
||||
* during the connect() promise.
|
||||
*
|
||||
* It is no longer relevant after the connect() promise is resolved.
|
||||
*/
|
||||
export type SshConnectionProgress =
|
||||
| 'tcpConnected' // TCP established, starting SSH handshake
|
||||
| 'sshHandshake' // SSH protocol negotiation complete
|
||||
|
||||
|
||||
export type PtyType =
|
||||
| 'Vanilla' | 'Vt100' | 'Vt102' | 'Vt220' | 'Ansi' | 'Xterm' | 'Xterm256';
|
||||
|
||||
export type ConnectOptions = ConnectionDetails & {
|
||||
onStatusChange?: (status: SshConnectionStatus) => void;
|
||||
onConnectionProgress?: (status: SshConnectionProgress) => void;
|
||||
onDisconnected?: (connectionId: string) => void;
|
||||
abortSignal?: AbortSignal;
|
||||
};
|
||||
|
||||
export type StartShellOptions = {
|
||||
pty: PtyType;
|
||||
onStatusChange?: (status: SshConnectionStatus) => void;
|
||||
onClosed?: (shellId: string) => void;
|
||||
abortSignal?: AbortSignal;
|
||||
};
|
||||
|
||||
@@ -162,13 +167,9 @@ const ptyEnumToLiteral: Record<GeneratedRussh.PtyType, PtyType> = {
|
||||
};
|
||||
|
||||
const sshConnStatusEnumToLiteral = {
|
||||
[GeneratedRussh.SshConnectionStatus.TcpConnecting]: 'tcpConnecting',
|
||||
[GeneratedRussh.SshConnectionStatus.TcpConnected]: 'tcpConnected',
|
||||
[GeneratedRussh.SshConnectionStatus.TcpDisconnected]: 'tcpDisconnected',
|
||||
[GeneratedRussh.SshConnectionStatus.ShellConnecting]: 'shellConnecting',
|
||||
[GeneratedRussh.SshConnectionStatus.ShellConnected]: 'shellConnected',
|
||||
[GeneratedRussh.SshConnectionStatus.ShellDisconnected]: 'shellDisconnected',
|
||||
} as const satisfies Record<GeneratedRussh.SshConnectionStatus, SshConnectionStatus>;
|
||||
[GeneratedRussh.SshConnectionStatus.SshHandshake]: 'sshHandshake',
|
||||
} as const satisfies Record<GeneratedRussh.SshConnectionStatus, SshConnectionProgress>;
|
||||
|
||||
const streamEnumToLiteral = {
|
||||
[GeneratedRussh.StreamKind.Stdout]: 'stdout',
|
||||
|
||||
Reference in New Issue
Block a user