From 841c2f47edfb2967f8ad18e6c9568c8dd8be6298 Mon Sep 17 00:00:00 2001 From: Julien Dessaux Date: Fri, 18 Mar 2022 10:48:37 +0100 Subject: Added basic ssh client to run commands --- .gitignore | 3 +++ build.zig | 61 +++++++++++++++++++++++++++++++++++++++++++++++ src/main.zig | 22 +++++++++++++++++ src/ssh.zig | 78 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 164 insertions(+) create mode 100644 .gitignore create mode 100644 build.zig create mode 100644 src/main.zig create mode 100644 src/ssh.zig diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6380ab0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +kcov-output/ +zig-cache +zig-out diff --git a/build.zig b/build.zig new file mode 100644 index 0000000..34c25f7 --- /dev/null +++ b/build.zig @@ -0,0 +1,61 @@ +const std = @import("std"); + +pub fn build(b: *std.build.Builder) void { + // Standard target options allows the person running `zig build` to choose + // what target to build for. Here we do not override the defaults, which + // means any target is allowed, and the default is native. Other options + // for restricting supported target set are available. + const target = b.standardTargetOptions(.{}); + + // Standard release options allow the person running `zig build` to select + // between Debug, ReleaseSafe, ReleaseFast, and ReleaseSmall. + const mode = b.standardReleaseOptions(); + + const exe = b.addExecutable("zigod", "src/main.zig"); + exe.addLibPath("/usr/lib64/"); + exe.linkSystemLibrary("c"); + exe.linkSystemLibrary("libssh"); + exe.setTarget(target); + exe.setBuildMode(mode); + exe.install(); + + const run_cmd = exe.run(); + run_cmd.step.dependOn(b.getInstallStep()); + if (b.args) |args| { + run_cmd.addArgs(args); + } + + const coverage = b.option(bool, "test-coverage", "Generate test coverage") orelse false; + + const run_step = b.step("run", "Run the app"); + run_step.dependOn(&run_cmd.step); + + const exe_tests = b.addTest("src/main.zig"); + exe_tests.addLibPath("/usr/lib64/"); + exe_tests.linkSystemLibrary("c"); + exe_tests.linkSystemLibrary("libssh"); + exe_tests.setTarget(target); + exe_tests.setBuildMode(mode); + + // Code coverage with kcov, we need an allocator for the setup + var general_purpose_allocator = std.heap.GeneralPurposeAllocator(.{}){}; + defer _ = general_purpose_allocator.deinit(); + const gpa = general_purpose_allocator.allocator(); + // We want to exclude the $HOME/.zig path + const home = std.process.getEnvVarOwned(gpa, "HOME") catch ""; + defer gpa.free(home); + const exclude = std.fmt.allocPrint(gpa, "--exclude-path={s}/.zig/,/usr", .{home}) catch ""; + defer gpa.free(exclude); + if (coverage) { + exe_tests.setExecCmd(&[_]?[]const u8{ + "kcov", + exclude, + //"--path-strip-level=3", // any kcov flags can be specified here + "kcov-output", // output dir for kcov + null, // to get zig to use the --test-cmd-bin flag + }); + } + + const test_step = b.step("test", "Run unit tests"); + test_step.dependOn(&exe_tests.step); +} diff --git a/src/main.zig b/src/main.zig new file mode 100644 index 0000000..3b16914 --- /dev/null +++ b/src/main.zig @@ -0,0 +1,22 @@ +const std = @import("std"); + +const ssh = @import("ssh.zig"); + +pub fn main() anyerror!void { + var client = ssh.Client.init("localhost") catch unreachable; + const stdout = std.io.getStdOut(); + client.run("pwd", stdout) catch unreachable; + client.run("who", stdout) catch unreachable; + client.deinit(); +} + +test "basic test" { + // this test requires you can ssh localhost without a password prompt + // (typically by having your ssh_agent running) + var client = ssh.Client.init("localhost") catch unreachable; + var buffer = std.ArrayList(u8).init(std.testing.allocator); + defer buffer.deinit(); + client.run("echo test", buffer.writer()) catch unreachable; + try std.testing.expectEqualSlices(u8, buffer.items, "test\n"); + client.deinit(); +} diff --git a/src/ssh.zig b/src/ssh.zig new file mode 100644 index 0000000..d8b1fda --- /dev/null +++ b/src/ssh.zig @@ -0,0 +1,78 @@ +const std = @import("std"); +const ssh = @cImport(@cInclude("libssh/libssh.h")); + +pub const Client = struct { + session: ?ssh.ssh_session, + + pub fn deinit(self: *Client) void { + if (self.session) |session| { + if (ssh.ssh_is_connected(session) == 1) { + ssh.ssh_disconnect(session); + } + ssh.ssh_free(session); + self.session = null; + } + } + pub fn init(hostname: [*]const u8) !Client { + var client = Client{ + .session = ssh.ssh_new(), + }; + var port: i64 = 22; + + if (client.session) |session| { + _ = ssh.ssh_options_set(session, ssh.SSH_OPTIONS_HOST, hostname); + _ = ssh.ssh_options_set(session, ssh.SSH_OPTIONS_PORT, &port); + //var verbosity = ssh.SSH_LOG_PROTOCOL; + //_ = ssh.ssh_options_set(session, ssh.SSH_OPTIONS_LOG_VERBOSITY, &verbosity); + + var rc = ssh.ssh_connect(session); + if (rc != ssh.SSH_OK) { + std.log.info("Error connecting to localhost: {s}\n", .{ssh.ssh_get_error(session)}); + std.os.exit(2); + } + + //if (ssh.verify_knownhost(session) < 0) { + // std.log.info("knownhost verification error", .{}); + // std.os.exit(3); + //} + + if (ssh.ssh_userauth_publickey_auto(session, null, null) == ssh.SSH_AUTH_ERROR) { + std.log.info("Error authenticating with public key: {s}\n", .{ssh.ssh_get_error(session)}); + std.os.exit(3); + } + } else { + std.log.info("failed to initialise ssh session\n", .{}); + std.os.exit(1); + } + return client; + } + pub fn run(self: *Client, cmd: [*]const u8, writer: anytype) !void { + if (self.session) |session| { + if (ssh.ssh_channel_new(session)) |channel| { + defer ssh.ssh_channel_free(channel); + defer _ = ssh.ssh_channel_close(channel); + if (ssh.ssh_channel_open_session(channel) != ssh.SSH_OK) { + std.log.info("Error opening channel session: {s}\n", .{ssh.ssh_get_error(session)}); + std.os.exit(3); + } + + if (ssh.ssh_channel_request_exec(channel, cmd) != ssh.SSH_OK) { + std.log.info("Error executing command: {s}\n", .{ssh.ssh_get_error(session)}); + std.os.exit(4); + } + + var buffer: [256]u8 = undefined; + var nbytes = ssh.ssh_channel_read(channel, &buffer, buffer.len, 0); + while (nbytes > 0) : (nbytes = ssh.ssh_channel_read(channel, &buffer, buffer.len, 0)) { + var w = try writer.write(buffer[0..@intCast(usize, nbytes)]); + if (w != nbytes) { + std.os.exit(5); + } + } + _ = ssh.ssh_channel_send_eof(channel); + } else { + std.log.info("Error creating channel", .{}); + } + } + } +}; -- cgit v1.2.3