summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/main.zig22
-rw-r--r--src/ssh.zig78
2 files changed, 100 insertions, 0 deletions
diff --git a/src/main.zig b/src/main.zig
new file mode 100644
index 0000000..3b16914
--- /dev/null
+++ b/src/main.zig
@@ -0,0 +1,22 @@
+const std = @import("std");
+
+const ssh = @import("ssh.zig");
+
+pub fn main() anyerror!void {
+ var client = ssh.Client.init("localhost") catch unreachable;
+ const stdout = std.io.getStdOut();
+ client.run("pwd", stdout) catch unreachable;
+ client.run("who", stdout) catch unreachable;
+ client.deinit();
+}
+
+test "basic test" {
+ // this test requires you can ssh localhost without a password prompt
+ // (typically by having your ssh_agent running)
+ var client = ssh.Client.init("localhost") catch unreachable;
+ var buffer = std.ArrayList(u8).init(std.testing.allocator);
+ defer buffer.deinit();
+ client.run("echo test", buffer.writer()) catch unreachable;
+ try std.testing.expectEqualSlices(u8, buffer.items, "test\n");
+ client.deinit();
+}
diff --git a/src/ssh.zig b/src/ssh.zig
new file mode 100644
index 0000000..d8b1fda
--- /dev/null
+++ b/src/ssh.zig
@@ -0,0 +1,78 @@
+const std = @import("std");
+const ssh = @cImport(@cInclude("libssh/libssh.h"));
+
+pub const Client = struct {
+ session: ?ssh.ssh_session,
+
+ pub fn deinit(self: *Client) void {
+ if (self.session) |session| {
+ if (ssh.ssh_is_connected(session) == 1) {
+ ssh.ssh_disconnect(session);
+ }
+ ssh.ssh_free(session);
+ self.session = null;
+ }
+ }
+ pub fn init(hostname: [*]const u8) !Client {
+ var client = Client{
+ .session = ssh.ssh_new(),
+ };
+ var port: i64 = 22;
+
+ if (client.session) |session| {
+ _ = ssh.ssh_options_set(session, ssh.SSH_OPTIONS_HOST, hostname);
+ _ = ssh.ssh_options_set(session, ssh.SSH_OPTIONS_PORT, &port);
+ //var verbosity = ssh.SSH_LOG_PROTOCOL;
+ //_ = ssh.ssh_options_set(session, ssh.SSH_OPTIONS_LOG_VERBOSITY, &verbosity);
+
+ var rc = ssh.ssh_connect(session);
+ if (rc != ssh.SSH_OK) {
+ std.log.info("Error connecting to localhost: {s}\n", .{ssh.ssh_get_error(session)});
+ std.os.exit(2);
+ }
+
+ //if (ssh.verify_knownhost(session) < 0) {
+ // std.log.info("knownhost verification error", .{});
+ // std.os.exit(3);
+ //}
+
+ if (ssh.ssh_userauth_publickey_auto(session, null, null) == ssh.SSH_AUTH_ERROR) {
+ std.log.info("Error authenticating with public key: {s}\n", .{ssh.ssh_get_error(session)});
+ std.os.exit(3);
+ }
+ } else {
+ std.log.info("failed to initialise ssh session\n", .{});
+ std.os.exit(1);
+ }
+ return client;
+ }
+ pub fn run(self: *Client, cmd: [*]const u8, writer: anytype) !void {
+ if (self.session) |session| {
+ if (ssh.ssh_channel_new(session)) |channel| {
+ defer ssh.ssh_channel_free(channel);
+ defer _ = ssh.ssh_channel_close(channel);
+ if (ssh.ssh_channel_open_session(channel) != ssh.SSH_OK) {
+ std.log.info("Error opening channel session: {s}\n", .{ssh.ssh_get_error(session)});
+ std.os.exit(3);
+ }
+
+ if (ssh.ssh_channel_request_exec(channel, cmd) != ssh.SSH_OK) {
+ std.log.info("Error executing command: {s}\n", .{ssh.ssh_get_error(session)});
+ std.os.exit(4);
+ }
+
+ var buffer: [256]u8 = undefined;
+ var nbytes = ssh.ssh_channel_read(channel, &buffer, buffer.len, 0);
+ while (nbytes > 0) : (nbytes = ssh.ssh_channel_read(channel, &buffer, buffer.len, 0)) {
+ var w = try writer.write(buffer[0..@intCast(usize, nbytes)]);
+ if (w != nbytes) {
+ std.os.exit(5);
+ }
+ }
+ _ = ssh.ssh_channel_send_eof(channel);
+ } else {
+ std.log.info("Error creating channel", .{});
+ }
+ }
+ }
+};